Prakyath01 commited on
Commit
eab6d5a
·
verified ·
1 Parent(s): fd8c579

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -138
app.py CHANGED
@@ -16,6 +16,7 @@ from langchain_community.vectorstores import Chroma
16
  from rank_bm25 import BM25Okapi
17
  from sentence_transformers import CrossEncoder
18
 
 
19
  PERSIST_DIR = "k8s_chroma_db"
20
 
21
  URLS = {
@@ -32,16 +33,18 @@ URLS = {
32
  }
33
 
34
 
35
- # ----------------- SCRAPING + KB ----------------- #
36
 
37
  def scrape_page(name, url):
38
  try:
39
- r = requests.get(url, timeout=20)
40
- r.raise_for_status()
41
- soup = BeautifulSoup(r.text, "html.parser")
42
  content = soup.find("div", class_="td-content")
 
43
  if not content:
44
  return None
 
45
  text = content.get_text(separator="\n").strip()
46
  return Document(page_content=text, metadata={"doc_id": name, "url": url})
47
  except Exception as e:
@@ -50,48 +53,48 @@ def scrape_page(name, url):
50
 
51
 
52
  def build_or_load_kb():
 
53
  embedding_model = HuggingFaceEmbeddings(
54
  model_name="sentence-transformers/all-MiniLM-L6-v2"
55
  )
56
 
57
- # If DB exists, load it
58
  if os.path.isdir(PERSIST_DIR):
59
- print("[INFO] Loading existing Chroma DB")
60
  vectordb = Chroma(
61
  embedding_function=embedding_model,
62
  persist_directory=PERSIST_DIR,
63
  )
64
  raw = vectordb._collection.get(include=["documents", "metadatas"])
65
  chunks = [
66
- Document(page_content=doc, metadata=meta)
67
- for doc, meta in zip(raw["documents"], raw["metadatas"])
68
  ]
69
  return vectordb, chunks
70
 
71
- # Else: scrape + build
72
- print("[INFO] No DB found, scraping docs...")
73
  docs = []
74
  for name, url in URLS.items():
75
- d = scrape_page(name, url)
76
- if d:
77
- docs.append(d)
78
  print(f"[INFO] Scraped {len(docs)} docs")
79
 
80
  splitter = RecursiveCharacterTextSplitter(chunk_size=900, chunk_overlap=200)
81
  chunks = splitter.split_documents(docs)
82
 
83
- vectordb = Chroma.from_documents(
84
- chunks, embedding_model, persist_directory=PERSIST_DIR
85
- )
 
86
  return vectordb, chunks
87
 
88
 
89
  vectordb, chunks = build_or_load_kb()
90
 
91
- # ----------------- HYBRID SEARCH ----------------- #
92
 
93
- bm25_corpus = [doc.page_content.split() for doc in chunks]
94
- bm25 = BM25Okapi(bm25_corpus)
 
95
  reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2")
96
 
97
  retriever = vectordb.as_retriever(
@@ -103,41 +106,40 @@ retriever = vectordb.as_retriever(
103
  def hybrid_search(query, top_k=5):
104
  vector_results = retriever.invoke(query)
105
 
106
- tokenized_query = query.lower().split()
107
- bm25_scores = bm25.get_scores(tokenized_query)
108
- bm25_ranked = sorted(zip(bm25_scores, chunks), key=lambda x: x[0], reverse=True)
109
- bm25_results = [d for _, d in bm25_ranked[:top_k]]
110
 
111
- combined = vector_results + bm25_results
112
  seen = set()
113
- unique = []
114
- for d in combined:
115
- key = (d.metadata.get("doc_id"), d.page_content[:80])
116
  if key not in seen:
117
  seen.add(key)
118
- unique.append(d)
119
 
120
- if not unique:
121
  return []
122
 
123
- pairs = [(query, doc.page_content) for doc in unique]
124
- scores = reranker.predict(pairs)
125
- ranked = sorted(zip(scores, unique), key=lambda x: x[0], reverse=True)[:top_k]
126
 
127
  for s, doc in ranked:
128
  doc.metadata["rerank_score"] = float(s)
 
129
  return [doc for _, doc in ranked]
130
 
131
 
132
- # ----------------- LLM CALL ----------------- #
133
 
134
- def call_llm(prompt: str) -> str:
135
  api_key = os.getenv("OPENROUTER_API_KEY")
136
  if not api_key:
137
- return "⚠️ Missing OPENROUTER_API_KEY in Space secrets.\nGroundedness: 0%"
138
 
139
  try:
140
- r = requests.post(
141
  "https://openrouter.ai/api/v1/chat/completions",
142
  headers={
143
  "Authorization": f"Bearer {api_key}",
@@ -147,36 +149,20 @@ def call_llm(prompt: str) -> str:
147
  json={
148
  "model": "meta-llama/llama-3.1-8b-instruct",
149
  "messages": [{"role": "user", "content": prompt}],
150
- "max_tokens": 400,
151
  "temperature": 0.0,
 
152
  },
153
- timeout=60,
154
  )
155
- r.raise_for_status()
156
- data = r.json()
157
  return data["choices"][0]["message"]["content"]
158
  except Exception as e:
159
- print("[ERROR] LLM:", e)
160
- return f"⚠️ LLM error: {e}\nGroundedness: 0%"
161
-
162
-
163
- # ----------------- CONTEXT + METRICS ----------------- #
164
-
165
- def build_context(query: str):
166
- docs = hybrid_search(query)
167
- if not docs:
168
- return "", [], []
169
 
170
- context, sources, scores = "", [], []
171
- for i, d in enumerate(docs, start=1):
172
- label = f"[{i}]"
173
- context += f"{label} {d.page_content[:900]}\nSource: {d.metadata['url']}\n\n"
174
- sources.append(f"{label} → {d.metadata['url']}")
175
- scores.append(d.metadata["rerank_score"])
176
- return context, sources, scores
177
 
 
178
 
179
- def classify_query(q: str) -> str:
180
  q = q.lower()
181
  if "how" in q:
182
  return "how-to"
@@ -185,134 +171,110 @@ def classify_query(q: str) -> str:
185
  return "general"
186
 
187
 
188
- def init_metrics():
189
- return {"q": [], "lat": [], "tok": [], "g": [], "r": [], "c": [], "t": []}
190
-
191
 
192
- # global analytics, no gr.State
193
- METRICS = init_metrics()
194
 
195
-
196
- # ----------------- CHAT HANDLER ----------------- #
197
 
198
  def answer_question(query, history):
199
- global METRICS
200
- if METRICS is None:
201
- METRICS = init_metrics()
202
-
203
  start = time.time()
204
- ctx, sources, scores = build_context(query)
205
 
206
- if not ctx:
207
- reply = "Not in docs or insufficient context.\nGroundedness: 0%"
208
- history.append((query, reply))
209
- return history, ""
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  prompt = f"""
212
- Use ONLY the context below to answer.
213
- Every sentence MUST end with a citation like [1].
214
 
215
  Question: {query}
216
 
217
  Context:
218
  {ctx}
219
 
220
- At the end add a line: Groundedness: XX%
221
  """
 
222
  answer = call_llm(prompt)
223
  latency = time.time() - start
224
 
225
- # robust groundedness parsing
226
  grounded = 0
227
  m = re.search(r"Groundedness:\s*(\d+)%", answer)
228
  if m:
229
- try:
230
- grounded = int(m.group(1))
231
- except ValueError:
232
- grounded = 0
233
 
234
  cites = len(set(re.findall(r"\[(\d+)\]", answer)))
235
- avg_score = sum(scores) / len(scores) if scores else 0.0
236
- tokens = len(answer.split()) + len(prompt.split())
237
-
238
- alert = ""
239
- if grounded < 70 or cites == 0:
240
- alert = "⚠️ Low support from docs; please verify in official Kubernetes docs.\n\n"
241
-
242
- final = alert + answer + "\n\n---\nSources:\n" + "\n".join(sources)
243
 
244
- history.append((query, final))
245
 
246
  METRICS["q"].append(query)
247
  METRICS["lat"].append(latency)
248
- METRICS["tok"].append(tokens)
249
  METRICS["g"].append(grounded)
250
  METRICS["r"].append(avg_score)
251
  METRICS["c"].append(cites)
252
  METRICS["t"].append(classify_query(query))
253
 
 
 
254
  return history, ""
255
 
256
 
257
- # ----------------- ANALYTICS HELPERS ----------------- #
 
 
 
 
 
 
 
 
 
 
258
 
259
- def render_metrics():
260
- if len(METRICS["q"]) == 0:
261
- return [], 0.0, 0.0, 0.0
262
 
263
- rows = []
264
- for i, q in enumerate(METRICS["q"]):
265
- rows.append([
266
- i + 1,
267
- q,
268
- round(METRICS["lat"][i], 3),
269
- METRICS["tok"][i],
270
- METRICS["g"][i],
271
- round(METRICS["r"][i], 3),
272
- METRICS["c"][i],
273
- METRICS["t"][i],
274
- ])
275
 
276
- avg_ground = sum(METRICS["g"]) / len(METRICS["g"])
277
- avg_lat = sum(METRICS["lat"]) / len(METRICS["lat"])
278
- avg_tok = sum(METRICS["tok"]) / len(METRICS["tok"])
279
 
280
- return rows, avg_ground, avg_lat, avg_tok
281
-
282
-
283
- # ----------------- GRADIO UI ----------------- #
284
 
285
  with gr.Blocks(title="Kubernetes RAG Assistant") as app:
286
  gr.Markdown("# ☸ Kubernetes RAG Assistant")
287
 
288
  with gr.Tab("Chat"):
289
  chat = gr.Chatbot(height=450)
290
- inp = gr.Textbox(label="Ask anything about Kubernetes")
291
- clear_btn = gr.Button("Reset Conversation")
292
 
293
- inp.submit(answer_question, [inp, chat], [chat, inp])
294
- clear_btn.click(lambda: ([], ""), None, [chat, inp])
295
 
296
  with gr.Tab("Analytics"):
297
- gr.Markdown("### 📊 Query Analytics (this session)")
298
- table = gr.DataFrame(
299
- headers=[
300
- "ID",
301
- "Query",
302
- "Latency (s)",
303
- "Tokens",
304
- "Groundedness (%)",
305
- "Avg Rerank Score",
306
- "Citations",
307
- "Type",
308
- ],
309
- interactive=False,
310
- )
311
- avgG = gr.Number(label="Avg Groundedness (%)")
312
- avgL = gr.Number(label="Avg Latency (s)")
313
  avgT = gr.Number(label="Avg Tokens")
314
-
315
- refresh = gr.Button("Update Dashboard")
316
- refresh.click(render_metrics, None, [table, avgG, avgL, avgT])
317
 
318
  app.launch()
 
16
  from rank_bm25 import BM25Okapi
17
  from sentence_transformers import CrossEncoder
18
 
19
+
20
  PERSIST_DIR = "k8s_chroma_db"
21
 
22
  URLS = {
 
33
  }
34
 
35
 
36
+ # ================= Knowledge Base ================= #
37
 
38
  def scrape_page(name, url):
39
  try:
40
+ response = requests.get(url, timeout=20)
41
+ response.raise_for_status()
42
+ soup = BeautifulSoup(response.text, "html.parser")
43
  content = soup.find("div", class_="td-content")
44
+
45
  if not content:
46
  return None
47
+
48
  text = content.get_text(separator="\n").strip()
49
  return Document(page_content=text, metadata={"doc_id": name, "url": url})
50
  except Exception as e:
 
53
 
54
 
55
  def build_or_load_kb():
56
+ print("[INFO] Loading embedding model...")
57
  embedding_model = HuggingFaceEmbeddings(
58
  model_name="sentence-transformers/all-MiniLM-L6-v2"
59
  )
60
 
 
61
  if os.path.isdir(PERSIST_DIR):
62
+ print("[INFO] Loading existing vector DB...")
63
  vectordb = Chroma(
64
  embedding_function=embedding_model,
65
  persist_directory=PERSIST_DIR,
66
  )
67
  raw = vectordb._collection.get(include=["documents", "metadatas"])
68
  chunks = [
69
+ Document(page_content=d, metadata=m)
70
+ for d, m in zip(raw["documents"], raw["metadatas"])
71
  ]
72
  return vectordb, chunks
73
 
74
+ print("[INFO] No DB found — scraping docs...")
 
75
  docs = []
76
  for name, url in URLS.items():
77
+ doc = scrape_page(name, url)
78
+ if doc:
79
+ docs.append(doc)
80
  print(f"[INFO] Scraped {len(docs)} docs")
81
 
82
  splitter = RecursiveCharacterTextSplitter(chunk_size=900, chunk_overlap=200)
83
  chunks = splitter.split_documents(docs)
84
 
85
+ vectordb = Chroma.from_documents(chunks, embedding_model, persist_directory=PERSIST_DIR)
86
+ vectordb.persist()
87
+
88
+ print("[INFO] Vector DB built & saved.")
89
  return vectordb, chunks
90
 
91
 
92
  vectordb, chunks = build_or_load_kb()
93
 
 
94
 
95
+ # ================= Search & Reranker ================= #
96
+
97
+ bm25 = BM25Okapi([c.page_content.split() for c in chunks])
98
  reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2")
99
 
100
  retriever = vectordb.as_retriever(
 
106
  def hybrid_search(query, top_k=5):
107
  vector_results = retriever.invoke(query)
108
 
109
+ bm_scores = bm25.get_scores(query.lower().split())
110
+ bm_ranked = sorted(zip(bm_scores, chunks), reverse=True)
111
+ bm_results = [doc for _, doc in bm_ranked[:top_k]]
 
112
 
113
+ unique_docs = []
114
  seen = set()
115
+ for doc in vector_results + bm_results:
116
+ key = (doc.metadata.get("doc_id"), doc.page_content[:50])
 
117
  if key not in seen:
118
  seen.add(key)
119
+ unique_docs.append(doc)
120
 
121
+ if not unique_docs:
122
  return []
123
 
124
+ rerank_pairs = [(query, doc.page_content) for doc in unique_docs]
125
+ scores = reranker.predict(rerank_pairs)
126
+ ranked = sorted(zip(scores, unique_docs), reverse=True)[:top_k]
127
 
128
  for s, doc in ranked:
129
  doc.metadata["rerank_score"] = float(s)
130
+
131
  return [doc for _, doc in ranked]
132
 
133
 
134
+ # ================= LLM ================= #
135
 
136
+ def call_llm(prompt):
137
  api_key = os.getenv("OPENROUTER_API_KEY")
138
  if not api_key:
139
+ return "⚠️ Missing API key.\nGroundedness: 0%"
140
 
141
  try:
142
+ res = requests.post(
143
  "https://openrouter.ai/api/v1/chat/completions",
144
  headers={
145
  "Authorization": f"Bearer {api_key}",
 
149
  json={
150
  "model": "meta-llama/llama-3.1-8b-instruct",
151
  "messages": [{"role": "user", "content": prompt}],
 
152
  "temperature": 0.0,
153
+ "max_tokens": 400,
154
  },
 
155
  )
156
+ res.raise_for_status()
157
+ data = res.json()
158
  return data["choices"][0]["message"]["content"]
159
  except Exception as e:
160
+ return f"⚠️ LLM Error: {e}\nGroundedness: 0%"
 
 
 
 
 
 
 
 
 
161
 
 
 
 
 
 
 
 
162
 
163
+ # ================= Analytics ================= #
164
 
165
+ def classify_query(q):
166
  q = q.lower()
167
  if "how" in q:
168
  return "how-to"
 
171
  return "general"
172
 
173
 
174
+ METRICS = {"q": [], "lat": [], "tok": [], "g": [], "r": [], "c": [], "t": []}
 
 
175
 
 
 
176
 
177
+ # ================= Chat Handler ================= #
 
178
 
179
  def answer_question(query, history):
 
 
 
 
180
  start = time.time()
181
+ docs = hybrid_search(query)
182
 
183
+ if not docs:
184
+ reply = "Not found in docs.\nGroundedness: 0%"
185
+ return history + [
186
+ {"role": "user", "content": query},
187
+ {"role": "assistant", "content": reply}
188
+ ], ""
189
+
190
+ ctx = ""
191
+ sources = []
192
+ scores = []
193
+ for i, d in enumerate(docs, 1):
194
+ label = f"[{i}]"
195
+ ctx += f"{label} {d.page_content[:900]}\nSource: {d.metadata['url']}\n\n"
196
+ sources.append(f"{label} → {d.metadata['url']}")
197
+ scores.append(d.metadata["rerank_score"])
198
 
199
  prompt = f"""
200
+ Answer the question ONLY using the context below.
201
+ Each sentence MUST end with a citation like [1].
202
 
203
  Question: {query}
204
 
205
  Context:
206
  {ctx}
207
 
208
+ End with: Groundedness: XX%
209
  """
210
+
211
  answer = call_llm(prompt)
212
  latency = time.time() - start
213
 
 
214
  grounded = 0
215
  m = re.search(r"Groundedness:\s*(\d+)%", answer)
216
  if m:
217
+ grounded = int(m.group(1"))
 
 
 
218
 
219
  cites = len(set(re.findall(r"\[(\d+)\]", answer)))
220
+ avg_score = sum(scores) / len(scores) if scores else 0
 
 
 
 
 
 
 
221
 
222
+ final = answer + "\n\n---\nSources:\n" + "\n".join(sources)
223
 
224
  METRICS["q"].append(query)
225
  METRICS["lat"].append(latency)
226
+ METRICS["tok"].append(len(answer.split()))
227
  METRICS["g"].append(grounded)
228
  METRICS["r"].append(avg_score)
229
  METRICS["c"].append(cites)
230
  METRICS["t"].append(classify_query(query))
231
 
232
+ history.append({"role": "user", "content": query})
233
+ history.append({"role": "assistant", "content": final})
234
  return history, ""
235
 
236
 
237
+ def update_dashboard():
238
+ rows = list(zip(
239
+ range(1, len(METRICS["q"])+1),
240
+ METRICS["q"],
241
+ METRICS["lat"],
242
+ METRICS["tok"],
243
+ METRICS["g"],
244
+ METRICS["r"],
245
+ METRICS["c"],
246
+ METRICS["t"],
247
+ ))
248
 
249
+ avgG = round(sum(METRICS["g"]) / len(METRICS["g"]), 2)
250
+ avgL = round(sum(METRICS["lat"]) / len(METRICS["lat"]), 2)
251
+ avgT = round(sum(METRICS["tok"]) / len(METRICS["tok"]), 2)
252
 
253
+ return rows, avgG, avgL, avgT
 
 
 
 
 
 
 
 
 
 
 
254
 
 
 
 
255
 
256
+ # ================= UI ================= #
 
 
 
257
 
258
  with gr.Blocks(title="Kubernetes RAG Assistant") as app:
259
  gr.Markdown("# ☸ Kubernetes RAG Assistant")
260
 
261
  with gr.Tab("Chat"):
262
  chat = gr.Chatbot(height=450)
263
+ user_in = gr.Textbox(label="Ask anything about Kubernetes")
264
+ reset = gr.Button("Reset")
265
 
266
+ user_in.submit(answer_question, [user_in, chat], [chat, user_in])
267
+ reset.click(lambda: ([], ""), None, [chat, user_in])
268
 
269
  with gr.Tab("Analytics"):
270
+ gr.Markdown("### 📊 Analytics This Session")
271
+ table = gr.DataFrame(headers=[
272
+ "ID","Query","Latency","Tokens","Grounded","Rerank","Citations","Type"
273
+ ], interactive=False)
274
+ avgG = gr.Number(label="Avg Groundedness")
275
+ avgL = gr.Number(label="Avg Latency")
 
 
 
 
 
 
 
 
 
 
276
  avgT = gr.Number(label="Avg Tokens")
277
+ refresh = gr.Button("Refresh")
278
+ refresh.click(update_dashboard, None, [table, avgG, avgL, avgT])
 
279
 
280
  app.launch()