joytheslothh commited on
Commit
b6f9fa8
·
0 Parent(s):

deploy: clean build

Browse files
.dockerignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Override .gitignore for Docker builds
2
+ # Include necessary data files
3
+ !data/index/
4
+ !data/index/*
5
+
6
+ # Exclude everything else from data
7
+ data/raw/*
8
+ data/processed/*
9
+ logs/*
.gitattributes ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces - Git Attributes
2
+ *.pkl filter=lfs diff=lfs merge=lfs -text
3
+ *.index filter=lfs diff=lfs merge=lfs -text
4
+ *.bin filter=lfs diff=lfs merge=lfs -text
5
+ *.pt filter=lfs diff=lfs merge=lfs -text
6
+ *.pth filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.onnx filter=lfs diff=lfs merge=lfs -text
9
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore all large database files and directories
2
+ data/
3
+
4
+ # Python
5
+ __pycache__/
6
+ *.pyc
7
+ *.pyo
8
+ *.pyd
9
+ *.egg-info/
10
+ dist/
11
+ build/
12
+ .eggs/
13
+
14
+ # Environments
15
+ venv/
16
+ .venv/
17
+ env/
18
+ .env
19
+
20
+ # Logs (generated at runtime)
21
+ logs/
22
+
23
+ # IDE
24
+ .vscode/
25
+ .idea/
26
+ *.suo
27
+ *.user
28
+
29
+ # OS
30
+ .DS_Store
31
+ Thumbs.db
32
+
33
+ # Notebooks checkpoints
34
+ .ipynb_checkpoints/
35
+
36
+ # Temporary files
37
+ *.tmp
38
+ *.bak
39
+ .env
Dockerfile ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MediRAG Backend - Hugging Face Spaces Docker Deployment
2
+ # Optimized for faster builds
3
+ FROM python:3.10-slim
4
+
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies (libmupdf deps bundled in pymupdf wheel, no extra needed)
8
+ RUN apt-get update && apt-get install -y \
9
+ git \
10
+ curl \
11
+ build-essential \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Set environment variables
15
+ ENV PYTHONUNBUFFERED=1
16
+ ENV TRANSFORMERS_CACHE=/tmp/transformers_cache
17
+ ENV HF_HOME=/tmp/hf_home
18
+ ENV TORCH_HOME=/tmp/torch_cache
19
+ ENV PIP_NO_CACHE_DIR=1
20
+ ENV PIP_DISABLE_PIP_VERSION_CHECK=1
21
+
22
+ # Copy requirements first for better caching
23
+ COPY requirements_minimal.txt .
24
+
25
+ # Force pip re-run by busting the cache (update this date to force full reinstall)
26
+ ARG CACHE_BUST=2026-04-12-v3
27
+ RUN pip install --no-cache-dir -r requirements_minimal.txt
28
+
29
+ # Copy the rest of the application
30
+ COPY . .
31
+
32
+ # Create necessary directories
33
+ RUN mkdir -p data/processed data/raw logs
34
+
35
+ # Expose port (Hugging Face Spaces uses 7860)
36
+ EXPOSE 7860
37
+
38
+ # Run FastAPI backend directly
39
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
40
+
README.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MediRAG API
3
+ emoji: 🏥
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ # MediRAG Backend - Hugging Face Spaces (Docker)
11
+
12
+ 🏥 **Medical RAG System with Hallucination Detection**
13
+
14
+ This is the **backend API** for MediRAG 2.0, designed to work with a **React frontend**.
15
+
16
+ ## 🐳 Docker Deployment
17
+
18
+ This Space provides the backend API. The React frontend connects to this backend.
19
+
20
+ ### Backend Features
21
+ - 🔍 **Hybrid Retrieval**: FAISS (BioBERT) + BM25 keyword search
22
+ - 🧠 **LLM Generation**: Mistral/Gemini for medical answer generation
23
+ - 🛡️ **4-Layer Audit**: Faithfulness, Entity Verification, Source Credibility, Contradiction Detection
24
+ - ⚠️ **Safety Interventions**: Auto-blocks high-risk responses
25
+ - 📊 **Health Risk Score (HRS)**: 0-100 composite safety metric
26
+ - 🔌 **REST API**: Full FastAPI endpoints for React frontend
27
+
28
+ ## 🚀 Usage
29
+
30
+ ### For React Frontend
31
+ Connect your React app to this backend:
32
+ ```javascript
33
+ const API_URL = "https://joytheslothh-medirag-api.hf.space";
34
+ ```
35
+
36
+ ### API Endpoints
37
+ - `GET /health` - Health check
38
+ - `POST /query` - Full RAG pipeline
39
+ - `POST /evaluate` - Evaluate answer
40
+ - `GET /docs` - Swagger API documentation
41
+
42
+ ### Environment Variables
43
+ Set in Hugging Face Space settings:
44
+ - `MISTRAL_API_KEY` - For Mistral LLM
45
+ - `GOOGLE_API_KEY` - For Gemini LLM
46
+
47
+ ## 🏗️ Architecture
48
+ ```
49
+ React Frontend → FastAPI Backend → RAG Pipeline → Response
50
+ ```
51
+
52
+ ## ⚠️ Disclaimer
53
+ **This system is for research purposes only. Always consult qualified medical professionals for health decisions.**
54
+
55
+ ## 📄 License
56
+ MIT License - See repository for details.
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MediRAG Backend - FastAPI only (No Gradio)
3
+ React frontend on Vercel, this is just the API backend
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import subprocess
9
+ import logging
10
+ import requests
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Set cache directories for Hugging Face
17
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
18
+ os.environ["HF_HOME"] = "/tmp/hf_home"
19
+ os.environ["TORCH_HOME"] = "/tmp/torch_cache"
20
+
21
+ # Add src to path
22
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
23
+
24
+ # Install spaCy model if not present (optional — server starts without it)
25
+ try:
26
+ import spacy
27
+ try:
28
+ spacy.load("en_core_sci_lg")
29
+ logger.info("spaCy model en_core_sci_lg loaded.")
30
+ except OSError:
31
+ # Try installing the model at runtime
32
+ try:
33
+ logger.info("Attempting to install scispacy model en_core_sci_lg...")
34
+ subprocess.run([
35
+ sys.executable, "-m", "pip", "install", "--quiet",
36
+ "https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_lg-0.5.4.tar.gz"
37
+ ], check=True, timeout=300)
38
+ spacy.load("en_core_sci_lg")
39
+ logger.info("spaCy model installed and loaded.")
40
+ except Exception as model_err:
41
+ logger.warning(f"Could not install spaCy model: {model_err}. NER features will be limited.")
42
+ except ImportError:
43
+ logger.warning("spacy/scispacy not installed. NER features will be limited but server will still start.")
44
+
45
+ # Download datasets using huggingface_hub
46
+ from huggingface_hub import hf_hub_download
47
+
48
+ # Check and download index and data files
49
+ data_dir = os.path.join(os.path.dirname(__file__), "data")
50
+ index_dir = os.path.join(data_dir, "index")
51
+ os.makedirs(index_dir, exist_ok=True)
52
+
53
+ faiss_path = os.path.join(index_dir, "faiss.index")
54
+ metadata_path = os.path.join(index_dir, "metadata_store.pkl")
55
+ bm25_path = os.path.join(index_dir, "bm25_cache.pkl")
56
+ vocab_path = os.path.join(data_dir, "drugbank vocabulary.csv")
57
+ rxnorm_path = os.path.join(data_dir, "rxnorm_cache.csv")
58
+
59
+ def download_dataset_files():
60
+ """Download FAISS index and other core data from Hugging Face Dataset"""
61
+ repo_id = "joytheslothh/MediRAG-Index-Data"
62
+ token = os.environ.get("HF_TOKEN")
63
+ if not token:
64
+ logger.warning("HF_TOKEN environment variable is not set. Dataset download might fail if repo is private.")
65
+
66
+ try:
67
+ if not os.path.exists(faiss_path):
68
+ logger.info("Downloading faiss.index from HF dataset...")
69
+ hf_hub_download(repo_id=repo_id, filename="index/faiss.index", local_dir=data_dir, repo_type="dataset", token=token)
70
+ if not os.path.exists(metadata_path):
71
+ logger.info("Downloading metadata_store.pkl from HF dataset...")
72
+ hf_hub_download(repo_id=repo_id, filename="index/metadata_store.pkl", local_dir=data_dir, repo_type="dataset", token=token)
73
+ if not os.path.exists(bm25_path):
74
+ logger.info("Downloading bm25_cache.pkl from HF dataset...")
75
+ hf_hub_download(repo_id=repo_id, filename="index/bm25_cache.pkl", local_dir=data_dir, repo_type="dataset", token=token)
76
+ if not os.path.exists(vocab_path):
77
+ logger.info("Downloading drugbank vocabulary.csv from HF dataset...")
78
+ hf_hub_download(repo_id=repo_id, filename="drugbank vocabulary.csv", local_dir=data_dir, repo_type="dataset", token=token)
79
+ if not os.path.exists(rxnorm_path):
80
+ logger.info("Downloading rxnorm_cache.csv from HF dataset...")
81
+ hf_hub_download(repo_id=repo_id, filename="rxnorm_cache.csv", local_dir=data_dir, repo_type="dataset", token=token)
82
+ except Exception as e:
83
+ logger.error(f"Failed to download dataset files: {e}")
84
+ logger.warning("Backend may not start correctly or queries may fail.")
85
+
86
+ # Trigger download at startup
87
+ download_dataset_files()
88
+
89
+ # Import FastAPI app - this is the main backend for React frontend
90
+ from src.api.main import app
91
+
92
+ if __name__ == "__main__":
93
+ import uvicorn
94
+ # Get port from environment (Hugging Face uses 7860)
95
+ port = int(os.environ.get("PORT", 7860))
96
+
97
+ logger.info("Starting FastAPI backend on port {}".format(port))
98
+ uvicorn.run(app, host="0.0.0.0", port=port)
app_demo.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MediRAG Backend - Local Demo Version
3
+ Simplified version for local testing without heavy models
4
+ """
5
+
6
+ import os
7
+ import gradio as gr
8
+
9
+ # Mock functions for demo
10
+ def health_check():
11
+ return {"status": "ok", "demo_mode": True}
12
+
13
+ def query_medical(question: str, top_k: int = 5, mistral_api_key: str = "", google_api_key: str = ""):
14
+ """Demo version - returns mock response"""
15
+
16
+ # Simulate processing
17
+ demo_answer = f"""
18
+ This is a DEMO response for: "{question}"
19
+
20
+ In the full version, this would:
21
+ 1. Retrieve relevant medical documents from FAISS index
22
+ 2. Generate answer using Mistral/Gemini LLM
23
+ 3. Evaluate with 4-layer audit system
24
+ 4. Return Health Risk Score (HRS)
25
+
26
+ **To run full version:**
27
+ - Deploy to Hugging Face Spaces (Docker)
28
+ - Or install all dependencies locally
29
+ """
30
+
31
+ demo_output = f"""
32
+ 🏥 **MEDICAL ANSWER (DEMO MODE)**
33
+
34
+ {demo_answer}
35
+
36
+ ---
37
+ 📊 **RISK ASSESSMENT**
38
+ • Health Risk Score (HRS): 25/100 (DEMO)
39
+ • Risk Band: LOW
40
+ • Confidence: MEDIUM
41
+
42
+ ---
43
+ 🧪 **MODULE SCORES (DEMO)**
44
+ ✓ Faithfulness: 0.85
45
+ ✓ Entity Accuracy: 0.90
46
+ ✓ Source Credibility: 0.88
47
+ ✓ Contradiction Risk: 0.95
48
+
49
+ ---
50
+ 📚 **TOP SOURCES (DEMO)**
51
+ 📄 Source 1: PubMed - Clinical Study (Score: 0.923)
52
+ This is a placeholder for retrieved medical literature...
53
+
54
+ 📄 Source 2: PMC - Systematic Review (Score: 0.891)
55
+ Another placeholder for medical evidence...
56
+
57
+ ---
58
+ ⏱️ Total Time: 1250ms (DEMO)
59
+
60
+ ---
61
+ ⚠️ **NOTE**: This is running in DEMO mode without the full ML models.
62
+ For full functionality, deploy to Hugging Face Spaces or install all dependencies.
63
+ """.strip()
64
+
65
+ return demo_output
66
+
67
+ # Create Gradio interface
68
+ with gr.Blocks(title="MediRAG - Medical AI Demo") as demo:
69
+ gr.Markdown("""
70
+ # 🏥 MediRAG 2.0 - DEMO MODE
71
+ ## Medical Question Answering with Hallucination Detection
72
+
73
+ **⚠️ This is a DEMO version for local testing.**
74
+
75
+ The full version includes:
76
+ - 107,425+ medical documents in FAISS index
77
+ - BioBERT embeddings for retrieval
78
+ - Mistral/Gemini LLM for generation
79
+ - 4-layer audit system (DeBERTa-v3, SciSpaCy)
80
+ - Health Risk Score calculation
81
+
82
+ **Deploy to Hugging Face Spaces for full functionality:**
83
+ https://huggingface.co/spaces/joytheslothh/MediRAG-API
84
+ """)
85
+
86
+ with gr.Accordion("⚙️ API Configuration (Optional)", open=False):
87
+ gr.Markdown("""
88
+ In the full version, provide your API keys for LLM generation:
89
+ - **Mistral API Key**: https://console.mistral.ai/
90
+ - **Google API Key**: https://makersuite.google.com/app/apikey
91
+ """)
92
+ with gr.Row():
93
+ mistral_key_input = gr.Textbox(
94
+ label="Mistral API Key",
95
+ placeholder="Enter your Mistral API key (full version only)",
96
+ type="password",
97
+ value=""
98
+ )
99
+ google_key_input = gr.Textbox(
100
+ label="Google API Key (Gemini)",
101
+ placeholder="Enter your Google API key (full version only)",
102
+ type="password",
103
+ value=""
104
+ )
105
+
106
+ with gr.Row():
107
+ with gr.Column():
108
+ question_input = gr.Textbox(
109
+ label="Your Medical Question",
110
+ placeholder="e.g., What are the side effects of metformin?",
111
+ lines=3
112
+ )
113
+ top_k_slider = gr.Slider(
114
+ minimum=1,
115
+ maximum=10,
116
+ value=5,
117
+ step=1,
118
+ label="Number of Sources to Retrieve"
119
+ )
120
+ submit_btn = gr.Button("🔍 Ask MediRAG (Demo)", variant="primary")
121
+
122
+ with gr.Column():
123
+ output_text = gr.Markdown(label="Response")
124
+
125
+ submit_btn.click(
126
+ fn=query_medical,
127
+ inputs=[question_input, top_k_slider, mistral_key_input, google_key_input],
128
+ outputs=output_text
129
+ )
130
+
131
+ gr.Markdown("""
132
+ ---
133
+ ### 🚀 How to Run Full Version
134
+
135
+ **Option 1: Hugging Face Spaces (Recommended)**
136
+ ```
137
+ 1. Visit: https://huggingface.co/spaces/joytheslothh/MediRAG-API
138
+ 2. The full app is already deployed there!
139
+ ```
140
+
141
+ **Option 2: Local with Docker**
142
+ ```bash
143
+ cd Backend
144
+ docker build -t medirag .
145
+ docker run -p 7860:7860 medirag
146
+ ```
147
+
148
+ **Option 3: Local with Virtual Environment**
149
+ ```bash
150
+ cd Backend
151
+ python -m venv venv
152
+ venv\Scripts\activate
153
+ pip install -r requirements_hf.txt
154
+ python -m spacy download en_core_sci_lg
155
+ python app.py
156
+ ```
157
+
158
+ ### 🔬 Full System Features
159
+ - **Faithfulness**: DeBERTa-v3 NLI model checks claim support
160
+ - **Entity Verification**: SciSpaCy + DrugBank for drug/dosage validation
161
+ - **Source Credibility**: Ranks evidence by publication tier
162
+ - **Contradiction Detection**: Internal NLI cross-check for self-contradictions
163
+ """)
164
+
165
+ if __name__ == "__main__":
166
+ port = int(os.environ.get("PORT", 7860))
167
+ demo.launch(
168
+ server_name="0.0.0.0",
169
+ server_port=port,
170
+ share=False,
171
+ show_error=True
172
+ )
config.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ retrieval:
2
+ top_k: 5
3
+ chunk_size: 512
4
+ chunk_overlap: 50
5
+ embedding_model: dmis-lab/biobert-v1.1
6
+ index_path: data/index/faiss.index
7
+ metadata_path: data/index/metadata_store.pkl
8
+
9
+ modules:
10
+ faithfulness:
11
+ nli_model: cnut1648/biolinkbert-mednli
12
+ entailment_threshold: 0.75
13
+ max_nli_tokens: 510
14
+ truncate_side: left # keep END of context (clinical values appear last)
15
+ deberta_batch_size: 4 # Colab T4: safe at 8 | CPU 16GB: use 2 | OOM: system retries at 1
16
+ entity_verifier:
17
+ spacy_model: en_ner_bc5cdr_md
18
+ critical_entity_types: [DRUG, DOSAGE]
19
+ dosage_tolerance_pct: 10 # >10% numerical difference → CRITICAL
20
+ rxnorm_api_url: https://rxnav.nlm.nih.gov/REST/approximateTerm.json
21
+ rxnorm_api_timeout_s: 3
22
+ rxnorm_cache_path: data/rxnorm_cache.csv
23
+ source_credibility:
24
+ method: keyword # "keyword" = demo (FR-11a) | "metadata" = May (FR-11b)
25
+ # tier weights are defined by name in src/modules/source_credibility.py TIER_WEIGHTS dict
26
+ # clinical_guideline=1.0, drug_label=0.90, systematic_review=0.85,
27
+ # research_abstract=0.70, review_article=0.60, clinical_case=0.50, unknown=0.30
28
+ contradiction:
29
+ nli_model: cnut1648/biolinkbert-mednli # same model as faithfulness — load once
30
+ confidence_threshold: 0.75
31
+ max_sentence_pairs: 45 # skip if N > 10 sentences, check adjacent + (first,last)
32
+ deberta_batch_size: 4
33
+
34
+ aggregator:
35
+ weights:
36
+ faithfulness: 0.35
37
+ entity_accuracy: 0.20
38
+ source_credibility: 0.20
39
+ contradiction_risk: 0.15
40
+ ragas_composite: 0.10
41
+ risk_bands:
42
+ low: [0, 30]
43
+ moderate: [31, 60]
44
+ high: [61, 85]
45
+ critical: [86, 100]
46
+
47
+ llm:
48
+ provider: mistral
49
+ gemini_api_key: ${GEMINI_API_KEY}
50
+ mistral_api_key: ${MISTRAL_API_KEY}
51
+ groq_api_key: ${GROQ_API_KEY}
52
+ model: mistral-large-latest
53
+ gemini_model: gemini-2.0-flash
54
+ groq_model: llama-3.3-70b-versatile
55
+ base_url: http://localhost:11434
56
+ timeout_seconds: 120
57
+ judge_temperature: 0.0
58
+ generation_temperature: 0.7
59
+
60
+ api:
61
+ host: 0.0.0.0
62
+ port: 8000
63
+ max_query_length: 500
64
+ max_answer_length: 2000
65
+ max_chunks: 10
66
+ max_chunk_length: 2000
67
+
68
+ logging:
69
+ level: INFO # set to WARNING on demo day
70
+ file: logs/medirag.log
71
+ format: "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
conftest.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ conftest.py — project root
3
+ Ensures src/ is on the Python path so all test files can import from src.*
4
+ without needing PYTHONPATH to be set manually. (SRS Section 17)
5
+ """
6
+ import sys
7
+ import os
8
+
9
+ # Add the src/ directory to path so `from modules.faithfulness import ...` works
10
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
11
+ # Also add project root so `import src` works
12
+ sys.path.insert(0, os.path.dirname(__file__))
demo/.gitkeep ADDED
@@ -0,0 +1 @@
 
 
1
+ # placeholder — demo_fallback.json generated by scripts/warmup.py
pytest.ini ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [pytest]
2
+ testpaths = tests
3
+ python_files = test_*.py
4
+ python_classes = Test*
5
+ python_functions = test_*
6
+ addopts = -v --tb=short
requirements.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain==0.1.20
2
+ langchain-community==0.0.38
3
+
4
+ # FIX 1: faiss-cpu 1.7.4 doesn't exist on PyPI — 1.9.0+ has a compatible API
5
+ faiss-cpu>=1.9.0
6
+
7
+ # FIX 2: torch 2.2.0 has no Python 3.13 wheels — 2.5.0+ supports Python 3.13
8
+ torch>=2.5.0
9
+
10
+ # FIX 3: transformers 4.40.0 may have issues on Python 3.13 — use 4.44+
11
+ transformers>=4.44.0
12
+ sentence-transformers>=2.7.0
13
+
14
+ # scispacy + en_core_sci_lg: installed via conda, NOT here (see setup commands below)
15
+ # scispacy 0.5.4 pins scipy<1.11 which has no Python 3.12 pip wheels.
16
+ # Conda has pre-built scipy binaries — use: conda install -c conda-forge scispacy
17
+
18
+ ragas==0.1.9
19
+ fastapi==0.110.0
20
+ uvicorn==0.27.0
21
+ # streamlit>=1.35.0 # Removed - using React frontend instead
22
+ pyyaml==6.0.1
23
+ pydantic>=2.9.0 # 2.6.0 has broken pydantic.v1 on Python 3.12 (ForwardRef bug); fixed in 2.9+
24
+ datasets==2.18.0
25
+ pytest==8.1.0
26
+ httpx>=0.27.0,<0.28.0 # starlette 0.36.3 TestClient breaks with httpx 0.28+ (removed app= kwarg)
27
+ pandas>=2.2.0 # 2.2.0 has Python 3.12 wheels (no longer need 2.2.3+)
28
+ numpy>=1.26.4,<2 # langchain 0.1.20 requires numpy<2; use conda env for Python 3.12 (conda pre-builds numpy 1.x)
29
+ requests==2.31.0
30
+ google-genai>=1.0.0 # New Google GenAI SDK (replaces deprecated google-generativeai)
31
+ pysbd>=0.3.4 # sentence boundary detection (faithfulness module)
32
+ pymupdf>=1.24.0 # fitz: extracted text from PDF
33
+ python-docx>=1.1.0 # extracted text from DOCX
34
+ rank-bm25>=0.2.2 # keyword search for retriever
35
+ python-multipart>=0.0.12 # handle form data in FastAPI
requirements_hf.txt ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MediRAG Backend - Hugging Face Spaces Requirements
2
+ # Optimized for faster builds - relaxed version constraints
3
+
4
+ # Core dependencies
5
+ langchain>=0.1.0
6
+ langchain-community>=0.0.30
7
+
8
+ # Vector search
9
+ faiss-cpu>=1.9.0
10
+
11
+ # ML/DL frameworks
12
+ torch>=2.0.0
13
+ transformers>=4.40.0
14
+ sentence-transformers>=2.5.0
15
+
16
+ # Medical NLP - installed in Dockerfile instead
17
+ # scispacy>=0.5.4
18
+ # https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_lg-0.5.4.tar.gz
19
+ # https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_ner_bc5cdr_md-0.5.4.tar.gz
20
+
21
+ # Evaluation
22
+ ragas>=0.1.0
23
+
24
+ # API framework
25
+ fastapi>=0.110.0
26
+ uvicorn>=0.27.0
27
+
28
+ # Hugging Face Spaces - Gradio for API wrapper
29
+ gradio>=4.0.0,<5.0.0
30
+
31
+ # Utilities
32
+ pyyaml>=6.0.0
33
+ pydantic>=2.0.0
34
+ datasets>=2.18.0
35
+ pandas>=2.0.0
36
+ numpy>=1.26.0,<2
37
+ requests>=2.30.0
38
+ google-genai>=0.5.0
39
+ pysbd>=0.3.0
40
+ pymupdf>=1.24.0
41
+ python-docx>=1.1.0
42
+ rank-bm25>=0.2.0
43
+ python-multipart>=0.0.12
44
+
45
+ # Additional for Hugging Face
46
+ huggingface-hub>=0.20.0
requirements_minimal.txt ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MediRAG Backend - FastAPI only (no Gradio)
2
+ # React frontend on Vercel, this is just the API backend
3
+
4
+ # Core API
5
+ fastapi>=0.110.0
6
+ uvicorn>=0.27.0
7
+ python-multipart>=0.0.12
8
+
9
+ # Data handling
10
+ pydantic>=2.0.0
11
+ pyyaml>=6.0.0
12
+ numpy>=1.26.0,<2
13
+ pandas>=2.0.0
14
+ requests>=2.30.0
15
+
16
+ # Essential ML only
17
+ torch --index-url https://download.pytorch.org/whl/cpu
18
+ transformers>=4.40.0
19
+ sentence-transformers>=2.5.0
20
+ faiss-cpu>=1.9.0
21
+
22
+ # LLM integrations
23
+ langchain>=0.1.0
24
+ langchain-community>=0.0.30
25
+ google-genai>=0.5.0
26
+ ragas>=0.1.0
27
+
28
+ # Hugging Face Hub (for fetching FAISS index at runtime)
29
+ huggingface-hub>=0.20.0
30
+ datasets>=2.18.0
31
+
32
+ # File parsing (PDF, DOCX)
33
+ pymupdf>=1.24.0
34
+ python-docx>=1.1.0
35
+
36
+ # Medical NLP
37
+ spacy>=3.7.0
38
+ scispacy>=0.5.4
39
+ https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_lg-0.5.4.tar.gz
40
+ https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_ner_bc5cdr_md-0.5.4.tar.gz
41
+ pysbd>=0.3.0
42
+ rank-bm25>=0.2.0
43
+
scripts/build_rxnorm_cache.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FR-20: build_rxnorm_cache.py — Offline Drug Name Normalisation Cache Builder
3
+ =============================================================================
4
+ Accepts EITHER:
5
+ A) DrugBank vocabulary CSV (--drugbank-csv) ← recommended, immediate
6
+ B) DrugBank Open Data XML (--drugbank-xml) ← requires registration at drugbank.com
7
+
8
+ DrugBank vocabulary CSV is freely downloadable (no account needed) from:
9
+ https://go.drugbank.com/releases/latest#open-data → "DrugBank Vocabulary"
10
+
11
+ Queries RxNorm REST API (single approximateTerm call per drug) and saves
12
+ results to data/rxnorm_cache.csv.
13
+
14
+ Runtime:
15
+ ~14,000 names × 0.1s delay × 1 API call ≈ 24 minutes
16
+
17
+ Usage:
18
+ python scripts/build_rxnorm_cache.py --drugbank-csv "data/drugbank vocabulary.csv"
19
+ python scripts/build_rxnorm_cache.py --drugbank-csv "data/drugbank vocabulary.csv" --dry-run 50
20
+ python scripts/build_rxnorm_cache.py --drugbank-xml data/raw/drugbank_open_data.xml
21
+ """
22
+ from __future__ import annotations
23
+
24
+ import argparse
25
+ import csv
26
+ import logging
27
+ import sys
28
+ import time
29
+ import xml.etree.ElementTree as ET
30
+ from pathlib import Path
31
+
32
+ import requests
33
+
34
+ logging.basicConfig(
35
+ level=logging.INFO,
36
+ format="%(asctime)s [%(levelname)s] %(message)s",
37
+ )
38
+ logger = logging.getLogger("build_rxnorm_cache")
39
+
40
+ # RxNorm approximateTerm endpoint — returns rxcui + name in ONE call (v1.4 fix)
41
+ RXNORM_APPROX_URL = "https://rxnav.nlm.nih.gov/REST/approximateTerm.json"
42
+
43
+ # DrugBank Open Data XML namespace (XML path only)
44
+ NS = {"db": "http://www.drugbank.ca"}
45
+
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # Step 1A: Extract drug names from DrugBank Vocabulary CSV ← preferred
49
+ # ---------------------------------------------------------------------------
50
+
51
+ def extract_drug_names_from_csv(csv_path: str) -> list[str]:
52
+ """
53
+ Parse the DrugBank vocabulary CSV and return all drug name strings.
54
+
55
+ CSV columns: DrugBank ID | Accession Numbers | Common name | CAS | UNII
56
+ | Synonyms | Standard InChI Key
57
+
58
+ Synonyms column is pipe-separated (e.g. "Drug A | Alias B | Trade Name C").
59
+
60
+ Args:
61
+ csv_path : path to the DrugBank vocabulary CSV file
62
+
63
+ Returns:
64
+ Sorted deduplicated list of drug name strings.
65
+ """
66
+ path = Path(csv_path)
67
+ if not path.exists():
68
+ logger.error(
69
+ "DrugBank vocabulary CSV not found at '%s'. "
70
+ "Download it from https://go.drugbank.com/releases/latest#open-data "
71
+ "(look for 'DrugBank Vocabulary' — no account needed).",
72
+ csv_path,
73
+ )
74
+ sys.exit(1)
75
+
76
+ logger.info("Parsing DrugBank vocabulary CSV: %s", path)
77
+ names: set[str] = set()
78
+
79
+ with open(path, "r", encoding="utf-8") as f:
80
+ reader = csv.DictReader(f)
81
+ for row in reader:
82
+ # Common name
83
+ common = row.get("Common name", "").strip()
84
+ if common:
85
+ names.add(common)
86
+
87
+ # Pipe-separated synonyms
88
+ synonyms_raw = row.get("Synonyms", "")
89
+ if synonyms_raw:
90
+ for syn in synonyms_raw.split("|"):
91
+ syn = syn.strip()
92
+ if syn:
93
+ names.add(syn)
94
+
95
+ result = sorted(names)
96
+ logger.info("Extracted %d unique drug names/synonyms from CSV", len(result))
97
+ return result
98
+
99
+
100
+ # ---------------------------------------------------------------------------
101
+ # Step 1B: Extract drug names from DrugBank Open Data XML ← needs account
102
+ # ---------------------------------------------------------------------------
103
+
104
+ def extract_drug_names_from_xml(xml_path: str) -> list[str]:
105
+ """
106
+ Parse DrugBank Open Data XML and extract all drug names + synonyms.
107
+
108
+ Args:
109
+ xml_path : Path to drugbank_open_data.xml
110
+
111
+ Returns:
112
+ Sorted deduplicated list of drug name strings.
113
+ """
114
+ logger.info("Parsing DrugBank XML: %s", xml_path)
115
+ try:
116
+ tree = ET.parse(xml_path)
117
+ except FileNotFoundError:
118
+ logger.error(
119
+ "DrugBank XML not found at '%s'. "
120
+ "Download it from https://go.drugbank.com/releases/latest#open-data "
121
+ "(free academic registration required), or use --drugbank-csv instead.",
122
+ xml_path,
123
+ )
124
+ sys.exit(1)
125
+ except ET.ParseError as exc:
126
+ logger.error("Failed to parse DrugBank XML: %s", exc)
127
+ sys.exit(1)
128
+
129
+ root = tree.getroot()
130
+ names: set[str] = set()
131
+
132
+ for drug in root.findall("db:drug", NS):
133
+ name_el = drug.find("db:name", NS)
134
+ if name_el is not None and name_el.text:
135
+ names.add(name_el.text.strip())
136
+ for syn in drug.findall("db:synonyms/db:synonym", NS):
137
+ if syn.text:
138
+ names.add(syn.text.strip())
139
+ for brand in drug.findall(
140
+ "db:international-brands/db:international-brand/db:name", NS
141
+ ):
142
+ if brand.text:
143
+ names.add(brand.text.strip())
144
+
145
+ result = sorted(names)
146
+ logger.info("Extracted %d unique drug names/synonyms from XML", len(result))
147
+ return result
148
+
149
+
150
+
151
+ # ---------------------------------------------------------------------------
152
+ # Step 2: Query RxNorm (single API call per drug — v1.4)
153
+ # ---------------------------------------------------------------------------
154
+
155
+ def query_rxnorm(drug_name: str, timeout: int = 5) -> tuple[str, str]:
156
+ """
157
+ Look up a drug name in RxNorm using approximateTerm endpoint.
158
+ Returns (rxcui, canonical_name). Returns ("", "") on any failure.
159
+
160
+ Uses /approximateTerm — single HTTP call returning both rxcui and name.
161
+ (Previous 2-call approach was replaced in v1.4, cutting runtime by ~50%.)
162
+ """
163
+ try:
164
+ resp = requests.get(
165
+ RXNORM_APPROX_URL,
166
+ params={"term": drug_name, "maxEntries": "1", "option": "1"},
167
+ timeout=timeout,
168
+ )
169
+ if resp.status_code != 200:
170
+ return "", ""
171
+
172
+ candidates: list[dict] = (
173
+ resp.json()
174
+ .get("approximateGroup", {})
175
+ .get("candidate", [])
176
+ )
177
+ if not candidates:
178
+ return "", ""
179
+
180
+ rxcui = candidates[0].get("rxcui", "")
181
+ name = candidates[0].get("name", drug_name) # fallback to input
182
+ return rxcui, name
183
+
184
+ except Exception:
185
+ return "", ""
186
+
187
+
188
+ # ---------------------------------------------------------------------------
189
+ # Main
190
+ # ---------------------------------------------------------------------------
191
+
192
+ def main() -> None:
193
+ parser = argparse.ArgumentParser(
194
+ description="Build offline RxNorm cache from DrugBank data (FR-20)"
195
+ )
196
+ source = parser.add_mutually_exclusive_group()
197
+ source.add_argument(
198
+ "--drugbank-csv",
199
+ metavar="PATH",
200
+ default=None,
201
+ help=(
202
+ "Path to DrugBank vocabulary CSV [RECOMMENDED — no account needed]. "
203
+ "Download from https://go.drugbank.com/releases/latest#open-data"
204
+ ),
205
+ )
206
+ source.add_argument(
207
+ "--drugbank-xml",
208
+ metavar="PATH",
209
+ default=None,
210
+ help="Path to DrugBank Open Data XML (requires free academic registration).",
211
+ )
212
+ parser.add_argument(
213
+ "--output-csv",
214
+ default="data/rxnorm_cache.csv",
215
+ help="Path for output CSV",
216
+ )
217
+ parser.add_argument(
218
+ "--delay",
219
+ type=float,
220
+ default=0.1,
221
+ help="Seconds to wait between API calls (default 0.1 — ~24 min total)",
222
+ )
223
+ parser.add_argument(
224
+ "--dry-run",
225
+ type=int,
226
+ default=0,
227
+ metavar="N",
228
+ help="Only process first N drug names (for testing)",
229
+ )
230
+ parser.add_argument(
231
+ "--resume",
232
+ action="store_true",
233
+ help=(
234
+ "Resume a previously interrupted run. Reads already-completed entries "
235
+ "from --output-csv and skips them, appending only the missing ones."
236
+ ),
237
+ )
238
+ args = parser.parse_args()
239
+
240
+ # ------------------------------------------------------------------
241
+ # Auto-detect source if neither flag was given
242
+ # ------------------------------------------------------------------
243
+ csv_default = "data/drugbank vocabulary.csv"
244
+ xml_default = "data/raw/drugbank_open_data.xml"
245
+
246
+ if args.drugbank_csv:
247
+ drug_names = extract_drug_names_from_csv(args.drugbank_csv)
248
+ elif args.drugbank_xml:
249
+ drug_names = extract_drug_names_from_xml(args.drugbank_xml)
250
+ elif Path(csv_default).exists():
251
+ logger.info("Auto-detected DrugBank vocabulary CSV at '%s'", csv_default)
252
+ drug_names = extract_drug_names_from_csv(csv_default)
253
+ elif Path(xml_default).exists():
254
+ logger.info("Auto-detected DrugBank XML at '%s'", xml_default)
255
+ drug_names = extract_drug_names_from_xml(xml_default)
256
+ else:
257
+ logger.error(
258
+ "No DrugBank source found. Pass --drugbank-csv or --drugbank-xml. "
259
+ "See script docstring for download links."
260
+ )
261
+ sys.exit(1)
262
+
263
+ if args.dry_run > 0:
264
+ drug_names = drug_names[: args.dry_run]
265
+ logger.info("Dry-run mode: processing %d names only", len(drug_names))
266
+
267
+ # ------------------------------------------------------------------
268
+ # Resume: skip names already in the output CSV
269
+ # ------------------------------------------------------------------
270
+ out_path = Path(args.output_csv)
271
+ out_path.parent.mkdir(parents=True, exist_ok=True)
272
+
273
+ already_done: set[str] = set()
274
+ if args.resume and out_path.exists():
275
+ try:
276
+ with open(out_path, "r", encoding="utf-8") as f:
277
+ reader = csv.DictReader(f)
278
+ for row in reader:
279
+ name = row.get("drug_name", "").strip()
280
+ if name:
281
+ already_done.add(name)
282
+ logger.info(
283
+ "Resume mode: %d entries already in cache — skipping these.",
284
+ len(already_done),
285
+ )
286
+ except Exception as exc:
287
+ logger.warning("Could not read existing cache for resume: %s", exc)
288
+ already_done = set()
289
+
290
+ remaining = [n for n in drug_names if n not in already_done]
291
+ skipped = len(drug_names) - len(remaining)
292
+ if skipped:
293
+ logger.info("Skipping %d already-resolved names. %d remaining.", skipped, len(remaining))
294
+
295
+ total = len(remaining)
296
+ if total == 0:
297
+ logger.info("Nothing to do — cache is already complete.")
298
+ sys.exit(0)
299
+
300
+ est_minutes = total * (args.delay + 0.05) / 60
301
+ logger.info(
302
+ "Starting cache build: %d names to process, delay=%.2fs, estimated %.0f minutes",
303
+ total, args.delay, est_minutes,
304
+ )
305
+
306
+ # ------------------------------------------------------------------
307
+ # Write CSV — append if resuming, overwrite otherwise
308
+ # ------------------------------------------------------------------
309
+ file_mode = "a" if args.resume and out_path.exists() and already_done else "w"
310
+ write_header = file_mode == "w"
311
+
312
+ found = len(already_done) # count previously resolved entries too
313
+ new_found = 0
314
+
315
+ with open(out_path, file_mode, newline="", encoding="utf-8") as f:
316
+ writer = csv.writer(f)
317
+ if write_header:
318
+ writer.writerow(["drug_name", "rxcui", "canonical_name"])
319
+
320
+ for i, name in enumerate(remaining):
321
+ rxcui, canonical = query_rxnorm(name)
322
+ writer.writerow([name, rxcui, canonical])
323
+ if rxcui:
324
+ new_found += 1
325
+ found += 1
326
+
327
+ if i % 25 == 0 or i == total - 1:
328
+ pct = 100 * (i + 1) / total
329
+ logger.info(
330
+ "Progress: %d/%d (%.1f%%) — %d resolved this run (%d total)",
331
+ i + 1, total, pct, new_found, found,
332
+ )
333
+
334
+ time.sleep(args.delay)
335
+
336
+ logger.info(
337
+ "Cache saved to %s — %d/%d names resolved to RxNorm IDs (this run: +%d)",
338
+ out_path, found, len(drug_names), new_found,
339
+ )
340
+ logger.info(
341
+ "Commit this file to the repo: git add %s && git commit -m 'Add RxNorm cache'",
342
+ out_path,
343
+ )
344
+
345
+
346
+ if __name__ == "__main__":
347
+ main()
scripts/debug_pmc.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests, re
2
+ from lxml import html
3
+
4
+ r = requests.get(
5
+ 'https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10725812/',
6
+ headers={'User-Agent': 'Mozilla/5.0'},
7
+ timeout=15
8
+ )
9
+ tree = html.fromstring(r.content)
10
+
11
+ # Find main article body — skip nav/header
12
+ article = tree.xpath('//article') or tree.xpath('//*[@role="main"]') or tree.xpath('//div[@class="article"]')
13
+ root = article[0] if article else tree
14
+ print('Using root:', root.tag, root.get('class','')[:40])
15
+
16
+ # Find all sections with their h2/h3 and paragraphs
17
+ sections = root.xpath('.//section')
18
+ print(f'\nTotal sections: {len(sections)}')
19
+
20
+ # Show first Recommendations section content
21
+ for sec in sections:
22
+ h3 = sec.xpath('.//h3')
23
+ if h3 and 'Recommendation' in h3[0].text_content():
24
+ print('\n--- RECOMMENDATIONS SECTION ---')
25
+ print('H3:', h3[0].text_content().strip())
26
+ # Get all list items and paragraphs in this section
27
+ items = sec.xpath('.//li | .//p')
28
+ for item in items[:8]:
29
+ t = item.text_content().strip()
30
+ if t and len(t) > 20:
31
+ print(' TEXT:', t[:200])
32
+ break
33
+
34
+ # Check how rec numbers look — find paragraphs starting with N.N pattern
35
+ all_p = root.xpath('.//p')
36
+ print('\n--- PARAGRAPHS WITH REC NUMBERS ---')
37
+ rec_re = re.compile(r'^\s*\d+\.\d+[a-z]?\s+\w')
38
+ count = 0
39
+ for p in all_p:
40
+ t = p.text_content().strip()
41
+ if rec_re.match(t):
42
+ print(' REC:', t[:200])
43
+ count += 1
44
+ if count >= 5:
45
+ break
46
+
47
+ # Show structure of first H2 section
48
+ print('\n--- FIRST H2 SECTION STRUCTURE ---')
49
+ h2_secs = root.xpath('.//section[.//h2]')
50
+ if h2_secs:
51
+ sec = h2_secs[0]
52
+ print('H2:', sec.xpath('.//h2')[0].text_content().strip()[:60])
53
+ children = list(sec)
54
+ print('Direct children tags:', [c.tag for c in children[:10]])
scripts/download_dailymed.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ scripts/download_dailymed.py
3
+ ============================
4
+ Downloads FDA DailyMed drug labels for common clinical drugs via the
5
+ DailyMed API and saves them as chunks.jsonl ready for ingestion into
6
+ the MediRAG FAISS index.
7
+
8
+ Sections extracted per drug:
9
+ - DOSAGE AND ADMINISTRATION
10
+ - CONTRAINDICATIONS
11
+ - WARNINGS AND PRECAUTIONS
12
+ - INDICATIONS AND USAGE
13
+ - DRUG INTERACTIONS
14
+
15
+ Usage:
16
+ python scripts/download_dailymed.py
17
+ python scripts/download_dailymed.py --drugs metformin aspirin warfarin
18
+ python scripts/download_dailymed.py --output data/dailymed_chunks.jsonl
19
+ """
20
+ from __future__ import annotations
21
+
22
+ import argparse
23
+ import json
24
+ import logging
25
+ import time
26
+ import xml.etree.ElementTree as ET
27
+ from pathlib import Path
28
+
29
+ import requests
30
+
31
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # ---------------------------------------------------------------------------
35
+ # Top 200 common clinical drugs (priority list)
36
+ # ---------------------------------------------------------------------------
37
+ TOP_DRUGS = [
38
+ "metformin", "atorvastatin", "lisinopril", "levothyroxine", "amlodipine",
39
+ "omeprazole", "metoprolol", "albuterol", "losartan", "gabapentin",
40
+ "sertraline", "simvastatin", "montelukast", "pantoprazole", "alprazolam",
41
+ "furosemide", "escitalopram", "rosuvastatin", "acetaminophen", "ibuprofen",
42
+ "amoxicillin", "azithromycin", "doxycycline", "prednisone", "warfarin",
43
+ "clopidogrel", "aspirin", "tamsulosin", "insulin glargine", "glipizide",
44
+ "hydrochlorothiazide", "amlodipine", "venlafaxine", "bupropion", "duloxetine",
45
+ "clonazepam", "lorazepam", "zolpidem", "quetiapine", "aripiprazole",
46
+ "olanzapine", "risperidone", "fluoxetine", "paroxetine", "citalopram",
47
+ "tramadol", "oxycodone", "morphine", "fentanyl", "naloxone",
48
+ "ciprofloxacin", "levofloxacin", "clindamycin", "metronidazole", "trimethoprim",
49
+ "enalapril", "ramipril", "carvedilol", "bisoprolol", "digoxin",
50
+ "spironolactone", "diltiazem", "verapamil", "nifedipine", "hydralazine",
51
+ "nitroglycerin", "isosorbide", "clopidogrel", "apixaban", "rivaroxaban",
52
+ "dabigatran", "heparin", "enoxaparin", "atorvastatin", "pravastatin",
53
+ "ezetimibe", "fenofibrate", "niacin", "gemfibrozil", "cholestyramine",
54
+ "allopurinol", "colchicine", "indomethacin", "naproxen", "celecoxib",
55
+ "hydroxychloroquine", "methotrexate", "leflunomide", "sulfasalazine",
56
+ "prednisolone", "dexamethasone", "budesonide", "fluticasone", "beclomethasone",
57
+ "ipratropium", "tiotropium", "salmeterol", "formoterol", "theophylline",
58
+ "insulin aspart", "insulin lispro", "sitagliptin", "saxagliptin", "empagliflozin",
59
+ "canagliflozin", "dapagliflozin", "liraglutide", "exenatide", "pioglitazone",
60
+ "acarbose", "repaglinide", "nateglinide", "glimepiride", "glyburide",
61
+ "levothyroxine", "methimazole", "propylthiouracil", "calcitonin", "alendronate",
62
+ "risedronate", "ibandronate", "denosumab", "teriparatide", "raloxifene",
63
+ "tamoxifen", "letrozole", "anastrozole", "exemestane", "fulvestrant",
64
+ "rituximab", "trastuzumab", "bevacizumab", "imatinib", "erlotinib",
65
+ "ondansetron", "metoclopramide", "promethazine", "prochlorperazine",
66
+ "loperamide", "bismuth subsalicylate", "lactulose", "polyethylene glycol",
67
+ "docusate", "senna", "mesalamine", "sulfasalazine", "infliximab",
68
+ "adalimumab", "etanercept", "ustekinumab", "secukinumab",
69
+ "acyclovir", "valacyclovir", "oseltamivir", "ribavirin", "sofosbuvir",
70
+ "fluconazole", "itraconazole", "voriconazole", "amphotericin b",
71
+ "vancomycin", "linezolid", "daptomycin", "meropenem", "piperacillin",
72
+ "phenytoin", "valproic acid", "carbamazepine", "levetiracetam", "lamotrigine",
73
+ "topiramate", "oxcarbazepine", "lacosamide", "brivaracetam",
74
+ "donepezil", "memantine", "rivastigmine", "galantamine",
75
+ "carbidopa levodopa", "pramipexole", "ropinirole", "rasagiline", "selegiline",
76
+ "baclofen", "tizanidine", "cyclobenzaprine", "methocarbamol",
77
+ "sildenafil", "tadalafil", "vardenafil", "finasteride", "dutasteride",
78
+ "testosterone", "estradiol", "progesterone", "medroxyprogesterone",
79
+ "methylphenidate", "amphetamine", "atomoxetine", "guanfacine", "clonidine",
80
+ ]
81
+
82
+ # DailyMed sections we care about (LOINC codes)
83
+ SECTION_CODES = {
84
+ "34068-7": "DOSAGE AND ADMINISTRATION",
85
+ "34070-3": "CONTRAINDICATIONS",
86
+ "43685-7": "WARNINGS AND PRECAUTIONS",
87
+ "34067-9": "INDICATIONS AND USAGE",
88
+ "34073-7": "DRUG INTERACTIONS",
89
+ "34071-1": "WARNINGS",
90
+ "34084-4": "ADVERSE REACTIONS",
91
+ "34088-5": "OVERDOSAGE",
92
+ "34080-2": "USE IN SPECIFIC POPULATIONS",
93
+ }
94
+
95
+ DAILYMED_API = "https://dailymed.nlm.nih.gov/dailymed/services/v2"
96
+
97
+
98
+ def search_drug(drug_name: str) -> str | None:
99
+ """Return the SPL set_id for the first matching drug label."""
100
+ try:
101
+ r = requests.get(
102
+ f"{DAILYMED_API}/spls.json",
103
+ params={"drug_name": drug_name, "pagesize": 1},
104
+ timeout=10,
105
+ )
106
+ r.raise_for_status()
107
+ data = r.json()
108
+ results = data.get("data", [])
109
+ if results:
110
+ return results[0].get("setid")
111
+ except Exception as e:
112
+ logger.warning("Search failed for '%s': %s", drug_name, e)
113
+ return None
114
+
115
+
116
+ def fetch_label_xml(set_id: str) -> str | None:
117
+ """Download the full SPL XML for a given set_id."""
118
+ try:
119
+ r = requests.get(
120
+ f"{DAILYMED_API}/spls/{set_id}.xml",
121
+ timeout=15,
122
+ )
123
+ r.raise_for_status()
124
+ return r.text
125
+ except Exception as e:
126
+ logger.warning("XML fetch failed for set_id '%s': %s", set_id, e)
127
+ return None
128
+
129
+
130
+ def extract_sections(xml_text: str, drug_name: str, set_id: str = "unknown") -> list[dict]:
131
+ """Parse SPL XML and extract clinical sections as chunk dicts."""
132
+ chunks = []
133
+ try:
134
+ root = ET.fromstring(xml_text)
135
+ ns = {"hl7": "urn:hl7-org:v3"}
136
+
137
+ # Get brand/generic name from XML
138
+ title_el = root.find(".//hl7:title", ns)
139
+ label_title = title_el.text.strip() if title_el is not None and title_el.text else drug_name.title()
140
+
141
+ for section in root.findall(".//hl7:section", ns):
142
+ code_el = section.find("hl7:code", ns)
143
+ if code_el is None:
144
+ continue
145
+ code = code_el.get("code", "")
146
+ section_name = SECTION_CODES.get(code)
147
+ if not section_name:
148
+ continue
149
+
150
+ # Extract text — handle tables specially so row data isn't lost
151
+ texts = []
152
+ for el in section.iter("{urn:hl7-org:v3}text"):
153
+ # Extract tables as readable rows before falling back to itertext
154
+ for table in el.findall(".//{urn:hl7-org:v3}table"):
155
+ rows = []
156
+ for tr in table.iter("{urn:hl7-org:v3}tr"):
157
+ cells = [" ".join(td.itertext()).strip()
158
+ for td in tr.iter("{urn:hl7-org:v3}td")]
159
+ if not cells:
160
+ cells = [" ".join(th.itertext()).strip()
161
+ for th in tr.iter("{urn:hl7-org:v3}th")]
162
+ row = " | ".join(c for c in cells if c)
163
+ if row:
164
+ rows.append(row)
165
+ if rows:
166
+ texts.append(" ; ".join(rows))
167
+ # Remove table from tree to avoid double-counting via itertext
168
+ el.remove(table) if table in list(el) else None
169
+
170
+ # Non-table text
171
+ text = " ".join(el.itertext()).strip()
172
+ if text:
173
+ texts.append(text)
174
+ full_text = " ".join(texts).strip()
175
+
176
+ if len(full_text) < 50:
177
+ continue
178
+
179
+ # Truncate to 1500 chars per chunk (BioBERT max ~512 tokens)
180
+ for i in range(0, min(len(full_text), 6000), 1500):
181
+ segment = full_text[i:i+1500].strip()
182
+ if len(segment) < 50:
183
+ continue
184
+ chunk_id = f"fda_{drug_name.replace(' ', '_')}_{set_id}_{code}_{i}"
185
+ chunks.append({
186
+ "chunk_id": chunk_id,
187
+ "doc_id": f"fda_{drug_name.replace(' ', '_')}_{set_id}",
188
+ "chunk_text": f"[FDA DailyMed | {drug_name.title()} | {section_name}] {drug_name.title()} {section_name}: {segment}",
189
+ "chunk_index": i // 1500,
190
+ "total_chunks": max(1, min(4, len(full_text) // 1500 + 1)),
191
+ "pub_type": "drug_label",
192
+ "source": "FDA DailyMed",
193
+ "title": f"{label_title} — {section_name}",
194
+ "pub_year": 2024,
195
+ "journal": "FDA DailyMed",
196
+ "drug_name": drug_name,
197
+ "section": section_name,
198
+ })
199
+ except ET.ParseError as e:
200
+ logger.warning("XML parse error for '%s': %s", drug_name, e)
201
+ return chunks
202
+
203
+
204
+ def download_dailymed(drug_list: list[str], output_path: str) -> None:
205
+ out = Path(output_path)
206
+ out.parent.mkdir(parents=True, exist_ok=True)
207
+
208
+ total_chunks = 0
209
+ failed = []
210
+
211
+ with open(out, "w", encoding="utf-8") as f:
212
+ for i, drug in enumerate(drug_list):
213
+ logger.info("[%d/%d] Processing: %s", i + 1, len(drug_list), drug)
214
+
215
+ set_id = search_drug(drug)
216
+ if not set_id:
217
+ logger.warning(" No DailyMed entry found for '%s'", drug)
218
+ failed.append(drug)
219
+ time.sleep(0.3)
220
+ continue
221
+
222
+ xml_text = fetch_label_xml(set_id)
223
+ if not xml_text:
224
+ failed.append(drug)
225
+ time.sleep(0.3)
226
+ continue
227
+
228
+ chunks = extract_sections(xml_text, drug, set_id=set_id)
229
+ for chunk in chunks:
230
+ f.write(json.dumps(chunk) + "\n")
231
+
232
+ total_chunks += len(chunks)
233
+ logger.info(" → %d chunks extracted (set_id: %s)", len(chunks), set_id)
234
+ time.sleep(0.4) # Be polite to the API
235
+
236
+ logger.info("Done. %d total chunks written to %s", total_chunks, out)
237
+ if failed:
238
+ logger.warning("Failed drugs (%d): %s", len(failed), ", ".join(failed))
239
+
240
+
241
+ if __name__ == "__main__":
242
+ parser = argparse.ArgumentParser()
243
+ parser.add_argument("--drugs", nargs="*", default=None,
244
+ help="Specific drug names (default: full TOP_DRUGS list)")
245
+ parser.add_argument("--output", default="data/dailymed_chunks.jsonl",
246
+ help="Output JSONL path")
247
+ parser.add_argument("--limit", type=int, default=None,
248
+ help="Limit number of drugs to download")
249
+ args = parser.parse_args()
250
+
251
+ drug_list = args.drugs or TOP_DRUGS
252
+ # Deduplicate while preserving order
253
+ seen: set[str] = set()
254
+ drug_list = [d for d in drug_list if not (d in seen or seen.add(d))]
255
+ if args.limit:
256
+ drug_list = drug_list[:args.limit]
257
+
258
+ logger.info("Downloading DailyMed labels for %d drugs...", len(drug_list))
259
+ download_dailymed(drug_list, args.output)
scripts/download_guidelines.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ scripts/download_guidelines.py
3
+ ================================
4
+ Downloads clinical guidelines from PubMed Central (PMC) open-access API
5
+ and chunks them for ingestion into the MediRAG FAISS index.
6
+
7
+ Sources:
8
+ - ADA Standards of Medical Care in Diabetes 2024 (16 sections via PMC)
9
+ - More guidelines can be added to GUIDELINE_SOURCES below
10
+
11
+ Chunking strategy (based on structural analysis):
12
+ - Primary boundary: H2 clinical topic + Recommendations block + evidence narrative
13
+ - Never split a Recommendations block
14
+ - Store evidence grades (A/B/C/E) and recommendation numbers as metadata
15
+
16
+ Usage:
17
+ python scripts/download_guidelines.py
18
+ python scripts/download_guidelines.py --source ada_diabetes
19
+ python scripts/download_guidelines.py --dry-run
20
+ """
21
+ from __future__ import annotations
22
+
23
+ import argparse
24
+ import json
25
+ import logging
26
+ import re
27
+ import time
28
+ import uuid
29
+ from pathlib import Path
30
+
31
+ import requests
32
+
33
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
34
+ logger = logging.getLogger(__name__)
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # Guideline sources — PMC IDs for ADA 2024 Standards of Care
38
+ # ---------------------------------------------------------------------------
39
+ GUIDELINE_SOURCES = {
40
+ "ada_diabetes": {
41
+ "name": "ADA Standards of Medical Care in Diabetes 2024",
42
+ "key": "ada",
43
+ "pub_type": "clinical_guideline",
44
+ "source": "American Diabetes Association",
45
+ "pub_year": 2024,
46
+ "journal": "Diabetes Care",
47
+ "sections": [
48
+ {"pmcid": "PMC10725812", "section": "2", "title": "Diagnosis and Classification of Diabetes"},
49
+ {"pmcid": "PMC10725809", "section": "4", "title": "Comprehensive Medical Evaluation and Assessment of Comorbidities"},
50
+ {"pmcid": "PMC10725816", "section": "5", "title": "Facilitating Positive Health Behaviors and Well-being"},
51
+ {"pmcid": "PMC10725808", "section": "6", "title": "Glycemic Goals and Hypoglycemia"},
52
+ {"pmcid": "PMC10725813", "section": "7", "title": "Diabetes Technology"},
53
+ {"pmcid": "PMC10725806", "section": "8", "title": "Obesity and Weight Management for the Prevention and Treatment of Type 2 Diabetes"},
54
+ {"pmcid": "PMC10725810", "section": "9", "title": "Pharmacologic Approaches to Glycemic Treatment"},
55
+ {"pmcid": "PMC10725804", "section": "13", "title": "Older Adults"},
56
+ {"pmcid": "PMC10725814", "section": "14", "title": "Children and Adolescents"},
57
+ {"pmcid": "PMC10725801", "section": "15", "title": "Management of Diabetes in Pregnancy"},
58
+ {"pmcid": "PMC10725815", "section": "16", "title": "Diabetes Care in the Hospital"},
59
+ {"pmcid": "PMC10725798", "section": "1", "title": "Improving Care and Promoting Health in Populations"},
60
+ ],
61
+ },
62
+ "acc_aha_cholesterol": {
63
+ "name": "2018 ACC/AHA Guideline on Management of Blood Cholesterol",
64
+ "key": "acc_aha_chol",
65
+ "pub_type": "clinical_guideline",
66
+ "source": "American College of Cardiology/American Heart Association",
67
+ "pub_year": 2018,
68
+ "journal": "Circulation",
69
+ "sections": [
70
+ # PMC7403606: Grundy et al. 2018 executive summary, freely accessible full text
71
+ {"pmcid": "PMC7403606", "section": "1", "title": "Management of Blood Cholesterol — Statin Therapy and LDL Targets"},
72
+ ],
73
+ },
74
+ "acc_aha_prevention": {
75
+ "name": "2019 ACC/AHA Guideline on Primary Prevention of Cardiovascular Disease",
76
+ "key": "acc_aha_prev",
77
+ "pub_type": "clinical_guideline",
78
+ "source": "American College of Cardiology/American Heart Association",
79
+ "pub_year": 2019,
80
+ "journal": "Journal of the American College of Cardiology",
81
+ "sections": [
82
+ # PMC7685565: Arnett et al. 2019, full guideline open access
83
+ {"pmcid": "PMC7685565", "section": "1", "title": "Primary Prevention — Blood Pressure, Cholesterol, Aspirin, Lifestyle"},
84
+ ],
85
+ },
86
+ }
87
+
88
+ PMC_API = "https://www.ncbi.nlm.nih.gov/research/bionlp/RESTful/pmcoa.cgi/BioC_json/{pmcid}/unicode"
89
+ PMC_EFETCH = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
90
+
91
+ # Evidence grade pattern: single letter A/B/C/E at end of recommendation
92
+ _GRADE_RE = re.compile(r'\b([ABCE])\s*$')
93
+ # Recommendation number pattern: e.g. "9.18", "2.1a", "6.5b"
94
+ _REC_NUM_RE = re.compile(r'^(\d+\.\d+[a-z]?)\s+')
95
+
96
+
97
+ PMC_HTML_URL = "https://www.ncbi.nlm.nih.gov/pmc/articles/{pmcid}/"
98
+
99
+
100
+ def fetch_pmc_xml(pmcid: str) -> str | None:
101
+ """Fetch PMC article HTML page and extract clean structured text."""
102
+ try:
103
+ from lxml import html as lxml_html
104
+ url = PMC_HTML_URL.format(pmcid=pmcid)
105
+ r = requests.get(url, headers={"User-Agent": "Mozilla/5.0"}, timeout=30)
106
+ r.raise_for_status()
107
+ return _extract_pmc_html_text(lxml_html.fromstring(r.content))
108
+ except Exception as e:
109
+ logger.warning("PMC HTML fetch failed for %s: %s", pmcid, e)
110
+ return None
111
+
112
+
113
+ def _extract_pmc_html_text(tree) -> str:
114
+ """
115
+ Extract clean structured text from PMC article HTML.
116
+ Uses lxml XPath to navigate the <article> element.
117
+ Deduplicates recommendation paragraphs (PMC renders them twice).
118
+ """
119
+ # Get main article element
120
+ articles = tree.xpath('//article')
121
+ root = articles[0] if articles else tree
122
+
123
+ lines = []
124
+ seen_texts: set[str] = set() # Deduplication for repeated elements
125
+
126
+ def clean(el) -> str:
127
+ return " ".join(el.text_content().split()).strip()
128
+
129
+ def add_line(text: str) -> None:
130
+ if text and len(text) > 10 and text not in seen_texts:
131
+ seen_texts.add(text)
132
+ lines.append(text)
133
+
134
+ def extract_table(table_el):
135
+ """Extract a table element as readable pipe-separated rows."""
136
+ caption = table_el.xpath('.//caption')
137
+ if caption:
138
+ add_line(f"[Table: {clean(caption[0])}]")
139
+ for tr in table_el.xpath('.//tr'):
140
+ cells = [" ".join(td.text_content().split()).strip()
141
+ for td in tr.xpath('.//td | .//th')]
142
+ row = " | ".join(c for c in cells if c)
143
+ if row:
144
+ add_line(row)
145
+
146
+ def process_section(sec, depth=0):
147
+ # Deep-search for tables first (they may be nested inside divs/figures)
148
+ for table in sec.xpath('.//table'):
149
+ # Only process tables whose nearest section ancestor is this sec
150
+ ancestors = table.xpath('ancestor::section')
151
+ if not ancestors or ancestors[-1] == sec:
152
+ extract_table(table)
153
+
154
+ for child in sec:
155
+ tag = child.tag.lower() if isinstance(child.tag, str) else ""
156
+
157
+ if tag in ("h1", "h2", "h3", "h4"):
158
+ text = clean(child)
159
+ if text and text not in ("Abstract", "References", "Footnotes"):
160
+ lines.append(f"\n{'#' * (depth + 2)} {text}")
161
+
162
+ elif tag == "p":
163
+ text = clean(child)
164
+ add_line(text)
165
+
166
+ elif tag in ("ul", "ol"):
167
+ for li in child.xpath('.//li'):
168
+ text = clean(li)
169
+ add_line(f"• {text}")
170
+
171
+ elif tag == "section":
172
+ process_section(child, depth + 1)
173
+
174
+ elif tag == "table":
175
+ pass # Already handled above via deep-search
176
+
177
+ elif tag == "div":
178
+ # Recurse into divs that might contain content
179
+ cls = child.get("class", "")
180
+ if any(k in cls for k in ("content", "body", "text", "article")):
181
+ process_section(child, depth)
182
+
183
+ for sec in root.xpath('.//section'):
184
+ # Only process top-level sections (not deeply nested)
185
+ parent = sec.getparent()
186
+ if parent is not None and parent.tag.lower() not in ("section",):
187
+ process_section(sec)
188
+
189
+ # If no sections found, fall back to all paragraphs
190
+ if len(lines) < 5:
191
+ for p in root.xpath('.//article//p | .//p[@class]'):
192
+ add_line(clean(p))
193
+
194
+ return "\n\n".join(l for l in lines if l.strip())
195
+
196
+
197
+ def extract_recommendations(text: str) -> list[dict]:
198
+ """Extract individual recommendations with their numbers and grades."""
199
+ recs = []
200
+ for line in text.split('\n'):
201
+ line = line.strip()
202
+ m = _REC_NUM_RE.match(line)
203
+ if m:
204
+ rec_num = m.group(1)
205
+ rec_text = line[m.end():].strip()
206
+ grade_m = _GRADE_RE.search(rec_text)
207
+ grade = grade_m.group(1) if grade_m else "E"
208
+ recs.append({"number": rec_num, "text": rec_text, "grade": grade})
209
+ return recs
210
+
211
+
212
+ def chunk_guideline_text(
213
+ text: str,
214
+ section_meta: dict,
215
+ guideline_meta: dict,
216
+ max_chunk_chars: int = 2000,
217
+ ) -> list[dict]:
218
+ """
219
+ Chunk guideline text at ## heading boundaries produced by _extract_pmc_html_text.
220
+ Each chunk = H2/H3 topic + its paragraphs/recommendations.
221
+ """
222
+ chunks = []
223
+ section_num = section_meta["section"]
224
+ section_title = section_meta["title"]
225
+ guideline_name = guideline_meta["name"]
226
+ source = guideline_meta["source"]
227
+ pub_year = guideline_meta["pub_year"]
228
+ pub_type = guideline_meta["pub_type"]
229
+ source_key = guideline_meta.get("key", "ada")
230
+ journal = guideline_meta.get("journal", "Diabetes Care")
231
+
232
+ # Split text into blocks at any ## heading
233
+ # Each block starts with a heading line and contains the following paragraphs
234
+ _HEADING_RE = re.compile(r'^(#{1,4})\s+(.+)$', re.MULTILINE)
235
+
236
+ # Find all heading positions
237
+ heading_matches = list(_HEADING_RE.finditer(text))
238
+
239
+ if not heading_matches:
240
+ # No headings found — chunk by size
241
+ blocks = [(section_title, text)]
242
+ else:
243
+ blocks = []
244
+ for i, m in enumerate(heading_matches):
245
+ heading_text = m.group(2).strip()
246
+ # Skip metadata headings
247
+ if heading_text in ("Abstract", "References", "Footnotes", "Author notes",
248
+ "Conflicts of interest", "Acknowledgments"):
249
+ continue
250
+ start = m.end()
251
+ end = heading_matches[i + 1].start() if i + 1 < len(heading_matches) else len(text)
252
+ content = text[start:end].strip()
253
+ if content:
254
+ blocks.append((heading_text, content))
255
+
256
+ def make_chunk(heading: str, content: str, part_idx: int = 0) -> dict:
257
+ recs = extract_recommendations(content)
258
+ rec_nums = [r["number"] for r in recs]
259
+ grades = {r["number"]: r["grade"] for r in recs}
260
+ grade_summary = "/".join(sorted(set(r["grade"] for r in recs))) if recs else ""
261
+
262
+ prefix = f"[{guideline_name} | Section {section_num}: {section_title} | {heading}]"
263
+ if grade_summary:
264
+ prefix += f" [Evidence: {grade_summary}]"
265
+
266
+ return {
267
+ "chunk_id": f"guideline_{source_key}_{section_num}_{uuid.uuid4().hex[:8]}_{part_idx}",
268
+ "doc_id": f"guideline_{source_key}_section_{section_num}",
269
+ "chunk_text": f"{prefix}\n{content}",
270
+ "chunk_index": len(chunks),
271
+ "total_chunks": 0,
272
+ "pub_type": pub_type,
273
+ "source": source,
274
+ "title": f"{guideline_name} — Section {section_num}: {heading}",
275
+ "pub_year": pub_year,
276
+ "journal": journal,
277
+ "section_number": section_num,
278
+ "section_title": section_title,
279
+ "h2_heading": heading,
280
+ "recommendation_numbers": rec_nums,
281
+ "evidence_grades": grades,
282
+ }
283
+
284
+ for heading, content in blocks:
285
+ if len(content) <= max_chunk_chars:
286
+ chunks.append(make_chunk(heading, content))
287
+ else:
288
+ # Split long blocks at paragraph boundaries
289
+ paras = [p.strip() for p in re.split(r'\n{2,}', content) if p.strip()]
290
+ current: list[str] = []
291
+ part = 0
292
+ for para in paras:
293
+ current.append(para)
294
+ if len("\n\n".join(current)) >= max_chunk_chars:
295
+ chunks.append(make_chunk(heading, "\n\n".join(current[:-1]), part))
296
+ current = [para]
297
+ part += 1
298
+ if current:
299
+ chunks.append(make_chunk(heading, "\n\n".join(current), part))
300
+
301
+ for chunk in chunks:
302
+ chunk["total_chunks"] = len(chunks)
303
+
304
+ return chunks
305
+
306
+
307
+ def download_guidelines(source_key: str, output_path: str, dry_run: bool = False) -> None:
308
+ source = GUIDELINE_SOURCES[source_key]
309
+ out = Path(output_path)
310
+ out.parent.mkdir(parents=True, exist_ok=True)
311
+
312
+ total_chunks = 0
313
+ failed_sections = []
314
+
315
+ with open(out, "w", encoding="utf-8") as f:
316
+ for section in source["sections"]:
317
+ pmcid = section["pmcid"]
318
+ logger.info("Fetching %s — Section %s: %s", pmcid, section["section"], section["title"])
319
+
320
+ text = fetch_pmc_xml(pmcid)
321
+
322
+ if not text or len(text) < 200:
323
+ logger.warning("No text retrieved for %s — skipping", pmcid)
324
+ failed_sections.append(section["title"])
325
+ time.sleep(0.5)
326
+ continue
327
+
328
+ logger.info(" Retrieved %d chars", len(text))
329
+
330
+ chunks = chunk_guideline_text(text, section, source)
331
+ logger.info(" → %d chunks extracted", len(chunks))
332
+
333
+ if dry_run:
334
+ if chunks:
335
+ logger.info(" Sample chunk:\n%s\n...", chunks[0]["chunk_text"][:300])
336
+ continue
337
+
338
+ for chunk in chunks:
339
+ f.write(json.dumps(chunk) + "\n")
340
+
341
+ total_chunks += len(chunks)
342
+ time.sleep(0.5) # Be polite to NCBI API
343
+
344
+ if not dry_run:
345
+ logger.info("Done. %d total chunks written to %s", total_chunks, out)
346
+ if failed_sections:
347
+ logger.warning("Failed sections: %s", failed_sections)
348
+
349
+
350
+ if __name__ == "__main__":
351
+ parser = argparse.ArgumentParser()
352
+ parser.add_argument("--source", default=None,
353
+ choices=list(GUIDELINE_SOURCES.keys()),
354
+ help="Guideline source to download (default: all sources)")
355
+ parser.add_argument("--all", action="store_true",
356
+ help="Download all guideline sources")
357
+ parser.add_argument("--output", default="data/guidelines_chunks.jsonl")
358
+ parser.add_argument("--dry-run", action="store_true",
359
+ help="Fetch and parse but don't write output")
360
+ args = parser.parse_args()
361
+
362
+ sources_to_run = list(GUIDELINE_SOURCES.keys()) if (args.all or args.source is None) else [args.source]
363
+
364
+ for source_key in sources_to_run:
365
+ logger.info("Downloading: %s", GUIDELINE_SOURCES[source_key]["name"])
366
+ # For multi-source runs, append non-ada sources to the same output file
367
+ if source_key == sources_to_run[0]:
368
+ download_guidelines(source_key, args.output, dry_run=args.dry_run)
369
+ else:
370
+ # Append to existing file by re-opening in append mode
371
+ out = Path(args.output)
372
+ source = GUIDELINE_SOURCES[source_key]
373
+ total_chunks = 0
374
+ failed_sections = []
375
+ with open(out, "a", encoding="utf-8") as f:
376
+ for section in source["sections"]:
377
+ pmcid = section["pmcid"]
378
+ logger.info("Fetching %s — Section %s: %s", pmcid, section["section"], section["title"])
379
+ text = fetch_pmc_xml(pmcid)
380
+ if not text or len(text) < 200:
381
+ logger.warning("No text retrieved for %s — skipping", pmcid)
382
+ failed_sections.append(section["title"])
383
+ time.sleep(0.5)
384
+ continue
385
+ logger.info(" Retrieved %d chars", len(text))
386
+ chunks = chunk_guideline_text(text, section, source)
387
+ logger.info(" → %d chunks extracted", len(chunks))
388
+ if args.dry_run:
389
+ if chunks:
390
+ logger.info(" Sample chunk:\n%s\n...", chunks[0]["chunk_text"][:300])
391
+ continue
392
+ for chunk in chunks:
393
+ f.write(json.dumps(chunk) + "\n")
394
+ total_chunks += len(chunks)
395
+ time.sleep(0.5)
396
+ if not args.dry_run:
397
+ logger.info("Done. %d total chunks written for %s", total_chunks, source_key)
398
+ if failed_sections:
399
+ logger.warning("Failed sections: %s", failed_sections)
scripts/fix_fda_chunk_text.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ scripts/fix_fda_chunk_text.py
3
+ ==============================
4
+ One-time fix: replaces the verbose FDA boilerplate prefix in all FDA DailyMed
5
+ chunk_text entries in the metadata store with a clean, BM25-friendly prefix.
6
+
7
+ Before: [FDA DRUG LABEL — These highlights do not include all the information
8
+ needed to use WARFARIN SODIUM TABLETS safely and effectively...]
9
+ CONTRAINDICATIONS: actual content...
10
+
11
+ After: [FDA DailyMed | Warfarin | CONTRAINDICATIONS] actual content...
12
+
13
+ Usage:
14
+ python scripts/fix_fda_chunk_text.py
15
+ python scripts/fix_fda_chunk_text.py --dry-run
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import logging
21
+ import pickle
22
+ import re
23
+ import sys
24
+ from pathlib import Path
25
+
26
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
27
+ import yaml
28
+
29
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
30
+ logger = logging.getLogger(__name__)
31
+
32
+ SECTION_CODES = {
33
+ "34068-7": "DOSAGE AND ADMINISTRATION",
34
+ "34070-3": "CONTRAINDICATIONS",
35
+ "43685-7": "WARNINGS AND PRECAUTIONS",
36
+ "34067-9": "INDICATIONS AND USAGE",
37
+ "34073-7": "DRUG INTERACTIONS",
38
+ "34071-1": "WARNINGS",
39
+ }
40
+
41
+ # Matches both old boilerplate and previously-fixed format
42
+ _BOILERPLATE_RE = re.compile(r"^\[FDA[^\]]*\]\s*(?:[A-Za-z][^:]*:\s*)?", re.DOTALL)
43
+
44
+
45
+ def fix_chunk_text(chunk_id: str, old_text: str) -> str:
46
+ """Return cleaned chunk_text with a compact, keyword-rich prefix."""
47
+ # Extract drug name from chunk_id: fda_{drug_name}_{set_id}_{code}_{offset}
48
+ parts = chunk_id.split("_")
49
+ # parts[0] = "fda", parts[1] = drug_name (may be multi-word), then UUID parts, then code, then offset
50
+ # Find the section code in parts
51
+ section_name = None
52
+ drug_name_parts = []
53
+ for i, part in enumerate(parts[1:], 1):
54
+ if part in SECTION_CODES:
55
+ section_name = SECTION_CODES[part]
56
+ drug_name_parts = parts[1:i]
57
+ break
58
+
59
+ # Filter out UUID parts (set_id format: 8hex-4hex-...) from drug name
60
+ _UUID_RE = re.compile(r'^[0-9a-f]{8}-', re.I)
61
+ drug_name_parts = [p for p in drug_name_parts if not _UUID_RE.match(p)]
62
+ drug_name = " ".join(drug_name_parts).replace("_", " ").title() if drug_name_parts else "Unknown"
63
+
64
+ if not section_name:
65
+ m = _BOILERPLATE_RE.match(old_text)
66
+ section_name = m.group(1).strip() if m else "DRUG INFORMATION"
67
+
68
+ # Strip the old boilerplate prefix and get just the content
69
+ m = _BOILERPLATE_RE.match(old_text)
70
+ content = old_text[m.end():].strip() if m else old_text.strip()
71
+
72
+ # Prepend drug name into content so BM25 finds it even in continuation chunks
73
+ # e.g. chunk starting "Bleeding tendencies..." now reads "Warfarin CONTRAINDICATIONS: Bleeding..."
74
+ return f"[FDA DailyMed | {drug_name} | {section_name}] {drug_name} {section_name}: {content}"
75
+
76
+
77
+ def main() -> None:
78
+ parser = argparse.ArgumentParser()
79
+ parser.add_argument("--dry-run", action="store_true")
80
+ args = parser.parse_args()
81
+
82
+ with open("config.yaml") as f:
83
+ cfg = yaml.safe_load(f)
84
+ meta_path = cfg["retrieval"]["metadata_path"]
85
+
86
+ logger.info("Loading metadata store from %s ...", meta_path)
87
+ with open(meta_path, "rb") as f:
88
+ store: dict = pickle.load(f)
89
+
90
+ fda_keys = [k for k, v in store.items() if v.get("source") == "FDA DailyMed"]
91
+ logger.info("Found %d FDA DailyMed entries to fix", len(fda_keys))
92
+
93
+ fixed = 0
94
+ for key in fda_keys:
95
+ entry = store[key]
96
+ old_text = entry.get("chunk_text", "")
97
+ # Re-run on both old boilerplate AND previously-fixed entries (to fix UUID + add drug name to content)
98
+ if not (old_text.startswith("[FDA DRUG LABEL") or old_text.startswith("[FDA DailyMed |")):
99
+ continue
100
+ new_text = fix_chunk_text(entry.get("chunk_id", ""), old_text)
101
+ if args.dry_run:
102
+ if fixed < 3:
103
+ logger.info("BEFORE: %s", old_text[:120])
104
+ logger.info("AFTER: %s", new_text[:120])
105
+ logger.info("---")
106
+ else:
107
+ store[key]["chunk_text"] = new_text
108
+ fixed += 1
109
+
110
+ logger.info("%d entries %s", fixed,
111
+ "would be fixed (dry run)" if args.dry_run else "fixed")
112
+
113
+ if not args.dry_run:
114
+ with open(meta_path, "wb") as f:
115
+ pickle.dump(store, f, protocol=pickle.HIGHEST_PROTOCOL)
116
+ logger.info("Metadata store saved. Restart backend to rebuild BM25 index.")
117
+
118
+
119
+ if __name__ == "__main__":
120
+ main()
scripts/ingest_incremental.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ scripts/ingest_incremental.py
3
+ ==============================
4
+ Adds new chunks to an EXISTING FAISS index without rebuilding from scratch.
5
+ Only the new chunks are embedded — existing vectors are untouched.
6
+
7
+ Usage:
8
+ python scripts/ingest_incremental.py --input data/dailymed_chunks.jsonl
9
+ python scripts/ingest_incremental.py --input data/dailymed_chunks.jsonl --dry-run
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import argparse
14
+ import json
15
+ import logging
16
+ import pickle
17
+ import sys
18
+ from pathlib import Path
19
+
20
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
21
+
22
+ import faiss
23
+ import numpy as np
24
+ import yaml
25
+
26
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def load_config() -> dict:
31
+ with open("config.yaml", "r", encoding="utf-8") as f:
32
+ return yaml.safe_load(f)
33
+
34
+
35
+ def load_new_chunks(path: str) -> list[dict]:
36
+ chunks = []
37
+ with open(path, "r", encoding="utf-8") as f:
38
+ for line in f:
39
+ line = line.strip()
40
+ if line:
41
+ chunks.append(json.loads(line))
42
+ logger.info("Loaded %d new chunks from %s", len(chunks), path)
43
+ return chunks
44
+
45
+
46
+ def embed_chunks(chunks: list[dict], model_name: str) -> np.ndarray:
47
+ from sentence_transformers import SentenceTransformer
48
+ model = SentenceTransformer(model_name)
49
+ texts = [c["chunk_text"] for c in chunks]
50
+ logger.info("Embedding %d new chunks with %s...", len(texts), model_name)
51
+ embeddings = model.encode(
52
+ texts,
53
+ batch_size=32,
54
+ show_progress_bar=True,
55
+ normalize_embeddings=True,
56
+ convert_to_numpy=True,
57
+ )
58
+ return embeddings.astype(np.float32)
59
+
60
+
61
+ def main() -> None:
62
+ parser = argparse.ArgumentParser()
63
+ parser.add_argument("--input", required=True, help="JSONL file of new chunks")
64
+ parser.add_argument("--dry-run", action="store_true",
65
+ help="Show what would be added without writing to disk")
66
+ parser.add_argument("--force-update-section", default=None,
67
+ help="Force-update chunk_text for existing chunks matching this section keyword (e.g. 'ADVERSE REACTIONS')")
68
+ args = parser.parse_args()
69
+
70
+ cfg = load_config()
71
+ idx_path = cfg["retrieval"]["index_path"]
72
+ meta_path = cfg["retrieval"]["metadata_path"]
73
+ model_name = cfg["retrieval"]["embedding_model"]
74
+
75
+ if not Path(idx_path).exists():
76
+ logger.error("FAISS index not found at %s. Run embedder.py first.", idx_path)
77
+ sys.exit(1)
78
+
79
+ # Load existing index + metadata
80
+ logger.info("Loading existing FAISS index from %s ...", idx_path)
81
+ index = faiss.read_index(idx_path)
82
+ existing_count = index.ntotal
83
+ logger.info("Existing index: %d vectors", existing_count)
84
+
85
+ with open(meta_path, "rb") as f:
86
+ metadata_store: dict[int, dict] = pickle.load(f)
87
+
88
+ # Force-update existing chunk_text for a specific section (no new FAISS vectors needed)
89
+ all_input_chunks = load_new_chunks(args.input)
90
+ if args.force_update_section:
91
+ section_kw = args.force_update_section.upper()
92
+ # Primary lookup: chunk_id → FAISS key (works for FDA with deterministic IDs)
93
+ id_to_meta = {v.get("chunk_id"): k for k, v in metadata_store.items()}
94
+ # Secondary lookup: (doc_id, chunk_index) → FAISS key (works for guidelines with random UUID IDs)
95
+ docidx_to_meta = {(v.get("doc_id", ""), v.get("chunk_index", 0)): k
96
+ for k, v in metadata_store.items()}
97
+ updated = 0
98
+ for chunk in all_input_chunks:
99
+ if section_kw in chunk.get("chunk_text", "").upper():
100
+ # Try primary match first
101
+ faiss_key = id_to_meta.get(chunk.get("chunk_id"))
102
+ # Fallback to (doc_id, chunk_index) match
103
+ if faiss_key is None:
104
+ faiss_key = docidx_to_meta.get((chunk.get("doc_id", ""), chunk.get("chunk_index", 0)))
105
+ if faiss_key is not None:
106
+ metadata_store[faiss_key]["chunk_text"] = chunk["chunk_text"]
107
+ updated += 1
108
+ logger.info("Force-updated chunk_text for %d '%s' entries", updated, section_kw)
109
+ if not args.dry_run:
110
+ with open(meta_path, "wb") as f:
111
+ pickle.dump(metadata_store, f, protocol=pickle.HIGHEST_PROTOCOL)
112
+ logger.info("Metadata store saved.")
113
+ # Invalidate BM25 cache
114
+ bm25_cache = Path(meta_path).parent / "bm25_cache.pkl"
115
+ if bm25_cache.exists():
116
+ bm25_cache.unlink()
117
+ logger.info("BM25 cache invalidated — will rebuild on next startup.")
118
+ return
119
+
120
+ # Deduplicate — skip chunks already in the index.
121
+ # Primary key: chunk_id. Secondary key: (doc_id, chunk_index) handles
122
+ # re-ingestion of the same document with new UUIDs (e.g. FDA label updates).
123
+ existing_ids = {v.get("chunk_id", "") for v in metadata_store.values()}
124
+ existing_docidx = {
125
+ (v.get("doc_id", ""), v.get("chunk_index", -1))
126
+ for v in metadata_store.values()
127
+ if v.get("doc_id") and v.get("chunk_index", -1) >= 0
128
+ }
129
+
130
+ def _is_duplicate(c: dict) -> bool:
131
+ if c.get("chunk_id") in existing_ids:
132
+ return True
133
+ key = (c.get("doc_id", ""), c.get("chunk_index", -1))
134
+ return key[0] != "" and key[1] >= 0 and key in existing_docidx
135
+
136
+ new_chunks = [c for c in all_input_chunks if not _is_duplicate(c)]
137
+
138
+ if not new_chunks:
139
+ logger.info("All chunks already in index. Nothing to add.")
140
+ return
141
+
142
+ logger.info("%d new chunks to add (%d duplicates skipped)",
143
+ len(new_chunks), len(all_input_chunks) - len(new_chunks))
144
+
145
+ if args.dry_run:
146
+ logger.info("DRY RUN — no changes written.")
147
+ for c in new_chunks[:5]:
148
+ logger.info(" Would add: %s | %s", c.get("chunk_id"), c.get("title", "")[:60])
149
+ return
150
+
151
+ # Embed new chunks only
152
+ embeddings = embed_chunks(new_chunks, model_name)
153
+
154
+ # Add to existing FAISS index
155
+ index.add(embeddings)
156
+ logger.info("Index now has %d vectors (+%d)", index.ntotal, len(new_chunks))
157
+
158
+ # Extend metadata store (new keys start from existing_count)
159
+ for i, chunk in enumerate(new_chunks):
160
+ metadata_store[existing_count + i] = {
161
+ "chunk_id": chunk.get("chunk_id", f"chunk_{existing_count + i}"),
162
+ "doc_id": chunk.get("doc_id", ""),
163
+ "source": chunk.get("source", ""),
164
+ "title": chunk.get("title", ""),
165
+ "pub_type": chunk.get("pub_type", "unknown"),
166
+ "pub_year": chunk.get("pub_year"),
167
+ "journal": chunk.get("journal", ""),
168
+ "chunk_index": chunk.get("chunk_index", 0),
169
+ "total_chunks": chunk.get("total_chunks", 1),
170
+ "chunk_text": chunk.get("chunk_text", ""),
171
+ }
172
+
173
+ # Save updated artifacts
174
+ faiss.write_index(index, idx_path)
175
+ logger.info("FAISS index saved to %s", idx_path)
176
+
177
+ with open(meta_path, "wb") as f:
178
+ pickle.dump(metadata_store, f, protocol=pickle.HIGHEST_PROTOCOL)
179
+ logger.info("Metadata store saved (%d total entries)", len(metadata_store))
180
+
181
+ # Also append to chunks.jsonl for future full rebuilds
182
+ chunks_jsonl = Path("data/processed/chunks.jsonl")
183
+ with open(chunks_jsonl, "a", encoding="utf-8") as f:
184
+ for chunk in new_chunks:
185
+ f.write(json.dumps(chunk) + "\n")
186
+ logger.info("Appended %d chunks to %s", len(new_chunks), chunks_jsonl)
187
+
188
+ logger.info("Done. Restart the backend to reload the updated index.")
189
+
190
+
191
+ if __name__ == "__main__":
192
+ main()
scripts/warmup.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/scripts/warmup.py
3
+ =====================
4
+ Pre-loads heavy ML models (FAISS, DeBERTa, SciSpaCy) into memory
5
+ and guarantees instantaneous responses for the first API request during the live demo.
6
+
7
+ Usage:
8
+ python scripts/warmup.py
9
+ """
10
+
11
+ import sys
12
+ from pathlib import Path
13
+
14
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
15
+
16
+ import logging
17
+ import time
18
+ import requests
19
+
20
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
21
+ logger = logging.getLogger("warmup")
22
+
23
+ def main():
24
+ api_url = "http://localhost:8000"
25
+
26
+ logger.info("Verifying API is running...")
27
+ try:
28
+ health = requests.get(f"{api_url}/health", timeout=5)
29
+ health.raise_for_status()
30
+ logger.info(f"API Health: {health.json()}")
31
+ except requests.exceptions.RequestException as e:
32
+ logger.error(f"API is not running at {api_url}. Please start it with 'uvicorn src.api.main:app' first.")
33
+ sys.exit(1)
34
+
35
+ logger.info("Sending WARMUP query to load DeBERTa, SciSpaCy, and FAISS into RAM... (This may take 15-25s)")
36
+ t0 = time.time()
37
+
38
+ # We send a basic query to force all models to initialize
39
+ payload = {
40
+ "question": "What is the recommended dosage of Metformin for elderly Type 2 Diabetes patients?",
41
+ "top_k": 1,
42
+ "run_ragas": False
43
+ }
44
+
45
+ try:
46
+ resp = requests.post(f"{api_url}/query", json=payload, timeout=60)
47
+ resp.raise_for_status()
48
+ elapsed = time.time() - t0
49
+ logger.info(f"Warmup successful in {elapsed:.1f}s!")
50
+ logger.info("All machine learning models are now cached in RAM.")
51
+ logger.info("The next API requests will be completely instantaneous.")
52
+ except Exception as e:
53
+ logger.error(f"Warmup failed: {e}")
54
+ if hasattr(e, "response") and e.response is not None:
55
+ logger.error(f"Response: {e.response.text}")
56
+
57
+ if __name__ == "__main__":
58
+ main()
setup.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="medirag-cli",
5
+ version="0.1.0",
6
+ packages=find_packages(),
7
+ install_requires=[
8
+ "typer>=0.9.0",
9
+ ],
10
+ entry_points={
11
+ "console_scripts": [
12
+ "medirag=src.cli:app",
13
+ ],
14
+ },
15
+ )
src/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/__init__.py — Package initializer and logging setup.
3
+ Runs once on first `import src`. Sets up logging from config.yaml.
4
+ (SRS Section 13)
5
+ """
6
+ import logging
7
+ import os
8
+
9
+
10
+ def _setup_logging() -> None:
11
+ """Configure root logger. No-op if handlers already exist."""
12
+ os.makedirs("logs", exist_ok=True)
13
+
14
+ log_level = logging.INFO
15
+ log_format = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
16
+ log_file = "logs/medirag.log"
17
+
18
+ # Try to load level from config.yaml
19
+ try:
20
+ import yaml
21
+ with open("config.yaml", "r") as f:
22
+ cfg = yaml.safe_load(f)
23
+ level_str = cfg.get("logging", {}).get("level", "INFO")
24
+ log_level = getattr(logging, level_str.upper(), logging.INFO)
25
+ log_file = cfg.get("logging", {}).get("file", log_file)
26
+ log_format = cfg.get("logging", {}).get("format", log_format)
27
+ except Exception:
28
+ pass # Use defaults if config not found (e.g., during tests)
29
+
30
+ root = logging.getLogger()
31
+ if root.handlers:
32
+ return # Already configured — don't add duplicate handlers
33
+
34
+ handlers: list[logging.Handler] = [logging.StreamHandler()]
35
+ try:
36
+ os.makedirs(os.path.dirname(log_file), exist_ok=True)
37
+ handlers.append(logging.FileHandler(log_file, encoding="utf-8"))
38
+ except Exception:
39
+ pass # File logging optional — don't fail on permission errors
40
+
41
+ logging.basicConfig(level=log_level, format=log_format, handlers=handlers)
42
+
43
+
44
+ _setup_logging()
src/api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # src/api/__init__.py
src/api/main.py ADDED
@@ -0,0 +1,933 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/api/main.py — MediRAG FastAPI Application
3
+ =============================================
4
+ FR-18: Two endpoints:
5
+ GET /health → liveness check + Ollama status
6
+ POST /evaluate → calls run_evaluation(), returns FR-17 JSON
7
+
8
+ Design decisions:
9
+ - DeBERTa model is loaded once at app startup (not per-request)
10
+ - If any module raises an exception, partial results are returned (no HTTP 500)
11
+ - HTTP 422 Pydantic validation errors are automatic
12
+ - RAGAS is disabled by default (run_ragas=False) — set to True only if
13
+ Ollama/OpenAI is available; the RAGAS module already fails gracefully.
14
+
15
+ To run:
16
+ uvicorn src.api.main:app --reload --host 0.0.0.0 --port 8000
17
+ """
18
+ from __future__ import annotations
19
+
20
+ import os
21
+ import logging
22
+ import time
23
+ from contextlib import asynccontextmanager
24
+ from pathlib import Path
25
+ from typing import Optional
26
+
27
+ import requests
28
+ import json
29
+ import sqlite3
30
+ import yaml
31
+ from datetime import datetime
32
+ from fastapi import FastAPI, HTTPException, File, UploadFile
33
+ from fastapi.middleware.cors import CORSMiddleware
34
+ from fastapi.responses import RedirectResponse
35
+
36
+ import threading
37
+ from src.api.schemas import (
38
+ HealthResponse,
39
+ EvaluateRequest,
40
+ EvaluateResponse,
41
+ QueryRequest,
42
+ QueryResponse,
43
+ RetrievedChunk,
44
+ IngestRequest,
45
+ ChatRequest,
46
+ ModuleScore,
47
+ ModuleResults,
48
+ )
49
+ from src.evaluate import run_evaluation
50
+ from src.pipeline.generator import generate_answer
51
+ from src.pipeline.retriever import Retriever
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # Logging
55
+ # ---------------------------------------------------------------------------
56
+ try:
57
+ _cfg = yaml.safe_load(Path("config.yaml").read_text())
58
+ _log_level = _cfg.get("logging", {}).get("level", "INFO")
59
+ _ollama_base = _cfg.get("llm", {}).get("base_url", "http://localhost:11434")
60
+ _api_cfg = _cfg.get("api", {})
61
+ except Exception:
62
+ _log_level = "INFO"
63
+ _ollama_base = "http://localhost:11434"
64
+ _api_cfg = {}
65
+
66
+ logging.basicConfig(
67
+ level=getattr(logging, _log_level, logging.INFO),
68
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
69
+ )
70
+ logger = logging.getLogger(__name__)
71
+
72
+
73
+ # ---------------------------------------------------------------------------
74
+ # Database settings
75
+ # ---------------------------------------------------------------------------
76
+
77
+ def init_db():
78
+ Path("data").mkdir(exist_ok=True)
79
+ conn = sqlite3.connect("data/logs.db")
80
+ c = conn.cursor()
81
+ c.execute("""
82
+ CREATE TABLE IF NOT EXISTS audit_logs (
83
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
84
+ timestamp TEXT,
85
+ endpoint TEXT,
86
+ question TEXT,
87
+ answer TEXT,
88
+ hrs INTEGER,
89
+ risk_band TEXT,
90
+ composite_score REAL,
91
+ latency_ms INTEGER,
92
+ intervention_applied BOOLEAN,
93
+ details TEXT
94
+ )
95
+ """)
96
+ conn.commit()
97
+ conn.close()
98
+
99
+ def log_audit(endpoint: str, question: str, answer: str, hrs: int, risk_band: str, composite: float, latency: int, intervention: bool, details: dict):
100
+ try:
101
+ conn = sqlite3.connect("data/logs.db")
102
+ c = conn.cursor()
103
+ c.execute("""
104
+ INSERT INTO audit_logs (timestamp, endpoint, question, answer, hrs, risk_band, composite_score, latency_ms, intervention_applied, details)
105
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
106
+ """, (
107
+ datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"),
108
+ endpoint,
109
+ question,
110
+ answer,
111
+ hrs,
112
+ risk_band,
113
+ composite,
114
+ latency,
115
+ intervention,
116
+ json.dumps(details)
117
+ ))
118
+ conn.commit()
119
+ conn.close()
120
+ except Exception as e:
121
+ logger.error(f"Failed to save audit log to DB: {e}")
122
+
123
+
124
+ # ---------------------------------------------------------------------------
125
+ # Lifespan: warm DeBERTa once at startup so the first request isn't slow
126
+ # ---------------------------------------------------------------------------
127
+ @asynccontextmanager
128
+ async def lifespan(app: FastAPI):
129
+ """Pre-warm DeBERTa and Retriever at startup."""
130
+ init_db()
131
+ logger.info("MediRAG API starting — pre-warming models...")
132
+ try:
133
+ from src.modules.faithfulness import _get_model
134
+ _get_model()
135
+ logger.info("DeBERTa pre-warm complete.")
136
+ except Exception as exc:
137
+ logger.warning("DeBERTa pre-warm skipped: %s", exc)
138
+
139
+ # Pre-load the retriever (BioBERT + FAISS index) into app state
140
+ try:
141
+ app.state.retriever = Retriever(_cfg)
142
+ # Trigger lazy load now so first /query request isn't slow
143
+ app.state.retriever._load_model()
144
+ app.state.retriever._load_index()
145
+ logger.info("Retriever pre-warm complete.")
146
+ except Exception as exc:
147
+ logger.warning("Retriever pre-warm skipped: %s", exc)
148
+ app.state.retriever = None
149
+
150
+ yield
151
+ logger.info("MediRAG API shutting down.")
152
+
153
+
154
+
155
+ # ---------------------------------------------------------------------------
156
+ # App
157
+ # ---------------------------------------------------------------------------
158
+ app = FastAPI(
159
+ title="MediRAG Evaluation API",
160
+ description=(
161
+ "Evaluate LLM-generated medical answers against retrieved evidence. "
162
+ "Returns faithfulness, entity accuracy, source credibility, "
163
+ "contradiction risk, and a composite Health Risk Score (HRS)."
164
+ ),
165
+ version="0.1.0",
166
+ lifespan=lifespan,
167
+ )
168
+
169
+ # Allow all origins for local dev / React frontend on same machine
170
+ app.add_middleware(
171
+ CORSMiddleware,
172
+ allow_origins=["*"],
173
+ allow_methods=["GET", "POST"],
174
+ allow_headers=["*"],
175
+ )
176
+
177
+
178
+ # ---------------------------------------------------------------------------
179
+ # Helper: check Ollama
180
+ # ---------------------------------------------------------------------------
181
+ def _check_ollama() -> bool:
182
+ """Return True if Ollama API is reachable."""
183
+ try:
184
+ resp = requests.get(f"{_ollama_base}/api/tags", timeout=2)
185
+ return resp.status_code == 200
186
+ except Exception:
187
+ return False
188
+
189
+
190
+ # ---------------------------------------------------------------------------
191
+ # Helper: convert EvalResult details → ModuleScore
192
+ # ---------------------------------------------------------------------------
193
+ def _module_score(module_results: dict, key: str) -> Optional[ModuleScore]:
194
+ data = module_results.get(key)
195
+ if data is None:
196
+ return None
197
+ return ModuleScore(
198
+ score=data.get("score", 0.0),
199
+ details=data.get("details", {}),
200
+ error=data.get("error"),
201
+ latency_ms=data.get("latency_ms"),
202
+ )
203
+
204
+
205
+ # ---------------------------------------------------------------------------
206
+ # GET / → redirect to /docs
207
+ # ---------------------------------------------------------------------------
208
+ @app.post("/project-guide")
209
+ def project_guide(req: ChatRequest):
210
+ """
211
+ Proxy endpoint for the Project Guide chatbot.
212
+ Routes requests to Groq API using the local GROQ_API_KEY.
213
+ """
214
+ groq_url = "https://api.groq.com/openai/v1/chat/completions"
215
+ api_key = os.getenv("GROQ_API_KEY")
216
+
217
+ if not api_key:
218
+ raise HTTPException(status_code=500, detail="GROQ_API_KEY not found in server environment")
219
+
220
+ headers = {
221
+ "Authorization": f"Bearer {api_key}",
222
+ "Content-Type": "application/json"
223
+ }
224
+
225
+ # Format messages for Groq
226
+ messages = []
227
+ if req.system_prompt:
228
+ messages.append({"role": "system", "content": req.system_prompt})
229
+
230
+ for m in req.messages:
231
+ messages.append({"role": m.role, "content": m.content})
232
+
233
+ payload = {
234
+ "model": "mixtral-8x7b-32768",
235
+ "messages": messages,
236
+ "temperature": 0.5,
237
+ "max_tokens": 1024
238
+ }
239
+
240
+ try:
241
+ resp = requests.post(groq_url, headers=headers, json=payload, timeout=30)
242
+ resp.raise_for_status()
243
+ return resp.json()
244
+ except Exception as e:
245
+ logger.error(f"Groq Proxy Error: {e}")
246
+ raise HTTPException(status_code=500, detail=str(e))
247
+
248
+
249
+ @app.get("/", include_in_schema=False)
250
+ def root():
251
+ return RedirectResponse(url="/docs")
252
+
253
+
254
+ # ---------------------------------------------------------------------------
255
+ # GET /health
256
+ # ---------------------------------------------------------------------------
257
+ @app.get("/health", response_model=HealthResponse, tags=["system"])
258
+ def health() -> HealthResponse:
259
+ """
260
+ Liveness check.
261
+
262
+ Returns {"status": "ok", "ollama_available": true/false}.
263
+ Always returns 200 — the caller decides what to do with `ollama_available`.
264
+ """
265
+ return HealthResponse(
266
+ status="ok",
267
+ ollama_available=_check_ollama(),
268
+ )
269
+
270
+
271
+ # ---------------------------------------------------------------------------
272
+ # POST /evaluate
273
+ # ---------------------------------------------------------------------------
274
+ @app.post("/evaluate", response_model=EvaluateResponse, tags=["evaluation"])
275
+ def evaluate(req: EvaluateRequest) -> EvaluateResponse:
276
+ """
277
+ Run the full MediRAG evaluation pipeline on a question + answer + context.
278
+
279
+ - Validates inputs (FR-18: length limits, chunk count)
280
+ - Runs Faithfulness, Entity Verification, Source Credibility, Contradiction
281
+ - Optionally runs RAGAS (set `run_ragas=true` if Ollama/OpenAI is available)
282
+ - Returns composite Health Risk Score (HRS) + per-module breakdown
283
+
284
+ **Note on `run_ragas`**: RAGAS requires a running LLM backend (Ollama or
285
+ OpenAI). If unavailable, RAGAS will gracefully return score=0.5 as a
286
+ neutral fallback — it will NOT crash the request.
287
+ """
288
+ logger.info(
289
+ "POST /evaluate — question=%r, chunks=%d, run_ragas=%s",
290
+ req.question[:80],
291
+ len(req.context_chunks),
292
+ req.run_ragas,
293
+ )
294
+
295
+ # Convert Pydantic ContextChunk → plain dicts for the pipeline
296
+ context_dicts: list[dict] = [chunk.model_dump(exclude_none=True) for chunk in req.context_chunks]
297
+
298
+ t0 = time.perf_counter()
299
+ try:
300
+ result = run_evaluation(
301
+ question=req.question,
302
+ answer=req.answer,
303
+ context_chunks=context_dicts,
304
+ rxnorm_cache_path=req.rxnorm_cache_path,
305
+ run_ragas=req.run_ragas,
306
+ config=_cfg,
307
+ )
308
+ except Exception as exc:
309
+ logger.exception("run_evaluation raised an unhandled exception: %s", exc)
310
+ raise HTTPException(
311
+ status_code=500,
312
+ detail=f"Evaluation pipeline error: {type(exc).__name__}: {exc}",
313
+ ) from exc
314
+
315
+ total_ms = int((time.perf_counter() - t0) * 1000)
316
+
317
+ # Extract composite score + details
318
+ composite = float(result.score)
319
+ details = result.details or {}
320
+ hrs = details.get("hrs", int(round(100 * (1.0 - composite))))
321
+ hrs = max(0, min(100, hrs))
322
+
323
+ confidence_level = details.get("confidence_level", "UNKNOWN")
324
+ risk_band = details.get("risk_band", "UNKNOWN")
325
+ pipeline_ms = details.get("total_pipeline_ms", total_ms)
326
+
327
+ # Build per-module scores
328
+ mod_results: dict = details.get("module_results", {})
329
+ module_scores = ModuleResults(
330
+ faithfulness=_module_score(mod_results, "faithfulness"),
331
+ entity_verifier=_module_score(mod_results, "entity_verifier"),
332
+ source_credibility=_module_score(mod_results, "source_credibility"),
333
+ contradiction=_module_score(mod_results, "contradiction"),
334
+ ragas=_module_score(mod_results, "ragas"),
335
+ )
336
+
337
+ logger.info(
338
+ "POST /evaluate → HRS=%d (%s) in %d ms",
339
+ hrs, risk_band, pipeline_ms,
340
+ )
341
+
342
+ log_audit("evaluate", req.question, req.answer, hrs, risk_band, composite, pipeline_ms, False, {
343
+ "module_results": mod_results,
344
+ "confidence_level": confidence_level
345
+ })
346
+
347
+ return EvaluateResponse(
348
+ composite_score=composite,
349
+ hrs=hrs,
350
+ confidence_level=confidence_level,
351
+ risk_band=risk_band,
352
+ module_results=module_scores,
353
+ total_pipeline_ms=pipeline_ms,
354
+ )
355
+
356
+
357
+ # ---------------------------------------------------------------------------
358
+ # POST /query — end-to-end: question → retrieve → generate → evaluate
359
+ # ---------------------------------------------------------------------------
360
+ @app.post("/query", response_model=QueryResponse, tags=["query"])
361
+ def query(req: QueryRequest) -> QueryResponse:
362
+ """
363
+ Full end-to-end MediRAG pipeline.
364
+
365
+ 1. Retrieves top-k context chunks from FAISS (BioBERT)
366
+ 2. Generates a grounded answer using Mistral (Ollama)
367
+ 3. Evaluates the answer with all 4 modules + aggregator
368
+ 4. Returns the answer, retrieved chunks, HRS score, and full breakdown
369
+
370
+ **Requires Ollama running locally with Mistral pulled.**
371
+ No fallback — returns 503 if Ollama is unavailable.
372
+ """
373
+ import time as _time
374
+ t_total = _time.perf_counter()
375
+
376
+ logger.info("POST /query — question=%r, top_k=%d", req.question[:80], req.top_k)
377
+
378
+ # Step 1: Retrieve
379
+ retriever: Optional[Retriever] = getattr(app.state, "retriever", None)
380
+ if retriever is None:
381
+ # Fallback: instantiate now (slower first call)
382
+ try:
383
+ retriever = Retriever(_cfg)
384
+ except Exception as exc:
385
+ raise HTTPException(status_code=503,
386
+ detail=f"Retriever unavailable: {exc}") from exc
387
+
388
+ try:
389
+ raw_results = retriever.search(req.question, top_k=req.top_k)
390
+ except FileNotFoundError as exc:
391
+ raise HTTPException(status_code=503,
392
+ detail=f"FAISS index not found: {exc}") from exc
393
+ except Exception as exc:
394
+ raise HTTPException(status_code=500,
395
+ detail=f"Retrieval error: {exc}") from exc
396
+
397
+ if not raw_results:
398
+ raise HTTPException(status_code=404,
399
+ detail="No relevant documents found for this question.")
400
+
401
+ # Convert retriever output → chunk dicts for generator + evaluate
402
+ context_chunks: list[dict] = []
403
+ retrieved_chunks_out: list[RetrievedChunk] = []
404
+ for chunk_text, meta, score in raw_results:
405
+ d = {
406
+ "text": chunk_text,
407
+ "chunk_id": meta.get("chunk_id"),
408
+ "source": meta.get("source", ""),
409
+ "pub_type": meta.get("pub_type", ""),
410
+ "pub_year": meta.get("pub_year"),
411
+ "title": meta.get("title", ""),
412
+ }
413
+ context_chunks.append(d)
414
+ retrieved_chunks_out.append(RetrievedChunk(
415
+ chunk_id=meta.get("chunk_id"),
416
+ text=chunk_text[:500], # truncate for response readability
417
+ source=meta.get("source", ""),
418
+ pub_type=meta.get("pub_type", ""),
419
+ pub_year=meta.get("pub_year"),
420
+ title=meta.get("title", ""),
421
+ similarity_score=round(score, 4),
422
+ ))
423
+
424
+ logger.info("Retrieved %d chunks (top score=%.4f)", len(context_chunks),
425
+ raw_results[0][2] if raw_results else 0.0)
426
+
427
+ # Raw FAISS cosine similarity for coverage gap gate.
428
+ # IndexFlatIP + L2-norm = cosine in [-1, 1]. < 0.60 means no close semantic match in DB.
429
+ top_faiss_cosine = (
430
+ raw_results[0][1].get("_top_faiss_cosine", 0.0) if raw_results else 0.0
431
+ )
432
+
433
+ # Convert request overrides into a dict for generator
434
+ llm_overrides = {}
435
+ if req.llm_provider:
436
+ llm_overrides["provider"] = req.llm_provider
437
+ if req.llm_api_key:
438
+ llm_overrides["api_key"] = req.llm_api_key
439
+ if req.llm_model:
440
+ llm_overrides["model"] = req.llm_model
441
+ if req.ollama_url:
442
+ llm_overrides["ollama_url"] = req.ollama_url
443
+ if req.system_prompt:
444
+ llm_overrides["system_prompt"] = req.system_prompt
445
+ if req.persona:
446
+ llm_overrides["persona"] = req.persona
447
+
448
+ # =========================================================================
449
+ # Step 2a: PRIVACY SHIELD — MediRAG redacts PHI (Option 1)
450
+ # =========================================================================
451
+ p_mapping = {}
452
+ privacy_applied = False
453
+ question_to_gen = req.question
454
+
455
+ if req.use_privacy_shield:
456
+ from src.pipeline.privacy import shield
457
+ question_to_gen, p_mapping = shield.redact(req.question)
458
+ if p_mapping:
459
+ privacy_applied = True
460
+ logger.info("PRIVACY INTERVENTION: Redacted %d items from question.", len(p_mapping))
461
+
462
+ # Step 2: Generate answer via LLM (Gemini or Ollama)
463
+ try:
464
+ # Use the potentially redacted question for generation
465
+ answer = generate_answer(question_to_gen, context_chunks, _cfg, overrides=llm_overrides)
466
+ except RuntimeError as exc:
467
+ raise HTTPException(status_code=503,
468
+ detail=f"LLM generation failed: {exc}") from exc
469
+
470
+ # Restore the PHI for the final display so the user sees the actual names
471
+ if privacy_applied:
472
+ from src.pipeline.privacy import shield
473
+ answer = shield.restore(answer, p_mapping)
474
+ # =========================================================================
475
+
476
+ # =========================================================================
477
+ # Step 2b: CONSENSUS CHECK — MediRAG compares multiple models (Option 2)
478
+ # =========================================================================
479
+ consensus_results = None
480
+ if req.use_consensus:
481
+ from src.pipeline.consensus import run_consensus_check
482
+ # Determine which providers to use based on available config/overrides
483
+ providers = ["gemini"]
484
+ if os.environ.get("GROQ_API_KEY"):
485
+ providers.append("groq")
486
+ elif os.environ.get("MISTRAL_API_KEY"):
487
+ providers.append("mistral")
488
+ else:
489
+ providers.append("ollama") # fallback to local if no second key
490
+
491
+ logger.info("Running Consensus Layer with %s", providers)
492
+ consensus_results = run_consensus_check(req.question, context_chunks, _cfg, providers=providers)
493
+
494
+ # If consensus finds a safer merged answer, we promote it
495
+ # and update the primary answer for the evaluation loop
496
+ answer = consensus_results.get("consensus_answer", answer)
497
+ # =========================================================================
498
+
499
+ # [DEMO MODE] Inject a false claim to demonstrate the intervention system
500
+ if req.inject_hallucination:
501
+ logger.warning("DEMO MODE: Injecting hallucinated claim into answer: '%s'",
502
+ req.inject_hallucination)
503
+ answer = answer + " " + req.inject_hallucination.strip()
504
+
505
+ # Step 3: Evaluate
506
+ try:
507
+ eval_result = run_evaluation(
508
+ question=req.question,
509
+ answer=answer,
510
+ context_chunks=context_chunks,
511
+ run_ragas=req.run_ragas,
512
+ config=_cfg,
513
+ )
514
+ except Exception as exc:
515
+ logger.exception("Evaluation failed: %s", exc)
516
+ try:
517
+ log_audit("query", req.question, answer, 100, "EVAL_ERROR", 0.0,
518
+ int((_time.perf_counter() - t_total) * 1000),
519
+ False, {"error": str(exc), "error_type": "evaluation_failure"})
520
+ except Exception:
521
+ pass
522
+ raise HTTPException(status_code=500,
523
+ detail=f"Evaluation error: {exc}") from exc
524
+
525
+ # =========================================================================
526
+ # Step 3b: INTERVENTION LOOP — MediRAG acts on evaluation results
527
+ # =========================================================================
528
+ from src.pipeline.generator import generate_strict_answer
529
+
530
+ details = eval_result.details or {}
531
+ composite = float(eval_result.score)
532
+ hrs = int(round(100 * (1.0 - composite)))
533
+ hrs = max(0, min(100, hrs))
534
+ mod_results: dict = details.get("module_results", {})
535
+
536
+ intervention_applied = False
537
+ intervention_reason = None
538
+ original_answer = None
539
+ intervention_details = None
540
+
541
+ faith_score = (mod_results.get("faithfulness") or {}).get("score", 1.0)
542
+
543
+ # Source-credibility-aware faith threshold: high-credibility sources get more tolerance
544
+ source_cred = float(details.get("component_scores", {}).get("source_credibility", 0.5))
545
+ faith_threshold = max(0.3, 0.7 - (source_cred * 0.4)) # 0.30 for cred=1.0, 0.66 for cred=0.3
546
+
547
+ # ── Coverage Gap Gate ────────────────────────────────────────────────────
548
+ # Two signals combined:
549
+ # 1. Refusal answer — LLM says "not in context / insufficient evidence"
550
+ # → LLM itself confirms the DB doesn't cover this topic.
551
+ # 2. FAISS cosine — genuinely poor semantic match vs. the query.
552
+ # BioBERT clusters medical dosing texts, so threshold must be high (0.75).
553
+ _REFUSAL_PATTERNS = (
554
+ "not mentioned in the provided context",
555
+ "not provided in the retrieved context",
556
+ "insufficient evidence in retrieved context",
557
+ "no information about",
558
+ "not in the provided context",
559
+ "cannot find information",
560
+ "the retrieved context does not contain",
561
+ "the context does not contain",
562
+ "not mentioned in the context",
563
+ "is not provided in the context",
564
+ )
565
+ _answer_lower = answer.lower()
566
+ is_refusal_answer = any(p in _answer_lower for p in _REFUSAL_PATTERNS)
567
+ is_low_faiss = top_faiss_cosine < 0.75
568
+
569
+ # If a verified drug with rxcui appears in the question, the intervention's
570
+ # FDA direct lookup can still retrieve the right data even when initial FAISS
571
+ # retrieval missed it. Don't label those as coverage gaps — let intervention run.
572
+ _ev_entities = (mod_results.get("entity_verifier") or {}).get("details", {}).get("entities", [])
573
+ _q_lower_cg = req.question.lower()
574
+ _drug_in_question = any(
575
+ e.get("rxcui") and e.get("entity", "").lower() in _q_lower_cg
576
+ for e in _ev_entities
577
+ )
578
+
579
+ # Refusal is a standalone COVERAGE_GAP signal — faith_score is unreliable here
580
+ # because NLI scores refusal sentences as NEUTRAL (0.5), not low.
581
+ # Exception: if a drug is named in the question, FDA lookup can still help.
582
+ # HALLUCINATION: specific claims made but not grounded in available context.
583
+ if is_refusal_answer and not _drug_in_question:
584
+ gap_type = "COVERAGE_GAP"
585
+ elif faith_score < faith_threshold and is_low_faiss and not _drug_in_question:
586
+ gap_type = "COVERAGE_GAP" # poor retrieval + low faith = DB lacks this topic
587
+ elif faith_score < faith_threshold:
588
+ gap_type = "HALLUCINATION" # relevant context exists but answer ignores it
589
+ else:
590
+ gap_type = None
591
+
592
+ coverage_gap = gap_type == "COVERAGE_GAP"
593
+ coverage_gap_details: dict | None = {
594
+ "gap_type": gap_type,
595
+ "top_faiss_cosine": round(top_faiss_cosine, 4),
596
+ "is_refusal_answer": is_refusal_answer,
597
+ "note": (
598
+ "Database coverage may be insufficient for this topic. "
599
+ "The answer could not be verified against retrieved evidence. "
600
+ "Consult primary medical literature or a specialist."
601
+ ) if coverage_gap else None,
602
+ } if gap_type else None
603
+ if coverage_gap:
604
+ logger.warning(
605
+ "COVERAGE_GAP detected — refusal=%s, faiss=%.4f, faith=%.2f",
606
+ is_refusal_answer, top_faiss_cosine, faith_score,
607
+ )
608
+
609
+ # Tier 1: CRITICAL BLOCK (HRS ≥ 86) — response is too dangerous to show
610
+ # Coverage gap: skip both tiers — regenerating from an empty DB won't help
611
+ if coverage_gap:
612
+ logger.info("COVERAGE_GAP — skipping intervention (regeneration cannot add missing data).")
613
+ elif hrs >= 86:
614
+ original_answer = answer
615
+ answer = (
616
+ "⛔ UNSAFE RESPONSE BLOCKED by MediRAG Safety Gate.\n\n"
617
+ "The generated answer was flagged as CRITICAL risk "
618
+ f"(Health Risk Score: {hrs}/100). "
619
+ "It showed signs of hallucination or contradiction with the retrieved evidence. "
620
+ "Please consult a qualified medical professional or rephrase your question."
621
+ )
622
+ intervention_applied = True
623
+ intervention_reason = "CRITICAL_BLOCKED"
624
+ intervention_details = {
625
+ "hrs_original": hrs,
626
+ "faithfulness": faith_score,
627
+ "message": "Response blocked: HRS ≥ 86 (CRITICAL risk band).",
628
+ }
629
+ logger.warning("INTERVENTION: CRITICAL_BLOCKED — HRS=%d", hrs)
630
+
631
+ # Tier 2: HIGH RISK REGENERATION
632
+ elif hrs >= 61 or faith_score < faith_threshold:
633
+ original_answer = answer
634
+ original_hrs = hrs
635
+ logger.warning(
636
+ "INTERVENTION: HIGH_RISK_REGENERATED — HRS=%d, faith=%.2f. Regenerating with strict prompt.",
637
+ hrs, faith_score
638
+ )
639
+ try:
640
+ # Re-retrieve from shared index — find better chunks than the ones that failed
641
+ try:
642
+ # Direct FDA lookup — only for drugs named in the question itself.
643
+ # Drugs found in the answer but NOT in the question (e.g. metformin
644
+ # mentioned incidentally in a general "first-line treatment" answer)
645
+ # should not trigger FDA lookup; that would replace relevant context
646
+ # with the wrong label sections (contraindications instead of treatment).
647
+ fda_direct: list[dict] = []
648
+ try:
649
+ ev_details = eval_result.details.get("module_results", {}).get("entity_verifier", {}).get("details", {})
650
+ verified_drugs = [
651
+ e["entity"] for e in ev_details.get("entities", [])
652
+ if e.get("status") == "VERIFIED" and e.get("rxcui")
653
+ ]
654
+ q_lower = req.question.lower()
655
+ for drug in verified_drugs:
656
+ if drug.lower() in q_lower:
657
+ fda_direct += app.state.retriever.get_fda_chunks(drug)
658
+ if fda_direct:
659
+ logger.info("Direct FDA lookup found %d chunks for drugs: %s",
660
+ len(fda_direct), [d for d in verified_drugs if d.lower() in q_lower])
661
+ except Exception as fda_exc:
662
+ logger.debug("Direct FDA lookup skipped: %s", fda_exc)
663
+
664
+ # Direct guideline lookup — only when original retrieval was poor.
665
+ # If FAISS cosine ≥ 0.85 the original chunks were already relevant;
666
+ # adding guideline sections here can pull in wrong topic areas
667
+ # (e.g., ADA Section 2 Diagnosis instead of Section 9 Treatment).
668
+ guideline_direct: list[dict] = []
669
+ if top_faiss_cosine < 0.85:
670
+ try:
671
+ guideline_direct = app.state.retriever.get_guideline_chunks(req.question)
672
+ if guideline_direct:
673
+ logger.info("Direct guideline lookup found %d chunks", len(guideline_direct))
674
+ except Exception as gl_exc:
675
+ logger.debug("Direct guideline lookup skipped: %s", gl_exc)
676
+ else:
677
+ logger.debug("Skipping guideline direct lookup (FAISS cosine=%.4f ≥ 0.85, original retrieval was high-quality)", top_faiss_cosine)
678
+
679
+ # Merge: guideline chunks + FDA chunks + fresh retrieval
680
+ fda_direct = guideline_direct + fda_direct
681
+
682
+ # For drug/clinical questions, expand query toward authoritative sources
683
+ _drug_terms = ("contraindication", "dosage", "dose", "interaction",
684
+ "warning", "adverse", "side effect", "mechanism")
685
+ _q_lower = req.question.lower()
686
+ retry_query = (
687
+ f"FDA drug label clinical guideline {req.question}"
688
+ if any(t in _q_lower for t in _drug_terms)
689
+ else req.question
690
+ )
691
+ fresh_results = app.state.retriever.search(retry_query, top_k=req.top_k)
692
+ fresh_chunks: list[dict] = []
693
+ for chunk_text, meta, score in fresh_results:
694
+ fresh_chunks.append({
695
+ "text": chunk_text, "chunk_id": meta.get("chunk_id"),
696
+ "source": meta.get("source", ""), "pub_type": meta.get("pub_type", ""),
697
+ "pub_year": meta.get("pub_year"), "title": meta.get("title", ""),
698
+ })
699
+ # Merge: direct lookups first (FDA/guidelines), then fresh retrieval
700
+ base_chunks = fresh_chunks if fresh_chunks else context_chunks
701
+ retry_chunks = (fda_direct + base_chunks)[:req.top_k] if fda_direct else base_chunks
702
+ logger.info("Re-retrieval for intervention: %d fresh chunks (top source: %s)",
703
+ len(retry_chunks),
704
+ retry_chunks[0].get("pub_type", "?") if retry_chunks else "none")
705
+ except Exception:
706
+ retry_chunks = context_chunks
707
+
708
+ answer = generate_strict_answer(req.question, retry_chunks, _cfg, overrides=llm_overrides)
709
+ # Re-evaluate the corrected answer
710
+ eval_result = run_evaluation(
711
+ question=req.question,
712
+ answer=answer,
713
+ context_chunks=retry_chunks,
714
+ run_ragas=False, # skip RAGAS on retry to reduce latency
715
+ config=_cfg,
716
+ )
717
+ details = eval_result.details or {}
718
+ composite = float(eval_result.score)
719
+ hrs = int(round(100 * (1.0 - composite)))
720
+ hrs = max(0, min(100, hrs))
721
+ mod_results = details.get("module_results", {})
722
+ except Exception as exc:
723
+ logger.error("Strict regeneration failed: %s — keeping original answer", exc)
724
+ answer = original_answer # fall back gracefully
725
+ original_answer = None
726
+
727
+ intervention_applied = True
728
+ intervention_reason = "HIGH_RISK_REGENERATED"
729
+ intervention_details = {
730
+ "hrs_original": original_hrs,
731
+ "hrs_corrected": hrs,
732
+ "faithfulness_original": faith_score,
733
+ "faithfulness_corrected": (mod_results.get("faithfulness") or {}).get("score", 0),
734
+ "message": "Response regenerated with strict context-only prompt due to high risk score.",
735
+ }
736
+ # =========================================================================
737
+
738
+ # Step 4: Build response
739
+ total_ms = int((_time.perf_counter() - t_total) * 1000)
740
+ logger.info("POST /query → HRS=%d (%s) intervention=%s in %d ms total",
741
+ hrs, details.get("risk_band", "?"), intervention_reason or "none", total_ms)
742
+
743
+ log_audit("query", req.question, answer, hrs, details.get("risk_band", "UNKNOWN"), composite, total_ms, intervention_applied, {
744
+ "module_results": mod_results,
745
+ "confidence_level": details.get("confidence_level", "UNKNOWN"),
746
+ "intervention_reason": intervention_reason,
747
+ "original_answer": original_answer,
748
+ })
749
+
750
+ return QueryResponse(
751
+ question=req.question,
752
+ generated_answer=answer,
753
+ retrieved_chunks=retrieved_chunks_out,
754
+ composite_score=composite,
755
+ hrs=hrs,
756
+ confidence_level=details.get("confidence_level", "UNKNOWN"),
757
+ risk_band=details.get("risk_band", "UNKNOWN"),
758
+ module_results=ModuleResults(
759
+ faithfulness=_module_score(mod_results, "faithfulness"),
760
+ entity_verifier=_module_score(mod_results, "entity_verifier"),
761
+ source_credibility=_module_score(mod_results, "source_credibility"),
762
+ contradiction=_module_score(mod_results, "contradiction"),
763
+ ragas=_module_score(mod_results, "ragas"),
764
+ ),
765
+ total_pipeline_ms=total_ms,
766
+ intervention_applied=intervention_applied,
767
+ intervention_reason=intervention_reason,
768
+ original_answer=original_answer,
769
+ intervention_details=intervention_details,
770
+ consensus_results=consensus_results,
771
+ privacy_applied=privacy_applied,
772
+ privacy_details={"redacted_count": len(p_mapping)} if privacy_applied else None,
773
+ coverage_gap=coverage_gap,
774
+ coverage_gap_details=coverage_gap_details,
775
+ )
776
+
777
+ # ---------------------------------------------------------------------------
778
+ # POST /ingest — dynamically append new documents to the FAISS index
779
+ # ---------------------------------------------------------------------------
780
+ _faiss_lock = threading.Lock()
781
+
782
+ @app.post("/ingest", tags=["ingestion"])
783
+ def ingest_document(req: IngestRequest):
784
+ """
785
+ Dynamically ingest a new document into the running FAISS index.
786
+ Thread-safe implementation uses a lock to prevent concurrent write corruption.
787
+ """
788
+ import pickle
789
+ import faiss
790
+ from src.pipeline.chunker import chunk_documents
791
+
792
+
793
+ retriever = getattr(app.state, "retriever", None)
794
+ if retriever is None or retriever._index is None:
795
+ raise HTTPException(status_code=503, detail="Retriever not pre-warmed. Cannot ingest.")
796
+
797
+ logger.info("POST /ingest — title='%s', size=%d chars", req.title, len(req.text))
798
+
799
+ # 1. Chunk the document
800
+ doc = {
801
+ "text": req.text,
802
+ "doc_id": "custom_" + req.title[:10],
803
+ "title": req.title,
804
+ "source": req.source,
805
+ "pub_type": req.pub_type,
806
+ "pub_year": 2026,
807
+ }
808
+ chunks = chunk_documents([doc], _cfg)
809
+
810
+ if not chunks:
811
+ raise HTTPException(status_code=400, detail="Document produced 0 chunks.")
812
+
813
+ # 2. Embed the chunks using the same BioBERT model as the retriever
814
+ from src.pipeline.embedder import encode_texts
815
+ import numpy as np
816
+
817
+ # Reuse already-loaded SentenceTransformer from the retriever to avoid double RAM load
818
+ if retriever._model is None:
819
+ retriever._load_model()
820
+ st_model = retriever._model
821
+
822
+ texts = [c["chunk_text"] for c in chunks]
823
+ embeddings = np.array(st_model.encode(texts, show_progress_bar=False), dtype=np.float32)
824
+ faiss.normalize_L2(embeddings) # Required: index is IndexFlatIP = cosine sim
825
+
826
+ # 3. Thread-safe Index Update with atomic disk writes
827
+ with _faiss_lock:
828
+ import os
829
+ idx_path = Path(_cfg["retrieval"]["index_path"])
830
+ meta_path = Path(_cfg["retrieval"]["metadata_path"])
831
+
832
+ index = retriever._index
833
+ metadata_store = retriever._metadata
834
+
835
+ start_id = len(metadata_store)
836
+
837
+ # Add to in-memory structures
838
+ for i, chunk in enumerate(chunks):
839
+ metadata_store[start_id + i] = chunk
840
+
841
+ # Add to FAISS in memory
842
+ index.add(embeddings)
843
+
844
+ # Atomic FAISS write: write to temp → rename (never leaves a half-written file)
845
+ idx_tmp = str(idx_path) + ".tmp"
846
+ faiss.write_index(index, idx_tmp)
847
+ os.replace(idx_tmp, str(idx_path))
848
+
849
+ # Atomic metadata write
850
+ meta_tmp = str(meta_path) + ".tmp"
851
+ with open(meta_tmp, "wb") as f:
852
+ pickle.dump(metadata_store, f)
853
+ os.replace(meta_tmp, str(meta_path))
854
+
855
+ # 4. Rebuild BM25 for the running instance
856
+ retriever.rebuild_bm25()
857
+
858
+ logger.info("Successfully injected %d chunks for '%s' into FAISS and BM25.", len(chunks), req.title)
859
+ return {"status": "success", "chunks_added": len(chunks), "title": req.title}
860
+
861
+ # ---------------------------------------------------------------------------
862
+ # GET /logs and /stats — fetch history for dashboard
863
+ # ---------------------------------------------------------------------------
864
+ @app.get("/logs", tags=["dashboard"])
865
+ def get_logs(limit: int = 50):
866
+ try:
867
+ conn = sqlite3.connect("data/logs.db")
868
+ conn.row_factory = sqlite3.Row
869
+ c = conn.cursor()
870
+ c.execute("SELECT * FROM audit_logs ORDER BY id DESC LIMIT ?", (limit,))
871
+ rows = c.fetchall()
872
+ conn.close()
873
+ return [dict(ix) for ix in rows]
874
+ except Exception as e:
875
+ return []
876
+
877
+ @app.get("/stats", tags=["dashboard"])
878
+ def get_stats():
879
+ try:
880
+ conn = sqlite3.connect("data/logs.db")
881
+ c = conn.cursor()
882
+ c.execute("SELECT COUNT(*), AVG(hrs), SUM(CASE WHEN risk_band='CRITICAL' THEN 1 ELSE 0 END) FROM audit_logs")
883
+ total_evals, avg_hrs, crit_alerts = c.fetchone()
884
+
885
+ c.execute("SELECT SUM(CASE WHEN intervention_applied=1 THEN 1 ELSE 0 END) FROM audit_logs")
886
+ interventions = c.fetchone()[0]
887
+
888
+ # Monthly data example
889
+ monthly_query = "SELECT SUBSTR(timestamp, 1, 7) as month, AVG(hrs) FROM audit_logs GROUP BY month ORDER BY month LIMIT 12"
890
+ c.execute(monthly_query)
891
+ monthly_data = [{"month": row[0], "avg_hrs": row[1]} for row in c.fetchall()]
892
+
893
+ conn.close()
894
+ return {
895
+ "totalEvals": total_evals or 0,
896
+ "avgHrs": round(avg_hrs or 0, 1),
897
+ "critAlerts": crit_alerts or 0,
898
+ "interventions": interventions or 0,
899
+ "monthly": monthly_data
900
+ }
901
+ except Exception as e:
902
+ return {
903
+ "totalEvals": 0, "avgHrs": 0, "critAlerts": 0, "interventions": 0, "monthly": []
904
+ }
905
+
906
+ # ---------------------------------------------------------------------------
907
+ # POST /parse_file — helper for frontend to extract PDF/DOCX text
908
+ # ---------------------------------------------------------------------------
909
+ @app.post("/parse_file", tags=["ingestion"])
910
+ async def parse_file(file: UploadFile = File(...)):
911
+ """Extract text from uploaded txt, md, pdf, or docx files."""
912
+ content = await file.read()
913
+ filename = file.filename.lower()
914
+ text = ""
915
+ try:
916
+ if filename.endswith(".pdf"):
917
+ import fitz
918
+ doc = fitz.open(stream=content, filetype="pdf")
919
+ msgs = []
920
+ for page in doc:
921
+ msgs.append(page.get_text())
922
+ text = "\n".join(msgs)
923
+ elif filename.endswith(".docx"):
924
+ import docx
925
+ from io import BytesIO
926
+ doc = docx.Document(BytesIO(content))
927
+ text = "\n".join([p.text for p in doc.paragraphs])
928
+ else:
929
+ text = content.decode("utf-8", errors="replace")
930
+ return {"status": "success", "text": text}
931
+ except Exception as e:
932
+ raise HTTPException(status_code=400, detail=f"Failed to parse file: {e}")
933
+
src/api/schemas.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/api/schemas.py — Pydantic request/response models for MediRAG FastAPI
3
+ =========================================================================
4
+ FR-18: Input validation limits from config.yaml → api:
5
+ - max_query_length: 500
6
+ - max_answer_length: 2000
7
+ - max_chunks: 10
8
+ - max_chunk_length: 2000
9
+ """
10
+ from __future__ import annotations
11
+
12
+ from typing import Any, Dict, List, Optional
13
+ from pydantic import BaseModel, Field, field_validator
14
+
15
+ class IngestRequest(BaseModel):
16
+ """POST /ingest — append a custom document to the FAISS index."""
17
+ title: str = Field(..., description="Document title")
18
+ text: str = Field(..., min_length=10, description="Raw text of the document to ingest")
19
+ pub_type: str = Field(default="clinical_guideline", description="Document type")
20
+ source: str = Field(default="custom_upload", description="Source of the document")
21
+
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # Request schemas
25
+ # ---------------------------------------------------------------------------
26
+
27
+ class ContextChunk(BaseModel):
28
+ """A single retrieved context chunk passed to the evaluation pipeline."""
29
+ text: str = Field(..., min_length=1, max_length=2000,
30
+ description="Chunk text (max 2000 chars)")
31
+ # Optional metadata fields — all pass-through to the pipeline modules
32
+ chunk_id: Optional[str] = None
33
+ pub_type: Optional[str] = None
34
+ pub_year: Optional[int] = None
35
+ source: Optional[str] = None
36
+ title: Optional[str] = None
37
+ tier_type: Optional[str] = None # pre-labelled evidence tier (optional)
38
+ score: Optional[float] = None # retrieval similarity score
39
+
40
+
41
+ class EvaluateRequest(BaseModel):
42
+ """POST /evaluate — request body."""
43
+ question: str = Field(
44
+ ...,
45
+ min_length=5,
46
+ max_length=500,
47
+ description="User question (5–500 chars)",
48
+ examples=["What is the recommended dosage of Metformin for Type 2 Diabetes in elderly patients?"],
49
+ )
50
+ answer: str = Field(
51
+ ...,
52
+ min_length=1,
53
+ max_length=2000,
54
+ description="LLM-generated answer to evaluate (1–2000 chars)",
55
+ examples=["Metformin is typically started at 500 mg twice daily with meals..."],
56
+ )
57
+ context_chunks: List[ContextChunk] = Field(
58
+ ...,
59
+ min_length=1,
60
+ max_length=10,
61
+ description="Retrieved context chunks (1–10 items)",
62
+ )
63
+ run_ragas: bool = Field(
64
+ default=False,
65
+ description="Run RAGAS evaluation (requires Ollama or OpenAI backend; slower)",
66
+ )
67
+ llm_provider: Optional[str] = Field(
68
+ default=None,
69
+ description="LLM provider override: 'gemini' or 'ollama'"
70
+ )
71
+ llm_api_key: Optional[str] = Field(
72
+ default=None,
73
+ description="API Key if accessing Gemini"
74
+ )
75
+ llm_model: Optional[str] = Field(
76
+ default=None,
77
+ description="Specific model string if overriding defaults"
78
+ )
79
+ rxnorm_cache_path: str = Field(
80
+ default="data/rxnorm_cache.csv",
81
+ description="Path to RxNorm cache CSV",
82
+ )
83
+
84
+ @field_validator("context_chunks")
85
+ @classmethod
86
+ def at_least_one_chunk(cls, v: list) -> list:
87
+ if len(v) == 0:
88
+ raise ValueError("At least one context chunk is required")
89
+ return v
90
+
91
+
92
+ # ---------------------------------------------------------------------------
93
+ # Response schemas
94
+ # ---------------------------------------------------------------------------
95
+
96
+ class ModuleScore(BaseModel):
97
+ """Score + details dict for a single evaluation module."""
98
+ score: float = Field(..., ge=0.0, le=1.0, description="Module score in [0, 1]")
99
+ details: Dict[str, Any] = Field(default_factory=dict)
100
+ error: Optional[str] = Field(None, description="Error message if module failed")
101
+ latency_ms: Optional[int] = None
102
+
103
+
104
+ class ModuleResults(BaseModel):
105
+ """All per-module scores bundled together."""
106
+ faithfulness: Optional[ModuleScore] = None
107
+ entity_verifier: Optional[ModuleScore] = None
108
+ source_credibility: Optional[ModuleScore] = None
109
+ contradiction: Optional[ModuleScore] = None
110
+ ragas: Optional[ModuleScore] = None
111
+
112
+
113
+ class EvaluateResponse(BaseModel):
114
+ """POST /evaluate — response body (FR-17 format)."""
115
+ composite_score: float = Field(
116
+ ..., ge=0.0, le=1.0,
117
+ description="Weighted composite score in [0, 1]"
118
+ )
119
+ hrs: int = Field(
120
+ ..., ge=0, le=100,
121
+ description="Health Risk Score = round(100 × (1 - composite_score))"
122
+ )
123
+ confidence_level: str = Field(
124
+ ...,
125
+ description="HIGH / MODERATE / LOW",
126
+ )
127
+ risk_band: str = Field(
128
+ ...,
129
+ description="LOW / MODERATE / HIGH / CRITICAL",
130
+ )
131
+ module_results: ModuleResults
132
+ total_pipeline_ms: int = Field(..., description="Total wall-clock time in ms")
133
+
134
+
135
+ class ChatMessage(BaseModel):
136
+ role: str
137
+ content: str
138
+
139
+ class ChatRequest(BaseModel):
140
+ messages: List[ChatMessage]
141
+ system_prompt: Optional[str] = None
142
+ persona: Optional[str] = "physician"
143
+
144
+
145
+ class HealthResponse(BaseModel):
146
+ """GET /health — liveness and dependency status."""
147
+ status: str = Field(default="ok")
148
+ ollama_available: bool
149
+ version: str = Field(default="0.1.0")
150
+
151
+
152
+ # ---------------------------------------------------------------------------
153
+ # End-to-end query schemas (POST /query)
154
+ # ---------------------------------------------------------------------------
155
+
156
+ class QueryRequest(BaseModel):
157
+ """POST /query — only a question needed; retrieval + generation happen server-side."""
158
+ question: str = Field(
159
+ ...,
160
+ min_length=5,
161
+ max_length=8000,
162
+ description="Medical question (5–8000 chars; may include doc context)",
163
+ examples=["What is the recommended dosage of Metformin for elderly Type 2 Diabetes patients?"],
164
+ )
165
+ top_k: int = Field(
166
+ default=5,
167
+ ge=1,
168
+ le=10,
169
+ description="Number of context chunks to retrieve (1–10)",
170
+ )
171
+ run_ragas: bool = Field(
172
+ default=False,
173
+ description="Run RAGAS evaluation (requires LLM backend)",
174
+ )
175
+ # Per-request LLM overrides — if not set, server config.yaml values are used
176
+ # This makes the eval engine portable: callers bring their own key + model
177
+ llm_provider: Optional[str] = Field(
178
+ default=None,
179
+ description="LLM provider override: 'gemini' or 'ollama'"
180
+ )
181
+ llm_api_key: Optional[str] = Field(
182
+ default=None,
183
+ description="API key override (e.g. Gemini key). Not logged or stored."
184
+ )
185
+ llm_model: Optional[str] = Field(
186
+ default=None,
187
+ description="Model name override (e.g. 'gemini-2.5-flash-lite')"
188
+ )
189
+ ollama_url: Optional[str] = Field(
190
+ default=None,
191
+ description="Ollama base URL override (e.g. 'http://localhost:11434')"
192
+ )
193
+ # Demo/test only — injects a false claim into the LLM answer before evaluation
194
+ # to demonstrate the intervention system catching hallucinations.
195
+ inject_hallucination: Optional[str] = Field(
196
+ default=None,
197
+ description="[DEMO ONLY] Appends a false medical claim to the answer before evaluation."
198
+ )
199
+ # Consensus Engine (Option 2)
200
+ use_consensus: bool = Field(
201
+ default=False,
202
+ description="Run multiple models and compare for clinical agreement."
203
+ )
204
+ # Privacy Shield (Option 1)
205
+ use_privacy_shield: bool = Field(
206
+ default=False,
207
+ description="Automatically redact PHI/PII (names, IDs) before external API calls.",
208
+ )
209
+ system_prompt: Optional[str] = Field(
210
+ default=None,
211
+ description="Custom system prompt to override the default clinical persona."
212
+ )
213
+ persona: Optional[str] = Field(
214
+ default="physician",
215
+ description="The target audience for the response: 'physician' or 'patient'."
216
+ )
217
+
218
+
219
+ class RetrievedChunk(BaseModel):
220
+ """A single chunk returned alongside the query response for transparency."""
221
+ chunk_id: Optional[str] = None
222
+ text: str
223
+ source: Optional[str] = None
224
+ pub_type: Optional[str] = None
225
+ pub_year: Optional[int] = None
226
+ title: Optional[str] = None
227
+ similarity_score: Optional[float] = None
228
+
229
+
230
+ class QueryResponse(BaseModel):
231
+ """POST /query — full end-to-end response."""
232
+ question: str
233
+ generated_answer: str
234
+ retrieved_chunks: List[RetrievedChunk]
235
+ # Evaluation fields (same as EvaluateResponse)
236
+ composite_score: float = Field(..., ge=0.0, le=1.0)
237
+ hrs: int = Field(..., ge=0, le=100)
238
+ confidence_level: str
239
+ risk_band: str
240
+ module_results: ModuleResults
241
+ total_pipeline_ms: int
242
+ # Intervention fields (active safety gate)
243
+ intervention_applied: bool = Field(
244
+ default=False,
245
+ description="True if the system modified or blocked the response for safety.",
246
+ )
247
+ intervention_reason: Optional[str] = Field(
248
+ default=None,
249
+ description="CRITICAL_BLOCKED | HIGH_RISK_REGENERATED | null",
250
+ )
251
+ original_answer: Optional[str] = Field(
252
+ default=None,
253
+ description="The original (unsafe) LLM answer before intervention, for transparency.",
254
+ )
255
+ intervention_details: Optional[Dict[str, Any]] = Field(
256
+ default=None,
257
+ description="Which modules triggered the intervention and their scores.",
258
+ )
259
+ # Consensus fields
260
+ consensus_results: Optional[Dict[str, Any]] = Field(
261
+ default=None,
262
+ description="Results from the multi-model agreement check."
263
+ )
264
+ # Privacy Shield fields
265
+ privacy_applied: bool = Field(default=False)
266
+ privacy_details: Optional[Dict[str, Any]] = Field(default=None)
267
+ # Coverage gap gate — distinguishes missing DB coverage from hallucination
268
+ coverage_gap: bool = Field(
269
+ default=False,
270
+ description="True when retrieval quality is low — the database may lack coverage for this topic.",
271
+ )
272
+ coverage_gap_details: Optional[Dict[str, Any]] = Field(
273
+ default=None,
274
+ description="gap_type (COVERAGE_GAP | HALLUCINATION), retrieval_confidence, threshold.",
275
+ )
276
+
src/cli.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typer
2
+ import subprocess
3
+ import webbrowser
4
+ import time
5
+ import socket
6
+ import os
7
+ import sys
8
+
9
+ app = typer.Typer(help="MediRAG Command Line Interface")
10
+
11
+ def is_port_in_use(port: int) -> bool:
12
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
13
+ return s.connect_ex(('localhost', port)) == 0
14
+
15
+ @app.command()
16
+ def start():
17
+ """Start the full MediRAG experience (Backend + Full Frontend)"""
18
+ typer.echo("Starting full MediRAG experience...")
19
+ run_servers(practical_mode=False)
20
+
21
+ @app.command()
22
+ def api():
23
+ """Start the streamlined 'practical' UI"""
24
+ typer.echo("Starting streamlined MediRAG practical UI...")
25
+ run_servers(practical_mode=True)
26
+
27
+ def run_servers(practical_mode: bool):
28
+ # Check ports
29
+ if is_port_in_use(8000):
30
+ typer.echo("Warning: Port 8000 (Backend) might already be in use.")
31
+ if is_port_in_use(5173):
32
+ typer.echo("Warning: Port 5173 (Frontend) might already be in use.")
33
+
34
+ backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
35
+ frontend_dir = os.path.join(os.path.dirname(backend_dir), "Frontend")
36
+
37
+ # Start Backend
38
+ typer.echo("Starting Backend server...")
39
+ backend_process = subprocess.Popen(
40
+ [sys.executable, "-m", "uvicorn", "src.api.main:app", "--host", "0.0.0.0", "--port", "8000"],
41
+ cwd=backend_dir
42
+ )
43
+
44
+ # Start Frontend
45
+ typer.echo("Starting Frontend server...")
46
+ # On Windows, npm run dev needs shell=True or using cmd /c
47
+ frontend_process = subprocess.Popen(
48
+ ["cmd", "/c", "npm", "run", "dev"] if os.name == 'nt' else ["npm", "run", "dev"],
49
+ cwd=frontend_dir
50
+ )
51
+
52
+ typer.echo("Waiting for servers to start...")
53
+ time.sleep(5) # Basic wait for frontend to spin up
54
+
55
+ url = "http://localhost:5173/cli-view" if practical_mode else "http://localhost:5173/"
56
+ typer.echo(f"Opening browser at {url}...")
57
+ webbrowser.open(url)
58
+
59
+ try:
60
+ # Keep process alive
61
+ backend_process.wait()
62
+ frontend_process.wait()
63
+ except KeyboardInterrupt:
64
+ typer.echo("\nShutting down servers...")
65
+ backend_process.terminate()
66
+ frontend_process.terminate()
67
+ typer.echo("Servers stopped.")
68
+
69
+ if __name__ == "__main__":
70
+ app()
src/dashboard/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # src/dashboard/__init__.py
src/evaluate.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FR-22: src/evaluate.py — MediRAG Evaluation Orchestrator
3
+ =========================================================
4
+ Top-level entry point for the evaluation pipeline.
5
+
6
+ Runs all 4 evaluation modules + RAGAS + aggregator for a given
7
+ (question, answer, context_docs) triple, returning a fully structured
8
+ composite EvalResult.
9
+
10
+ Usage as a module:
11
+ from src.evaluate import run_evaluation
12
+ result = run_evaluation(question, answer, context_docs)
13
+ print(f"Score: {result.score:.3f} ({result.details['confidence_level']})")
14
+
15
+ Usage from CLI:
16
+ python -m src.evaluate \\
17
+ --question "What is the recommended dosage of Metformin for Type 2 Diabetes?" \\
18
+ --answer "Metformin is typically started at 500mg twice daily..." \\
19
+ --context-file data/processed/chunks.jsonl \\
20
+ --top-k 5
21
+
22
+ SRS reference: FR-22, Section 7 (Evaluation Pipeline Overview)
23
+ """
24
+ from __future__ import annotations
25
+
26
+ import argparse
27
+ import json
28
+ import logging
29
+ import sys
30
+ import time
31
+ from pathlib import Path
32
+ from typing import Optional
33
+
34
+ from src.modules.base import EvalResult
35
+ from src.modules.faithfulness import score_faithfulness
36
+ from src.modules.entity_verifier import verify_entities
37
+ from src.modules.source_credibility import score_source_credibility
38
+ from src.modules.contradiction import score_contradiction
39
+ from src.evaluation.ragas_eval import score_ragas
40
+ from src.evaluation.aggregator import aggregate
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ # ---------------------------------------------------------------------------
46
+ # Main evaluation function
47
+ # ---------------------------------------------------------------------------
48
+
49
+ def run_evaluation(
50
+ question: str,
51
+ answer: str,
52
+ context_chunks: list[dict],
53
+ rxnorm_cache_path: str = "data/rxnorm_cache.csv",
54
+ run_ragas: bool = True,
55
+ weights: Optional[dict[str, float]] = None,
56
+ config: Optional[dict] = None,
57
+ ) -> EvalResult:
58
+ """
59
+ Run the full MediRAG evaluation pipeline for a single QA pair.
60
+
61
+ Args:
62
+ question : Original user question.
63
+ answer : LLM-generated answer to evaluate.
64
+ context_chunks : List of retrieved chunk dicts (from retriever.retrieve()).
65
+ Each chunk must have at minimum {'text': str}.
66
+ rxnorm_cache_path : Path to rxnorm_cache.csv for entity verification.
67
+ run_ragas : Whether to run the RAGAS module (requires LLM backend).
68
+ weights : Override default aggregation weights (optional).
69
+
70
+ Returns:
71
+ EvalResult for the "aggregator" module containing:
72
+ .score → composite score in [0, 1]
73
+ .details → full breakdown per module
74
+ .latency_ms → total wall-clock time
75
+ """
76
+ t_start = time.perf_counter()
77
+ logger.info("=== MediRAG Evaluation START ===")
78
+ logger.info("Question: %s", question[:120])
79
+ logger.info("Answer : %s", answer[:120])
80
+ logger.info("Chunks : %d context documents", len(context_chunks))
81
+
82
+ # Extract text and metadata for modules that need it
83
+ context_texts: list[str] = [c.get("text", "") for c in context_chunks]
84
+ chunk_ids: list[str] = [
85
+ c.get("chunk_id") or c.get("metadata", {}).get("chunk_id") or f"chunk_{i}"
86
+ for i, c in enumerate(context_chunks)
87
+ ]
88
+
89
+ # -------------------------------------------------------------------------
90
+ # Retrieval Quality Gate
91
+ # If the retriever's absolute RRF score is too low, the chunks are likely
92
+ # unrelated to the question — evaluation against them produces false HRS spikes.
93
+ # Threshold: max raw RRF for top-1 in both sources = 2/(60+1) ≈ 0.0328
94
+ # We flag as insufficient if max_rrf < 0.008 (only very weak BM25 or FAISS match)
95
+ # -------------------------------------------------------------------------
96
+ RETRIEVAL_CONFIDENCE_THRESHOLD = 0.008
97
+ retrieval_confidence = context_chunks[0].get("_retrieval_confidence", 1.0) if context_chunks else 0.0
98
+
99
+ if context_chunks and retrieval_confidence < RETRIEVAL_CONFIDENCE_THRESHOLD:
100
+ logger.warning(
101
+ "Retrieval confidence %.6f below threshold %.3f — context likely irrelevant to question.",
102
+ retrieval_confidence, RETRIEVAL_CONFIDENCE_THRESHOLD,
103
+ )
104
+ total_ms = int((time.perf_counter() - t_start) * 1000)
105
+ return EvalResult(
106
+ module_name="aggregator",
107
+ score=0.5,
108
+ details={
109
+ "retrieval_insufficient": True,
110
+ "retrieval_confidence": retrieval_confidence,
111
+ "hrs": 50,
112
+ "risk_band": "MODERATE",
113
+ "confidence_level": "LOW",
114
+ "total_pipeline_ms": total_ms,
115
+ "module_results": {},
116
+ "warning": (
117
+ "Retrieved context has very low relevance to the question "
118
+ f"(retrieval_confidence={retrieval_confidence:.4f}). "
119
+ "Evaluation scores would be meaningless. "
120
+ "Consider rephrasing the question or expanding the index."
121
+ ),
122
+ },
123
+ latency_ms=total_ms,
124
+ )
125
+
126
+ # -------------------------------------------------------------------------
127
+ # Module 1: Faithfulness (DeBERTa NLI)
128
+ # -------------------------------------------------------------------------
129
+ logger.info("--- Module 1: Faithfulness ---")
130
+ faith_result = score_faithfulness(
131
+ answer=answer,
132
+ context_docs=context_texts,
133
+ chunk_ids=chunk_ids,
134
+ config=config,
135
+ )
136
+
137
+ # -------------------------------------------------------------------------
138
+ # Module 2: Entity Verification (SciSpaCy + RxNorm)
139
+ # -------------------------------------------------------------------------
140
+ logger.info("--- Module 2: Entity Verification ---")
141
+ entity_result = verify_entities(
142
+ answer=answer,
143
+ question=question,
144
+ context_docs=context_texts,
145
+ rxnorm_cache_path=rxnorm_cache_path,
146
+ )
147
+
148
+ # -------------------------------------------------------------------------
149
+ # Module 3: Source Credibility (Evidence Tier)
150
+ # -------------------------------------------------------------------------
151
+ logger.info("--- Module 3: Source Credibility ---")
152
+ source_result = score_source_credibility(retrieved_chunks=context_chunks)
153
+
154
+ # -------------------------------------------------------------------------
155
+ # Module 4: Contradiction Detection (DeBERTa NLI cross-check)
156
+ # -------------------------------------------------------------------------
157
+ logger.info("--- Module 4: Contradiction Detection ---")
158
+ contra_result = score_contradiction(
159
+ answer=answer,
160
+ context_docs=context_texts,
161
+ )
162
+
163
+ # -------------------------------------------------------------------------
164
+ # RAGAS (optional — requires LLM backend)
165
+ # -------------------------------------------------------------------------
166
+ ragas_result: Optional[EvalResult] = None
167
+ if run_ragas:
168
+ logger.info("--- RAGAS Evaluation ---")
169
+ ragas_result = score_ragas(
170
+ question=question,
171
+ answer=answer,
172
+ context_docs=context_texts,
173
+ )
174
+
175
+ # -------------------------------------------------------------------------
176
+ # Aggregator: weighted composite
177
+ # -------------------------------------------------------------------------
178
+ logger.info("--- Aggregator ---")
179
+ agg_result = aggregate(
180
+ faithfulness_result=faith_result,
181
+ entity_result=entity_result,
182
+ source_result=source_result,
183
+ contradiction_result=contra_result,
184
+ ragas_result=ragas_result,
185
+ weights=weights,
186
+ )
187
+
188
+ total_ms = int((time.perf_counter() - t_start) * 1000)
189
+ agg_result.details["total_pipeline_ms"] = total_ms
190
+
191
+ # Attach per-module results for API/dashboard access
192
+ agg_result.details["module_results"] = {
193
+ "faithfulness": {"score": faith_result.score, "details": faith_result.details},
194
+ "entity_verifier": {"score": entity_result.score, "details": entity_result.details},
195
+ "source_credibility": {"score": source_result.score, "details": source_result.details},
196
+ "contradiction": {"score": contra_result.score, "details": contra_result.details},
197
+ "ragas": {"score": ragas_result.score, "details": ragas_result.details} if ragas_result else None,
198
+ }
199
+
200
+ logger.info(
201
+ "=== MediRAG Evaluation DONE: score=%.3f (%s) in %d ms ===",
202
+ agg_result.score,
203
+ agg_result.details.get("confidence_level", "?"),
204
+ total_ms,
205
+ )
206
+ return agg_result
207
+
208
+
209
+ # ---------------------------------------------------------------------------
210
+ # CLI entry point
211
+ # ---------------------------------------------------------------------------
212
+
213
+ def _build_parser() -> argparse.ArgumentParser:
214
+ p = argparse.ArgumentParser(
215
+ description="MediRAG evaluation pipeline (FR-22)"
216
+ )
217
+ p.add_argument("--question", required=True, help="User question")
218
+ p.add_argument("--answer", required=True, help="LLM answer to evaluate")
219
+ p.add_argument("--context-file", default="data/processed/chunks.jsonl",
220
+ help="JSONL file of chunks (output of ingest.py)")
221
+ p.add_argument("--top-k", type=int, default=5,
222
+ help="Number of context chunks to use")
223
+ p.add_argument("--rxnorm-cache", default="data/rxnorm_cache.csv",
224
+ help="Path to rxnorm_cache.csv")
225
+ p.add_argument("--no-ragas", action="store_true",
226
+ help="Skip RAGAS evaluation (no LLM backend needed)")
227
+ p.add_argument("--json", action="store_true",
228
+ help="Output result as JSON")
229
+ return p
230
+
231
+
232
+ def _load_context_from_file(path: str, top_k: int) -> list[dict]:
233
+ """Load top-k chunks from a JSONL file as simple dicts."""
234
+ chunks = []
235
+ try:
236
+ with open(path, "r", encoding="utf-8") as f:
237
+ for line in f:
238
+ line = line.strip()
239
+ if line:
240
+ chunks.append(json.loads(line))
241
+ if len(chunks) >= top_k:
242
+ break
243
+ except FileNotFoundError:
244
+ logger.error("Context file not found: %s", path)
245
+ sys.exit(1)
246
+ return chunks
247
+
248
+
249
+ if __name__ == "__main__":
250
+ import yaml
251
+
252
+ # Load config.yaml for logging setup
253
+ try:
254
+ cfg = yaml.safe_load(Path("config.yaml").read_text())
255
+ log_level = cfg.get("logging", {}).get("level", "INFO")
256
+ except Exception:
257
+ log_level = "INFO"
258
+
259
+ logging.basicConfig(
260
+ level=getattr(logging, log_level, logging.INFO),
261
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
262
+ )
263
+
264
+ args = _build_parser().parse_args()
265
+ chunks = _load_context_from_file(args.context_file, args.top_k)
266
+
267
+ result = run_evaluation(
268
+ question=args.question,
269
+ answer=args.answer,
270
+ context_chunks=chunks,
271
+ rxnorm_cache_path=args.rxnorm_cache,
272
+ run_ragas=not args.no_ragas,
273
+ )
274
+
275
+ if args.json:
276
+ import dataclasses
277
+ print(json.dumps(dataclasses.asdict(result), indent=2))
278
+ else:
279
+ print(f"\n{'='*60}")
280
+ print(f" MediRAG Evaluation Result")
281
+ print(f"{'='*60}")
282
+ print(f" Score : {result.score:.3f}")
283
+ print(f" Confidence : {result.details.get('confidence_level', 'N/A')}")
284
+ print(f" Pipeline time : {result.details.get('total_pipeline_ms', 0)} ms")
285
+ print(f"\n Module Breakdown:")
286
+ for mod, res in (result.details.get("module_results") or {}).items():
287
+ if res:
288
+ print(f" {mod:22s}: {res['score']:.3f}")
289
+ print(f"{'='*60}\n")
src/evaluation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # src/evaluation/__init__.py
src/evaluation/aggregator.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FR-19: src/evaluation/aggregator.py — Weighted Score Aggregation
3
+ ================================================================
4
+ Combines scores from all evaluation modules into a single composite score
5
+ using the fixed weights defined in SRS Section 8.2.
6
+
7
+ Weights (must sum to 1.0):
8
+ faithfulness : 0.35 (primary signal — DeBERTa NLI)
9
+ entity_accuracy : 0.20 (SciSpaCy NER + RxNorm)
10
+ source_credibility : 0.20 (evidence tier)
11
+ contradiction_risk : 0.15 (1.0 - contradiction_score)
12
+ ragas_composite : 0.10 (optional — 0.5 neutral if unavailable)
13
+
14
+ Output:
15
+ EvalResult with:
16
+ module_name = "aggregator"
17
+ score = weighted composite in [0, 1]
18
+ details = {weights_used, weighted_composite, component_contributions}
19
+
20
+ Usage:
21
+ from src.evaluation.aggregator import aggregate
22
+ agg_result = aggregate(faith_res, entity_res, source_res, contra_res, ragas_res)
23
+ """
24
+ from __future__ import annotations
25
+
26
+ import logging
27
+ import time
28
+ from typing import Optional
29
+
30
+ from src.modules.base import EvalResult
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # ---------------------------------------------------------------------------
35
+ # Default weights (SRS Section 8.2)
36
+ # ---------------------------------------------------------------------------
37
+
38
+ DEFAULT_WEIGHTS: dict[str, float] = {
39
+ "faithfulness": 0.35,
40
+ "entity_accuracy": 0.20,
41
+ "source_credibility": 0.20,
42
+ "contradiction_risk": 0.15,
43
+ "ragas_composite": 0.10,
44
+ }
45
+
46
+
47
+ def aggregate(
48
+ faithfulness_result: EvalResult,
49
+ entity_result: EvalResult,
50
+ source_result: EvalResult,
51
+ contradiction_result: EvalResult,
52
+ ragas_result: Optional[EvalResult] = None,
53
+ weights: Optional[dict[str, float]] = None,
54
+ ) -> EvalResult:
55
+ """
56
+ Aggregate all module scores into a single composite evaluation result.
57
+
58
+ Args:
59
+ faithfulness_result : Output from faithfulness.score_faithfulness()
60
+ entity_result : Output from entity_verifier.verify_entities()
61
+ source_result : Output from source_credibility.score_source_credibility()
62
+ contradiction_result : Output from contradiction.score_contradiction()
63
+ ragas_result : Output from ragas_eval.score_ragas() (optional)
64
+ weights : Override default weights (must sum to 1.0)
65
+
66
+ Returns:
67
+ EvalResult with module_name="aggregator" and composite score.
68
+ """
69
+ t0 = time.perf_counter()
70
+ w = weights or DEFAULT_WEIGHTS
71
+
72
+ # Validate weights sum to 1.0 (tolerance 0.01)
73
+ weight_sum = sum(w.values())
74
+ if abs(weight_sum - 1.0) > 0.01:
75
+ logger.warning(
76
+ "Weights sum to %.4f (expected 1.0) — normalising.", weight_sum
77
+ )
78
+ w = {k: v / weight_sum for k, v in w.items()}
79
+
80
+ # Extract scores — use 0.5 neutral for any unavailable module
81
+ faith_score = faithfulness_result.score if not faithfulness_result.error else 0.5
82
+ entity_score = entity_result.score if not entity_result.error else 0.5
83
+ source_score = source_result.score if not source_result.error else 0.5
84
+ contra_score = contradiction_result.score if not contradiction_result.error else 1.0
85
+ ragas_score = (ragas_result.score if ragas_result and not ragas_result.error else 0.5)
86
+
87
+ # Compute base weighted contributions
88
+ contributions = {
89
+ "faithfulness_contribution": round(faith_score * w["faithfulness"], 4),
90
+ "entity_contribution": round(entity_score * w["entity_accuracy"], 4),
91
+ "source_contribution": round(source_score * w["source_credibility"], 4),
92
+ "contradiction_contribution": round(contra_score * w["contradiction_risk"], 4),
93
+ "ragas_contribution": round(ragas_score * w["ragas_composite"], 4),
94
+ }
95
+
96
+ base_composite = sum(contributions.values())
97
+
98
+ # --- Non-linear Safety Penalties ---
99
+ # Faithfulness penalty: applies when answer is not grounded in context.
100
+ # Contradiction penalty: only applies when actual contradictions are detected
101
+ # (score < 0.3). Score = 0.5 means "neutral/cannot verify" (refusal answers,
102
+ # no keyword overlap) — these should NOT be double-penalized.
103
+ penalty_multiplier = 1.0
104
+ if faith_score <= 0.6:
105
+ penalty_multiplier *= 0.6 # 40% penalty for ungrounded claims
106
+ if contra_score < 0.3:
107
+ penalty_multiplier *= 0.6 # 40% penalty only for confirmed contradictions
108
+
109
+ composite = base_composite * penalty_multiplier
110
+
111
+ # HRS = round(100 × (1 - composite)), then map to risk band
112
+ # Thresholds must match config.yaml aggregator.risk_bands
113
+ _HRS_LOW = 30
114
+ _HRS_MODERATE = 60
115
+ _HRS_HIGH = 85
116
+
117
+ hrs = int(round(100 * (1.0 - composite)))
118
+ hrs = max(0, min(100, hrs))
119
+
120
+ if hrs <= _HRS_LOW:
121
+ risk_band = "LOW"
122
+ elif hrs <= _HRS_MODERATE:
123
+ risk_band = "MODERATE"
124
+ elif hrs <= _HRS_HIGH:
125
+ risk_band = "HIGH"
126
+ else:
127
+ risk_band = "CRITICAL"
128
+
129
+ # Confidence level (based on composite, not HRS)
130
+ if composite >= 0.80:
131
+ confidence = "HIGH"
132
+ elif composite >= 0.55:
133
+ confidence = "MODERATE"
134
+ else:
135
+ confidence = "LOW"
136
+
137
+ details = {
138
+ "weights_used": {k: round(v, 4) for k, v in w.items()},
139
+ "component_scores": {
140
+ "faithfulness": round(faith_score, 4),
141
+ "entity_accuracy": round(entity_score, 4),
142
+ "source_credibility": round(source_score, 4),
143
+ "contradiction_risk": round(contra_score, 4),
144
+ "ragas_composite": round(ragas_score, 4),
145
+ },
146
+ "weighted_composite": round(composite, 4),
147
+ "hrs": hrs,
148
+ "risk_band": risk_band,
149
+ "component_contributions": contributions,
150
+ "confidence_level": confidence,
151
+ "module_latencies_ms": {
152
+ "faithfulness": faithfulness_result.latency_ms,
153
+ "entity_verifier": entity_result.latency_ms,
154
+ "source_credibility": source_result.latency_ms,
155
+ "contradiction": contradiction_result.latency_ms,
156
+ "ragas": ragas_result.latency_ms if ragas_result else 0,
157
+ },
158
+ }
159
+
160
+ latency_ms = int((time.perf_counter() - t0) * 1000)
161
+ logger.info(
162
+ "Aggregated score: %.3f (%s confidence) — "
163
+ "faith=%.2f entity=%.2f source=%.2f contra=%.2f ragas=%.2f",
164
+ composite, confidence,
165
+ faith_score, entity_score, source_score, contra_score, ragas_score,
166
+ )
167
+
168
+ return EvalResult(
169
+ module_name="aggregator",
170
+ score=composite,
171
+ details=details,
172
+ latency_ms=latency_ms,
173
+ )
src/evaluation/ragas_eval.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FR-06: src/evaluation/ragas_eval.py — RAGAS Faithfulness + Answer Relevancy
3
+ =============================================================================
4
+ Wraps the ragas library to compute:
5
+ - faithfulness : context-grounded claim verification
6
+ - answer_relevancy : semantic similarity of answer to question
7
+
8
+ Requires an LLM backend. Supported backends (in priority order):
9
+ 1. Ollama (local, free) — set OLLAMA_HOST env var or use default localhost:11434
10
+ 2. OpenAI API — set OPENAI_API_KEY env var
11
+ 3. Graceful degradation — returns score=None with explanation if no LLM available
12
+
13
+ Usage:
14
+ from src.evaluation.ragas_eval import score_ragas
15
+ result = score_ragas(question, answer, context_docs)
16
+
17
+ SRS reference: FR-06, Section 7 (Evaluation Pipeline)
18
+ """
19
+ from __future__ import annotations
20
+
21
+ import logging
22
+ import os
23
+ import time
24
+ from typing import Optional
25
+
26
+ from src.modules.base import EvalResult
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Backend detection
33
+ # ---------------------------------------------------------------------------
34
+
35
+ def _detect_llm_backend() -> Optional[str]:
36
+ """Return 'ollama', 'openai', or None."""
37
+ if os.getenv("OPENAI_API_KEY"):
38
+ return "openai"
39
+ # Check if Ollama is running locally
40
+ try:
41
+ import requests
42
+ host = os.getenv("OLLAMA_HOST", "http://localhost:11434")
43
+ resp = requests.get(f"{host}/api/tags", timeout=2)
44
+ if resp.status_code == 200:
45
+ return "ollama"
46
+ except Exception:
47
+ pass
48
+ return None
49
+
50
+
51
+ def _build_ragas_llm(backend: str):
52
+ """Build a ragas-compatible LLM wrapper."""
53
+ if backend == "openai":
54
+ from langchain_openai import ChatOpenAI
55
+ return ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
56
+ elif backend == "ollama":
57
+ from langchain_community.chat_models import ChatOllama
58
+ host = os.getenv("OLLAMA_HOST", "http://localhost:11434")
59
+ model = os.getenv("OLLAMA_MODEL", "mistral")
60
+ return ChatOllama(base_url=host, model=model)
61
+ raise ValueError(f"Unknown backend: {backend}")
62
+
63
+
64
+ def _build_ragas_embeddings(backend: str):
65
+ """Build a ragas-compatible embeddings wrapper."""
66
+ if backend == "openai":
67
+ from langchain_openai import OpenAIEmbeddings
68
+ return OpenAIEmbeddings()
69
+ elif backend == "ollama":
70
+ from langchain_community.embeddings import OllamaEmbeddings
71
+ host = os.getenv("OLLAMA_HOST", "http://localhost:11434")
72
+ model = os.getenv("OLLAMA_EMBED_MODEL", "nomic-embed-text")
73
+ return OllamaEmbeddings(base_url=host, model=model)
74
+ raise ValueError(f"Unknown backend: {backend}")
75
+
76
+
77
+ # ---------------------------------------------------------------------------
78
+ # Public API
79
+ # ---------------------------------------------------------------------------
80
+
81
+ def score_ragas(
82
+ question: str,
83
+ answer: str,
84
+ context_docs: list[str],
85
+ max_contexts: int = 3,
86
+ ) -> EvalResult:
87
+ """
88
+ Compute RAGAS faithfulness and answer_relevancy scores.
89
+
90
+ Args:
91
+ question : Original user question.
92
+ answer : LLM-generated answer.
93
+ context_docs : Retrieved context passages.
94
+ max_contexts : Max context chunks to pass to RAGAS (to limit token cost).
95
+
96
+ Returns:
97
+ EvalResult with module_name="ragas", score in [0,1].
98
+ score = mean(faithfulness, answer_relevancy).
99
+ Returns score=0.5 (neutral) if no LLM backend is available.
100
+ """
101
+ t0 = time.perf_counter()
102
+
103
+ backend = _detect_llm_backend()
104
+ if backend is None:
105
+ logger.warning(
106
+ "No LLM backend available for RAGAS. "
107
+ "Set OPENAI_API_KEY or start Ollama (ollama serve). "
108
+ "Returning neutral score (0.5)."
109
+ )
110
+ return EvalResult(
111
+ module_name="ragas",
112
+ score=0.5,
113
+ details={
114
+ "backend": None,
115
+ "faithfulness": None,
116
+ "answer_relevancy": None,
117
+ "note": "No LLM backend — set OPENAI_API_KEY or start Ollama",
118
+ },
119
+ latency_ms=int((time.perf_counter() - t0) * 1000),
120
+ )
121
+
122
+ try:
123
+ from datasets import Dataset
124
+ from ragas import evaluate
125
+ from ragas.metrics import faithfulness, answer_relevancy
126
+
127
+ llm = _build_ragas_llm(backend)
128
+ embeddings = _build_ragas_embeddings(backend)
129
+
130
+ # Configure metrics to use our chosen backend
131
+ faithfulness.llm = llm
132
+ faithfulness.embeddings = embeddings
133
+ answer_relevancy.llm = llm
134
+ answer_relevancy.embeddings = embeddings
135
+
136
+ contexts = context_docs[:max_contexts]
137
+ dataset = Dataset.from_dict(
138
+ {
139
+ "question": [question],
140
+ "answer": [answer],
141
+ "contexts": [contexts],
142
+ }
143
+ )
144
+
145
+ result = evaluate(dataset, metrics=[faithfulness, answer_relevancy])
146
+
147
+ faith_score = float(result["faithfulness"])
148
+ relevancy_score = float(result["answer_relevancy"])
149
+ composite = (faith_score + relevancy_score) / 2.0
150
+
151
+ details = {
152
+ "backend": backend,
153
+ "faithfulness": round(faith_score, 4),
154
+ "answer_relevancy": round(relevancy_score, 4),
155
+ }
156
+
157
+ latency_ms = int((time.perf_counter() - t0) * 1000)
158
+ logger.info(
159
+ "RAGAS: faith=%.3f, relevancy=%.3f → composite=%.3f in %d ms",
160
+ faith_score, relevancy_score, composite, latency_ms,
161
+ )
162
+ return EvalResult(
163
+ module_name="ragas",
164
+ score=composite,
165
+ details=details,
166
+ latency_ms=latency_ms,
167
+ )
168
+
169
+ except Exception as exc:
170
+ logger.error("RAGAS evaluation failed: %s", exc)
171
+ return EvalResult(
172
+ module_name="ragas",
173
+ score=0.5,
174
+ details={"backend": backend, "error": str(exc)},
175
+ error=str(exc),
176
+ latency_ms=int((time.perf_counter() - t0) * 1000),
177
+ )
src/modules/__init__.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/modules/base.py — Shared EvalResult dataclass.
3
+ Used as the standard output schema by all 4 evaluation modules.
4
+ Details shape per module is fully specified here (SRS Section 5).
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import logging
9
+ from dataclasses import dataclass, field
10
+ from typing import Any, Optional
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class EvalResult:
17
+ """
18
+ Shared output schema for all evaluation modules.
19
+
20
+ Attributes:
21
+ module_name : Identifier string, e.g. "faithfulness"
22
+ score : Module score in [0.0, 1.0] — clipped automatically
23
+ details : Module-specific dict (see DETAILS SHAPES below)
24
+ error : None if successful; error message string if module failed
25
+ latency_ms : Wall-clock milliseconds for this module's execution
26
+ """
27
+
28
+ module_name: str
29
+ score: float
30
+ details: dict[str, Any] = field(default_factory=dict)
31
+ error: Optional[str] = None
32
+ latency_ms: int = 0
33
+
34
+ def __post_init__(self) -> None:
35
+ """Clip score to [0.0, 1.0] as required by SRS 4.2."""
36
+ if not (0.0 <= self.score <= 1.0):
37
+ logger.warning(
38
+ "%s: score %.4f out of [0,1], clipping.",
39
+ self.module_name,
40
+ self.score,
41
+ )
42
+ self.score = max(0.0, min(1.0, self.score))
43
+
44
+ # -------------------------------------------------------------------------
45
+ # DETAILS SHAPE REFERENCE (SRS Section 5)
46
+ # -------------------------------------------------------------------------
47
+ #
48
+ # faithfulness.details:
49
+ # {
50
+ # "total_claims": int,
51
+ # "entailed_count": int,
52
+ # "neutral_count": int,
53
+ # "contradicted_count": int,
54
+ # "claims": [
55
+ # {
56
+ # "claim": str,
57
+ # "status": "ENTAILED" | "NEUTRAL" | "CONTRADICTED",
58
+ # "best_chunk_id": str, # chunk with highest NLI score
59
+ # "nli_score": float
60
+ # }
61
+ # ]
62
+ # }
63
+ #
64
+ # entity_verifier.details:
65
+ # {
66
+ # "total_entities": int,
67
+ # "verified_count": int,
68
+ # "flagged_count": int,
69
+ # "entities": [
70
+ # {
71
+ # "entity": str,
72
+ # "type": "DRUG" | "DOSAGE" | "CONDITION" | "PROCEDURE",
73
+ # "status": "VERIFIED" | "FLAGGED" | "NOT_FOUND",
74
+ # "severity": "CRITICAL" | "MODERATE" | "MINOR" | null,
75
+ # "answer_value": str,
76
+ # "context_value": str | null,
77
+ # "rxcui": str | null
78
+ # }
79
+ # ]
80
+ # }
81
+ #
82
+ # source_credibility.details:
83
+ # {
84
+ # "method_used": "keyword" | "metadata",
85
+ # "chunks": [
86
+ # {
87
+ # "chunk_id": str,
88
+ # "tier": int, # 1–5
89
+ # "tier_weight": float,
90
+ # "pub_type": str,
91
+ # "title": str,
92
+ # "matched_keyword": str | null
93
+ # }
94
+ # ]
95
+ # }
96
+ #
97
+ # contradiction.details:
98
+ # {
99
+ # "total_sentences": int,
100
+ # "checked_pairs": int,
101
+ # "contradicted_pairs": int,
102
+ # "pairs": [
103
+ # {
104
+ # "sentence_a": str,
105
+ # "sentence_b": str,
106
+ # "contradiction_score": float,
107
+ # "flagged": bool
108
+ # }
109
+ # ]
110
+ # }
111
+ #
112
+ # aggregator.details:
113
+ # {
114
+ # "weights_used": {
115
+ # "faithfulness": float,
116
+ # "entity_accuracy": float,
117
+ # "source_credibility": float,
118
+ # "contradiction_risk": float
119
+ # },
120
+ # "weighted_composite": float,
121
+ # "component_contributions": {
122
+ # "faithfulness_contribution": float,
123
+ # "entity_contribution": float,
124
+ # "source_contribution": float,
125
+ # "contradiction_contribution": float
126
+ # }
127
+ # }
src/modules/base.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """src/modules/base.py — see __init__.py of this package for EvalResult."""
2
+ from src.modules import EvalResult # re-export for convenience
3
+
4
+ __all__ = ["EvalResult"]
src/modules/contradiction.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FR-17: src/modules/contradiction.py — Module 4: Cross-Document Contradiction Detection
3
+ ========================================================================================
4
+ Uses the same DeBERTa NLI cross-encoder (cross-encoder/nli-deberta-v3-small) to
5
+ detect contradictions between the LLM answer and retrieved context passages.
6
+
7
+ Algorithm (SRS Section 6.4):
8
+ 1. Split answer into sentences (claims)
9
+ 2. Split each context chunk into sentences
10
+ 3. For each (answer_sentence, context_sentence) pair:
11
+ - Run NLI → get contradiction score
12
+ - If contradiction_score ≥ CONTRADICTION_THRESHOLD → flag pair
13
+ 4. score = 1.0 - (flagged_pairs / total_pairs)
14
+
15
+ This module shares the NLI model instance with faithfulness.py when both
16
+ run in the same process (the model is cached at the faithfulness module level).
17
+
18
+ Design note:
19
+ To keep latency manageable, context sentences are limited to
20
+ MAX_CONTEXT_SENTS per chunk and total pairs are capped at MAX_PAIRS.
21
+ """
22
+ from __future__ import annotations
23
+
24
+ import logging
25
+ import time
26
+
27
+ import pysbd
28
+
29
+ from src.modules.base import EvalResult
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Constants
35
+ # ---------------------------------------------------------------------------
36
+
37
+ CONTRADICTION_THRESHOLD = 0.50 # Balanced: catches real contradictions without over-flagging
38
+ MIN_KEYWORD_OVERLAP = 3 # At least 3 meaningful words in common before running NLI
39
+ MAX_CONTEXT_SENTS = 4 # top N sentences per context chunk
40
+ MAX_PAIRS = 200 # hard cap to keep latency bounded (~2-3s)
41
+
42
+ _segmenter = None
43
+
44
+ # Common stopwords to ignore in overlap check
45
+ _STOPWORDS = {
46
+ "the", "a", "an", "is", "in", "of", "to", "for", "and", "or", "are",
47
+ "be", "at", "by", "if", "it", "as", "on", "with", "this", "that",
48
+ "was", "were", "not", "no", "have", "has", "had", "but", "so", "from",
49
+ "should", "may", "can", "will", "than", "more", "when", "which", "who",
50
+ "what", "all", "each", "after", "before", "been", "do", "does", "1",
51
+ "2", "3", "mg", "iv", "od", "per", "day", "based", "using", "include",
52
+ }
53
+
54
+
55
+ def _get_segmenter():
56
+ """Lazily load and return the pysbd segmenter."""
57
+ global _segmenter
58
+ if _segmenter is None:
59
+ try:
60
+ import pysbd
61
+ _segmenter = pysbd.Segmenter(language="en", clean=False)
62
+ except ImportError:
63
+ logger.warning("pysbd not installed, falling back to naive sentence splitting.")
64
+ _segmenter = "stub" # Use a string to indicate stub mode
65
+ except Exception as e:
66
+ logger.error("Failed to initialize pysbd segmenter: %s", e)
67
+ _segmenter = "stub"
68
+ return _segmenter
69
+
70
+
71
+ def _keyword_overlap(sent_a: str, sent_b: str) -> int:
72
+ """Count shared content words between two sentences."""
73
+ tokens_a = {w.lower() for w in sent_a.split() if w.lower() not in _STOPWORDS and len(w) > 2}
74
+ tokens_b = {w.lower() for w in sent_b.split() if w.lower() not in _STOPWORDS and len(w) > 2}
75
+ return len(tokens_a & tokens_b)
76
+
77
+
78
+ def _segment(text: str) -> list[str]:
79
+ """Segment text into sentences using pysbd or a fallback."""
80
+ seg = _get_segmenter()
81
+ try:
82
+ if seg == "stub":
83
+ return [s.strip() for s in text.split(".") if s.strip()]
84
+ else:
85
+ return [s.strip() for s in seg.segment(text) if s.strip()]
86
+ except Exception:
87
+ return [s.strip() for s in text.split(".") if s.strip()]
88
+
89
+
90
+ # ---------------------------------------------------------------------------
91
+ # Public API
92
+ # ---------------------------------------------------------------------------
93
+
94
+ def score_contradiction(
95
+ answer: str,
96
+ context_docs: list[str],
97
+ max_chunks: int = 5,
98
+ ) -> EvalResult:
99
+ """
100
+ Detect contradictions between the LLM answer and retrieved context.
101
+
102
+ Args:
103
+ answer : LLM-generated answer text.
104
+ context_docs : List of retrieved context passage strings.
105
+ max_chunks : Max number of context chunks to evaluate.
106
+
107
+ Returns:
108
+ EvalResult with module_name="contradiction", score in [0,1] where
109
+ 1.0 = no contradictions detected, 0.0 = all pairs contradicted.
110
+ """
111
+ t0 = time.perf_counter()
112
+
113
+ if not answer or not context_docs:
114
+ return EvalResult(
115
+ module_name="contradiction",
116
+ score=0.5, # neutral — cannot verify with missing input
117
+ details={"total_sentences": 0, "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []},
118
+ latency_ms=0,
119
+ )
120
+
121
+ # Import model via faithfulness module (shared cache)
122
+ try:
123
+ from src.modules.faithfulness import _get_model, LABEL_CONTRADICTION
124
+ except ImportError:
125
+ # (Lazy imports to prevent startup crashes when libraries aren't installed yet)
126
+ try:
127
+ from sentence_transformers import CrossEncoder
128
+ _model = CrossEncoder("cross-encoder/nli-deberta-v3-small")
129
+ _get_model = lambda: _model # noqa: E731
130
+ LABEL_CONTRADICTION = 0
131
+ except ImportError:
132
+ logger.error("sentence-transformers not installed. Cannot run NLI model.")
133
+ return EvalResult(
134
+ module_name="contradiction",
135
+ score=1.0,
136
+ details={},
137
+ error="NLI model (sentence-transformers) not installed.",
138
+ latency_ms=int((time.perf_counter() - t0) * 1000),
139
+ )
140
+ except Exception as exc:
141
+ logger.error("Failed to load NLI model: %s", exc)
142
+ return EvalResult(
143
+ module_name="contradiction",
144
+ score=1.0,
145
+ details={},
146
+ error=f"Failed to load NLI model: {exc}",
147
+ latency_ms=int((time.perf_counter() - t0) * 1000),
148
+ )
149
+
150
+ model = _get_model()
151
+
152
+ # Strip markdown/citations from answer before NLI (same reason as faithfulness.py)
153
+ import re as _re
154
+ _MD = _re.compile(
155
+ r'\[Source:[^\]]*\]|\[[^\]]{0,120}\]' # citations
156
+ r'|\*\*([^*]+)\*\*|\*([^*]+)\*' # bold/italic → keep text
157
+ r'|`[^`]+`' # code
158
+ )
159
+ answer = _MD.sub(lambda m: (m.group(1) or m.group(2) or ''), answer).strip()
160
+
161
+ # Segment answer into claims
162
+ answer_sents = _segment(answer)
163
+ if not answer_sents:
164
+ return EvalResult(
165
+ module_name="contradiction",
166
+ score=0.5, # neutral — cannot verify with no sentences
167
+ details={"total_sentences": 0, "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []},
168
+ latency_ms=0,
169
+ )
170
+
171
+ # Segment context chunks
172
+ docs = context_docs[:max_chunks]
173
+ context_sents: list[str] = []
174
+ for doc in docs:
175
+ sents = _segment(doc)[:MAX_CONTEXT_SENTS]
176
+ context_sents.extend(sents)
177
+
178
+ if not context_sents:
179
+ return EvalResult(
180
+ module_name="contradiction",
181
+ score=0.5, # neutral — cannot verify with no context sentences
182
+ details={"total_sentences": len(answer_sents), "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []},
183
+ latency_ms=0,
184
+ )
185
+
186
+ # Build pairs WITH topical pre-filter (skip unrelated sentence pairs entirely)
187
+ all_pairs: list[tuple[str, str]] = []
188
+ for a_sent in answer_sents:
189
+ for c_sent in context_sents:
190
+ if _keyword_overlap(a_sent, c_sent) >= MIN_KEYWORD_OVERLAP:
191
+ all_pairs.append((a_sent, c_sent))
192
+ if len(all_pairs) >= MAX_PAIRS:
193
+ break
194
+ if len(all_pairs) >= MAX_PAIRS:
195
+ break
196
+
197
+ if not all_pairs:
198
+ # Topically unrelated — cannot check for contradictions
199
+ return EvalResult(
200
+ module_name="contradiction",
201
+ score=0.5, # neutral — no overlapping pairs to evaluate
202
+ details={"total_sentences": len(answer_sents), "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []},
203
+ latency_ms=int((time.perf_counter() - t0) * 1000),
204
+ )
205
+
206
+ # Batch NLI inference
207
+ try:
208
+ scores_raw = model.predict(all_pairs, apply_softmax=True)
209
+ except Exception as exc:
210
+ logger.error("Contradiction NLI inference failed: %s", exc)
211
+ return EvalResult(
212
+ module_name="contradiction",
213
+ score=1.0,
214
+ details={},
215
+ error=f"Model inference error: {exc}",
216
+ latency_ms=int((time.perf_counter() - t0) * 1000),
217
+ )
218
+
219
+ # Collect flagged pairs
220
+ pair_details: list[dict] = []
221
+ contradicted = 0
222
+ total = len(all_pairs)
223
+
224
+ for i, (a_sent, c_sent) in enumerate(all_pairs):
225
+ con_score = float(scores_raw[i][LABEL_CONTRADICTION])
226
+ flagged = con_score >= CONTRADICTION_THRESHOLD
227
+ if flagged:
228
+ contradicted += 1
229
+ # Only log the most severe contradictions to keep details manageable
230
+ pair_details.append(
231
+ {
232
+ "sentence_a": a_sent[:120],
233
+ "sentence_b": c_sent[:120],
234
+ "contradiction_score": round(con_score, 4),
235
+ "flagged": True,
236
+ }
237
+ )
238
+
239
+ # Score: 1.0 = clean, lower = more contradictions found
240
+ score = 1.0 - (contradicted / total) if total > 0 else 1.0
241
+
242
+ details = {
243
+ "total_sentences": len(answer_sents),
244
+ "checked_pairs": total,
245
+ "contradicted_pairs": contradicted,
246
+ "pairs": pair_details[:20], # cap output to top 20 flagged pairs
247
+ }
248
+
249
+ latency_ms = int((time.perf_counter() - t0) * 1000)
250
+ logger.info(
251
+ "Contradiction: %.3f (%d/%d pairs flagged) in %d ms",
252
+ score, contradicted, total, latency_ms,
253
+ )
254
+ return EvalResult(
255
+ module_name="contradiction",
256
+ score=score,
257
+ details=details,
258
+ latency_ms=latency_ms,
259
+ )
src/modules/entity_verifier.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FR-09: src/modules/entity_verifier.py — Module 2: Medical Entity Verification
3
+ ==============================================================================
4
+ Uses SciSpaCy NER (en_core_sci_lg) to extract medical entities from the answer,
5
+ then verifies drug entities against the RxNorm cache and/or REST API.
6
+
7
+ Verification pipeline (SRS Section 6.2):
8
+ 1. NER: extract DRUG, DOSAGE, CONDITION, PROCEDURE entities from answer
9
+ 2. For each DRUG entity:
10
+ a. Look up in local rxnorm_cache.csv (fast, offline)
11
+ b. If not found, query RxNorm REST API /approximateTerm (live fallback)
12
+ c. If still not found, mark as NOT_FOUND
13
+ 3. Cross-check entity presence in context docs (optional validation)
14
+ 4. Score = verified_drug_count / total_drug_count (non-drug entities have no score impact)
15
+
16
+ Entity status values:
17
+ VERIFIED — drug found in RxNorm cache or API with rxcui
18
+ FLAGGED — entity found but has a known dangerous synonym conflict
19
+ NOT_FOUND — drug name not resolvable via any layer
20
+
21
+ Severity mapping (for FLAGGED):
22
+ brand ↔ generic mismatch → CRITICAL
23
+ dosage discrepancy → MODERATE
24
+ minor synonym variant → MINOR
25
+ """
26
+ from __future__ import annotations
27
+
28
+ import logging
29
+ import re
30
+ import time
31
+ from functools import lru_cache
32
+ from pathlib import Path
33
+ from typing import Optional
34
+
35
+ import pandas as pd
36
+ import requests
37
+
38
+ from src.modules.base import EvalResult
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # Constants
44
+ # ---------------------------------------------------------------------------
45
+
46
+ RXNORM_APPROX_URL = "https://rxnav.nlm.nih.gov/REST/approximateTerm.json"
47
+ DEFAULT_CACHE_PATH = "data/rxnorm_cache.csv"
48
+ NER_MODEL = "en_ner_bc5cdr_md"
49
+ DOSAGE_TOLERANCE_PCT = 10 # flag if answer dose differs from context dose by > 10%
50
+
51
+ # Matches clinical dose values: "500 mg", "2.5 mcg/kg", "10 IU", etc.
52
+ _DOSE_RE = re.compile(
53
+ r'(\d+(?:\.\d+)?)\s*(?:mg|mcg|g\b|ml|iu|units?|mg/kg|mg/dl)',
54
+ re.IGNORECASE,
55
+ )
56
+
57
+ # Map spacy entity labels to our schema types
58
+ _ENTITY_TYPE_MAP = {
59
+ # en_core_sci_lg (CRAFT corpus) labels
60
+ "CHEBI": "DRUG", # Chemical Entities of Biological Interest — covers drugs
61
+ "GGP": "CONDITION", # Gene or Gene Product
62
+ "SO": "CONDITION", # Sequence Ontology
63
+ "TAXON": "CONDITION",
64
+ "GO": "CONDITION", # Gene Ontology
65
+ "CL": "CONDITION", # Cell Line
66
+ "DNA": "CONDITION",
67
+ "RNA": "CONDITION",
68
+ "CELL_TYPE": "CONDITION",
69
+ "CELL_LINE": "CONDITION",
70
+ "PROTEIN": "CONDITION",
71
+ # BC5CDR labels (used by some scispacy models)
72
+ "Chemical": "DRUG",
73
+ "Disease": "CONDITION",
74
+ # Generic / fallback labels
75
+ "CHEMICAL": "DRUG",
76
+ "DRUG": "DRUG",
77
+ "COMPOUND": "DRUG",
78
+ "DISEASE": "CONDITION",
79
+ "SYMPTOM": "CONDITION",
80
+ "PROCEDURE": "PROCEDURE",
81
+ "DOSAGE": "DOSAGE",
82
+ }
83
+ DRUG_TYPES = {"DRUG"} # only these get verified against RxNorm
84
+
85
+ # ---------------------------------------------------------------------------
86
+ # Module-level caches
87
+ # ---------------------------------------------------------------------------
88
+
89
+ _spacy_model = None
90
+ _rxnorm_cache: dict[str, str] | None = None # drug_name -> rxcui
91
+ _rxnorm_cache_path: str = DEFAULT_CACHE_PATH
92
+
93
+
94
+ def _get_spacy_model():
95
+ global _spacy_model
96
+ if _spacy_model is None:
97
+ import spacy
98
+ logger.info("Loading SciSpaCy NER model: %s (first call only)", NER_MODEL)
99
+ try:
100
+ _spacy_model = spacy.load(NER_MODEL)
101
+ logger.info("SciSpaCy model loaded.")
102
+ except OSError as exc:
103
+ logger.error(
104
+ "Failed to load '%s'. Install with: "
105
+ "pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/"
106
+ "releases/v0.5.4/en_ner_bc5cdr_md-0.5.4.tar.gz\nError: %s",
107
+ NER_MODEL, exc,
108
+ )
109
+ raise
110
+ return _spacy_model
111
+
112
+
113
+ def _load_rxnorm_cache(cache_path: str) -> dict[str, str]:
114
+ """Load the RxNorm cache CSV into a lowercase drug_name → rxcui dict."""
115
+ path = Path(cache_path)
116
+ if not path.exists():
117
+ logger.warning(
118
+ "RxNorm cache not found at '%s'. Live API only will be used.", cache_path
119
+ )
120
+ return {}
121
+ try:
122
+ df = pd.read_csv(path, dtype=str)
123
+ cache = {
124
+ str(row["drug_name"]).strip().lower(): str(row["rxcui"]).strip()
125
+ for _, row in df.iterrows()
126
+ if pd.notna(row.get("drug_name")) and pd.notna(row.get("rxcui"))
127
+ and str(row.get("rxcui", "")).strip()
128
+ }
129
+ logger.info("RxNorm cache loaded: %d entries from %s", len(cache), cache_path)
130
+ return cache
131
+ except Exception as exc:
132
+ logger.warning("Failed to load RxNorm cache: %s", exc)
133
+ return {}
134
+
135
+
136
+ def _get_rxnorm_cache(cache_path: str) -> dict[str, str]:
137
+ global _rxnorm_cache, _rxnorm_cache_path
138
+ if _rxnorm_cache is None or cache_path != _rxnorm_cache_path:
139
+ _rxnorm_cache_path = cache_path
140
+ _rxnorm_cache = _load_rxnorm_cache(cache_path)
141
+ return _rxnorm_cache
142
+
143
+
144
+ def _extract_doses_near(text: str, drug_name: str, window: int = 180) -> list[float]:
145
+ """Return numeric dose values found within `window` chars of `drug_name` in `text`."""
146
+ idx = text.lower().find(drug_name.lower())
147
+ if idx == -1:
148
+ return []
149
+ vicinity = text[max(0, idx - window // 2): idx + len(drug_name) + window]
150
+ return [float(m.group(1)) for m in _DOSE_RE.finditer(vicinity)]
151
+
152
+
153
+ def _lookup_rxnorm_api(drug_name: str, timeout: int = 4) -> Optional[str]:
154
+ """Query RxNorm REST API. Returns rxcui string or None."""
155
+ try:
156
+ resp = requests.get(
157
+ RXNORM_APPROX_URL,
158
+ params={"term": drug_name, "maxEntries": "1", "option": "1"},
159
+ timeout=timeout,
160
+ )
161
+ if resp.status_code != 200:
162
+ return None
163
+ candidates = (
164
+ resp.json()
165
+ .get("approximateGroup", {})
166
+ .get("candidate", [])
167
+ )
168
+ if candidates:
169
+ return str(candidates[0].get("rxcui", "")).strip() or None
170
+ except Exception:
171
+ pass
172
+ return None
173
+
174
+
175
+ # ---------------------------------------------------------------------------
176
+ # Public API
177
+ # ---------------------------------------------------------------------------
178
+
179
+ def verify_entities(
180
+ answer: str,
181
+ question: str = "",
182
+ context_docs: list[str] | None = None,
183
+ rxnorm_cache_path: str = DEFAULT_CACHE_PATH,
184
+ use_api_fallback: bool = True,
185
+ ) -> EvalResult:
186
+ """
187
+ Extract and verify medical entities from the LLM answer.
188
+
189
+ Args:
190
+ answer : LLM-generated answer text.
191
+ question : Original question (NER'd alongside answer for richer entity set).
192
+ context_docs : Retrieved context passages (used for cross-checking).
193
+ rxnorm_cache_path : Path to rxnorm_cache.csv.
194
+ use_api_fallback : Whether to call RxNorm REST API for cache misses.
195
+
196
+ Returns:
197
+ EvalResult with module_name="entity_verifier", score in [0,1], and
198
+ details matching the shape from src/modules/__init__.py.
199
+ """
200
+ t0 = time.perf_counter()
201
+
202
+ # --- NER -----------------------------------------------------------------
203
+ try:
204
+ nlp = _get_spacy_model()
205
+ except Exception as exc:
206
+ return EvalResult(
207
+ module_name="entity_verifier",
208
+ score=0.5, # neutral fallback — don't penalise if model not available
209
+ details={"error": str(exc), "entities": []},
210
+ error=f"NER model unavailable: {exc}",
211
+ latency_ms=int((time.perf_counter() - t0) * 1000),
212
+ )
213
+
214
+ # Combine question + answer for richer entity extraction
215
+ combined_text = f"{question} {answer}" if question else answer
216
+ doc = nlp(combined_text)
217
+
218
+ # Collect entities with deduplication
219
+ seen: set[str] = set()
220
+ raw_entities: list[tuple[str, str]] = [] # (text, type)
221
+ for ent in doc.ents:
222
+ key = ent.text.strip().lower()
223
+ if not key or key in seen:
224
+ continue
225
+ seen.add(key)
226
+ entity_type = _ENTITY_TYPE_MAP.get(ent.label_, "CONDITION")
227
+ raw_entities.append((ent.text.strip(), entity_type))
228
+
229
+ if not raw_entities:
230
+ return EvalResult(
231
+ module_name="entity_verifier",
232
+ score=0.5, # neutral — cannot verify what isn't there
233
+ details={"total_entities": 0, "verified_count": 0, "flagged_count": 0, "entities": []},
234
+ latency_ms=int((time.perf_counter() - t0) * 1000),
235
+ )
236
+
237
+ # --- RxNorm verification for DRUG entities -------------------------------
238
+ cache = _get_rxnorm_cache(rxnorm_cache_path)
239
+ context_text = " ".join(context_docs or []).lower()
240
+
241
+ entity_results: list[dict] = []
242
+ drug_total = 0
243
+ drug_verified = 0
244
+ drug_flagged = 0
245
+
246
+ for entity_text, entity_type in raw_entities:
247
+ result = {
248
+ "entity": entity_text,
249
+ "type": entity_type,
250
+ "status": "NOT_FOUND",
251
+ "severity": None,
252
+ "answer_value": entity_text,
253
+ "context_value": None,
254
+ "rxcui": None,
255
+ }
256
+
257
+ if entity_type in DRUG_TYPES:
258
+ drug_total += 1
259
+ key = entity_text.lower()
260
+
261
+ # Layer 1: Local cache lookup
262
+ rxcui = cache.get(key)
263
+
264
+ # Layer 2: API fallback
265
+ if not rxcui and use_api_fallback:
266
+ rxcui = _lookup_rxnorm_api(entity_text)
267
+
268
+ if rxcui:
269
+ result["rxcui"] = rxcui
270
+
271
+ # Check for dosage discrepancy before marking VERIFIED
272
+ answer_doses = _extract_doses_near(answer, entity_text)
273
+ context_doses = _extract_doses_near(context_text, entity_text)
274
+ flagged_dose = False
275
+ if answer_doses and context_doses:
276
+ a_dose = answer_doses[0]
277
+ c_dose = min(context_doses, key=lambda d: abs(d - a_dose))
278
+ pct_diff = abs(a_dose - c_dose) / max(c_dose, 1e-9) * 100
279
+ if pct_diff > DOSAGE_TOLERANCE_PCT:
280
+ result["status"] = "FLAGGED"
281
+ result["severity"] = "MODERATE"
282
+ result["answer_value"] = f"{a_dose} (answer)"
283
+ result["context_value"] = f"{c_dose} (context, Δ{pct_diff:.0f}%)"
284
+ drug_flagged += 1
285
+ flagged_dose = True
286
+ logger.warning(
287
+ "Dosage discrepancy for '%s': answer=%.1f context=%.1f (%.0f%%)",
288
+ entity_text, a_dose, c_dose, pct_diff,
289
+ )
290
+
291
+ if not flagged_dose:
292
+ result["status"] = "VERIFIED"
293
+ drug_verified += 1
294
+ if key in context_text:
295
+ result["context_value"] = entity_text
296
+ else:
297
+ result["status"] = "NOT_FOUND"
298
+
299
+ elif entity_type in ("CONDITION", "PROCEDURE"):
300
+ # Non-drug entities: check presence in context only
301
+ if entity_text.lower() in context_text:
302
+ result["status"] = "VERIFIED"
303
+ result["context_value"] = entity_text
304
+ else:
305
+ result["status"] = "NOT_FOUND"
306
+
307
+ entity_results.append(result)
308
+
309
+ # --- Score ---------------------------------------------------------------
310
+ # Score is based on drug entities only (per SRS Section 6.2)
311
+ if drug_total == 0:
312
+ score = 0.5 # neutral — no drug entities to verify
313
+ else:
314
+ score = drug_verified / drug_total
315
+
316
+ details = {
317
+ "total_entities": len(raw_entities),
318
+ "drug_total": drug_total,
319
+ "verified_count": drug_verified,
320
+ "flagged_count": drug_flagged,
321
+ "entities": entity_results,
322
+ }
323
+
324
+ latency_ms = int((time.perf_counter() - t0) * 1000)
325
+ logger.info(
326
+ "Entity verification: %.3f (%d/%d drugs verified) in %d ms",
327
+ score, drug_verified, drug_total, latency_ms,
328
+ )
329
+ return EvalResult(
330
+ module_name="entity_verifier",
331
+ score=score,
332
+ details=details,
333
+ latency_ms=latency_ms,
334
+ )
src/modules/faithfulness.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FR-05: src/modules/faithfulness.py — Module 1: Faithfulness Scoring
3
+ =====================================================================
4
+ Uses cross-encoder/nli-deberta-v3-small to score how well the LLM answer
5
+ is entailed by the retrieved context chunks.
6
+
7
+ Architecture:
8
+ 1. Split answer into individual claims (sentences via pysbd)
9
+ 2. For each claim: compute NLI score against every context chunk
10
+ 3. Assign claim status: ENTAILED / NEUTRAL / CONTRADICTED
11
+ 4. score = entailed_count / total_claims
12
+
13
+ Thresholds (SRS Section 6.1):
14
+ entailment ≥ 0.50 → ENTAILED
15
+ contradiction ≥ 0.30 → CONTRADICTED
16
+ otherwise → NEUTRAL
17
+
18
+ Model loaded lazily and cached at module level (avoids double-loading
19
+ when called multiple times in same process).
20
+ """
21
+ from __future__ import annotations
22
+
23
+ import logging
24
+ import time
25
+ from functools import lru_cache
26
+ from typing import TYPE_CHECKING
27
+
28
+ from src.modules.base import EvalResult
29
+
30
+ if TYPE_CHECKING:
31
+ pass
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Constants
37
+ # ---------------------------------------------------------------------------
38
+
39
+ # BioLinkBERT fine-tuned on MedNLI (clinical notes, MIMIC-III)
40
+ # Paper 15 (Chen et al. SemEval-2023): best single model for biomedical NLI (F1=0.765)
41
+ # Faster on CPU than DeBERTa-large (BERT-base architecture)
42
+ MODEL_NAME = "cnut1648/biolinkbert-mednli"
43
+
44
+ # MedNLI label order (verified): {0: entailment, 1: neutral, 2: contradiction}
45
+ LABEL_ENTAILMENT = 0
46
+ LABEL_NEUTRAL = 1
47
+ LABEL_CONTRADICTION = 2
48
+
49
+ ENTAILMENT_THRESHOLD = 0.50
50
+ CONTRADICTION_THRESHOLD = 0.30
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # Lazy model loader
54
+ # ---------------------------------------------------------------------------
55
+
56
+ _model = None
57
+ _segmenter = None
58
+
59
+
60
+ def _get_model():
61
+ global _model
62
+ if _model is None:
63
+ try:
64
+ from sentence_transformers import CrossEncoder
65
+ logger.info("Loading NLI model: %s (first call only)", MODEL_NAME)
66
+ _model = CrossEncoder(MODEL_NAME)
67
+ logger.info("NLI model loaded.")
68
+ except ImportError:
69
+ logger.error("sentence_transformers not installed. Faithfulness will be stubbed.")
70
+ _model = "stub"
71
+ return _model
72
+
73
+ def _get_segmenter():
74
+ global _segmenter
75
+ if _segmenter is None:
76
+ try:
77
+ import pysbd
78
+ _segmenter = pysbd.Segmenter(language="en", clean=False)
79
+ except ImportError:
80
+ _segmenter = "stub"
81
+ return _segmenter
82
+
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # Public API
86
+ # ---------------------------------------------------------------------------
87
+
88
+ def score_faithfulness(
89
+ answer: str,
90
+ context_docs: list[str],
91
+ chunk_ids: list[str] | None = None,
92
+ max_chunks: int = 3,
93
+ config: dict | None = None,
94
+ ) -> EvalResult:
95
+ """
96
+ Score the faithfulness of an answer against retrieved context documents.
97
+
98
+ Args:
99
+ answer : The LLM-generated answer text.
100
+ context_docs : List of context passage strings (top-k retrieved chunks).
101
+ chunk_ids : Optional IDs matching context_docs for traceability.
102
+ max_chunks : Maximum context chunks to consider (to limit API calls).
103
+
104
+ Returns:
105
+ EvalResult with module_name="faithfulness", score in [0,1], and details
106
+ dict matching the shape defined in src/modules/__init__.py.
107
+ """
108
+ t0 = time.perf_counter()
109
+
110
+ _faith_cfg = (config or {}).get("modules", {}).get("faithfulness", {})
111
+ entailment_threshold = _faith_cfg.get("entailment_threshold", ENTAILMENT_THRESHOLD)
112
+ contradiction_threshold = CONTRADICTION_THRESHOLD
113
+
114
+ if not answer or not context_docs:
115
+ return EvalResult(
116
+ module_name="faithfulness",
117
+ score=0.0,
118
+ details={"error": "Empty answer or no context provided"},
119
+ error="Empty answer or no context",
120
+ latency_ms=0,
121
+ )
122
+
123
+ # Limit context size
124
+ docs = context_docs[:max_chunks]
125
+ ids = (chunk_ids or [f"chunk_{i}" for i in range(len(docs))])[:max_chunks]
126
+
127
+ # Strip markdown formatting from guideline/structured chunks before NLI
128
+ # DeBERTa NLI was trained on clean prose — markdown confuses it
129
+ import re as _re
130
+ _MD_CLEAN = _re.compile(r'\[([^\]]+)\]\n|#{1,6}\s+|•\s+|\*\*([^*]+)\*\*|\*([^*]+)\*|`[^`]+`')
131
+ docs = [_MD_CLEAN.sub(lambda m: m.group(2) or m.group(3) or '', d) for d in docs]
132
+
133
+ # Strip inline citations and markdown from the answer before claim splitting.
134
+ # LLM answers often include [Source: *title*] citations and **bold** text that
135
+ # confuse BioLinkBERT NLI — the model was trained on clean prose.
136
+ _CITE_RE = _re.compile(
137
+ r'\[Source:[^\]]*\]' # [Source: title] or [Source: *italic title*]
138
+ r'|\[[^\]]{0,120}\]' # other short bracket constructs
139
+ r'|\*\*([^*]+)\*\*' # **bold** → keep inner text
140
+ r'|\*([^*]+)\*' # *italic* → keep inner text
141
+ r'|`[^`]+`' # `code`
142
+ r'|^\s*[*•]\s+' # bullet points at line start
143
+ )
144
+ answer_clean = _CITE_RE.sub(lambda m: (m.group(1) or m.group(2) or ''), answer).strip()
145
+
146
+ # Split answer into claims
147
+ seg = _get_segmenter()
148
+ try:
149
+ if seg == "stub":
150
+ claims = [s.strip() for s in answer_clean.split(".") if s.strip()]
151
+ else:
152
+ claims = [s.strip() for s in seg.segment(answer_clean) if s.strip()]
153
+ except Exception:
154
+ claims = [s.strip() for s in answer_clean.split(".") if s.strip()]
155
+
156
+ if not claims:
157
+ return EvalResult(
158
+ module_name="faithfulness",
159
+ score=0.5,
160
+ details={"error": "Could not extract claims from answer"},
161
+ error="No claims extracted",
162
+ latency_ms=0,
163
+ )
164
+
165
+ model = _get_model()
166
+
167
+ # Limit claims to avoid O(claims×chunks) explosion with the large model
168
+ claims = claims[:12]
169
+
170
+ # ---------------------------------------------------------------------------
171
+ # Numerical Bypass (Paper 14: non-optional for clinical NLI)
172
+ # NLI models structurally cannot verify numerical comparisons (≥6.5%, 126 mg/dL).
173
+ # Use direct string/lexical matching for claims containing clinical measurements.
174
+ # ---------------------------------------------------------------------------
175
+ import re as _re2
176
+ _NUM_PATTERN = _re2.compile(
177
+ r'[\d]+[\s]*(mg|mcg|%|mL|mmol|IU|units?|g|kg|≥|≤|>|<|±|mg/dL|mmol/L|mg/kg)',
178
+ _re2.IGNORECASE,
179
+ )
180
+
181
+ def _numerical_match(claim: str, context_chunks: list[str]) -> str:
182
+ """
183
+ For claims with numerical clinical values, check if the key numbers
184
+ appear in any context chunk. Returns ENTAILED or NEUTRAL.
185
+ """
186
+ nums = _re2.findall(r'[\d]+\.?[\d]*', claim)
187
+ if not nums:
188
+ return "NEUTRAL"
189
+ combined = " ".join(context_chunks).lower()
190
+ matched = sum(1 for n in nums if n in combined)
191
+ return "ENTAILED" if matched >= len(nums) * 0.6 else "NEUTRAL"
192
+
193
+ # Separate numerical claims (bypass NLI) from textual claims (use NLI)
194
+ numerical_results: dict[int, str] = {} # claim_idx → status
195
+ nli_claim_indices: list[int] = []
196
+
197
+ for ci, claim in enumerate(claims):
198
+ if _NUM_PATTERN.search(claim):
199
+ numerical_results[ci] = _numerical_match(claim, docs)
200
+ else:
201
+ nli_claim_indices.append(ci)
202
+
203
+ # Build NLI pairs only for non-numerical claims
204
+ nli_claims = [claims[ci] for ci in nli_claim_indices]
205
+ all_pairs = []
206
+ pair_map: list[tuple[int, int]] = [] # (nli_claim_idx, doc_idx)
207
+ for nci, claim in enumerate(nli_claims):
208
+ for di, doc in enumerate(docs):
209
+ all_pairs.append((doc, claim))
210
+ pair_map.append((nci, di))
211
+
212
+ # Batch NLI inference
213
+ try:
214
+ if model == "stub":
215
+ # Provide dummy scores if model is unavailable
216
+ scores_raw = [[0.1, 0.1, 0.8] for _ in all_pairs]
217
+ else:
218
+ scores_raw = model.predict(all_pairs, apply_softmax=True)
219
+ except Exception as exc:
220
+ logger.error("NLI model inference failed: %s", exc)
221
+ return EvalResult(
222
+ module_name="faithfulness",
223
+ score=0.0,
224
+ details={},
225
+ error=f"Model inference error: {exc}",
226
+ latency_ms=int((time.perf_counter() - t0) * 1000),
227
+ )
228
+
229
+ # Aggregate: for each claim find the context with the highest entailment
230
+ claim_results: list[dict] = []
231
+ entailed = 0
232
+ neutral = 0
233
+ contradicted = 0
234
+
235
+ # Build per-NLI-claim best scores from batch results
236
+ nli_best: dict[int, tuple[float, float, int]] = {} # nci → (best_ent, best_con, best_doc)
237
+ for idx, (nci, d_i) in enumerate(pair_map):
238
+ score_vec = scores_raw[idx]
239
+ ent_score = float(score_vec[LABEL_ENTAILMENT])
240
+ con_score = float(score_vec[LABEL_CONTRADICTION])
241
+ if nci not in nli_best or ent_score > nli_best[nci][0]:
242
+ nli_best[nci] = (ent_score, con_score, d_i)
243
+
244
+ for ci, claim in enumerate(claims):
245
+ if ci in numerical_results:
246
+ # Numerical bypass — lexical match result
247
+ status = numerical_results[ci]
248
+ nli_score = 1.0 if status == "ENTAILED" else 0.0
249
+ best_doc_idx = 0
250
+ method = "numerical_bypass"
251
+ else:
252
+ # NLI result
253
+ nci = nli_claim_indices.index(ci) if ci in nli_claim_indices else -1
254
+ best_entailment, best_contradiction, best_doc_idx = nli_best.get(nci, (0.0, 0.0, 0))
255
+ if best_entailment >= entailment_threshold:
256
+ status = "ENTAILED"
257
+ nli_score = best_entailment
258
+ elif best_contradiction >= contradiction_threshold:
259
+ status = "CONTRADICTED"
260
+ nli_score = best_contradiction
261
+ else:
262
+ status = "NEUTRAL"
263
+ nli_score = best_entailment
264
+ method = "nli"
265
+
266
+ if status == "ENTAILED":
267
+ entailed += 1
268
+ elif status == "CONTRADICTED":
269
+ contradicted += 1
270
+ else:
271
+ neutral += 1
272
+
273
+ claim_results.append({
274
+ "claim": claim,
275
+ "status": status,
276
+ "best_chunk_id": ids[best_doc_idx],
277
+ "nli_score": round(nli_score, 4),
278
+ "method": method,
279
+ })
280
+
281
+ total = len(claims)
282
+ score = max(0.0, (entailed - contradicted) / total) if total > 0 else 0.0
283
+
284
+ details = {
285
+ "total_claims": total,
286
+ "entailed_count": entailed,
287
+ "neutral_count": neutral,
288
+ "contradicted_count": contradicted,
289
+ "claims": claim_results,
290
+ }
291
+
292
+ latency_ms = int((time.perf_counter() - t0) * 1000)
293
+ logger.info(
294
+ "Faithfulness: %.3f (%d/%d entailed) in %d ms",
295
+ score, entailed, total, latency_ms,
296
+ )
297
+ return EvalResult(
298
+ module_name="faithfulness",
299
+ score=score,
300
+ details=details,
301
+ latency_ms=latency_ms,
302
+ )
src/modules/source_credibility.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FR-14: src/modules/source_credibility.py — Module 3: Source Credibility Scoring
3
+ =================================================================================
4
+ Scores the credibility of retrieved source documents based on their publication
5
+ type / evidence tier.
6
+
7
+ Tier weights (SRS Section 6.3):
8
+ clinical_guideline → 1.00 (Tier 1 — highest authority)
9
+ systematic_review → 0.85 (Tier 2)
10
+ research_abstract → 0.70 (Tier 3 — PubMedQA default)
11
+ review_article → 0.60 (Tier 4)
12
+ clinical_case → 0.50 (Tier 5)
13
+ unknown / other → 0.30 (fallback)
14
+
15
+ Detection:
16
+ 1. Use 'tier_type' metadata field if present (set by embedder.py)
17
+ 2. Fall back to keyword matching in pub_type / title text
18
+
19
+ Score = weighted mean of tier weights across all retrieved chunks.
20
+
21
+ Each chunk must be a dict with at minimum:
22
+ {"text": str, "metadata": {"tier_type": str, "pub_type": str, "title": str}}
23
+ or the simpler form accepted by the retriever:
24
+ {"text": str, "source": str, "tier_type": str, "title": str}
25
+ """
26
+ from __future__ import annotations
27
+
28
+ import logging
29
+ import re
30
+ import time
31
+
32
+ from src.modules.base import EvalResult
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # Evidence tier weights
38
+ # ---------------------------------------------------------------------------
39
+
40
+ TIER_WEIGHTS: dict[str, float] = {
41
+ "clinical_guideline": 1.00,
42
+ "systematic_review": 0.85,
43
+ "drug_label": 0.90, # FDA-approved drug labels — authoritative regulatory source
44
+ "research_abstract": 0.70,
45
+ "review_article": 0.60,
46
+ "clinical_case": 0.50,
47
+ "unknown": 0.30,
48
+ }
49
+
50
+ # Keyword → tier_type mapping for fallback text matching
51
+ _KEYWORD_MAP: list[tuple[re.Pattern, str]] = [
52
+ (re.compile(r"\b(guideline|clinical practice|recommendation|consensus)\b", re.I), "clinical_guideline"),
53
+ (re.compile(r"\b(systematic review|meta.?analysis)\b", re.I), "systematic_review"),
54
+ # RCT / controlled trial → highest single-study evidence tier
55
+ (re.compile(r"\b(randomized|randomised|controlled trial|rct|clinical trial)\b", re.I), "clinical_guideline"),
56
+ # FDA drug labels
57
+ (re.compile(r"\b(fda|drug label|prescribing information|package insert|dailymed)\b", re.I), "drug_label"),
58
+ (re.compile(r"\b(review|overview)\b", re.I), "review_article"),
59
+ (re.compile(r"\b(case report|case study|clinical case)\b", re.I), "clinical_case"),
60
+ (re.compile(r"\b(abstract|research article|original article|journal)\b", re.I), "research_abstract"),
61
+ ]
62
+
63
+
64
+ def _classify_tier(chunk: dict) -> tuple[str, str | None]:
65
+ """
66
+ Return (tier_type, matched_keyword) for a single retrieved chunk dict.
67
+
68
+ Priority 1: explicit tier_type field (set by embedder.py)
69
+ Priority 2: pub_type field directly maps to a known tier name
70
+ Priority 3: keyword regex on pub_type + title text
71
+ """
72
+ # Priority 1: explicit tier_type already set (e.g., by embedder.py)
73
+ tier = (
74
+ chunk.get("tier_type")
75
+ or chunk.get("metadata", {}).get("tier_type")
76
+ )
77
+ if tier and tier in TIER_WEIGHTS:
78
+ return tier, None
79
+
80
+ # Priority 2: direct pub_type value lookup
81
+ # Handles underscore-separated values like "research_abstract" which
82
+ # won't match word-boundary regex patterns
83
+ pub_type_raw = str(
84
+ chunk.get("pub_type") or chunk.get("metadata", {}).get("pub_type") or ""
85
+ ).strip().lower()
86
+
87
+ _PUB_TYPE_DIRECT: dict[str, str] = {
88
+ "research_abstract": "research_abstract",
89
+ "abstract": "research_abstract",
90
+ "systematic_review": "systematic_review",
91
+ "systematic review": "systematic_review",
92
+ "meta_analysis": "systematic_review",
93
+ "meta-analysis": "systematic_review",
94
+ "drug_label": "drug_label",
95
+ "drug label": "drug_label",
96
+ "clinical_guideline": "clinical_guideline",
97
+ "clinical guideline": "clinical_guideline",
98
+ "guideline": "clinical_guideline",
99
+ "review_article": "review_article",
100
+ "review article": "review_article",
101
+ "review": "review_article",
102
+ "clinical_case": "clinical_case",
103
+ "case_report": "clinical_case",
104
+ "case report": "clinical_case",
105
+ }
106
+ if pub_type_raw in _PUB_TYPE_DIRECT:
107
+ return _PUB_TYPE_DIRECT[pub_type_raw], None
108
+
109
+ # Priority 3: keyword regex on pub_type + title text
110
+ title = str(chunk.get("title") or chunk.get("metadata", {}).get("title") or "")
111
+ text_to_search = f"{pub_type_raw} {title}"
112
+
113
+ for pattern, matched_tier in _KEYWORD_MAP:
114
+ m = pattern.search(text_to_search)
115
+ if m:
116
+ return matched_tier, m.group(0)
117
+
118
+
119
+ return "unknown", None
120
+
121
+
122
+ # ---------------------------------------------------------------------------
123
+ # Public API
124
+ # ---------------------------------------------------------------------------
125
+
126
+ def score_source_credibility(
127
+ retrieved_chunks: list[dict],
128
+ ) -> EvalResult:
129
+ """
130
+ Score the credibility of a set of retrieved source documents.
131
+
132
+ Args:
133
+ retrieved_chunks : List of chunk dicts as returned by retriever.retrieve().
134
+ Each must contain at minimum 'text' and ideally
135
+ 'tier_type', 'pub_type', 'title', 'chunk_id' fields.
136
+
137
+ Returns:
138
+ EvalResult with module_name="source_credibility", score in [0,1], and
139
+ details matching the shape from src/modules/__init__.py.
140
+ """
141
+ t0 = time.perf_counter()
142
+
143
+ if not retrieved_chunks:
144
+ return EvalResult(
145
+ module_name="source_credibility",
146
+ score=0.0,
147
+ details={"chunks": [], "method_used": "none"},
148
+ error="No chunks provided",
149
+ latency_ms=0,
150
+ )
151
+
152
+ chunk_details: list[dict] = []
153
+ weights: list[float] = []
154
+ method_used = "metadata" # assume metadata-first; may switch to keyword
155
+
156
+ for i, chunk in enumerate(retrieved_chunks):
157
+ tier_type, matched_kw = _classify_tier(chunk)
158
+ weight = TIER_WEIGHTS.get(tier_type, TIER_WEIGHTS["unknown"])
159
+ weights.append(weight)
160
+
161
+ if matched_kw:
162
+ method_used = "keyword"
163
+
164
+ # Compute tier number (1-5) for display
165
+ tier_num = {
166
+ "clinical_guideline": 1,
167
+ "systematic_review": 2,
168
+ "research_abstract": 3,
169
+ "review_article": 4,
170
+ "clinical_case": 5,
171
+ }.get(tier_type, 6) # 6 = unknown/unclassified
172
+
173
+ chunk_details.append(
174
+ {
175
+ "chunk_id": chunk.get("chunk_id") or chunk.get("metadata", {}).get("chunk_id") or f"chunk_{i}",
176
+ "tier": tier_num,
177
+ "tier_type": tier_type,
178
+ "tier_weight": round(weight, 2),
179
+ "pub_type": chunk.get("pub_type") or chunk.get("metadata", {}).get("pub_type") or "",
180
+ "title": (chunk.get("title") or chunk.get("metadata", {}).get("title") or "")[:80],
181
+ "matched_keyword": matched_kw,
182
+ }
183
+ )
184
+
185
+ score = sum(weights) / len(weights) if weights else 0.0
186
+
187
+ details = {
188
+ "method_used": method_used,
189
+ "chunk_count": len(retrieved_chunks),
190
+ "avg_tier_weight": round(score, 4),
191
+ "chunks": chunk_details,
192
+ }
193
+
194
+ latency_ms = int((time.perf_counter() - t0) * 1000)
195
+ logger.info(
196
+ "Source credibility: %.3f (avg tier weight over %d chunks) in %d ms",
197
+ score, len(retrieved_chunks), latency_ms,
198
+ )
199
+ return EvalResult(
200
+ module_name="source_credibility",
201
+ score=score,
202
+ details=details,
203
+ latency_ms=latency_ms,
204
+ )
src/pipeline/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # src/pipeline/__init__.py
src/pipeline/chunker.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FR-02: Document Chunking
3
+ ========================
4
+ LangChain RecursiveCharacterTextSplitter
5
+ chunk_size = 512 chars (config: retrieval.chunk_size)
6
+ overlap = 50 chars (config: retrieval.chunk_overlap)
7
+
8
+ Each chunk carries the full FR-03b metadata schema required by Module 3
9
+ (source credibility) and the FAISS metadata store.
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import logging
14
+ import uuid
15
+ from typing import Any
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def chunk_documents(
21
+ documents: list[dict[str, Any]],
22
+ config: dict,
23
+ ) -> list[dict[str, Any]]:
24
+ """
25
+ Split a list of raw documents into overlapping text chunks.
26
+
27
+ Args:
28
+ documents : List of dicts with keys:
29
+ text, doc_id, source, title, pub_type, pub_year, journal
30
+ config : Loaded config.yaml dict
31
+
32
+ Returns:
33
+ List of chunk dicts (FR-03b metadata schema):
34
+ chunk_id, chunk_text, doc_id, source, title,
35
+ pub_type, pub_year, journal, chunk_index, total_chunks
36
+ """
37
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
38
+
39
+ chunk_size = config["retrieval"]["chunk_size"] # 512
40
+ chunk_overlap = config["retrieval"]["chunk_overlap"] # 50
41
+
42
+ splitter = RecursiveCharacterTextSplitter(
43
+ chunk_size=chunk_size,
44
+ chunk_overlap=chunk_overlap,
45
+ length_function=len,
46
+ separators=["\n\n", "\n", ". ", " ", ""],
47
+ )
48
+
49
+ all_chunks: list[dict] = []
50
+
51
+ for doc in documents:
52
+ text = doc.get("text", "").strip()
53
+ if not text:
54
+ logger.debug("Skipping empty document: doc_id=%s", doc.get("doc_id"))
55
+ continue
56
+
57
+ raw_chunks = splitter.split_text(text)
58
+ total = len(raw_chunks)
59
+
60
+ for idx, chunk_text in enumerate(raw_chunks):
61
+ chunk_text = chunk_text.strip()
62
+ if not chunk_text:
63
+ continue
64
+ all_chunks.append({
65
+ # FR-03b schema
66
+ "chunk_id": str(uuid.uuid4()),
67
+ "chunk_text": chunk_text,
68
+ "doc_id": doc["doc_id"],
69
+ "source": doc["source"],
70
+ "title": doc["title"],
71
+ "pub_type": doc["pub_type"],
72
+ "pub_year": doc.get("pub_year", 0),
73
+ "journal": doc.get("journal", ""),
74
+ "chunk_index": idx,
75
+ "total_chunks": total,
76
+ })
77
+
78
+ logger.info(
79
+ "Chunked %d documents → %d chunks (size=%d, overlap=%d)",
80
+ len(documents), len(all_chunks), chunk_size, chunk_overlap,
81
+ )
82
+ return all_chunks
src/pipeline/consensus.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/pipeline/consensus.py — Multi-Model Consensus Engine
3
+ =========================================================
4
+ Implements the "Ensemble Judge" middleware feature.
5
+ Calls multiple LLMs and compares their answers for medical contradictions.
6
+ """
7
+ from __future__ import annotations
8
+ import logging
9
+ import concurrent.futures
10
+ from typing import List, Dict, Any, Optional
11
+ from src.pipeline.generator import generate_answer
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def run_consensus_check(
16
+ question: str,
17
+ context_chunks: List[Dict[str, Any]],
18
+ config: Dict[str, Any],
19
+ providers: List[str] = ["gemini", "groq"]
20
+ ) -> Dict[str, Any]:
21
+ """
22
+ Calls multiple providers in parallel and compares outcomes.
23
+ Returns: {
24
+ "answers": { provider: answer },
25
+ "agreement_score": float [0-1],
26
+ "conflicts": List[str],
27
+ "consensus_answer": str
28
+ }
29
+ """
30
+ logger.info("Starting Consensus Check with providers: %s", providers)
31
+
32
+ # 1. Generate answers in parallel
33
+ answers = {}
34
+ with concurrent.futures.ThreadPoolExecutor() as executor:
35
+ future_to_provider = {
36
+ executor.submit(generate_answer, question, context_chunks, config, {"provider": p}): p
37
+ for p in providers
38
+ }
39
+ for future in concurrent.futures.as_completed(future_to_provider):
40
+ provider = future_to_provider[future]
41
+ try:
42
+ answers[provider] = future.result()
43
+ except Exception as exc:
44
+ logger.error("Provider %s failed during consensus: %s", provider, exc)
45
+ answers[provider] = f"ERROR: {exc}"
46
+
47
+ if len(answers) < 2:
48
+ return {
49
+ "answers": answers,
50
+ "agreement_score": 1.0,
51
+ "conflicts": ["Insufficient providers responded for a full consensus check."],
52
+ "consensus_answer": list(answers.values())[0] if answers else "Safety failure: No providers responded."
53
+ }
54
+
55
+ # Compile context text to feed the Judge
56
+ context_text = "\n\n".join([f"Source {i+1}:\n{c.get('text', '')}" for i, c in enumerate(context_chunks)])
57
+
58
+ # 2. Compare answers using a "Judge" Agent
59
+ # We use Gemini (or the primary provider) as the judge
60
+ comparison_prompt = f"""
61
+ You are a Medical Consensus Judge. Compare the following two medical answers provided by different AI models to the same question.
62
+ CRITICAL INSTRUCTION: Your primary duty is to ensure the final answer is explicitly grounded in the provided MEDICAL CONTEXT.
63
+ Identify any CLINICAL CONTRADICTIONS or significant discrepancies in drug names, dosages, or recommendations.
64
+ If one model hallucinates outside the context, you must side with the model that stuck to the context.
65
+
66
+ QUESTION: {question}
67
+
68
+ MEDICAL CONTEXT FROM DATASET:
69
+ {context_text}
70
+
71
+ ANSWER A:
72
+ {list(answers.values())[0]}
73
+
74
+ ANSWER B:
75
+ {list(answers.values())[1] if len(answers) > 1 else "N/A"}
76
+
77
+ OUTPUT FORMAT (JSON ONLY):
78
+ {{
79
+ "agreement_score": 0.0 to 1.0 (1.0 means perfect alignment, 0.0 means complete contradiction),
80
+ "conflicts": ["list of specific medical discrepancies found"],
81
+ "summary": "brief summary of how they differ and which one aligns better with the Medical Context",
82
+ "recommended_consensus": "the most conservative and safe unified answer that strictly adheres to the Medical Context"
83
+ }}
84
+ """
85
+ try:
86
+ # Use the generator's default to run the judge
87
+ judge_raw = generate_answer("Medical Consensus Judge Task", [{"text": comparison_prompt}], config)
88
+ # Attempt to parse JSON from the judge's response
89
+ # (A real implementation would use structured output, but we use a robust parse for now)
90
+ import json
91
+ import re
92
+
93
+ # Clean potential markdown
94
+ clean_json = re.sub(r'```json\n?|\n?```', '', judge_raw).strip()
95
+ judge_data = json.loads(clean_json)
96
+
97
+ return {
98
+ "answers": answers,
99
+ "agreement_score": judge_data.get("agreement_score", 0.5),
100
+ "conflicts": judge_data.get("conflicts", []),
101
+ "summary": judge_data.get("summary", ""),
102
+ "consensus_answer": judge_data.get("recommended_consensus", list(answers.values())[0])
103
+ }
104
+ except Exception as e:
105
+ logger.error("Consensus Judge failed: %s", e)
106
+ return {
107
+ "answers": answers,
108
+ "agreement_score": 0.5,
109
+ "conflicts": [f"Judge failed: {e}"],
110
+ "consensus_answer": list(answers.values())[0]
111
+ }
src/pipeline/embedder.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FR-03 + FR-03b: Embedding Generation & FAISS Index Construction
3
+ ===============================================================
4
+ Model : dmis-lab/biobert-v1.1 (768-dim dense vectors, SentenceTransformer)
5
+ Index : FAISS IndexFlatIP with L2-normalized vectors (= cosine similarity)
6
+ Metadata: Parallel dict[int → dict] saved as pickle alongside index
7
+
8
+ Usage:
9
+ python src/pipeline/embedder.py
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import sys
14
+ import os
15
+ from pathlib import Path
16
+
17
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
18
+
19
+ import json
20
+ import logging
21
+ import pickle
22
+
23
+ import faiss
24
+ import numpy as np
25
+ import yaml
26
+
27
+ import src # noqa: F401 — logging setup
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def _load_config() -> dict:
33
+ with open("config.yaml", "r", encoding="utf-8") as f:
34
+ return yaml.safe_load(f)
35
+
36
+
37
+ def load_chunks(chunks_path: str = "data/processed/chunks.jsonl") -> list[dict]:
38
+ """Load chunks from JSONL produced by ingest.py."""
39
+ path = Path(chunks_path)
40
+ if not path.exists():
41
+ raise FileNotFoundError(
42
+ f"Chunks file not found: '{chunks_path}'. "
43
+ "Run python src/pipeline/ingest.py first."
44
+ )
45
+ chunks = []
46
+ with open(path, "r", encoding="utf-8") as f:
47
+ for line in f:
48
+ line = line.strip()
49
+ if line:
50
+ chunks.append(json.loads(line))
51
+ logger.info("Loaded %d chunks from %s", len(chunks), chunks_path)
52
+ return chunks
53
+
54
+
55
+ def encode_texts(
56
+ texts: list[str],
57
+ model_name: str,
58
+ batch_size: int = 32,
59
+ ) -> np.ndarray:
60
+ """
61
+ Encode texts using BioBERT via SentenceTransformer.
62
+ Returns L2-normalized float32 array of shape (N, 768).
63
+ """
64
+ from sentence_transformers import SentenceTransformer
65
+
66
+ logger.info("Loading embedding model: %s", model_name)
67
+ model = SentenceTransformer(model_name)
68
+
69
+ logger.info("Encoding %d texts (batch_size=%d)...", len(texts), batch_size)
70
+ embeddings: np.ndarray = model.encode(
71
+ texts,
72
+ batch_size=batch_size,
73
+ show_progress_bar=True,
74
+ normalize_embeddings=True, # L2-normalise → cosine via IndexFlatIP
75
+ convert_to_numpy=True,
76
+ )
77
+ logger.info("Encoded shape: %s", embeddings.shape)
78
+ return embeddings.astype(np.float32)
79
+
80
+
81
+ def build_faiss_index(embeddings: np.ndarray) -> faiss.IndexFlatIP:
82
+ """
83
+ Build FAISS IndexFlatIP.
84
+ Because vectors are L2-normalised, inner product == cosine similarity.
85
+ """
86
+ dim = embeddings.shape[1] # 768 for BioBERT
87
+ index = faiss.IndexFlatIP(dim)
88
+ index.add(embeddings)
89
+ logger.info(
90
+ "FAISS IndexFlatIP built: %d vectors, dim=%d", index.ntotal, dim
91
+ )
92
+ return index
93
+
94
+
95
+ def build_metadata_store(chunks: list[dict]) -> dict[int, dict]:
96
+ """
97
+ Build parallel metadata dict → key = FAISS integer index (0-based).
98
+ Stores the full FR-03b schema plus chunk_text for retrieval.
99
+ """
100
+ store: dict[int, dict] = {}
101
+ for i, chunk in enumerate(chunks):
102
+ store[i] = {
103
+ "chunk_id": chunk["chunk_id"],
104
+ "doc_id": chunk["doc_id"],
105
+ "source": chunk["source"],
106
+ "title": chunk["title"],
107
+ "pub_type": chunk["pub_type"],
108
+ "pub_year": chunk["pub_year"],
109
+ "journal": chunk["journal"],
110
+ "chunk_index": chunk["chunk_index"],
111
+ "total_chunks": chunk["total_chunks"],
112
+ "chunk_text": chunk["chunk_text"], # kept for retrieval
113
+ }
114
+ return store
115
+
116
+
117
+ def save_artifacts(
118
+ index: faiss.IndexFlatIP,
119
+ metadata_store: dict,
120
+ config: dict,
121
+ ) -> None:
122
+ """Persist FAISS index and metadata pickle to disk."""
123
+ index_path = Path(config["retrieval"]["index_path"])
124
+ meta_path = Path(config["retrieval"]["metadata_path"])
125
+
126
+ index_path.parent.mkdir(parents=True, exist_ok=True)
127
+ meta_path.parent.mkdir(parents=True, exist_ok=True)
128
+
129
+ faiss.write_index(index, str(index_path))
130
+ logger.info("FAISS index written to %s", index_path)
131
+
132
+ with open(meta_path, "wb") as f:
133
+ pickle.dump(metadata_store, f, protocol=pickle.HIGHEST_PROTOCOL)
134
+ logger.info(
135
+ "Metadata store written to %s (%d entries)", meta_path, len(metadata_store)
136
+ )
137
+
138
+
139
+ def main() -> None:
140
+ config = _load_config()
141
+ chunks = load_chunks("data/processed/chunks.jsonl")
142
+
143
+ if not chunks:
144
+ logger.error("No chunks to embed. Run python src/pipeline/ingest.py first.")
145
+ sys.exit(1)
146
+
147
+ texts = [c["chunk_text"] for c in chunks]
148
+ model_name = config["retrieval"]["embedding_model"]
149
+ embeddings = encode_texts(texts, model_name, batch_size=32)
150
+ index = build_faiss_index(embeddings)
151
+ metadata_store = build_metadata_store(chunks)
152
+
153
+ save_artifacts(index, metadata_store, config)
154
+
155
+ logger.info(
156
+ "Embedding complete. Index has %d vectors. "
157
+ "Next: python scripts/warmup.py && streamlit run src/dashboard/app.py",
158
+ index.ntotal,
159
+ )
160
+
161
+
162
+ if __name__ == "__main__":
163
+ main()
src/pipeline/generator.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/pipeline/generator.py — LLM Answer Generation
3
+ ===================================================
4
+ Supports multiple providers based on config.yaml → llm.provider:
5
+ - "gemini" : Google Gemini API (recommended)
6
+ - "mistral" : Mistral AI API (api.mistral.ai)
7
+ - "groq" : Groq Cloud API (fast inference)
8
+ - "ollama" : Local Ollama/Mistral (requires Ollama running locally)
9
+
10
+ API Key setup:
11
+ Set env variables in Backend/.env:
12
+ GEMINI_API_KEY=your_key
13
+ MISTRAL_API_KEY=your_key
14
+ GROQ_API_KEY=your_key
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import logging
20
+ import os
21
+ import time
22
+ from pathlib import Path
23
+ from typing import Optional
24
+
25
+ import yaml
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Load .env file at module import time
30
+ def _load_env():
31
+ env_path = Path(".env")
32
+ if not env_path.exists():
33
+ # Try one level up
34
+ env_path = Path("../Backend/.env")
35
+ if env_path.exists():
36
+ for line in env_path.read_text().splitlines():
37
+ line = line.strip()
38
+ if line and not line.startswith("#") and "=" in line:
39
+ key, val = line.split("=", 1)
40
+ key = key.strip()
41
+ val = val.strip().strip('"').strip("'")
42
+ if key and val and key not in os.environ:
43
+ os.environ[key] = val
44
+
45
+ _load_env()
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # Config loader
49
+ # ---------------------------------------------------------------------------
50
+
51
+ def _load_config() -> dict:
52
+ try:
53
+ return yaml.safe_load(Path("config.yaml").read_text())
54
+ except Exception:
55
+ return {}
56
+
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # Prompt builder (shared by both providers)
60
+ # ---------------------------------------------------------------------------
61
+
62
+ _PHYSICIAN_PROMPT = (
63
+ "You are MediRAG, a medical AI assistant tailored for clinicians and researchers. "
64
+ "You MUST answer ONLY using information explicitly stated in the CONTEXT provided below. "
65
+ "Use professional medical terminology, be concise, and cite specific details. "
66
+ "After each claim, cite it inline as [Source: <document title>]. "
67
+ "If the context does NOT contain sufficient information to answer safely, you MUST respond EXACTLY with: "
68
+ "'⚠️ The retrieved context does not contain enough information to answer this safely. "
69
+ "Please consult authoritative clinical guidelines or a specialist.' "
70
+ "NEVER use general knowledge, training data, or information outside the provided context."
71
+ )
72
+
73
+ _PATIENT_PROMPT = (
74
+ "You are MediRAG, a medical AI assistant tailored for patients and non-experts. "
75
+ "You MUST answer ONLY using information explicitly stated in the CONTEXT provided below. "
76
+ "Explain medical information in a clear, accessible, and empathetic way. "
77
+ "After each claim, cite it inline as [Source: <document title>]. "
78
+ "If the context does NOT contain sufficient information to answer safely, you MUST respond EXACTLY with: "
79
+ "'⚠️ The retrieved context does not contain enough information to answer this safely. "
80
+ "Please consult your doctor or a medical specialist.' "
81
+ "NEVER use general knowledge, training data, or information outside the provided context."
82
+ )
83
+
84
+ _SYSTEM_PROMPT = _PHYSICIAN_PROMPT # Default fallback
85
+
86
+
87
+ def _build_prompt(question: str, context_chunks: list[dict], system_prompt: Optional[str] = None, persona: str = "physician") -> str:
88
+ """Build the RAG prompt from the question + retrieved chunks.
89
+
90
+ Explicitly surfaces title and source for each chunk in the header so the LLM
91
+ can cite [Source: <title>] inline in its answer.
92
+ """
93
+ context_parts = []
94
+ for i, chunk in enumerate(context_chunks, 1):
95
+ text = chunk.get("text") or chunk.get("chunk_text", "")
96
+ title = chunk.get("title", "")
97
+ source = chunk.get("source", "")
98
+ pub_type = chunk.get("pub_type", "")
99
+ # Include title as the primary citation label
100
+ header_parts = [f"Source {i}"]
101
+ if title:
102
+ header_parts.append(f"Title: {title}")
103
+ if pub_type:
104
+ header_parts.append(pub_type)
105
+ if source and source != title:
106
+ header_parts.append(source)
107
+ header = "[" + " | ".join(header_parts) + "]"
108
+ context_parts.append(f"{header}\n{text.strip()}")
109
+
110
+ context_block = "\n\n".join(context_parts)
111
+
112
+ # Determine effective system prompt based on persona if no manual override
113
+ if system_prompt:
114
+ effective_system = system_prompt
115
+ else:
116
+ effective_system = _PATIENT_PROMPT if persona == "patient" else _PHYSICIAN_PROMPT
117
+
118
+ return (
119
+ f"{effective_system}\n\n"
120
+ f"CONTEXT:\n{context_block}\n\n"
121
+ f"QUESTION: {question}\n\n"
122
+ f"ANSWER (cite sources inline as [Source: document title]):"
123
+ )
124
+
125
+
126
+ # Strict prompt — used when first answer fails evaluation (HRS ≥ 60)
127
+ _STRICT_SYSTEM_PROMPT = (
128
+ "You are MediRAG, a clinical safety assistant under strict mode. "
129
+ "A previous response was flagged as potentially unsafe or inaccurate. "
130
+ "You MUST answer ONLY using the information explicitly stated in the CONTEXT below. "
131
+ "Do NOT use any general medical knowledge, training data, or outside information. "
132
+ "If the context is insufficient, you MUST say EXACTLY: "
133
+ "'⚠️ Insufficient evidence in retrieved context to answer safely. Please consult a clinical specialist.' "
134
+ "NEVER hallucinate drug names, dosages, or clinical recommendations."
135
+ )
136
+
137
+
138
+ def _build_strict_prompt(question: str, context_chunks: list[dict]) -> str:
139
+ """Strict prompt: context-only, used on regeneration after failed evaluation."""
140
+ context_parts = []
141
+ for i, chunk in enumerate(context_chunks, 1):
142
+ text = chunk.get("text") or chunk.get("chunk_text", "")
143
+ title = chunk.get("title", "")
144
+ source = chunk.get("source", "")
145
+ pub_type = chunk.get("pub_type", "")
146
+ header_parts = [f"Source {i}"]
147
+ if title:
148
+ header_parts.append(f"Title: {title}")
149
+ if pub_type:
150
+ header_parts.append(pub_type)
151
+ if source and source != title:
152
+ header_parts.append(source)
153
+ header = "[" + " | ".join(header_parts) + "]"
154
+ context_parts.append(f"{header}\n{text.strip()}")
155
+
156
+ context_block = "\n\n".join(context_parts)
157
+ return (
158
+ f"{_STRICT_SYSTEM_PROMPT}\n\n"
159
+ f"CONTEXT:\n{context_block}\n\n"
160
+ f"QUESTION: {question}\n\n"
161
+ f"SAFE ANSWER (context-only, cite [Source: title] for every claim):"
162
+ )
163
+
164
+
165
+ # ---------------------------------------------------------------------------
166
+ # OpenAI provider
167
+ # ---------------------------------------------------------------------------
168
+
169
+ def _generate_openai(prompt: str, config: dict) -> str:
170
+ llm_cfg = config.get("llm", {})
171
+
172
+ # Override from frontend/config takes priority over system ENV
173
+ api_key = llm_cfg.get("openai_api_key") or os.environ.get("OPENAI_API_KEY")
174
+ if not api_key:
175
+ env_file = Path(".env")
176
+ if env_file.exists():
177
+ for line in env_file.read_text().splitlines():
178
+ if line.startswith("OPENAI_API_KEY="):
179
+ api_key = line.split("=", 1)[1].strip().strip('"').strip("'")
180
+ break
181
+
182
+ if not api_key:
183
+ raise RuntimeError("OpenAI API key not found. Set OPENAI_API_KEY env var or in .env.")
184
+
185
+ try:
186
+ from openai import OpenAI
187
+ except ImportError:
188
+ raise RuntimeError("openai not installed. Run: pip install openai")
189
+
190
+ model_name = llm_cfg.get("openai_model") or llm_cfg.get("model") or "gpt-4o"
191
+ client = OpenAI(api_key=api_key)
192
+
193
+ logger.info("Calling OpenAI API (model=%s)...", model_name)
194
+ t0 = time.perf_counter()
195
+
196
+ try:
197
+ response = client.chat.completions.create(
198
+ model=model_name,
199
+ messages=[{"role": "user", "content": prompt}],
200
+ temperature=float(llm_cfg.get("generation_temperature", 0.7)),
201
+ max_tokens=1024,
202
+ )
203
+ except Exception as exc:
204
+ raise RuntimeError(f"OpenAI API error: {exc}") from exc
205
+
206
+ elapsed = int((time.perf_counter() - t0) * 1000)
207
+ answer = response.choices[0].message.content.strip()
208
+
209
+ if not answer:
210
+ raise RuntimeError("OpenAI returned an empty response.")
211
+
212
+ logger.info("OpenAI generated answer in %d ms (%d chars)", elapsed, len(answer))
213
+ return answer
214
+
215
+ def _generate_gemini(prompt: str, config: dict) -> str:
216
+ llm_cfg = config.get("llm", {})
217
+
218
+ # Override from frontend/config takes priority over system ENV
219
+ api_key = llm_cfg.get("gemini_api_key") or os.environ.get("GEMINI_API_KEY")
220
+ if not api_key:
221
+ # Try loading from .env file if present
222
+ env_file = Path(".env")
223
+ if env_file.exists():
224
+ for line in env_file.read_text().splitlines():
225
+ if line.startswith("GEMINI_API_KEY="):
226
+ api_key = line.split("=", 1)[1].strip().strip('"').strip("'")
227
+ break
228
+
229
+ if not api_key:
230
+ raise RuntimeError(
231
+ "Gemini API key not found. "
232
+ "Either: (1) set GEMINI_API_KEY=your_key in the same terminal as uvicorn, "
233
+ "or (2) create a .env file with GEMINI_API_KEY=your_key in the project root."
234
+ )
235
+
236
+ try:
237
+ from google import genai
238
+ from google.genai import types
239
+ except ImportError:
240
+ raise RuntimeError(
241
+ "google-genai not installed. Run: pip install google-genai"
242
+ )
243
+
244
+ model_name = llm_cfg.get("gemini_model", "gemini-2.0-flash")
245
+ client = genai.Client(api_key=api_key)
246
+
247
+ logger.info("Calling Gemini API (model=%s)...", model_name)
248
+ t0 = time.perf_counter()
249
+
250
+ try:
251
+ response = client.models.generate_content(
252
+ model=model_name,
253
+ contents=prompt,
254
+ config=types.GenerateContentConfig(
255
+ temperature=float(llm_cfg.get("generation_temperature", 0.7)),
256
+ max_output_tokens=1024,
257
+ ),
258
+ )
259
+ except Exception as exc:
260
+ raise RuntimeError(f"Gemini API error: {exc}") from exc
261
+
262
+ elapsed = int((time.perf_counter() - t0) * 1000)
263
+ answer = response.text.strip() if response.text else ""
264
+
265
+ if not answer:
266
+ raise RuntimeError("Gemini returned an empty response.")
267
+
268
+ logger.info("Gemini generated answer in %d ms (%d chars)", elapsed, len(answer))
269
+ return answer
270
+
271
+
272
+ # ---------------------------------------------------------------------------
273
+ # Ollama provider (kept as fallback)
274
+ # ---------------------------------------------------------------------------
275
+
276
+ def _generate_ollama(prompt: str, config: dict) -> str:
277
+ import requests as _requests
278
+
279
+ llm_cfg = config.get("llm", {})
280
+ base_url = llm_cfg.get("base_url", "http://localhost:11434")
281
+ model = llm_cfg.get("model", "mistral")
282
+ timeout = llm_cfg.get("timeout_seconds", 120)
283
+ temperature = llm_cfg.get("generation_temperature", 0.7)
284
+
285
+ payload = {
286
+ "model": model,
287
+ "prompt": prompt,
288
+ "stream": False,
289
+ "options": {"temperature": temperature, "num_predict": 512},
290
+ }
291
+
292
+ url = f"{base_url}/api/generate"
293
+ logger.info("Calling Ollama (%s @ %s)...", model, base_url)
294
+ t0 = time.perf_counter()
295
+
296
+ try:
297
+ resp = _requests.post(url, json=payload, timeout=timeout)
298
+ except _requests.exceptions.ConnectionError as exc:
299
+ raise RuntimeError(
300
+ f"Ollama is not running at {base_url}. Start with: ollama serve"
301
+ ) from exc
302
+ except _requests.exceptions.Timeout as exc:
303
+ raise RuntimeError(
304
+ f"Ollama timed out after {timeout}s. Increase llm.timeout_seconds in config.yaml."
305
+ ) from exc
306
+
307
+ if resp.status_code != 200:
308
+ raise RuntimeError(f"Ollama HTTP {resp.status_code}: {resp.text[:300]}")
309
+
310
+ try:
311
+ data = resp.json()
312
+ answer = data.get("response", "").strip()
313
+ except (json.JSONDecodeError, KeyError) as exc:
314
+ raise RuntimeError(f"Unexpected Ollama response: {exc}") from exc
315
+
316
+ if not answer:
317
+ raise RuntimeError("Ollama returned an empty response.")
318
+
319
+ elapsed = int((time.perf_counter() - t0) * 1000)
320
+ logger.info("Ollama generated answer in %d ms (%d chars)", elapsed, len(answer))
321
+ return answer
322
+
323
+
324
+ # ---------------------------------------------------------------------------
325
+ # Mistral provider
326
+ # ---------------------------------------------------------------------------
327
+
328
+ def _generate_mistral(prompt: str, config: dict) -> str:
329
+ import requests as _requests
330
+
331
+ llm_cfg = config.get("llm", {})
332
+ # Resolve placeholder or direct value
333
+ _raw_key = llm_cfg.get("mistral_api_key", "")
334
+ api_key = os.environ.get("MISTRAL_API_KEY") if (not _raw_key or _raw_key.startswith("${")) else _raw_key
335
+ if not api_key:
336
+ raise RuntimeError(
337
+ "Mistral API key not found. Set MISTRAL_API_KEY in Backend/.env"
338
+ )
339
+
340
+ model = llm_cfg.get("model", "mistral-large-latest")
341
+ timeout = llm_cfg.get("timeout_seconds", 120)
342
+ temperature = llm_cfg.get("generation_temperature", 0.7)
343
+
344
+ payload = {
345
+ "model": model,
346
+ "messages": [{"role": "user", "content": prompt}],
347
+ "temperature": temperature,
348
+ "max_tokens": 1024,
349
+ }
350
+
351
+ headers = {
352
+ "Authorization": f"Bearer {api_key}",
353
+ "Content-Type": "application/json"
354
+ }
355
+
356
+ url = "https://api.mistral.ai/v1/chat/completions"
357
+ logger.info("Calling Mistral API (model=%s, key=...***)", model)
358
+ t0 = time.perf_counter()
359
+
360
+ try:
361
+ resp = _requests.post(url, json=payload, headers=headers, timeout=timeout)
362
+ except Exception as exc:
363
+ raise RuntimeError(f"Mistral API network error: {exc}") from exc
364
+
365
+ if resp.status_code != 200:
366
+ raise RuntimeError(f"Mistral HTTP {resp.status_code}: {resp.text[:300]}")
367
+
368
+ try:
369
+ data = resp.json()
370
+ answer = data["choices"][0]["message"]["content"].strip()
371
+ except Exception as exc:
372
+ raise RuntimeError(f"Unexpected Mistral response: {exc}") from exc
373
+
374
+ if not answer:
375
+ raise RuntimeError("Mistral returned an empty response.")
376
+
377
+ elapsed = int((time.perf_counter() - t0) * 1000)
378
+ logger.info("Mistral generated answer in %d ms (%d chars)", elapsed, len(answer))
379
+ return answer
380
+
381
+
382
+ # ---------------------------------------------------------------------------
383
+ # Groq provider
384
+ # ---------------------------------------------------------------------------
385
+
386
+ def _generate_groq(prompt: str, config: dict) -> str:
387
+ import requests as _requests
388
+
389
+ llm_cfg = config.get("llm", {})
390
+ _raw_key = llm_cfg.get("groq_api_key", "")
391
+ api_key = os.environ.get("GROQ_API_KEY") if (not _raw_key or _raw_key.startswith("${")) else _raw_key
392
+ if not api_key:
393
+ raise RuntimeError(
394
+ "Groq API key not found. Set GROQ_API_KEY in Backend/.env"
395
+ )
396
+
397
+ model = llm_cfg.get("groq_model") or llm_cfg.get("model", "llama-3.3-70b-versatile")
398
+ timeout = llm_cfg.get("timeout_seconds", 120)
399
+ temperature = llm_cfg.get("generation_temperature", 0.7)
400
+
401
+ payload = {
402
+ "model": model,
403
+ "messages": [{"role": "user", "content": prompt}],
404
+ "temperature": temperature,
405
+ "max_tokens": 1024,
406
+ }
407
+
408
+ headers = {
409
+ "Authorization": f"Bearer {api_key}",
410
+ "Content-Type": "application/json"
411
+ }
412
+
413
+ url = "https://api.groq.com/openai/v1/chat/completions"
414
+ logger.info("Calling Groq API (model=%s, key=...***)", model)
415
+ t0 = time.perf_counter()
416
+
417
+ try:
418
+ resp = _requests.post(url, json=payload, headers=headers, timeout=timeout)
419
+ except Exception as exc:
420
+ raise RuntimeError(f"Groq API network error: {exc}") from exc
421
+
422
+ if resp.status_code != 200:
423
+ raise RuntimeError(f"Groq HTTP {resp.status_code}: {resp.text[:300]}")
424
+
425
+ try:
426
+ data = resp.json()
427
+ answer = data["choices"][0]["message"]["content"].strip()
428
+ except Exception as exc:
429
+ raise RuntimeError(f"Unexpected Groq response: {exc}") from exc
430
+
431
+ if not answer:
432
+ raise RuntimeError("Groq returned an empty response.")
433
+
434
+ elapsed = int((time.perf_counter() - t0) * 1000)
435
+ logger.info("Groq generated answer in %d ms (%d chars)", elapsed, len(answer))
436
+ return answer
437
+
438
+
439
+ # ---------------------------------------------------------------------------
440
+ # Public API
441
+ # ---------------------------------------------------------------------------
442
+
443
+ def generate_answer(
444
+ question: str,
445
+ context_chunks: list[dict],
446
+ config: Optional[dict] = None,
447
+ overrides: Optional[dict] = None,
448
+ ) -> str:
449
+ """
450
+ Generate a grounded medical answer.
451
+
452
+ Provider is selected from config.yaml → llm.provider, but can be
453
+ overridden per-request via the `overrides` dict. This makes the eval
454
+ engine portable — callers bring their own API key and model.
455
+
456
+ Args:
457
+ question : User's medical question.
458
+ context_chunks : Retrieved context chunks (dicts with 'text' key).
459
+ config : Config dict (loaded from config.yaml if None).
460
+ overrides : Per-request overrides. Supported keys:
461
+ provider → "gemini" or "ollama"
462
+ api_key → Gemini API key
463
+ model → model name (e.g. "gemini-2.5-flash-lite")
464
+ ollama_url → Ollama base URL
465
+
466
+ Returns:
467
+ Generated answer string.
468
+
469
+ Raises:
470
+ RuntimeError : If the provider is unreachable or returns an error.
471
+ """
472
+ if config is None:
473
+ config = _load_config()
474
+
475
+ # Build effective config: server config as base, overrides win
476
+ effective_llm = dict(config.get("llm", {}))
477
+ if overrides:
478
+ if overrides.get("provider"):
479
+ effective_llm["provider"] = overrides["provider"]
480
+ if overrides.get("api_key"):
481
+ pk = (overrides.get("provider") or "gemini").lower()
482
+ key_map = {
483
+ "gemini": "gemini_api_key",
484
+ "openai": "openai_api_key",
485
+ "mistral": "mistral_api_key",
486
+ "groq": "groq_api_key",
487
+ }
488
+ effective_llm[key_map.get(pk, "gemini_api_key")] = overrides["api_key"]
489
+ if overrides.get("model"):
490
+ pk = (overrides.get("provider") or "gemini").lower()
491
+ model_map = {
492
+ "gemini": "gemini_model",
493
+ "openai": "openai_model",
494
+ "mistral": "model",
495
+ "groq": "groq_model",
496
+ }
497
+ effective_llm[model_map.get(pk, "gemini_model")] = overrides["model"]
498
+ if overrides.get("ollama_url"):
499
+ effective_llm["base_url"] = overrides["ollama_url"]
500
+
501
+ effective_config = {**config, "llm": effective_llm}
502
+ provider = effective_llm.get("provider", "gemini").lower()
503
+ system_prompt_override = overrides.get("system_prompt") if overrides else None
504
+ persona = overrides.get("persona", "physician") if overrides else "physician"
505
+
506
+ prompt = _build_prompt(
507
+ question,
508
+ context_chunks,
509
+ system_prompt=system_prompt_override,
510
+ persona=persona
511
+ )
512
+
513
+ if provider == "gemini":
514
+ return _generate_gemini(prompt, effective_config)
515
+ elif provider == "openai":
516
+ return _generate_openai(prompt, effective_config)
517
+ elif provider == "ollama":
518
+ return _generate_ollama(prompt, effective_config)
519
+ elif provider == "mistral":
520
+ return _generate_mistral(prompt, effective_config)
521
+ elif provider == "groq":
522
+ return _generate_groq(prompt, effective_config)
523
+ else:
524
+ raise RuntimeError(
525
+ f"Unknown LLM provider '{provider}'. "
526
+ "Set llm.provider to 'gemini', 'mistral', 'groq', or 'ollama'."
527
+ )
528
+
529
+
530
+ def generate_strict_answer(
531
+ question: str,
532
+ context_chunks: list[dict],
533
+ config: Optional[dict] = None,
534
+ overrides: Optional[dict] = None,
535
+ ) -> str:
536
+ """
537
+ Generate a STRICT context-only answer.
538
+ Called when initial answer fails evaluation (HRS >= 60).
539
+ The LLM is forbidden from using any training knowledge.
540
+ """
541
+ if config is None:
542
+ config = _load_config()
543
+
544
+ effective_llm = dict(config.get("llm", {}))
545
+ if overrides:
546
+ if overrides.get("provider"):
547
+ effective_llm["provider"] = overrides["provider"]
548
+ if overrides.get("api_key"):
549
+ pk = (overrides.get("provider") or "gemini").lower()
550
+ key_map = {
551
+ "gemini": "gemini_api_key",
552
+ "openai": "openai_api_key",
553
+ "mistral": "mistral_api_key",
554
+ "groq": "groq_api_key",
555
+ }
556
+ effective_llm[key_map.get(pk, "gemini_api_key")] = overrides["api_key"]
557
+ if overrides.get("model"):
558
+ pk = (overrides.get("provider") or "gemini").lower()
559
+ model_map = {
560
+ "gemini": "gemini_model",
561
+ "openai": "openai_model",
562
+ "mistral": "model",
563
+ "groq": "groq_model",
564
+ }
565
+ effective_llm[model_map.get(pk, "gemini_model")] = overrides["model"]
566
+ if overrides.get("ollama_url"):
567
+ effective_llm["base_url"] = overrides["ollama_url"]
568
+
569
+ effective_config = {**config, "llm": effective_llm}
570
+ provider = effective_llm.get("provider", "gemini").lower()
571
+ prompt = _build_strict_prompt(question, context_chunks)
572
+
573
+ if provider == "gemini":
574
+ return _generate_gemini(prompt, effective_config)
575
+ elif provider == "openai":
576
+ return _generate_openai(prompt, effective_config)
577
+ elif provider == "ollama":
578
+ return _generate_ollama(prompt, effective_config)
579
+ elif provider == "mistral":
580
+ return _generate_mistral(prompt, effective_config)
581
+ elif provider == "groq":
582
+ return _generate_groq(prompt, effective_config)
583
+ else:
584
+ raise RuntimeError(f"Unknown LLM provider '{provider}'.")
src/pipeline/ingest.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FR-01: Document Ingestion
3
+ =========================
4
+ Loads documents from:
5
+ - PubMedQA (HuggingFace: pubmed_qa, pqa_labeled) — up to 500 samples
6
+ - MedQA-USMLE (local JSONL from jind11/MedQA) — up to 200 samples
7
+
8
+ Then calls chunker.py to split and saves chunks to data/processed/chunks.jsonl.
9
+
10
+ Usage:
11
+ python src/pipeline/ingest.py
12
+ python src/pipeline/ingest.py --pubmedqa 500 --medqa 200
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import sys
17
+ import os
18
+ from pathlib import Path
19
+
20
+ # Make project root importable when running as a script
21
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
22
+
23
+ import argparse
24
+ import json
25
+ import logging
26
+ import uuid
27
+ import yaml
28
+ from typing import Any
29
+
30
+ import src # noqa: F401 — triggers logging setup
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Config
37
+ # ---------------------------------------------------------------------------
38
+
39
+ def _load_config() -> dict:
40
+ with open("config.yaml", "r", encoding="utf-8") as f:
41
+ return yaml.safe_load(f)
42
+
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # PubMedQA Ingestion (FR-01)
46
+ # ---------------------------------------------------------------------------
47
+
48
+ def ingest_pubmedqa(max_samples: int = 500) -> list[dict[str, Any]]:
49
+ """
50
+ Load PubMedQA from HuggingFace datasets.
51
+ Each QA item contributes its context passages (abstracts) as documents,
52
+ plus its long_answer if available.
53
+
54
+ pub_type = "research_abstract" → Tier 3 (SRS FR-03b)
55
+ """
56
+ # Use 'pqa_artificial' (211k rows) if asking for more than 1000,
57
+ # as 'pqa_labeled' only has 1000 rows.
58
+ split_name = "pqa_artificial" if max_samples > 1000 else "pqa_labeled"
59
+ logger.info("Loading PubMedQA split='%s' (max %d QA pairs)...", split_name, max_samples)
60
+ try:
61
+ from datasets import load_dataset
62
+ dataset = load_dataset(
63
+ "pubmed_qa", split_name, split="train", trust_remote_code=True
64
+ )
65
+ except Exception as exc:
66
+ logger.error("Failed to load PubMedQA from HuggingFace: %s", exc)
67
+ logger.error("Ensure you have an internet connection and datasets>=2.18.0")
68
+ return []
69
+
70
+ documents: list[dict] = []
71
+ for i, item in enumerate(dataset):
72
+ if i >= max_samples:
73
+ break
74
+
75
+ pub_id = str(item.get("pubid", uuid.uuid4().hex[:8]))
76
+ question = item.get("question", "")[:200]
77
+
78
+ # Index each context passage as a separate document
79
+ contexts: list[str] = item.get("context", {}).get("contexts", [])
80
+ for ctx in contexts:
81
+ if ctx and ctx.strip():
82
+ documents.append({
83
+ "text": ctx.strip(),
84
+ "title": question,
85
+ "doc_id": f"pubmedqa_{pub_id}",
86
+ "source": "pubmedqa",
87
+ "pub_type": "research_abstract",
88
+ "pub_year": 0,
89
+ "journal": "",
90
+ })
91
+
92
+ # Also index the long_answer (gold-standard explanation)
93
+ long_ans: str = item.get("long_answer", "").strip()
94
+ if long_ans:
95
+ documents.append({
96
+ "text": long_ans,
97
+ "title": question,
98
+ "doc_id": f"pubmedqa_{pub_id}_ans",
99
+ "source": "pubmedqa",
100
+ "pub_type": "research_abstract",
101
+ "pub_year": 0,
102
+ "journal": "",
103
+ })
104
+
105
+ logger.info(
106
+ "PubMedQA: %d documents loaded from %d QA items",
107
+ len(documents),
108
+ min(max_samples, len(dataset)),
109
+ )
110
+ return documents
111
+
112
+
113
+ # ---------------------------------------------------------------------------
114
+ # MedQA-USMLE Ingestion (FR-01)
115
+ # ---------------------------------------------------------------------------
116
+
117
+ def ingest_medqa(
118
+ data_dir: str = "data/raw/medqa",
119
+ max_samples: int = 200,
120
+ ) -> list[dict[str, Any]]:
121
+ """
122
+ Load MedQA-USMLE from local JSONL files.
123
+
124
+ To obtain the data:
125
+ git clone https://github.com/jind11/MedQA
126
+ Copy the JSONL files from data_clean/questions/US/ to data/raw/medqa/
127
+
128
+ pub_type = "exam_question" → Tier 5 (SRS FR-03b)
129
+ """
130
+ data_path = Path(data_dir)
131
+ jsonl_files = sorted(list(data_path.glob("*.jsonl")) + list(data_path.glob("**/*.jsonl")))
132
+
133
+ if not jsonl_files:
134
+ logger.warning(
135
+ "MedQA data not found at '%s'. "
136
+ "To get it: git clone https://github.com/jind11/MedQA "
137
+ "and copy JSONL files to %s/",
138
+ data_dir, data_dir,
139
+ )
140
+ return []
141
+
142
+ logger.info("Loading MedQA from '%s' (%d files)...", data_dir, len(jsonl_files))
143
+ documents: list[dict] = []
144
+
145
+ for jsonl_file in jsonl_files:
146
+ if len(documents) >= max_samples:
147
+ break
148
+ with open(jsonl_file, "r", encoding="utf-8") as f:
149
+ for raw_line in f:
150
+ if len(documents) >= max_samples:
151
+ break
152
+ raw_line = raw_line.strip()
153
+ if not raw_line:
154
+ continue
155
+ try:
156
+ item = json.loads(raw_line)
157
+ except json.JSONDecodeError as exc:
158
+ logger.warning("Skipping malformed JSON in %s: %s", jsonl_file.name, exc)
159
+ continue
160
+
161
+ question: str = item.get("question", "")
162
+ options: dict = item.get("options", {})
163
+ answer_key: str = item.get("answer", "")
164
+ answer_text: str = options.get(answer_key, "")
165
+
166
+ # Combine question + all options + correct answer as document text
167
+ opts_text = " ".join(f"{k}: {v}" for k, v in options.items())
168
+ text = f"Question: {question}\nOptions: {opts_text}"
169
+ if answer_text:
170
+ text += f"\nAnswer ({answer_key}): {answer_text}"
171
+
172
+ documents.append({
173
+ "text": text,
174
+ "title": question[:200],
175
+ "doc_id": f"medqa_{uuid.uuid4().hex[:10]}",
176
+ "source": "medqa",
177
+ "pub_type": "exam_question",
178
+ "pub_year": 0,
179
+ "journal": "",
180
+ })
181
+
182
+ logger.info("MedQA: %d documents loaded", len(documents))
183
+ return documents
184
+
185
+
186
+ # ---------------------------------------------------------------------------
187
+ # Helpers
188
+ # ---------------------------------------------------------------------------
189
+
190
+ def _save_raw_documents(documents: list[dict], output_path: str) -> None:
191
+ out = Path(output_path)
192
+ out.parent.mkdir(parents=True, exist_ok=True)
193
+ with open(out, "w", encoding="utf-8") as f:
194
+ for doc in documents:
195
+ f.write(json.dumps(doc, ensure_ascii=False) + "\n")
196
+ logger.info("Saved %d raw documents to %s", len(documents), output_path)
197
+
198
+
199
+ def _save_chunks(chunks: list[dict], output_path: str) -> None:
200
+ out = Path(output_path)
201
+ out.parent.mkdir(parents=True, exist_ok=True)
202
+ with open(out, "w", encoding="utf-8") as f:
203
+ for chunk in chunks:
204
+ f.write(json.dumps(chunk, ensure_ascii=False) + "\n")
205
+ logger.info("Saved %d chunks to %s", len(chunks), output_path)
206
+
207
+
208
+ # ---------------------------------------------------------------------------
209
+ # Main
210
+ # ---------------------------------------------------------------------------
211
+
212
+ def main() -> None:
213
+ parser = argparse.ArgumentParser(description="MediRAG-Eval Document Ingestion (FR-01)")
214
+ parser.add_argument("--pubmedqa", type=int, default=500, help="Max PubMedQA samples")
215
+ parser.add_argument("--medqa", type=int, default=200, help="Max MedQA-USMLE samples")
216
+ parser.add_argument(
217
+ "--medqa-dir", default="data/raw/medqa",
218
+ help="Directory containing MedQA JSONL files",
219
+ )
220
+ args = parser.parse_args()
221
+
222
+ config = _load_config()
223
+
224
+ # --- Ingest ---
225
+ pubmedqa_docs = ingest_pubmedqa(max_samples=args.pubmedqa)
226
+ medqa_docs = ingest_medqa(data_dir=args.medqa_dir, max_samples=args.medqa)
227
+ all_docs = pubmedqa_docs + medqa_docs
228
+
229
+ logger.info("Total documents ingested: %d", len(all_docs))
230
+
231
+ if not all_docs:
232
+ logger.error("No documents loaded. Check internet for PubMedQA and/or data/raw/medqa/ for MedQA.")
233
+ sys.exit(1)
234
+
235
+ # --- Save raw documents (for inspection) ---
236
+ _save_raw_documents(all_docs, "data/raw/documents.jsonl")
237
+
238
+ # --- Chunk ---
239
+ from src.pipeline.chunker import chunk_documents
240
+ chunks = chunk_documents(all_docs, config)
241
+ logger.info("Total chunks produced: %d", len(chunks))
242
+
243
+ # --- Save chunks for embedder ---
244
+ _save_chunks(chunks, "data/processed/chunks.jsonl")
245
+
246
+ logger.info("Ingestion complete. Now run: python src/pipeline/embedder.py")
247
+
248
+
249
+ if __name__ == "__main__":
250
+ main()
src/pipeline/privacy.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ src/pipeline/privacy.py — PHI/PII Privacy Shield (The Sanitizer)
3
+ ==============================================================
4
+ Detects and redacts sensitive patient information before external API calls.
5
+ Supports names, dates, contact info, and generic medical IDs.
6
+ """
7
+ from __future__ import annotations
8
+ import re
9
+ import logging
10
+ from typing import Dict, Tuple
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class PrivacyShield:
15
+ def __init__(self):
16
+ # Basic patterns for common PII
17
+ self.patterns = {
18
+ "EMAIL": r'[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+',
19
+ "PHONE": r'\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}',
20
+ "SSN": r'\b\d{3}-\d{2}-\d{4}\b',
21
+ "DOB": r'\b\d{2}/\d{2}/\d{4}\b|\b\d{4}-\d{2}-\d{2}\b',
22
+ "ID": r'\bPT-\d{4,8}\b|\bID:\s?\d{4,8}\b'
23
+ }
24
+ # Names are harder without heavy NER, so we start with common indicators or capital patterns
25
+ # In a production app, we would use a dedicated medical NER model.
26
+ self.name_pattern = r'\b(?:Mr\.|Ms\.|Mrs\.|Dr\.)\s[A-Z][a-z]+(?:\s[A-Z][a-z]+)?\b'
27
+
28
+ def redact(self, text: str) -> Tuple[str, Dict[str, str]]:
29
+ """
30
+ Redacts PHI in text and returns (redacted_text, placeholder_map).
31
+ """
32
+ mapping = {}
33
+ redacted = text
34
+
35
+ # 1. Redact specific patterns
36
+ for label, pattern in self.patterns.items():
37
+ matches = re.findall(pattern, redacted)
38
+ for i, match in enumerate(set(matches)):
39
+ placeholder = f"[{label}_{i+1}]"
40
+ mapping[placeholder] = match
41
+ redacted = redacted.replace(match, placeholder)
42
+
43
+ # 2. Redact potential names
44
+ name_matches = re.findall(self.name_pattern, redacted)
45
+ for i, match in enumerate(set(name_matches)):
46
+ placeholder = f"[PATIENT_NAME_{i+1}]"
47
+ mapping[placeholder] = match
48
+ redacted = redacted.replace(match, placeholder)
49
+
50
+ if mapping:
51
+ logger.info("Privacy Shield: Redacted %d sensitive items.", len(mapping))
52
+
53
+ return redacted, mapping
54
+
55
+ def restore(self, text: str, mapping: Dict[str, str]) -> str:
56
+ """
57
+ Replaces placeholders in the AI response with original values.
58
+ """
59
+ restored = text
60
+ for placeholder, original in mapping.items():
61
+ restored = restored.replace(placeholder, original)
62
+ return restored
63
+
64
+ # Singleton instance
65
+ shield = PrivacyShield()
src/pipeline/retriever.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FR-04: Vector Retrieval
3
+ =======================
4
+ FAISS IndexFlatIP with L2-normalised vectors (inner product = cosine similarity).
5
+ Returns top-k chunks as (chunk_text, metadata_dict, similarity_score) tuples.
6
+
7
+ Usage (as a module):
8
+ from src.pipeline.retriever import Retriever
9
+ r = Retriever(config)
10
+ results = r.search("What is the treatment for Type 2 Diabetes?")
11
+ for text, meta, score in results:
12
+ print(score, meta["pub_type"], text[:80])
13
+
14
+ Usage (smoke test):
15
+ python src/pipeline/retriever.py
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import sys
20
+ from pathlib import Path
21
+
22
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
23
+
24
+ import logging
25
+ import pickle
26
+ from typing import Any
27
+
28
+ try:
29
+ import faiss
30
+ except ImportError:
31
+ faiss = None
32
+
33
+ import numpy as np
34
+ import yaml
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ class Retriever:
40
+ """
41
+ Hybrid FAISS + BM25 document retriever.
42
+
43
+ On first search, lazily builds a BM25 index over all chunk texts.
44
+ Each search runs both FAISS (semantic) and BM25 (keyword) then merges
45
+ results using Reciprocal Rank Fusion (RRF) for best-of-both precision
46
+ and recall.
47
+ """
48
+
49
+ RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
50
+ RERANK_CANDIDATES = 60 # retrieve this many via RRF, then re-rank to top_k
51
+
52
+ def __init__(self, config: dict) -> None:
53
+ self.config = config
54
+ self.top_k: int = config["retrieval"]["top_k"]
55
+ self.model_name: str = config["retrieval"]["embedding_model"]
56
+ self.index_path: str = config["retrieval"]["index_path"]
57
+ self.meta_path: str = config["retrieval"]["metadata_path"]
58
+
59
+ self._model = None
60
+ self._reranker = None # cross-encoder re-ranker, loaded lazily
61
+ self._index = None
62
+ self._metadata: dict[int, dict] | None = None
63
+ self._bm25 = None # built lazily on first search
64
+ self._bm25_ids: list[int] = [] # maps bm25 row → faiss_idx
65
+
66
+ # ------------------------------------------------------------------
67
+ # Private loaders (lazy)
68
+ # ------------------------------------------------------------------
69
+
70
+ def _load_model(self) -> None:
71
+ if self._model is None:
72
+ try:
73
+ from sentence_transformers import SentenceTransformer
74
+ logger.info("Loading BioBERT: %s", self.model_name)
75
+ self._model = SentenceTransformer(self.model_name)
76
+ logger.info("BioBERT model loaded successfully.")
77
+ except ImportError as e:
78
+ logger.error("sentence_transformers not installed: %s", e)
79
+ self._model = None
80
+ except Exception as e:
81
+ logger.error("Failed to load embedding model '%s': %s — FAISS search will be skipped, falling back to BM25.", self.model_name, e)
82
+ self._model = None
83
+
84
+ def _load_reranker(self) -> None:
85
+ if self._reranker is None:
86
+ try:
87
+ from sentence_transformers import CrossEncoder
88
+ logger.info("Loading re-ranker: %s", self.RERANKER_MODEL)
89
+ self._reranker = CrossEncoder(self.RERANKER_MODEL)
90
+ logger.info("Re-ranker loaded.")
91
+ except Exception as e:
92
+ logger.warning("Re-ranker unavailable (%s) — falling back to RRF ranking.", e)
93
+ self._reranker = "unavailable"
94
+
95
+ def _load_index(self) -> None:
96
+ if self._index is not None:
97
+ return
98
+
99
+ idx_path = Path(self.index_path)
100
+ meta_path = Path(self.meta_path)
101
+
102
+ if not idx_path.exists():
103
+ raise FileNotFoundError(
104
+ f"FAISS index not found at '{idx_path}'. "
105
+ "Run python src/pipeline/ingest.py && python src/pipeline/embedder.py first."
106
+ )
107
+
108
+ try:
109
+ logger.info("Loading FAISS index from %s", idx_path)
110
+ if faiss is not None:
111
+ self._index = faiss.read_index(str(idx_path))
112
+ else:
113
+ self._index = None
114
+ logger.warning("FAISS not installed — FAISS search disabled.")
115
+
116
+ logger.info("Loading metadata store from %s", meta_path)
117
+ with open(meta_path, "rb") as f:
118
+ self._metadata = pickle.load(f)
119
+
120
+ logger.info(
121
+ "Retriever ready: %d vectors, %d metadata entries",
122
+ self._index.ntotal if self._index is not None else 0, len(self._metadata),
123
+ )
124
+ # Build drug→FDA chunks lookup (O(1) at query time)
125
+ self._fda_index: dict[str, list[int]] = {}
126
+ for idx, meta in self._metadata.items():
127
+ if meta.get("source") == "FDA DailyMed":
128
+ doc_id = meta.get("doc_id", "")
129
+ # doc_id format: fda_{drug_name}_{set_id}
130
+ parts = doc_id.split("_")
131
+ drug_key = parts[1].lower() if len(parts) >= 2 else ""
132
+ if drug_key:
133
+ self._fda_index.setdefault(drug_key, []).append(idx)
134
+ logger.info("FDA drug index built: %d unique drugs", len(self._fda_index))
135
+
136
+ # Build keyword→guideline chunks lookup for clinical guidelines
137
+ self._guideline_index: dict[str, list[int]] = {}
138
+ for idx, meta in self._metadata.items():
139
+ if meta.get("pub_type") == "clinical_guideline":
140
+ text = (meta.get("chunk_text", "") + " " + meta.get("title", "")).lower()
141
+ for keyword in [
142
+ # Diabetes / ADA
143
+ "diagnosis", "diagnostic", "treatment", "pharmacologic",
144
+ "glycemic", "insulin", "obesity", "hypoglycemia",
145
+ "screening", "complication", "pregnancy",
146
+ "children", "adolescent", "older adult", "hospital",
147
+ # Cardiovascular / ACC-AHA
148
+ "hypertension", "blood pressure", "antihypertensive",
149
+ "statin", "cholesterol", "ldl", "lipid", "triglyceride",
150
+ "cardiovascular", "coronary", "heart disease", "stroke",
151
+ "aspirin", "antiplatelet", "anticoagulant",
152
+ "prevention", "risk reduction", "atherosclerosis",
153
+ "heart failure", "ejection fraction",
154
+ "smoking", "exercise", "diet", "lifestyle",
155
+ ]:
156
+ if keyword in text:
157
+ self._guideline_index.setdefault(keyword, []).append(idx)
158
+ logger.info("Guideline index built: %d keyword entries", len(self._guideline_index))
159
+ except Exception as e:
160
+ logger.error("Failed to load FAISS index or metadata: %s", e)
161
+ self._index = None
162
+ if self._metadata is None:
163
+ self._metadata = {}
164
+
165
+ def _build_bm25(self) -> None:
166
+ """Build BM25 index from the loaded metadata store (called once)."""
167
+ if self._bm25 is not None:
168
+ return
169
+ self.rebuild_bm25()
170
+
171
+ def rebuild_bm25(self) -> None:
172
+ """Build BM25 index — loads from cache if available, otherwise builds and saves."""
173
+ try:
174
+ from rank_bm25 import BM25Okapi
175
+ except ImportError:
176
+ logger.warning("rank-bm25 not installed — falling back to FAISS-only.")
177
+ return
178
+
179
+ if self._metadata is None:
180
+ self._load_index()
181
+
182
+ # Cache path: alongside the metadata store
183
+ bm25_cache = Path(self.meta_path).parent / "bm25_cache.pkl"
184
+ meta_mtime = Path(self.meta_path).stat().st_mtime if Path(self.meta_path).exists() else 0
185
+
186
+ # Load from cache if it exists and is newer than the metadata store
187
+ if bm25_cache.exists() and bm25_cache.stat().st_mtime >= meta_mtime:
188
+ try:
189
+ logger.info("Loading BM25 index from cache %s …", bm25_cache)
190
+ with open(bm25_cache, "rb") as f:
191
+ cached = pickle.load(f)
192
+ self._bm25 = cached["bm25"]
193
+ self._bm25_ids = cached["ids"]
194
+ logger.info("BM25 cache loaded (%d docs).", len(self._bm25_ids))
195
+ return
196
+ except Exception as e:
197
+ logger.warning("BM25 cache load failed (%s) — rebuilding.", e)
198
+
199
+ logger.info("Rebuilding BM25 index over %d chunks…", len(self._metadata))
200
+ corpus_ids: list[int] = []
201
+ corpus_tokens: list[list[str]] = []
202
+ for faiss_idx, meta in self._metadata.items():
203
+ text = meta.get("chunk_text", "")
204
+ if text:
205
+ corpus_ids.append(faiss_idx)
206
+ corpus_tokens.append(text.lower().split())
207
+
208
+ self._bm25 = BM25Okapi(corpus_tokens)
209
+ self._bm25_ids = corpus_ids
210
+ logger.info("BM25 index built (%d docs). Saving cache…", len(corpus_ids))
211
+
212
+ try:
213
+ with open(bm25_cache, "wb") as f:
214
+ pickle.dump({"bm25": self._bm25, "ids": self._bm25_ids}, f,
215
+ protocol=pickle.HIGHEST_PROTOCOL)
216
+ logger.info("BM25 cache saved to %s", bm25_cache)
217
+ except Exception as e:
218
+ logger.warning("BM25 cache save failed: %s", e)
219
+
220
+ def get_fda_chunks(self, drug_name: str, section_priority: list[str] | None = None) -> list[dict]:
221
+ """
222
+ Directly return FDA DailyMed chunks for a specific drug by name.
223
+ Bypasses FAISS/BM25 ranking — O(1) lookup, always finds the drug's label.
224
+ Used during intervention re-retrieval when entity_verifier identifies a drug.
225
+ """
226
+ self._load_index()
227
+ key = drug_name.lower().strip()
228
+ indices = getattr(self, "_fda_index", {}).get(key, [])
229
+ if not indices:
230
+ # Try partial match (e.g. "warfarin sodium" → "warfarin")
231
+ indices = next(
232
+ (v for k, v in getattr(self, "_fda_index", {}).items() if key in k or k in key),
233
+ []
234
+ )
235
+ chunks = []
236
+ priority = section_priority or ["CONTRAINDICATIONS", "ADVERSE REACTIONS",
237
+ "DOSAGE AND ADMINISTRATION", "WARNINGS AND PRECAUTIONS",
238
+ "DRUG INTERACTIONS", "INDICATIONS AND USAGE",
239
+ "USE IN SPECIFIC POPULATIONS"]
240
+ for idx in indices:
241
+ meta = self._metadata.get(idx, {})
242
+ chunk_text = meta.get("chunk_text", "")
243
+ section = next((s for s in priority if s in chunk_text.upper()), "OTHER")
244
+ chunks.append({
245
+ "text": chunk_text, "chunk_id": meta.get("chunk_id"),
246
+ "source": meta.get("source", ""), "pub_type": meta.get("pub_type", ""),
247
+ "pub_year": meta.get("pub_year"), "title": meta.get("title", ""),
248
+ "_section": section, "_priority": priority.index(section) if section in priority else 99,
249
+ })
250
+ chunks.sort(key=lambda c: c["_priority"])
251
+ return chunks[:5]
252
+
253
+ def get_guideline_chunks(self, query: str, top_n: int = 5) -> list[dict]:
254
+ """
255
+ Return clinical guideline chunks relevant to the query via keyword matching.
256
+ Bypasses FAISS/BM25 ranking — used during intervention when retrieval fails.
257
+ """
258
+ self._load_index()
259
+ query_lower = query.lower()
260
+ guideline_idx = getattr(self, "_guideline_index", {})
261
+ if not guideline_idx:
262
+ return []
263
+
264
+ # Find matching indices — union of all matching keyword lists
265
+ matched: dict[int, int] = {} # idx → match count
266
+ for keyword, indices in guideline_idx.items():
267
+ if keyword in query_lower:
268
+ for idx in indices:
269
+ matched[idx] = matched.get(idx, 0) + 1
270
+
271
+ if not matched:
272
+ return []
273
+
274
+ # Sort by match count (most keyword hits first), take top_n
275
+ top_indices = sorted(matched, key=lambda i: matched[i], reverse=True)[:top_n]
276
+
277
+ chunks = []
278
+ for idx in top_indices:
279
+ meta = self._metadata.get(idx, {})
280
+ chunks.append({
281
+ "text": meta.get("chunk_text", ""),
282
+ "chunk_id": meta.get("chunk_id"),
283
+ "source": meta.get("source", ""),
284
+ "pub_type": meta.get("pub_type", "clinical_guideline"),
285
+ "pub_year": meta.get("pub_year"),
286
+ "title": meta.get("title", ""),
287
+ })
288
+ return chunks
289
+
290
+ # ------------------------------------------------------------------
291
+ # Public API
292
+ # ------------------------------------------------------------------
293
+
294
+ def search(
295
+ self,
296
+ query: str,
297
+ top_k: int | None = None,
298
+ ) -> list[tuple[str, dict[str, Any], float]]:
299
+ """
300
+ Hybrid semantic + keyword search using Reciprocal Rank Fusion.
301
+
302
+ Args:
303
+ query : Natural language query
304
+ top_k : Override config top_k if provided
305
+
306
+ Returns:
307
+ List of (chunk_text, metadata_dict, rrf_score),
308
+ sorted by descending combined score.
309
+ """
310
+ if not query or not query.strip():
311
+ logger.warning("Retriever.search called with empty query — returning []")
312
+ return []
313
+
314
+ k = top_k or self.top_k
315
+ # Fetch RERANK_CANDIDATES via RRF, then re-rank to top-k
316
+ fetch_k = max(self.RERANK_CANDIDATES, k * 3)
317
+ RRF_K = 60 # standard RRF constant (higher = smoother rank blending)
318
+
319
+ self._load_model()
320
+ self._load_reranker()
321
+ self._load_index()
322
+ self._build_bm25()
323
+
324
+ # ── 1. FAISS semantic search ──────────────────────────────────
325
+ faiss_ranks: dict[int, int] = {}
326
+ if self._model is not None and self._index is not None and faiss is not None:
327
+ try:
328
+ q_vec: np.ndarray = self._model.encode(
329
+ [query.strip()],
330
+ normalize_embeddings=True,
331
+ convert_to_numpy=True,
332
+ ).astype(np.float32)
333
+
334
+ scores_arr, idx_arr = self._index.search(q_vec, fetch_k)
335
+ faiss_scores = scores_arr[0]
336
+ faiss_indices = idx_arr[0]
337
+
338
+ # Map faiss_idx → rank (1-indexed)
339
+ for rank, (faiss_idx, score) in enumerate(zip(faiss_indices, faiss_scores), 1):
340
+ if faiss_idx != -1:
341
+ faiss_ranks[int(faiss_idx)] = rank
342
+
343
+ # Raw top-1 cosine similarity (IndexFlatIP + L2-norm = cosine).
344
+ # Used by main.py for coverage-gap detection — a poor match here
345
+ # means the topic is genuinely absent from the database.
346
+ _top_faiss_cosine = float(faiss_scores[0]) if len(faiss_scores) > 0 else 0.0
347
+ except Exception as e:
348
+ logger.error("FAISS search failed: %s", e)
349
+
350
+ # If FAISS failed but BM25 is available, continue with BM25-only (no stub)
351
+ if not faiss_ranks and self._bm25 is not None:
352
+ _top_faiss_cosine = 0.0 # no FAISS score available
353
+ logger.warning("FAISS model unavailable — using BM25-only search for this query.")
354
+
355
+ # Only return empty if BOTH are completely unavailable
356
+ if not faiss_ranks and self._bm25 is None:
357
+ logger.error("Both FAISS and BM25 are unavailable. Cannot retrieve. Check that the index exists and dependencies are installed.")
358
+ return []
359
+
360
+ # ── 2. BM25 keyword search ────────────────────────────────────
361
+ bm25_ranks: dict[int, int] = {}
362
+ if self._bm25 is not None:
363
+ query_tokens = query.lower().split()
364
+ bm25_scores_arr = self._bm25.get_scores(query_tokens)
365
+ # Get top fetch_k indices by BM25 score
366
+ top_bm25 = np.argsort(bm25_scores_arr)[::-1][:fetch_k]
367
+ for rank, corpus_pos in enumerate(top_bm25, 1):
368
+ if bm25_scores_arr[corpus_pos] > 0:
369
+ faiss_idx = self._bm25_ids[corpus_pos]
370
+ bm25_ranks[faiss_idx] = rank
371
+
372
+ # ── 3. Reciprocal Rank Fusion ─────────────────────────────────
373
+ # Score = 1/(k+rank_faiss) + 1/(k+rank_bm25)
374
+ # A chunk only in FAISS gets 1/(60+rank); only in BM25 gets 1/(60+rank)
375
+ # A chunk in BOTH gets the sum — it floats to the top
376
+ all_ids = set(faiss_ranks.keys()) | set(bm25_ranks.keys())
377
+ rrf_scores: dict[int, float] = {}
378
+ for faiss_idx in all_ids:
379
+ score = 0.0
380
+ if faiss_idx in faiss_ranks:
381
+ score += 1.0 / (RRF_K + faiss_ranks[faiss_idx])
382
+ if faiss_idx in bm25_ranks:
383
+ score += 1.0 / (RRF_K + bm25_ranks[faiss_idx])
384
+ rrf_scores[faiss_idx] = score
385
+
386
+ # Capture absolute quality BEFORE normalising (used for retrieval confidence gate)
387
+ max_rrf_absolute = max(rrf_scores.values()) if rrf_scores else 0.0
388
+
389
+ # Normalise RRF scores to [0, 1] for display
390
+ if rrf_scores and max_rrf_absolute > 0:
391
+ rrf_scores = {k: v / max_rrf_absolute for k, v in rrf_scores.items()}
392
+
393
+ # Sort by RRF score descending — take RERANK_CANDIDATES (not just top-k)
394
+ candidate_ids = sorted(rrf_scores.keys(), key=lambda i: rrf_scores[i], reverse=True)[:self.RERANK_CANDIDATES]
395
+
396
+ candidates: list[tuple[str, dict, float]] = []
397
+ for faiss_idx in candidate_ids:
398
+ meta = self._metadata.get(faiss_idx, {})
399
+ text = meta.get("chunk_text", "")
400
+ meta["_retrieval_confidence"] = round(max_rrf_absolute, 6)
401
+ meta["_top_faiss_cosine"] = round(_top_faiss_cosine, 4)
402
+ candidates.append((text, meta, rrf_scores[faiss_idx]))
403
+
404
+ # ── Re-ranking ────────────────────────────────────────────────────
405
+ # Cross-encoder scores every (query, chunk) pair directly.
406
+ # No volume bias — the right chunk wins on relevance regardless of source.
407
+ if self._reranker and self._reranker != "unavailable" and len(candidates) > k:
408
+ pairs = [(query, text) for text, _, _ in candidates]
409
+ rerank_scores = self._reranker.predict(pairs)
410
+ ranked = sorted(
411
+ zip(rerank_scores, candidates),
412
+ key=lambda x: x[0],
413
+ reverse=True,
414
+ )
415
+ results = [item for _, item in ranked[:k]]
416
+ logger.debug("Re-ranked %d candidates → top-%d", len(candidates), k)
417
+ else:
418
+ results = candidates[:k]
419
+
420
+ logger.debug(
421
+ "Hybrid query '%s...' → %d results (top RRF=%.4f) "
422
+ "[FAISS candidates: %d, BM25 candidates: %d]",
423
+ query[:40], len(results),
424
+ results[0][2] if results else 0.0,
425
+ len(faiss_ranks), len(bm25_ranks),
426
+ )
427
+ return results
428
+
429
+
430
+
431
+ # ---------------------------------------------------------------------------
432
+ # CLI smoke test
433
+ # ---------------------------------------------------------------------------
434
+
435
+ def _load_config() -> dict:
436
+ with open("config.yaml", "r", encoding="utf-8") as f:
437
+ return yaml.safe_load(f)
438
+
439
+
440
+ if __name__ == "__main__":
441
+ import src # noqa: F401 — logging
442
+ config = _load_config()
443
+ retriever = Retriever(config)
444
+
445
+ test_queries = [
446
+ "What is the recommended dosage of Metformin for Type 2 Diabetes in elderly patients?",
447
+ "Contraindications of ibuprofen for patients with chronic kidney disease",
448
+ "First-line treatment for hypertension according to clinical guidelines",
449
+ ]
450
+
451
+ for query in test_queries:
452
+ print(f"\n{'='*70}")
453
+ print(f"QUERY: {query}")
454
+ print("=" * 70)
455
+ results = retriever.search(query, top_k=3)
456
+ if not results:
457
+ print(" No results — is the FAISS index built?")
458
+ continue
459
+ for rank, (text, meta, score) in enumerate(results, 1):
460
+ print(f"\n Rank {rank} | score={score:.4f} | source={meta.get('source')} | "
461
+ f"tier_type={meta.get('pub_type')}")
462
+ print(f" Title: {meta.get('title', '')[:80]}")
463
+ print(f" Text : {text[:200]}...")
tests/test_api.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from fastapi.testclient import TestClient
3
+ from src.api.main import app
4
+
5
+ client = TestClient(app)
6
+
7
+ def test_health_endpoint():
8
+ """Test that the /health endpoint correctly reports system status."""
9
+ response = client.get("/health")
10
+ assert response.status_code == 200
11
+ data = response.json()
12
+ assert data["status"] == "ok"
13
+ assert "ollama_available" in data
14
+
15
+ def test_evaluate_endpoint():
16
+ """Test the /evaluate endpoint with mock claims."""
17
+ payload = {
18
+ "question": "Is Metformin safe?",
19
+ "answer": "Metformin is a safe and effective drug. It is recommended.",
20
+ "context_chunks": [
21
+ {
22
+ "chunk_id": "mock-1",
23
+ "text": "Metformin is a first-line medication for the treatment of type 2 diabetes. It is safe.",
24
+ "source": "mock_db",
25
+ "pub_type": "research_abstract",
26
+ "pub_year": 2024,
27
+ "title": "Study on Metformin safety"
28
+ }
29
+ ],
30
+ "run_ragas": False
31
+ }
32
+
33
+ # Since the evaluation modules load heavy ML models,
34
+ # the first test call might take 10-15s to run.
35
+ response = client.post("/evaluate", json=payload)
36
+ assert response.status_code == 200
37
+
38
+ data = response.json()
39
+ assert "composite_score" in data
40
+ assert "hrs" in data
41
+ assert data["risk_band"] in ["LOW", "MODERATE", "HIGH", "CRITICAL"]
42
+ assert "faithfulness" in data["module_results"]
43
+
44
+ def test_query_invalid_params():
45
+ """Test the /query validation rules."""
46
+ payload = {
47
+ "question": "Hi", # 2 chars — below min_length=5, triggers 422
48
+ "top_k": 5
49
+ }
50
+ response = client.post("/query", json=payload)
51
+ assert response.status_code == 422 # Unprocessable Entity (Pydantic validation error)
tests/test_modules.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from src.modules.faithfulness import score_faithfulness
3
+ from src.modules.source_credibility import score_source_credibility
4
+ from src.modules.contradiction import score_contradiction
5
+ from src.evaluation.aggregator import aggregate
6
+
7
+ def test_source_credibility():
8
+ chunks = [
9
+ {"chunk_id": "c1", "pub_type": "research_abstract", "title": "Mock Paper"},
10
+ {"chunk_id": "c2", "pub_type": "exam_question", "title": "Mock Exam Q"}
11
+ ]
12
+ results = score_source_credibility(chunks)
13
+ assert results.score > 0.0
14
+ assert 0.3 <= results.score <= 0.5
15
+ assert results.details["chunk_count"] == 2
16
+
17
+ def test_faithfulness_nli():
18
+ res_entail = score_faithfulness(
19
+ answer="The sky is blue.",
20
+ context_docs=["The sky is colored blue today."]
21
+ )
22
+ assert res_entail.score >= 0.8
23
+
24
+ res_contra = score_faithfulness(
25
+ answer="The sky is red.",
26
+ context_docs=["The sky is completely blue and not red."]
27
+ )
28
+ assert res_contra.score <= 0.2
29
+
30
+ def test_aggregator_logic():
31
+ # Mock config
32
+ test_cfg = {
33
+ "evaluation": {
34
+ "weights": {
35
+ "faithfulness": 0.4,
36
+ "entity_accuracy": 0.2,
37
+ "source_credibility": 0.2,
38
+ "contradiction_risk": 0.2,
39
+ "ragas_composite": 0.0
40
+ }
41
+ }
42
+ }
43
+
44
+ module_results = {
45
+ "faithfulness": {"score": 1.0},
46
+ "entity_verifier": {"score": 1.0},
47
+ "source_credibility": {"score": 0.5},
48
+ "contradiction": {"score": 1.0},
49
+ }
50
+
51
+ class MockResult:
52
+ def __init__(self, score, error=None):
53
+ self.score = score
54
+ self.error = error
55
+ self.latency_ms = 10
56
+
57
+ res = aggregate(
58
+ faithfulness_result=MockResult(1.0),
59
+ entity_result=MockResult(1.0),
60
+ source_result=MockResult(0.5),
61
+ contradiction_result=MockResult(1.0),
62
+ weights=test_cfg["evaluation"]["weights"]
63
+ )
64
+ assert abs(res.score - 0.9) < 0.01
65
+ assert res.details["hrs"] == 10
66
+ assert res.details["risk_band"] == "LOW"