Spaces:
Sleeping
Sleeping
| import os | |
| import glob | |
| from typing import List, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import torch | |
| # ----------------------------- | |
| # CONFIG | |
| # ----------------------------- | |
| KB_DIR = "./kb" # optional: folder with .txt or .md files | |
| EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
| GEN_MODEL_NAME = "google/flan-t5-base" | |
| TOP_K = 3 | |
| CHUNK_SIZE = 500 # characters | |
| CHUNK_OVERLAP = 100 # characters | |
| # ----------------------------- | |
| # UTILITIES | |
| # ----------------------------- | |
| def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]: | |
| """Split long text into overlapping chunks so retrieval is more precise.""" | |
| if not text: | |
| return [] | |
| chunks = [] | |
| start = 0 | |
| length = len(text) | |
| while start < length: | |
| end = min(start + chunk_size, length) | |
| chunk = text[start:end].strip() | |
| if chunk: | |
| chunks.append(chunk) | |
| start += chunk_size - overlap | |
| return chunks | |
| def load_kb_texts(kb_dir: str = KB_DIR) -> List[Tuple[str, str]]: | |
| """ | |
| Load all .txt and .md files from the KB directory. | |
| Returns a list of (source_name, content). | |
| """ | |
| texts = [] | |
| if os.path.isdir(kb_dir): | |
| paths = glob.glob(os.path.join(kb_dir, "*.txt")) + glob.glob(os.path.join(kb_dir, "*.md")) | |
| for path in paths: | |
| try: | |
| with open(path, "r", encoding="utf-8") as f: | |
| content = f.read() | |
| if content.strip(): | |
| texts.append((os.path.basename(path), content)) | |
| except Exception as e: | |
| print(f"Could not read {path}: {e}") | |
| # If no files found, fall back to some built-in demo content | |
| if not texts: | |
| print("No KB files found. Using built-in demo content.") | |
| demo_text = """ | |
| Welcome to the Self-Service KB Assistant. | |
| This assistant is meant to help you find information inside a knowledge base. | |
| In a real setup, it would be connected to your own articles, procedures, | |
| troubleshooting guides and FAQs. | |
| Good knowledge base content is: | |
| - Clear and structured with headings, steps and expected outcomes. | |
| - Written in a customer-friendly tone. | |
| - Easy to scan, with short paragraphs and bullet points. | |
| - Maintained regularly to reflect product and process changes. | |
| Example use cases for a KB assistant: | |
| - Agents quickly searching for internal procedures. | |
| - Customers asking “how do I…” style questions. | |
| - Managers analyzing gaps in documentation based on repeated queries. | |
| """ | |
| texts.append(("demo_content.txt", demo_text)) | |
| return texts | |
| # ----------------------------- | |
| # KB INDEX | |
| # ----------------------------- | |
| class KBIndex: | |
| def __init__(self, model_name: str = EMBEDDING_MODEL_NAME): | |
| print("Loading embedding model...") | |
| self.model = SentenceTransformer(model_name) | |
| print("Model loaded.") | |
| self.chunks: List[str] = [] | |
| self.chunk_sources: List[str] = [] | |
| self.embeddings: np.ndarray | None = None | |
| self.build_index() | |
| def build_index(self): | |
| """Load KB texts, split into chunks, and build an embedding index.""" | |
| texts = load_kb_texts(KB_DIR) | |
| all_chunks = [] | |
| all_sources = [] | |
| for source_name, content in texts: | |
| for chunk in chunk_text(content): | |
| all_chunks.append(chunk) | |
| all_sources.append(source_name) | |
| if not all_chunks: | |
| print("⚠️ No chunks found for KB index.") | |
| self.chunks = [] | |
| self.chunk_sources = [] | |
| self.embeddings = None | |
| return | |
| print(f"Creating embeddings for {len(all_chunks)} chunks...") | |
| embeddings = self.model.encode(all_chunks, show_progress_bar=False, convert_to_numpy=True) | |
| self.chunks = all_chunks | |
| self.chunk_sources = all_sources | |
| self.embeddings = embeddings | |
| print("KB index ready.") | |
| def search(self, query: str, top_k: int = TOP_K) -> List[Tuple[str, str, float]]: | |
| """Return top-k (chunk, source_name, score) for a given query.""" | |
| if not query.strip(): | |
| return [] | |
| if self.embeddings is None or not len(self.chunks): | |
| return [] | |
| query_vec = self.model.encode([query], show_progress_bar=False, convert_to_numpy=True)[0] | |
| # Cosine similarity | |
| dot_scores = np.dot(self.embeddings, query_vec) | |
| norm_docs = np.linalg.norm(self.embeddings, axis=1) | |
| norm_query = np.linalg.norm(query_vec) + 1e-10 | |
| scores = dot_scores / (norm_docs * norm_query + 1e-10) | |
| top_idx = np.argsort(scores)[::-1][:top_k] | |
| results = [] | |
| for idx in top_idx: | |
| results.append((self.chunks[idx], self.chunk_sources[idx], float(scores[idx]))) | |
| return results | |
| kb_index = KBIndex() | |
| print("Loading generation model...") | |
| gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME) | |
| gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| gen_model.to(device) | |
| gen_model.eval() | |
| print("Generation model ready.") | |
| # ----------------------------- | |
| # LLM (FLAN-T5-Large) - lazy load | |
| # ----------------------------- | |
| _llm_pipeline = None | |
| def get_llm(): | |
| """ | |
| Lazily load FLAN-T5-Large as a text2text-generation pipeline. | |
| This avoids blocking startup too much. | |
| """ | |
| global _llm_pipeline | |
| if _llm_pipeline is not None: | |
| return _llm_pipeline | |
| print("Loading FLAN-T5-Large model...") | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| import torch | |
| tokenizer = AutoTokenizer.from_pretrained(FLAN_MODEL_NAME) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(FLAN_MODEL_NAME) | |
| device = 0 if torch.cuda.is_available() else -1 | |
| _llm_pipeline = pipeline( | |
| "text2text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=device, | |
| ) | |
| print("FLAN-T5-Large loaded.") | |
| return _llm_pipeline | |
| # ----------------------------- | |
| # CHAT LOGIC | |
| # ----------------------------- | |
| def build_context_from_results(results: List[Tuple[str, str, float]]) -> str: | |
| """ | |
| Turn retrieved chunks into a compact context string for the LLM. | |
| """ | |
| context_parts = [] | |
| for chunk, source, score in results: | |
| # Keep it concise; we don't need every line label | |
| cleaned = chunk.strip() | |
| context_parts.append(f"From {source}:\n{cleaned}") | |
| return "\n\n".join(context_parts) | |
| def build_answer(query: str) -> str: | |
| """ | |
| Use the KB index to retrieve relevant chunks, | |
| then ask FLAN-T5 to write a natural answer based ONLY on that context. | |
| """ | |
| results = kb_index.search(query, top_k=TOP_K) | |
| if not results: | |
| return ( | |
| "I couldn't find anything relevant in the knowledge base for this query yet.\n\n" | |
| "If this were connected to your real KB, this would be a good moment to:\n" | |
| "- Create a new article, or\n" | |
| "- Improve the existing documentation for this topic." | |
| ) | |
| # Build context for the model | |
| context = build_context_from_results(results) | |
| # Short list of sources for a small citation line | |
| source_names = list({src for _, src, _ in results}) | |
| source_line = "Based on: " + ", ".join(source_names) | |
| # Prompt for FLAN-T5 | |
| prompt = ( | |
| "You are a helpful knowledge base assistant.\n" | |
| "Using ONLY the information in the context below, answer the user's question " | |
| "in a clear, concise, and natural way. Focus on practical guidance.\n\n" | |
| f"Context:\n{context}\n\n" | |
| f"Question: {query}\n\n" | |
| "Answer in 2–5 short paragraphs. If something is not covered in the context, say that.\n" | |
| ) | |
| inputs = gen_tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=2048, | |
| ).to(device) | |
| with torch.no_grad(): | |
| output_ids = gen_model.generate( | |
| **inputs, | |
| max_length=512, | |
| temperature=0.7, | |
| top_p=0.95, | |
| num_beams=4, | |
| ) | |
| answer_text = gen_tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() | |
| # Add a subtle source hint at the end | |
| final_answer = f"{answer_text}\n\n— {source_line}" | |
| return final_answer | |
| def chat_respond(message: str, history): | |
| """ | |
| Gradio ChatInterface (type='messages') calls this with: | |
| - message: latest user message (str) | |
| - history: list of previous messages (handled by Gradio) | |
| We only need to return the assistant's reply as a string. | |
| """ | |
| answer = build_answer(message) | |
| return answer | |
| # ----------------------------- | |
| # GRADIO UI | |
| # ----------------------------- | |
| description = """ | |
| Ask questions as if you were talking to a knowledge base assistant. | |
| In a real scenario, this assistant would be connected to your own | |
| help center or internal documentation. Here, it's using a small demo | |
| knowledge base to show how retrieval-based self-service can work. | |
| """ | |
| chat = gr.ChatInterface( | |
| fn=chat_respond, | |
| title="Self-Service KB Assistant", | |
| description=description, | |
| type="messages", # use new-style message format | |
| examples=[ | |
| "What makes a good knowledge base article?", | |
| "How could a KB assistant help agents?", | |
| "Why is self-service important for customer support?", | |
| ], | |
| cache_examples=False, # avoid example pre-caching issues on HF Spaces | |
| ) | |
| if __name__ == "__main__": | |
| chat.launch() | |