sofzcc's picture
Update app.py
bf0ef35 verified
raw
history blame
9.21 kB
import os
import re
import json
from pathlib import Path
from typing import List, Dict, Tuple
import numpy as np
import faiss
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering
from sentence_transformers import SentenceTransformer
# ----------- Paths -----------
KB_DIR = Path("./kb")
INDEX_DIR = Path("./.index")
INDEX_DIR.mkdir(exist_ok=True, parents=True)
# ----------- Models (free) -----------
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
READER_MODEL_NAME = "deepset/roberta-base-squad2"
EMBEDDINGS_PATH = INDEX_DIR / "kb_embeddings.npy"
METADATA_PATH = INDEX_DIR / "kb_metadata.json"
FAISS_PATH = INDEX_DIR / "kb_faiss.index"
HEADING_RE = re.compile(r"^(#{1,6})\s+(.*)$", re.MULTILINE)
# ----------- Load Markdown -----------
def read_markdown_files(kb_dir: Path) -> List[Dict]:
docs = []
for md_path in sorted(kb_dir.glob("*.md")):
text = md_path.read_text(encoding="utf-8", errors="ignore")
title = md_path.stem.replace("_", " ").title()
m = re.search(r"^#\s+(.*)$", text, flags=re.MULTILINE)
if m:
title = m.group(1).strip()
docs.append({"filepath": str(md_path), "filename": md_path.name, "title": title, "text": text})
return docs
def chunk_markdown(doc: Dict, chunk_chars: int = 1200, overlap: int = 150) -> List[Dict]:
text = doc["text"]
sections = re.split(r"(?=^##\s+|\n##\s+|\n###\s+|^###\s+)", text, flags=re.MULTILINE)
if len(sections) == 1:
sections = [text]
chunks = []
for sec in sections:
sec = sec.strip()
if not sec:
continue
heading_match = HEADING_RE.search(sec)
section_heading = heading_match.group(2).strip() if heading_match else doc["title"]
start = 0
while start < len(sec):
end = min(start + chunk_chars, len(sec))
chunk_text = sec[start:end].strip()
if chunk_text:
chunks.append({
"doc_title": doc["title"],
"filename": doc["filename"],
"filepath": doc["filepath"],
"section": section_heading,
"content": chunk_text
})
if end == len(sec):
break
start = max(0, end - overlap)
return chunks
# ----------- KB Index -----------
class KBIndex:
def __init__(self):
self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
self.reader_tokenizer = AutoTokenizer.from_pretrained(READER_MODEL_NAME)
self.reader_model = AutoModelForQuestionAnswering.from_pretrained(READER_MODEL_NAME)
self.reader = pipeline("question-answering", model=self.reader_model, tokenizer=self.reader_tokenizer)
self.index = None
self.embeddings = None
self.metadata = []
def build(self, kb_dir: Path):
docs = read_markdown_files(kb_dir)
if not docs:
raise RuntimeError(f"No markdown files found in {kb_dir.resolve()}")
all_chunks = []
for d in docs:
all_chunks.extend(chunk_markdown(d))
texts = [c["content"] for c in all_chunks]
if not texts:
raise RuntimeError("No content chunks generated from KB.")
embeddings = self.embedder.encode(texts, batch_size=32, convert_to_numpy=True, show_progress_bar=False)
faiss.normalize_L2(embeddings)
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings)
self.index = index
self.embeddings = embeddings
self.metadata = all_chunks
np.save(EMBEDDINGS_PATH, embeddings)
with open(METADATA_PATH, "w", encoding="utf-8") as f:
json.dump(self.metadata, f, ensure_ascii=False, indent=2)
faiss.write_index(index, str(FAISS_PATH))
def load(self):
if not (EMBEDDINGS_PATH.exists() and METADATA_PATH.exists() and FAISS_PATH.exists()):
return False
self.embeddings = np.load(EMBEDDINGS_PATH)
with open(METADATA_PATH, "r", encoding="utf-8") as f:
self.metadata = json.load(f)
self.index = faiss.read_index(str(FAISS_PATH))
return True
def retrieve(self, query: str, top_k: int = 4) -> List[Tuple[int, float]]:
q_emb = self.embedder.encode([query], convert_to_numpy=True)
faiss.normalize_L2(q_emb)
D, I = self.index.search(q_emb, top_k)
return list(zip(I[0].tolist(), D[0].tolist()))
def answer(self, question: str, retrieved: List[Tuple[int, float]]):
best = {"text": None, "score": -1e9, "meta": None, "sim": 0.0}
for idx, sim in retrieved:
meta = self.metadata[idx]
ctx = meta["content"]
try:
out = self.reader(question=question, context=ctx)
except Exception:
continue
score = float(out.get("score", 0.0))
if score > best["score"]:
best = {"text": out.get("answer", "").strip(), "score": score, "meta": meta, "sim": float(sim)}
if not best["text"]:
return None, 0.0, []
citations = []
seen = set()
for idx, _ in retrieved[:2]:
m = self.metadata[idx]
key = (m["filename"], m["section"])
if key in seen:
continue
seen.add(key)
citations.append({"title": m["doc_title"], "filename": m["filename"], "section": m["section"]})
return best["text"], best["score"], citations
kb = KBIndex()
def ensure_index():
# Build on first run in Space; load if cached
if not kb.load():
kb.build(KB_DIR)
ensure_index()
# ----------- Guardrails -----------
LOW_CONF_THRESHOLD = 0.20
LOW_SIM_THRESHOLD = 0.30
HELPFUL_SUGGESTIONS = [
("Connect WhatsApp", "How do I connect my WhatsApp number?"),
("Reset Password", "I can't sign in / forgot my password"),
("First Automation", "How do I create my first automation?"),
("Billing & Invoices", "How do I download invoices for billing?"),
("Fix Instagram Connect", "Why can't I connect Instagram?")
]
def format_citations(citations: List[Dict]) -> str:
if not citations:
return ""
return "\n".join([f"• **{c['title']}** — _{c['section']}_ (`{c['filename']}`)" for c in citations])
def respond(user_msg, history):
user_msg = (user_msg or "").strip()
if not user_msg:
return "How can I help? Try: **Connect WhatsApp** or **Reset password**."
retrieved = kb.retrieve(user_msg, top_k=4)
if not retrieved:
return "I couldn’t find anything yet. Try rephrasing or pick a quick action below."
span, score, citations = kb.answer(user_msg, retrieved)
if not span:
suggestions = "\n".join([f"- {c['title']} — _{c['section']}_" for c in citations]) or "- Try a different query."
return f"I’m not fully sure. Here are the closest matches:\n\n{suggestions}"
best_sim = max([s for _, s in retrieved]) if retrieved else 0.0
low_conf = (score < LOW_CONF_THRESHOLD) or (best_sim < LOW_SIM_THRESHOLD)
citations_md = format_citations(citations)
base_answer = span if len(span) > 3 else "I found a relevant section. Opening the steps in the cited article."
if low_conf:
return (
f"{base_answer}\n\n---\n**I may be uncertain.** Here are relevant articles:\n{citations_md}\n\n"
f"If this doesn’t solve it, ask me to *escalate to human support*."
)
return f"{base_answer}\n\n---\n**Sources:**\n{citations_md}\n_Tip: Say **show full steps** for more detail._"
def quick_intent(label):
for l, q in HELPFUL_SUGGESTIONS:
if l == label:
return q
return ""
def rebuild_index():
kb.build(KB_DIR)
return gr.update(value="✅ Index rebuilt from KB.")
# ----------- Gradio UI -----------
with gr.Blocks(title="Self-Service KB Assistant", fill_height=True) as demo:
gr.Markdown(
"""
# 🧩 Self-Service Knowledge Assistant
Ask about setup, automations, billing, or troubleshooting.
The assistant answers **only from the Knowledge Base** and cites the articles it used.
**Quick actions:** Try one of the buttons below.
"""
)
with gr.Row():
outputs = []
for label, _ in HELPFUL_SUGGESTIONS:
btn = gr.Button(label)
btn.click(fn=lambda L=label: quick_intent(L), outputs=None)\
.then(fn=respond, inputs=[gr.State(quick_intent(label)), gr.State([])], outputs=gr.Chatbot())
outputs.append(btn)
chat = gr.ChatInterface(
fn=respond,
chatbot=gr.Chatbot(height=420, show_copy_button=True),
textbox=gr.Textbox(placeholder="e.g., How do I connect WhatsApp?"),
retry_btn="Retry",
undo_btn="Undo",
clear_btn="Clear",
)
with gr.Accordion("Admin", open=False):
gr.Markdown("Rebuild the search index after changing files in `/kb`.")
rebuild = gr.Button("Rebuild Index")
status = gr.Markdown("")
rebuild.click(fn=rebuild_index, outputs=status)
if __name__ == "__main__":
demo.launch()