sofzcc's picture
Update app.py
e202573 verified
raw
history blame
9.75 kB
import os
import glob
from typing import List, Tuple
import gradio as gr
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# -----------------------------
# CONFIG
# -----------------------------
KB_DIR = "./kb" # optional: folder with .txt or .md files
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
GEN_MODEL_NAME = "google/flan-t5-base"
TOP_K = 3
CHUNK_SIZE = 500 # characters
CHUNK_OVERLAP = 100 # characters
# -----------------------------
# UTILITIES
# -----------------------------
def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]:
"""Split long text into overlapping chunks so retrieval is more precise."""
if not text:
return []
chunks = []
start = 0
length = len(text)
while start < length:
end = min(start + chunk_size, length)
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
start += chunk_size - overlap
return chunks
def load_kb_texts(kb_dir: str = KB_DIR) -> List[Tuple[str, str]]:
"""
Load all .txt and .md files from the KB directory.
Returns a list of (source_name, content).
"""
texts = []
if os.path.isdir(kb_dir):
paths = glob.glob(os.path.join(kb_dir, "*.txt")) + glob.glob(os.path.join(kb_dir, "*.md"))
for path in paths:
try:
with open(path, "r", encoding="utf-8") as f:
content = f.read()
if content.strip():
texts.append((os.path.basename(path), content))
except Exception as e:
print(f"Could not read {path}: {e}")
# If no files found, fall back to some built-in demo content
if not texts:
print("No KB files found. Using built-in demo content.")
demo_text = """
Welcome to the Self-Service KB Assistant.
This assistant is meant to help you find information inside a knowledge base.
In a real setup, it would be connected to your own articles, procedures,
troubleshooting guides and FAQs.
Good knowledge base content is:
- Clear and structured with headings, steps and expected outcomes.
- Written in a customer-friendly tone.
- Easy to scan, with short paragraphs and bullet points.
- Maintained regularly to reflect product and process changes.
Example use cases for a KB assistant:
- Agents quickly searching for internal procedures.
- Customers asking “how do I…” style questions.
- Managers analyzing gaps in documentation based on repeated queries.
"""
texts.append(("demo_content.txt", demo_text))
return texts
# -----------------------------
# KB INDEX
# -----------------------------
class KBIndex:
def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
print("Loading embedding model...")
self.model = SentenceTransformer(model_name)
print("Model loaded.")
self.chunks: List[str] = []
self.chunk_sources: List[str] = []
self.embeddings: np.ndarray | None = None
self.build_index()
def build_index(self):
"""Load KB texts, split into chunks, and build an embedding index."""
texts = load_kb_texts(KB_DIR)
all_chunks = []
all_sources = []
for source_name, content in texts:
for chunk in chunk_text(content):
all_chunks.append(chunk)
all_sources.append(source_name)
if not all_chunks:
print("⚠️ No chunks found for KB index.")
self.chunks = []
self.chunk_sources = []
self.embeddings = None
return
print(f"Creating embeddings for {len(all_chunks)} chunks...")
embeddings = self.model.encode(all_chunks, show_progress_bar=False, convert_to_numpy=True)
self.chunks = all_chunks
self.chunk_sources = all_sources
self.embeddings = embeddings
print("KB index ready.")
def search(self, query: str, top_k: int = TOP_K) -> List[Tuple[str, str, float]]:
"""Return top-k (chunk, source_name, score) for a given query."""
if not query.strip():
return []
if self.embeddings is None or not len(self.chunks):
return []
query_vec = self.model.encode([query], show_progress_bar=False, convert_to_numpy=True)[0]
# Cosine similarity
dot_scores = np.dot(self.embeddings, query_vec)
norm_docs = np.linalg.norm(self.embeddings, axis=1)
norm_query = np.linalg.norm(query_vec) + 1e-10
scores = dot_scores / (norm_docs * norm_query + 1e-10)
top_idx = np.argsort(scores)[::-1][:top_k]
results = []
for idx in top_idx:
results.append((self.chunks[idx], self.chunk_sources[idx], float(scores[idx])))
return results
kb_index = KBIndex()
print("Loading generation model...")
gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gen_model.to(device)
gen_model.eval()
print("Generation model ready.")
# -----------------------------
# LLM (FLAN-T5-Large) - lazy load
# -----------------------------
_llm_pipeline = None
def get_llm():
"""
Lazily load FLAN-T5-Large as a text2text-generation pipeline.
This avoids blocking startup too much.
"""
global _llm_pipeline
if _llm_pipeline is not None:
return _llm_pipeline
print("Loading FLAN-T5-Large model...")
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch
tokenizer = AutoTokenizer.from_pretrained(FLAN_MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(FLAN_MODEL_NAME)
device = 0 if torch.cuda.is_available() else -1
_llm_pipeline = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
device=device,
)
print("FLAN-T5-Large loaded.")
return _llm_pipeline
# -----------------------------
# CHAT LOGIC
# -----------------------------
def build_context_from_results(results: List[Tuple[str, str, float]]) -> str:
"""
Turn retrieved chunks into a compact context string for the LLM.
"""
context_parts = []
for chunk, source, score in results:
# Keep it concise; we don't need every line label
cleaned = chunk.strip()
context_parts.append(f"From {source}:\n{cleaned}")
return "\n\n".join(context_parts)
def build_answer(query: str) -> str:
"""
Use the KB index to retrieve relevant chunks,
then ask FLAN-T5 to write a natural answer based ONLY on that context.
"""
results = kb_index.search(query, top_k=TOP_K)
if not results:
return (
"I couldn't find anything relevant in the knowledge base for this query yet.\n\n"
"If this were connected to your real KB, this would be a good moment to:\n"
"- Create a new article, or\n"
"- Improve the existing documentation for this topic."
)
# Build context for the model
context = build_context_from_results(results)
# Short list of sources for a small citation line
source_names = list({src for _, src, _ in results})
source_line = "Based on: " + ", ".join(source_names)
# Prompt for FLAN-T5
prompt = (
"You are a helpful knowledge base assistant.\n"
"Using ONLY the information in the context below, answer the user's question "
"in a clear, concise, and natural way. Focus on practical guidance.\n\n"
f"Context:\n{context}\n\n"
f"Question: {query}\n\n"
"Answer in 2–5 short paragraphs. If something is not covered in the context, say that.\n"
)
inputs = gen_tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=2048,
).to(device)
with torch.no_grad():
output_ids = gen_model.generate(
**inputs,
max_length=512,
temperature=0.7,
top_p=0.95,
num_beams=4,
)
answer_text = gen_tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
# Add a subtle source hint at the end
final_answer = f"{answer_text}\n\n— {source_line}"
return final_answer
def chat_respond(message: str, history):
"""
Gradio ChatInterface (type='messages') calls this with:
- message: latest user message (str)
- history: list of previous messages (handled by Gradio)
We only need to return the assistant's reply as a string.
"""
answer = build_answer(message)
return answer
# -----------------------------
# GRADIO UI
# -----------------------------
description = """
Ask questions as if you were talking to a knowledge base assistant.
In a real scenario, this assistant would be connected to your own
help center or internal documentation. Here, it's using a small demo
knowledge base to show how retrieval-based self-service can work.
"""
chat = gr.ChatInterface(
fn=chat_respond,
title="Self-Service KB Assistant",
description=description,
type="messages", # use new-style message format
examples=[
"What makes a good knowledge base article?",
"How could a KB assistant help agents?",
"Why is self-service important for customer support?",
],
cache_examples=False, # avoid example pre-caching issues on HF Spaces
)
if __name__ == "__main__":
chat.launch()