Prakyath01's picture
Update app.py
de72d5d verified
raw
history blame
7.85 kB
import os
import re
import time
import requests
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import gradio as gr
from bs4 import BeautifulSoup
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from rank_bm25 import BM25Okapi
from sentence_transformers import CrossEncoder
PERSIST_DIR = "k8s_chroma_db"
URLS = {
"pods": "https://kubernetes.io/docs/concepts/workloads/pods/",
"deployments": "https://kubernetes.io/docs/concepts/workloads/controllers/deployment/",
"services": "https://kubernetes.io/docs/concepts/services-networking/service/",
"namespaces": "https://kubernetes.io/docs/concepts/overview/working-with-objects/namespaces/",
"nodes": "https://kubernetes.io/docs/concepts/architecture/nodes/",
"statefulsets": "https://kubernetes.io/docs/concepts/workloads/controllers/statefulset/",
"rbac": "https://kubernetes.io/docs/reference/access-authn-authz/rbac/",
"persistent-volumes": "https://kubernetes.io/docs/concepts/storage/persistent-volumes/",
"ingress": "https://kubernetes.io/docs/concepts/services-networking/ingress/",
"autoscaling": "https://kubernetes.io/docs/tasks/run-application/horizontal-pod-autoscale/",
}
def scrape_page(name, url):
try:
r = requests.get(url, timeout=20)
soup = BeautifulSoup(r.text, "html.parser")
content = soup.find("div", class_="td-content")
if not content:
return None
text = content.get_text(separator="\n").strip()
return Document(page_content=text, metadata={"doc_id": name, "url": url})
except:
return None
def build_or_load_kb():
embedding_model = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
if os.path.isdir(PERSIST_DIR):
vectordb = Chroma(
embedding_function=embedding_model,
persist_directory=PERSIST_DIR
)
data = vectordb._collection.get(include=["documents", "metadatas"])
chunks = [
Document(page_content=doc, metadata=meta)
for doc, meta in zip(data["documents"], data["metadatas"])
]
return vectordb, chunks
docs = []
for name, url in URLS.items():
d = scrape_page(name, url)
if d:
docs.append(d)
splitter = RecursiveCharacterTextSplitter(chunk_size=900, chunk_overlap=200)
chunks = splitter.split_documents(docs)
vectordb = Chroma.from_documents(
chunks,
embedding_model,
persist_directory=PERSIST_DIR
)
return vectordb, chunks
vectordb, chunks = build_or_load_kb()
bm25_corpus = [doc.page_content.split() for doc in chunks]
bm25 = BM25Okapi(bm25_corpus)
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2")
retriever = vectordb.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 8, "score_threshold": 0.4},
)
def hybrid_search(query):
vresults = retriever.invoke(query)
tokens = query.lower().split()
bm_scores = bm25.get_scores(tokens)
bm_ranked = sorted(zip(bm_scores, chunks), key=lambda x: x[0], reverse=True)
bmresults = [d for _, d in bm_ranked[:5]]
combined = vresults + bmresults
seen = set()
unique = []
for d in combined:
key = (d.metadata.get("doc_id"), d.page_content[:50])
if key not in seen:
seen.add(key)
unique.append(d)
if not unique:
return []
pairs = [(query, d.page_content) for d in unique]
scores = reranker.predict(pairs)
ranked = sorted(zip(scores, unique), key=lambda x: x[0], reverse=True)[:5]
for s, d in ranked:
d.metadata["rerank_score"] = float(s)
return [d for _, d in ranked]
def call_llm(prompt):
api_key = os.getenv("OPENROUTER_API_KEY")
if not api_key:
return "⚠️ Missing OPENROUTER_API_KEY environment variable.\nGroundedness: 0%"
try:
res = requests.post(
"https://openrouter.ai/api/v1/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"HTTP-Referer": "https://huggingface.co/",
"X-Title": "Kubernetes RAG Assistant"
},
json={
"model": "meta-llama/llama-3.1-8b-instruct",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 300,
"temperature": 0.2
}
)
res.raise_for_status()
return res.json()["choices"][0]["message"]["content"]
except Exception as e:
return f"⚠️ LLM Error: {e}\nGroundedness: 0%"
def build_context(query):
docs = hybrid_search(query)
if not docs:
return "", [], []
ctx, srcs, scores = "", [], []
for i, d in enumerate(docs, start=1):
ctx += f"[{i}] {d.page_content[:900]}\nSource: {d.metadata['url']}\n\n"
srcs.append(f"[{i}] → {d.metadata['url']}")
scores.append(d.metadata["rerank_score"])
return ctx, srcs, scores
def init_metrics():
return {"q":[], "lat":[], "tok":[], "g":[], "cit":[], "r":[], "type":[]}
def classify(q):
q=q.lower()
return "how-to" if "how" in q else ("debug" if "error" in q else "general")
def answer(q, history, metrics):
if metrics is None: metrics = init_metrics()
start = time.time()
ctx, srcs, scores = build_context(q)
if not ctx:
txt = "Not in docs.\nGroundedness: 0%"
history.append((q, txt))
return history,"",metrics
prompt = f"""Use context ONLY. Cite every sentence as [n].
User question: {q}
Context:
{ctx}
Groundedness MUST appear as: Groundedness: XX%"""
txt = call_llm(prompt)
latency = time.time() - start
grounded = int(re.search(r"Groundedness:\s*(\d+)%", txt).group(1)) if "Groundedness" in txt else 0
tokens = len(txt.split())
cites = len(set(re.findall(r"\[(\d+)\]", txt)))
avg = sum(scores)/len(scores)
final = txt+"\n\nSources:\n"+"\n".join(srcs)
history.append((q, final))
metrics["q"].append(q)
metrics["lat"].append(latency)
metrics["tok"].append(tokens)
metrics["g"].append(grounded)
metrics["cit"].append(cites)
metrics["r"].append(avg)
metrics["type"].append(classify(q))
return history,"",metrics
def render(metrics):
if len(metrics["q"])==0: return [],0,0,0
rows=[[
i+1, metrics["q"][i], round(metrics["lat"][i],3),
metrics["tok"][i], metrics["g"][i],
round(metrics["r"][i],2), metrics["cit"][i], metrics["type"][i]
] for i in range(len(metrics["q"]))]
avgL=sum(metrics["g"])/len(metrics["g"])
avgG=sum(metrics["lat"])/len(metrics["lat"])
avgT=sum(metrics["tok"])/len(metrics["tok"])
return rows,avgL,avgG,avgT
metrics_state = gr.State(init_metrics())
with gr.Blocks(title="Kubernetes RAG Assistant") as app:
gr.Markdown("# ☸ Kubernetes RAG Assistant")
with gr.Tab("Chat"):
chat = gr.Chatbot()
inp = gr.Textbox(label="Ask anything about Kubernetes")
clear= gr.Button("Reset")
inp.submit(answer,[inp,chat,metrics_state],[chat,inp,metrics_state])
clear.click(lambda: ([], "", init_metrics()), None, [chat,inp,metrics_state])
with gr.Tab("Analytics"):
table = gr.DataFrame(headers=["ID","Query","Latency","Tokens","Grounded","Relevance","Citations","Type"])
avgL = gr.Number(label="Avg Groundedness")
avgG = gr.Number(label="Avg Latency")
avgT = gr.Number(label="Avg Tokens")
refresh = gr.Button("Update Dashboard")
refresh.click(render,[metrics_state],[table,avgL,avgG,avgT])
app.launch()