MedicalRAG / main.py
mrshibly's picture
Upload main.py
ae66656 verified
import os
import json
from typing import List, Optional
import faiss
import pickle
import redis
from fastapi import FastAPI, HTTPException
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
app = FastAPI(title="Medical Policy RAG Chatbot API")
# --------------------------
# Configuration & Helpers
# --------------------------
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
redis_client = redis.from_url(REDIS_URL, socket_connect_timeout=1)
# Local fallback for session data if Redis is unavailable
local_cache = {}
INDEX_PATH = os.getenv("FAISS_INDEX_PATH", "data/faiss.index")
METADATA_PATH = os.getenv("METADATA_PATH", "data/metadata.pkl")
EMBEDDING_MODEL = os.getenv("EMBED_MODEL", "intfloat/e5-base-v2")
LLM_MODEL = os.getenv("LLM_MODEL", "google/flan-t5-large")
# Load heavy models once – they are cached in the process memory.
def load_models():
embedder = SentenceTransformer(EMBEDDING_MODEL)
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
llm = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL)
return embedder, tokenizer, llm
def load_index():
if not os.path.exists(INDEX_PATH) or not os.path.exists(METADATA_PATH):
raise FileNotFoundError("FAISS index or metadata not found. Run the ingestion pipeline first.")
index = faiss.read_index(INDEX_PATH)
with open(METADATA_PATH, "rb") as f:
docs = pickle.load(f)
return index, docs
embedder, tokenizer, llm = load_models()
index, documents = load_index()
# --------------------------
# Request / Response models
# --------------------------
class ChatRequest(BaseModel):
session_id: str
query: str
confirm: Optional[bool] = None # Used for scenario‑based follow‑up
class ChatResponse(BaseModel):
answer: str
sources: List[dict] = []
follow_up: Optional[str] = None
# --------------------------
# Utility functions
# --------------------------
def embed_text(text: str):
return embedder.encode([text])
def retrieve(query: str, k: int = 5):
q_emb = embed_text("query: " + query)
D, I = index.search(q_emb, k)
retrieved = [documents[i] for i in I[0]]
return retrieved
def build_prompt(query: str, retrieved: List[dict]):
context_str = "\n\n".join([doc["text"] for doc in retrieved])
system = f"""Question: {query}\n\nBased on the context below, write a concise, factual answer.\nIf the answer is unknown, say \"I cannot find this info in the documents.\"\n\nContext:\n{context_str}\n\nAnswer:"""
return system
def generate_answer(prompt: str) -> str:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1536).input_ids
outputs = llm.generate(inputs, max_new_tokens=500, min_length=60, num_beams=4, early_stopping=True)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def store_pending(session_id: str, data: dict):
try:
redis_client.setex(f"session:{session_id}:pending", 600, json.dumps(data))
except Exception:
local_cache[f"session:{session_id}:pending"] = json.dumps(data)
def get_pending(session_id: str) -> Optional[dict]:
try:
raw = redis_client.get(f"session:{session_id}:pending")
except Exception:
raw = local_cache.get(f"session:{session_id}:pending")
if raw:
return json.loads(raw)
return None
def clear_pending(session_id: str):
try:
redis_client.delete(f"session:{session_id}:pending")
except Exception:
pass
local_cache.pop(f"session:{session_id}:pending", None)
# --------------------------
# Core endpoint
# --------------------------
@app.get("/", include_in_schema=False)
def root():
return RedirectResponse(url="/docs")
@app.post("/chat", response_model=ChatResponse)
def chat(req: ChatRequest):
# Check if there is a pending confirmation for this session
pending = get_pending(req.session_id)
if pending:
# User is responding to a follow‑up question
if req.confirm is None:
raise HTTPException(status_code=400, detail="Missing 'confirm' field for pending follow‑up.")
if not req.confirm:
# User declined – abort the flow
clear_pending(req.session_id)
return ChatResponse(answer="Okay, let me know if you need anything else.")
# User confirmed – continue with stored context
query = pending["original_query"]
retrieved = pending["retrieved"]
prompt = pending["prompt"]
answer = generate_answer(prompt)
clear_pending(req.session_id)
return ChatResponse(answer=answer, sources=[{"source": d["source"], "snippet": d["text"][:200]} for d in retrieved])
# No pending – normal processing
# Simple keyword detection for scenario‑based flow (can be extended)
lowered = req.query.lower()
if "leave" in lowered and "process" in lowered:
# Ask for confirmation before revealing the full procedure
retrieved = retrieve(req.query)
prompt = build_prompt(req.query, retrieved)
# Store pending state
store_pending(req.session_id, {
"original_query": req.query,
"retrieved": retrieved,
"prompt": prompt
})
return ChatResponse(
answer="May I proceed to explain the medical leave process?",
follow_up="Please respond with 'confirm': true or false."
)
# Regular RAG path
retrieved = retrieve(req.query)
prompt = build_prompt(req.query, retrieved)
answer = generate_answer(prompt)
return ChatResponse(
answer=answer,
sources=[{"source": d["source"], "snippet": d["text"][:200]} for d in retrieved]
)