import os import json from typing import List, Optional import faiss import pickle import redis from fastapi import FastAPI, HTTPException from fastapi.responses import RedirectResponse from pydantic import BaseModel from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForSeq2SeqLM app = FastAPI(title="Medical Policy RAG Chatbot API") # -------------------------- # Configuration & Helpers # -------------------------- REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") redis_client = redis.from_url(REDIS_URL, socket_connect_timeout=1) # Local fallback for session data if Redis is unavailable local_cache = {} INDEX_PATH = os.getenv("FAISS_INDEX_PATH", "data/faiss.index") METADATA_PATH = os.getenv("METADATA_PATH", "data/metadata.pkl") EMBEDDING_MODEL = os.getenv("EMBED_MODEL", "intfloat/e5-base-v2") LLM_MODEL = os.getenv("LLM_MODEL", "google/flan-t5-large") # Load heavy models once – they are cached in the process memory. def load_models(): embedder = SentenceTransformer(EMBEDDING_MODEL) tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL) llm = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL) return embedder, tokenizer, llm def load_index(): if not os.path.exists(INDEX_PATH) or not os.path.exists(METADATA_PATH): raise FileNotFoundError("FAISS index or metadata not found. Run the ingestion pipeline first.") index = faiss.read_index(INDEX_PATH) with open(METADATA_PATH, "rb") as f: docs = pickle.load(f) return index, docs embedder, tokenizer, llm = load_models() index, documents = load_index() # -------------------------- # Request / Response models # -------------------------- class ChatRequest(BaseModel): session_id: str query: str confirm: Optional[bool] = None # Used for scenario‑based follow‑up class ChatResponse(BaseModel): answer: str sources: List[dict] = [] follow_up: Optional[str] = None # -------------------------- # Utility functions # -------------------------- def embed_text(text: str): return embedder.encode([text]) def retrieve(query: str, k: int = 5): q_emb = embed_text("query: " + query) D, I = index.search(q_emb, k) retrieved = [documents[i] for i in I[0]] return retrieved def build_prompt(query: str, retrieved: List[dict]): context_str = "\n\n".join([doc["text"] for doc in retrieved]) system = f"""Question: {query}\n\nBased on the context below, write a concise, factual answer.\nIf the answer is unknown, say \"I cannot find this info in the documents.\"\n\nContext:\n{context_str}\n\nAnswer:""" return system def generate_answer(prompt: str) -> str: inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1536).input_ids outputs = llm.generate(inputs, max_new_tokens=500, min_length=60, num_beams=4, early_stopping=True) return tokenizer.decode(outputs[0], skip_special_tokens=True) def store_pending(session_id: str, data: dict): try: redis_client.setex(f"session:{session_id}:pending", 600, json.dumps(data)) except Exception: local_cache[f"session:{session_id}:pending"] = json.dumps(data) def get_pending(session_id: str) -> Optional[dict]: try: raw = redis_client.get(f"session:{session_id}:pending") except Exception: raw = local_cache.get(f"session:{session_id}:pending") if raw: return json.loads(raw) return None def clear_pending(session_id: str): try: redis_client.delete(f"session:{session_id}:pending") except Exception: pass local_cache.pop(f"session:{session_id}:pending", None) # -------------------------- # Core endpoint # -------------------------- @app.get("/", include_in_schema=False) def root(): return RedirectResponse(url="/docs") @app.post("/chat", response_model=ChatResponse) def chat(req: ChatRequest): # Check if there is a pending confirmation for this session pending = get_pending(req.session_id) if pending: # User is responding to a follow‑up question if req.confirm is None: raise HTTPException(status_code=400, detail="Missing 'confirm' field for pending follow‑up.") if not req.confirm: # User declined – abort the flow clear_pending(req.session_id) return ChatResponse(answer="Okay, let me know if you need anything else.") # User confirmed – continue with stored context query = pending["original_query"] retrieved = pending["retrieved"] prompt = pending["prompt"] answer = generate_answer(prompt) clear_pending(req.session_id) return ChatResponse(answer=answer, sources=[{"source": d["source"], "snippet": d["text"][:200]} for d in retrieved]) # No pending – normal processing # Simple keyword detection for scenario‑based flow (can be extended) lowered = req.query.lower() if "leave" in lowered and "process" in lowered: # Ask for confirmation before revealing the full procedure retrieved = retrieve(req.query) prompt = build_prompt(req.query, retrieved) # Store pending state store_pending(req.session_id, { "original_query": req.query, "retrieved": retrieved, "prompt": prompt }) return ChatResponse( answer="May I proceed to explain the medical leave process?", follow_up="Please respond with 'confirm': true or false." ) # Regular RAG path retrieved = retrieve(req.query) prompt = build_prompt(req.query, retrieved) answer = generate_answer(prompt) return ChatResponse( answer=answer, sources=[{"source": d["source"], "snippet": d["text"][:200]} for d in retrieved] )