Spaces:
Sleeping
Sleeping
| import os | |
| import glob | |
| from typing import List, Tuple | |
| import time | |
| import gradio as gr | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| # ----------------------------- | |
| # CONFIG | |
| # ----------------------------- | |
| KB_DIR = "./kb" # folder with .txt or .md files | |
| EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
| TOP_K = 3 | |
| CHUNK_SIZE = 500 # characters | |
| CHUNK_OVERLAP = 100 # characters | |
| MIN_SIMILARITY_THRESHOLD = 0.3 # Minimum similarity score to include results | |
| # ----------------------------- | |
| # UTILITIES | |
| # ----------------------------- | |
| def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]: | |
| """Split long text into overlapping chunks so retrieval is more precise.""" | |
| if not text: | |
| return [] | |
| chunks = [] | |
| start = 0 | |
| length = len(text) | |
| while start < length: | |
| end = min(start + chunk_size, length) | |
| chunk = text[start:end].strip() | |
| if chunk: | |
| chunks.append(chunk) | |
| start += chunk_size - overlap | |
| return chunks | |
| def load_kb_texts(kb_dir: str = KB_DIR) -> List[Tuple[str, str]]: | |
| """ | |
| Load all .txt and .md files from the KB directory. | |
| Returns a list of (source_name, content). | |
| """ | |
| texts = [] | |
| if os.path.isdir(kb_dir): | |
| paths = glob.glob(os.path.join(kb_dir, "*.txt")) + glob.glob(os.path.join(kb_dir, "*.md")) | |
| for path in paths: | |
| try: | |
| with open(path, "r", encoding="utf-8") as f: | |
| content = f.read() | |
| if content.strip(): | |
| texts.append((os.path.basename(path), content)) | |
| except Exception as e: | |
| print(f"Could not read {path}: {e}") | |
| # If no files found, fall back to built-in demo content | |
| if not texts: | |
| print("No KB files found. Using built-in demo content.") | |
| demo_text = """ | |
| Welcome to the Self-Service KB Assistant. | |
| This assistant is meant to help you find information inside a knowledge base. | |
| In a real setup, it would be connected to your own articles, procedures, | |
| troubleshooting guides and FAQs. | |
| Good knowledge base content is: | |
| - Clear and structured with headings, steps and expected outcomes. | |
| - Written in a customer-friendly tone. | |
| - Easy to scan, with short paragraphs and bullet points. | |
| - Maintained regularly to reflect product and process changes. | |
| Example use cases for a KB assistant: | |
| - Agents quickly searching for internal procedures. | |
| - Customers asking "how do I…" style questions. | |
| - Managers analyzing gaps in documentation based on repeated queries. | |
| """ | |
| texts.append(("demo_content.txt", demo_text)) | |
| return texts | |
| # ----------------------------- | |
| # KB INDEX | |
| # ----------------------------- | |
| class KBIndex: | |
| def __init__(self, model_name: str = EMBEDDING_MODEL_NAME): | |
| print("Loading embedding model...") | |
| self.model = SentenceTransformer(model_name) | |
| print("Embedding model loaded.") | |
| self.chunks: List[str] = [] | |
| self.chunk_sources: List[str] = [] | |
| self.embeddings = None | |
| self.build_index() | |
| def build_index(self): | |
| """Load KB texts, split into chunks, and build an embedding index.""" | |
| texts = load_kb_texts(KB_DIR) | |
| all_chunks = [] | |
| all_sources = [] | |
| for source_name, content in texts: | |
| for chunk in chunk_text(content): | |
| all_chunks.append(chunk) | |
| all_sources.append(source_name) | |
| if not all_chunks: | |
| print("⚠️ No chunks found for KB index.") | |
| self.chunks = [] | |
| self.chunk_sources = [] | |
| self.embeddings = None | |
| return | |
| print(f"Creating embeddings for {len(all_chunks)} chunks...") | |
| embeddings = self.model.encode(all_chunks, show_progress_bar=False, convert_to_numpy=True) | |
| self.chunks = all_chunks | |
| self.chunk_sources = all_sources | |
| self.embeddings = embeddings | |
| print("KB index ready.") | |
| def search(self, query: str, top_k: int = TOP_K) -> List[Tuple[str, str, float]]: | |
| """Return top-k (chunk, source_name, score) for a given query.""" | |
| if not query.strip(): | |
| return [] | |
| if self.embeddings is None or not len(self.chunks): | |
| return [] | |
| query_vec = self.model.encode([query], show_progress_bar=False, convert_to_numpy=True)[0] | |
| # Cosine similarity | |
| dot_scores = np.dot(self.embeddings, query_vec) | |
| norm_docs = np.linalg.norm(self.embeddings, axis=1) | |
| norm_query = np.linalg.norm(query_vec) + 1e-10 | |
| scores = dot_scores / (norm_docs * norm_query + 1e-10) | |
| top_idx = np.argsort(scores)[::-1][:top_k] | |
| results = [] | |
| for idx in top_idx: | |
| results.append((self.chunks[idx], self.chunk_sources[idx], float(scores[idx]))) | |
| return results | |
| # Initialize KB index | |
| print("Initializing KB index...") | |
| kb_index = KBIndex() | |
| # Initialize LLM for answer generation | |
| print("Loading LLM for answer generation...") | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| # Use a small but capable model for faster responses | |
| LLM_MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Fast and good quality | |
| print(f"Loading {LLM_MODEL_NAME}...") | |
| llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME) | |
| llm_model = AutoModelForCausalLM.from_pretrained( | |
| LLM_MODEL_NAME, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ) | |
| if not torch.cuda.is_available(): | |
| llm_model = llm_model.to("cpu") | |
| llm_model.eval() | |
| print(f"✅ LLM loaded successfully on {'GPU' if torch.cuda.is_available() else 'CPU'}") | |
| llm_available = True | |
| except Exception as e: | |
| print(f"⚠️ Could not load LLM: {e}") | |
| print("⚠️ Will use fallback mode (direct retrieval)") | |
| llm_available = False | |
| llm_tokenizer = None | |
| llm_model = None | |
| print("✅ KB Assistant ready!") | |
| # ----------------------------- | |
| # CHAT LOGIC (With LLM Answer Generation) | |
| # ----------------------------- | |
| def clean_context(text: str) -> str: | |
| """Clean up text for context, removing markdown and excess whitespace.""" | |
| # Remove markdown headers | |
| text = text.replace('#', '') | |
| # Remove multiple spaces | |
| text = ' '.join(text.split()) | |
| return text.strip() | |
| def generate_answer_with_llm(query: str, context: str, sources: List[str]) -> str: | |
| """ | |
| Generate a natural, conversational answer using LLM based on retrieved context. | |
| """ | |
| if not llm_available: | |
| return None | |
| # Create a focused prompt | |
| prompt = f"""<|system|> | |
| You are a helpful knowledge base assistant. Answer the user's question based ONLY on the provided context. Be conversational, clear, and concise. If the context doesn't contain enough information, say so. | |
| </s> | |
| <|user|> | |
| Context from knowledge base: | |
| {context} | |
| Question: {query} | |
| </s> | |
| <|assistant|> | |
| """ | |
| try: | |
| # Tokenize | |
| inputs = llm_tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=1024 | |
| ) | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = llm_model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=llm_tokenizer.eos_token_id, | |
| ) | |
| # Decode | |
| full_response = llm_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the assistant's response | |
| if "<|assistant|>" in full_response: | |
| answer = full_response.split("<|assistant|>")[-1].strip() | |
| else: | |
| answer = full_response.strip() | |
| # Clean up the answer | |
| answer = answer.replace("</s>", "").strip() | |
| # Add source attribution | |
| sources_text = ", ".join(sources) | |
| final_answer = f"{answer}\n\n---\n📚 **Sources:** {sources_text}" | |
| return final_answer | |
| except Exception as e: | |
| print(f"Error in LLM generation: {e}") | |
| return None | |
| def format_fallback_answer(results: List[Tuple[str, str, float]]) -> str: | |
| """ | |
| Fallback formatting when LLM is not available or fails. | |
| """ | |
| if not results: | |
| return ( | |
| "I couldn't find any relevant information in the knowledge base.\n\n" | |
| "**Try:**\n" | |
| "- Rephrasing your question\n" | |
| "- Using different keywords\n" | |
| "- Breaking down complex questions" | |
| ) | |
| # Get best result | |
| best_chunk, best_source, best_score = results[0] | |
| # Clean markdown | |
| cleaned = clean_context(best_chunk) | |
| # Format nicely | |
| answer = f"**From {best_source}:**\n\n{cleaned}" | |
| # Add other sources if available | |
| if len(results) > 1: | |
| other_sources = list(set([src for _, src, _ in results[1:]])) | |
| if other_sources: | |
| answer += f"\n\n💡 **Also see:** {', '.join(other_sources)}" | |
| return answer | |
| def build_answer(query: str) -> str: | |
| """ | |
| Main answer generation function using LLM for natural responses. | |
| Process: | |
| 1. Retrieve relevant chunks from KB | |
| 2. Build context from top results | |
| 3. Use LLM to generate natural answer | |
| 4. Cite sources | |
| """ | |
| # Step 1: Search the knowledge base | |
| results = kb_index.search(query, top_k=TOP_K) | |
| if not results: | |
| return ( | |
| "I couldn't find any relevant information in the knowledge base to answer your question.\n\n" | |
| "**Suggestions:**\n" | |
| "- Try rephrasing with different words\n" | |
| "- Check if the topic is covered in the KB\n" | |
| "- Be more specific about what you're looking for" | |
| ) | |
| # Step 2: Filter by similarity threshold | |
| filtered_results = [ | |
| (chunk, src, score) | |
| for chunk, src, score in results | |
| if score >= MIN_SIMILARITY_THRESHOLD | |
| ] | |
| if not filtered_results: | |
| return ( | |
| "I found some content, but it doesn't seem relevant enough to your question.\n\n" | |
| "Please try being more specific or using different keywords." | |
| ) | |
| # Step 3: Build context from top results | |
| context_parts = [] | |
| sources = [] | |
| for chunk, source, score in filtered_results[:2]: # Top 2 most relevant | |
| cleaned = clean_context(chunk) | |
| context_parts.append(cleaned) | |
| if source not in sources: | |
| sources.append(source) | |
| # Combine context (limit to 1000 chars for speed) | |
| context = " ".join(context_parts)[:1000] | |
| # Step 4: Generate answer with LLM | |
| if llm_available: | |
| llm_answer = generate_answer_with_llm(query, context, sources) | |
| if llm_answer: | |
| return llm_answer | |
| # Step 5: Fallback if LLM fails or unavailable | |
| return format_fallback_answer(filtered_results) | |
| def chat_respond(message: str, history): | |
| """ | |
| Gradio ChatInterface callback. | |
| Args: | |
| message: Latest user message (str) | |
| history: List of previous messages (handled by Gradio) | |
| Returns: | |
| Assistant's reply as a string | |
| """ | |
| if not message or not message.strip(): | |
| return "Please ask me a question about the knowledge base." | |
| try: | |
| answer = build_answer(message.strip()) | |
| return answer | |
| except Exception as e: | |
| print(f"Error generating answer: {e}") | |
| return f"Sorry, I encountered an error processing your question: {str(e)}" | |
| # ----------------------------- | |
| # GRADIO UI | |
| # ----------------------------- | |
| description = """ | |
| 🚀 **Fast Knowledge Base Search Assistant** | |
| Ask questions and get instant answers from the knowledge base. This assistant uses semantic search to find the most relevant information quickly. | |
| **Tips for better results:** | |
| - Be specific in your questions | |
| - Use keywords related to your topic | |
| - Ask one question at a time | |
| """ | |
| # Create ChatInterface (without 'type' parameter for compatibility) | |
| chat_interface = gr.ChatInterface( | |
| fn=chat_respond, | |
| title="🤖 Self-Service KB Assistant", | |
| description=description, | |
| examples=[ | |
| "What makes a good knowledge base article?", | |
| "How could a KB assistant help agents?", | |
| "Why is self-service important for customer support?", | |
| ], | |
| cache_examples=False, | |
| ) | |
| # Launch | |
| if __name__ == "__main__": | |
| # Detect environment and launch appropriately | |
| is_huggingface = os.getenv('SPACE_ID') is not None | |
| is_container = os.path.exists('/.dockerenv') or os.getenv('KUBERNETES_SERVICE_HOST') is not None | |
| if is_huggingface: | |
| print("🤗 Launching on HuggingFace Spaces...") | |
| chat_interface.launch(server_name="0.0.0.0", server_port=7860) | |
| elif is_container: | |
| print("🐳 Launching in container environment...") | |
| chat_interface.launch(server_name="0.0.0.0", server_port=7860, share=False) | |
| else: | |
| print("💻 Launching locally...") | |
| chat_interface.launch(share=False) |