Spaces:
Sleeping
Sleeping
Update: backend v3.2 — privacy pipeline, consensus, new scripts
Browse files- app.py +114 -98
- src/__init__.py +3 -2
- src/api/main.py +269 -44
- src/api/schemas.py +25 -0
- src/modules/entity_verifier.py +71 -3
- src/pipeline/generator.py +77 -1
- src/pipeline/retriever.py +3 -1
- src/pipeline/semantic_cache.py +112 -0
- tests/test_modules.py +24 -0
app.py
CHANGED
|
@@ -1,98 +1,114 @@
|
|
| 1 |
-
"""
|
| 2 |
-
MediRAG Backend - FastAPI only (No Gradio)
|
| 3 |
-
React frontend on Vercel, this is just the API backend
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import os
|
| 7 |
-
import sys
|
| 8 |
-
import subprocess
|
| 9 |
-
import logging
|
| 10 |
-
import requests
|
| 11 |
-
|
| 12 |
-
# Configure logging
|
| 13 |
-
logging.basicConfig(level=logging.INFO)
|
| 14 |
-
logger = logging.getLogger(__name__)
|
| 15 |
-
|
| 16 |
-
# Set cache directories for Hugging Face
|
| 17 |
-
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
|
| 18 |
-
os.environ["HF_HOME"] = "/tmp/hf_home"
|
| 19 |
-
os.environ["TORCH_HOME"] = "/tmp/torch_cache"
|
| 20 |
-
|
| 21 |
-
# Add src to path
|
| 22 |
-
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
|
| 23 |
-
|
| 24 |
-
# Install spaCy model if not present (optional — server starts without it)
|
| 25 |
-
try:
|
| 26 |
-
import spacy
|
| 27 |
-
try:
|
| 28 |
-
spacy.load("en_core_sci_lg")
|
| 29 |
-
logger.info("spaCy model en_core_sci_lg loaded.")
|
| 30 |
-
except OSError:
|
| 31 |
-
# Try installing the model at runtime
|
| 32 |
-
try:
|
| 33 |
-
logger.info("Attempting to install scispacy model en_core_sci_lg...")
|
| 34 |
-
subprocess.run([
|
| 35 |
-
sys.executable, "-m", "pip", "install", "--quiet",
|
| 36 |
-
"https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_lg-0.5.4.tar.gz"
|
| 37 |
-
], check=True, timeout=300)
|
| 38 |
-
spacy.load("en_core_sci_lg")
|
| 39 |
-
logger.info("spaCy model installed and loaded.")
|
| 40 |
-
except Exception as model_err:
|
| 41 |
-
logger.warning(f"Could not install spaCy model: {model_err}. NER features will be limited.")
|
| 42 |
-
except ImportError:
|
| 43 |
-
logger.warning("spacy/scispacy not installed. NER features will be limited but server will still start.")
|
| 44 |
-
|
| 45 |
-
# Download datasets using huggingface_hub
|
| 46 |
-
from huggingface_hub import hf_hub_download
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
os.
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
if
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MediRAG Backend - FastAPI only (No Gradio)
|
| 3 |
+
React frontend on Vercel, this is just the API backend
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import subprocess
|
| 9 |
+
import logging
|
| 10 |
+
import requests
|
| 11 |
+
|
| 12 |
+
# Configure logging
|
| 13 |
+
logging.basicConfig(level=logging.INFO)
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
# Set cache directories for Hugging Face
|
| 17 |
+
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
|
| 18 |
+
os.environ["HF_HOME"] = "/tmp/hf_home"
|
| 19 |
+
os.environ["TORCH_HOME"] = "/tmp/torch_cache"
|
| 20 |
+
|
| 21 |
+
# Add src to path
|
| 22 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
|
| 23 |
+
|
| 24 |
+
# Install spaCy model if not present (optional — server starts without it)
|
| 25 |
+
try:
|
| 26 |
+
import spacy
|
| 27 |
+
try:
|
| 28 |
+
spacy.load("en_core_sci_lg")
|
| 29 |
+
logger.info("spaCy model en_core_sci_lg loaded.")
|
| 30 |
+
except OSError:
|
| 31 |
+
# Try installing the model at runtime
|
| 32 |
+
try:
|
| 33 |
+
logger.info("Attempting to install scispacy model en_core_sci_lg...")
|
| 34 |
+
subprocess.run([
|
| 35 |
+
sys.executable, "-m", "pip", "install", "--quiet",
|
| 36 |
+
"https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_lg-0.5.4.tar.gz"
|
| 37 |
+
], check=True, timeout=300)
|
| 38 |
+
spacy.load("en_core_sci_lg")
|
| 39 |
+
logger.info("spaCy model installed and loaded.")
|
| 40 |
+
except Exception as model_err:
|
| 41 |
+
logger.warning(f"Could not install spaCy model: {model_err}. NER features will be limited.")
|
| 42 |
+
except ImportError:
|
| 43 |
+
logger.warning("spacy/scispacy not installed. NER features will be limited but server will still start.")
|
| 44 |
+
|
| 45 |
+
# Download datasets using huggingface_hub
|
| 46 |
+
from huggingface_hub import hf_hub_download
|
| 47 |
+
import yaml
|
| 48 |
+
from pathlib import Path
|
| 49 |
+
|
| 50 |
+
# Check if config_local.yaml exists or USE_LOCAL_DATASET is set to skip HF downloads
|
| 51 |
+
config_path = os.environ.get("MEDIRAG_CONFIG", "config_local.yaml" if Path("config_local.yaml").exists() else "config.yaml")
|
| 52 |
+
try:
|
| 53 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 54 |
+
config_data = yaml.safe_load(f)
|
| 55 |
+
except Exception:
|
| 56 |
+
config_data = {}
|
| 57 |
+
|
| 58 |
+
use_local_dataset = config_data.get("retrieval", {}).get("use_local_dataset", False) or os.environ.get("USE_LOCAL_DATASET", "false").lower() == "true"
|
| 59 |
+
|
| 60 |
+
# Check and download index and data files
|
| 61 |
+
data_dir = os.path.join(os.path.dirname(__file__), "data")
|
| 62 |
+
index_dir = os.path.join(data_dir, "index")
|
| 63 |
+
os.makedirs(index_dir, exist_ok=True)
|
| 64 |
+
|
| 65 |
+
faiss_path = os.path.join(index_dir, "faiss.index")
|
| 66 |
+
metadata_path = os.path.join(index_dir, "metadata_store.pkl")
|
| 67 |
+
bm25_path = os.path.join(index_dir, "bm25_cache.pkl")
|
| 68 |
+
vocab_path = os.path.join(data_dir, "drugbank vocabulary.csv")
|
| 69 |
+
rxnorm_path = os.path.join(data_dir, "rxnorm_cache.csv")
|
| 70 |
+
|
| 71 |
+
def download_dataset_files():
|
| 72 |
+
"""Download FAISS index and other core data from Hugging Face Dataset"""
|
| 73 |
+
if use_local_dataset:
|
| 74 |
+
logger.info("[LOCAL MODE] Bypassing Hugging Face repository download. Relying on local datasets in data/index/.")
|
| 75 |
+
return
|
| 76 |
+
|
| 77 |
+
repo_id = "joytheslothh/MediRAG-Index-Data"
|
| 78 |
+
token = os.environ.get("HF_TOKEN")
|
| 79 |
+
if not token:
|
| 80 |
+
logger.warning("HF_TOKEN environment variable is not set. Dataset download might fail if repo is private.")
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
if not os.path.exists(faiss_path):
|
| 84 |
+
logger.info("Downloading faiss.index from HF dataset...")
|
| 85 |
+
hf_hub_download(repo_id=repo_id, filename="index/faiss.index", local_dir=data_dir, repo_type="dataset", token=token)
|
| 86 |
+
if not os.path.exists(metadata_path):
|
| 87 |
+
logger.info("Downloading metadata_store.pkl from HF dataset...")
|
| 88 |
+
hf_hub_download(repo_id=repo_id, filename="index/metadata_store.pkl", local_dir=data_dir, repo_type="dataset", token=token)
|
| 89 |
+
if not os.path.exists(bm25_path):
|
| 90 |
+
logger.info("Downloading bm25_cache.pkl from HF dataset...")
|
| 91 |
+
hf_hub_download(repo_id=repo_id, filename="index/bm25_cache.pkl", local_dir=data_dir, repo_type="dataset", token=token)
|
| 92 |
+
if not os.path.exists(vocab_path):
|
| 93 |
+
logger.info("Downloading drugbank vocabulary.csv from HF dataset...")
|
| 94 |
+
hf_hub_download(repo_id=repo_id, filename="drugbank vocabulary.csv", local_dir=data_dir, repo_type="dataset", token=token)
|
| 95 |
+
if not os.path.exists(rxnorm_path):
|
| 96 |
+
logger.info("Downloading rxnorm_cache.csv from HF dataset...")
|
| 97 |
+
hf_hub_download(repo_id=repo_id, filename="rxnorm_cache.csv", local_dir=data_dir, repo_type="dataset", token=token)
|
| 98 |
+
except Exception as e:
|
| 99 |
+
logger.error(f"Failed to download dataset files: {e}")
|
| 100 |
+
logger.warning("Backend may not start correctly or queries may fail.")
|
| 101 |
+
|
| 102 |
+
# Trigger download at startup
|
| 103 |
+
download_dataset_files()
|
| 104 |
+
|
| 105 |
+
# Import FastAPI app - this is the main backend for React frontend
|
| 106 |
+
from src.api.main import app
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
import uvicorn
|
| 110 |
+
# Get port from environment (Hugging Face uses 7860)
|
| 111 |
+
port = int(os.environ.get("PORT", 7860))
|
| 112 |
+
|
| 113 |
+
logger.info("Starting FastAPI backend on port {}".format(port))
|
| 114 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
src/__init__.py
CHANGED
|
@@ -15,10 +15,11 @@ def _setup_logging() -> None:
|
|
| 15 |
log_format = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
| 16 |
log_file = "logs/medirag.log"
|
| 17 |
|
| 18 |
-
# Try to load level from config.yaml
|
| 19 |
try:
|
| 20 |
import yaml
|
| 21 |
-
|
|
|
|
| 22 |
cfg = yaml.safe_load(f)
|
| 23 |
level_str = cfg.get("logging", {}).get("level", "INFO")
|
| 24 |
log_level = getattr(logging, level_str.upper(), logging.INFO)
|
|
|
|
| 15 |
log_format = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
| 16 |
log_file = "logs/medirag.log"
|
| 17 |
|
| 18 |
+
# Try to load level from config_local.yaml or config.yaml
|
| 19 |
try:
|
| 20 |
import yaml
|
| 21 |
+
config_path = os.environ.get("MEDIRAG_CONFIG", "config_local.yaml" if os.path.exists("config_local.yaml") else "config.yaml")
|
| 22 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 23 |
cfg = yaml.safe_load(f)
|
| 24 |
level_str = cfg.get("logging", {}).get("level", "INFO")
|
| 25 |
log_level = getattr(logging, level_str.upper(), logging.INFO)
|
src/api/main.py
CHANGED
|
@@ -32,6 +32,7 @@ from datetime import datetime
|
|
| 32 |
from fastapi import FastAPI, HTTPException, File, UploadFile
|
| 33 |
from fastapi.middleware.cors import CORSMiddleware
|
| 34 |
from fastapi.responses import RedirectResponse
|
|
|
|
| 35 |
|
| 36 |
import threading
|
| 37 |
from src.api.schemas import (
|
|
@@ -54,7 +55,8 @@ from src.pipeline.retriever import Retriever
|
|
| 54 |
# Logging
|
| 55 |
# ---------------------------------------------------------------------------
|
| 56 |
try:
|
| 57 |
-
|
|
|
|
| 58 |
_log_level = _cfg.get("logging", {}).get("level", "INFO")
|
| 59 |
_ollama_base = _cfg.get("llm", {}).get("base_url", "http://localhost:11434")
|
| 60 |
_api_cfg = _cfg.get("api", {})
|
|
@@ -354,6 +356,40 @@ def evaluate(req: EvaluateRequest) -> EvaluateResponse:
|
|
| 354 |
)
|
| 355 |
|
| 356 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
# ---------------------------------------------------------------------------
|
| 358 |
# POST /query — end-to-end: question → retrieve → generate → evaluate
|
| 359 |
# ---------------------------------------------------------------------------
|
|
@@ -373,20 +409,120 @@ def query(req: QueryRequest) -> QueryResponse:
|
|
| 373 |
import time as _time
|
| 374 |
t_total = _time.perf_counter()
|
| 375 |
|
| 376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
|
| 378 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
retriever: Optional[Retriever] = getattr(app.state, "retriever", None)
|
| 380 |
if retriever is None:
|
| 381 |
-
# Fallback: instantiate now (slower first call)
|
| 382 |
try:
|
| 383 |
retriever = Retriever(_cfg)
|
|
|
|
| 384 |
except Exception as exc:
|
| 385 |
raise HTTPException(status_code=503,
|
| 386 |
detail=f"Retriever unavailable: {exc}") from exc
|
| 387 |
|
| 388 |
try:
|
| 389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
except FileNotFoundError as exc:
|
| 391 |
raise HTTPException(status_code=503,
|
| 392 |
detail=f"FAISS index not found: {exc}") from exc
|
|
@@ -429,32 +565,16 @@ def query(req: QueryRequest) -> QueryResponse:
|
|
| 429 |
top_faiss_cosine = (
|
| 430 |
raw_results[0][1].get("_top_faiss_cosine", 0.0) if raw_results else 0.0
|
| 431 |
)
|
| 432 |
-
|
| 433 |
-
# Convert request overrides into a dict for generator
|
| 434 |
-
llm_overrides = {}
|
| 435 |
-
if req.llm_provider:
|
| 436 |
-
llm_overrides["provider"] = req.llm_provider
|
| 437 |
-
if req.llm_api_key:
|
| 438 |
-
llm_overrides["api_key"] = req.llm_api_key
|
| 439 |
-
if req.llm_model:
|
| 440 |
-
llm_overrides["model"] = req.llm_model
|
| 441 |
-
if req.ollama_url:
|
| 442 |
-
llm_overrides["ollama_url"] = req.ollama_url
|
| 443 |
-
if req.system_prompt:
|
| 444 |
-
llm_overrides["system_prompt"] = req.system_prompt
|
| 445 |
-
if req.persona:
|
| 446 |
-
llm_overrides["persona"] = req.persona
|
| 447 |
-
|
| 448 |
# =========================================================================
|
| 449 |
# Step 2a: PRIVACY SHIELD — MediRAG redacts PHI (Option 1)
|
| 450 |
# =========================================================================
|
| 451 |
p_mapping = {}
|
| 452 |
privacy_applied = False
|
| 453 |
-
question_to_gen =
|
| 454 |
|
| 455 |
if req.use_privacy_shield:
|
| 456 |
from src.pipeline.privacy import shield
|
| 457 |
-
question_to_gen, p_mapping = shield.redact(
|
| 458 |
if p_mapping:
|
| 459 |
privacy_applied = True
|
| 460 |
logger.info("PRIVACY INTERVENTION: Redacted %d items from question.", len(p_mapping))
|
|
@@ -489,7 +609,7 @@ def query(req: QueryRequest) -> QueryResponse:
|
|
| 489 |
providers.append("ollama") # fallback to local if no second key
|
| 490 |
|
| 491 |
logger.info("Running Consensus Layer with %s", providers)
|
| 492 |
-
consensus_results = run_consensus_check(
|
| 493 |
|
| 494 |
# If consensus finds a safer merged answer, we promote it
|
| 495 |
# and update the primary answer for the evaluation loop
|
|
@@ -505,7 +625,7 @@ def query(req: QueryRequest) -> QueryResponse:
|
|
| 505 |
# Step 3: Evaluate
|
| 506 |
try:
|
| 507 |
eval_result = run_evaluation(
|
| 508 |
-
question=
|
| 509 |
answer=answer,
|
| 510 |
context_chunks=context_chunks,
|
| 511 |
run_ragas=req.run_ragas,
|
|
@@ -514,7 +634,7 @@ def query(req: QueryRequest) -> QueryResponse:
|
|
| 514 |
except Exception as exc:
|
| 515 |
logger.exception("Evaluation failed: %s", exc)
|
| 516 |
try:
|
| 517 |
-
log_audit("query",
|
| 518 |
int((_time.perf_counter() - t_total) * 1000),
|
| 519 |
False, {"error": str(exc), "error_type": "evaluation_failure"})
|
| 520 |
except Exception:
|
|
@@ -538,6 +658,50 @@ def query(req: QueryRequest) -> QueryResponse:
|
|
| 538 |
original_answer = None
|
| 539 |
intervention_details = None
|
| 540 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
faith_score = (mod_results.get("faithfulness") or {}).get("score", 1.0)
|
| 542 |
|
| 543 |
# Source-credibility-aware faith threshold: high-credibility sources get more tolerance
|
|
@@ -570,7 +734,7 @@ def query(req: QueryRequest) -> QueryResponse:
|
|
| 570 |
# FDA direct lookup can still retrieve the right data even when initial FAISS
|
| 571 |
# retrieval missed it. Don't label those as coverage gaps — let intervention run.
|
| 572 |
_ev_entities = (mod_results.get("entity_verifier") or {}).get("details", {}).get("entities", [])
|
| 573 |
-
_q_lower_cg =
|
| 574 |
_drug_in_question = any(
|
| 575 |
e.get("rxcui") and e.get("entity", "").lower() in _q_lower_cg
|
| 576 |
for e in _ev_entities
|
|
@@ -606,17 +770,33 @@ def query(req: QueryRequest) -> QueryResponse:
|
|
| 606 |
is_refusal_answer, top_faiss_cosine, faith_score,
|
| 607 |
)
|
| 608 |
|
| 609 |
-
# Tier 1: CRITICAL BLOCK (HRS ≥
|
| 610 |
# Coverage gap: skip both tiers — regenerating from an empty DB won't help
|
| 611 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
logger.info("COVERAGE_GAP — skipping intervention (regeneration cannot add missing data).")
|
| 613 |
-
elif hrs >=
|
| 614 |
original_answer = answer
|
| 615 |
answer = (
|
| 616 |
"⛔ UNSAFE RESPONSE BLOCKED by MediRAG Safety Gate.\n\n"
|
| 617 |
"The generated answer was flagged as CRITICAL risk "
|
| 618 |
-
f"(Health Risk Score: {hrs}/100).
|
| 619 |
-
"It showed signs of hallucination or contradiction with the retrieved evidence. "
|
| 620 |
"Please consult a qualified medical professional or rephrase your question."
|
| 621 |
)
|
| 622 |
intervention_applied = True
|
|
@@ -624,12 +804,12 @@ def query(req: QueryRequest) -> QueryResponse:
|
|
| 624 |
intervention_details = {
|
| 625 |
"hrs_original": hrs,
|
| 626 |
"faithfulness": faith_score,
|
| 627 |
-
"message": "Response blocked: HRS
|
| 628 |
}
|
| 629 |
-
logger.warning("INTERVENTION: CRITICAL_BLOCKED — HRS=%d", hrs)
|
| 630 |
|
| 631 |
# Tier 2: HIGH RISK REGENERATION
|
| 632 |
-
elif hrs >=
|
| 633 |
original_answer = answer
|
| 634 |
original_hrs = hrs
|
| 635 |
logger.warning(
|
|
@@ -651,7 +831,7 @@ def query(req: QueryRequest) -> QueryResponse:
|
|
| 651 |
e["entity"] for e in ev_details.get("entities", [])
|
| 652 |
if e.get("status") == "VERIFIED" and e.get("rxcui")
|
| 653 |
]
|
| 654 |
-
q_lower =
|
| 655 |
for drug in verified_drugs:
|
| 656 |
if drug.lower() in q_lower:
|
| 657 |
fda_direct += app.state.retriever.get_fda_chunks(drug)
|
|
@@ -668,7 +848,7 @@ def query(req: QueryRequest) -> QueryResponse:
|
|
| 668 |
guideline_direct: list[dict] = []
|
| 669 |
if top_faiss_cosine < 0.85:
|
| 670 |
try:
|
| 671 |
-
guideline_direct = app.state.retriever.get_guideline_chunks(
|
| 672 |
if guideline_direct:
|
| 673 |
logger.info("Direct guideline lookup found %d chunks", len(guideline_direct))
|
| 674 |
except Exception as gl_exc:
|
|
@@ -682,11 +862,11 @@ def query(req: QueryRequest) -> QueryResponse:
|
|
| 682 |
# For drug/clinical questions, expand query toward authoritative sources
|
| 683 |
_drug_terms = ("contraindication", "dosage", "dose", "interaction",
|
| 684 |
"warning", "adverse", "side effect", "mechanism")
|
| 685 |
-
_q_lower =
|
| 686 |
retry_query = (
|
| 687 |
-
f"FDA drug label clinical guideline {
|
| 688 |
if any(t in _q_lower for t in _drug_terms)
|
| 689 |
-
else
|
| 690 |
)
|
| 691 |
fresh_results = app.state.retriever.search(retry_query, top_k=req.top_k)
|
| 692 |
fresh_chunks: list[dict] = []
|
|
@@ -705,10 +885,10 @@ def query(req: QueryRequest) -> QueryResponse:
|
|
| 705 |
except Exception:
|
| 706 |
retry_chunks = context_chunks
|
| 707 |
|
| 708 |
-
answer = generate_strict_answer(
|
| 709 |
# Re-evaluate the corrected answer
|
| 710 |
eval_result = run_evaluation(
|
| 711 |
-
question=
|
| 712 |
answer=answer,
|
| 713 |
context_chunks=retry_chunks,
|
| 714 |
run_ragas=False, # skip RAGAS on retry to reduce latency
|
|
@@ -740,15 +920,59 @@ def query(req: QueryRequest) -> QueryResponse:
|
|
| 740 |
logger.info("POST /query → HRS=%d (%s) intervention=%s in %d ms total",
|
| 741 |
hrs, details.get("risk_band", "?"), intervention_reason or "none", total_ms)
|
| 742 |
|
| 743 |
-
log_audit("query",
|
| 744 |
"module_results": mod_results,
|
| 745 |
"confidence_level": details.get("confidence_level", "UNKNOWN"),
|
| 746 |
"intervention_reason": intervention_reason,
|
| 747 |
"original_answer": original_answer,
|
| 748 |
})
|
| 749 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 750 |
return QueryResponse(
|
| 751 |
-
question=
|
| 752 |
generated_answer=answer,
|
| 753 |
retrieved_chunks=retrieved_chunks_out,
|
| 754 |
composite_score=composite,
|
|
@@ -770,6 +994,7 @@ def query(req: QueryRequest) -> QueryResponse:
|
|
| 770 |
consensus_results=consensus_results,
|
| 771 |
privacy_applied=privacy_applied,
|
| 772 |
privacy_details={"redacted_count": len(p_mapping)} if privacy_applied else None,
|
|
|
|
| 773 |
coverage_gap=coverage_gap,
|
| 774 |
coverage_gap_details=coverage_gap_details,
|
| 775 |
)
|
|
|
|
| 32 |
from fastapi import FastAPI, HTTPException, File, UploadFile
|
| 33 |
from fastapi.middleware.cors import CORSMiddleware
|
| 34 |
from fastapi.responses import RedirectResponse
|
| 35 |
+
from pydantic import BaseModel
|
| 36 |
|
| 37 |
import threading
|
| 38 |
from src.api.schemas import (
|
|
|
|
| 55 |
# Logging
|
| 56 |
# ---------------------------------------------------------------------------
|
| 57 |
try:
|
| 58 |
+
_config_path = os.environ.get("MEDIRAG_CONFIG", "config_local.yaml" if Path("config_local.yaml").exists() else "config.yaml")
|
| 59 |
+
_cfg = yaml.safe_load(Path(_config_path).read_text())
|
| 60 |
_log_level = _cfg.get("logging", {}).get("level", "INFO")
|
| 61 |
_ollama_base = _cfg.get("llm", {}).get("base_url", "http://localhost:11434")
|
| 62 |
_api_cfg = _cfg.get("api", {})
|
|
|
|
| 356 |
)
|
| 357 |
|
| 358 |
|
| 359 |
+
# ---------------------------------------------------------------------------
|
| 360 |
+
# POST /translate — lightweight Hinglish to English translation route
|
| 361 |
+
# ---------------------------------------------------------------------------
|
| 362 |
+
class TranslateRequest(BaseModel):
|
| 363 |
+
text: str
|
| 364 |
+
llm_provider: Optional[str] = None
|
| 365 |
+
llm_api_key: Optional[str] = None
|
| 366 |
+
llm_model: Optional[str] = None
|
| 367 |
+
ollama_url: Optional[str] = None
|
| 368 |
+
|
| 369 |
+
class TranslateResponse(BaseModel):
|
| 370 |
+
translated_text: str
|
| 371 |
+
|
| 372 |
+
@app.post("/translate", response_model=TranslateResponse, tags=["translation"])
|
| 373 |
+
def translate(req: TranslateRequest) -> TranslateResponse:
|
| 374 |
+
llm_overrides = {}
|
| 375 |
+
if req.llm_provider:
|
| 376 |
+
llm_overrides["provider"] = req.llm_provider
|
| 377 |
+
if req.llm_api_key:
|
| 378 |
+
llm_overrides["api_key"] = req.llm_api_key
|
| 379 |
+
if req.llm_model:
|
| 380 |
+
llm_overrides["model"] = req.llm_model
|
| 381 |
+
if req.ollama_url:
|
| 382 |
+
llm_overrides["ollama_url"] = req.ollama_url
|
| 383 |
+
|
| 384 |
+
try:
|
| 385 |
+
from src.pipeline.generator import translate_hinglish_to_english
|
| 386 |
+
translated = translate_hinglish_to_english(req.text, _cfg, overrides=llm_overrides)
|
| 387 |
+
return TranslateResponse(translated_text=translated)
|
| 388 |
+
except Exception as exc:
|
| 389 |
+
logger.exception("Translation endpoint failed: %s", exc)
|
| 390 |
+
raise HTTPException(status_code=500, detail=str(exc))
|
| 391 |
+
|
| 392 |
+
|
| 393 |
# ---------------------------------------------------------------------------
|
| 394 |
# POST /query — end-to-end: question → retrieve → generate → evaluate
|
| 395 |
# ---------------------------------------------------------------------------
|
|
|
|
| 409 |
import time as _time
|
| 410 |
t_total = _time.perf_counter()
|
| 411 |
|
| 412 |
+
# Extract request overrides into a dict for translator + generator
|
| 413 |
+
llm_overrides = {}
|
| 414 |
+
if req.llm_provider:
|
| 415 |
+
llm_overrides["provider"] = req.llm_provider
|
| 416 |
+
if req.llm_api_key:
|
| 417 |
+
llm_overrides["api_key"] = req.llm_api_key
|
| 418 |
+
if req.llm_model:
|
| 419 |
+
llm_overrides["model"] = req.llm_model
|
| 420 |
+
if req.ollama_url:
|
| 421 |
+
llm_overrides["ollama_url"] = req.ollama_url
|
| 422 |
+
if req.system_prompt:
|
| 423 |
+
llm_overrides["system_prompt"] = req.system_prompt
|
| 424 |
+
if req.persona:
|
| 425 |
+
llm_overrides["persona"] = req.persona
|
| 426 |
+
|
| 427 |
+
original_hinglish = None
|
| 428 |
+
question_to_use = req.question
|
| 429 |
|
| 430 |
+
if req.translate_hinglish:
|
| 431 |
+
if req.original_hinglish_query:
|
| 432 |
+
original_hinglish = req.original_hinglish_query
|
| 433 |
+
question_to_use = req.question
|
| 434 |
+
logger.info("PRE-TRANSLATED AUDIT SUBMISSION: %r -> %r", original_hinglish, question_to_use)
|
| 435 |
+
else:
|
| 436 |
+
try:
|
| 437 |
+
from src.pipeline.generator import translate_hinglish_to_english
|
| 438 |
+
translated_q = translate_hinglish_to_english(req.question, _cfg, overrides=llm_overrides)
|
| 439 |
+
if translated_q.strip().lower() != req.question.strip().lower():
|
| 440 |
+
original_hinglish = req.question
|
| 441 |
+
question_to_use = translated_q
|
| 442 |
+
logger.info("AUTO-TRANSLATED HINGLISH QUERY: %r -> %r", original_hinglish, question_to_use)
|
| 443 |
+
except Exception as exc:
|
| 444 |
+
logger.error("Hinglish translation module failed: %s", exc)
|
| 445 |
+
|
| 446 |
+
logger.info("POST /query — question=%r, processed_question=%r, top_k=%d",
|
| 447 |
+
req.question[:50], question_to_use[:50], req.top_k)
|
| 448 |
+
|
| 449 |
+
# Safe Semantic Cache lookup
|
| 450 |
+
q_vec = None
|
| 451 |
+
from src.pipeline.semantic_cache import SafeSemanticCache
|
| 452 |
+
semantic_cache = SafeSemanticCache()
|
| 453 |
+
|
| 454 |
+
# Check if retriever can encode
|
| 455 |
retriever: Optional[Retriever] = getattr(app.state, "retriever", None)
|
| 456 |
if retriever is None:
|
|
|
|
| 457 |
try:
|
| 458 |
retriever = Retriever(_cfg)
|
| 459 |
+
app.state.retriever = retriever
|
| 460 |
except Exception as exc:
|
| 461 |
raise HTTPException(status_code=503,
|
| 462 |
detail=f"Retriever unavailable: {exc}") from exc
|
| 463 |
|
| 464 |
try:
|
| 465 |
+
retriever._load_model()
|
| 466 |
+
if retriever._model:
|
| 467 |
+
q_vec = retriever._model.encode(
|
| 468 |
+
[question_to_use.strip()],
|
| 469 |
+
normalize_embeddings=True,
|
| 470 |
+
convert_to_numpy=True,
|
| 471 |
+
)[0].astype(np.float32)
|
| 472 |
+
|
| 473 |
+
cache_hit = semantic_cache.get(
|
| 474 |
+
query_emb=q_vec,
|
| 475 |
+
patient_allergies=req.patient_allergies or [],
|
| 476 |
+
department=req.department or "default",
|
| 477 |
+
overrides=llm_overrides
|
| 478 |
+
)
|
| 479 |
+
if cache_hit:
|
| 480 |
+
logger.info("SEMANTIC CACHE HIT: Returning safe cached response instantly.")
|
| 481 |
+
retrieved_chunks = [
|
| 482 |
+
RetrievedChunk(
|
| 483 |
+
chunk_id=c.get("chunk_id"),
|
| 484 |
+
text=c.get("text"),
|
| 485 |
+
source=c.get("source", ""),
|
| 486 |
+
pub_type=c.get("pub_type", ""),
|
| 487 |
+
pub_year=c.get("pub_year"),
|
| 488 |
+
title=c.get("title", ""),
|
| 489 |
+
similarity_score=c.get("similarity_score", 0.0)
|
| 490 |
+
) for c in cache_hit.get("retrieved_chunks", [])
|
| 491 |
+
]
|
| 492 |
+
mr_dict = cache_hit.get("module_results", {})
|
| 493 |
+
return QueryResponse(
|
| 494 |
+
question=cache_hit.get("question", question_to_use),
|
| 495 |
+
generated_answer=cache_hit.get("generated_answer"),
|
| 496 |
+
retrieved_chunks=retrieved_chunks,
|
| 497 |
+
composite_score=cache_hit.get("composite_score", 1.0),
|
| 498 |
+
hrs=cache_hit.get("hrs", 0),
|
| 499 |
+
confidence_level=cache_hit.get("confidence_level", "UNKNOWN"),
|
| 500 |
+
risk_band=cache_hit.get("risk_band", "UNKNOWN"),
|
| 501 |
+
module_results=ModuleResults(
|
| 502 |
+
faithfulness=_module_score(mr_dict, "faithfulness"),
|
| 503 |
+
entity_verifier=_module_score(mr_dict, "entity_verifier"),
|
| 504 |
+
source_credibility=_module_score(mr_dict, "source_credibility"),
|
| 505 |
+
contradiction=_module_score(mr_dict, "contradiction"),
|
| 506 |
+
ragas=_module_score(mr_dict, "ragas"),
|
| 507 |
+
),
|
| 508 |
+
total_pipeline_ms=0,
|
| 509 |
+
intervention_applied=cache_hit.get("intervention_applied", False),
|
| 510 |
+
intervention_reason=cache_hit.get("intervention_reason"),
|
| 511 |
+
original_answer=cache_hit.get("original_answer"),
|
| 512 |
+
intervention_details=cache_hit.get("intervention_details"),
|
| 513 |
+
consensus_results=cache_hit.get("consensus_results"),
|
| 514 |
+
privacy_applied=cache_hit.get("privacy_applied", False),
|
| 515 |
+
privacy_details=cache_hit.get("privacy_details"),
|
| 516 |
+
original_hinglish_query=cache_hit.get("original_hinglish_query"),
|
| 517 |
+
coverage_gap=cache_hit.get("coverage_gap", False),
|
| 518 |
+
coverage_gap_details=cache_hit.get("coverage_gap_details"),
|
| 519 |
+
)
|
| 520 |
+
except Exception as exc:
|
| 521 |
+
logger.error("Failed semantic cache retrieval lookup: %s", exc)
|
| 522 |
+
|
| 523 |
+
# Step 1: Retrieve
|
| 524 |
+
try:
|
| 525 |
+
raw_results = retriever.search(question_to_use, top_k=req.top_k)
|
| 526 |
except FileNotFoundError as exc:
|
| 527 |
raise HTTPException(status_code=503,
|
| 528 |
detail=f"FAISS index not found: {exc}") from exc
|
|
|
|
| 565 |
top_faiss_cosine = (
|
| 566 |
raw_results[0][1].get("_top_faiss_cosine", 0.0) if raw_results else 0.0
|
| 567 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
# =========================================================================
|
| 569 |
# Step 2a: PRIVACY SHIELD — MediRAG redacts PHI (Option 1)
|
| 570 |
# =========================================================================
|
| 571 |
p_mapping = {}
|
| 572 |
privacy_applied = False
|
| 573 |
+
question_to_gen = question_to_use
|
| 574 |
|
| 575 |
if req.use_privacy_shield:
|
| 576 |
from src.pipeline.privacy import shield
|
| 577 |
+
question_to_gen, p_mapping = shield.redact(question_to_use)
|
| 578 |
if p_mapping:
|
| 579 |
privacy_applied = True
|
| 580 |
logger.info("PRIVACY INTERVENTION: Redacted %d items from question.", len(p_mapping))
|
|
|
|
| 609 |
providers.append("ollama") # fallback to local if no second key
|
| 610 |
|
| 611 |
logger.info("Running Consensus Layer with %s", providers)
|
| 612 |
+
consensus_results = run_consensus_check(question_to_use, context_chunks, _cfg, providers=providers)
|
| 613 |
|
| 614 |
# If consensus finds a safer merged answer, we promote it
|
| 615 |
# and update the primary answer for the evaluation loop
|
|
|
|
| 625 |
# Step 3: Evaluate
|
| 626 |
try:
|
| 627 |
eval_result = run_evaluation(
|
| 628 |
+
question=question_to_use,
|
| 629 |
answer=answer,
|
| 630 |
context_chunks=context_chunks,
|
| 631 |
run_ragas=req.run_ragas,
|
|
|
|
| 634 |
except Exception as exc:
|
| 635 |
logger.exception("Evaluation failed: %s", exc)
|
| 636 |
try:
|
| 637 |
+
log_audit("query", question_to_use, answer, 100, "EVAL_ERROR", 0.0,
|
| 638 |
int((_time.perf_counter() - t_total) * 1000),
|
| 639 |
False, {"error": str(exc), "error_type": "evaluation_failure"})
|
| 640 |
except Exception:
|
|
|
|
| 658 |
original_answer = None
|
| 659 |
intervention_details = None
|
| 660 |
|
| 661 |
+
# Dynamic Department-specific Safety Policies & Patient Allergy Gates
|
| 662 |
+
hrs_block_threshold = 86
|
| 663 |
+
hrs_retry_threshold = 61
|
| 664 |
+
|
| 665 |
+
if req.department:
|
| 666 |
+
dept_lower = req.department.lower()
|
| 667 |
+
if "pediatric" in dept_lower:
|
| 668 |
+
hrs_block_threshold = 50
|
| 669 |
+
hrs_retry_threshold = 20
|
| 670 |
+
logger.info("DEPARTMENT SAFETY TUNING (Pediatrics): Block >= 50, Regenerate >= 20")
|
| 671 |
+
elif "oncology" in dept_lower:
|
| 672 |
+
hrs_block_threshold = 55
|
| 673 |
+
hrs_retry_threshold = 25
|
| 674 |
+
logger.info("DEPARTMENT SAFETY TUNING (Oncology): Block >= 55, Regenerate >= 25")
|
| 675 |
+
elif "cardiology" in dept_lower:
|
| 676 |
+
hrs_block_threshold = 60
|
| 677 |
+
hrs_retry_threshold = 30
|
| 678 |
+
logger.info("DEPARTMENT SAFETY TUNING (Cardiology): Block >= 60, Regenerate >= 30")
|
| 679 |
+
elif "emergency" in dept_lower or "er" in dept_lower:
|
| 680 |
+
hrs_block_threshold = 70
|
| 681 |
+
hrs_retry_threshold = 50
|
| 682 |
+
logger.info("DEPARTMENT SAFETY TUNING (ER): Block >= 70, Regenerate >= 50")
|
| 683 |
+
elif "opd" in dept_lower:
|
| 684 |
+
hrs_block_threshold = 80
|
| 685 |
+
hrs_retry_threshold = 60
|
| 686 |
+
logger.info("DEPARTMENT SAFETY TUNING (OPD): Block >= 80, Regenerate >= 60")
|
| 687 |
+
|
| 688 |
+
# Custom Admin Console Override
|
| 689 |
+
if req.custom_hrs_limit is not None:
|
| 690 |
+
hrs_retry_threshold = req.custom_hrs_limit
|
| 691 |
+
hrs_block_threshold = min(95, req.custom_hrs_limit + (30 if req.custom_hrs_limit <= 30 else 20))
|
| 692 |
+
logger.info(f"HOSPITAL CONSOLE OVERRIDE: Block >= {hrs_block_threshold}, Regenerate >= {hrs_retry_threshold}")
|
| 693 |
+
|
| 694 |
+
# Patient Allergy Safety Interception
|
| 695 |
+
allergy_intercepted = False
|
| 696 |
+
allergen_matched = None
|
| 697 |
+
if req.patient_allergies:
|
| 698 |
+
text_to_scan = (question_to_use + " " + answer).lower()
|
| 699 |
+
for allergen in req.patient_allergies:
|
| 700 |
+
if allergen.strip().lower() in text_to_scan:
|
| 701 |
+
allergy_intercepted = True
|
| 702 |
+
allergen_matched = allergen.strip().capitalize()
|
| 703 |
+
break
|
| 704 |
+
|
| 705 |
faith_score = (mod_results.get("faithfulness") or {}).get("score", 1.0)
|
| 706 |
|
| 707 |
# Source-credibility-aware faith threshold: high-credibility sources get more tolerance
|
|
|
|
| 734 |
# FDA direct lookup can still retrieve the right data even when initial FAISS
|
| 735 |
# retrieval missed it. Don't label those as coverage gaps — let intervention run.
|
| 736 |
_ev_entities = (mod_results.get("entity_verifier") or {}).get("details", {}).get("entities", [])
|
| 737 |
+
_q_lower_cg = question_to_use.lower()
|
| 738 |
_drug_in_question = any(
|
| 739 |
e.get("rxcui") and e.get("entity", "").lower() in _q_lower_cg
|
| 740 |
for e in _ev_entities
|
|
|
|
| 770 |
is_refusal_answer, top_faiss_cosine, faith_score,
|
| 771 |
)
|
| 772 |
|
| 773 |
+
# Tier 1: CRITICAL BLOCK (HRS ≥ hrs_block_threshold) — response is too dangerous to show
|
| 774 |
# Coverage gap: skip both tiers — regenerating from an empty DB won't help
|
| 775 |
+
if allergy_intercepted:
|
| 776 |
+
original_answer = answer
|
| 777 |
+
answer = (
|
| 778 |
+
"⛔ PATIENT SAFETY SHIELD — ALLERGY CONTRADICTION BLOCKED\n\n"
|
| 779 |
+
f"Prescribing or recommending {allergen_matched} is STRICTLY CONTRAINDICATED "
|
| 780 |
+
f"because this patient is flagged as severely ALLERGIC to: {allergen_matched}.\n\n"
|
| 781 |
+
"Immediate Action: Cancel drug order and consult guidelines for safe alternative therapies (e.g. Paracetamol instead of NSAIDs)."
|
| 782 |
+
)
|
| 783 |
+
hrs = 100
|
| 784 |
+
intervention_applied = True
|
| 785 |
+
intervention_reason = "CRITICAL_ALLERGY_BLOCKED"
|
| 786 |
+
intervention_details = {
|
| 787 |
+
"hrs_original": 100,
|
| 788 |
+
"message": f"Response blocked: Patient has an active chart allergy to {allergen_matched}.",
|
| 789 |
+
}
|
| 790 |
+
logger.warning("INTERVENTION: CRITICAL_ALLERGY_BLOCKED — allergen=%s", allergen_matched)
|
| 791 |
+
elif coverage_gap:
|
| 792 |
logger.info("COVERAGE_GAP — skipping intervention (regeneration cannot add missing data).")
|
| 793 |
+
elif hrs >= hrs_block_threshold:
|
| 794 |
original_answer = answer
|
| 795 |
answer = (
|
| 796 |
"⛔ UNSAFE RESPONSE BLOCKED by MediRAG Safety Gate.\n\n"
|
| 797 |
"The generated answer was flagged as CRITICAL risk "
|
| 798 |
+
f"(Health Risk Score: {hrs}/100, Ward Limit: {hrs_block_threshold}%).\n\n"
|
| 799 |
+
"It showed signs of clinical hallucination or contradiction with the retrieved evidence. "
|
| 800 |
"Please consult a qualified medical professional or rephrase your question."
|
| 801 |
)
|
| 802 |
intervention_applied = True
|
|
|
|
| 804 |
intervention_details = {
|
| 805 |
"hrs_original": hrs,
|
| 806 |
"faithfulness": faith_score,
|
| 807 |
+
"message": f"Response blocked: HRS >= {hrs_block_threshold} (Ward limit exceeded).",
|
| 808 |
}
|
| 809 |
+
logger.warning("INTERVENTION: CRITICAL_BLOCKED — HRS=%d (limit=%d)", hrs, hrs_block_threshold)
|
| 810 |
|
| 811 |
# Tier 2: HIGH RISK REGENERATION
|
| 812 |
+
elif hrs >= hrs_retry_threshold or faith_score < faith_threshold:
|
| 813 |
original_answer = answer
|
| 814 |
original_hrs = hrs
|
| 815 |
logger.warning(
|
|
|
|
| 831 |
e["entity"] for e in ev_details.get("entities", [])
|
| 832 |
if e.get("status") == "VERIFIED" and e.get("rxcui")
|
| 833 |
]
|
| 834 |
+
q_lower = question_to_use.lower()
|
| 835 |
for drug in verified_drugs:
|
| 836 |
if drug.lower() in q_lower:
|
| 837 |
fda_direct += app.state.retriever.get_fda_chunks(drug)
|
|
|
|
| 848 |
guideline_direct: list[dict] = []
|
| 849 |
if top_faiss_cosine < 0.85:
|
| 850 |
try:
|
| 851 |
+
guideline_direct = app.state.retriever.get_guideline_chunks(question_to_use)
|
| 852 |
if guideline_direct:
|
| 853 |
logger.info("Direct guideline lookup found %d chunks", len(guideline_direct))
|
| 854 |
except Exception as gl_exc:
|
|
|
|
| 862 |
# For drug/clinical questions, expand query toward authoritative sources
|
| 863 |
_drug_terms = ("contraindication", "dosage", "dose", "interaction",
|
| 864 |
"warning", "adverse", "side effect", "mechanism")
|
| 865 |
+
_q_lower = question_to_use.lower()
|
| 866 |
retry_query = (
|
| 867 |
+
f"FDA drug label clinical guideline {question_to_use}"
|
| 868 |
if any(t in _q_lower for t in _drug_terms)
|
| 869 |
+
else question_to_use
|
| 870 |
)
|
| 871 |
fresh_results = app.state.retriever.search(retry_query, top_k=req.top_k)
|
| 872 |
fresh_chunks: list[dict] = []
|
|
|
|
| 885 |
except Exception:
|
| 886 |
retry_chunks = context_chunks
|
| 887 |
|
| 888 |
+
answer = generate_strict_answer(question_to_use, retry_chunks, _cfg, overrides=llm_overrides)
|
| 889 |
# Re-evaluate the corrected answer
|
| 890 |
eval_result = run_evaluation(
|
| 891 |
+
question=question_to_use,
|
| 892 |
answer=answer,
|
| 893 |
context_chunks=retry_chunks,
|
| 894 |
run_ragas=False, # skip RAGAS on retry to reduce latency
|
|
|
|
| 920 |
logger.info("POST /query → HRS=%d (%s) intervention=%s in %d ms total",
|
| 921 |
hrs, details.get("risk_band", "?"), intervention_reason or "none", total_ms)
|
| 922 |
|
| 923 |
+
log_audit("query", question_to_use, answer, hrs, details.get("risk_band", "UNKNOWN"), composite, total_ms, intervention_applied, {
|
| 924 |
"module_results": mod_results,
|
| 925 |
"confidence_level": details.get("confidence_level", "UNKNOWN"),
|
| 926 |
"intervention_reason": intervention_reason,
|
| 927 |
"original_answer": original_answer,
|
| 928 |
})
|
| 929 |
|
| 930 |
+
# Save successful evaluation to Safe Semantic Cache
|
| 931 |
+
if q_vec is not None and not coverage_gap:
|
| 932 |
+
try:
|
| 933 |
+
response_dict = {
|
| 934 |
+
"question": question_to_use,
|
| 935 |
+
"generated_answer": answer,
|
| 936 |
+
"retrieved_chunks": [
|
| 937 |
+
{
|
| 938 |
+
"chunk_id": c.chunk_id,
|
| 939 |
+
"text": c.text,
|
| 940 |
+
"source": c.source,
|
| 941 |
+
"pub_type": c.pub_type,
|
| 942 |
+
"pub_year": c.pub_year,
|
| 943 |
+
"title": c.title,
|
| 944 |
+
"similarity_score": c.similarity_score
|
| 945 |
+
} for c in retrieved_chunks_out
|
| 946 |
+
],
|
| 947 |
+
"composite_score": composite,
|
| 948 |
+
"hrs": hrs,
|
| 949 |
+
"confidence_level": details.get("confidence_level", "UNKNOWN"),
|
| 950 |
+
"risk_band": details.get("risk_band", "UNKNOWN"),
|
| 951 |
+
"module_results": mod_results,
|
| 952 |
+
"intervention_applied": intervention_applied,
|
| 953 |
+
"intervention_reason": intervention_reason,
|
| 954 |
+
"original_answer": original_answer,
|
| 955 |
+
"intervention_details": intervention_details,
|
| 956 |
+
"consensus_results": consensus_results,
|
| 957 |
+
"privacy_applied": privacy_applied,
|
| 958 |
+
"privacy_details": {"redacted_count": len(p_mapping)} if privacy_applied else None,
|
| 959 |
+
"original_hinglish_query": original_hinglish,
|
| 960 |
+
"coverage_gap": coverage_gap,
|
| 961 |
+
"coverage_gap_details": coverage_gap_details,
|
| 962 |
+
}
|
| 963 |
+
semantic_cache.store(
|
| 964 |
+
query_text=question_to_use,
|
| 965 |
+
query_emb=q_vec,
|
| 966 |
+
response=response_dict,
|
| 967 |
+
patient_allergies=req.patient_allergies or [],
|
| 968 |
+
department=req.department or "default",
|
| 969 |
+
overrides=llm_overrides
|
| 970 |
+
)
|
| 971 |
+
except Exception as exc:
|
| 972 |
+
logger.error("Failed to store in semantic cache: %s", exc)
|
| 973 |
+
|
| 974 |
return QueryResponse(
|
| 975 |
+
question=question_to_use,
|
| 976 |
generated_answer=answer,
|
| 977 |
retrieved_chunks=retrieved_chunks_out,
|
| 978 |
composite_score=composite,
|
|
|
|
| 994 |
consensus_results=consensus_results,
|
| 995 |
privacy_applied=privacy_applied,
|
| 996 |
privacy_details={"redacted_count": len(p_mapping)} if privacy_applied else None,
|
| 997 |
+
original_hinglish_query=original_hinglish,
|
| 998 |
coverage_gap=coverage_gap,
|
| 999 |
coverage_gap_details=coverage_gap_details,
|
| 1000 |
)
|
src/api/schemas.py
CHANGED
|
@@ -206,6 +206,14 @@ class QueryRequest(BaseModel):
|
|
| 206 |
default=False,
|
| 207 |
description="Automatically redact PHI/PII (names, IDs) before external API calls.",
|
| 208 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
system_prompt: Optional[str] = Field(
|
| 210 |
default=None,
|
| 211 |
description="Custom system prompt to override the default clinical persona."
|
|
@@ -214,6 +222,22 @@ class QueryRequest(BaseModel):
|
|
| 214 |
default="physician",
|
| 215 |
description="The target audience for the response: 'physician' or 'patient'."
|
| 216 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
|
| 219 |
class RetrievedChunk(BaseModel):
|
|
@@ -264,6 +288,7 @@ class QueryResponse(BaseModel):
|
|
| 264 |
# Privacy Shield fields
|
| 265 |
privacy_applied: bool = Field(default=False)
|
| 266 |
privacy_details: Optional[Dict[str, Any]] = Field(default=None)
|
|
|
|
| 267 |
# Coverage gap gate — distinguishes missing DB coverage from hallucination
|
| 268 |
coverage_gap: bool = Field(
|
| 269 |
default=False,
|
|
|
|
| 206 |
default=False,
|
| 207 |
description="Automatically redact PHI/PII (names, IDs) before external API calls.",
|
| 208 |
)
|
| 209 |
+
translate_hinglish: bool = Field(
|
| 210 |
+
default=False,
|
| 211 |
+
description="Translate query from Hinglish to English before processing",
|
| 212 |
+
)
|
| 213 |
+
original_hinglish_query: Optional[str] = Field(
|
| 214 |
+
default=None,
|
| 215 |
+
description="[OPTIONAL] Pre-translated original Hinglish text from the front-end audit gate."
|
| 216 |
+
)
|
| 217 |
system_prompt: Optional[str] = Field(
|
| 218 |
default=None,
|
| 219 |
description="Custom system prompt to override the default clinical persona."
|
|
|
|
| 222 |
default="physician",
|
| 223 |
description="The target audience for the response: 'physician' or 'patient'."
|
| 224 |
)
|
| 225 |
+
department: Optional[str] = Field(
|
| 226 |
+
default=None,
|
| 227 |
+
description="[OPTIONAL] Active department (oncology, cardiology, pediatrics, opd, etc.) to trigger custom safety thresholds."
|
| 228 |
+
)
|
| 229 |
+
patient_allergies: Optional[list[str]] = Field(
|
| 230 |
+
default=None,
|
| 231 |
+
description="[OPTIONAL] List of patient drug allergies to scan against recommended medications."
|
| 232 |
+
)
|
| 233 |
+
custom_hrs_limit: Optional[int] = Field(
|
| 234 |
+
default=None,
|
| 235 |
+
description="[OPTIONAL] Custom HRS risk tolerance percentage (0-100) set by the hospital console."
|
| 236 |
+
)
|
| 237 |
+
custom_latency_limit: Optional[int] = Field(
|
| 238 |
+
default=None,
|
| 239 |
+
description="[OPTIONAL] Custom max allowed latency in ms."
|
| 240 |
+
)
|
| 241 |
|
| 242 |
|
| 243 |
class RetrievedChunk(BaseModel):
|
|
|
|
| 288 |
# Privacy Shield fields
|
| 289 |
privacy_applied: bool = Field(default=False)
|
| 290 |
privacy_details: Optional[Dict[str, Any]] = Field(default=None)
|
| 291 |
+
original_hinglish_query: Optional[str] = Field(default=None, description="The original Hinglish query before translation")
|
| 292 |
# Coverage gap gate — distinguishes missing DB coverage from hallucination
|
| 293 |
coverage_gap: bool = Field(
|
| 294 |
default=False,
|
src/modules/entity_verifier.py
CHANGED
|
@@ -172,6 +172,56 @@ def _lookup_rxnorm_api(drug_name: str, timeout: int = 4) -> Optional[str]:
|
|
| 172 |
return None
|
| 173 |
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
# ---------------------------------------------------------------------------
|
| 176 |
# Public API
|
| 177 |
# ---------------------------------------------------------------------------
|
|
@@ -306,12 +356,29 @@ def verify_entities(
|
|
| 306 |
|
| 307 |
entity_results.append(result)
|
| 308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
# --- Score ---------------------------------------------------------------
|
| 310 |
# Score is based on drug entities only (per SRS Section 6.2)
|
| 311 |
if drug_total == 0:
|
| 312 |
score = 0.5 # neutral — no drug entities to verify
|
| 313 |
else:
|
| 314 |
-
score
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
| 316 |
details = {
|
| 317 |
"total_entities": len(raw_entities),
|
|
@@ -319,12 +386,13 @@ def verify_entities(
|
|
| 319 |
"verified_count": drug_verified,
|
| 320 |
"flagged_count": drug_flagged,
|
| 321 |
"entities": entity_results,
|
|
|
|
| 322 |
}
|
| 323 |
|
| 324 |
latency_ms = int((time.perf_counter() - t0) * 1000)
|
| 325 |
logger.info(
|
| 326 |
-
"Entity verification: %.3f (%d/%d drugs verified) in %d ms",
|
| 327 |
-
score, drug_verified, drug_total, latency_ms,
|
| 328 |
)
|
| 329 |
return EvalResult(
|
| 330 |
module_name="entity_verifier",
|
|
|
|
| 172 |
return None
|
| 173 |
|
| 174 |
|
| 175 |
+
@lru_cache(maxsize=1024)
|
| 176 |
+
def _cached_drug_interactions(rxcuis_tuple: tuple[str, ...], timeout: int) -> list[dict]:
|
| 177 |
+
"""
|
| 178 |
+
Synchronous cached NIH REST request to resolve drug interactions.
|
| 179 |
+
"""
|
| 180 |
+
rxcuis_str = "+".join(rxcuis_tuple)
|
| 181 |
+
url = f"https://rxnav.nlm.nih.gov/REST/interaction/list.json?rxcuis={rxcuis_str}"
|
| 182 |
+
try:
|
| 183 |
+
resp = requests.get(url, timeout=timeout)
|
| 184 |
+
if resp.status_code != 200:
|
| 185 |
+
return []
|
| 186 |
+
|
| 187 |
+
data = resp.json()
|
| 188 |
+
interactions = []
|
| 189 |
+
|
| 190 |
+
# ONCHigh returns fullInteractionTypeGroup
|
| 191 |
+
groups = data.get("fullInteractionTypeGroup", [])
|
| 192 |
+
for group in groups:
|
| 193 |
+
for fit in group.get("fullInteractionType", []):
|
| 194 |
+
for pair in fit.get("interactionPair", []):
|
| 195 |
+
# Extract concepts
|
| 196 |
+
concepts = pair.get("interactionConcept", [])
|
| 197 |
+
drugs_involved = [c.get("minConcept", {}).get("name", "Unknown") for c in concepts]
|
| 198 |
+
severity = pair.get("severity", "high").lower() # default to high since it's from ONCHigh
|
| 199 |
+
description = pair.get("description", "")
|
| 200 |
+
interactions.append({
|
| 201 |
+
"drugs": drugs_involved,
|
| 202 |
+
"severity": severity,
|
| 203 |
+
"description": description
|
| 204 |
+
})
|
| 205 |
+
return interactions
|
| 206 |
+
except Exception as e:
|
| 207 |
+
logger.error("Failed to fetch drug interactions from RxNav: %s", e)
|
| 208 |
+
return []
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def check_drug_interactions(rxcuis: list[str], timeout: int = 5) -> list[dict]:
|
| 212 |
+
"""
|
| 213 |
+
Query RxNav API for drug-drug interactions between a list of RxCUIs.
|
| 214 |
+
Uses sorted tuple transformation to enable efficient order-independent caching.
|
| 215 |
+
"""
|
| 216 |
+
if len(rxcuis) < 2:
|
| 217 |
+
return []
|
| 218 |
+
|
| 219 |
+
# Sort RxCUIs to ensure cache consistency regardless of list order
|
| 220 |
+
rxcuis_tuple = tuple(sorted(rxcuis))
|
| 221 |
+
return _cached_drug_interactions(rxcuis_tuple, timeout)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
|
| 225 |
# ---------------------------------------------------------------------------
|
| 226 |
# Public API
|
| 227 |
# ---------------------------------------------------------------------------
|
|
|
|
| 356 |
|
| 357 |
entity_results.append(result)
|
| 358 |
|
| 359 |
+
# --- Drug-Drug Interaction Check (DDI) -----------------------------------
|
| 360 |
+
# Gather standard RxCUIs for the verified drugs
|
| 361 |
+
rxcuis = [ent["rxcui"] for ent in entity_results if ent.get("rxcui")]
|
| 362 |
+
unique_rxcuis = list(set(rxcuis))
|
| 363 |
+
interactions = []
|
| 364 |
+
|
| 365 |
+
if len(unique_rxcuis) >= 2:
|
| 366 |
+
logger.info("Multiple drugs detected in answer (%s) — checking for interactions...", unique_rxcuis)
|
| 367 |
+
interactions = check_drug_interactions(unique_rxcuis)
|
| 368 |
+
if interactions:
|
| 369 |
+
logger.warning("DDI Check: Found %d drug interactions!", len(interactions))
|
| 370 |
+
drug_flagged += len(interactions)
|
| 371 |
+
|
| 372 |
# --- Score ---------------------------------------------------------------
|
| 373 |
# Score is based on drug entities only (per SRS Section 6.2)
|
| 374 |
if drug_total == 0:
|
| 375 |
score = 0.5 # neutral — no drug entities to verify
|
| 376 |
else:
|
| 377 |
+
# Base score is drug_verified / drug_total
|
| 378 |
+
base_score = drug_verified / drug_total
|
| 379 |
+
# Deduct score for multi-drug interactions (0.2 deduction per interaction, cap at 0.0)
|
| 380 |
+
interaction_deduction = len(interactions) * 0.20
|
| 381 |
+
score = max(0.0, base_score - interaction_deduction)
|
| 382 |
|
| 383 |
details = {
|
| 384 |
"total_entities": len(raw_entities),
|
|
|
|
| 386 |
"verified_count": drug_verified,
|
| 387 |
"flagged_count": drug_flagged,
|
| 388 |
"entities": entity_results,
|
| 389 |
+
"interactions": interactions,
|
| 390 |
}
|
| 391 |
|
| 392 |
latency_ms = int((time.perf_counter() - t0) * 1000)
|
| 393 |
logger.info(
|
| 394 |
+
"Entity verification: %.3f (%d/%d drugs verified, %d DDI found) in %d ms",
|
| 395 |
+
score, drug_verified, drug_total, len(interactions), latency_ms,
|
| 396 |
)
|
| 397 |
return EvalResult(
|
| 398 |
module_name="entity_verifier",
|
src/pipeline/generator.py
CHANGED
|
@@ -50,7 +50,8 @@ _load_env()
|
|
| 50 |
|
| 51 |
def _load_config() -> dict:
|
| 52 |
try:
|
| 53 |
-
|
|
|
|
| 54 |
except Exception:
|
| 55 |
return {}
|
| 56 |
|
|
@@ -582,3 +583,78 @@ def generate_strict_answer(
|
|
| 582 |
return _generate_groq(prompt, effective_config)
|
| 583 |
else:
|
| 584 |
raise RuntimeError(f"Unknown LLM provider '{provider}'.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
def _load_config() -> dict:
|
| 52 |
try:
|
| 53 |
+
config_path = os.environ.get("MEDIRAG_CONFIG", "config_local.yaml" if Path("config_local.yaml").exists() else "config.yaml")
|
| 54 |
+
return yaml.safe_load(Path(config_path).read_text())
|
| 55 |
except Exception:
|
| 56 |
return {}
|
| 57 |
|
|
|
|
| 583 |
return _generate_groq(prompt, effective_config)
|
| 584 |
else:
|
| 585 |
raise RuntimeError(f"Unknown LLM provider '{provider}'.")
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def generate_simple_prompt(
|
| 589 |
+
prompt: str,
|
| 590 |
+
config: Optional[dict] = None,
|
| 591 |
+
overrides: Optional[dict] = None,
|
| 592 |
+
) -> str:
|
| 593 |
+
"""Execute a simple prompt on the active LLM provider without context formatting."""
|
| 594 |
+
if config is None:
|
| 595 |
+
config = _load_config()
|
| 596 |
+
|
| 597 |
+
effective_llm = dict(config.get("llm", {}))
|
| 598 |
+
if overrides:
|
| 599 |
+
if overrides.get("provider"):
|
| 600 |
+
effective_llm["provider"] = overrides["provider"]
|
| 601 |
+
if overrides.get("api_key"):
|
| 602 |
+
pk = (overrides.get("provider") or "gemini").lower()
|
| 603 |
+
key_map = {
|
| 604 |
+
"gemini": "gemini_api_key",
|
| 605 |
+
"openai": "openai_api_key",
|
| 606 |
+
"mistral": "mistral_api_key",
|
| 607 |
+
"groq": "groq_api_key",
|
| 608 |
+
}
|
| 609 |
+
effective_llm[key_map.get(pk, "gemini_api_key")] = overrides["api_key"]
|
| 610 |
+
if overrides.get("model"):
|
| 611 |
+
pk = (overrides.get("provider") or "gemini").lower()
|
| 612 |
+
model_map = {
|
| 613 |
+
"gemini": "gemini_model",
|
| 614 |
+
"openai": "openai_model",
|
| 615 |
+
"mistral": "model",
|
| 616 |
+
"groq": "groq_model",
|
| 617 |
+
}
|
| 618 |
+
effective_llm[model_map.get(pk, "gemini_model")] = overrides["model"]
|
| 619 |
+
if overrides.get("ollama_url"):
|
| 620 |
+
effective_llm["base_url"] = overrides["ollama_url"]
|
| 621 |
+
|
| 622 |
+
effective_config = {**config, "llm": effective_llm}
|
| 623 |
+
provider = effective_llm.get("provider", "gemini").lower()
|
| 624 |
+
|
| 625 |
+
if provider == "gemini":
|
| 626 |
+
return _generate_gemini(prompt, effective_config)
|
| 627 |
+
elif provider == "openai":
|
| 628 |
+
return _generate_openai(prompt, effective_config)
|
| 629 |
+
elif provider == "ollama":
|
| 630 |
+
return _generate_ollama(prompt, effective_config)
|
| 631 |
+
elif provider == "mistral":
|
| 632 |
+
return _generate_mistral(prompt, effective_config)
|
| 633 |
+
elif provider == "groq":
|
| 634 |
+
return _generate_groq(prompt, effective_config)
|
| 635 |
+
else:
|
| 636 |
+
raise RuntimeError(f"Unknown LLM provider '{provider}'.")
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def translate_hinglish_to_english(
|
| 640 |
+
question: str,
|
| 641 |
+
config: Optional[dict] = None,
|
| 642 |
+
overrides: Optional[dict] = None,
|
| 643 |
+
) -> str:
|
| 644 |
+
"""Translate clinical query from Hinglish or standard Hindi to professional English."""
|
| 645 |
+
prompt = (
|
| 646 |
+
"You are an expert bilingual clinical query translator. You will receive a medical question "
|
| 647 |
+
"written in Hinglish (a mixture of Hindi and English written in the Latin alphabet) or standard Hindi. "
|
| 648 |
+
"Convert the Hinglish/Hindi question into a clear, professional, grammatically correct English clinical query. "
|
| 649 |
+
"If the input query is already completely in English, return it exactly as it is with no edits. "
|
| 650 |
+
"Do NOT add any conversational preamble, greetings, explanation, or formatting. Only return the translated English query.\n\n"
|
| 651 |
+
f"Query: {question}\n"
|
| 652 |
+
"English Translation:"
|
| 653 |
+
)
|
| 654 |
+
try:
|
| 655 |
+
translated = generate_simple_prompt(prompt, config=config, overrides=overrides)
|
| 656 |
+
return translated.strip().strip('"').strip("'")
|
| 657 |
+
except Exception as exc:
|
| 658 |
+
logger.warning("Hinglish translation failed: %s. Using original query.", exc)
|
| 659 |
+
return question
|
| 660 |
+
|
src/pipeline/retriever.py
CHANGED
|
@@ -433,7 +433,9 @@ class Retriever:
|
|
| 433 |
# ---------------------------------------------------------------------------
|
| 434 |
|
| 435 |
def _load_config() -> dict:
|
| 436 |
-
|
|
|
|
|
|
|
| 437 |
return yaml.safe_load(f)
|
| 438 |
|
| 439 |
|
|
|
|
| 433 |
# ---------------------------------------------------------------------------
|
| 434 |
|
| 435 |
def _load_config() -> dict:
|
| 436 |
+
import os
|
| 437 |
+
config_path = os.environ.get("MEDIRAG_CONFIG", "config_local.yaml" if os.path.exists("config_local.yaml") else "config.yaml")
|
| 438 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 439 |
return yaml.safe_load(f)
|
| 440 |
|
| 441 |
|
src/pipeline/semantic_cache.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
import json
|
| 3 |
+
import hashlib
|
| 4 |
+
import numpy as np
|
| 5 |
+
import logging
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
class SafeSemanticCache:
|
| 11 |
+
def __init__(self, db_path="data/cache.db", threshold=0.97):
|
| 12 |
+
self.db_path = db_path
|
| 13 |
+
self.threshold = threshold
|
| 14 |
+
self._init_db()
|
| 15 |
+
|
| 16 |
+
def _init_db(self):
|
| 17 |
+
# Ensure containing directory exists
|
| 18 |
+
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
| 19 |
+
conn = sqlite3.connect(self.db_path)
|
| 20 |
+
conn.execute("""
|
| 21 |
+
CREATE TABLE IF NOT EXISTS semantic_cache (
|
| 22 |
+
id INTEGER PRIMARY KEY,
|
| 23 |
+
query_text TEXT,
|
| 24 |
+
embedding BLOB,
|
| 25 |
+
cache_hash TEXT UNIQUE,
|
| 26 |
+
response_json TEXT,
|
| 27 |
+
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
| 28 |
+
)
|
| 29 |
+
""")
|
| 30 |
+
conn.commit()
|
| 31 |
+
conn.close()
|
| 32 |
+
|
| 33 |
+
def _generate_hash(self, query_emb: np.ndarray, patient_allergies: list[str], department: str, overrides: dict) -> str:
|
| 34 |
+
# Create a deterministic representation of the safety environment
|
| 35 |
+
allergies_str = ",".join(sorted([a.lower().strip() for a in patient_allergies]))
|
| 36 |
+
dept_str = department.lower().strip()
|
| 37 |
+
overrides_str = json.dumps(overrides, sort_keys=True)
|
| 38 |
+
|
| 39 |
+
# Round embedding to 4 decimals to ensure stability against float discrepancies
|
| 40 |
+
emb_str = np.round(query_emb, 4).tobytes()
|
| 41 |
+
|
| 42 |
+
hasher = hashlib.sha256()
|
| 43 |
+
hasher.update(emb_str)
|
| 44 |
+
hasher.update(allergies_str.encode('utf-8'))
|
| 45 |
+
hasher.update(dept_str.encode('utf-8'))
|
| 46 |
+
hasher.update(overrides_str.encode('utf-8'))
|
| 47 |
+
return hasher.hexdigest()
|
| 48 |
+
|
| 49 |
+
def get(self, query_emb: np.ndarray, patient_allergies: list[str], department: str, overrides: dict) -> dict | None:
|
| 50 |
+
target_hash = self._generate_hash(query_emb, patient_allergies, department, overrides)
|
| 51 |
+
|
| 52 |
+
conn = sqlite3.connect(self.db_path)
|
| 53 |
+
cursor = conn.cursor()
|
| 54 |
+
# Direct hash lookup first (O(1) fast path)
|
| 55 |
+
cursor.execute("SELECT response_json FROM semantic_cache WHERE cache_hash = ?", (target_hash,))
|
| 56 |
+
row = cursor.fetchone()
|
| 57 |
+
|
| 58 |
+
if row:
|
| 59 |
+
conn.close()
|
| 60 |
+
logger.info("Semantic Cache: Direct hash hit! Returning safe response.")
|
| 61 |
+
try:
|
| 62 |
+
return json.loads(row[0])
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.error(f"Failed to parse cached JSON: {e}")
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
# Fuzzy lookup (Cosine similarity fallback under identical safety settings)
|
| 68 |
+
cursor.execute("SELECT query_text, embedding, response_json, cache_hash FROM semantic_cache")
|
| 69 |
+
rows = cursor.fetchall()
|
| 70 |
+
conn.close()
|
| 71 |
+
|
| 72 |
+
for query, emb_bytes, response_json, cached_hash in rows:
|
| 73 |
+
saved_emb = np.frombuffer(emb_bytes, dtype=np.float32)
|
| 74 |
+
# Compute cosine similarity
|
| 75 |
+
norm_product = np.linalg.norm(query_emb) * np.linalg.norm(saved_emb)
|
| 76 |
+
if norm_product == 0:
|
| 77 |
+
continue
|
| 78 |
+
cosine = np.dot(query_emb, saved_emb) / norm_product
|
| 79 |
+
|
| 80 |
+
if cosine >= self.threshold:
|
| 81 |
+
# Re-verify that the safety hash matches (no allergy difference)
|
| 82 |
+
# To prevent cross-contamination, fuzzy match requires identical allergies/department config
|
| 83 |
+
candidate_hash = self._generate_hash(saved_emb, patient_allergies, department, overrides)
|
| 84 |
+
if candidate_hash == cached_hash:
|
| 85 |
+
logger.info(f"Semantic Cache: Fuzzy similarity hit! ({cosine:.4f})")
|
| 86 |
+
try:
|
| 87 |
+
return json.loads(response_json)
|
| 88 |
+
except Exception as e:
|
| 89 |
+
logger.error(f"Failed to parse cached JSON in fuzzy match: {e}")
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
+
def store(self, query_text: str, query_emb: np.ndarray, response: dict, patient_allergies: list[str], department: str, overrides: dict):
|
| 95 |
+
target_hash = self._generate_hash(query_emb, patient_allergies, department, overrides)
|
| 96 |
+
conn = sqlite3.connect(self.db_path)
|
| 97 |
+
try:
|
| 98 |
+
conn.execute("""
|
| 99 |
+
INSERT OR REPLACE INTO semantic_cache (query_text, embedding, cache_hash, response_json)
|
| 100 |
+
VALUES (?, ?, ?, ?)
|
| 101 |
+
""", (
|
| 102 |
+
query_text,
|
| 103 |
+
query_emb.tobytes(),
|
| 104 |
+
target_hash,
|
| 105 |
+
json.dumps(response)
|
| 106 |
+
))
|
| 107 |
+
conn.commit()
|
| 108 |
+
logger.info("Saved successful evaluation to Semantic Cache.")
|
| 109 |
+
except Exception as e:
|
| 110 |
+
logger.error(f"Failed to store in cache: {e}")
|
| 111 |
+
finally:
|
| 112 |
+
conn.close()
|
tests/test_modules.py
CHANGED
|
@@ -64,3 +64,27 @@ def test_aggregator_logic():
|
|
| 64 |
assert abs(res.score - 0.9) < 0.01
|
| 65 |
assert res.details["hrs"] == 10
|
| 66 |
assert res.details["risk_band"] == "LOW"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
assert abs(res.score - 0.9) < 0.01
|
| 65 |
assert res.details["hrs"] == 10
|
| 66 |
assert res.details["risk_band"] == "LOW"
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def test_drug_interactions_and_entity_verifier():
|
| 70 |
+
from src.modules.entity_verifier import check_drug_interactions, verify_entities
|
| 71 |
+
|
| 72 |
+
# 1. Test DDI check directly with known interactive drugs (Warfarin: 11289, Ibuprofen: 5640)
|
| 73 |
+
interactions = check_drug_interactions(["11289", "5640"])
|
| 74 |
+
|
| 75 |
+
# Verify interactions structure is a valid list of dicts
|
| 76 |
+
assert isinstance(interactions, list)
|
| 77 |
+
if interactions:
|
| 78 |
+
assert "drugs" in interactions[0]
|
| 79 |
+
assert "severity" in interactions[0]
|
| 80 |
+
assert "description" in interactions[0]
|
| 81 |
+
|
| 82 |
+
# 2. Verify fallback & interface safety of verify_entities
|
| 83 |
+
res = verify_entities(
|
| 84 |
+
answer="Patient is taking Metformin and Lisinopril.",
|
| 85 |
+
question="What medications is the patient on?",
|
| 86 |
+
context_docs=["The patient is prescribed Metformin 500mg and Lisinopril 10mg."]
|
| 87 |
+
)
|
| 88 |
+
assert res.score is not None
|
| 89 |
+
assert isinstance(res.details, dict)
|
| 90 |
+
|