joytheslothh commited on
Commit
1bf0a27
·
verified ·
1 Parent(s): b6f9fa8

Update: backend v3.2 — privacy pipeline, consensus, new scripts

Browse files
app.py CHANGED
@@ -1,98 +1,114 @@
1
- """
2
- MediRAG Backend - FastAPI only (No Gradio)
3
- React frontend on Vercel, this is just the API backend
4
- """
5
-
6
- import os
7
- import sys
8
- import subprocess
9
- import logging
10
- import requests
11
-
12
- # Configure logging
13
- logging.basicConfig(level=logging.INFO)
14
- logger = logging.getLogger(__name__)
15
-
16
- # Set cache directories for Hugging Face
17
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
18
- os.environ["HF_HOME"] = "/tmp/hf_home"
19
- os.environ["TORCH_HOME"] = "/tmp/torch_cache"
20
-
21
- # Add src to path
22
- sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
23
-
24
- # Install spaCy model if not present (optional — server starts without it)
25
- try:
26
- import spacy
27
- try:
28
- spacy.load("en_core_sci_lg")
29
- logger.info("spaCy model en_core_sci_lg loaded.")
30
- except OSError:
31
- # Try installing the model at runtime
32
- try:
33
- logger.info("Attempting to install scispacy model en_core_sci_lg...")
34
- subprocess.run([
35
- sys.executable, "-m", "pip", "install", "--quiet",
36
- "https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_lg-0.5.4.tar.gz"
37
- ], check=True, timeout=300)
38
- spacy.load("en_core_sci_lg")
39
- logger.info("spaCy model installed and loaded.")
40
- except Exception as model_err:
41
- logger.warning(f"Could not install spaCy model: {model_err}. NER features will be limited.")
42
- except ImportError:
43
- logger.warning("spacy/scispacy not installed. NER features will be limited but server will still start.")
44
-
45
- # Download datasets using huggingface_hub
46
- from huggingface_hub import hf_hub_download
47
-
48
- # Check and download index and data files
49
- data_dir = os.path.join(os.path.dirname(__file__), "data")
50
- index_dir = os.path.join(data_dir, "index")
51
- os.makedirs(index_dir, exist_ok=True)
52
-
53
- faiss_path = os.path.join(index_dir, "faiss.index")
54
- metadata_path = os.path.join(index_dir, "metadata_store.pkl")
55
- bm25_path = os.path.join(index_dir, "bm25_cache.pkl")
56
- vocab_path = os.path.join(data_dir, "drugbank vocabulary.csv")
57
- rxnorm_path = os.path.join(data_dir, "rxnorm_cache.csv")
58
-
59
- def download_dataset_files():
60
- """Download FAISS index and other core data from Hugging Face Dataset"""
61
- repo_id = "joytheslothh/MediRAG-Index-Data"
62
- token = os.environ.get("HF_TOKEN")
63
- if not token:
64
- logger.warning("HF_TOKEN environment variable is not set. Dataset download might fail if repo is private.")
65
-
66
- try:
67
- if not os.path.exists(faiss_path):
68
- logger.info("Downloading faiss.index from HF dataset...")
69
- hf_hub_download(repo_id=repo_id, filename="index/faiss.index", local_dir=data_dir, repo_type="dataset", token=token)
70
- if not os.path.exists(metadata_path):
71
- logger.info("Downloading metadata_store.pkl from HF dataset...")
72
- hf_hub_download(repo_id=repo_id, filename="index/metadata_store.pkl", local_dir=data_dir, repo_type="dataset", token=token)
73
- if not os.path.exists(bm25_path):
74
- logger.info("Downloading bm25_cache.pkl from HF dataset...")
75
- hf_hub_download(repo_id=repo_id, filename="index/bm25_cache.pkl", local_dir=data_dir, repo_type="dataset", token=token)
76
- if not os.path.exists(vocab_path):
77
- logger.info("Downloading drugbank vocabulary.csv from HF dataset...")
78
- hf_hub_download(repo_id=repo_id, filename="drugbank vocabulary.csv", local_dir=data_dir, repo_type="dataset", token=token)
79
- if not os.path.exists(rxnorm_path):
80
- logger.info("Downloading rxnorm_cache.csv from HF dataset...")
81
- hf_hub_download(repo_id=repo_id, filename="rxnorm_cache.csv", local_dir=data_dir, repo_type="dataset", token=token)
82
- except Exception as e:
83
- logger.error(f"Failed to download dataset files: {e}")
84
- logger.warning("Backend may not start correctly or queries may fail.")
85
-
86
- # Trigger download at startup
87
- download_dataset_files()
88
-
89
- # Import FastAPI app - this is the main backend for React frontend
90
- from src.api.main import app
91
-
92
- if __name__ == "__main__":
93
- import uvicorn
94
- # Get port from environment (Hugging Face uses 7860)
95
- port = int(os.environ.get("PORT", 7860))
96
-
97
- logger.info("Starting FastAPI backend on port {}".format(port))
98
- uvicorn.run(app, host="0.0.0.0", port=port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MediRAG Backend - FastAPI only (No Gradio)
3
+ React frontend on Vercel, this is just the API backend
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import subprocess
9
+ import logging
10
+ import requests
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Set cache directories for Hugging Face
17
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
18
+ os.environ["HF_HOME"] = "/tmp/hf_home"
19
+ os.environ["TORCH_HOME"] = "/tmp/torch_cache"
20
+
21
+ # Add src to path
22
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
23
+
24
+ # Install spaCy model if not present (optional — server starts without it)
25
+ try:
26
+ import spacy
27
+ try:
28
+ spacy.load("en_core_sci_lg")
29
+ logger.info("spaCy model en_core_sci_lg loaded.")
30
+ except OSError:
31
+ # Try installing the model at runtime
32
+ try:
33
+ logger.info("Attempting to install scispacy model en_core_sci_lg...")
34
+ subprocess.run([
35
+ sys.executable, "-m", "pip", "install", "--quiet",
36
+ "https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_lg-0.5.4.tar.gz"
37
+ ], check=True, timeout=300)
38
+ spacy.load("en_core_sci_lg")
39
+ logger.info("spaCy model installed and loaded.")
40
+ except Exception as model_err:
41
+ logger.warning(f"Could not install spaCy model: {model_err}. NER features will be limited.")
42
+ except ImportError:
43
+ logger.warning("spacy/scispacy not installed. NER features will be limited but server will still start.")
44
+
45
+ # Download datasets using huggingface_hub
46
+ from huggingface_hub import hf_hub_download
47
+ import yaml
48
+ from pathlib import Path
49
+
50
+ # Check if config_local.yaml exists or USE_LOCAL_DATASET is set to skip HF downloads
51
+ config_path = os.environ.get("MEDIRAG_CONFIG", "config_local.yaml" if Path("config_local.yaml").exists() else "config.yaml")
52
+ try:
53
+ with open(config_path, "r", encoding="utf-8") as f:
54
+ config_data = yaml.safe_load(f)
55
+ except Exception:
56
+ config_data = {}
57
+
58
+ use_local_dataset = config_data.get("retrieval", {}).get("use_local_dataset", False) or os.environ.get("USE_LOCAL_DATASET", "false").lower() == "true"
59
+
60
+ # Check and download index and data files
61
+ data_dir = os.path.join(os.path.dirname(__file__), "data")
62
+ index_dir = os.path.join(data_dir, "index")
63
+ os.makedirs(index_dir, exist_ok=True)
64
+
65
+ faiss_path = os.path.join(index_dir, "faiss.index")
66
+ metadata_path = os.path.join(index_dir, "metadata_store.pkl")
67
+ bm25_path = os.path.join(index_dir, "bm25_cache.pkl")
68
+ vocab_path = os.path.join(data_dir, "drugbank vocabulary.csv")
69
+ rxnorm_path = os.path.join(data_dir, "rxnorm_cache.csv")
70
+
71
+ def download_dataset_files():
72
+ """Download FAISS index and other core data from Hugging Face Dataset"""
73
+ if use_local_dataset:
74
+ logger.info("[LOCAL MODE] Bypassing Hugging Face repository download. Relying on local datasets in data/index/.")
75
+ return
76
+
77
+ repo_id = "joytheslothh/MediRAG-Index-Data"
78
+ token = os.environ.get("HF_TOKEN")
79
+ if not token:
80
+ logger.warning("HF_TOKEN environment variable is not set. Dataset download might fail if repo is private.")
81
+
82
+ try:
83
+ if not os.path.exists(faiss_path):
84
+ logger.info("Downloading faiss.index from HF dataset...")
85
+ hf_hub_download(repo_id=repo_id, filename="index/faiss.index", local_dir=data_dir, repo_type="dataset", token=token)
86
+ if not os.path.exists(metadata_path):
87
+ logger.info("Downloading metadata_store.pkl from HF dataset...")
88
+ hf_hub_download(repo_id=repo_id, filename="index/metadata_store.pkl", local_dir=data_dir, repo_type="dataset", token=token)
89
+ if not os.path.exists(bm25_path):
90
+ logger.info("Downloading bm25_cache.pkl from HF dataset...")
91
+ hf_hub_download(repo_id=repo_id, filename="index/bm25_cache.pkl", local_dir=data_dir, repo_type="dataset", token=token)
92
+ if not os.path.exists(vocab_path):
93
+ logger.info("Downloading drugbank vocabulary.csv from HF dataset...")
94
+ hf_hub_download(repo_id=repo_id, filename="drugbank vocabulary.csv", local_dir=data_dir, repo_type="dataset", token=token)
95
+ if not os.path.exists(rxnorm_path):
96
+ logger.info("Downloading rxnorm_cache.csv from HF dataset...")
97
+ hf_hub_download(repo_id=repo_id, filename="rxnorm_cache.csv", local_dir=data_dir, repo_type="dataset", token=token)
98
+ except Exception as e:
99
+ logger.error(f"Failed to download dataset files: {e}")
100
+ logger.warning("Backend may not start correctly or queries may fail.")
101
+
102
+ # Trigger download at startup
103
+ download_dataset_files()
104
+
105
+ # Import FastAPI app - this is the main backend for React frontend
106
+ from src.api.main import app
107
+
108
+ if __name__ == "__main__":
109
+ import uvicorn
110
+ # Get port from environment (Hugging Face uses 7860)
111
+ port = int(os.environ.get("PORT", 7860))
112
+
113
+ logger.info("Starting FastAPI backend on port {}".format(port))
114
+ uvicorn.run(app, host="0.0.0.0", port=port)
src/__init__.py CHANGED
@@ -15,10 +15,11 @@ def _setup_logging() -> None:
15
  log_format = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
16
  log_file = "logs/medirag.log"
17
 
18
- # Try to load level from config.yaml
19
  try:
20
  import yaml
21
- with open("config.yaml", "r") as f:
 
22
  cfg = yaml.safe_load(f)
23
  level_str = cfg.get("logging", {}).get("level", "INFO")
24
  log_level = getattr(logging, level_str.upper(), logging.INFO)
 
15
  log_format = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
16
  log_file = "logs/medirag.log"
17
 
18
+ # Try to load level from config_local.yaml or config.yaml
19
  try:
20
  import yaml
21
+ config_path = os.environ.get("MEDIRAG_CONFIG", "config_local.yaml" if os.path.exists("config_local.yaml") else "config.yaml")
22
+ with open(config_path, "r", encoding="utf-8") as f:
23
  cfg = yaml.safe_load(f)
24
  level_str = cfg.get("logging", {}).get("level", "INFO")
25
  log_level = getattr(logging, level_str.upper(), logging.INFO)
src/api/main.py CHANGED
@@ -32,6 +32,7 @@ from datetime import datetime
32
  from fastapi import FastAPI, HTTPException, File, UploadFile
33
  from fastapi.middleware.cors import CORSMiddleware
34
  from fastapi.responses import RedirectResponse
 
35
 
36
  import threading
37
  from src.api.schemas import (
@@ -54,7 +55,8 @@ from src.pipeline.retriever import Retriever
54
  # Logging
55
  # ---------------------------------------------------------------------------
56
  try:
57
- _cfg = yaml.safe_load(Path("config.yaml").read_text())
 
58
  _log_level = _cfg.get("logging", {}).get("level", "INFO")
59
  _ollama_base = _cfg.get("llm", {}).get("base_url", "http://localhost:11434")
60
  _api_cfg = _cfg.get("api", {})
@@ -354,6 +356,40 @@ def evaluate(req: EvaluateRequest) -> EvaluateResponse:
354
  )
355
 
356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  # ---------------------------------------------------------------------------
358
  # POST /query — end-to-end: question → retrieve → generate → evaluate
359
  # ---------------------------------------------------------------------------
@@ -373,20 +409,120 @@ def query(req: QueryRequest) -> QueryResponse:
373
  import time as _time
374
  t_total = _time.perf_counter()
375
 
376
- logger.info("POST /query question=%r, top_k=%d", req.question[:80], req.top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
- # Step 1: Retrieve
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  retriever: Optional[Retriever] = getattr(app.state, "retriever", None)
380
  if retriever is None:
381
- # Fallback: instantiate now (slower first call)
382
  try:
383
  retriever = Retriever(_cfg)
 
384
  except Exception as exc:
385
  raise HTTPException(status_code=503,
386
  detail=f"Retriever unavailable: {exc}") from exc
387
 
388
  try:
389
- raw_results = retriever.search(req.question, top_k=req.top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  except FileNotFoundError as exc:
391
  raise HTTPException(status_code=503,
392
  detail=f"FAISS index not found: {exc}") from exc
@@ -429,32 +565,16 @@ def query(req: QueryRequest) -> QueryResponse:
429
  top_faiss_cosine = (
430
  raw_results[0][1].get("_top_faiss_cosine", 0.0) if raw_results else 0.0
431
  )
432
-
433
- # Convert request overrides into a dict for generator
434
- llm_overrides = {}
435
- if req.llm_provider:
436
- llm_overrides["provider"] = req.llm_provider
437
- if req.llm_api_key:
438
- llm_overrides["api_key"] = req.llm_api_key
439
- if req.llm_model:
440
- llm_overrides["model"] = req.llm_model
441
- if req.ollama_url:
442
- llm_overrides["ollama_url"] = req.ollama_url
443
- if req.system_prompt:
444
- llm_overrides["system_prompt"] = req.system_prompt
445
- if req.persona:
446
- llm_overrides["persona"] = req.persona
447
-
448
  # =========================================================================
449
  # Step 2a: PRIVACY SHIELD — MediRAG redacts PHI (Option 1)
450
  # =========================================================================
451
  p_mapping = {}
452
  privacy_applied = False
453
- question_to_gen = req.question
454
 
455
  if req.use_privacy_shield:
456
  from src.pipeline.privacy import shield
457
- question_to_gen, p_mapping = shield.redact(req.question)
458
  if p_mapping:
459
  privacy_applied = True
460
  logger.info("PRIVACY INTERVENTION: Redacted %d items from question.", len(p_mapping))
@@ -489,7 +609,7 @@ def query(req: QueryRequest) -> QueryResponse:
489
  providers.append("ollama") # fallback to local if no second key
490
 
491
  logger.info("Running Consensus Layer with %s", providers)
492
- consensus_results = run_consensus_check(req.question, context_chunks, _cfg, providers=providers)
493
 
494
  # If consensus finds a safer merged answer, we promote it
495
  # and update the primary answer for the evaluation loop
@@ -505,7 +625,7 @@ def query(req: QueryRequest) -> QueryResponse:
505
  # Step 3: Evaluate
506
  try:
507
  eval_result = run_evaluation(
508
- question=req.question,
509
  answer=answer,
510
  context_chunks=context_chunks,
511
  run_ragas=req.run_ragas,
@@ -514,7 +634,7 @@ def query(req: QueryRequest) -> QueryResponse:
514
  except Exception as exc:
515
  logger.exception("Evaluation failed: %s", exc)
516
  try:
517
- log_audit("query", req.question, answer, 100, "EVAL_ERROR", 0.0,
518
  int((_time.perf_counter() - t_total) * 1000),
519
  False, {"error": str(exc), "error_type": "evaluation_failure"})
520
  except Exception:
@@ -538,6 +658,50 @@ def query(req: QueryRequest) -> QueryResponse:
538
  original_answer = None
539
  intervention_details = None
540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
541
  faith_score = (mod_results.get("faithfulness") or {}).get("score", 1.0)
542
 
543
  # Source-credibility-aware faith threshold: high-credibility sources get more tolerance
@@ -570,7 +734,7 @@ def query(req: QueryRequest) -> QueryResponse:
570
  # FDA direct lookup can still retrieve the right data even when initial FAISS
571
  # retrieval missed it. Don't label those as coverage gaps — let intervention run.
572
  _ev_entities = (mod_results.get("entity_verifier") or {}).get("details", {}).get("entities", [])
573
- _q_lower_cg = req.question.lower()
574
  _drug_in_question = any(
575
  e.get("rxcui") and e.get("entity", "").lower() in _q_lower_cg
576
  for e in _ev_entities
@@ -606,17 +770,33 @@ def query(req: QueryRequest) -> QueryResponse:
606
  is_refusal_answer, top_faiss_cosine, faith_score,
607
  )
608
 
609
- # Tier 1: CRITICAL BLOCK (HRS ≥ 86) — response is too dangerous to show
610
  # Coverage gap: skip both tiers — regenerating from an empty DB won't help
611
- if coverage_gap:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612
  logger.info("COVERAGE_GAP — skipping intervention (regeneration cannot add missing data).")
613
- elif hrs >= 86:
614
  original_answer = answer
615
  answer = (
616
  "⛔ UNSAFE RESPONSE BLOCKED by MediRAG Safety Gate.\n\n"
617
  "The generated answer was flagged as CRITICAL risk "
618
- f"(Health Risk Score: {hrs}/100). "
619
- "It showed signs of hallucination or contradiction with the retrieved evidence. "
620
  "Please consult a qualified medical professional or rephrase your question."
621
  )
622
  intervention_applied = True
@@ -624,12 +804,12 @@ def query(req: QueryRequest) -> QueryResponse:
624
  intervention_details = {
625
  "hrs_original": hrs,
626
  "faithfulness": faith_score,
627
- "message": "Response blocked: HRS 86 (CRITICAL risk band).",
628
  }
629
- logger.warning("INTERVENTION: CRITICAL_BLOCKED — HRS=%d", hrs)
630
 
631
  # Tier 2: HIGH RISK REGENERATION
632
- elif hrs >= 61 or faith_score < faith_threshold:
633
  original_answer = answer
634
  original_hrs = hrs
635
  logger.warning(
@@ -651,7 +831,7 @@ def query(req: QueryRequest) -> QueryResponse:
651
  e["entity"] for e in ev_details.get("entities", [])
652
  if e.get("status") == "VERIFIED" and e.get("rxcui")
653
  ]
654
- q_lower = req.question.lower()
655
  for drug in verified_drugs:
656
  if drug.lower() in q_lower:
657
  fda_direct += app.state.retriever.get_fda_chunks(drug)
@@ -668,7 +848,7 @@ def query(req: QueryRequest) -> QueryResponse:
668
  guideline_direct: list[dict] = []
669
  if top_faiss_cosine < 0.85:
670
  try:
671
- guideline_direct = app.state.retriever.get_guideline_chunks(req.question)
672
  if guideline_direct:
673
  logger.info("Direct guideline lookup found %d chunks", len(guideline_direct))
674
  except Exception as gl_exc:
@@ -682,11 +862,11 @@ def query(req: QueryRequest) -> QueryResponse:
682
  # For drug/clinical questions, expand query toward authoritative sources
683
  _drug_terms = ("contraindication", "dosage", "dose", "interaction",
684
  "warning", "adverse", "side effect", "mechanism")
685
- _q_lower = req.question.lower()
686
  retry_query = (
687
- f"FDA drug label clinical guideline {req.question}"
688
  if any(t in _q_lower for t in _drug_terms)
689
- else req.question
690
  )
691
  fresh_results = app.state.retriever.search(retry_query, top_k=req.top_k)
692
  fresh_chunks: list[dict] = []
@@ -705,10 +885,10 @@ def query(req: QueryRequest) -> QueryResponse:
705
  except Exception:
706
  retry_chunks = context_chunks
707
 
708
- answer = generate_strict_answer(req.question, retry_chunks, _cfg, overrides=llm_overrides)
709
  # Re-evaluate the corrected answer
710
  eval_result = run_evaluation(
711
- question=req.question,
712
  answer=answer,
713
  context_chunks=retry_chunks,
714
  run_ragas=False, # skip RAGAS on retry to reduce latency
@@ -740,15 +920,59 @@ def query(req: QueryRequest) -> QueryResponse:
740
  logger.info("POST /query → HRS=%d (%s) intervention=%s in %d ms total",
741
  hrs, details.get("risk_band", "?"), intervention_reason or "none", total_ms)
742
 
743
- log_audit("query", req.question, answer, hrs, details.get("risk_band", "UNKNOWN"), composite, total_ms, intervention_applied, {
744
  "module_results": mod_results,
745
  "confidence_level": details.get("confidence_level", "UNKNOWN"),
746
  "intervention_reason": intervention_reason,
747
  "original_answer": original_answer,
748
  })
749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
750
  return QueryResponse(
751
- question=req.question,
752
  generated_answer=answer,
753
  retrieved_chunks=retrieved_chunks_out,
754
  composite_score=composite,
@@ -770,6 +994,7 @@ def query(req: QueryRequest) -> QueryResponse:
770
  consensus_results=consensus_results,
771
  privacy_applied=privacy_applied,
772
  privacy_details={"redacted_count": len(p_mapping)} if privacy_applied else None,
 
773
  coverage_gap=coverage_gap,
774
  coverage_gap_details=coverage_gap_details,
775
  )
 
32
  from fastapi import FastAPI, HTTPException, File, UploadFile
33
  from fastapi.middleware.cors import CORSMiddleware
34
  from fastapi.responses import RedirectResponse
35
+ from pydantic import BaseModel
36
 
37
  import threading
38
  from src.api.schemas import (
 
55
  # Logging
56
  # ---------------------------------------------------------------------------
57
  try:
58
+ _config_path = os.environ.get("MEDIRAG_CONFIG", "config_local.yaml" if Path("config_local.yaml").exists() else "config.yaml")
59
+ _cfg = yaml.safe_load(Path(_config_path).read_text())
60
  _log_level = _cfg.get("logging", {}).get("level", "INFO")
61
  _ollama_base = _cfg.get("llm", {}).get("base_url", "http://localhost:11434")
62
  _api_cfg = _cfg.get("api", {})
 
356
  )
357
 
358
 
359
+ # ---------------------------------------------------------------------------
360
+ # POST /translate — lightweight Hinglish to English translation route
361
+ # ---------------------------------------------------------------------------
362
+ class TranslateRequest(BaseModel):
363
+ text: str
364
+ llm_provider: Optional[str] = None
365
+ llm_api_key: Optional[str] = None
366
+ llm_model: Optional[str] = None
367
+ ollama_url: Optional[str] = None
368
+
369
+ class TranslateResponse(BaseModel):
370
+ translated_text: str
371
+
372
+ @app.post("/translate", response_model=TranslateResponse, tags=["translation"])
373
+ def translate(req: TranslateRequest) -> TranslateResponse:
374
+ llm_overrides = {}
375
+ if req.llm_provider:
376
+ llm_overrides["provider"] = req.llm_provider
377
+ if req.llm_api_key:
378
+ llm_overrides["api_key"] = req.llm_api_key
379
+ if req.llm_model:
380
+ llm_overrides["model"] = req.llm_model
381
+ if req.ollama_url:
382
+ llm_overrides["ollama_url"] = req.ollama_url
383
+
384
+ try:
385
+ from src.pipeline.generator import translate_hinglish_to_english
386
+ translated = translate_hinglish_to_english(req.text, _cfg, overrides=llm_overrides)
387
+ return TranslateResponse(translated_text=translated)
388
+ except Exception as exc:
389
+ logger.exception("Translation endpoint failed: %s", exc)
390
+ raise HTTPException(status_code=500, detail=str(exc))
391
+
392
+
393
  # ---------------------------------------------------------------------------
394
  # POST /query — end-to-end: question → retrieve → generate → evaluate
395
  # ---------------------------------------------------------------------------
 
409
  import time as _time
410
  t_total = _time.perf_counter()
411
 
412
+ # Extract request overrides into a dict for translator + generator
413
+ llm_overrides = {}
414
+ if req.llm_provider:
415
+ llm_overrides["provider"] = req.llm_provider
416
+ if req.llm_api_key:
417
+ llm_overrides["api_key"] = req.llm_api_key
418
+ if req.llm_model:
419
+ llm_overrides["model"] = req.llm_model
420
+ if req.ollama_url:
421
+ llm_overrides["ollama_url"] = req.ollama_url
422
+ if req.system_prompt:
423
+ llm_overrides["system_prompt"] = req.system_prompt
424
+ if req.persona:
425
+ llm_overrides["persona"] = req.persona
426
+
427
+ original_hinglish = None
428
+ question_to_use = req.question
429
 
430
+ if req.translate_hinglish:
431
+ if req.original_hinglish_query:
432
+ original_hinglish = req.original_hinglish_query
433
+ question_to_use = req.question
434
+ logger.info("PRE-TRANSLATED AUDIT SUBMISSION: %r -> %r", original_hinglish, question_to_use)
435
+ else:
436
+ try:
437
+ from src.pipeline.generator import translate_hinglish_to_english
438
+ translated_q = translate_hinglish_to_english(req.question, _cfg, overrides=llm_overrides)
439
+ if translated_q.strip().lower() != req.question.strip().lower():
440
+ original_hinglish = req.question
441
+ question_to_use = translated_q
442
+ logger.info("AUTO-TRANSLATED HINGLISH QUERY: %r -> %r", original_hinglish, question_to_use)
443
+ except Exception as exc:
444
+ logger.error("Hinglish translation module failed: %s", exc)
445
+
446
+ logger.info("POST /query — question=%r, processed_question=%r, top_k=%d",
447
+ req.question[:50], question_to_use[:50], req.top_k)
448
+
449
+ # Safe Semantic Cache lookup
450
+ q_vec = None
451
+ from src.pipeline.semantic_cache import SafeSemanticCache
452
+ semantic_cache = SafeSemanticCache()
453
+
454
+ # Check if retriever can encode
455
  retriever: Optional[Retriever] = getattr(app.state, "retriever", None)
456
  if retriever is None:
 
457
  try:
458
  retriever = Retriever(_cfg)
459
+ app.state.retriever = retriever
460
  except Exception as exc:
461
  raise HTTPException(status_code=503,
462
  detail=f"Retriever unavailable: {exc}") from exc
463
 
464
  try:
465
+ retriever._load_model()
466
+ if retriever._model:
467
+ q_vec = retriever._model.encode(
468
+ [question_to_use.strip()],
469
+ normalize_embeddings=True,
470
+ convert_to_numpy=True,
471
+ )[0].astype(np.float32)
472
+
473
+ cache_hit = semantic_cache.get(
474
+ query_emb=q_vec,
475
+ patient_allergies=req.patient_allergies or [],
476
+ department=req.department or "default",
477
+ overrides=llm_overrides
478
+ )
479
+ if cache_hit:
480
+ logger.info("SEMANTIC CACHE HIT: Returning safe cached response instantly.")
481
+ retrieved_chunks = [
482
+ RetrievedChunk(
483
+ chunk_id=c.get("chunk_id"),
484
+ text=c.get("text"),
485
+ source=c.get("source", ""),
486
+ pub_type=c.get("pub_type", ""),
487
+ pub_year=c.get("pub_year"),
488
+ title=c.get("title", ""),
489
+ similarity_score=c.get("similarity_score", 0.0)
490
+ ) for c in cache_hit.get("retrieved_chunks", [])
491
+ ]
492
+ mr_dict = cache_hit.get("module_results", {})
493
+ return QueryResponse(
494
+ question=cache_hit.get("question", question_to_use),
495
+ generated_answer=cache_hit.get("generated_answer"),
496
+ retrieved_chunks=retrieved_chunks,
497
+ composite_score=cache_hit.get("composite_score", 1.0),
498
+ hrs=cache_hit.get("hrs", 0),
499
+ confidence_level=cache_hit.get("confidence_level", "UNKNOWN"),
500
+ risk_band=cache_hit.get("risk_band", "UNKNOWN"),
501
+ module_results=ModuleResults(
502
+ faithfulness=_module_score(mr_dict, "faithfulness"),
503
+ entity_verifier=_module_score(mr_dict, "entity_verifier"),
504
+ source_credibility=_module_score(mr_dict, "source_credibility"),
505
+ contradiction=_module_score(mr_dict, "contradiction"),
506
+ ragas=_module_score(mr_dict, "ragas"),
507
+ ),
508
+ total_pipeline_ms=0,
509
+ intervention_applied=cache_hit.get("intervention_applied", False),
510
+ intervention_reason=cache_hit.get("intervention_reason"),
511
+ original_answer=cache_hit.get("original_answer"),
512
+ intervention_details=cache_hit.get("intervention_details"),
513
+ consensus_results=cache_hit.get("consensus_results"),
514
+ privacy_applied=cache_hit.get("privacy_applied", False),
515
+ privacy_details=cache_hit.get("privacy_details"),
516
+ original_hinglish_query=cache_hit.get("original_hinglish_query"),
517
+ coverage_gap=cache_hit.get("coverage_gap", False),
518
+ coverage_gap_details=cache_hit.get("coverage_gap_details"),
519
+ )
520
+ except Exception as exc:
521
+ logger.error("Failed semantic cache retrieval lookup: %s", exc)
522
+
523
+ # Step 1: Retrieve
524
+ try:
525
+ raw_results = retriever.search(question_to_use, top_k=req.top_k)
526
  except FileNotFoundError as exc:
527
  raise HTTPException(status_code=503,
528
  detail=f"FAISS index not found: {exc}") from exc
 
565
  top_faiss_cosine = (
566
  raw_results[0][1].get("_top_faiss_cosine", 0.0) if raw_results else 0.0
567
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
  # =========================================================================
569
  # Step 2a: PRIVACY SHIELD — MediRAG redacts PHI (Option 1)
570
  # =========================================================================
571
  p_mapping = {}
572
  privacy_applied = False
573
+ question_to_gen = question_to_use
574
 
575
  if req.use_privacy_shield:
576
  from src.pipeline.privacy import shield
577
+ question_to_gen, p_mapping = shield.redact(question_to_use)
578
  if p_mapping:
579
  privacy_applied = True
580
  logger.info("PRIVACY INTERVENTION: Redacted %d items from question.", len(p_mapping))
 
609
  providers.append("ollama") # fallback to local if no second key
610
 
611
  logger.info("Running Consensus Layer with %s", providers)
612
+ consensus_results = run_consensus_check(question_to_use, context_chunks, _cfg, providers=providers)
613
 
614
  # If consensus finds a safer merged answer, we promote it
615
  # and update the primary answer for the evaluation loop
 
625
  # Step 3: Evaluate
626
  try:
627
  eval_result = run_evaluation(
628
+ question=question_to_use,
629
  answer=answer,
630
  context_chunks=context_chunks,
631
  run_ragas=req.run_ragas,
 
634
  except Exception as exc:
635
  logger.exception("Evaluation failed: %s", exc)
636
  try:
637
+ log_audit("query", question_to_use, answer, 100, "EVAL_ERROR", 0.0,
638
  int((_time.perf_counter() - t_total) * 1000),
639
  False, {"error": str(exc), "error_type": "evaluation_failure"})
640
  except Exception:
 
658
  original_answer = None
659
  intervention_details = None
660
 
661
+ # Dynamic Department-specific Safety Policies & Patient Allergy Gates
662
+ hrs_block_threshold = 86
663
+ hrs_retry_threshold = 61
664
+
665
+ if req.department:
666
+ dept_lower = req.department.lower()
667
+ if "pediatric" in dept_lower:
668
+ hrs_block_threshold = 50
669
+ hrs_retry_threshold = 20
670
+ logger.info("DEPARTMENT SAFETY TUNING (Pediatrics): Block >= 50, Regenerate >= 20")
671
+ elif "oncology" in dept_lower:
672
+ hrs_block_threshold = 55
673
+ hrs_retry_threshold = 25
674
+ logger.info("DEPARTMENT SAFETY TUNING (Oncology): Block >= 55, Regenerate >= 25")
675
+ elif "cardiology" in dept_lower:
676
+ hrs_block_threshold = 60
677
+ hrs_retry_threshold = 30
678
+ logger.info("DEPARTMENT SAFETY TUNING (Cardiology): Block >= 60, Regenerate >= 30")
679
+ elif "emergency" in dept_lower or "er" in dept_lower:
680
+ hrs_block_threshold = 70
681
+ hrs_retry_threshold = 50
682
+ logger.info("DEPARTMENT SAFETY TUNING (ER): Block >= 70, Regenerate >= 50")
683
+ elif "opd" in dept_lower:
684
+ hrs_block_threshold = 80
685
+ hrs_retry_threshold = 60
686
+ logger.info("DEPARTMENT SAFETY TUNING (OPD): Block >= 80, Regenerate >= 60")
687
+
688
+ # Custom Admin Console Override
689
+ if req.custom_hrs_limit is not None:
690
+ hrs_retry_threshold = req.custom_hrs_limit
691
+ hrs_block_threshold = min(95, req.custom_hrs_limit + (30 if req.custom_hrs_limit <= 30 else 20))
692
+ logger.info(f"HOSPITAL CONSOLE OVERRIDE: Block >= {hrs_block_threshold}, Regenerate >= {hrs_retry_threshold}")
693
+
694
+ # Patient Allergy Safety Interception
695
+ allergy_intercepted = False
696
+ allergen_matched = None
697
+ if req.patient_allergies:
698
+ text_to_scan = (question_to_use + " " + answer).lower()
699
+ for allergen in req.patient_allergies:
700
+ if allergen.strip().lower() in text_to_scan:
701
+ allergy_intercepted = True
702
+ allergen_matched = allergen.strip().capitalize()
703
+ break
704
+
705
  faith_score = (mod_results.get("faithfulness") or {}).get("score", 1.0)
706
 
707
  # Source-credibility-aware faith threshold: high-credibility sources get more tolerance
 
734
  # FDA direct lookup can still retrieve the right data even when initial FAISS
735
  # retrieval missed it. Don't label those as coverage gaps — let intervention run.
736
  _ev_entities = (mod_results.get("entity_verifier") or {}).get("details", {}).get("entities", [])
737
+ _q_lower_cg = question_to_use.lower()
738
  _drug_in_question = any(
739
  e.get("rxcui") and e.get("entity", "").lower() in _q_lower_cg
740
  for e in _ev_entities
 
770
  is_refusal_answer, top_faiss_cosine, faith_score,
771
  )
772
 
773
+ # Tier 1: CRITICAL BLOCK (HRS ≥ hrs_block_threshold) — response is too dangerous to show
774
  # Coverage gap: skip both tiers — regenerating from an empty DB won't help
775
+ if allergy_intercepted:
776
+ original_answer = answer
777
+ answer = (
778
+ "⛔ PATIENT SAFETY SHIELD — ALLERGY CONTRADICTION BLOCKED\n\n"
779
+ f"Prescribing or recommending {allergen_matched} is STRICTLY CONTRAINDICATED "
780
+ f"because this patient is flagged as severely ALLERGIC to: {allergen_matched}.\n\n"
781
+ "Immediate Action: Cancel drug order and consult guidelines for safe alternative therapies (e.g. Paracetamol instead of NSAIDs)."
782
+ )
783
+ hrs = 100
784
+ intervention_applied = True
785
+ intervention_reason = "CRITICAL_ALLERGY_BLOCKED"
786
+ intervention_details = {
787
+ "hrs_original": 100,
788
+ "message": f"Response blocked: Patient has an active chart allergy to {allergen_matched}.",
789
+ }
790
+ logger.warning("INTERVENTION: CRITICAL_ALLERGY_BLOCKED — allergen=%s", allergen_matched)
791
+ elif coverage_gap:
792
  logger.info("COVERAGE_GAP — skipping intervention (regeneration cannot add missing data).")
793
+ elif hrs >= hrs_block_threshold:
794
  original_answer = answer
795
  answer = (
796
  "⛔ UNSAFE RESPONSE BLOCKED by MediRAG Safety Gate.\n\n"
797
  "The generated answer was flagged as CRITICAL risk "
798
+ f"(Health Risk Score: {hrs}/100, Ward Limit: {hrs_block_threshold}%).\n\n"
799
+ "It showed signs of clinical hallucination or contradiction with the retrieved evidence. "
800
  "Please consult a qualified medical professional or rephrase your question."
801
  )
802
  intervention_applied = True
 
804
  intervention_details = {
805
  "hrs_original": hrs,
806
  "faithfulness": faith_score,
807
+ "message": f"Response blocked: HRS >= {hrs_block_threshold} (Ward limit exceeded).",
808
  }
809
+ logger.warning("INTERVENTION: CRITICAL_BLOCKED — HRS=%d (limit=%d)", hrs, hrs_block_threshold)
810
 
811
  # Tier 2: HIGH RISK REGENERATION
812
+ elif hrs >= hrs_retry_threshold or faith_score < faith_threshold:
813
  original_answer = answer
814
  original_hrs = hrs
815
  logger.warning(
 
831
  e["entity"] for e in ev_details.get("entities", [])
832
  if e.get("status") == "VERIFIED" and e.get("rxcui")
833
  ]
834
+ q_lower = question_to_use.lower()
835
  for drug in verified_drugs:
836
  if drug.lower() in q_lower:
837
  fda_direct += app.state.retriever.get_fda_chunks(drug)
 
848
  guideline_direct: list[dict] = []
849
  if top_faiss_cosine < 0.85:
850
  try:
851
+ guideline_direct = app.state.retriever.get_guideline_chunks(question_to_use)
852
  if guideline_direct:
853
  logger.info("Direct guideline lookup found %d chunks", len(guideline_direct))
854
  except Exception as gl_exc:
 
862
  # For drug/clinical questions, expand query toward authoritative sources
863
  _drug_terms = ("contraindication", "dosage", "dose", "interaction",
864
  "warning", "adverse", "side effect", "mechanism")
865
+ _q_lower = question_to_use.lower()
866
  retry_query = (
867
+ f"FDA drug label clinical guideline {question_to_use}"
868
  if any(t in _q_lower for t in _drug_terms)
869
+ else question_to_use
870
  )
871
  fresh_results = app.state.retriever.search(retry_query, top_k=req.top_k)
872
  fresh_chunks: list[dict] = []
 
885
  except Exception:
886
  retry_chunks = context_chunks
887
 
888
+ answer = generate_strict_answer(question_to_use, retry_chunks, _cfg, overrides=llm_overrides)
889
  # Re-evaluate the corrected answer
890
  eval_result = run_evaluation(
891
+ question=question_to_use,
892
  answer=answer,
893
  context_chunks=retry_chunks,
894
  run_ragas=False, # skip RAGAS on retry to reduce latency
 
920
  logger.info("POST /query → HRS=%d (%s) intervention=%s in %d ms total",
921
  hrs, details.get("risk_band", "?"), intervention_reason or "none", total_ms)
922
 
923
+ log_audit("query", question_to_use, answer, hrs, details.get("risk_band", "UNKNOWN"), composite, total_ms, intervention_applied, {
924
  "module_results": mod_results,
925
  "confidence_level": details.get("confidence_level", "UNKNOWN"),
926
  "intervention_reason": intervention_reason,
927
  "original_answer": original_answer,
928
  })
929
 
930
+ # Save successful evaluation to Safe Semantic Cache
931
+ if q_vec is not None and not coverage_gap:
932
+ try:
933
+ response_dict = {
934
+ "question": question_to_use,
935
+ "generated_answer": answer,
936
+ "retrieved_chunks": [
937
+ {
938
+ "chunk_id": c.chunk_id,
939
+ "text": c.text,
940
+ "source": c.source,
941
+ "pub_type": c.pub_type,
942
+ "pub_year": c.pub_year,
943
+ "title": c.title,
944
+ "similarity_score": c.similarity_score
945
+ } for c in retrieved_chunks_out
946
+ ],
947
+ "composite_score": composite,
948
+ "hrs": hrs,
949
+ "confidence_level": details.get("confidence_level", "UNKNOWN"),
950
+ "risk_band": details.get("risk_band", "UNKNOWN"),
951
+ "module_results": mod_results,
952
+ "intervention_applied": intervention_applied,
953
+ "intervention_reason": intervention_reason,
954
+ "original_answer": original_answer,
955
+ "intervention_details": intervention_details,
956
+ "consensus_results": consensus_results,
957
+ "privacy_applied": privacy_applied,
958
+ "privacy_details": {"redacted_count": len(p_mapping)} if privacy_applied else None,
959
+ "original_hinglish_query": original_hinglish,
960
+ "coverage_gap": coverage_gap,
961
+ "coverage_gap_details": coverage_gap_details,
962
+ }
963
+ semantic_cache.store(
964
+ query_text=question_to_use,
965
+ query_emb=q_vec,
966
+ response=response_dict,
967
+ patient_allergies=req.patient_allergies or [],
968
+ department=req.department or "default",
969
+ overrides=llm_overrides
970
+ )
971
+ except Exception as exc:
972
+ logger.error("Failed to store in semantic cache: %s", exc)
973
+
974
  return QueryResponse(
975
+ question=question_to_use,
976
  generated_answer=answer,
977
  retrieved_chunks=retrieved_chunks_out,
978
  composite_score=composite,
 
994
  consensus_results=consensus_results,
995
  privacy_applied=privacy_applied,
996
  privacy_details={"redacted_count": len(p_mapping)} if privacy_applied else None,
997
+ original_hinglish_query=original_hinglish,
998
  coverage_gap=coverage_gap,
999
  coverage_gap_details=coverage_gap_details,
1000
  )
src/api/schemas.py CHANGED
@@ -206,6 +206,14 @@ class QueryRequest(BaseModel):
206
  default=False,
207
  description="Automatically redact PHI/PII (names, IDs) before external API calls.",
208
  )
 
 
 
 
 
 
 
 
209
  system_prompt: Optional[str] = Field(
210
  default=None,
211
  description="Custom system prompt to override the default clinical persona."
@@ -214,6 +222,22 @@ class QueryRequest(BaseModel):
214
  default="physician",
215
  description="The target audience for the response: 'physician' or 'patient'."
216
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
 
219
  class RetrievedChunk(BaseModel):
@@ -264,6 +288,7 @@ class QueryResponse(BaseModel):
264
  # Privacy Shield fields
265
  privacy_applied: bool = Field(default=False)
266
  privacy_details: Optional[Dict[str, Any]] = Field(default=None)
 
267
  # Coverage gap gate — distinguishes missing DB coverage from hallucination
268
  coverage_gap: bool = Field(
269
  default=False,
 
206
  default=False,
207
  description="Automatically redact PHI/PII (names, IDs) before external API calls.",
208
  )
209
+ translate_hinglish: bool = Field(
210
+ default=False,
211
+ description="Translate query from Hinglish to English before processing",
212
+ )
213
+ original_hinglish_query: Optional[str] = Field(
214
+ default=None,
215
+ description="[OPTIONAL] Pre-translated original Hinglish text from the front-end audit gate."
216
+ )
217
  system_prompt: Optional[str] = Field(
218
  default=None,
219
  description="Custom system prompt to override the default clinical persona."
 
222
  default="physician",
223
  description="The target audience for the response: 'physician' or 'patient'."
224
  )
225
+ department: Optional[str] = Field(
226
+ default=None,
227
+ description="[OPTIONAL] Active department (oncology, cardiology, pediatrics, opd, etc.) to trigger custom safety thresholds."
228
+ )
229
+ patient_allergies: Optional[list[str]] = Field(
230
+ default=None,
231
+ description="[OPTIONAL] List of patient drug allergies to scan against recommended medications."
232
+ )
233
+ custom_hrs_limit: Optional[int] = Field(
234
+ default=None,
235
+ description="[OPTIONAL] Custom HRS risk tolerance percentage (0-100) set by the hospital console."
236
+ )
237
+ custom_latency_limit: Optional[int] = Field(
238
+ default=None,
239
+ description="[OPTIONAL] Custom max allowed latency in ms."
240
+ )
241
 
242
 
243
  class RetrievedChunk(BaseModel):
 
288
  # Privacy Shield fields
289
  privacy_applied: bool = Field(default=False)
290
  privacy_details: Optional[Dict[str, Any]] = Field(default=None)
291
+ original_hinglish_query: Optional[str] = Field(default=None, description="The original Hinglish query before translation")
292
  # Coverage gap gate — distinguishes missing DB coverage from hallucination
293
  coverage_gap: bool = Field(
294
  default=False,
src/modules/entity_verifier.py CHANGED
@@ -172,6 +172,56 @@ def _lookup_rxnorm_api(drug_name: str, timeout: int = 4) -> Optional[str]:
172
  return None
173
 
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  # ---------------------------------------------------------------------------
176
  # Public API
177
  # ---------------------------------------------------------------------------
@@ -306,12 +356,29 @@ def verify_entities(
306
 
307
  entity_results.append(result)
308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  # --- Score ---------------------------------------------------------------
310
  # Score is based on drug entities only (per SRS Section 6.2)
311
  if drug_total == 0:
312
  score = 0.5 # neutral — no drug entities to verify
313
  else:
314
- score = drug_verified / drug_total
 
 
 
 
315
 
316
  details = {
317
  "total_entities": len(raw_entities),
@@ -319,12 +386,13 @@ def verify_entities(
319
  "verified_count": drug_verified,
320
  "flagged_count": drug_flagged,
321
  "entities": entity_results,
 
322
  }
323
 
324
  latency_ms = int((time.perf_counter() - t0) * 1000)
325
  logger.info(
326
- "Entity verification: %.3f (%d/%d drugs verified) in %d ms",
327
- score, drug_verified, drug_total, latency_ms,
328
  )
329
  return EvalResult(
330
  module_name="entity_verifier",
 
172
  return None
173
 
174
 
175
+ @lru_cache(maxsize=1024)
176
+ def _cached_drug_interactions(rxcuis_tuple: tuple[str, ...], timeout: int) -> list[dict]:
177
+ """
178
+ Synchronous cached NIH REST request to resolve drug interactions.
179
+ """
180
+ rxcuis_str = "+".join(rxcuis_tuple)
181
+ url = f"https://rxnav.nlm.nih.gov/REST/interaction/list.json?rxcuis={rxcuis_str}"
182
+ try:
183
+ resp = requests.get(url, timeout=timeout)
184
+ if resp.status_code != 200:
185
+ return []
186
+
187
+ data = resp.json()
188
+ interactions = []
189
+
190
+ # ONCHigh returns fullInteractionTypeGroup
191
+ groups = data.get("fullInteractionTypeGroup", [])
192
+ for group in groups:
193
+ for fit in group.get("fullInteractionType", []):
194
+ for pair in fit.get("interactionPair", []):
195
+ # Extract concepts
196
+ concepts = pair.get("interactionConcept", [])
197
+ drugs_involved = [c.get("minConcept", {}).get("name", "Unknown") for c in concepts]
198
+ severity = pair.get("severity", "high").lower() # default to high since it's from ONCHigh
199
+ description = pair.get("description", "")
200
+ interactions.append({
201
+ "drugs": drugs_involved,
202
+ "severity": severity,
203
+ "description": description
204
+ })
205
+ return interactions
206
+ except Exception as e:
207
+ logger.error("Failed to fetch drug interactions from RxNav: %s", e)
208
+ return []
209
+
210
+
211
+ def check_drug_interactions(rxcuis: list[str], timeout: int = 5) -> list[dict]:
212
+ """
213
+ Query RxNav API for drug-drug interactions between a list of RxCUIs.
214
+ Uses sorted tuple transformation to enable efficient order-independent caching.
215
+ """
216
+ if len(rxcuis) < 2:
217
+ return []
218
+
219
+ # Sort RxCUIs to ensure cache consistency regardless of list order
220
+ rxcuis_tuple = tuple(sorted(rxcuis))
221
+ return _cached_drug_interactions(rxcuis_tuple, timeout)
222
+
223
+
224
+
225
  # ---------------------------------------------------------------------------
226
  # Public API
227
  # ---------------------------------------------------------------------------
 
356
 
357
  entity_results.append(result)
358
 
359
+ # --- Drug-Drug Interaction Check (DDI) -----------------------------------
360
+ # Gather standard RxCUIs for the verified drugs
361
+ rxcuis = [ent["rxcui"] for ent in entity_results if ent.get("rxcui")]
362
+ unique_rxcuis = list(set(rxcuis))
363
+ interactions = []
364
+
365
+ if len(unique_rxcuis) >= 2:
366
+ logger.info("Multiple drugs detected in answer (%s) — checking for interactions...", unique_rxcuis)
367
+ interactions = check_drug_interactions(unique_rxcuis)
368
+ if interactions:
369
+ logger.warning("DDI Check: Found %d drug interactions!", len(interactions))
370
+ drug_flagged += len(interactions)
371
+
372
  # --- Score ---------------------------------------------------------------
373
  # Score is based on drug entities only (per SRS Section 6.2)
374
  if drug_total == 0:
375
  score = 0.5 # neutral — no drug entities to verify
376
  else:
377
+ # Base score is drug_verified / drug_total
378
+ base_score = drug_verified / drug_total
379
+ # Deduct score for multi-drug interactions (0.2 deduction per interaction, cap at 0.0)
380
+ interaction_deduction = len(interactions) * 0.20
381
+ score = max(0.0, base_score - interaction_deduction)
382
 
383
  details = {
384
  "total_entities": len(raw_entities),
 
386
  "verified_count": drug_verified,
387
  "flagged_count": drug_flagged,
388
  "entities": entity_results,
389
+ "interactions": interactions,
390
  }
391
 
392
  latency_ms = int((time.perf_counter() - t0) * 1000)
393
  logger.info(
394
+ "Entity verification: %.3f (%d/%d drugs verified, %d DDI found) in %d ms",
395
+ score, drug_verified, drug_total, len(interactions), latency_ms,
396
  )
397
  return EvalResult(
398
  module_name="entity_verifier",
src/pipeline/generator.py CHANGED
@@ -50,7 +50,8 @@ _load_env()
50
 
51
  def _load_config() -> dict:
52
  try:
53
- return yaml.safe_load(Path("config.yaml").read_text())
 
54
  except Exception:
55
  return {}
56
 
@@ -582,3 +583,78 @@ def generate_strict_answer(
582
  return _generate_groq(prompt, effective_config)
583
  else:
584
  raise RuntimeError(f"Unknown LLM provider '{provider}'.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def _load_config() -> dict:
52
  try:
53
+ config_path = os.environ.get("MEDIRAG_CONFIG", "config_local.yaml" if Path("config_local.yaml").exists() else "config.yaml")
54
+ return yaml.safe_load(Path(config_path).read_text())
55
  except Exception:
56
  return {}
57
 
 
583
  return _generate_groq(prompt, effective_config)
584
  else:
585
  raise RuntimeError(f"Unknown LLM provider '{provider}'.")
586
+
587
+
588
+ def generate_simple_prompt(
589
+ prompt: str,
590
+ config: Optional[dict] = None,
591
+ overrides: Optional[dict] = None,
592
+ ) -> str:
593
+ """Execute a simple prompt on the active LLM provider without context formatting."""
594
+ if config is None:
595
+ config = _load_config()
596
+
597
+ effective_llm = dict(config.get("llm", {}))
598
+ if overrides:
599
+ if overrides.get("provider"):
600
+ effective_llm["provider"] = overrides["provider"]
601
+ if overrides.get("api_key"):
602
+ pk = (overrides.get("provider") or "gemini").lower()
603
+ key_map = {
604
+ "gemini": "gemini_api_key",
605
+ "openai": "openai_api_key",
606
+ "mistral": "mistral_api_key",
607
+ "groq": "groq_api_key",
608
+ }
609
+ effective_llm[key_map.get(pk, "gemini_api_key")] = overrides["api_key"]
610
+ if overrides.get("model"):
611
+ pk = (overrides.get("provider") or "gemini").lower()
612
+ model_map = {
613
+ "gemini": "gemini_model",
614
+ "openai": "openai_model",
615
+ "mistral": "model",
616
+ "groq": "groq_model",
617
+ }
618
+ effective_llm[model_map.get(pk, "gemini_model")] = overrides["model"]
619
+ if overrides.get("ollama_url"):
620
+ effective_llm["base_url"] = overrides["ollama_url"]
621
+
622
+ effective_config = {**config, "llm": effective_llm}
623
+ provider = effective_llm.get("provider", "gemini").lower()
624
+
625
+ if provider == "gemini":
626
+ return _generate_gemini(prompt, effective_config)
627
+ elif provider == "openai":
628
+ return _generate_openai(prompt, effective_config)
629
+ elif provider == "ollama":
630
+ return _generate_ollama(prompt, effective_config)
631
+ elif provider == "mistral":
632
+ return _generate_mistral(prompt, effective_config)
633
+ elif provider == "groq":
634
+ return _generate_groq(prompt, effective_config)
635
+ else:
636
+ raise RuntimeError(f"Unknown LLM provider '{provider}'.")
637
+
638
+
639
+ def translate_hinglish_to_english(
640
+ question: str,
641
+ config: Optional[dict] = None,
642
+ overrides: Optional[dict] = None,
643
+ ) -> str:
644
+ """Translate clinical query from Hinglish or standard Hindi to professional English."""
645
+ prompt = (
646
+ "You are an expert bilingual clinical query translator. You will receive a medical question "
647
+ "written in Hinglish (a mixture of Hindi and English written in the Latin alphabet) or standard Hindi. "
648
+ "Convert the Hinglish/Hindi question into a clear, professional, grammatically correct English clinical query. "
649
+ "If the input query is already completely in English, return it exactly as it is with no edits. "
650
+ "Do NOT add any conversational preamble, greetings, explanation, or formatting. Only return the translated English query.\n\n"
651
+ f"Query: {question}\n"
652
+ "English Translation:"
653
+ )
654
+ try:
655
+ translated = generate_simple_prompt(prompt, config=config, overrides=overrides)
656
+ return translated.strip().strip('"').strip("'")
657
+ except Exception as exc:
658
+ logger.warning("Hinglish translation failed: %s. Using original query.", exc)
659
+ return question
660
+
src/pipeline/retriever.py CHANGED
@@ -433,7 +433,9 @@ class Retriever:
433
  # ---------------------------------------------------------------------------
434
 
435
  def _load_config() -> dict:
436
- with open("config.yaml", "r", encoding="utf-8") as f:
 
 
437
  return yaml.safe_load(f)
438
 
439
 
 
433
  # ---------------------------------------------------------------------------
434
 
435
  def _load_config() -> dict:
436
+ import os
437
+ config_path = os.environ.get("MEDIRAG_CONFIG", "config_local.yaml" if os.path.exists("config_local.yaml") else "config.yaml")
438
+ with open(config_path, "r", encoding="utf-8") as f:
439
  return yaml.safe_load(f)
440
 
441
 
src/pipeline/semantic_cache.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import json
3
+ import hashlib
4
+ import numpy as np
5
+ import logging
6
+ from pathlib import Path
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class SafeSemanticCache:
11
+ def __init__(self, db_path="data/cache.db", threshold=0.97):
12
+ self.db_path = db_path
13
+ self.threshold = threshold
14
+ self._init_db()
15
+
16
+ def _init_db(self):
17
+ # Ensure containing directory exists
18
+ Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
19
+ conn = sqlite3.connect(self.db_path)
20
+ conn.execute("""
21
+ CREATE TABLE IF NOT EXISTS semantic_cache (
22
+ id INTEGER PRIMARY KEY,
23
+ query_text TEXT,
24
+ embedding BLOB,
25
+ cache_hash TEXT UNIQUE,
26
+ response_json TEXT,
27
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
28
+ )
29
+ """)
30
+ conn.commit()
31
+ conn.close()
32
+
33
+ def _generate_hash(self, query_emb: np.ndarray, patient_allergies: list[str], department: str, overrides: dict) -> str:
34
+ # Create a deterministic representation of the safety environment
35
+ allergies_str = ",".join(sorted([a.lower().strip() for a in patient_allergies]))
36
+ dept_str = department.lower().strip()
37
+ overrides_str = json.dumps(overrides, sort_keys=True)
38
+
39
+ # Round embedding to 4 decimals to ensure stability against float discrepancies
40
+ emb_str = np.round(query_emb, 4).tobytes()
41
+
42
+ hasher = hashlib.sha256()
43
+ hasher.update(emb_str)
44
+ hasher.update(allergies_str.encode('utf-8'))
45
+ hasher.update(dept_str.encode('utf-8'))
46
+ hasher.update(overrides_str.encode('utf-8'))
47
+ return hasher.hexdigest()
48
+
49
+ def get(self, query_emb: np.ndarray, patient_allergies: list[str], department: str, overrides: dict) -> dict | None:
50
+ target_hash = self._generate_hash(query_emb, patient_allergies, department, overrides)
51
+
52
+ conn = sqlite3.connect(self.db_path)
53
+ cursor = conn.cursor()
54
+ # Direct hash lookup first (O(1) fast path)
55
+ cursor.execute("SELECT response_json FROM semantic_cache WHERE cache_hash = ?", (target_hash,))
56
+ row = cursor.fetchone()
57
+
58
+ if row:
59
+ conn.close()
60
+ logger.info("Semantic Cache: Direct hash hit! Returning safe response.")
61
+ try:
62
+ return json.loads(row[0])
63
+ except Exception as e:
64
+ logger.error(f"Failed to parse cached JSON: {e}")
65
+ return None
66
+
67
+ # Fuzzy lookup (Cosine similarity fallback under identical safety settings)
68
+ cursor.execute("SELECT query_text, embedding, response_json, cache_hash FROM semantic_cache")
69
+ rows = cursor.fetchall()
70
+ conn.close()
71
+
72
+ for query, emb_bytes, response_json, cached_hash in rows:
73
+ saved_emb = np.frombuffer(emb_bytes, dtype=np.float32)
74
+ # Compute cosine similarity
75
+ norm_product = np.linalg.norm(query_emb) * np.linalg.norm(saved_emb)
76
+ if norm_product == 0:
77
+ continue
78
+ cosine = np.dot(query_emb, saved_emb) / norm_product
79
+
80
+ if cosine >= self.threshold:
81
+ # Re-verify that the safety hash matches (no allergy difference)
82
+ # To prevent cross-contamination, fuzzy match requires identical allergies/department config
83
+ candidate_hash = self._generate_hash(saved_emb, patient_allergies, department, overrides)
84
+ if candidate_hash == cached_hash:
85
+ logger.info(f"Semantic Cache: Fuzzy similarity hit! ({cosine:.4f})")
86
+ try:
87
+ return json.loads(response_json)
88
+ except Exception as e:
89
+ logger.error(f"Failed to parse cached JSON in fuzzy match: {e}")
90
+ continue
91
+
92
+ return None
93
+
94
+ def store(self, query_text: str, query_emb: np.ndarray, response: dict, patient_allergies: list[str], department: str, overrides: dict):
95
+ target_hash = self._generate_hash(query_emb, patient_allergies, department, overrides)
96
+ conn = sqlite3.connect(self.db_path)
97
+ try:
98
+ conn.execute("""
99
+ INSERT OR REPLACE INTO semantic_cache (query_text, embedding, cache_hash, response_json)
100
+ VALUES (?, ?, ?, ?)
101
+ """, (
102
+ query_text,
103
+ query_emb.tobytes(),
104
+ target_hash,
105
+ json.dumps(response)
106
+ ))
107
+ conn.commit()
108
+ logger.info("Saved successful evaluation to Semantic Cache.")
109
+ except Exception as e:
110
+ logger.error(f"Failed to store in cache: {e}")
111
+ finally:
112
+ conn.close()
tests/test_modules.py CHANGED
@@ -64,3 +64,27 @@ def test_aggregator_logic():
64
  assert abs(res.score - 0.9) < 0.01
65
  assert res.details["hrs"] == 10
66
  assert res.details["risk_band"] == "LOW"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  assert abs(res.score - 0.9) < 0.01
65
  assert res.details["hrs"] == 10
66
  assert res.details["risk_band"] == "LOW"
67
+
68
+
69
+ def test_drug_interactions_and_entity_verifier():
70
+ from src.modules.entity_verifier import check_drug_interactions, verify_entities
71
+
72
+ # 1. Test DDI check directly with known interactive drugs (Warfarin: 11289, Ibuprofen: 5640)
73
+ interactions = check_drug_interactions(["11289", "5640"])
74
+
75
+ # Verify interactions structure is a valid list of dicts
76
+ assert isinstance(interactions, list)
77
+ if interactions:
78
+ assert "drugs" in interactions[0]
79
+ assert "severity" in interactions[0]
80
+ assert "description" in interactions[0]
81
+
82
+ # 2. Verify fallback & interface safety of verify_entities
83
+ res = verify_entities(
84
+ answer="Patient is taking Metformin and Lisinopril.",
85
+ question="What medications is the patient on?",
86
+ context_docs=["The patient is prescribed Metformin 500mg and Lisinopril 10mg."]
87
+ )
88
+ assert res.score is not None
89
+ assert isinstance(res.details, dict)
90
+