from fastapi import FastAPI from sentence_transformers import SentenceTransformer import chromadb from chromadb.config import Settings import uuid from huggingface_hub import InferenceClient import os from docx import Document import google.generativeai as genai # --- 0. Config --- GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") if not GEMINI_API_KEY: raise RuntimeError("GEMINI_API_KEY is not set in environment.") # Configure the SDK genai.configure(api_key=GEMINI_API_KEY) # Choose the model MODEL_NAME = "gemini-2.5-flash-lite" LLM = genai.GenerativeModel(MODEL_NAME) app = FastAPI() # ----------------------------- # 1. SETUP: Embeddings + LLM # ----------------------------- EMBED_MODEL = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") # ----------------------------- # 2. SETUP: ChromaDB # ----------------------------- chroma_client = chromadb.PersistentClient(path="./chroma_db") collection = chroma_client.get_or_create_collection(name="knowledge_base") # ----------------------------- # Helper: Extract text from docx # ----------------------------- def extract_docx_text(file_path): doc = Document(file_path) return "\n".join([para.text for para in doc.paragraphs]) # ----------------------------- # 3. STARTUP INGEST # ----------------------------- @app.on_event("startup") def ingest_documents(): print("Checking if KB already has data...") if collection.count() > 0: print("KB exists. Skipping ingest.") return print("Empty KB. Ingesting files...") for fname in os.listdir("./documents"): if fname.endswith(".docx"): text = extract_docx_text(f"./documents/{fname}") chunks = text.split("\n\n") # simple chunking for beginners for chunk in chunks: if len(chunk.strip()) < 50: continue embedding = EMBED_MODEL.encode(chunk).tolist() collection.add( ids=[str(uuid.uuid4())], embeddings=[embedding], documents=[chunk], metadatas=[{"source": fname}] ) print("Ingest complete.") # ----------------------------- # 4. LLM for Intent detection # ----------------------------- def get_intent(query): prompt = f""" Classify the user's intent from the list: - receiving - inventory_adjustment - update_footprint - picking - shipping - trailer_close User query: "{query}" Respond ONLY with the intent label. """ resp = LLM.text_generation(prompt, max_new_tokens=10) return resp.strip() # ----------------------------- # 5. Hybrid Search (vector + keyword) # ----------------------------- def hybrid_search(query, intent, top_k=3): # Vector search emb = EMBED_MODEL.encode(query).tolist() results = collection.query(query_embeddings=[emb], n_results=top_k) docs = results["documents"][0] scores = results["distances"][0] # Convert distances to similarity similarities = [1 - d for d in scores] combined = list(zip(docs, similarities)) # Simple keyword boost boosted = [] for text, sim in combined: score = sim if intent.replace("_", " ") in text.lower(): score += 0.05 boosted.append((text, score)) boosted.sort(key=lambda x: x[1], reverse=True) return boosted # ----------------------------- # 6. LLM Format (rephrase KB) # ----------------------------- def format_with_llm(answer): prompt = f""" Rewrite this answer clearly and politely without adding new information: {answer} """ return LLM.text_generation(prompt, max_new_tokens=150) # ----------------------------- # 7. RAG Fallback # ----------------------------- def rag_fallback(query, docs): context = "\n\n".join([d for d, _ in docs]) prompt = f""" Use ONLY the information below to answer the question. If the answer is not found, say "not found". Context: {context} Question: {query} Answer: """ return LLM.text_generation(prompt, max_new_tokens=200) # ----------------------------- # 8. INCIDENT NUMBER GENERATOR # ----------------------------- def generate_incident(): return "INC" + str(uuid.uuid4())[:8].upper() # ----------------------------- # 9. MAIN CHAT ENDPOINT # ----------------------------- @app.post("/chat") def chat(query: str): # Step 2: Detect intent intent = get_intent(query) # Step 3–4: Hybrid search docs = hybrid_search(query, intent) top_answer, top_score = docs[0] # Step 5: High confidence (≥ 0.89) if top_score >= 0.89: reply = format_with_llm(top_answer) return {"answer": reply, "intent": intent, "confidence": top_score} # Step 6: RAG fallback rag_answer = rag_fallback(query, docs) if "not found" not in rag_answer.lower() and len(rag_answer.split()) > 5: return {"answer": rag_answer, "intent": intent, "mode": "RAG"} # Step 7: Still not resolved → create incident incident = generate_incident() return { "answer": f"I couldn't find this information. I've created incident {incident}.", "incident": incident, "intent": intent }