Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from sentence_transformers import SentenceTransformer | |
| import chromadb | |
| from chromadb.config import Settings | |
| import uuid | |
| from huggingface_hub import InferenceClient | |
| import os | |
| from docx import Document | |
| import google.generativeai as genai | |
| # --- 0. Config --- | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| if not GEMINI_API_KEY: | |
| raise RuntimeError("GEMINI_API_KEY is not set in environment.") | |
| # Configure the SDK | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| # Choose the model | |
| MODEL_NAME = "gemini-2.5-flash-lite" | |
| LLM = genai.GenerativeModel(MODEL_NAME) | |
| app = FastAPI() | |
| # ----------------------------- | |
| # 1. SETUP: Embeddings + LLM | |
| # ----------------------------- | |
| EMBED_MODEL = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| # ----------------------------- | |
| # 2. SETUP: ChromaDB | |
| # ----------------------------- | |
| chroma_client = chromadb.PersistentClient(path="./chroma_db") | |
| collection = chroma_client.get_or_create_collection(name="knowledge_base") | |
| # ----------------------------- | |
| # Helper: Extract text from docx | |
| # ----------------------------- | |
| def extract_docx_text(file_path): | |
| doc = Document(file_path) | |
| return "\n".join([para.text for para in doc.paragraphs]) | |
| # ----------------------------- | |
| # 3. STARTUP INGEST | |
| # ----------------------------- | |
| def ingest_documents(): | |
| print("Checking if KB already has data...") | |
| if collection.count() > 0: | |
| print("KB exists. Skipping ingest.") | |
| return | |
| print("Empty KB. Ingesting files...") | |
| for fname in os.listdir("./documents"): | |
| if fname.endswith(".docx"): | |
| text = extract_docx_text(f"./documents/{fname}") | |
| chunks = text.split("\n\n") # simple chunking for beginners | |
| for chunk in chunks: | |
| if len(chunk.strip()) < 50: | |
| continue | |
| embedding = EMBED_MODEL.encode(chunk).tolist() | |
| collection.add( | |
| ids=[str(uuid.uuid4())], | |
| embeddings=[embedding], | |
| documents=[chunk], | |
| metadatas=[{"source": fname}] | |
| ) | |
| print("Ingest complete.") | |
| # ----------------------------- | |
| # 4. LLM for Intent detection | |
| # ----------------------------- | |
| def get_intent(query): | |
| prompt = f""" | |
| Classify the user's intent from the list: | |
| - receiving | |
| - inventory_adjustment | |
| - update_footprint | |
| - picking | |
| - shipping | |
| - trailer_close | |
| User query: "{query}" | |
| Respond ONLY with the intent label. | |
| """ | |
| resp = LLM.text_generation(prompt, max_new_tokens=10) | |
| return resp.strip() | |
| # ----------------------------- | |
| # 5. Hybrid Search (vector + keyword) | |
| # ----------------------------- | |
| def hybrid_search(query, intent, top_k=3): | |
| # Vector search | |
| emb = EMBED_MODEL.encode(query).tolist() | |
| results = collection.query(query_embeddings=[emb], n_results=top_k) | |
| docs = results["documents"][0] | |
| scores = results["distances"][0] | |
| # Convert distances to similarity | |
| similarities = [1 - d for d in scores] | |
| combined = list(zip(docs, similarities)) | |
| # Simple keyword boost | |
| boosted = [] | |
| for text, sim in combined: | |
| score = sim | |
| if intent.replace("_", " ") in text.lower(): | |
| score += 0.05 | |
| boosted.append((text, score)) | |
| boosted.sort(key=lambda x: x[1], reverse=True) | |
| return boosted | |
| # ----------------------------- | |
| # 6. LLM Format (rephrase KB) | |
| # ----------------------------- | |
| def format_with_llm(answer): | |
| prompt = f""" | |
| Rewrite this answer clearly and politely without adding new information: | |
| {answer} | |
| """ | |
| return LLM.text_generation(prompt, max_new_tokens=150) | |
| # ----------------------------- | |
| # 7. RAG Fallback | |
| # ----------------------------- | |
| def rag_fallback(query, docs): | |
| context = "\n\n".join([d for d, _ in docs]) | |
| prompt = f""" | |
| Use ONLY the information below to answer the question. | |
| If the answer is not found, say "not found". | |
| Context: | |
| {context} | |
| Question: {query} | |
| Answer: | |
| """ | |
| return LLM.text_generation(prompt, max_new_tokens=200) | |
| # ----------------------------- | |
| # 8. INCIDENT NUMBER GENERATOR | |
| # ----------------------------- | |
| def generate_incident(): | |
| return "INC" + str(uuid.uuid4())[:8].upper() | |
| # ----------------------------- | |
| # 9. MAIN CHAT ENDPOINT | |
| # ----------------------------- | |
| def chat(query: str): | |
| # Step 2: Detect intent | |
| intent = get_intent(query) | |
| # Step 3–4: Hybrid search | |
| docs = hybrid_search(query, intent) | |
| top_answer, top_score = docs[0] | |
| # Step 5: High confidence (≥ 0.89) | |
| if top_score >= 0.89: | |
| reply = format_with_llm(top_answer) | |
| return {"answer": reply, "intent": intent, "confidence": top_score} | |
| # Step 6: RAG fallback | |
| rag_answer = rag_fallback(query, docs) | |
| if "not found" not in rag_answer.lower() and len(rag_answer.split()) > 5: | |
| return {"answer": rag_answer, "intent": intent, "mode": "RAG"} | |
| # Step 7: Still not resolved → create incident | |
| incident = generate_incident() | |
| return { | |
| "answer": f"I couldn't find this information. I've created incident {incident}.", | |
| "incident": incident, | |
| "intent": intent | |
| } |