""" src/api/main.py — MediRAG FastAPI Application ============================================= FR-18: Two endpoints: GET /health → liveness check + Ollama status POST /evaluate → calls run_evaluation(), returns FR-17 JSON Design decisions: - DeBERTa model is loaded once at app startup (not per-request) - If any module raises an exception, partial results are returned (no HTTP 500) - HTTP 422 Pydantic validation errors are automatic - RAGAS is disabled by default (run_ragas=False) — set to True only if Ollama/OpenAI is available; the RAGAS module already fails gracefully. To run: uvicorn src.api.main:app --reload --host 0.0.0.0 --port 8000 """ from __future__ import annotations import os import logging import time from contextlib import asynccontextmanager from pathlib import Path from typing import Optional import requests import json import sqlite3 import yaml from datetime import datetime from fastapi import FastAPI, HTTPException, File, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse import threading from src.api.schemas import ( HealthResponse, EvaluateRequest, EvaluateResponse, QueryRequest, QueryResponse, RetrievedChunk, IngestRequest, ChatRequest, ModuleScore, ModuleResults, ) from src.evaluate import run_evaluation from src.pipeline.generator import generate_answer from src.pipeline.retriever import Retriever # --------------------------------------------------------------------------- # Logging # --------------------------------------------------------------------------- try: _cfg = yaml.safe_load(Path("config.yaml").read_text()) _log_level = _cfg.get("logging", {}).get("level", "INFO") _ollama_base = _cfg.get("llm", {}).get("base_url", "http://localhost:11434") _api_cfg = _cfg.get("api", {}) except Exception: _log_level = "INFO" _ollama_base = "http://localhost:11434" _api_cfg = {} logging.basicConfig( level=getattr(logging, _log_level, logging.INFO), format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Database settings # --------------------------------------------------------------------------- def init_db(): Path("data").mkdir(exist_ok=True) conn = sqlite3.connect("data/logs.db") c = conn.cursor() c.execute(""" CREATE TABLE IF NOT EXISTS audit_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp TEXT, endpoint TEXT, question TEXT, answer TEXT, hrs INTEGER, risk_band TEXT, composite_score REAL, latency_ms INTEGER, intervention_applied BOOLEAN, details TEXT ) """) conn.commit() conn.close() def log_audit(endpoint: str, question: str, answer: str, hrs: int, risk_band: str, composite: float, latency: int, intervention: bool, details: dict): try: conn = sqlite3.connect("data/logs.db") c = conn.cursor() c.execute(""" INSERT INTO audit_logs (timestamp, endpoint, question, answer, hrs, risk_band, composite_score, latency_ms, intervention_applied, details) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"), endpoint, question, answer, hrs, risk_band, composite, latency, intervention, json.dumps(details) )) conn.commit() conn.close() except Exception as e: logger.error(f"Failed to save audit log to DB: {e}") # --------------------------------------------------------------------------- # Lifespan: warm DeBERTa once at startup so the first request isn't slow # --------------------------------------------------------------------------- @asynccontextmanager async def lifespan(app: FastAPI): """Pre-warm DeBERTa and Retriever at startup.""" init_db() logger.info("MediRAG API starting — pre-warming models...") try: from src.modules.faithfulness import _get_model _get_model() logger.info("DeBERTa pre-warm complete.") except Exception as exc: logger.warning("DeBERTa pre-warm skipped: %s", exc) # Pre-load the retriever (BioBERT + FAISS index) into app state try: app.state.retriever = Retriever(_cfg) # Trigger lazy load now so first /query request isn't slow app.state.retriever._load_model() app.state.retriever._load_index() logger.info("Retriever pre-warm complete.") except Exception as exc: logger.warning("Retriever pre-warm skipped: %s", exc) app.state.retriever = None yield logger.info("MediRAG API shutting down.") # --------------------------------------------------------------------------- # App # --------------------------------------------------------------------------- app = FastAPI( title="MediRAG Evaluation API", description=( "Evaluate LLM-generated medical answers against retrieved evidence. " "Returns faithfulness, entity accuracy, source credibility, " "contradiction risk, and a composite Health Risk Score (HRS)." ), version="0.1.0", lifespan=lifespan, ) # Allow all origins for local dev / React frontend on same machine app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["GET", "POST"], allow_headers=["*"], ) # --------------------------------------------------------------------------- # Helper: check Ollama # --------------------------------------------------------------------------- def _check_ollama() -> bool: """Return True if Ollama API is reachable.""" try: resp = requests.get(f"{_ollama_base}/api/tags", timeout=2) return resp.status_code == 200 except Exception: return False # --------------------------------------------------------------------------- # Helper: convert EvalResult details → ModuleScore # --------------------------------------------------------------------------- def _module_score(module_results: dict, key: str) -> Optional[ModuleScore]: data = module_results.get(key) if data is None: return None return ModuleScore( score=data.get("score", 0.0), details=data.get("details", {}), error=data.get("error"), latency_ms=data.get("latency_ms"), ) # --------------------------------------------------------------------------- # GET / → redirect to /docs # --------------------------------------------------------------------------- @app.post("/project-guide") def project_guide(req: ChatRequest): """ Proxy endpoint for the Project Guide chatbot. Routes requests to Groq API using the local GROQ_API_KEY. """ groq_url = "https://api.groq.com/openai/v1/chat/completions" api_key = os.getenv("GROQ_API_KEY") if not api_key: raise HTTPException(status_code=500, detail="GROQ_API_KEY not found in server environment") headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" } # Format messages for Groq messages = [] if req.system_prompt: messages.append({"role": "system", "content": req.system_prompt}) for m in req.messages: messages.append({"role": m.role, "content": m.content}) payload = { "model": "mixtral-8x7b-32768", "messages": messages, "temperature": 0.5, "max_tokens": 1024 } try: resp = requests.post(groq_url, headers=headers, json=payload, timeout=30) resp.raise_for_status() return resp.json() except Exception as e: logger.error(f"Groq Proxy Error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/", include_in_schema=False) def root(): return RedirectResponse(url="/docs") # --------------------------------------------------------------------------- # GET /health # --------------------------------------------------------------------------- @app.get("/health", response_model=HealthResponse, tags=["system"]) def health() -> HealthResponse: """ Liveness check. Returns {"status": "ok", "ollama_available": true/false}. Always returns 200 — the caller decides what to do with `ollama_available`. """ return HealthResponse( status="ok", ollama_available=_check_ollama(), ) # --------------------------------------------------------------------------- # POST /evaluate # --------------------------------------------------------------------------- @app.post("/evaluate", response_model=EvaluateResponse, tags=["evaluation"]) def evaluate(req: EvaluateRequest) -> EvaluateResponse: """ Run the full MediRAG evaluation pipeline on a question + answer + context. - Validates inputs (FR-18: length limits, chunk count) - Runs Faithfulness, Entity Verification, Source Credibility, Contradiction - Optionally runs RAGAS (set `run_ragas=true` if Ollama/OpenAI is available) - Returns composite Health Risk Score (HRS) + per-module breakdown **Note on `run_ragas`**: RAGAS requires a running LLM backend (Ollama or OpenAI). If unavailable, RAGAS will gracefully return score=0.5 as a neutral fallback — it will NOT crash the request. """ logger.info( "POST /evaluate — question=%r, chunks=%d, run_ragas=%s", req.question[:80], len(req.context_chunks), req.run_ragas, ) # Convert Pydantic ContextChunk → plain dicts for the pipeline context_dicts: list[dict] = [chunk.model_dump(exclude_none=True) for chunk in req.context_chunks] t0 = time.perf_counter() try: result = run_evaluation( question=req.question, answer=req.answer, context_chunks=context_dicts, rxnorm_cache_path=req.rxnorm_cache_path, run_ragas=req.run_ragas, config=_cfg, ) except Exception as exc: logger.exception("run_evaluation raised an unhandled exception: %s", exc) raise HTTPException( status_code=500, detail=f"Evaluation pipeline error: {type(exc).__name__}: {exc}", ) from exc total_ms = int((time.perf_counter() - t0) * 1000) # Extract composite score + details composite = float(result.score) details = result.details or {} hrs = details.get("hrs", int(round(100 * (1.0 - composite)))) hrs = max(0, min(100, hrs)) confidence_level = details.get("confidence_level", "UNKNOWN") risk_band = details.get("risk_band", "UNKNOWN") pipeline_ms = details.get("total_pipeline_ms", total_ms) # Build per-module scores mod_results: dict = details.get("module_results", {}) module_scores = ModuleResults( faithfulness=_module_score(mod_results, "faithfulness"), entity_verifier=_module_score(mod_results, "entity_verifier"), source_credibility=_module_score(mod_results, "source_credibility"), contradiction=_module_score(mod_results, "contradiction"), ragas=_module_score(mod_results, "ragas"), ) logger.info( "POST /evaluate → HRS=%d (%s) in %d ms", hrs, risk_band, pipeline_ms, ) log_audit("evaluate", req.question, req.answer, hrs, risk_band, composite, pipeline_ms, False, { "module_results": mod_results, "confidence_level": confidence_level }) return EvaluateResponse( composite_score=composite, hrs=hrs, confidence_level=confidence_level, risk_band=risk_band, module_results=module_scores, total_pipeline_ms=pipeline_ms, ) # --------------------------------------------------------------------------- # POST /query — end-to-end: question → retrieve → generate → evaluate # --------------------------------------------------------------------------- @app.post("/query", response_model=QueryResponse, tags=["query"]) def query(req: QueryRequest) -> QueryResponse: """ Full end-to-end MediRAG pipeline. 1. Retrieves top-k context chunks from FAISS (BioBERT) 2. Generates a grounded answer using Mistral (Ollama) 3. Evaluates the answer with all 4 modules + aggregator 4. Returns the answer, retrieved chunks, HRS score, and full breakdown **Requires Ollama running locally with Mistral pulled.** No fallback — returns 503 if Ollama is unavailable. """ import time as _time t_total = _time.perf_counter() logger.info("POST /query — question=%r, top_k=%d", req.question[:80], req.top_k) # Step 1: Retrieve retriever: Optional[Retriever] = getattr(app.state, "retriever", None) if retriever is None: # Fallback: instantiate now (slower first call) try: retriever = Retriever(_cfg) except Exception as exc: raise HTTPException(status_code=503, detail=f"Retriever unavailable: {exc}") from exc try: raw_results = retriever.search(req.question, top_k=req.top_k) except FileNotFoundError as exc: raise HTTPException(status_code=503, detail=f"FAISS index not found: {exc}") from exc except Exception as exc: raise HTTPException(status_code=500, detail=f"Retrieval error: {exc}") from exc if not raw_results: raise HTTPException(status_code=404, detail="No relevant documents found for this question.") # Convert retriever output → chunk dicts for generator + evaluate context_chunks: list[dict] = [] retrieved_chunks_out: list[RetrievedChunk] = [] for chunk_text, meta, score in raw_results: d = { "text": chunk_text, "chunk_id": meta.get("chunk_id"), "source": meta.get("source", ""), "pub_type": meta.get("pub_type", ""), "pub_year": meta.get("pub_year"), "title": meta.get("title", ""), } context_chunks.append(d) retrieved_chunks_out.append(RetrievedChunk( chunk_id=meta.get("chunk_id"), text=chunk_text[:500], # truncate for response readability source=meta.get("source", ""), pub_type=meta.get("pub_type", ""), pub_year=meta.get("pub_year"), title=meta.get("title", ""), similarity_score=round(score, 4), )) logger.info("Retrieved %d chunks (top score=%.4f)", len(context_chunks), raw_results[0][2] if raw_results else 0.0) # Raw FAISS cosine similarity for coverage gap gate. # IndexFlatIP + L2-norm = cosine in [-1, 1]. < 0.60 means no close semantic match in DB. top_faiss_cosine = ( raw_results[0][1].get("_top_faiss_cosine", 0.0) if raw_results else 0.0 ) # Convert request overrides into a dict for generator llm_overrides = {} if req.llm_provider: llm_overrides["provider"] = req.llm_provider if req.llm_api_key: llm_overrides["api_key"] = req.llm_api_key if req.llm_model: llm_overrides["model"] = req.llm_model if req.ollama_url: llm_overrides["ollama_url"] = req.ollama_url if req.system_prompt: llm_overrides["system_prompt"] = req.system_prompt if req.persona: llm_overrides["persona"] = req.persona # ========================================================================= # Step 2a: PRIVACY SHIELD — MediRAG redacts PHI (Option 1) # ========================================================================= p_mapping = {} privacy_applied = False question_to_gen = req.question if req.use_privacy_shield: from src.pipeline.privacy import shield question_to_gen, p_mapping = shield.redact(req.question) if p_mapping: privacy_applied = True logger.info("PRIVACY INTERVENTION: Redacted %d items from question.", len(p_mapping)) # Step 2: Generate answer via LLM (Gemini or Ollama) try: # Use the potentially redacted question for generation answer = generate_answer(question_to_gen, context_chunks, _cfg, overrides=llm_overrides) except RuntimeError as exc: raise HTTPException(status_code=503, detail=f"LLM generation failed: {exc}") from exc # Restore the PHI for the final display so the user sees the actual names if privacy_applied: from src.pipeline.privacy import shield answer = shield.restore(answer, p_mapping) # ========================================================================= # ========================================================================= # Step 2b: CONSENSUS CHECK — MediRAG compares multiple models (Option 2) # ========================================================================= consensus_results = None if req.use_consensus: from src.pipeline.consensus import run_consensus_check # Determine which providers to use based on available config/overrides providers = ["gemini"] if os.environ.get("GROQ_API_KEY"): providers.append("groq") elif os.environ.get("MISTRAL_API_KEY"): providers.append("mistral") else: providers.append("ollama") # fallback to local if no second key logger.info("Running Consensus Layer with %s", providers) consensus_results = run_consensus_check(req.question, context_chunks, _cfg, providers=providers) # If consensus finds a safer merged answer, we promote it # and update the primary answer for the evaluation loop answer = consensus_results.get("consensus_answer", answer) # ========================================================================= # [DEMO MODE] Inject a false claim to demonstrate the intervention system if req.inject_hallucination: logger.warning("DEMO MODE: Injecting hallucinated claim into answer: '%s'", req.inject_hallucination) answer = answer + " " + req.inject_hallucination.strip() # Step 3: Evaluate try: eval_result = run_evaluation( question=req.question, answer=answer, context_chunks=context_chunks, run_ragas=req.run_ragas, config=_cfg, ) except Exception as exc: logger.exception("Evaluation failed: %s", exc) try: log_audit("query", req.question, answer, 100, "EVAL_ERROR", 0.0, int((_time.perf_counter() - t_total) * 1000), False, {"error": str(exc), "error_type": "evaluation_failure"}) except Exception: pass raise HTTPException(status_code=500, detail=f"Evaluation error: {exc}") from exc # ========================================================================= # Step 3b: INTERVENTION LOOP — MediRAG acts on evaluation results # ========================================================================= from src.pipeline.generator import generate_strict_answer details = eval_result.details or {} composite = float(eval_result.score) hrs = int(round(100 * (1.0 - composite))) hrs = max(0, min(100, hrs)) mod_results: dict = details.get("module_results", {}) intervention_applied = False intervention_reason = None original_answer = None intervention_details = None faith_score = (mod_results.get("faithfulness") or {}).get("score", 1.0) # Source-credibility-aware faith threshold: high-credibility sources get more tolerance source_cred = float(details.get("component_scores", {}).get("source_credibility", 0.5)) faith_threshold = max(0.3, 0.7 - (source_cred * 0.4)) # 0.30 for cred=1.0, 0.66 for cred=0.3 # ── Coverage Gap Gate ──────────────────────────────────────────────────── # Two signals combined: # 1. Refusal answer — LLM says "not in context / insufficient evidence" # → LLM itself confirms the DB doesn't cover this topic. # 2. FAISS cosine — genuinely poor semantic match vs. the query. # BioBERT clusters medical dosing texts, so threshold must be high (0.75). _REFUSAL_PATTERNS = ( "not mentioned in the provided context", "not provided in the retrieved context", "insufficient evidence in retrieved context", "no information about", "not in the provided context", "cannot find information", "the retrieved context does not contain", "the context does not contain", "not mentioned in the context", "is not provided in the context", ) _answer_lower = answer.lower() is_refusal_answer = any(p in _answer_lower for p in _REFUSAL_PATTERNS) is_low_faiss = top_faiss_cosine < 0.75 # If a verified drug with rxcui appears in the question, the intervention's # FDA direct lookup can still retrieve the right data even when initial FAISS # retrieval missed it. Don't label those as coverage gaps — let intervention run. _ev_entities = (mod_results.get("entity_verifier") or {}).get("details", {}).get("entities", []) _q_lower_cg = req.question.lower() _drug_in_question = any( e.get("rxcui") and e.get("entity", "").lower() in _q_lower_cg for e in _ev_entities ) # Refusal is a standalone COVERAGE_GAP signal — faith_score is unreliable here # because NLI scores refusal sentences as NEUTRAL (0.5), not low. # Exception: if a drug is named in the question, FDA lookup can still help. # HALLUCINATION: specific claims made but not grounded in available context. if is_refusal_answer and not _drug_in_question: gap_type = "COVERAGE_GAP" elif faith_score < faith_threshold and is_low_faiss and not _drug_in_question: gap_type = "COVERAGE_GAP" # poor retrieval + low faith = DB lacks this topic elif faith_score < faith_threshold: gap_type = "HALLUCINATION" # relevant context exists but answer ignores it else: gap_type = None coverage_gap = gap_type == "COVERAGE_GAP" coverage_gap_details: dict | None = { "gap_type": gap_type, "top_faiss_cosine": round(top_faiss_cosine, 4), "is_refusal_answer": is_refusal_answer, "note": ( "Database coverage may be insufficient for this topic. " "The answer could not be verified against retrieved evidence. " "Consult primary medical literature or a specialist." ) if coverage_gap else None, } if gap_type else None if coverage_gap: logger.warning( "COVERAGE_GAP detected — refusal=%s, faiss=%.4f, faith=%.2f", is_refusal_answer, top_faiss_cosine, faith_score, ) # Tier 1: CRITICAL BLOCK (HRS ≥ 86) — response is too dangerous to show # Coverage gap: skip both tiers — regenerating from an empty DB won't help if coverage_gap: logger.info("COVERAGE_GAP — skipping intervention (regeneration cannot add missing data).") elif hrs >= 86: original_answer = answer answer = ( "⛔ UNSAFE RESPONSE BLOCKED by MediRAG Safety Gate.\n\n" "The generated answer was flagged as CRITICAL risk " f"(Health Risk Score: {hrs}/100). " "It showed signs of hallucination or contradiction with the retrieved evidence. " "Please consult a qualified medical professional or rephrase your question." ) intervention_applied = True intervention_reason = "CRITICAL_BLOCKED" intervention_details = { "hrs_original": hrs, "faithfulness": faith_score, "message": "Response blocked: HRS ≥ 86 (CRITICAL risk band).", } logger.warning("INTERVENTION: CRITICAL_BLOCKED — HRS=%d", hrs) # Tier 2: HIGH RISK REGENERATION elif hrs >= 61 or faith_score < faith_threshold: original_answer = answer original_hrs = hrs logger.warning( "INTERVENTION: HIGH_RISK_REGENERATED — HRS=%d, faith=%.2f. Regenerating with strict prompt.", hrs, faith_score ) try: # Re-retrieve from shared index — find better chunks than the ones that failed try: # Direct FDA lookup — only for drugs named in the question itself. # Drugs found in the answer but NOT in the question (e.g. metformin # mentioned incidentally in a general "first-line treatment" answer) # should not trigger FDA lookup; that would replace relevant context # with the wrong label sections (contraindications instead of treatment). fda_direct: list[dict] = [] try: ev_details = eval_result.details.get("module_results", {}).get("entity_verifier", {}).get("details", {}) verified_drugs = [ e["entity"] for e in ev_details.get("entities", []) if e.get("status") == "VERIFIED" and e.get("rxcui") ] q_lower = req.question.lower() for drug in verified_drugs: if drug.lower() in q_lower: fda_direct += app.state.retriever.get_fda_chunks(drug) if fda_direct: logger.info("Direct FDA lookup found %d chunks for drugs: %s", len(fda_direct), [d for d in verified_drugs if d.lower() in q_lower]) except Exception as fda_exc: logger.debug("Direct FDA lookup skipped: %s", fda_exc) # Direct guideline lookup — only when original retrieval was poor. # If FAISS cosine ≥ 0.85 the original chunks were already relevant; # adding guideline sections here can pull in wrong topic areas # (e.g., ADA Section 2 Diagnosis instead of Section 9 Treatment). guideline_direct: list[dict] = [] if top_faiss_cosine < 0.85: try: guideline_direct = app.state.retriever.get_guideline_chunks(req.question) if guideline_direct: logger.info("Direct guideline lookup found %d chunks", len(guideline_direct)) except Exception as gl_exc: logger.debug("Direct guideline lookup skipped: %s", gl_exc) else: logger.debug("Skipping guideline direct lookup (FAISS cosine=%.4f ≥ 0.85, original retrieval was high-quality)", top_faiss_cosine) # Merge: guideline chunks + FDA chunks + fresh retrieval fda_direct = guideline_direct + fda_direct # For drug/clinical questions, expand query toward authoritative sources _drug_terms = ("contraindication", "dosage", "dose", "interaction", "warning", "adverse", "side effect", "mechanism") _q_lower = req.question.lower() retry_query = ( f"FDA drug label clinical guideline {req.question}" if any(t in _q_lower for t in _drug_terms) else req.question ) fresh_results = app.state.retriever.search(retry_query, top_k=req.top_k) fresh_chunks: list[dict] = [] for chunk_text, meta, score in fresh_results: fresh_chunks.append({ "text": chunk_text, "chunk_id": meta.get("chunk_id"), "source": meta.get("source", ""), "pub_type": meta.get("pub_type", ""), "pub_year": meta.get("pub_year"), "title": meta.get("title", ""), }) # Merge: direct lookups first (FDA/guidelines), then fresh retrieval base_chunks = fresh_chunks if fresh_chunks else context_chunks retry_chunks = (fda_direct + base_chunks)[:req.top_k] if fda_direct else base_chunks logger.info("Re-retrieval for intervention: %d fresh chunks (top source: %s)", len(retry_chunks), retry_chunks[0].get("pub_type", "?") if retry_chunks else "none") except Exception: retry_chunks = context_chunks answer = generate_strict_answer(req.question, retry_chunks, _cfg, overrides=llm_overrides) # Re-evaluate the corrected answer eval_result = run_evaluation( question=req.question, answer=answer, context_chunks=retry_chunks, run_ragas=False, # skip RAGAS on retry to reduce latency config=_cfg, ) details = eval_result.details or {} composite = float(eval_result.score) hrs = int(round(100 * (1.0 - composite))) hrs = max(0, min(100, hrs)) mod_results = details.get("module_results", {}) except Exception as exc: logger.error("Strict regeneration failed: %s — keeping original answer", exc) answer = original_answer # fall back gracefully original_answer = None intervention_applied = True intervention_reason = "HIGH_RISK_REGENERATED" intervention_details = { "hrs_original": original_hrs, "hrs_corrected": hrs, "faithfulness_original": faith_score, "faithfulness_corrected": (mod_results.get("faithfulness") or {}).get("score", 0), "message": "Response regenerated with strict context-only prompt due to high risk score.", } # ========================================================================= # Step 4: Build response total_ms = int((_time.perf_counter() - t_total) * 1000) logger.info("POST /query → HRS=%d (%s) intervention=%s in %d ms total", hrs, details.get("risk_band", "?"), intervention_reason or "none", total_ms) log_audit("query", req.question, answer, hrs, details.get("risk_band", "UNKNOWN"), composite, total_ms, intervention_applied, { "module_results": mod_results, "confidence_level": details.get("confidence_level", "UNKNOWN"), "intervention_reason": intervention_reason, "original_answer": original_answer, }) return QueryResponse( question=req.question, generated_answer=answer, retrieved_chunks=retrieved_chunks_out, composite_score=composite, hrs=hrs, confidence_level=details.get("confidence_level", "UNKNOWN"), risk_band=details.get("risk_band", "UNKNOWN"), module_results=ModuleResults( faithfulness=_module_score(mod_results, "faithfulness"), entity_verifier=_module_score(mod_results, "entity_verifier"), source_credibility=_module_score(mod_results, "source_credibility"), contradiction=_module_score(mod_results, "contradiction"), ragas=_module_score(mod_results, "ragas"), ), total_pipeline_ms=total_ms, intervention_applied=intervention_applied, intervention_reason=intervention_reason, original_answer=original_answer, intervention_details=intervention_details, consensus_results=consensus_results, privacy_applied=privacy_applied, privacy_details={"redacted_count": len(p_mapping)} if privacy_applied else None, coverage_gap=coverage_gap, coverage_gap_details=coverage_gap_details, ) # --------------------------------------------------------------------------- # POST /ingest — dynamically append new documents to the FAISS index # --------------------------------------------------------------------------- _faiss_lock = threading.Lock() @app.post("/ingest", tags=["ingestion"]) def ingest_document(req: IngestRequest): """ Dynamically ingest a new document into the running FAISS index. Thread-safe implementation uses a lock to prevent concurrent write corruption. """ import pickle import faiss from src.pipeline.chunker import chunk_documents retriever = getattr(app.state, "retriever", None) if retriever is None or retriever._index is None: raise HTTPException(status_code=503, detail="Retriever not pre-warmed. Cannot ingest.") logger.info("POST /ingest — title='%s', size=%d chars", req.title, len(req.text)) # 1. Chunk the document doc = { "text": req.text, "doc_id": "custom_" + req.title[:10], "title": req.title, "source": req.source, "pub_type": req.pub_type, "pub_year": 2026, } chunks = chunk_documents([doc], _cfg) if not chunks: raise HTTPException(status_code=400, detail="Document produced 0 chunks.") # 2. Embed the chunks using the same BioBERT model as the retriever from src.pipeline.embedder import encode_texts import numpy as np # Reuse already-loaded SentenceTransformer from the retriever to avoid double RAM load if retriever._model is None: retriever._load_model() st_model = retriever._model texts = [c["chunk_text"] for c in chunks] embeddings = np.array(st_model.encode(texts, show_progress_bar=False), dtype=np.float32) faiss.normalize_L2(embeddings) # Required: index is IndexFlatIP = cosine sim # 3. Thread-safe Index Update with atomic disk writes with _faiss_lock: import os idx_path = Path(_cfg["retrieval"]["index_path"]) meta_path = Path(_cfg["retrieval"]["metadata_path"]) index = retriever._index metadata_store = retriever._metadata start_id = len(metadata_store) # Add to in-memory structures for i, chunk in enumerate(chunks): metadata_store[start_id + i] = chunk # Add to FAISS in memory index.add(embeddings) # Atomic FAISS write: write to temp → rename (never leaves a half-written file) idx_tmp = str(idx_path) + ".tmp" faiss.write_index(index, idx_tmp) os.replace(idx_tmp, str(idx_path)) # Atomic metadata write meta_tmp = str(meta_path) + ".tmp" with open(meta_tmp, "wb") as f: pickle.dump(metadata_store, f) os.replace(meta_tmp, str(meta_path)) # 4. Rebuild BM25 for the running instance retriever.rebuild_bm25() logger.info("Successfully injected %d chunks for '%s' into FAISS and BM25.", len(chunks), req.title) return {"status": "success", "chunks_added": len(chunks), "title": req.title} # --------------------------------------------------------------------------- # GET /logs and /stats — fetch history for dashboard # --------------------------------------------------------------------------- @app.get("/logs", tags=["dashboard"]) def get_logs(limit: int = 50): try: conn = sqlite3.connect("data/logs.db") conn.row_factory = sqlite3.Row c = conn.cursor() c.execute("SELECT * FROM audit_logs ORDER BY id DESC LIMIT ?", (limit,)) rows = c.fetchall() conn.close() return [dict(ix) for ix in rows] except Exception as e: return [] @app.get("/stats", tags=["dashboard"]) def get_stats(): try: conn = sqlite3.connect("data/logs.db") c = conn.cursor() c.execute("SELECT COUNT(*), AVG(hrs), SUM(CASE WHEN risk_band='CRITICAL' THEN 1 ELSE 0 END) FROM audit_logs") total_evals, avg_hrs, crit_alerts = c.fetchone() c.execute("SELECT SUM(CASE WHEN intervention_applied=1 THEN 1 ELSE 0 END) FROM audit_logs") interventions = c.fetchone()[0] # Monthly data example monthly_query = "SELECT SUBSTR(timestamp, 1, 7) as month, AVG(hrs) FROM audit_logs GROUP BY month ORDER BY month LIMIT 12" c.execute(monthly_query) monthly_data = [{"month": row[0], "avg_hrs": row[1]} for row in c.fetchall()] conn.close() return { "totalEvals": total_evals or 0, "avgHrs": round(avg_hrs or 0, 1), "critAlerts": crit_alerts or 0, "interventions": interventions or 0, "monthly": monthly_data } except Exception as e: return { "totalEvals": 0, "avgHrs": 0, "critAlerts": 0, "interventions": 0, "monthly": [] } # --------------------------------------------------------------------------- # POST /parse_file — helper for frontend to extract PDF/DOCX text # --------------------------------------------------------------------------- @app.post("/parse_file", tags=["ingestion"]) async def parse_file(file: UploadFile = File(...)): """Extract text from uploaded txt, md, pdf, or docx files.""" content = await file.read() filename = file.filename.lower() text = "" try: if filename.endswith(".pdf"): import fitz doc = fitz.open(stream=content, filetype="pdf") msgs = [] for page in doc: msgs.append(page.get_text()) text = "\n".join(msgs) elif filename.endswith(".docx"): import docx from io import BytesIO doc = docx.Document(BytesIO(content)) text = "\n".join([p.text for p in doc.paragraphs]) else: text = content.decode("utf-8", errors="replace") return {"status": "success", "text": text} except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to parse file: {e}")