Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| from typing import List, Optional | |
| import faiss | |
| import pickle | |
| import redis | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import RedirectResponse | |
| from pydantic import BaseModel | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| app = FastAPI(title="Medical Policy RAG Chatbot API") | |
| # -------------------------- | |
| # Configuration & Helpers | |
| # -------------------------- | |
| REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") | |
| redis_client = redis.from_url(REDIS_URL, socket_connect_timeout=1) | |
| # Local fallback for session data if Redis is unavailable | |
| local_cache = {} | |
| INDEX_PATH = os.getenv("FAISS_INDEX_PATH", "data/faiss.index") | |
| METADATA_PATH = os.getenv("METADATA_PATH", "data/metadata.pkl") | |
| EMBEDDING_MODEL = os.getenv("EMBED_MODEL", "intfloat/e5-base-v2") | |
| LLM_MODEL = os.getenv("LLM_MODEL", "google/flan-t5-large") | |
| # Load heavy models once – they are cached in the process memory. | |
| def load_models(): | |
| embedder = SentenceTransformer(EMBEDDING_MODEL) | |
| tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL) | |
| llm = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL) | |
| return embedder, tokenizer, llm | |
| def load_index(): | |
| if not os.path.exists(INDEX_PATH) or not os.path.exists(METADATA_PATH): | |
| raise FileNotFoundError("FAISS index or metadata not found. Run the ingestion pipeline first.") | |
| index = faiss.read_index(INDEX_PATH) | |
| with open(METADATA_PATH, "rb") as f: | |
| docs = pickle.load(f) | |
| return index, docs | |
| embedder, tokenizer, llm = load_models() | |
| index, documents = load_index() | |
| # -------------------------- | |
| # Request / Response models | |
| # -------------------------- | |
| class ChatRequest(BaseModel): | |
| session_id: str | |
| query: str | |
| confirm: Optional[bool] = None # Used for scenario‑based follow‑up | |
| class ChatResponse(BaseModel): | |
| answer: str | |
| sources: List[dict] = [] | |
| follow_up: Optional[str] = None | |
| # -------------------------- | |
| # Utility functions | |
| # -------------------------- | |
| def embed_text(text: str): | |
| return embedder.encode([text]) | |
| def retrieve(query: str, k: int = 5): | |
| q_emb = embed_text("query: " + query) | |
| D, I = index.search(q_emb, k) | |
| retrieved = [documents[i] for i in I[0]] | |
| return retrieved | |
| def build_prompt(query: str, retrieved: List[dict]): | |
| context_str = "\n\n".join([doc["text"] for doc in retrieved]) | |
| system = f"""Question: {query}\n\nBased on the context below, write a concise, factual answer.\nIf the answer is unknown, say \"I cannot find this info in the documents.\"\n\nContext:\n{context_str}\n\nAnswer:""" | |
| return system | |
| def generate_answer(prompt: str) -> str: | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1536).input_ids | |
| outputs = llm.generate(inputs, max_new_tokens=500, min_length=60, num_beams=4, early_stopping=True) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def store_pending(session_id: str, data: dict): | |
| try: | |
| redis_client.setex(f"session:{session_id}:pending", 600, json.dumps(data)) | |
| except Exception: | |
| local_cache[f"session:{session_id}:pending"] = json.dumps(data) | |
| def get_pending(session_id: str) -> Optional[dict]: | |
| try: | |
| raw = redis_client.get(f"session:{session_id}:pending") | |
| except Exception: | |
| raw = local_cache.get(f"session:{session_id}:pending") | |
| if raw: | |
| return json.loads(raw) | |
| return None | |
| def clear_pending(session_id: str): | |
| try: | |
| redis_client.delete(f"session:{session_id}:pending") | |
| except Exception: | |
| pass | |
| local_cache.pop(f"session:{session_id}:pending", None) | |
| # -------------------------- | |
| # Core endpoint | |
| # -------------------------- | |
| def root(): | |
| return RedirectResponse(url="/docs") | |
| def chat(req: ChatRequest): | |
| # Check if there is a pending confirmation for this session | |
| pending = get_pending(req.session_id) | |
| if pending: | |
| # User is responding to a follow‑up question | |
| if req.confirm is None: | |
| raise HTTPException(status_code=400, detail="Missing 'confirm' field for pending follow‑up.") | |
| if not req.confirm: | |
| # User declined – abort the flow | |
| clear_pending(req.session_id) | |
| return ChatResponse(answer="Okay, let me know if you need anything else.") | |
| # User confirmed – continue with stored context | |
| query = pending["original_query"] | |
| retrieved = pending["retrieved"] | |
| prompt = pending["prompt"] | |
| answer = generate_answer(prompt) | |
| clear_pending(req.session_id) | |
| return ChatResponse(answer=answer, sources=[{"source": d["source"], "snippet": d["text"][:200]} for d in retrieved]) | |
| # No pending – normal processing | |
| # Simple keyword detection for scenario‑based flow (can be extended) | |
| lowered = req.query.lower() | |
| if "leave" in lowered and "process" in lowered: | |
| # Ask for confirmation before revealing the full procedure | |
| retrieved = retrieve(req.query) | |
| prompt = build_prompt(req.query, retrieved) | |
| # Store pending state | |
| store_pending(req.session_id, { | |
| "original_query": req.query, | |
| "retrieved": retrieved, | |
| "prompt": prompt | |
| }) | |
| return ChatResponse( | |
| answer="May I proceed to explain the medical leave process?", | |
| follow_up="Please respond with 'confirm': true or false." | |
| ) | |
| # Regular RAG path | |
| retrieved = retrieve(req.query) | |
| prompt = build_prompt(req.query, retrieved) | |
| answer = generate_answer(prompt) | |
| return ChatResponse( | |
| answer=answer, | |
| sources=[{"source": d["source"], "snippet": d["text"][:200]} for d in retrieved] | |
| ) | |