Prakyath01 commited on
Commit
27bcc7f
·
verified ·
1 Parent(s): fc9656f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -501
app.py CHANGED
@@ -1,21 +1,14 @@
1
- # ========================================================
2
- # ☸ Kubernetes RAG Assistant
3
- # Hybrid Search • Reranked • Cited • Monitored 📌
4
- # Ready for Hugging Face Spaces (Gradio)
5
- # ========================================================
6
-
7
  import os
8
  import re
9
  import time
10
  import requests
11
  import pandas as pd
12
  import matplotlib
13
- matplotlib.use("Agg") # Non-GUI backend for servers
14
  import matplotlib.pyplot as plt
15
  import gradio as gr
16
 
17
  from bs4 import BeautifulSoup
18
-
19
  from langchain_core.documents import Document
20
  from langchain_text_splitters import RecursiveCharacterTextSplitter
21
  from langchain_huggingface import HuggingFaceEmbeddings
@@ -23,8 +16,6 @@ from langchain_community.vectorstores import Chroma
23
  from rank_bm25 import BM25Okapi
24
  from sentence_transformers import CrossEncoder
25
 
26
- # -------------------- CONFIG -------------------- #
27
-
28
  PERSIST_DIR = "k8s_chroma_db"
29
 
30
  URLS = {
@@ -40,101 +31,48 @@ URLS = {
40
  "autoscaling": "https://kubernetes.io/docs/tasks/run-application/horizontal-pod-autoscale/",
41
  }
42
 
43
- # -------------------- SCRAPING -------------------- #
44
-
45
- def scrape_page(name: str, url: str):
46
  try:
47
  r = requests.get(url, timeout=20)
48
- r.raise_for_status()
49
  soup = BeautifulSoup(r.text, "html.parser")
50
  content = soup.find("div", class_="td-content")
51
  if not content:
52
- print(f"[WARN] No td-content for {url}")
53
  return None
54
  text = content.get_text(separator="\n").strip()
55
  return Document(page_content=text, metadata={"doc_id": name, "url": url})
56
- except Exception as e:
57
- print(f"[ERROR] scraping {url}: {e}")
58
  return None
59
 
60
- def scrape_k8s_docs():
61
- print("[INFO] Scraping Kubernetes docs...")
 
 
 
 
 
 
 
 
 
 
62
  docs = []
63
  for name, url in URLS.items():
64
  d = scrape_page(name, url)
65
  if d:
66
  docs.append(d)
67
- print(f"[INFO] Scraped {len(docs)} docs.")
68
- return docs
69
-
70
- # -------------------- KNOWLEDGE BASE SETUP -------------------- #
71
-
72
- def build_or_load_kb():
73
- """
74
- If a Chroma DB exists, load it.
75
- Otherwise, scrape → chunk → embed → create DB → persist.
76
- Returns: vectordb, chunks_for_bm25
77
- """
78
- print("[INFO] Initializing knowledge base...")
79
- embedding_model = HuggingFaceEmbeddings(
80
- model_name="sentence-transformers/all-MiniLM-L6-v2"
81
- )
82
 
83
- # If persistent dir exists, load vectordb and docs from it
84
- if os.path.isdir(PERSIST_DIR):
85
- print("[INFO] Found existing Chroma DB. Loading...")
86
- vectordb = Chroma(
87
- embedding_function=embedding_model,
88
- persist_directory=PERSIST_DIR,
89
- )
90
- # Pull all docs from collection
91
- try:
92
- raw = vectordb._collection.get(include=["documents", "metadatas"])
93
- docs = [
94
- Document(page_content=doc, metadata=meta)
95
- for doc, meta in zip(raw["documents"], raw["metadatas"])
96
- ]
97
- print(f"[INFO] Loaded {len(docs)} chunks from existing DB.")
98
- chunks = docs
99
- except Exception as e:
100
- print(f"[WARN] Failed to load docs from DB, rescraping. Error: {e}")
101
- docs = scrape_k8s_docs()
102
- splitter = RecursiveCharacterTextSplitter(
103
- chunk_size=900, chunk_overlap=200
104
- )
105
- chunks = splitter.split_documents(docs)
106
- vectordb = Chroma.from_documents(
107
- chunks,
108
- embedding_model,
109
- persist_directory=PERSIST_DIR,
110
- )
111
- vectordb.persist()
112
- else:
113
- print("[INFO] No existing DB, scraping + building...")
114
- docs = scrape_k8s_docs()
115
- splitter = RecursiveCharacterTextSplitter(
116
- chunk_size=900, chunk_overlap=200
117
- )
118
- chunks = splitter.split_documents(docs)
119
- vectordb = Chroma.from_documents(
120
- chunks,
121
- embedding_model,
122
- persist_directory=PERSIST_DIR,
123
- )
124
- vectordb.persist()
125
- print("[INFO] Chroma DB built and persisted.")
126
 
127
- return vectordb, chunks, embedding_model
 
128
 
129
- vectordb, chunks, embedding_model = build_or_load_kb()
130
 
131
- # -------------------- HYBRID SEARCH + RERANKER -------------------- #
132
 
133
- print("[INFO] Initializing BM25 + CrossEncoder reranker...")
134
  bm25_corpus = [doc.page_content.split() for doc in chunks]
135
  bm25 = BM25Okapi(bm25_corpus)
136
-
137
- # Balanced reranker model (Option B you chose)
138
  reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2")
139
 
140
  retriever = vectordb.as_retriever(
@@ -142,452 +80,149 @@ retriever = vectordb.as_retriever(
142
  search_kwargs={"k": 8, "score_threshold": 0.35},
143
  )
144
 
145
- def hybrid_search(query: str, top_k: int = 5):
146
- # Vector search
147
  vector_results = retriever.invoke(query)
148
-
149
- # BM25 keyword search
150
  tokenized_query = query.lower().split()
151
  bm25_scores = bm25.get_scores(tokenized_query)
152
- bm25_ranked = sorted(
153
- zip(bm25_scores, chunks), key=lambda x: x[0], reverse=True
154
- )
155
  bm25_results = [d for _, d in bm25_ranked[:top_k]]
156
-
157
- # Combine + dedupe
158
  combined = vector_results + bm25_results
159
- unique = []
160
  seen = set()
 
161
  for d in combined:
162
- key = (d.metadata.get("doc_id", ""), d.page_content[:80])
163
  if key not in seen:
164
  seen.add(key)
165
  unique.append(d)
166
-
167
  if not unique:
168
  return []
169
-
170
- # Rerank with cross-encoder
171
  pairs = [(query, doc.page_content) for doc in unique]
172
  scores = reranker.predict(pairs)
 
 
 
 
173
 
174
- scored_docs = sorted(zip(scores, unique), key=lambda x: x[0], reverse=True)
175
- top_docs = scored_docs[:top_k]
176
-
177
- reranked = []
178
- for score, doc in top_docs:
179
- doc.metadata["rerank_score"] = float(score)
180
- reranked.append(doc)
181
-
182
- return reranked
183
-
184
- # -------------------- LLM CALL (OpenRouter) -------------------- #
185
-
186
- def call_llm(prompt: str) -> str:
187
  url = "https://openrouter.ai/api/v1/chat/completions"
188
  api_key = os.getenv("OPENROUTER_API_KEY")
189
  if not api_key:
190
- print("[ERROR] OPENROUTER_API_KEY not set.")
191
- return (
192
- "⚠️ Model failed: missing OPENROUTER_API_KEY environment variable.\n\n"
193
- "Groundedness: 0%"
194
- )
195
-
196
- headers = {
197
  "Authorization": f"Bearer {api_key}",
198
  "HTTP-Referer": "https://huggingface.co/",
199
- "X-Title": "Kubernetes RAG Assistant",
200
- }
201
- data = {
202
  "model": "meta-llama/llama-3.1-8b-instruct",
203
  "messages": [{"role": "user", "content": prompt}],
204
  "max_tokens": 400,
205
- "temperature": 0.0,
206
- }
207
-
208
- try:
209
- r = requests.post(url, headers=headers, json=data, timeout=60)
210
- r.raise_for_status()
211
- res = r.json()
212
- except Exception as e:
213
- print(f"[ERROR] LLM call failed: {e}")
214
- return "⚠️ Model failed. Please retry.\n\nGroundedness: 0%"
215
-
216
- if "choices" in res and res["choices"]:
217
- return res["choices"][0]["message"]["content"]
218
- print("[ERROR] Unexpected LLM response:", res)
219
- return "⚠️ Model failed. Please retry.\n\nGroundedness: 0%"
220
-
221
- # -------------------- CONTEXT + CITATIONS -------------------- #
222
-
223
- def build_context_with_citations(query: str, history=None, top_k: int = 5):
224
- """
225
- Use hybrid search + conversation-aware follow-up handling.
226
- """
227
- effective_query = query
228
 
229
- if history:
230
- last_user_q = history[-1][0] if history[-1] else ""
231
- followup_tokens = [
232
- "and", "also", "that", "those", "it", "them", "one",
233
- "this", "these", "more", "what about"
234
- ]
235
- if len(query.split()) <= 4 or any(t in query.lower() for t in followup_tokens):
236
- effective_query = f"{last_user_q} | Follow-up: {query}"
237
-
238
- docs = hybrid_search(effective_query, top_k=top_k)
239
  if not docs:
240
- return "", [], [], []
241
-
242
- context = ""
243
- sources = []
244
- scores = []
245
- doc_ids = []
246
-
247
  for i, d in enumerate(docs, start=1):
248
  label = f"[{i}]"
249
- snippet = d.page_content[:900].strip()
250
- url = d.metadata.get("url", "N/A")
251
- score = float(d.metadata.get("rerank_score", 0.0))
252
-
253
- context += (
254
- f"{label} (score={score:.2f})\n"
255
- f"{snippet}\n"
256
- f"Source: {url}\n\n"
257
- )
258
- sources.append(f"{label} → {url}")
259
- scores.append(score)
260
- doc_ids.append(d.metadata.get("doc_id", "k8s-doc"))
261
-
262
- return context, sources, scores, doc_ids
263
-
264
- # -------------------- QUERY CLASSIFIER -------------------- #
265
-
266
- def classify_query(query: str) -> str:
267
- q = query.lower()
268
- if any(q.startswith(p) for p in ["what is", "define", "explain"]):
269
- return "definition"
270
- if any(k in q for k in ["how to", "how do i", "steps", "tutorial"]):
271
- return "how-to"
272
- if any(k in q for k in ["error", "failed", "crash", "issue", "troubleshoot"]):
273
- return "debugging"
274
- if any(k in q for k in ["best practice", "recommend", "should i"]):
275
- return "best-practice"
276
  return "general"
277
 
278
- # -------------------- ANALYTICS STORAGE -------------------- #
279
-
280
- def init_analytics():
281
- return {
282
- "queries": [],
283
- "latency": [],
284
- "approx_tokens": [],
285
- "groundedness": [],
286
- "avg_rerank_score": [],
287
- "citation_count": [],
288
- "query_type": [],
289
- }
290
-
291
- # -------------------- MAIN ANSWER FUNCTION -------------------- #
292
-
293
- def answer_question(query, history, analytics):
294
- if analytics is None or analytics == {}:
295
- analytics = init_analytics()
296
-
297
- start_time = time.time()
298
-
299
- context, sources, scores, doc_ids = build_context_with_citations(query, history)
300
-
301
- # Retrieval failure – safe response
302
- if not context:
303
- resp = (
304
- "Not in documentation or insufficient context to answer confidently.\n\n"
305
- "Possible reasons:\n"
306
- "- The question is too vague or missing key details.\n"
307
- "- The topic may not be covered in the scraped Kubernetes docs.\n\n"
308
- "Try rephrasing with more detail.\n\n"
309
- "Groundedness: 0%"
310
- )
311
- latency = time.time() - start_time
312
-
313
- analytics["queries"].append(query)
314
- analytics["latency"].append(latency)
315
- analytics["approx_tokens"].append(len(resp.split()))
316
- analytics["groundedness"].append(0)
317
- analytics["avg_rerank_score"].append(0.0)
318
- analytics["citation_count"].append(0)
319
- analytics["query_type"].append(classify_query(query))
320
-
321
- history.append((query, resp))
322
- return history, "", analytics
323
-
324
- # Recent conversation context (not for citations)
325
- conversation_context = ""
326
- if history:
327
- last_turns = history[-3:]
328
- for uq, aq in last_turns:
329
- conversation_context += f"User: {uq}\nAssistant: {aq}\n\n"
330
-
331
- prompt = f"""
332
- You are a strict Kubernetes documentation assistant.
333
-
334
- RULES:
335
- - Answer ONLY using the Context section.
336
- - EVERY sentence must end with at least one citation like [1] or [2].
337
- - If the answer is not found in the context, respond exactly:
338
- "Not in documentation: Please rephrase or check the official Kubernetes docs."
339
- - Do NOT invent APIs, flags, YAML fields, or behaviors not shown in the context.
340
- - Use short, precise sentences.
341
- - At the END, output a separate line: Groundedness: XX%
342
- - XX is an integer from 0 to 100.
343
- - 100 means every statement is directly and clearly supported.
344
- - Lower if you are uncertain or context is thin.
345
-
346
- User Question:
347
- {query}
348
-
349
- Recent Conversation (for context, not citations):
350
- {conversation_context}
351
-
352
- Context (with source ids and rerank scores):
353
- {context}
354
  """
355
-
356
- answer = call_llm(prompt)
357
- latency = time.time() - start_time
358
-
359
- approx_tokens = len(prompt.split()) + len(answer.split())
360
-
361
- groundedness_match = re.search(r"Groundedness:\s*(\d+)%", answer)
362
- groundedness = int(groundedness_match.group(1)) if groundedness_match else 0
363
-
364
- citation_matches = re.findall(r"\[(\d+)\]", answer)
365
- unique_citations = set(citation_matches)
366
- citation_count = len(unique_citations)
367
-
368
- avg_rerank_score = sum(scores) / len(scores) if scores else 0.0
369
-
370
- # Low groundedness / no citations alert
371
- alert = ""
372
- if groundedness < 70 or citation_count == 0:
373
- alert = (
374
- "⚠️ Warning: This response may not be fully supported by the retrieved Kubernetes documentation.\n"
375
- "Consider rephrasing your question with more specific details, or verifying in the official docs.\n\n"
376
- )
377
-
378
- final_answer = alert + answer + "\n\n---\nSources:\n" + "\n".join(sources)
379
-
380
- history.append((query, final_answer))
381
-
382
- analytics["queries"].append(query)
383
- analytics["latency"].append(latency)
384
- analytics["approx_tokens"].append(approx_tokens)
385
- analytics["groundedness"].append(groundedness)
386
- analytics["avg_rerank_score"].append(avg_rerank_score)
387
- analytics["citation_count"].append(citation_count)
388
- analytics["query_type"].append(classify_query(query))
389
-
390
- return history, "", analytics
391
-
392
- # -------------------- ANALYTICS RENDERING -------------------- #
393
-
394
- def render_analytics(analytics):
395
- if not analytics or len(analytics["queries"]) == 0:
396
- return [], 0.0, 0.0, 0.0
397
-
398
- rows = []
399
- for i, q in enumerate(analytics["queries"]):
400
- rows.append([
401
- i + 1,
402
- q,
403
- round(analytics["latency"][i], 3),
404
- analytics["approx_tokens"][i],
405
- analytics["groundedness"][i],
406
- round(analytics["avg_rerank_score"][i], 3),
407
- analytics["citation_count"][i],
408
- analytics["query_type"][i],
409
- ])
410
-
411
- avg_latency = sum(analytics["latency"]) / len(analytics["latency"])
412
- avg_grounded = sum(analytics["groundedness"]) / len(analytics["groundedness"])
413
- avg_tokens = sum(analytics["approx_tokens"]) / len(analytics["approx_tokens"])
414
-
415
- return rows, avg_latency, avg_grounded, avg_tokens
416
-
417
- def generate_charts(analytics):
418
- if not analytics or len(analytics["queries"]) == 0:
419
- return None, None, None, None
420
-
421
- df = pd.DataFrame({
422
- "Latency": analytics["latency"],
423
- "Groundedness": analytics["groundedness"],
424
- "Tokens": analytics["approx_tokens"],
425
- "Query Type": analytics["query_type"],
426
- })
427
-
428
- # Latency chart
429
- fig_latency, ax1 = plt.subplots()
430
- ax1.plot(df["Latency"])
431
- ax1.set_title("Latency Over Time")
432
- ax1.set_xlabel("Query #")
433
- ax1.set_ylabel("Seconds")
434
-
435
- # Groundedness chart
436
- fig_ground, ax2 = plt.subplots()
437
- ax2.plot(df["Groundedness"])
438
- ax2.set_title("Groundedness Trend")
439
- ax2.set_xlabel("Query #")
440
- ax2.set_ylabel("Groundedness (%)")
441
-
442
- # Token usage chart
443
- fig_tokens, ax3 = plt.subplots()
444
- ax3.plot(df["Tokens"])
445
- ax3.set_title("Token Usage Over Time")
446
- ax3.set_xlabel("Query #")
447
- ax3.set_ylabel("Approx Tokens")
448
-
449
- # Query type distribution pie chart
450
- fig_pie, ax4 = plt.subplots()
451
- df["Query Type"].value_counts().plot.pie(
452
- ax=ax4,
453
- autopct="%1.1f%%",
454
- )
455
- ax4.set_ylabel("")
456
- ax4.set_title("Query Types Distribution")
457
-
458
- return fig_latency, fig_ground, fig_tokens, fig_pie
459
-
460
- def export_csv(analytics):
461
- if not analytics or len(analytics["queries"]) == 0:
462
- path = "analytics.csv"
463
- pd.DataFrame(columns=[
464
- "query", "latency", "approx_tokens", "groundedness",
465
- "avg_rerank_score", "citation_count", "query_type"
466
- ]).to_csv(path, index=False)
467
- return path
468
-
469
- df = pd.DataFrame({
470
- "query": analytics["queries"],
471
- "latency": analytics["latency"],
472
- "approx_tokens": analytics["approx_tokens"],
473
- "groundedness": analytics["groundedness"],
474
- "avg_rerank_score": analytics["avg_rerank_score"],
475
- "citation_count": analytics["citation_count"],
476
- "query_type": analytics["query_type"],
477
  })
478
- path = "analytics.csv"
479
- df.to_csv(path, index=False)
480
- return path
481
-
482
- def clear_all():
483
- return [], "", init_analytics()
484
-
485
- # -------------------- GRADIO UI -------------------- #
486
-
487
- custom_css = """
488
- .source-box {
489
- background: #1e293b;
490
- color: #dbeafe;
491
- padding: 10px;
492
- border-radius: 7px;
493
- border: 1px solid #3b82f6;
494
- }
495
- """
496
-
497
- with gr.Blocks(theme="soft") as app:
498
- gr.HTML(f"<style>{custom_css}</style>")
499
- gr.HTML(
500
- "<h1 style='text-align:center;color:#3b82f6'>☸ Kubernetes RAG Assistant</h1>"
501
- "<p style='text-align:center;color:#cbd5e1'>Hybrid Search • Reranked • Cited • Monitored 📌</p>"
502
- )
503
-
504
- analytics_state = gr.State(init_analytics())
505
-
506
- with gr.Tab("Chatbot"):
507
- chat = gr.Chatbot(label="Conversation", height=450)
508
- msg = gr.Textbox(
509
- label="Ask anything about Kubernetes…",
510
- placeholder="e.g., What is RBAC?",
511
- )
512
- clear = gr.Button("Clear Conversation")
513
-
514
- msg.submit(
515
- answer_question,
516
- inputs=[msg, chat, analytics_state],
517
- outputs=[chat, msg, analytics_state],
518
- )
519
-
520
- clear.click(
521
- clear_all,
522
- inputs=None,
523
- outputs=[chat, msg, analytics_state],
524
- )
525
-
526
- with gr.Tab("Analytics Dashboard"):
527
- gr.Markdown("### 📊 System Metrics")
528
- gr.Markdown(
529
- "- Each row is a user query\n"
530
- "- Latency = retrieval + LLM time\n"
531
- "- Groundedness = model-reported confidence based on docs\n"
532
- "- Rerank score = cross-encoder relevance\n"
533
- "- Citation count = number of unique [n] labels used in the answer"
534
- )
535
-
536
- analytics_table = gr.Dataframe(
537
- headers=[
538
- "ID",
539
- "Query",
540
- "Latency (s)",
541
- "Approx Tokens",
542
- "Groundedness (%)",
543
- "Avg Rerank Score",
544
- "Citations Used",
545
- "Query Type",
546
- ],
547
- row_count=0,
548
- col_count=8,
549
- interactive=False,
550
- label="Query Stats",
551
- )
552
-
553
- avg_latency_box = gr.Number(label="Average Latency (s)", precision=3)
554
- avg_ground_box = gr.Number(label="Average Groundedness (%)", precision=1)
555
- avg_tokens_box = gr.Number(label="Average Tokens per Answer", precision=1)
556
-
557
- plot_latency = gr.Plot(label="Latency Trend")
558
- plot_ground = gr.Plot(label="Groundedness Trend")
559
- plot_tokens = gr.Plot(label="Token Usage Trend")
560
- plot_pie = gr.Plot(label="Query Types Distribution")
561
-
562
- refresh_btn = gr.Button("Refresh Analytics")
563
- export_btn = gr.Button("Export Analytics as CSV")
564
- file_out = gr.File(label="Download CSV")
565
-
566
- # Refresh metrics table + summary
567
- refresh_btn.click(
568
- render_analytics,
569
- inputs=[analytics_state],
570
- outputs=[
571
- analytics_table,
572
- avg_latency_box,
573
- avg_ground_box,
574
- avg_tokens_box,
575
- ],
576
- )
577
-
578
- # Refresh charts
579
- refresh_btn.click(
580
- generate_charts,
581
- inputs=[analytics_state],
582
- outputs=[plot_latency, plot_ground, plot_tokens, plot_pie],
583
- )
584
-
585
- # Export CSV
586
- export_btn.click(
587
- export_csv,
588
- inputs=[analytics_state],
589
- outputs=[file_out],
590
- )
591
-
592
- if __name__ == "__main__":
593
- app.launch()
 
 
 
 
 
 
 
1
  import os
2
  import re
3
  import time
4
  import requests
5
  import pandas as pd
6
  import matplotlib
7
+ matplotlib.use("Agg")
8
  import matplotlib.pyplot as plt
9
  import gradio as gr
10
 
11
  from bs4 import BeautifulSoup
 
12
  from langchain_core.documents import Document
13
  from langchain_text_splitters import RecursiveCharacterTextSplitter
14
  from langchain_huggingface import HuggingFaceEmbeddings
 
16
  from rank_bm25 import BM25Okapi
17
  from sentence_transformers import CrossEncoder
18
 
 
 
19
  PERSIST_DIR = "k8s_chroma_db"
20
 
21
  URLS = {
 
31
  "autoscaling": "https://kubernetes.io/docs/tasks/run-application/horizontal-pod-autoscale/",
32
  }
33
 
34
+ def scrape_page(name, url):
 
 
35
  try:
36
  r = requests.get(url, timeout=20)
 
37
  soup = BeautifulSoup(r.text, "html.parser")
38
  content = soup.find("div", class_="td-content")
39
  if not content:
 
40
  return None
41
  text = content.get_text(separator="\n").strip()
42
  return Document(page_content=text, metadata={"doc_id": name, "url": url})
43
+ except:
 
44
  return None
45
 
46
+ def build_or_load_kb():
47
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
48
+
49
+ if os.path.isdir(PERSIST_DIR):
50
+ vectordb = Chroma(embedding_function=embedding_model, persist_directory=PERSIST_DIR)
51
+ raw = vectordb._collection.get(include=["documents", "metadatas"])
52
+ chunks = [
53
+ Document(page_content=doc, metadata=meta)
54
+ for doc, meta in zip(raw["documents"], raw["metadatas"])
55
+ ]
56
+ return vectordb, chunks
57
+
58
  docs = []
59
  for name, url in URLS.items():
60
  d = scrape_page(name, url)
61
  if d:
62
  docs.append(d)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ splitter = RecursiveCharacterTextSplitter(chunk_size=900, chunk_overlap=200)
65
+ chunks = splitter.split_documents(docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ vectordb = Chroma.from_documents(chunks, embedding_model, persist_directory=PERSIST_DIR)
68
+ vectordb.persist()
69
 
70
+ return vectordb, chunks
71
 
72
+ vectordb, chunks = build_or_load_kb()
73
 
 
74
  bm25_corpus = [doc.page_content.split() for doc in chunks]
75
  bm25 = BM25Okapi(bm25_corpus)
 
 
76
  reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2")
77
 
78
  retriever = vectordb.as_retriever(
 
80
  search_kwargs={"k": 8, "score_threshold": 0.35},
81
  )
82
 
83
+ def hybrid_search(query, top_k=5):
 
84
  vector_results = retriever.invoke(query)
 
 
85
  tokenized_query = query.lower().split()
86
  bm25_scores = bm25.get_scores(tokenized_query)
87
+ bm25_ranked = sorted(zip(bm25_scores, chunks), key=lambda x: x[0], reverse=True)
 
 
88
  bm25_results = [d for _, d in bm25_ranked[:top_k]]
 
 
89
  combined = vector_results + bm25_results
 
90
  seen = set()
91
+ unique = []
92
  for d in combined:
93
+ key = (d.metadata.get("doc_id"), d.page_content[:80])
94
  if key not in seen:
95
  seen.add(key)
96
  unique.append(d)
 
97
  if not unique:
98
  return []
 
 
99
  pairs = [(query, doc.page_content) for doc in unique]
100
  scores = reranker.predict(pairs)
101
+ ranked = sorted(zip(scores, unique), key=lambda x: x[0], reverse=True)[:top_k]
102
+ for s, doc in ranked:
103
+ doc.metadata["rerank_score"] = float(s)
104
+ return [doc for _, doc in ranked]
105
 
106
+ def call_llm(prompt):
 
 
 
 
 
 
 
 
 
 
 
 
107
  url = "https://openrouter.ai/api/v1/chat/completions"
108
  api_key = os.getenv("OPENROUTER_API_KEY")
109
  if not api_key:
110
+ return " Missing API key.\nGroundedness: 0%"
111
+ res = requests.post(url, headers={
 
 
 
 
 
112
  "Authorization": f"Bearer {api_key}",
113
  "HTTP-Referer": "https://huggingface.co/",
114
+ "X-Title": "Kubernetes RAG Assistant"
115
+ }, json={
 
116
  "model": "meta-llama/llama-3.1-8b-instruct",
117
  "messages": [{"role": "user", "content": prompt}],
118
  "max_tokens": 400,
119
+ "temperature": 0
120
+ }).json()
121
+ return res["choices"][0]["message"]["content"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ def build_context(query, history):
124
+ docs = hybrid_search(query)
 
 
 
 
 
 
 
 
125
  if not docs:
126
+ return "", [], []
127
+ context, sources, scores = "", [], []
 
 
 
 
 
128
  for i, d in enumerate(docs, start=1):
129
  label = f"[{i}]"
130
+ context += f"{label} {d.page_content[:900]}\nSource: {d.metadata['url']}\n\n"
131
+ sources.append(f"{label} {d.metadata['url']}")
132
+ scores.append(d.metadata["rerank_score"])
133
+ return context, sources, scores
134
+
135
+ def classify_query(q):
136
+ q=q.lower()
137
+ if "how" in q: return "how-to"
138
+ if "error" in q: return "debug"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  return "general"
140
 
141
+ def init_metrics():
142
+ return {"q":[], "lat":[], "tok":[], "g":[],"r":[],"c":[],"t":[]}
143
+
144
+ def answer_question(query, history, metrics):
145
+ if metrics is None or metrics == {}: metrics = init_metrics()
146
+ start = time.time()
147
+ ctx, sources, scores = build_context(query, history)
148
+ if not ctx:
149
+ reply="Not in docs.\nGroundedness: 0%"
150
+ history.append((query, reply))
151
+ return history,"",metrics
152
+ prompt=f"""
153
+ Use ONLY context. Every sentence must end with citation [n].
154
+ Answer:
155
+ Question: {query}
156
+ Context:
157
+ {ctx}
158
+ Groundedness must be in final line as: Groundedness: XX%
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  """
160
+ answer=call_llm(prompt)
161
+ latency=time.time()-start
162
+ grounded=int(re.search(r"Groundedness:\s*(\d+)%", answer).group(1)) if "Groundedness" in answer else 0
163
+ cites=len(set(re.findall(r"\[(\d+)\]", answer)))
164
+ avg_score=sum(scores)/len(scores)
165
+ tokens=len(answer.split())+len(prompt.split())
166
+ alert="⚠ Low support.\n\n" if grounded<70 or cites==0 else ""
167
+ final=alert+answer+"\n\n---\nSources:\n"+"\n".join(sources)
168
+ history.append((query,final))
169
+ metrics["q"].append(query)
170
+ metrics["lat"].append(latency)
171
+ metrics["tok"].append(tokens)
172
+ metrics["g"].append(grounded)
173
+ metrics["r"].append(avg_score)
174
+ metrics["c"].append(cites)
175
+ metrics["t"].append(classify_query(query))
176
+ return history,"",metrics
177
+
178
+ def render(metrics):
179
+ rows=[[i+1,metrics["q"][i],round(metrics["lat"][i],3),
180
+ metrics["tok"][i],metrics["g"][i],
181
+ round(metrics["r"][i],3),metrics["c"][i],metrics["t"][i]]
182
+ for i in range(len(metrics["q"]))]
183
+ avg_lat=sum(metrics["lat"])/len(metrics["lat"])
184
+ avg_g=sum(metrics["g"])/len(metrics["g"])
185
+ avg_tok=sum(metrics["tok"])/len(metrics["tok"])
186
+ return rows,avg_lat,avg_g,avg_tok
187
+
188
+ def charts(metrics):
189
+ df=pd.DataFrame({
190
+ "Latency":metrics["lat"],
191
+ "Groundedness":metrics["g"],
192
+ "Tokens":metrics["tok"],
193
+ "Type":metrics["t"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  })
195
+ fig_l,ax=plt.subplots();ax.plot(df["Latency"]);ax.set_title("Latency");ax.set_xlabel("#");ax.set_ylabel("s")
196
+ fig_g,ax=plt.subplots();ax.plot(df["Groundedness"]);ax.set_title("Groundedness");ax.set_xlabel("#");ax.set_ylabel("%")
197
+ fig_t,ax=plt.subplots();ax.plot(df["Tokens"]);ax.set_title("Tokens");ax.set_xlabel("#");ax.set_ylabel("count")
198
+ fig_p,ax=plt.subplots();df["Type"].value_counts().plot.pie(ax=ax,autopct="%1.1f%");ax.set_ylabel("");ax.set_title("Query Types")
199
+ return fig_l,fig_g,fig_t,fig_p
200
+
201
+ def export_csv(metrics):
202
+ df=pd.DataFrame(metrics)
203
+ path="analytics.csv";df.to_csv(path,index=False);return path
204
+
205
+ def clear_all(): return [],"",init_metrics()
206
+
207
+ metrics_state=gr.State(init_metrics())
208
+
209
+ with gr.Blocks() as app:
210
+ gr.Markdown("# Kubernetes RAG Assistant")
211
+ with gr.Tab("Chat"):
212
+ chat=gr.Chatbot()
213
+ user_in=gr.Textbox(label="Ask about Kubernetes")
214
+ clear=gr.Button("Clear")
215
+ user_in.submit(answer_question,[user_in,chat,metrics_state],[chat,user_in,metrics_state])
216
+ clear.click(clear_all,outputs=[chat,user_in,metrics_state])
217
+ with gr.Tab("Analytics"):
218
+ table=gr.Dataframe(headers=["ID","Query","Latency","Tokens","Grounded","Rerank","Citations","Type"])
219
+ avgL=gr.Number(label="Avg Latency");avgG=gr.Number(label="Avg Grounded");avgT=gr.Number(label="Avg Tokens")
220
+ p1,p2,p3,p4=gr.Plot(),gr.Plot(),gr.Plot(),gr.Plot()
221
+ refresh=gr.Button("Refresh")
222
+ export=gr.Button("Export CSV")
223
+ file=gr.File()
224
+ refresh.click(render,[metrics_state],[table,avgL,avgG,avgT])
225
+ refresh.click(charts,[metrics_state],[p1,p2,p3,p4])
226
+ export.click(export_csv,[metrics_state],[file])
227
+
228
+ app.launch()