Spaces:
Sleeping
Sleeping
| # ======================================================== | |
| # ☸ Kubernetes RAG Assistant | |
| # Hybrid Search • Reranked • Cited • Monitored 📌 | |
| # Ready for Hugging Face Spaces (Gradio) | |
| # ======================================================== | |
| import os | |
| import re | |
| import time | |
| import requests | |
| import pandas as pd | |
| import matplotlib | |
| matplotlib.use("Agg") # Non-GUI backend for servers | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| from bs4 import BeautifulSoup | |
| from langchain_core.documents import Document | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import CrossEncoder | |
| # -------------------- CONFIG -------------------- # | |
| PERSIST_DIR = "k8s_chroma_db" | |
| URLS = { | |
| "pods": "https://kubernetes.io/docs/concepts/workloads/pods/", | |
| "deployments": "https://kubernetes.io/docs/concepts/workloads/controllers/deployment/", | |
| "services": "https://kubernetes.io/docs/concepts/services-networking/service/", | |
| "namespaces": "https://kubernetes.io/docs/concepts/overview/working-with-objects/namespaces/", | |
| "nodes": "https://kubernetes.io/docs/concepts/architecture/nodes/", | |
| "statefulsets": "https://kubernetes.io/docs/concepts/workloads/controllers/statefulset/", | |
| "rbac": "https://kubernetes.io/docs/reference/access-authn-authz/rbac/", | |
| "persistent-volumes": "https://kubernetes.io/docs/concepts/storage/persistent-volumes/", | |
| "ingress": "https://kubernetes.io/docs/concepts/services-networking/ingress/", | |
| "autoscaling": "https://kubernetes.io/docs/tasks/run-application/horizontal-pod-autoscale/", | |
| } | |
| # -------------------- SCRAPING -------------------- # | |
| def scrape_page(name: str, url: str): | |
| try: | |
| r = requests.get(url, timeout=20) | |
| r.raise_for_status() | |
| soup = BeautifulSoup(r.text, "html.parser") | |
| content = soup.find("div", class_="td-content") | |
| if not content: | |
| print(f"[WARN] No td-content for {url}") | |
| return None | |
| text = content.get_text(separator="\n").strip() | |
| return Document(page_content=text, metadata={"doc_id": name, "url": url}) | |
| except Exception as e: | |
| print(f"[ERROR] scraping {url}: {e}") | |
| return None | |
| def scrape_k8s_docs(): | |
| print("[INFO] Scraping Kubernetes docs...") | |
| docs = [] | |
| for name, url in URLS.items(): | |
| d = scrape_page(name, url) | |
| if d: | |
| docs.append(d) | |
| print(f"[INFO] Scraped {len(docs)} docs.") | |
| return docs | |
| # -------------------- KNOWLEDGE BASE SETUP -------------------- # | |
| def build_or_load_kb(): | |
| """ | |
| If a Chroma DB exists, load it. | |
| Otherwise, scrape → chunk → embed → create DB → persist. | |
| Returns: vectordb, chunks_for_bm25 | |
| """ | |
| print("[INFO] Initializing knowledge base...") | |
| embedding_model = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| # If persistent dir exists, load vectordb and docs from it | |
| if os.path.isdir(PERSIST_DIR): | |
| print("[INFO] Found existing Chroma DB. Loading...") | |
| vectordb = Chroma( | |
| embedding_function=embedding_model, | |
| persist_directory=PERSIST_DIR, | |
| ) | |
| # Pull all docs from collection | |
| try: | |
| raw = vectordb._collection.get(include=["documents", "metadatas"]) | |
| docs = [ | |
| Document(page_content=doc, metadata=meta) | |
| for doc, meta in zip(raw["documents"], raw["metadatas"]) | |
| ] | |
| print(f"[INFO] Loaded {len(docs)} chunks from existing DB.") | |
| chunks = docs | |
| except Exception as e: | |
| print(f"[WARN] Failed to load docs from DB, rescraping. Error: {e}") | |
| docs = scrape_k8s_docs() | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=900, chunk_overlap=200 | |
| ) | |
| chunks = splitter.split_documents(docs) | |
| vectordb = Chroma.from_documents( | |
| chunks, | |
| embedding_model, | |
| persist_directory=PERSIST_DIR, | |
| ) | |
| vectordb.persist() | |
| else: | |
| print("[INFO] No existing DB, scraping + building...") | |
| docs = scrape_k8s_docs() | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=900, chunk_overlap=200 | |
| ) | |
| chunks = splitter.split_documents(docs) | |
| vectordb = Chroma.from_documents( | |
| chunks, | |
| embedding_model, | |
| persist_directory=PERSIST_DIR, | |
| ) | |
| vectordb.persist() | |
| print("[INFO] Chroma DB built and persisted.") | |
| return vectordb, chunks, embedding_model | |
| vectordb, chunks, embedding_model = build_or_load_kb() | |
| # -------------------- HYBRID SEARCH + RERANKER -------------------- # | |
| print("[INFO] Initializing BM25 + CrossEncoder reranker...") | |
| bm25_corpus = [doc.page_content.split() for doc in chunks] | |
| bm25 = BM25Okapi(bm25_corpus) | |
| # Balanced reranker model (Option B you chose) | |
| reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2") | |
| retriever = vectordb.as_retriever( | |
| search_type="similarity_score_threshold", | |
| search_kwargs={"k": 8, "score_threshold": 0.35}, | |
| ) | |
| def hybrid_search(query: str, top_k: int = 5): | |
| # Vector search | |
| vector_results = retriever.invoke(query) | |
| # BM25 keyword search | |
| tokenized_query = query.lower().split() | |
| bm25_scores = bm25.get_scores(tokenized_query) | |
| bm25_ranked = sorted( | |
| zip(bm25_scores, chunks), key=lambda x: x[0], reverse=True | |
| ) | |
| bm25_results = [d for _, d in bm25_ranked[:top_k]] | |
| # Combine + dedupe | |
| combined = vector_results + bm25_results | |
| unique = [] | |
| seen = set() | |
| for d in combined: | |
| key = (d.metadata.get("doc_id", ""), d.page_content[:80]) | |
| if key not in seen: | |
| seen.add(key) | |
| unique.append(d) | |
| if not unique: | |
| return [] | |
| # Rerank with cross-encoder | |
| pairs = [(query, doc.page_content) for doc in unique] | |
| scores = reranker.predict(pairs) | |
| scored_docs = sorted(zip(scores, unique), key=lambda x: x[0], reverse=True) | |
| top_docs = scored_docs[:top_k] | |
| reranked = [] | |
| for score, doc in top_docs: | |
| doc.metadata["rerank_score"] = float(score) | |
| reranked.append(doc) | |
| return reranked | |
| # -------------------- LLM CALL (OpenRouter) -------------------- # | |
| def call_llm(prompt: str) -> str: | |
| url = "https://openrouter.ai/api/v1/chat/completions" | |
| api_key = os.getenv("OPENROUTER_API_KEY") | |
| if not api_key: | |
| print("[ERROR] OPENROUTER_API_KEY not set.") | |
| return ( | |
| "⚠️ Model failed: missing OPENROUTER_API_KEY environment variable.\n\n" | |
| "Groundedness: 0%" | |
| ) | |
| headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "HTTP-Referer": "https://huggingface.co/", | |
| "X-Title": "Kubernetes RAG Assistant", | |
| } | |
| data = { | |
| "model": "meta-llama/llama-3.1-8b-instruct", | |
| "messages": [{"role": "user", "content": prompt}], | |
| "max_tokens": 400, | |
| "temperature": 0.0, | |
| } | |
| try: | |
| r = requests.post(url, headers=headers, json=data, timeout=60) | |
| r.raise_for_status() | |
| res = r.json() | |
| except Exception as e: | |
| print(f"[ERROR] LLM call failed: {e}") | |
| return "⚠️ Model failed. Please retry.\n\nGroundedness: 0%" | |
| if "choices" in res and res["choices"]: | |
| return res["choices"][0]["message"]["content"] | |
| print("[ERROR] Unexpected LLM response:", res) | |
| return "⚠️ Model failed. Please retry.\n\nGroundedness: 0%" | |
| # -------------------- CONTEXT + CITATIONS -------------------- # | |
| def build_context_with_citations(query: str, history=None, top_k: int = 5): | |
| """ | |
| Use hybrid search + conversation-aware follow-up handling. | |
| """ | |
| effective_query = query | |
| if history: | |
| last_user_q = history[-1][0] if history[-1] else "" | |
| followup_tokens = [ | |
| "and", "also", "that", "those", "it", "them", "one", | |
| "this", "these", "more", "what about" | |
| ] | |
| if len(query.split()) <= 4 or any(t in query.lower() for t in followup_tokens): | |
| effective_query = f"{last_user_q} | Follow-up: {query}" | |
| docs = hybrid_search(effective_query, top_k=top_k) | |
| if not docs: | |
| return "", [], [], [] | |
| context = "" | |
| sources = [] | |
| scores = [] | |
| doc_ids = [] | |
| for i, d in enumerate(docs, start=1): | |
| label = f"[{i}]" | |
| snippet = d.page_content[:900].strip() | |
| url = d.metadata.get("url", "N/A") | |
| score = float(d.metadata.get("rerank_score", 0.0)) | |
| context += ( | |
| f"{label} (score={score:.2f})\n" | |
| f"{snippet}\n" | |
| f"Source: {url}\n\n" | |
| ) | |
| sources.append(f"{label} → {url}") | |
| scores.append(score) | |
| doc_ids.append(d.metadata.get("doc_id", "k8s-doc")) | |
| return context, sources, scores, doc_ids | |
| # -------------------- QUERY CLASSIFIER -------------------- # | |
| def classify_query(query: str) -> str: | |
| q = query.lower() | |
| if any(q.startswith(p) for p in ["what is", "define", "explain"]): | |
| return "definition" | |
| if any(k in q for k in ["how to", "how do i", "steps", "tutorial"]): | |
| return "how-to" | |
| if any(k in q for k in ["error", "failed", "crash", "issue", "troubleshoot"]): | |
| return "debugging" | |
| if any(k in q for k in ["best practice", "recommend", "should i"]): | |
| return "best-practice" | |
| return "general" | |
| # -------------------- ANALYTICS STORAGE -------------------- # | |
| def init_analytics(): | |
| return { | |
| "queries": [], | |
| "latency": [], | |
| "approx_tokens": [], | |
| "groundedness": [], | |
| "avg_rerank_score": [], | |
| "citation_count": [], | |
| "query_type": [], | |
| } | |
| # -------------------- MAIN ANSWER FUNCTION -------------------- # | |
| def answer_question(query, history, analytics): | |
| if analytics is None or analytics == {}: | |
| analytics = init_analytics() | |
| start_time = time.time() | |
| context, sources, scores, doc_ids = build_context_with_citations(query, history) | |
| # Retrieval failure – safe response | |
| if not context: | |
| resp = ( | |
| "Not in documentation or insufficient context to answer confidently.\n\n" | |
| "Possible reasons:\n" | |
| "- The question is too vague or missing key details.\n" | |
| "- The topic may not be covered in the scraped Kubernetes docs.\n\n" | |
| "Try rephrasing with more detail.\n\n" | |
| "Groundedness: 0%" | |
| ) | |
| latency = time.time() - start_time | |
| analytics["queries"].append(query) | |
| analytics["latency"].append(latency) | |
| analytics["approx_tokens"].append(len(resp.split())) | |
| analytics["groundedness"].append(0) | |
| analytics["avg_rerank_score"].append(0.0) | |
| analytics["citation_count"].append(0) | |
| analytics["query_type"].append(classify_query(query)) | |
| history.append((query, resp)) | |
| return history, "", analytics | |
| # Recent conversation context (not for citations) | |
| conversation_context = "" | |
| if history: | |
| last_turns = history[-3:] | |
| for uq, aq in last_turns: | |
| conversation_context += f"User: {uq}\nAssistant: {aq}\n\n" | |
| prompt = f""" | |
| You are a strict Kubernetes documentation assistant. | |
| RULES: | |
| - Answer ONLY using the Context section. | |
| - EVERY sentence must end with at least one citation like [1] or [2]. | |
| - If the answer is not found in the context, respond exactly: | |
| "Not in documentation: Please rephrase or check the official Kubernetes docs." | |
| - Do NOT invent APIs, flags, YAML fields, or behaviors not shown in the context. | |
| - Use short, precise sentences. | |
| - At the END, output a separate line: Groundedness: XX% | |
| - XX is an integer from 0 to 100. | |
| - 100 means every statement is directly and clearly supported. | |
| - Lower if you are uncertain or context is thin. | |
| User Question: | |
| {query} | |
| Recent Conversation (for context, not citations): | |
| {conversation_context} | |
| Context (with source ids and rerank scores): | |
| {context} | |
| """ | |
| answer = call_llm(prompt) | |
| latency = time.time() - start_time | |
| approx_tokens = len(prompt.split()) + len(answer.split()) | |
| groundedness_match = re.search(r"Groundedness:\s*(\d+)%", answer) | |
| groundedness = int(groundedness_match.group(1)) if groundedness_match else 0 | |
| citation_matches = re.findall(r"\[(\d+)\]", answer) | |
| unique_citations = set(citation_matches) | |
| citation_count = len(unique_citations) | |
| avg_rerank_score = sum(scores) / len(scores) if scores else 0.0 | |
| # Low groundedness / no citations alert | |
| alert = "" | |
| if groundedness < 70 or citation_count == 0: | |
| alert = ( | |
| "⚠️ Warning: This response may not be fully supported by the retrieved Kubernetes documentation.\n" | |
| "Consider rephrasing your question with more specific details, or verifying in the official docs.\n\n" | |
| ) | |
| final_answer = alert + answer + "\n\n---\nSources:\n" + "\n".join(sources) | |
| history.append((query, final_answer)) | |
| analytics["queries"].append(query) | |
| analytics["latency"].append(latency) | |
| analytics["approx_tokens"].append(approx_tokens) | |
| analytics["groundedness"].append(groundedness) | |
| analytics["avg_rerank_score"].append(avg_rerank_score) | |
| analytics["citation_count"].append(citation_count) | |
| analytics["query_type"].append(classify_query(query)) | |
| return history, "", analytics | |
| # -------------------- ANALYTICS RENDERING -------------------- # | |
| def render_analytics(analytics): | |
| if not analytics or len(analytics["queries"]) == 0: | |
| return [], 0.0, 0.0, 0.0 | |
| rows = [] | |
| for i, q in enumerate(analytics["queries"]): | |
| rows.append([ | |
| i + 1, | |
| q, | |
| round(analytics["latency"][i], 3), | |
| analytics["approx_tokens"][i], | |
| analytics["groundedness"][i], | |
| round(analytics["avg_rerank_score"][i], 3), | |
| analytics["citation_count"][i], | |
| analytics["query_type"][i], | |
| ]) | |
| avg_latency = sum(analytics["latency"]) / len(analytics["latency"]) | |
| avg_grounded = sum(analytics["groundedness"]) / len(analytics["groundedness"]) | |
| avg_tokens = sum(analytics["approx_tokens"]) / len(analytics["approx_tokens"]) | |
| return rows, avg_latency, avg_grounded, avg_tokens | |
| def generate_charts(analytics): | |
| if not analytics or len(analytics["queries"]) == 0: | |
| return None, None, None, None | |
| df = pd.DataFrame({ | |
| "Latency": analytics["latency"], | |
| "Groundedness": analytics["groundedness"], | |
| "Tokens": analytics["approx_tokens"], | |
| "Query Type": analytics["query_type"], | |
| }) | |
| # Latency chart | |
| fig_latency, ax1 = plt.subplots() | |
| ax1.plot(df["Latency"]) | |
| ax1.set_title("Latency Over Time") | |
| ax1.set_xlabel("Query #") | |
| ax1.set_ylabel("Seconds") | |
| # Groundedness chart | |
| fig_ground, ax2 = plt.subplots() | |
| ax2.plot(df["Groundedness"]) | |
| ax2.set_title("Groundedness Trend") | |
| ax2.set_xlabel("Query #") | |
| ax2.set_ylabel("Groundedness (%)") | |
| # Token usage chart | |
| fig_tokens, ax3 = plt.subplots() | |
| ax3.plot(df["Tokens"]) | |
| ax3.set_title("Token Usage Over Time") | |
| ax3.set_xlabel("Query #") | |
| ax3.set_ylabel("Approx Tokens") | |
| # Query type distribution pie chart | |
| fig_pie, ax4 = plt.subplots() | |
| df["Query Type"].value_counts().plot.pie( | |
| ax=ax4, | |
| autopct="%1.1f%%", | |
| ) | |
| ax4.set_ylabel("") | |
| ax4.set_title("Query Types Distribution") | |
| return fig_latency, fig_ground, fig_tokens, fig_pie | |
| def export_csv(analytics): | |
| if not analytics or len(analytics["queries"]) == 0: | |
| path = "analytics.csv" | |
| pd.DataFrame(columns=[ | |
| "query", "latency", "approx_tokens", "groundedness", | |
| "avg_rerank_score", "citation_count", "query_type" | |
| ]).to_csv(path, index=False) | |
| return path | |
| df = pd.DataFrame({ | |
| "query": analytics["queries"], | |
| "latency": analytics["latency"], | |
| "approx_tokens": analytics["approx_tokens"], | |
| "groundedness": analytics["groundedness"], | |
| "avg_rerank_score": analytics["avg_rerank_score"], | |
| "citation_count": analytics["citation_count"], | |
| "query_type": analytics["query_type"], | |
| }) | |
| path = "analytics.csv" | |
| df.to_csv(path, index=False) | |
| return path | |
| def clear_all(): | |
| return [], "", init_analytics() | |
| # -------------------- GRADIO UI -------------------- # | |
| custom_css = """ | |
| .source-box { | |
| background: #1e293b; | |
| color: #dbeafe; | |
| padding: 10px; | |
| border-radius: 7px; | |
| border: 1px solid #3b82f6; | |
| } | |
| """ | |
| with gr.Blocks(theme="soft") as app: | |
| gr.HTML(f"<style>{custom_css}</style>") | |
| gr.HTML( | |
| "<h1 style='text-align:center;color:#3b82f6'>☸ Kubernetes RAG Assistant</h1>" | |
| "<p style='text-align:center;color:#cbd5e1'>Hybrid Search • Reranked • Cited • Monitored 📌</p>" | |
| ) | |
| analytics_state = gr.State(init_analytics()) | |
| with gr.Tab("Chatbot"): | |
| chat = gr.Chatbot(label="Conversation", height=450) | |
| msg = gr.Textbox( | |
| label="Ask anything about Kubernetes…", | |
| placeholder="e.g., What is RBAC?", | |
| ) | |
| clear = gr.Button("Clear Conversation") | |
| msg.submit( | |
| answer_question, | |
| inputs=[msg, chat, analytics_state], | |
| outputs=[chat, msg, analytics_state], | |
| ) | |
| clear.click( | |
| clear_all, | |
| inputs=None, | |
| outputs=[chat, msg, analytics_state], | |
| ) | |
| with gr.Tab("Analytics Dashboard"): | |
| gr.Markdown("### 📊 System Metrics") | |
| gr.Markdown( | |
| "- Each row is a user query\n" | |
| "- Latency = retrieval + LLM time\n" | |
| "- Groundedness = model-reported confidence based on docs\n" | |
| "- Rerank score = cross-encoder relevance\n" | |
| "- Citation count = number of unique [n] labels used in the answer" | |
| ) | |
| analytics_table = gr.Dataframe( | |
| headers=[ | |
| "ID", | |
| "Query", | |
| "Latency (s)", | |
| "Approx Tokens", | |
| "Groundedness (%)", | |
| "Avg Rerank Score", | |
| "Citations Used", | |
| "Query Type", | |
| ], | |
| row_count=0, | |
| col_count=8, | |
| interactive=False, | |
| label="Query Stats", | |
| ) | |
| avg_latency_box = gr.Number(label="Average Latency (s)", precision=3) | |
| avg_ground_box = gr.Number(label="Average Groundedness (%)", precision=1) | |
| avg_tokens_box = gr.Number(label="Average Tokens per Answer", precision=1) | |
| plot_latency = gr.Plot(label="Latency Trend") | |
| plot_ground = gr.Plot(label="Groundedness Trend") | |
| plot_tokens = gr.Plot(label="Token Usage Trend") | |
| plot_pie = gr.Plot(label="Query Types Distribution") | |
| refresh_btn = gr.Button("Refresh Analytics") | |
| export_btn = gr.Button("Export Analytics as CSV") | |
| file_out = gr.File(label="Download CSV") | |
| # Refresh metrics table + summary | |
| refresh_btn.click( | |
| render_analytics, | |
| inputs=[analytics_state], | |
| outputs=[ | |
| analytics_table, | |
| avg_latency_box, | |
| avg_ground_box, | |
| avg_tokens_box, | |
| ], | |
| ) | |
| # Refresh charts | |
| refresh_btn.click( | |
| generate_charts, | |
| inputs=[analytics_state], | |
| outputs=[plot_latency, plot_ground, plot_tokens, plot_pie], | |
| ) | |
| # Export CSV | |
| export_btn.click( | |
| export_csv, | |
| inputs=[analytics_state], | |
| outputs=[file_out], | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |