Prakyath01 commited on
Commit
de72d5d
·
verified ·
1 Parent(s): 3ddf607

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -117
app.py CHANGED
@@ -44,14 +44,19 @@ def scrape_page(name, url):
44
  return None
45
 
46
  def build_or_load_kb():
47
- embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
48
-
 
 
49
  if os.path.isdir(PERSIST_DIR):
50
- vectordb = Chroma(embedding_function=embedding_model, persist_directory=PERSIST_DIR)
51
- raw = vectordb._collection.get(include=["documents", "metadatas"])
 
 
 
52
  chunks = [
53
  Document(page_content=doc, metadata=meta)
54
- for doc, meta in zip(raw["documents"], raw["metadatas"])
55
  ]
56
  return vectordb, chunks
57
 
@@ -64,165 +69,162 @@ def build_or_load_kb():
64
  splitter = RecursiveCharacterTextSplitter(chunk_size=900, chunk_overlap=200)
65
  chunks = splitter.split_documents(docs)
66
 
67
- vectordb = Chroma.from_documents(chunks, embedding_model, persist_directory=PERSIST_DIR)
68
- vectordb.persist()
69
-
 
 
70
  return vectordb, chunks
71
 
72
  vectordb, chunks = build_or_load_kb()
73
 
74
  bm25_corpus = [doc.page_content.split() for doc in chunks]
75
  bm25 = BM25Okapi(bm25_corpus)
 
76
  reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2")
77
 
78
  retriever = vectordb.as_retriever(
79
  search_type="similarity_score_threshold",
80
- search_kwargs={"k": 8, "score_threshold": 0.35},
81
  )
82
 
83
- def hybrid_search(query, top_k=5):
84
- vector_results = retriever.invoke(query)
85
- tokenized_query = query.lower().split()
86
- bm25_scores = bm25.get_scores(tokenized_query)
87
- bm25_ranked = sorted(zip(bm25_scores, chunks), key=lambda x: x[0], reverse=True)
88
- bm25_results = [d for _, d in bm25_ranked[:top_k]]
89
- combined = vector_results + bm25_results
 
90
  seen = set()
91
  unique = []
92
  for d in combined:
93
- key = (d.metadata.get("doc_id"), d.page_content[:80])
94
  if key not in seen:
95
  seen.add(key)
96
  unique.append(d)
 
97
  if not unique:
98
  return []
99
- pairs = [(query, doc.page_content) for doc in unique]
 
100
  scores = reranker.predict(pairs)
101
- ranked = sorted(zip(scores, unique), key=lambda x: x[0], reverse=True)[:top_k]
102
- for s, doc in ranked:
103
- doc.metadata["rerank_score"] = float(s)
104
- return [doc for _, doc in ranked]
 
 
105
 
106
  def call_llm(prompt):
107
- url = "https://openrouter.ai/api/v1/chat/completions"
108
  api_key = os.getenv("OPENROUTER_API_KEY")
109
  if not api_key:
110
- return " Missing API key.\nGroundedness: 0%"
111
- res = requests.post(url, headers={
112
- "Authorization": f"Bearer {api_key}",
113
- "HTTP-Referer": "https://huggingface.co/",
114
- "X-Title": "Kubernetes RAG Assistant"
115
- }, json={
116
- "model": "meta-llama/llama-3.1-8b-instruct",
117
- "messages": [{"role": "user", "content": prompt}],
118
- "max_tokens": 400,
119
- "temperature": 0
120
- }).json()
121
- return res["choices"][0]["message"]["content"]
122
-
123
- def build_context(query, history):
 
 
 
 
 
 
 
 
 
124
  docs = hybrid_search(query)
125
  if not docs:
126
  return "", [], []
127
- context, sources, scores = "", [], []
128
  for i, d in enumerate(docs, start=1):
129
- label = f"[{i}]"
130
- context += f"{label} {d.page_content[:900]}\nSource: {d.metadata['url']}\n\n"
131
- sources.append(f"{label} → {d.metadata['url']}")
132
  scores.append(d.metadata["rerank_score"])
133
- return context, sources, scores
134
 
135
- def classify_query(q):
 
 
 
136
  q=q.lower()
137
- if "how" in q: return "how-to"
138
- if "error" in q: return "debug"
139
- return "general"
140
 
141
- def init_metrics():
142
- return {"q":[], "lat":[], "tok":[], "g":[],"r":[],"c":[],"t":[]}
 
143
 
144
- def answer_question(query, history, metrics):
145
- if metrics is None or metrics == {}: metrics = init_metrics()
146
- start = time.time()
147
- ctx, sources, scores = build_context(query, history)
148
  if not ctx:
149
- reply="Not in docs.\nGroundedness: 0%"
150
- history.append((query, reply))
151
  return history,"",metrics
152
- prompt=f"""
153
- Use ONLY context. Every sentence must end with citation [n].
154
- Answer:
155
- Question: {query}
156
  Context:
157
  {ctx}
158
- Groundedness must be in final line as: Groundedness: XX%
159
- """
160
- answer=call_llm(prompt)
161
- latency=time.time()-start
162
- grounded=int(re.search(r"Groundedness:\s*(\d+)%", answer).group(1)) if "Groundedness" in answer else 0
163
- cites=len(set(re.findall(r"\[(\d+)\]", answer)))
164
- avg_score=sum(scores)/len(scores)
165
- tokens=len(answer.split())+len(prompt.split())
166
- alert="⚠ Low support.\n\n" if grounded<70 or cites==0 else ""
167
- final=alert+answer+"\n\n---\nSources:\n"+"\n".join(sources)
168
- history.append((query,final))
169
- metrics["q"].append(query)
 
170
  metrics["lat"].append(latency)
171
  metrics["tok"].append(tokens)
172
  metrics["g"].append(grounded)
173
- metrics["r"].append(avg_score)
174
- metrics["c"].append(cites)
175
- metrics["t"].append(classify_query(query))
 
176
  return history,"",metrics
177
 
178
  def render(metrics):
179
- rows=[[i+1,metrics["q"][i],round(metrics["lat"][i],3),
180
- metrics["tok"][i],metrics["g"][i],
181
- round(metrics["r"][i],3),metrics["c"][i],metrics["t"][i]]
182
- for i in range(len(metrics["q"]))]
183
- avg_lat=sum(metrics["lat"])/len(metrics["lat"])
184
- avg_g=sum(metrics["g"])/len(metrics["g"])
185
- avg_tok=sum(metrics["tok"])/len(metrics["tok"])
186
- return rows,avg_lat,avg_g,avg_tok
187
-
188
- def charts(metrics):
189
- df=pd.DataFrame({
190
- "Latency":metrics["lat"],
191
- "Groundedness":metrics["g"],
192
- "Tokens":metrics["tok"],
193
- "Type":metrics["t"]
194
- })
195
- fig_l,ax=plt.subplots();ax.plot(df["Latency"]);ax.set_title("Latency");ax.set_xlabel("#");ax.set_ylabel("s")
196
- fig_g,ax=plt.subplots();ax.plot(df["Groundedness"]);ax.set_title("Groundedness");ax.set_xlabel("#");ax.set_ylabel("%")
197
- fig_t,ax=plt.subplots();ax.plot(df["Tokens"]);ax.set_title("Tokens");ax.set_xlabel("#");ax.set_ylabel("count")
198
- fig_p,ax=plt.subplots();df["Type"].value_counts().plot.pie(ax=ax,autopct="%1.1f%");ax.set_ylabel("");ax.set_title("Query Types")
199
- return fig_l,fig_g,fig_t,fig_p
200
-
201
- def export_csv(metrics):
202
- df=pd.DataFrame(metrics)
203
- path="analytics.csv";df.to_csv(path,index=False);return path
204
-
205
- def clear_all(): return [],"",init_metrics()
206
-
207
- metrics_state=gr.State(init_metrics())
208
-
209
- with gr.Blocks() as app:
210
  gr.Markdown("# ☸ Kubernetes RAG Assistant")
211
  with gr.Tab("Chat"):
212
- chat=gr.Chatbot()
213
- user_in=gr.Textbox(label="Ask about Kubernetes")
214
- clear=gr.Button("Clear")
215
- user_in.submit(answer_question,[user_in,chat,metrics_state],[chat,user_in,metrics_state])
216
- clear.click(clear_all,outputs=[chat,user_in,metrics_state])
 
217
  with gr.Tab("Analytics"):
218
- table=gr.Dataframe(headers=["ID","Query","Latency","Tokens","Grounded","Rerank","Citations","Type"])
219
- avgL=gr.Number(label="Avg Latency");avgG=gr.Number(label="Avg Grounded");avgT=gr.Number(label="Avg Tokens")
220
- p1,p2,p3,p4=gr.Plot(),gr.Plot(),gr.Plot(),gr.Plot()
221
- refresh=gr.Button("Refresh")
222
- export=gr.Button("Export CSV")
223
- file=gr.File()
224
  refresh.click(render,[metrics_state],[table,avgL,avgG,avgT])
225
- refresh.click(charts,[metrics_state],[p1,p2,p3,p4])
226
- export.click(export_csv,[metrics_state],[file])
227
 
228
  app.launch()
 
44
  return None
45
 
46
  def build_or_load_kb():
47
+ embedding_model = HuggingFaceEmbeddings(
48
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
49
+ )
50
+
51
  if os.path.isdir(PERSIST_DIR):
52
+ vectordb = Chroma(
53
+ embedding_function=embedding_model,
54
+ persist_directory=PERSIST_DIR
55
+ )
56
+ data = vectordb._collection.get(include=["documents", "metadatas"])
57
  chunks = [
58
  Document(page_content=doc, metadata=meta)
59
+ for doc, meta in zip(data["documents"], data["metadatas"])
60
  ]
61
  return vectordb, chunks
62
 
 
69
  splitter = RecursiveCharacterTextSplitter(chunk_size=900, chunk_overlap=200)
70
  chunks = splitter.split_documents(docs)
71
 
72
+ vectordb = Chroma.from_documents(
73
+ chunks,
74
+ embedding_model,
75
+ persist_directory=PERSIST_DIR
76
+ )
77
  return vectordb, chunks
78
 
79
  vectordb, chunks = build_or_load_kb()
80
 
81
  bm25_corpus = [doc.page_content.split() for doc in chunks]
82
  bm25 = BM25Okapi(bm25_corpus)
83
+
84
  reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2")
85
 
86
  retriever = vectordb.as_retriever(
87
  search_type="similarity_score_threshold",
88
+ search_kwargs={"k": 8, "score_threshold": 0.4},
89
  )
90
 
91
+ def hybrid_search(query):
92
+ vresults = retriever.invoke(query)
93
+ tokens = query.lower().split()
94
+ bm_scores = bm25.get_scores(tokens)
95
+ bm_ranked = sorted(zip(bm_scores, chunks), key=lambda x: x[0], reverse=True)
96
+ bmresults = [d for _, d in bm_ranked[:5]]
97
+
98
+ combined = vresults + bmresults
99
  seen = set()
100
  unique = []
101
  for d in combined:
102
+ key = (d.metadata.get("doc_id"), d.page_content[:50])
103
  if key not in seen:
104
  seen.add(key)
105
  unique.append(d)
106
+
107
  if not unique:
108
  return []
109
+
110
+ pairs = [(query, d.page_content) for d in unique]
111
  scores = reranker.predict(pairs)
112
+ ranked = sorted(zip(scores, unique), key=lambda x: x[0], reverse=True)[:5]
113
+
114
+ for s, d in ranked:
115
+ d.metadata["rerank_score"] = float(s)
116
+
117
+ return [d for _, d in ranked]
118
 
119
  def call_llm(prompt):
 
120
  api_key = os.getenv("OPENROUTER_API_KEY")
121
  if not api_key:
122
+ return "⚠️ Missing OPENROUTER_API_KEY environment variable.\nGroundedness: 0%"
123
+
124
+ try:
125
+ res = requests.post(
126
+ "https://openrouter.ai/api/v1/chat/completions",
127
+ headers={
128
+ "Authorization": f"Bearer {api_key}",
129
+ "HTTP-Referer": "https://huggingface.co/",
130
+ "X-Title": "Kubernetes RAG Assistant"
131
+ },
132
+ json={
133
+ "model": "meta-llama/llama-3.1-8b-instruct",
134
+ "messages": [{"role": "user", "content": prompt}],
135
+ "max_tokens": 300,
136
+ "temperature": 0.2
137
+ }
138
+ )
139
+ res.raise_for_status()
140
+ return res.json()["choices"][0]["message"]["content"]
141
+ except Exception as e:
142
+ return f"⚠️ LLM Error: {e}\nGroundedness: 0%"
143
+
144
+ def build_context(query):
145
  docs = hybrid_search(query)
146
  if not docs:
147
  return "", [], []
148
+ ctx, srcs, scores = "", [], []
149
  for i, d in enumerate(docs, start=1):
150
+ ctx += f"[{i}] {d.page_content[:900]}\nSource: {d.metadata['url']}\n\n"
151
+ srcs.append(f"[{i}] {d.metadata['url']}")
 
152
  scores.append(d.metadata["rerank_score"])
153
+ return ctx, srcs, scores
154
 
155
+ def init_metrics():
156
+ return {"q":[], "lat":[], "tok":[], "g":[], "cit":[], "r":[], "type":[]}
157
+
158
+ def classify(q):
159
  q=q.lower()
160
+ return "how-to" if "how" in q else ("debug" if "error" in q else "general")
 
 
161
 
162
+ def answer(q, history, metrics):
163
+ if metrics is None: metrics = init_metrics()
164
+ start = time.time()
165
 
166
+ ctx, srcs, scores = build_context(q)
 
 
 
167
  if not ctx:
168
+ txt = "Not in docs.\nGroundedness: 0%"
169
+ history.append((q, txt))
170
  return history,"",metrics
171
+
172
+ prompt = f"""Use context ONLY. Cite every sentence as [n].
173
+ User question: {q}
174
+
175
  Context:
176
  {ctx}
177
+ Groundedness MUST appear as: Groundedness: XX%"""
178
+ txt = call_llm(prompt)
179
+
180
+ latency = time.time() - start
181
+ grounded = int(re.search(r"Groundedness:\s*(\d+)%", txt).group(1)) if "Groundedness" in txt else 0
182
+ tokens = len(txt.split())
183
+ cites = len(set(re.findall(r"\[(\d+)\]", txt)))
184
+ avg = sum(scores)/len(scores)
185
+
186
+ final = txt+"\n\nSources:\n"+"\n".join(srcs)
187
+ history.append((q, final))
188
+
189
+ metrics["q"].append(q)
190
  metrics["lat"].append(latency)
191
  metrics["tok"].append(tokens)
192
  metrics["g"].append(grounded)
193
+ metrics["cit"].append(cites)
194
+ metrics["r"].append(avg)
195
+ metrics["type"].append(classify(q))
196
+
197
  return history,"",metrics
198
 
199
  def render(metrics):
200
+ if len(metrics["q"])==0: return [],0,0,0
201
+ rows=[[
202
+ i+1, metrics["q"][i], round(metrics["lat"][i],3),
203
+ metrics["tok"][i], metrics["g"][i],
204
+ round(metrics["r"][i],2), metrics["cit"][i], metrics["type"][i]
205
+ ] for i in range(len(metrics["q"]))]
206
+ avgL=sum(metrics["g"])/len(metrics["g"])
207
+ avgG=sum(metrics["lat"])/len(metrics["lat"])
208
+ avgT=sum(metrics["tok"])/len(metrics["tok"])
209
+ return rows,avgL,avgG,avgT
210
+
211
+ metrics_state = gr.State(init_metrics())
212
+
213
+ with gr.Blocks(title="Kubernetes RAG Assistant") as app:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  gr.Markdown("# ☸ Kubernetes RAG Assistant")
215
  with gr.Tab("Chat"):
216
+ chat = gr.Chatbot()
217
+ inp = gr.Textbox(label="Ask anything about Kubernetes")
218
+ clear= gr.Button("Reset")
219
+ inp.submit(answer,[inp,chat,metrics_state],[chat,inp,metrics_state])
220
+ clear.click(lambda: ([], "", init_metrics()), None, [chat,inp,metrics_state])
221
+
222
  with gr.Tab("Analytics"):
223
+ table = gr.DataFrame(headers=["ID","Query","Latency","Tokens","Grounded","Relevance","Citations","Type"])
224
+ avgL = gr.Number(label="Avg Groundedness")
225
+ avgG = gr.Number(label="Avg Latency")
226
+ avgT = gr.Number(label="Avg Tokens")
227
+ refresh = gr.Button("Update Dashboard")
 
228
  refresh.click(render,[metrics_state],[table,avgL,avgG,avgT])
 
 
229
 
230
  app.launch()