Spaces:
Running
Running
Commit ·
b6f9fa8
0
Parent(s):
deploy: clean build
Browse files- .dockerignore +9 -0
- .gitattributes +9 -0
- .gitignore +39 -0
- Dockerfile +40 -0
- README.md +56 -0
- app.py +98 -0
- app_demo.py +172 -0
- config.yaml +71 -0
- conftest.py +12 -0
- demo/.gitkeep +1 -0
- pytest.ini +6 -0
- requirements.txt +35 -0
- requirements_hf.txt +46 -0
- requirements_minimal.txt +43 -0
- scripts/build_rxnorm_cache.py +347 -0
- scripts/debug_pmc.py +54 -0
- scripts/download_dailymed.py +259 -0
- scripts/download_guidelines.py +399 -0
- scripts/fix_fda_chunk_text.py +120 -0
- scripts/ingest_incremental.py +192 -0
- scripts/warmup.py +58 -0
- setup.py +15 -0
- src/__init__.py +44 -0
- src/api/__init__.py +1 -0
- src/api/main.py +933 -0
- src/api/schemas.py +276 -0
- src/cli.py +70 -0
- src/dashboard/__init__.py +1 -0
- src/evaluate.py +289 -0
- src/evaluation/__init__.py +1 -0
- src/evaluation/aggregator.py +173 -0
- src/evaluation/ragas_eval.py +177 -0
- src/modules/__init__.py +127 -0
- src/modules/base.py +4 -0
- src/modules/contradiction.py +259 -0
- src/modules/entity_verifier.py +334 -0
- src/modules/faithfulness.py +302 -0
- src/modules/source_credibility.py +204 -0
- src/pipeline/__init__.py +1 -0
- src/pipeline/chunker.py +82 -0
- src/pipeline/consensus.py +111 -0
- src/pipeline/embedder.py +163 -0
- src/pipeline/generator.py +584 -0
- src/pipeline/ingest.py +250 -0
- src/pipeline/privacy.py +65 -0
- src/pipeline/retriever.py +463 -0
- tests/test_api.py +51 -0
- tests/test_modules.py +66 -0
.dockerignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Override .gitignore for Docker builds
|
| 2 |
+
# Include necessary data files
|
| 3 |
+
!data/index/
|
| 4 |
+
!data/index/*
|
| 5 |
+
|
| 6 |
+
# Exclude everything else from data
|
| 7 |
+
data/raw/*
|
| 8 |
+
data/processed/*
|
| 9 |
+
logs/*
|
.gitattributes
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Spaces - Git Attributes
|
| 2 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.index filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore all large database files and directories
|
| 2 |
+
data/
|
| 3 |
+
|
| 4 |
+
# Python
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.pyc
|
| 7 |
+
*.pyo
|
| 8 |
+
*.pyd
|
| 9 |
+
*.egg-info/
|
| 10 |
+
dist/
|
| 11 |
+
build/
|
| 12 |
+
.eggs/
|
| 13 |
+
|
| 14 |
+
# Environments
|
| 15 |
+
venv/
|
| 16 |
+
.venv/
|
| 17 |
+
env/
|
| 18 |
+
.env
|
| 19 |
+
|
| 20 |
+
# Logs (generated at runtime)
|
| 21 |
+
logs/
|
| 22 |
+
|
| 23 |
+
# IDE
|
| 24 |
+
.vscode/
|
| 25 |
+
.idea/
|
| 26 |
+
*.suo
|
| 27 |
+
*.user
|
| 28 |
+
|
| 29 |
+
# OS
|
| 30 |
+
.DS_Store
|
| 31 |
+
Thumbs.db
|
| 32 |
+
|
| 33 |
+
# Notebooks checkpoints
|
| 34 |
+
.ipynb_checkpoints/
|
| 35 |
+
|
| 36 |
+
# Temporary files
|
| 37 |
+
*.tmp
|
| 38 |
+
*.bak
|
| 39 |
+
.env
|
Dockerfile
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MediRAG Backend - Hugging Face Spaces Docker Deployment
|
| 2 |
+
# Optimized for faster builds
|
| 3 |
+
FROM python:3.10-slim
|
| 4 |
+
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install system dependencies (libmupdf deps bundled in pymupdf wheel, no extra needed)
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
git \
|
| 10 |
+
curl \
|
| 11 |
+
build-essential \
|
| 12 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 13 |
+
|
| 14 |
+
# Set environment variables
|
| 15 |
+
ENV PYTHONUNBUFFERED=1
|
| 16 |
+
ENV TRANSFORMERS_CACHE=/tmp/transformers_cache
|
| 17 |
+
ENV HF_HOME=/tmp/hf_home
|
| 18 |
+
ENV TORCH_HOME=/tmp/torch_cache
|
| 19 |
+
ENV PIP_NO_CACHE_DIR=1
|
| 20 |
+
ENV PIP_DISABLE_PIP_VERSION_CHECK=1
|
| 21 |
+
|
| 22 |
+
# Copy requirements first for better caching
|
| 23 |
+
COPY requirements_minimal.txt .
|
| 24 |
+
|
| 25 |
+
# Force pip re-run by busting the cache (update this date to force full reinstall)
|
| 26 |
+
ARG CACHE_BUST=2026-04-12-v3
|
| 27 |
+
RUN pip install --no-cache-dir -r requirements_minimal.txt
|
| 28 |
+
|
| 29 |
+
# Copy the rest of the application
|
| 30 |
+
COPY . .
|
| 31 |
+
|
| 32 |
+
# Create necessary directories
|
| 33 |
+
RUN mkdir -p data/processed data/raw logs
|
| 34 |
+
|
| 35 |
+
# Expose port (Hugging Face Spaces uses 7860)
|
| 36 |
+
EXPOSE 7860
|
| 37 |
+
|
| 38 |
+
# Run FastAPI backend directly
|
| 39 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
| 40 |
+
|
README.md
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: MediRAG API
|
| 3 |
+
emoji: 🏥
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# MediRAG Backend - Hugging Face Spaces (Docker)
|
| 11 |
+
|
| 12 |
+
🏥 **Medical RAG System with Hallucination Detection**
|
| 13 |
+
|
| 14 |
+
This is the **backend API** for MediRAG 2.0, designed to work with a **React frontend**.
|
| 15 |
+
|
| 16 |
+
## 🐳 Docker Deployment
|
| 17 |
+
|
| 18 |
+
This Space provides the backend API. The React frontend connects to this backend.
|
| 19 |
+
|
| 20 |
+
### Backend Features
|
| 21 |
+
- 🔍 **Hybrid Retrieval**: FAISS (BioBERT) + BM25 keyword search
|
| 22 |
+
- 🧠 **LLM Generation**: Mistral/Gemini for medical answer generation
|
| 23 |
+
- 🛡️ **4-Layer Audit**: Faithfulness, Entity Verification, Source Credibility, Contradiction Detection
|
| 24 |
+
- ⚠️ **Safety Interventions**: Auto-blocks high-risk responses
|
| 25 |
+
- 📊 **Health Risk Score (HRS)**: 0-100 composite safety metric
|
| 26 |
+
- 🔌 **REST API**: Full FastAPI endpoints for React frontend
|
| 27 |
+
|
| 28 |
+
## 🚀 Usage
|
| 29 |
+
|
| 30 |
+
### For React Frontend
|
| 31 |
+
Connect your React app to this backend:
|
| 32 |
+
```javascript
|
| 33 |
+
const API_URL = "https://joytheslothh-medirag-api.hf.space";
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
### API Endpoints
|
| 37 |
+
- `GET /health` - Health check
|
| 38 |
+
- `POST /query` - Full RAG pipeline
|
| 39 |
+
- `POST /evaluate` - Evaluate answer
|
| 40 |
+
- `GET /docs` - Swagger API documentation
|
| 41 |
+
|
| 42 |
+
### Environment Variables
|
| 43 |
+
Set in Hugging Face Space settings:
|
| 44 |
+
- `MISTRAL_API_KEY` - For Mistral LLM
|
| 45 |
+
- `GOOGLE_API_KEY` - For Gemini LLM
|
| 46 |
+
|
| 47 |
+
## 🏗️ Architecture
|
| 48 |
+
```
|
| 49 |
+
React Frontend → FastAPI Backend → RAG Pipeline → Response
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## ⚠️ Disclaimer
|
| 53 |
+
**This system is for research purposes only. Always consult qualified medical professionals for health decisions.**
|
| 54 |
+
|
| 55 |
+
## 📄 License
|
| 56 |
+
MIT License - See repository for details.
|
app.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MediRAG Backend - FastAPI only (No Gradio)
|
| 3 |
+
React frontend on Vercel, this is just the API backend
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import subprocess
|
| 9 |
+
import logging
|
| 10 |
+
import requests
|
| 11 |
+
|
| 12 |
+
# Configure logging
|
| 13 |
+
logging.basicConfig(level=logging.INFO)
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
# Set cache directories for Hugging Face
|
| 17 |
+
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
|
| 18 |
+
os.environ["HF_HOME"] = "/tmp/hf_home"
|
| 19 |
+
os.environ["TORCH_HOME"] = "/tmp/torch_cache"
|
| 20 |
+
|
| 21 |
+
# Add src to path
|
| 22 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
|
| 23 |
+
|
| 24 |
+
# Install spaCy model if not present (optional — server starts without it)
|
| 25 |
+
try:
|
| 26 |
+
import spacy
|
| 27 |
+
try:
|
| 28 |
+
spacy.load("en_core_sci_lg")
|
| 29 |
+
logger.info("spaCy model en_core_sci_lg loaded.")
|
| 30 |
+
except OSError:
|
| 31 |
+
# Try installing the model at runtime
|
| 32 |
+
try:
|
| 33 |
+
logger.info("Attempting to install scispacy model en_core_sci_lg...")
|
| 34 |
+
subprocess.run([
|
| 35 |
+
sys.executable, "-m", "pip", "install", "--quiet",
|
| 36 |
+
"https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_lg-0.5.4.tar.gz"
|
| 37 |
+
], check=True, timeout=300)
|
| 38 |
+
spacy.load("en_core_sci_lg")
|
| 39 |
+
logger.info("spaCy model installed and loaded.")
|
| 40 |
+
except Exception as model_err:
|
| 41 |
+
logger.warning(f"Could not install spaCy model: {model_err}. NER features will be limited.")
|
| 42 |
+
except ImportError:
|
| 43 |
+
logger.warning("spacy/scispacy not installed. NER features will be limited but server will still start.")
|
| 44 |
+
|
| 45 |
+
# Download datasets using huggingface_hub
|
| 46 |
+
from huggingface_hub import hf_hub_download
|
| 47 |
+
|
| 48 |
+
# Check and download index and data files
|
| 49 |
+
data_dir = os.path.join(os.path.dirname(__file__), "data")
|
| 50 |
+
index_dir = os.path.join(data_dir, "index")
|
| 51 |
+
os.makedirs(index_dir, exist_ok=True)
|
| 52 |
+
|
| 53 |
+
faiss_path = os.path.join(index_dir, "faiss.index")
|
| 54 |
+
metadata_path = os.path.join(index_dir, "metadata_store.pkl")
|
| 55 |
+
bm25_path = os.path.join(index_dir, "bm25_cache.pkl")
|
| 56 |
+
vocab_path = os.path.join(data_dir, "drugbank vocabulary.csv")
|
| 57 |
+
rxnorm_path = os.path.join(data_dir, "rxnorm_cache.csv")
|
| 58 |
+
|
| 59 |
+
def download_dataset_files():
|
| 60 |
+
"""Download FAISS index and other core data from Hugging Face Dataset"""
|
| 61 |
+
repo_id = "joytheslothh/MediRAG-Index-Data"
|
| 62 |
+
token = os.environ.get("HF_TOKEN")
|
| 63 |
+
if not token:
|
| 64 |
+
logger.warning("HF_TOKEN environment variable is not set. Dataset download might fail if repo is private.")
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
if not os.path.exists(faiss_path):
|
| 68 |
+
logger.info("Downloading faiss.index from HF dataset...")
|
| 69 |
+
hf_hub_download(repo_id=repo_id, filename="index/faiss.index", local_dir=data_dir, repo_type="dataset", token=token)
|
| 70 |
+
if not os.path.exists(metadata_path):
|
| 71 |
+
logger.info("Downloading metadata_store.pkl from HF dataset...")
|
| 72 |
+
hf_hub_download(repo_id=repo_id, filename="index/metadata_store.pkl", local_dir=data_dir, repo_type="dataset", token=token)
|
| 73 |
+
if not os.path.exists(bm25_path):
|
| 74 |
+
logger.info("Downloading bm25_cache.pkl from HF dataset...")
|
| 75 |
+
hf_hub_download(repo_id=repo_id, filename="index/bm25_cache.pkl", local_dir=data_dir, repo_type="dataset", token=token)
|
| 76 |
+
if not os.path.exists(vocab_path):
|
| 77 |
+
logger.info("Downloading drugbank vocabulary.csv from HF dataset...")
|
| 78 |
+
hf_hub_download(repo_id=repo_id, filename="drugbank vocabulary.csv", local_dir=data_dir, repo_type="dataset", token=token)
|
| 79 |
+
if not os.path.exists(rxnorm_path):
|
| 80 |
+
logger.info("Downloading rxnorm_cache.csv from HF dataset...")
|
| 81 |
+
hf_hub_download(repo_id=repo_id, filename="rxnorm_cache.csv", local_dir=data_dir, repo_type="dataset", token=token)
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.error(f"Failed to download dataset files: {e}")
|
| 84 |
+
logger.warning("Backend may not start correctly or queries may fail.")
|
| 85 |
+
|
| 86 |
+
# Trigger download at startup
|
| 87 |
+
download_dataset_files()
|
| 88 |
+
|
| 89 |
+
# Import FastAPI app - this is the main backend for React frontend
|
| 90 |
+
from src.api.main import app
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
import uvicorn
|
| 94 |
+
# Get port from environment (Hugging Face uses 7860)
|
| 95 |
+
port = int(os.environ.get("PORT", 7860))
|
| 96 |
+
|
| 97 |
+
logger.info("Starting FastAPI backend on port {}".format(port))
|
| 98 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
app_demo.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MediRAG Backend - Local Demo Version
|
| 3 |
+
Simplified version for local testing without heavy models
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import gradio as gr
|
| 8 |
+
|
| 9 |
+
# Mock functions for demo
|
| 10 |
+
def health_check():
|
| 11 |
+
return {"status": "ok", "demo_mode": True}
|
| 12 |
+
|
| 13 |
+
def query_medical(question: str, top_k: int = 5, mistral_api_key: str = "", google_api_key: str = ""):
|
| 14 |
+
"""Demo version - returns mock response"""
|
| 15 |
+
|
| 16 |
+
# Simulate processing
|
| 17 |
+
demo_answer = f"""
|
| 18 |
+
This is a DEMO response for: "{question}"
|
| 19 |
+
|
| 20 |
+
In the full version, this would:
|
| 21 |
+
1. Retrieve relevant medical documents from FAISS index
|
| 22 |
+
2. Generate answer using Mistral/Gemini LLM
|
| 23 |
+
3. Evaluate with 4-layer audit system
|
| 24 |
+
4. Return Health Risk Score (HRS)
|
| 25 |
+
|
| 26 |
+
**To run full version:**
|
| 27 |
+
- Deploy to Hugging Face Spaces (Docker)
|
| 28 |
+
- Or install all dependencies locally
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
demo_output = f"""
|
| 32 |
+
🏥 **MEDICAL ANSWER (DEMO MODE)**
|
| 33 |
+
|
| 34 |
+
{demo_answer}
|
| 35 |
+
|
| 36 |
+
---
|
| 37 |
+
📊 **RISK ASSESSMENT**
|
| 38 |
+
• Health Risk Score (HRS): 25/100 (DEMO)
|
| 39 |
+
• Risk Band: LOW
|
| 40 |
+
• Confidence: MEDIUM
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
🧪 **MODULE SCORES (DEMO)**
|
| 44 |
+
✓ Faithfulness: 0.85
|
| 45 |
+
✓ Entity Accuracy: 0.90
|
| 46 |
+
✓ Source Credibility: 0.88
|
| 47 |
+
✓ Contradiction Risk: 0.95
|
| 48 |
+
|
| 49 |
+
---
|
| 50 |
+
📚 **TOP SOURCES (DEMO)**
|
| 51 |
+
📄 Source 1: PubMed - Clinical Study (Score: 0.923)
|
| 52 |
+
This is a placeholder for retrieved medical literature...
|
| 53 |
+
|
| 54 |
+
📄 Source 2: PMC - Systematic Review (Score: 0.891)
|
| 55 |
+
Another placeholder for medical evidence...
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
⏱️ Total Time: 1250ms (DEMO)
|
| 59 |
+
|
| 60 |
+
---
|
| 61 |
+
⚠️ **NOTE**: This is running in DEMO mode without the full ML models.
|
| 62 |
+
For full functionality, deploy to Hugging Face Spaces or install all dependencies.
|
| 63 |
+
""".strip()
|
| 64 |
+
|
| 65 |
+
return demo_output
|
| 66 |
+
|
| 67 |
+
# Create Gradio interface
|
| 68 |
+
with gr.Blocks(title="MediRAG - Medical AI Demo") as demo:
|
| 69 |
+
gr.Markdown("""
|
| 70 |
+
# 🏥 MediRAG 2.0 - DEMO MODE
|
| 71 |
+
## Medical Question Answering with Hallucination Detection
|
| 72 |
+
|
| 73 |
+
**⚠️ This is a DEMO version for local testing.**
|
| 74 |
+
|
| 75 |
+
The full version includes:
|
| 76 |
+
- 107,425+ medical documents in FAISS index
|
| 77 |
+
- BioBERT embeddings for retrieval
|
| 78 |
+
- Mistral/Gemini LLM for generation
|
| 79 |
+
- 4-layer audit system (DeBERTa-v3, SciSpaCy)
|
| 80 |
+
- Health Risk Score calculation
|
| 81 |
+
|
| 82 |
+
**Deploy to Hugging Face Spaces for full functionality:**
|
| 83 |
+
https://huggingface.co/spaces/joytheslothh/MediRAG-API
|
| 84 |
+
""")
|
| 85 |
+
|
| 86 |
+
with gr.Accordion("⚙️ API Configuration (Optional)", open=False):
|
| 87 |
+
gr.Markdown("""
|
| 88 |
+
In the full version, provide your API keys for LLM generation:
|
| 89 |
+
- **Mistral API Key**: https://console.mistral.ai/
|
| 90 |
+
- **Google API Key**: https://makersuite.google.com/app/apikey
|
| 91 |
+
""")
|
| 92 |
+
with gr.Row():
|
| 93 |
+
mistral_key_input = gr.Textbox(
|
| 94 |
+
label="Mistral API Key",
|
| 95 |
+
placeholder="Enter your Mistral API key (full version only)",
|
| 96 |
+
type="password",
|
| 97 |
+
value=""
|
| 98 |
+
)
|
| 99 |
+
google_key_input = gr.Textbox(
|
| 100 |
+
label="Google API Key (Gemini)",
|
| 101 |
+
placeholder="Enter your Google API key (full version only)",
|
| 102 |
+
type="password",
|
| 103 |
+
value=""
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
with gr.Row():
|
| 107 |
+
with gr.Column():
|
| 108 |
+
question_input = gr.Textbox(
|
| 109 |
+
label="Your Medical Question",
|
| 110 |
+
placeholder="e.g., What are the side effects of metformin?",
|
| 111 |
+
lines=3
|
| 112 |
+
)
|
| 113 |
+
top_k_slider = gr.Slider(
|
| 114 |
+
minimum=1,
|
| 115 |
+
maximum=10,
|
| 116 |
+
value=5,
|
| 117 |
+
step=1,
|
| 118 |
+
label="Number of Sources to Retrieve"
|
| 119 |
+
)
|
| 120 |
+
submit_btn = gr.Button("🔍 Ask MediRAG (Demo)", variant="primary")
|
| 121 |
+
|
| 122 |
+
with gr.Column():
|
| 123 |
+
output_text = gr.Markdown(label="Response")
|
| 124 |
+
|
| 125 |
+
submit_btn.click(
|
| 126 |
+
fn=query_medical,
|
| 127 |
+
inputs=[question_input, top_k_slider, mistral_key_input, google_key_input],
|
| 128 |
+
outputs=output_text
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
gr.Markdown("""
|
| 132 |
+
---
|
| 133 |
+
### 🚀 How to Run Full Version
|
| 134 |
+
|
| 135 |
+
**Option 1: Hugging Face Spaces (Recommended)**
|
| 136 |
+
```
|
| 137 |
+
1. Visit: https://huggingface.co/spaces/joytheslothh/MediRAG-API
|
| 138 |
+
2. The full app is already deployed there!
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
**Option 2: Local with Docker**
|
| 142 |
+
```bash
|
| 143 |
+
cd Backend
|
| 144 |
+
docker build -t medirag .
|
| 145 |
+
docker run -p 7860:7860 medirag
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
**Option 3: Local with Virtual Environment**
|
| 149 |
+
```bash
|
| 150 |
+
cd Backend
|
| 151 |
+
python -m venv venv
|
| 152 |
+
venv\Scripts\activate
|
| 153 |
+
pip install -r requirements_hf.txt
|
| 154 |
+
python -m spacy download en_core_sci_lg
|
| 155 |
+
python app.py
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
### 🔬 Full System Features
|
| 159 |
+
- **Faithfulness**: DeBERTa-v3 NLI model checks claim support
|
| 160 |
+
- **Entity Verification**: SciSpaCy + DrugBank for drug/dosage validation
|
| 161 |
+
- **Source Credibility**: Ranks evidence by publication tier
|
| 162 |
+
- **Contradiction Detection**: Internal NLI cross-check for self-contradictions
|
| 163 |
+
""")
|
| 164 |
+
|
| 165 |
+
if __name__ == "__main__":
|
| 166 |
+
port = int(os.environ.get("PORT", 7860))
|
| 167 |
+
demo.launch(
|
| 168 |
+
server_name="0.0.0.0",
|
| 169 |
+
server_port=port,
|
| 170 |
+
share=False,
|
| 171 |
+
show_error=True
|
| 172 |
+
)
|
config.yaml
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
retrieval:
|
| 2 |
+
top_k: 5
|
| 3 |
+
chunk_size: 512
|
| 4 |
+
chunk_overlap: 50
|
| 5 |
+
embedding_model: dmis-lab/biobert-v1.1
|
| 6 |
+
index_path: data/index/faiss.index
|
| 7 |
+
metadata_path: data/index/metadata_store.pkl
|
| 8 |
+
|
| 9 |
+
modules:
|
| 10 |
+
faithfulness:
|
| 11 |
+
nli_model: cnut1648/biolinkbert-mednli
|
| 12 |
+
entailment_threshold: 0.75
|
| 13 |
+
max_nli_tokens: 510
|
| 14 |
+
truncate_side: left # keep END of context (clinical values appear last)
|
| 15 |
+
deberta_batch_size: 4 # Colab T4: safe at 8 | CPU 16GB: use 2 | OOM: system retries at 1
|
| 16 |
+
entity_verifier:
|
| 17 |
+
spacy_model: en_ner_bc5cdr_md
|
| 18 |
+
critical_entity_types: [DRUG, DOSAGE]
|
| 19 |
+
dosage_tolerance_pct: 10 # >10% numerical difference → CRITICAL
|
| 20 |
+
rxnorm_api_url: https://rxnav.nlm.nih.gov/REST/approximateTerm.json
|
| 21 |
+
rxnorm_api_timeout_s: 3
|
| 22 |
+
rxnorm_cache_path: data/rxnorm_cache.csv
|
| 23 |
+
source_credibility:
|
| 24 |
+
method: keyword # "keyword" = demo (FR-11a) | "metadata" = May (FR-11b)
|
| 25 |
+
# tier weights are defined by name in src/modules/source_credibility.py TIER_WEIGHTS dict
|
| 26 |
+
# clinical_guideline=1.0, drug_label=0.90, systematic_review=0.85,
|
| 27 |
+
# research_abstract=0.70, review_article=0.60, clinical_case=0.50, unknown=0.30
|
| 28 |
+
contradiction:
|
| 29 |
+
nli_model: cnut1648/biolinkbert-mednli # same model as faithfulness — load once
|
| 30 |
+
confidence_threshold: 0.75
|
| 31 |
+
max_sentence_pairs: 45 # skip if N > 10 sentences, check adjacent + (first,last)
|
| 32 |
+
deberta_batch_size: 4
|
| 33 |
+
|
| 34 |
+
aggregator:
|
| 35 |
+
weights:
|
| 36 |
+
faithfulness: 0.35
|
| 37 |
+
entity_accuracy: 0.20
|
| 38 |
+
source_credibility: 0.20
|
| 39 |
+
contradiction_risk: 0.15
|
| 40 |
+
ragas_composite: 0.10
|
| 41 |
+
risk_bands:
|
| 42 |
+
low: [0, 30]
|
| 43 |
+
moderate: [31, 60]
|
| 44 |
+
high: [61, 85]
|
| 45 |
+
critical: [86, 100]
|
| 46 |
+
|
| 47 |
+
llm:
|
| 48 |
+
provider: mistral
|
| 49 |
+
gemini_api_key: ${GEMINI_API_KEY}
|
| 50 |
+
mistral_api_key: ${MISTRAL_API_KEY}
|
| 51 |
+
groq_api_key: ${GROQ_API_KEY}
|
| 52 |
+
model: mistral-large-latest
|
| 53 |
+
gemini_model: gemini-2.0-flash
|
| 54 |
+
groq_model: llama-3.3-70b-versatile
|
| 55 |
+
base_url: http://localhost:11434
|
| 56 |
+
timeout_seconds: 120
|
| 57 |
+
judge_temperature: 0.0
|
| 58 |
+
generation_temperature: 0.7
|
| 59 |
+
|
| 60 |
+
api:
|
| 61 |
+
host: 0.0.0.0
|
| 62 |
+
port: 8000
|
| 63 |
+
max_query_length: 500
|
| 64 |
+
max_answer_length: 2000
|
| 65 |
+
max_chunks: 10
|
| 66 |
+
max_chunk_length: 2000
|
| 67 |
+
|
| 68 |
+
logging:
|
| 69 |
+
level: INFO # set to WARNING on demo day
|
| 70 |
+
file: logs/medirag.log
|
| 71 |
+
format: "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
conftest.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
conftest.py — project root
|
| 3 |
+
Ensures src/ is on the Python path so all test files can import from src.*
|
| 4 |
+
without needing PYTHONPATH to be set manually. (SRS Section 17)
|
| 5 |
+
"""
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
# Add the src/ directory to path so `from modules.faithfulness import ...` works
|
| 10 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
|
| 11 |
+
# Also add project root so `import src` works
|
| 12 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
demo/.gitkeep
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# placeholder — demo_fallback.json generated by scripts/warmup.py
|
pytest.ini
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
testpaths = tests
|
| 3 |
+
python_files = test_*.py
|
| 4 |
+
python_classes = Test*
|
| 5 |
+
python_functions = test_*
|
| 6 |
+
addopts = -v --tb=short
|
requirements.txt
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
langchain==0.1.20
|
| 2 |
+
langchain-community==0.0.38
|
| 3 |
+
|
| 4 |
+
# FIX 1: faiss-cpu 1.7.4 doesn't exist on PyPI — 1.9.0+ has a compatible API
|
| 5 |
+
faiss-cpu>=1.9.0
|
| 6 |
+
|
| 7 |
+
# FIX 2: torch 2.2.0 has no Python 3.13 wheels — 2.5.0+ supports Python 3.13
|
| 8 |
+
torch>=2.5.0
|
| 9 |
+
|
| 10 |
+
# FIX 3: transformers 4.40.0 may have issues on Python 3.13 — use 4.44+
|
| 11 |
+
transformers>=4.44.0
|
| 12 |
+
sentence-transformers>=2.7.0
|
| 13 |
+
|
| 14 |
+
# scispacy + en_core_sci_lg: installed via conda, NOT here (see setup commands below)
|
| 15 |
+
# scispacy 0.5.4 pins scipy<1.11 which has no Python 3.12 pip wheels.
|
| 16 |
+
# Conda has pre-built scipy binaries — use: conda install -c conda-forge scispacy
|
| 17 |
+
|
| 18 |
+
ragas==0.1.9
|
| 19 |
+
fastapi==0.110.0
|
| 20 |
+
uvicorn==0.27.0
|
| 21 |
+
# streamlit>=1.35.0 # Removed - using React frontend instead
|
| 22 |
+
pyyaml==6.0.1
|
| 23 |
+
pydantic>=2.9.0 # 2.6.0 has broken pydantic.v1 on Python 3.12 (ForwardRef bug); fixed in 2.9+
|
| 24 |
+
datasets==2.18.0
|
| 25 |
+
pytest==8.1.0
|
| 26 |
+
httpx>=0.27.0,<0.28.0 # starlette 0.36.3 TestClient breaks with httpx 0.28+ (removed app= kwarg)
|
| 27 |
+
pandas>=2.2.0 # 2.2.0 has Python 3.12 wheels (no longer need 2.2.3+)
|
| 28 |
+
numpy>=1.26.4,<2 # langchain 0.1.20 requires numpy<2; use conda env for Python 3.12 (conda pre-builds numpy 1.x)
|
| 29 |
+
requests==2.31.0
|
| 30 |
+
google-genai>=1.0.0 # New Google GenAI SDK (replaces deprecated google-generativeai)
|
| 31 |
+
pysbd>=0.3.4 # sentence boundary detection (faithfulness module)
|
| 32 |
+
pymupdf>=1.24.0 # fitz: extracted text from PDF
|
| 33 |
+
python-docx>=1.1.0 # extracted text from DOCX
|
| 34 |
+
rank-bm25>=0.2.2 # keyword search for retriever
|
| 35 |
+
python-multipart>=0.0.12 # handle form data in FastAPI
|
requirements_hf.txt
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MediRAG Backend - Hugging Face Spaces Requirements
|
| 2 |
+
# Optimized for faster builds - relaxed version constraints
|
| 3 |
+
|
| 4 |
+
# Core dependencies
|
| 5 |
+
langchain>=0.1.0
|
| 6 |
+
langchain-community>=0.0.30
|
| 7 |
+
|
| 8 |
+
# Vector search
|
| 9 |
+
faiss-cpu>=1.9.0
|
| 10 |
+
|
| 11 |
+
# ML/DL frameworks
|
| 12 |
+
torch>=2.0.0
|
| 13 |
+
transformers>=4.40.0
|
| 14 |
+
sentence-transformers>=2.5.0
|
| 15 |
+
|
| 16 |
+
# Medical NLP - installed in Dockerfile instead
|
| 17 |
+
# scispacy>=0.5.4
|
| 18 |
+
# https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_lg-0.5.4.tar.gz
|
| 19 |
+
# https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_ner_bc5cdr_md-0.5.4.tar.gz
|
| 20 |
+
|
| 21 |
+
# Evaluation
|
| 22 |
+
ragas>=0.1.0
|
| 23 |
+
|
| 24 |
+
# API framework
|
| 25 |
+
fastapi>=0.110.0
|
| 26 |
+
uvicorn>=0.27.0
|
| 27 |
+
|
| 28 |
+
# Hugging Face Spaces - Gradio for API wrapper
|
| 29 |
+
gradio>=4.0.0,<5.0.0
|
| 30 |
+
|
| 31 |
+
# Utilities
|
| 32 |
+
pyyaml>=6.0.0
|
| 33 |
+
pydantic>=2.0.0
|
| 34 |
+
datasets>=2.18.0
|
| 35 |
+
pandas>=2.0.0
|
| 36 |
+
numpy>=1.26.0,<2
|
| 37 |
+
requests>=2.30.0
|
| 38 |
+
google-genai>=0.5.0
|
| 39 |
+
pysbd>=0.3.0
|
| 40 |
+
pymupdf>=1.24.0
|
| 41 |
+
python-docx>=1.1.0
|
| 42 |
+
rank-bm25>=0.2.0
|
| 43 |
+
python-multipart>=0.0.12
|
| 44 |
+
|
| 45 |
+
# Additional for Hugging Face
|
| 46 |
+
huggingface-hub>=0.20.0
|
requirements_minimal.txt
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MediRAG Backend - FastAPI only (no Gradio)
|
| 2 |
+
# React frontend on Vercel, this is just the API backend
|
| 3 |
+
|
| 4 |
+
# Core API
|
| 5 |
+
fastapi>=0.110.0
|
| 6 |
+
uvicorn>=0.27.0
|
| 7 |
+
python-multipart>=0.0.12
|
| 8 |
+
|
| 9 |
+
# Data handling
|
| 10 |
+
pydantic>=2.0.0
|
| 11 |
+
pyyaml>=6.0.0
|
| 12 |
+
numpy>=1.26.0,<2
|
| 13 |
+
pandas>=2.0.0
|
| 14 |
+
requests>=2.30.0
|
| 15 |
+
|
| 16 |
+
# Essential ML only
|
| 17 |
+
torch --index-url https://download.pytorch.org/whl/cpu
|
| 18 |
+
transformers>=4.40.0
|
| 19 |
+
sentence-transformers>=2.5.0
|
| 20 |
+
faiss-cpu>=1.9.0
|
| 21 |
+
|
| 22 |
+
# LLM integrations
|
| 23 |
+
langchain>=0.1.0
|
| 24 |
+
langchain-community>=0.0.30
|
| 25 |
+
google-genai>=0.5.0
|
| 26 |
+
ragas>=0.1.0
|
| 27 |
+
|
| 28 |
+
# Hugging Face Hub (for fetching FAISS index at runtime)
|
| 29 |
+
huggingface-hub>=0.20.0
|
| 30 |
+
datasets>=2.18.0
|
| 31 |
+
|
| 32 |
+
# File parsing (PDF, DOCX)
|
| 33 |
+
pymupdf>=1.24.0
|
| 34 |
+
python-docx>=1.1.0
|
| 35 |
+
|
| 36 |
+
# Medical NLP
|
| 37 |
+
spacy>=3.7.0
|
| 38 |
+
scispacy>=0.5.4
|
| 39 |
+
https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_lg-0.5.4.tar.gz
|
| 40 |
+
https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_ner_bc5cdr_md-0.5.4.tar.gz
|
| 41 |
+
pysbd>=0.3.0
|
| 42 |
+
rank-bm25>=0.2.0
|
| 43 |
+
|
scripts/build_rxnorm_cache.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FR-20: build_rxnorm_cache.py — Offline Drug Name Normalisation Cache Builder
|
| 3 |
+
=============================================================================
|
| 4 |
+
Accepts EITHER:
|
| 5 |
+
A) DrugBank vocabulary CSV (--drugbank-csv) ← recommended, immediate
|
| 6 |
+
B) DrugBank Open Data XML (--drugbank-xml) ← requires registration at drugbank.com
|
| 7 |
+
|
| 8 |
+
DrugBank vocabulary CSV is freely downloadable (no account needed) from:
|
| 9 |
+
https://go.drugbank.com/releases/latest#open-data → "DrugBank Vocabulary"
|
| 10 |
+
|
| 11 |
+
Queries RxNorm REST API (single approximateTerm call per drug) and saves
|
| 12 |
+
results to data/rxnorm_cache.csv.
|
| 13 |
+
|
| 14 |
+
Runtime:
|
| 15 |
+
~14,000 names × 0.1s delay × 1 API call ≈ 24 minutes
|
| 16 |
+
|
| 17 |
+
Usage:
|
| 18 |
+
python scripts/build_rxnorm_cache.py --drugbank-csv "data/drugbank vocabulary.csv"
|
| 19 |
+
python scripts/build_rxnorm_cache.py --drugbank-csv "data/drugbank vocabulary.csv" --dry-run 50
|
| 20 |
+
python scripts/build_rxnorm_cache.py --drugbank-xml data/raw/drugbank_open_data.xml
|
| 21 |
+
"""
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import argparse
|
| 25 |
+
import csv
|
| 26 |
+
import logging
|
| 27 |
+
import sys
|
| 28 |
+
import time
|
| 29 |
+
import xml.etree.ElementTree as ET
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
|
| 32 |
+
import requests
|
| 33 |
+
|
| 34 |
+
logging.basicConfig(
|
| 35 |
+
level=logging.INFO,
|
| 36 |
+
format="%(asctime)s [%(levelname)s] %(message)s",
|
| 37 |
+
)
|
| 38 |
+
logger = logging.getLogger("build_rxnorm_cache")
|
| 39 |
+
|
| 40 |
+
# RxNorm approximateTerm endpoint — returns rxcui + name in ONE call (v1.4 fix)
|
| 41 |
+
RXNORM_APPROX_URL = "https://rxnav.nlm.nih.gov/REST/approximateTerm.json"
|
| 42 |
+
|
| 43 |
+
# DrugBank Open Data XML namespace (XML path only)
|
| 44 |
+
NS = {"db": "http://www.drugbank.ca"}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
# Step 1A: Extract drug names from DrugBank Vocabulary CSV ← preferred
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
def extract_drug_names_from_csv(csv_path: str) -> list[str]:
|
| 52 |
+
"""
|
| 53 |
+
Parse the DrugBank vocabulary CSV and return all drug name strings.
|
| 54 |
+
|
| 55 |
+
CSV columns: DrugBank ID | Accession Numbers | Common name | CAS | UNII
|
| 56 |
+
| Synonyms | Standard InChI Key
|
| 57 |
+
|
| 58 |
+
Synonyms column is pipe-separated (e.g. "Drug A | Alias B | Trade Name C").
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
csv_path : path to the DrugBank vocabulary CSV file
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Sorted deduplicated list of drug name strings.
|
| 65 |
+
"""
|
| 66 |
+
path = Path(csv_path)
|
| 67 |
+
if not path.exists():
|
| 68 |
+
logger.error(
|
| 69 |
+
"DrugBank vocabulary CSV not found at '%s'. "
|
| 70 |
+
"Download it from https://go.drugbank.com/releases/latest#open-data "
|
| 71 |
+
"(look for 'DrugBank Vocabulary' — no account needed).",
|
| 72 |
+
csv_path,
|
| 73 |
+
)
|
| 74 |
+
sys.exit(1)
|
| 75 |
+
|
| 76 |
+
logger.info("Parsing DrugBank vocabulary CSV: %s", path)
|
| 77 |
+
names: set[str] = set()
|
| 78 |
+
|
| 79 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 80 |
+
reader = csv.DictReader(f)
|
| 81 |
+
for row in reader:
|
| 82 |
+
# Common name
|
| 83 |
+
common = row.get("Common name", "").strip()
|
| 84 |
+
if common:
|
| 85 |
+
names.add(common)
|
| 86 |
+
|
| 87 |
+
# Pipe-separated synonyms
|
| 88 |
+
synonyms_raw = row.get("Synonyms", "")
|
| 89 |
+
if synonyms_raw:
|
| 90 |
+
for syn in synonyms_raw.split("|"):
|
| 91 |
+
syn = syn.strip()
|
| 92 |
+
if syn:
|
| 93 |
+
names.add(syn)
|
| 94 |
+
|
| 95 |
+
result = sorted(names)
|
| 96 |
+
logger.info("Extracted %d unique drug names/synonyms from CSV", len(result))
|
| 97 |
+
return result
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# ---------------------------------------------------------------------------
|
| 101 |
+
# Step 1B: Extract drug names from DrugBank Open Data XML ← needs account
|
| 102 |
+
# ---------------------------------------------------------------------------
|
| 103 |
+
|
| 104 |
+
def extract_drug_names_from_xml(xml_path: str) -> list[str]:
|
| 105 |
+
"""
|
| 106 |
+
Parse DrugBank Open Data XML and extract all drug names + synonyms.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
xml_path : Path to drugbank_open_data.xml
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Sorted deduplicated list of drug name strings.
|
| 113 |
+
"""
|
| 114 |
+
logger.info("Parsing DrugBank XML: %s", xml_path)
|
| 115 |
+
try:
|
| 116 |
+
tree = ET.parse(xml_path)
|
| 117 |
+
except FileNotFoundError:
|
| 118 |
+
logger.error(
|
| 119 |
+
"DrugBank XML not found at '%s'. "
|
| 120 |
+
"Download it from https://go.drugbank.com/releases/latest#open-data "
|
| 121 |
+
"(free academic registration required), or use --drugbank-csv instead.",
|
| 122 |
+
xml_path,
|
| 123 |
+
)
|
| 124 |
+
sys.exit(1)
|
| 125 |
+
except ET.ParseError as exc:
|
| 126 |
+
logger.error("Failed to parse DrugBank XML: %s", exc)
|
| 127 |
+
sys.exit(1)
|
| 128 |
+
|
| 129 |
+
root = tree.getroot()
|
| 130 |
+
names: set[str] = set()
|
| 131 |
+
|
| 132 |
+
for drug in root.findall("db:drug", NS):
|
| 133 |
+
name_el = drug.find("db:name", NS)
|
| 134 |
+
if name_el is not None and name_el.text:
|
| 135 |
+
names.add(name_el.text.strip())
|
| 136 |
+
for syn in drug.findall("db:synonyms/db:synonym", NS):
|
| 137 |
+
if syn.text:
|
| 138 |
+
names.add(syn.text.strip())
|
| 139 |
+
for brand in drug.findall(
|
| 140 |
+
"db:international-brands/db:international-brand/db:name", NS
|
| 141 |
+
):
|
| 142 |
+
if brand.text:
|
| 143 |
+
names.add(brand.text.strip())
|
| 144 |
+
|
| 145 |
+
result = sorted(names)
|
| 146 |
+
logger.info("Extracted %d unique drug names/synonyms from XML", len(result))
|
| 147 |
+
return result
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# ---------------------------------------------------------------------------
|
| 152 |
+
# Step 2: Query RxNorm (single API call per drug — v1.4)
|
| 153 |
+
# ---------------------------------------------------------------------------
|
| 154 |
+
|
| 155 |
+
def query_rxnorm(drug_name: str, timeout: int = 5) -> tuple[str, str]:
|
| 156 |
+
"""
|
| 157 |
+
Look up a drug name in RxNorm using approximateTerm endpoint.
|
| 158 |
+
Returns (rxcui, canonical_name). Returns ("", "") on any failure.
|
| 159 |
+
|
| 160 |
+
Uses /approximateTerm — single HTTP call returning both rxcui and name.
|
| 161 |
+
(Previous 2-call approach was replaced in v1.4, cutting runtime by ~50%.)
|
| 162 |
+
"""
|
| 163 |
+
try:
|
| 164 |
+
resp = requests.get(
|
| 165 |
+
RXNORM_APPROX_URL,
|
| 166 |
+
params={"term": drug_name, "maxEntries": "1", "option": "1"},
|
| 167 |
+
timeout=timeout,
|
| 168 |
+
)
|
| 169 |
+
if resp.status_code != 200:
|
| 170 |
+
return "", ""
|
| 171 |
+
|
| 172 |
+
candidates: list[dict] = (
|
| 173 |
+
resp.json()
|
| 174 |
+
.get("approximateGroup", {})
|
| 175 |
+
.get("candidate", [])
|
| 176 |
+
)
|
| 177 |
+
if not candidates:
|
| 178 |
+
return "", ""
|
| 179 |
+
|
| 180 |
+
rxcui = candidates[0].get("rxcui", "")
|
| 181 |
+
name = candidates[0].get("name", drug_name) # fallback to input
|
| 182 |
+
return rxcui, name
|
| 183 |
+
|
| 184 |
+
except Exception:
|
| 185 |
+
return "", ""
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# ---------------------------------------------------------------------------
|
| 189 |
+
# Main
|
| 190 |
+
# ---------------------------------------------------------------------------
|
| 191 |
+
|
| 192 |
+
def main() -> None:
|
| 193 |
+
parser = argparse.ArgumentParser(
|
| 194 |
+
description="Build offline RxNorm cache from DrugBank data (FR-20)"
|
| 195 |
+
)
|
| 196 |
+
source = parser.add_mutually_exclusive_group()
|
| 197 |
+
source.add_argument(
|
| 198 |
+
"--drugbank-csv",
|
| 199 |
+
metavar="PATH",
|
| 200 |
+
default=None,
|
| 201 |
+
help=(
|
| 202 |
+
"Path to DrugBank vocabulary CSV [RECOMMENDED — no account needed]. "
|
| 203 |
+
"Download from https://go.drugbank.com/releases/latest#open-data"
|
| 204 |
+
),
|
| 205 |
+
)
|
| 206 |
+
source.add_argument(
|
| 207 |
+
"--drugbank-xml",
|
| 208 |
+
metavar="PATH",
|
| 209 |
+
default=None,
|
| 210 |
+
help="Path to DrugBank Open Data XML (requires free academic registration).",
|
| 211 |
+
)
|
| 212 |
+
parser.add_argument(
|
| 213 |
+
"--output-csv",
|
| 214 |
+
default="data/rxnorm_cache.csv",
|
| 215 |
+
help="Path for output CSV",
|
| 216 |
+
)
|
| 217 |
+
parser.add_argument(
|
| 218 |
+
"--delay",
|
| 219 |
+
type=float,
|
| 220 |
+
default=0.1,
|
| 221 |
+
help="Seconds to wait between API calls (default 0.1 — ~24 min total)",
|
| 222 |
+
)
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--dry-run",
|
| 225 |
+
type=int,
|
| 226 |
+
default=0,
|
| 227 |
+
metavar="N",
|
| 228 |
+
help="Only process first N drug names (for testing)",
|
| 229 |
+
)
|
| 230 |
+
parser.add_argument(
|
| 231 |
+
"--resume",
|
| 232 |
+
action="store_true",
|
| 233 |
+
help=(
|
| 234 |
+
"Resume a previously interrupted run. Reads already-completed entries "
|
| 235 |
+
"from --output-csv and skips them, appending only the missing ones."
|
| 236 |
+
),
|
| 237 |
+
)
|
| 238 |
+
args = parser.parse_args()
|
| 239 |
+
|
| 240 |
+
# ------------------------------------------------------------------
|
| 241 |
+
# Auto-detect source if neither flag was given
|
| 242 |
+
# ------------------------------------------------------------------
|
| 243 |
+
csv_default = "data/drugbank vocabulary.csv"
|
| 244 |
+
xml_default = "data/raw/drugbank_open_data.xml"
|
| 245 |
+
|
| 246 |
+
if args.drugbank_csv:
|
| 247 |
+
drug_names = extract_drug_names_from_csv(args.drugbank_csv)
|
| 248 |
+
elif args.drugbank_xml:
|
| 249 |
+
drug_names = extract_drug_names_from_xml(args.drugbank_xml)
|
| 250 |
+
elif Path(csv_default).exists():
|
| 251 |
+
logger.info("Auto-detected DrugBank vocabulary CSV at '%s'", csv_default)
|
| 252 |
+
drug_names = extract_drug_names_from_csv(csv_default)
|
| 253 |
+
elif Path(xml_default).exists():
|
| 254 |
+
logger.info("Auto-detected DrugBank XML at '%s'", xml_default)
|
| 255 |
+
drug_names = extract_drug_names_from_xml(xml_default)
|
| 256 |
+
else:
|
| 257 |
+
logger.error(
|
| 258 |
+
"No DrugBank source found. Pass --drugbank-csv or --drugbank-xml. "
|
| 259 |
+
"See script docstring for download links."
|
| 260 |
+
)
|
| 261 |
+
sys.exit(1)
|
| 262 |
+
|
| 263 |
+
if args.dry_run > 0:
|
| 264 |
+
drug_names = drug_names[: args.dry_run]
|
| 265 |
+
logger.info("Dry-run mode: processing %d names only", len(drug_names))
|
| 266 |
+
|
| 267 |
+
# ------------------------------------------------------------------
|
| 268 |
+
# Resume: skip names already in the output CSV
|
| 269 |
+
# ------------------------------------------------------------------
|
| 270 |
+
out_path = Path(args.output_csv)
|
| 271 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 272 |
+
|
| 273 |
+
already_done: set[str] = set()
|
| 274 |
+
if args.resume and out_path.exists():
|
| 275 |
+
try:
|
| 276 |
+
with open(out_path, "r", encoding="utf-8") as f:
|
| 277 |
+
reader = csv.DictReader(f)
|
| 278 |
+
for row in reader:
|
| 279 |
+
name = row.get("drug_name", "").strip()
|
| 280 |
+
if name:
|
| 281 |
+
already_done.add(name)
|
| 282 |
+
logger.info(
|
| 283 |
+
"Resume mode: %d entries already in cache — skipping these.",
|
| 284 |
+
len(already_done),
|
| 285 |
+
)
|
| 286 |
+
except Exception as exc:
|
| 287 |
+
logger.warning("Could not read existing cache for resume: %s", exc)
|
| 288 |
+
already_done = set()
|
| 289 |
+
|
| 290 |
+
remaining = [n for n in drug_names if n not in already_done]
|
| 291 |
+
skipped = len(drug_names) - len(remaining)
|
| 292 |
+
if skipped:
|
| 293 |
+
logger.info("Skipping %d already-resolved names. %d remaining.", skipped, len(remaining))
|
| 294 |
+
|
| 295 |
+
total = len(remaining)
|
| 296 |
+
if total == 0:
|
| 297 |
+
logger.info("Nothing to do — cache is already complete.")
|
| 298 |
+
sys.exit(0)
|
| 299 |
+
|
| 300 |
+
est_minutes = total * (args.delay + 0.05) / 60
|
| 301 |
+
logger.info(
|
| 302 |
+
"Starting cache build: %d names to process, delay=%.2fs, estimated %.0f minutes",
|
| 303 |
+
total, args.delay, est_minutes,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# ------------------------------------------------------------------
|
| 307 |
+
# Write CSV — append if resuming, overwrite otherwise
|
| 308 |
+
# ------------------------------------------------------------------
|
| 309 |
+
file_mode = "a" if args.resume and out_path.exists() and already_done else "w"
|
| 310 |
+
write_header = file_mode == "w"
|
| 311 |
+
|
| 312 |
+
found = len(already_done) # count previously resolved entries too
|
| 313 |
+
new_found = 0
|
| 314 |
+
|
| 315 |
+
with open(out_path, file_mode, newline="", encoding="utf-8") as f:
|
| 316 |
+
writer = csv.writer(f)
|
| 317 |
+
if write_header:
|
| 318 |
+
writer.writerow(["drug_name", "rxcui", "canonical_name"])
|
| 319 |
+
|
| 320 |
+
for i, name in enumerate(remaining):
|
| 321 |
+
rxcui, canonical = query_rxnorm(name)
|
| 322 |
+
writer.writerow([name, rxcui, canonical])
|
| 323 |
+
if rxcui:
|
| 324 |
+
new_found += 1
|
| 325 |
+
found += 1
|
| 326 |
+
|
| 327 |
+
if i % 25 == 0 or i == total - 1:
|
| 328 |
+
pct = 100 * (i + 1) / total
|
| 329 |
+
logger.info(
|
| 330 |
+
"Progress: %d/%d (%.1f%%) — %d resolved this run (%d total)",
|
| 331 |
+
i + 1, total, pct, new_found, found,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
time.sleep(args.delay)
|
| 335 |
+
|
| 336 |
+
logger.info(
|
| 337 |
+
"Cache saved to %s — %d/%d names resolved to RxNorm IDs (this run: +%d)",
|
| 338 |
+
out_path, found, len(drug_names), new_found,
|
| 339 |
+
)
|
| 340 |
+
logger.info(
|
| 341 |
+
"Commit this file to the repo: git add %s && git commit -m 'Add RxNorm cache'",
|
| 342 |
+
out_path,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
if __name__ == "__main__":
|
| 347 |
+
main()
|
scripts/debug_pmc.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests, re
|
| 2 |
+
from lxml import html
|
| 3 |
+
|
| 4 |
+
r = requests.get(
|
| 5 |
+
'https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10725812/',
|
| 6 |
+
headers={'User-Agent': 'Mozilla/5.0'},
|
| 7 |
+
timeout=15
|
| 8 |
+
)
|
| 9 |
+
tree = html.fromstring(r.content)
|
| 10 |
+
|
| 11 |
+
# Find main article body — skip nav/header
|
| 12 |
+
article = tree.xpath('//article') or tree.xpath('//*[@role="main"]') or tree.xpath('//div[@class="article"]')
|
| 13 |
+
root = article[0] if article else tree
|
| 14 |
+
print('Using root:', root.tag, root.get('class','')[:40])
|
| 15 |
+
|
| 16 |
+
# Find all sections with their h2/h3 and paragraphs
|
| 17 |
+
sections = root.xpath('.//section')
|
| 18 |
+
print(f'\nTotal sections: {len(sections)}')
|
| 19 |
+
|
| 20 |
+
# Show first Recommendations section content
|
| 21 |
+
for sec in sections:
|
| 22 |
+
h3 = sec.xpath('.//h3')
|
| 23 |
+
if h3 and 'Recommendation' in h3[0].text_content():
|
| 24 |
+
print('\n--- RECOMMENDATIONS SECTION ---')
|
| 25 |
+
print('H3:', h3[0].text_content().strip())
|
| 26 |
+
# Get all list items and paragraphs in this section
|
| 27 |
+
items = sec.xpath('.//li | .//p')
|
| 28 |
+
for item in items[:8]:
|
| 29 |
+
t = item.text_content().strip()
|
| 30 |
+
if t and len(t) > 20:
|
| 31 |
+
print(' TEXT:', t[:200])
|
| 32 |
+
break
|
| 33 |
+
|
| 34 |
+
# Check how rec numbers look — find paragraphs starting with N.N pattern
|
| 35 |
+
all_p = root.xpath('.//p')
|
| 36 |
+
print('\n--- PARAGRAPHS WITH REC NUMBERS ---')
|
| 37 |
+
rec_re = re.compile(r'^\s*\d+\.\d+[a-z]?\s+\w')
|
| 38 |
+
count = 0
|
| 39 |
+
for p in all_p:
|
| 40 |
+
t = p.text_content().strip()
|
| 41 |
+
if rec_re.match(t):
|
| 42 |
+
print(' REC:', t[:200])
|
| 43 |
+
count += 1
|
| 44 |
+
if count >= 5:
|
| 45 |
+
break
|
| 46 |
+
|
| 47 |
+
# Show structure of first H2 section
|
| 48 |
+
print('\n--- FIRST H2 SECTION STRUCTURE ---')
|
| 49 |
+
h2_secs = root.xpath('.//section[.//h2]')
|
| 50 |
+
if h2_secs:
|
| 51 |
+
sec = h2_secs[0]
|
| 52 |
+
print('H2:', sec.xpath('.//h2')[0].text_content().strip()[:60])
|
| 53 |
+
children = list(sec)
|
| 54 |
+
print('Direct children tags:', [c.tag for c in children[:10]])
|
scripts/download_dailymed.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
scripts/download_dailymed.py
|
| 3 |
+
============================
|
| 4 |
+
Downloads FDA DailyMed drug labels for common clinical drugs via the
|
| 5 |
+
DailyMed API and saves them as chunks.jsonl ready for ingestion into
|
| 6 |
+
the MediRAG FAISS index.
|
| 7 |
+
|
| 8 |
+
Sections extracted per drug:
|
| 9 |
+
- DOSAGE AND ADMINISTRATION
|
| 10 |
+
- CONTRAINDICATIONS
|
| 11 |
+
- WARNINGS AND PRECAUTIONS
|
| 12 |
+
- INDICATIONS AND USAGE
|
| 13 |
+
- DRUG INTERACTIONS
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
python scripts/download_dailymed.py
|
| 17 |
+
python scripts/download_dailymed.py --drugs metformin aspirin warfarin
|
| 18 |
+
python scripts/download_dailymed.py --output data/dailymed_chunks.jsonl
|
| 19 |
+
"""
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import json
|
| 24 |
+
import logging
|
| 25 |
+
import time
|
| 26 |
+
import xml.etree.ElementTree as ET
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
|
| 29 |
+
import requests
|
| 30 |
+
|
| 31 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
# Top 200 common clinical drugs (priority list)
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
TOP_DRUGS = [
|
| 38 |
+
"metformin", "atorvastatin", "lisinopril", "levothyroxine", "amlodipine",
|
| 39 |
+
"omeprazole", "metoprolol", "albuterol", "losartan", "gabapentin",
|
| 40 |
+
"sertraline", "simvastatin", "montelukast", "pantoprazole", "alprazolam",
|
| 41 |
+
"furosemide", "escitalopram", "rosuvastatin", "acetaminophen", "ibuprofen",
|
| 42 |
+
"amoxicillin", "azithromycin", "doxycycline", "prednisone", "warfarin",
|
| 43 |
+
"clopidogrel", "aspirin", "tamsulosin", "insulin glargine", "glipizide",
|
| 44 |
+
"hydrochlorothiazide", "amlodipine", "venlafaxine", "bupropion", "duloxetine",
|
| 45 |
+
"clonazepam", "lorazepam", "zolpidem", "quetiapine", "aripiprazole",
|
| 46 |
+
"olanzapine", "risperidone", "fluoxetine", "paroxetine", "citalopram",
|
| 47 |
+
"tramadol", "oxycodone", "morphine", "fentanyl", "naloxone",
|
| 48 |
+
"ciprofloxacin", "levofloxacin", "clindamycin", "metronidazole", "trimethoprim",
|
| 49 |
+
"enalapril", "ramipril", "carvedilol", "bisoprolol", "digoxin",
|
| 50 |
+
"spironolactone", "diltiazem", "verapamil", "nifedipine", "hydralazine",
|
| 51 |
+
"nitroglycerin", "isosorbide", "clopidogrel", "apixaban", "rivaroxaban",
|
| 52 |
+
"dabigatran", "heparin", "enoxaparin", "atorvastatin", "pravastatin",
|
| 53 |
+
"ezetimibe", "fenofibrate", "niacin", "gemfibrozil", "cholestyramine",
|
| 54 |
+
"allopurinol", "colchicine", "indomethacin", "naproxen", "celecoxib",
|
| 55 |
+
"hydroxychloroquine", "methotrexate", "leflunomide", "sulfasalazine",
|
| 56 |
+
"prednisolone", "dexamethasone", "budesonide", "fluticasone", "beclomethasone",
|
| 57 |
+
"ipratropium", "tiotropium", "salmeterol", "formoterol", "theophylline",
|
| 58 |
+
"insulin aspart", "insulin lispro", "sitagliptin", "saxagliptin", "empagliflozin",
|
| 59 |
+
"canagliflozin", "dapagliflozin", "liraglutide", "exenatide", "pioglitazone",
|
| 60 |
+
"acarbose", "repaglinide", "nateglinide", "glimepiride", "glyburide",
|
| 61 |
+
"levothyroxine", "methimazole", "propylthiouracil", "calcitonin", "alendronate",
|
| 62 |
+
"risedronate", "ibandronate", "denosumab", "teriparatide", "raloxifene",
|
| 63 |
+
"tamoxifen", "letrozole", "anastrozole", "exemestane", "fulvestrant",
|
| 64 |
+
"rituximab", "trastuzumab", "bevacizumab", "imatinib", "erlotinib",
|
| 65 |
+
"ondansetron", "metoclopramide", "promethazine", "prochlorperazine",
|
| 66 |
+
"loperamide", "bismuth subsalicylate", "lactulose", "polyethylene glycol",
|
| 67 |
+
"docusate", "senna", "mesalamine", "sulfasalazine", "infliximab",
|
| 68 |
+
"adalimumab", "etanercept", "ustekinumab", "secukinumab",
|
| 69 |
+
"acyclovir", "valacyclovir", "oseltamivir", "ribavirin", "sofosbuvir",
|
| 70 |
+
"fluconazole", "itraconazole", "voriconazole", "amphotericin b",
|
| 71 |
+
"vancomycin", "linezolid", "daptomycin", "meropenem", "piperacillin",
|
| 72 |
+
"phenytoin", "valproic acid", "carbamazepine", "levetiracetam", "lamotrigine",
|
| 73 |
+
"topiramate", "oxcarbazepine", "lacosamide", "brivaracetam",
|
| 74 |
+
"donepezil", "memantine", "rivastigmine", "galantamine",
|
| 75 |
+
"carbidopa levodopa", "pramipexole", "ropinirole", "rasagiline", "selegiline",
|
| 76 |
+
"baclofen", "tizanidine", "cyclobenzaprine", "methocarbamol",
|
| 77 |
+
"sildenafil", "tadalafil", "vardenafil", "finasteride", "dutasteride",
|
| 78 |
+
"testosterone", "estradiol", "progesterone", "medroxyprogesterone",
|
| 79 |
+
"methylphenidate", "amphetamine", "atomoxetine", "guanfacine", "clonidine",
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
# DailyMed sections we care about (LOINC codes)
|
| 83 |
+
SECTION_CODES = {
|
| 84 |
+
"34068-7": "DOSAGE AND ADMINISTRATION",
|
| 85 |
+
"34070-3": "CONTRAINDICATIONS",
|
| 86 |
+
"43685-7": "WARNINGS AND PRECAUTIONS",
|
| 87 |
+
"34067-9": "INDICATIONS AND USAGE",
|
| 88 |
+
"34073-7": "DRUG INTERACTIONS",
|
| 89 |
+
"34071-1": "WARNINGS",
|
| 90 |
+
"34084-4": "ADVERSE REACTIONS",
|
| 91 |
+
"34088-5": "OVERDOSAGE",
|
| 92 |
+
"34080-2": "USE IN SPECIFIC POPULATIONS",
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
DAILYMED_API = "https://dailymed.nlm.nih.gov/dailymed/services/v2"
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def search_drug(drug_name: str) -> str | None:
|
| 99 |
+
"""Return the SPL set_id for the first matching drug label."""
|
| 100 |
+
try:
|
| 101 |
+
r = requests.get(
|
| 102 |
+
f"{DAILYMED_API}/spls.json",
|
| 103 |
+
params={"drug_name": drug_name, "pagesize": 1},
|
| 104 |
+
timeout=10,
|
| 105 |
+
)
|
| 106 |
+
r.raise_for_status()
|
| 107 |
+
data = r.json()
|
| 108 |
+
results = data.get("data", [])
|
| 109 |
+
if results:
|
| 110 |
+
return results[0].get("setid")
|
| 111 |
+
except Exception as e:
|
| 112 |
+
logger.warning("Search failed for '%s': %s", drug_name, e)
|
| 113 |
+
return None
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def fetch_label_xml(set_id: str) -> str | None:
|
| 117 |
+
"""Download the full SPL XML for a given set_id."""
|
| 118 |
+
try:
|
| 119 |
+
r = requests.get(
|
| 120 |
+
f"{DAILYMED_API}/spls/{set_id}.xml",
|
| 121 |
+
timeout=15,
|
| 122 |
+
)
|
| 123 |
+
r.raise_for_status()
|
| 124 |
+
return r.text
|
| 125 |
+
except Exception as e:
|
| 126 |
+
logger.warning("XML fetch failed for set_id '%s': %s", set_id, e)
|
| 127 |
+
return None
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def extract_sections(xml_text: str, drug_name: str, set_id: str = "unknown") -> list[dict]:
|
| 131 |
+
"""Parse SPL XML and extract clinical sections as chunk dicts."""
|
| 132 |
+
chunks = []
|
| 133 |
+
try:
|
| 134 |
+
root = ET.fromstring(xml_text)
|
| 135 |
+
ns = {"hl7": "urn:hl7-org:v3"}
|
| 136 |
+
|
| 137 |
+
# Get brand/generic name from XML
|
| 138 |
+
title_el = root.find(".//hl7:title", ns)
|
| 139 |
+
label_title = title_el.text.strip() if title_el is not None and title_el.text else drug_name.title()
|
| 140 |
+
|
| 141 |
+
for section in root.findall(".//hl7:section", ns):
|
| 142 |
+
code_el = section.find("hl7:code", ns)
|
| 143 |
+
if code_el is None:
|
| 144 |
+
continue
|
| 145 |
+
code = code_el.get("code", "")
|
| 146 |
+
section_name = SECTION_CODES.get(code)
|
| 147 |
+
if not section_name:
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
# Extract text — handle tables specially so row data isn't lost
|
| 151 |
+
texts = []
|
| 152 |
+
for el in section.iter("{urn:hl7-org:v3}text"):
|
| 153 |
+
# Extract tables as readable rows before falling back to itertext
|
| 154 |
+
for table in el.findall(".//{urn:hl7-org:v3}table"):
|
| 155 |
+
rows = []
|
| 156 |
+
for tr in table.iter("{urn:hl7-org:v3}tr"):
|
| 157 |
+
cells = [" ".join(td.itertext()).strip()
|
| 158 |
+
for td in tr.iter("{urn:hl7-org:v3}td")]
|
| 159 |
+
if not cells:
|
| 160 |
+
cells = [" ".join(th.itertext()).strip()
|
| 161 |
+
for th in tr.iter("{urn:hl7-org:v3}th")]
|
| 162 |
+
row = " | ".join(c for c in cells if c)
|
| 163 |
+
if row:
|
| 164 |
+
rows.append(row)
|
| 165 |
+
if rows:
|
| 166 |
+
texts.append(" ; ".join(rows))
|
| 167 |
+
# Remove table from tree to avoid double-counting via itertext
|
| 168 |
+
el.remove(table) if table in list(el) else None
|
| 169 |
+
|
| 170 |
+
# Non-table text
|
| 171 |
+
text = " ".join(el.itertext()).strip()
|
| 172 |
+
if text:
|
| 173 |
+
texts.append(text)
|
| 174 |
+
full_text = " ".join(texts).strip()
|
| 175 |
+
|
| 176 |
+
if len(full_text) < 50:
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
# Truncate to 1500 chars per chunk (BioBERT max ~512 tokens)
|
| 180 |
+
for i in range(0, min(len(full_text), 6000), 1500):
|
| 181 |
+
segment = full_text[i:i+1500].strip()
|
| 182 |
+
if len(segment) < 50:
|
| 183 |
+
continue
|
| 184 |
+
chunk_id = f"fda_{drug_name.replace(' ', '_')}_{set_id}_{code}_{i}"
|
| 185 |
+
chunks.append({
|
| 186 |
+
"chunk_id": chunk_id,
|
| 187 |
+
"doc_id": f"fda_{drug_name.replace(' ', '_')}_{set_id}",
|
| 188 |
+
"chunk_text": f"[FDA DailyMed | {drug_name.title()} | {section_name}] {drug_name.title()} {section_name}: {segment}",
|
| 189 |
+
"chunk_index": i // 1500,
|
| 190 |
+
"total_chunks": max(1, min(4, len(full_text) // 1500 + 1)),
|
| 191 |
+
"pub_type": "drug_label",
|
| 192 |
+
"source": "FDA DailyMed",
|
| 193 |
+
"title": f"{label_title} — {section_name}",
|
| 194 |
+
"pub_year": 2024,
|
| 195 |
+
"journal": "FDA DailyMed",
|
| 196 |
+
"drug_name": drug_name,
|
| 197 |
+
"section": section_name,
|
| 198 |
+
})
|
| 199 |
+
except ET.ParseError as e:
|
| 200 |
+
logger.warning("XML parse error for '%s': %s", drug_name, e)
|
| 201 |
+
return chunks
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def download_dailymed(drug_list: list[str], output_path: str) -> None:
|
| 205 |
+
out = Path(output_path)
|
| 206 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 207 |
+
|
| 208 |
+
total_chunks = 0
|
| 209 |
+
failed = []
|
| 210 |
+
|
| 211 |
+
with open(out, "w", encoding="utf-8") as f:
|
| 212 |
+
for i, drug in enumerate(drug_list):
|
| 213 |
+
logger.info("[%d/%d] Processing: %s", i + 1, len(drug_list), drug)
|
| 214 |
+
|
| 215 |
+
set_id = search_drug(drug)
|
| 216 |
+
if not set_id:
|
| 217 |
+
logger.warning(" No DailyMed entry found for '%s'", drug)
|
| 218 |
+
failed.append(drug)
|
| 219 |
+
time.sleep(0.3)
|
| 220 |
+
continue
|
| 221 |
+
|
| 222 |
+
xml_text = fetch_label_xml(set_id)
|
| 223 |
+
if not xml_text:
|
| 224 |
+
failed.append(drug)
|
| 225 |
+
time.sleep(0.3)
|
| 226 |
+
continue
|
| 227 |
+
|
| 228 |
+
chunks = extract_sections(xml_text, drug, set_id=set_id)
|
| 229 |
+
for chunk in chunks:
|
| 230 |
+
f.write(json.dumps(chunk) + "\n")
|
| 231 |
+
|
| 232 |
+
total_chunks += len(chunks)
|
| 233 |
+
logger.info(" → %d chunks extracted (set_id: %s)", len(chunks), set_id)
|
| 234 |
+
time.sleep(0.4) # Be polite to the API
|
| 235 |
+
|
| 236 |
+
logger.info("Done. %d total chunks written to %s", total_chunks, out)
|
| 237 |
+
if failed:
|
| 238 |
+
logger.warning("Failed drugs (%d): %s", len(failed), ", ".join(failed))
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
if __name__ == "__main__":
|
| 242 |
+
parser = argparse.ArgumentParser()
|
| 243 |
+
parser.add_argument("--drugs", nargs="*", default=None,
|
| 244 |
+
help="Specific drug names (default: full TOP_DRUGS list)")
|
| 245 |
+
parser.add_argument("--output", default="data/dailymed_chunks.jsonl",
|
| 246 |
+
help="Output JSONL path")
|
| 247 |
+
parser.add_argument("--limit", type=int, default=None,
|
| 248 |
+
help="Limit number of drugs to download")
|
| 249 |
+
args = parser.parse_args()
|
| 250 |
+
|
| 251 |
+
drug_list = args.drugs or TOP_DRUGS
|
| 252 |
+
# Deduplicate while preserving order
|
| 253 |
+
seen: set[str] = set()
|
| 254 |
+
drug_list = [d for d in drug_list if not (d in seen or seen.add(d))]
|
| 255 |
+
if args.limit:
|
| 256 |
+
drug_list = drug_list[:args.limit]
|
| 257 |
+
|
| 258 |
+
logger.info("Downloading DailyMed labels for %d drugs...", len(drug_list))
|
| 259 |
+
download_dailymed(drug_list, args.output)
|
scripts/download_guidelines.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
scripts/download_guidelines.py
|
| 3 |
+
================================
|
| 4 |
+
Downloads clinical guidelines from PubMed Central (PMC) open-access API
|
| 5 |
+
and chunks them for ingestion into the MediRAG FAISS index.
|
| 6 |
+
|
| 7 |
+
Sources:
|
| 8 |
+
- ADA Standards of Medical Care in Diabetes 2024 (16 sections via PMC)
|
| 9 |
+
- More guidelines can be added to GUIDELINE_SOURCES below
|
| 10 |
+
|
| 11 |
+
Chunking strategy (based on structural analysis):
|
| 12 |
+
- Primary boundary: H2 clinical topic + Recommendations block + evidence narrative
|
| 13 |
+
- Never split a Recommendations block
|
| 14 |
+
- Store evidence grades (A/B/C/E) and recommendation numbers as metadata
|
| 15 |
+
|
| 16 |
+
Usage:
|
| 17 |
+
python scripts/download_guidelines.py
|
| 18 |
+
python scripts/download_guidelines.py --source ada_diabetes
|
| 19 |
+
python scripts/download_guidelines.py --dry-run
|
| 20 |
+
"""
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import json
|
| 25 |
+
import logging
|
| 26 |
+
import re
|
| 27 |
+
import time
|
| 28 |
+
import uuid
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
|
| 31 |
+
import requests
|
| 32 |
+
|
| 33 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
# Guideline sources — PMC IDs for ADA 2024 Standards of Care
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
GUIDELINE_SOURCES = {
|
| 40 |
+
"ada_diabetes": {
|
| 41 |
+
"name": "ADA Standards of Medical Care in Diabetes 2024",
|
| 42 |
+
"key": "ada",
|
| 43 |
+
"pub_type": "clinical_guideline",
|
| 44 |
+
"source": "American Diabetes Association",
|
| 45 |
+
"pub_year": 2024,
|
| 46 |
+
"journal": "Diabetes Care",
|
| 47 |
+
"sections": [
|
| 48 |
+
{"pmcid": "PMC10725812", "section": "2", "title": "Diagnosis and Classification of Diabetes"},
|
| 49 |
+
{"pmcid": "PMC10725809", "section": "4", "title": "Comprehensive Medical Evaluation and Assessment of Comorbidities"},
|
| 50 |
+
{"pmcid": "PMC10725816", "section": "5", "title": "Facilitating Positive Health Behaviors and Well-being"},
|
| 51 |
+
{"pmcid": "PMC10725808", "section": "6", "title": "Glycemic Goals and Hypoglycemia"},
|
| 52 |
+
{"pmcid": "PMC10725813", "section": "7", "title": "Diabetes Technology"},
|
| 53 |
+
{"pmcid": "PMC10725806", "section": "8", "title": "Obesity and Weight Management for the Prevention and Treatment of Type 2 Diabetes"},
|
| 54 |
+
{"pmcid": "PMC10725810", "section": "9", "title": "Pharmacologic Approaches to Glycemic Treatment"},
|
| 55 |
+
{"pmcid": "PMC10725804", "section": "13", "title": "Older Adults"},
|
| 56 |
+
{"pmcid": "PMC10725814", "section": "14", "title": "Children and Adolescents"},
|
| 57 |
+
{"pmcid": "PMC10725801", "section": "15", "title": "Management of Diabetes in Pregnancy"},
|
| 58 |
+
{"pmcid": "PMC10725815", "section": "16", "title": "Diabetes Care in the Hospital"},
|
| 59 |
+
{"pmcid": "PMC10725798", "section": "1", "title": "Improving Care and Promoting Health in Populations"},
|
| 60 |
+
],
|
| 61 |
+
},
|
| 62 |
+
"acc_aha_cholesterol": {
|
| 63 |
+
"name": "2018 ACC/AHA Guideline on Management of Blood Cholesterol",
|
| 64 |
+
"key": "acc_aha_chol",
|
| 65 |
+
"pub_type": "clinical_guideline",
|
| 66 |
+
"source": "American College of Cardiology/American Heart Association",
|
| 67 |
+
"pub_year": 2018,
|
| 68 |
+
"journal": "Circulation",
|
| 69 |
+
"sections": [
|
| 70 |
+
# PMC7403606: Grundy et al. 2018 executive summary, freely accessible full text
|
| 71 |
+
{"pmcid": "PMC7403606", "section": "1", "title": "Management of Blood Cholesterol — Statin Therapy and LDL Targets"},
|
| 72 |
+
],
|
| 73 |
+
},
|
| 74 |
+
"acc_aha_prevention": {
|
| 75 |
+
"name": "2019 ACC/AHA Guideline on Primary Prevention of Cardiovascular Disease",
|
| 76 |
+
"key": "acc_aha_prev",
|
| 77 |
+
"pub_type": "clinical_guideline",
|
| 78 |
+
"source": "American College of Cardiology/American Heart Association",
|
| 79 |
+
"pub_year": 2019,
|
| 80 |
+
"journal": "Journal of the American College of Cardiology",
|
| 81 |
+
"sections": [
|
| 82 |
+
# PMC7685565: Arnett et al. 2019, full guideline open access
|
| 83 |
+
{"pmcid": "PMC7685565", "section": "1", "title": "Primary Prevention — Blood Pressure, Cholesterol, Aspirin, Lifestyle"},
|
| 84 |
+
],
|
| 85 |
+
},
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
PMC_API = "https://www.ncbi.nlm.nih.gov/research/bionlp/RESTful/pmcoa.cgi/BioC_json/{pmcid}/unicode"
|
| 89 |
+
PMC_EFETCH = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
|
| 90 |
+
|
| 91 |
+
# Evidence grade pattern: single letter A/B/C/E at end of recommendation
|
| 92 |
+
_GRADE_RE = re.compile(r'\b([ABCE])\s*$')
|
| 93 |
+
# Recommendation number pattern: e.g. "9.18", "2.1a", "6.5b"
|
| 94 |
+
_REC_NUM_RE = re.compile(r'^(\d+\.\d+[a-z]?)\s+')
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
PMC_HTML_URL = "https://www.ncbi.nlm.nih.gov/pmc/articles/{pmcid}/"
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def fetch_pmc_xml(pmcid: str) -> str | None:
|
| 101 |
+
"""Fetch PMC article HTML page and extract clean structured text."""
|
| 102 |
+
try:
|
| 103 |
+
from lxml import html as lxml_html
|
| 104 |
+
url = PMC_HTML_URL.format(pmcid=pmcid)
|
| 105 |
+
r = requests.get(url, headers={"User-Agent": "Mozilla/5.0"}, timeout=30)
|
| 106 |
+
r.raise_for_status()
|
| 107 |
+
return _extract_pmc_html_text(lxml_html.fromstring(r.content))
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logger.warning("PMC HTML fetch failed for %s: %s", pmcid, e)
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _extract_pmc_html_text(tree) -> str:
|
| 114 |
+
"""
|
| 115 |
+
Extract clean structured text from PMC article HTML.
|
| 116 |
+
Uses lxml XPath to navigate the <article> element.
|
| 117 |
+
Deduplicates recommendation paragraphs (PMC renders them twice).
|
| 118 |
+
"""
|
| 119 |
+
# Get main article element
|
| 120 |
+
articles = tree.xpath('//article')
|
| 121 |
+
root = articles[0] if articles else tree
|
| 122 |
+
|
| 123 |
+
lines = []
|
| 124 |
+
seen_texts: set[str] = set() # Deduplication for repeated elements
|
| 125 |
+
|
| 126 |
+
def clean(el) -> str:
|
| 127 |
+
return " ".join(el.text_content().split()).strip()
|
| 128 |
+
|
| 129 |
+
def add_line(text: str) -> None:
|
| 130 |
+
if text and len(text) > 10 and text not in seen_texts:
|
| 131 |
+
seen_texts.add(text)
|
| 132 |
+
lines.append(text)
|
| 133 |
+
|
| 134 |
+
def extract_table(table_el):
|
| 135 |
+
"""Extract a table element as readable pipe-separated rows."""
|
| 136 |
+
caption = table_el.xpath('.//caption')
|
| 137 |
+
if caption:
|
| 138 |
+
add_line(f"[Table: {clean(caption[0])}]")
|
| 139 |
+
for tr in table_el.xpath('.//tr'):
|
| 140 |
+
cells = [" ".join(td.text_content().split()).strip()
|
| 141 |
+
for td in tr.xpath('.//td | .//th')]
|
| 142 |
+
row = " | ".join(c for c in cells if c)
|
| 143 |
+
if row:
|
| 144 |
+
add_line(row)
|
| 145 |
+
|
| 146 |
+
def process_section(sec, depth=0):
|
| 147 |
+
# Deep-search for tables first (they may be nested inside divs/figures)
|
| 148 |
+
for table in sec.xpath('.//table'):
|
| 149 |
+
# Only process tables whose nearest section ancestor is this sec
|
| 150 |
+
ancestors = table.xpath('ancestor::section')
|
| 151 |
+
if not ancestors or ancestors[-1] == sec:
|
| 152 |
+
extract_table(table)
|
| 153 |
+
|
| 154 |
+
for child in sec:
|
| 155 |
+
tag = child.tag.lower() if isinstance(child.tag, str) else ""
|
| 156 |
+
|
| 157 |
+
if tag in ("h1", "h2", "h3", "h4"):
|
| 158 |
+
text = clean(child)
|
| 159 |
+
if text and text not in ("Abstract", "References", "Footnotes"):
|
| 160 |
+
lines.append(f"\n{'#' * (depth + 2)} {text}")
|
| 161 |
+
|
| 162 |
+
elif tag == "p":
|
| 163 |
+
text = clean(child)
|
| 164 |
+
add_line(text)
|
| 165 |
+
|
| 166 |
+
elif tag in ("ul", "ol"):
|
| 167 |
+
for li in child.xpath('.//li'):
|
| 168 |
+
text = clean(li)
|
| 169 |
+
add_line(f"• {text}")
|
| 170 |
+
|
| 171 |
+
elif tag == "section":
|
| 172 |
+
process_section(child, depth + 1)
|
| 173 |
+
|
| 174 |
+
elif tag == "table":
|
| 175 |
+
pass # Already handled above via deep-search
|
| 176 |
+
|
| 177 |
+
elif tag == "div":
|
| 178 |
+
# Recurse into divs that might contain content
|
| 179 |
+
cls = child.get("class", "")
|
| 180 |
+
if any(k in cls for k in ("content", "body", "text", "article")):
|
| 181 |
+
process_section(child, depth)
|
| 182 |
+
|
| 183 |
+
for sec in root.xpath('.//section'):
|
| 184 |
+
# Only process top-level sections (not deeply nested)
|
| 185 |
+
parent = sec.getparent()
|
| 186 |
+
if parent is not None and parent.tag.lower() not in ("section",):
|
| 187 |
+
process_section(sec)
|
| 188 |
+
|
| 189 |
+
# If no sections found, fall back to all paragraphs
|
| 190 |
+
if len(lines) < 5:
|
| 191 |
+
for p in root.xpath('.//article//p | .//p[@class]'):
|
| 192 |
+
add_line(clean(p))
|
| 193 |
+
|
| 194 |
+
return "\n\n".join(l for l in lines if l.strip())
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def extract_recommendations(text: str) -> list[dict]:
|
| 198 |
+
"""Extract individual recommendations with their numbers and grades."""
|
| 199 |
+
recs = []
|
| 200 |
+
for line in text.split('\n'):
|
| 201 |
+
line = line.strip()
|
| 202 |
+
m = _REC_NUM_RE.match(line)
|
| 203 |
+
if m:
|
| 204 |
+
rec_num = m.group(1)
|
| 205 |
+
rec_text = line[m.end():].strip()
|
| 206 |
+
grade_m = _GRADE_RE.search(rec_text)
|
| 207 |
+
grade = grade_m.group(1) if grade_m else "E"
|
| 208 |
+
recs.append({"number": rec_num, "text": rec_text, "grade": grade})
|
| 209 |
+
return recs
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def chunk_guideline_text(
|
| 213 |
+
text: str,
|
| 214 |
+
section_meta: dict,
|
| 215 |
+
guideline_meta: dict,
|
| 216 |
+
max_chunk_chars: int = 2000,
|
| 217 |
+
) -> list[dict]:
|
| 218 |
+
"""
|
| 219 |
+
Chunk guideline text at ## heading boundaries produced by _extract_pmc_html_text.
|
| 220 |
+
Each chunk = H2/H3 topic + its paragraphs/recommendations.
|
| 221 |
+
"""
|
| 222 |
+
chunks = []
|
| 223 |
+
section_num = section_meta["section"]
|
| 224 |
+
section_title = section_meta["title"]
|
| 225 |
+
guideline_name = guideline_meta["name"]
|
| 226 |
+
source = guideline_meta["source"]
|
| 227 |
+
pub_year = guideline_meta["pub_year"]
|
| 228 |
+
pub_type = guideline_meta["pub_type"]
|
| 229 |
+
source_key = guideline_meta.get("key", "ada")
|
| 230 |
+
journal = guideline_meta.get("journal", "Diabetes Care")
|
| 231 |
+
|
| 232 |
+
# Split text into blocks at any ## heading
|
| 233 |
+
# Each block starts with a heading line and contains the following paragraphs
|
| 234 |
+
_HEADING_RE = re.compile(r'^(#{1,4})\s+(.+)$', re.MULTILINE)
|
| 235 |
+
|
| 236 |
+
# Find all heading positions
|
| 237 |
+
heading_matches = list(_HEADING_RE.finditer(text))
|
| 238 |
+
|
| 239 |
+
if not heading_matches:
|
| 240 |
+
# No headings found — chunk by size
|
| 241 |
+
blocks = [(section_title, text)]
|
| 242 |
+
else:
|
| 243 |
+
blocks = []
|
| 244 |
+
for i, m in enumerate(heading_matches):
|
| 245 |
+
heading_text = m.group(2).strip()
|
| 246 |
+
# Skip metadata headings
|
| 247 |
+
if heading_text in ("Abstract", "References", "Footnotes", "Author notes",
|
| 248 |
+
"Conflicts of interest", "Acknowledgments"):
|
| 249 |
+
continue
|
| 250 |
+
start = m.end()
|
| 251 |
+
end = heading_matches[i + 1].start() if i + 1 < len(heading_matches) else len(text)
|
| 252 |
+
content = text[start:end].strip()
|
| 253 |
+
if content:
|
| 254 |
+
blocks.append((heading_text, content))
|
| 255 |
+
|
| 256 |
+
def make_chunk(heading: str, content: str, part_idx: int = 0) -> dict:
|
| 257 |
+
recs = extract_recommendations(content)
|
| 258 |
+
rec_nums = [r["number"] for r in recs]
|
| 259 |
+
grades = {r["number"]: r["grade"] for r in recs}
|
| 260 |
+
grade_summary = "/".join(sorted(set(r["grade"] for r in recs))) if recs else ""
|
| 261 |
+
|
| 262 |
+
prefix = f"[{guideline_name} | Section {section_num}: {section_title} | {heading}]"
|
| 263 |
+
if grade_summary:
|
| 264 |
+
prefix += f" [Evidence: {grade_summary}]"
|
| 265 |
+
|
| 266 |
+
return {
|
| 267 |
+
"chunk_id": f"guideline_{source_key}_{section_num}_{uuid.uuid4().hex[:8]}_{part_idx}",
|
| 268 |
+
"doc_id": f"guideline_{source_key}_section_{section_num}",
|
| 269 |
+
"chunk_text": f"{prefix}\n{content}",
|
| 270 |
+
"chunk_index": len(chunks),
|
| 271 |
+
"total_chunks": 0,
|
| 272 |
+
"pub_type": pub_type,
|
| 273 |
+
"source": source,
|
| 274 |
+
"title": f"{guideline_name} — Section {section_num}: {heading}",
|
| 275 |
+
"pub_year": pub_year,
|
| 276 |
+
"journal": journal,
|
| 277 |
+
"section_number": section_num,
|
| 278 |
+
"section_title": section_title,
|
| 279 |
+
"h2_heading": heading,
|
| 280 |
+
"recommendation_numbers": rec_nums,
|
| 281 |
+
"evidence_grades": grades,
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
for heading, content in blocks:
|
| 285 |
+
if len(content) <= max_chunk_chars:
|
| 286 |
+
chunks.append(make_chunk(heading, content))
|
| 287 |
+
else:
|
| 288 |
+
# Split long blocks at paragraph boundaries
|
| 289 |
+
paras = [p.strip() for p in re.split(r'\n{2,}', content) if p.strip()]
|
| 290 |
+
current: list[str] = []
|
| 291 |
+
part = 0
|
| 292 |
+
for para in paras:
|
| 293 |
+
current.append(para)
|
| 294 |
+
if len("\n\n".join(current)) >= max_chunk_chars:
|
| 295 |
+
chunks.append(make_chunk(heading, "\n\n".join(current[:-1]), part))
|
| 296 |
+
current = [para]
|
| 297 |
+
part += 1
|
| 298 |
+
if current:
|
| 299 |
+
chunks.append(make_chunk(heading, "\n\n".join(current), part))
|
| 300 |
+
|
| 301 |
+
for chunk in chunks:
|
| 302 |
+
chunk["total_chunks"] = len(chunks)
|
| 303 |
+
|
| 304 |
+
return chunks
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def download_guidelines(source_key: str, output_path: str, dry_run: bool = False) -> None:
|
| 308 |
+
source = GUIDELINE_SOURCES[source_key]
|
| 309 |
+
out = Path(output_path)
|
| 310 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 311 |
+
|
| 312 |
+
total_chunks = 0
|
| 313 |
+
failed_sections = []
|
| 314 |
+
|
| 315 |
+
with open(out, "w", encoding="utf-8") as f:
|
| 316 |
+
for section in source["sections"]:
|
| 317 |
+
pmcid = section["pmcid"]
|
| 318 |
+
logger.info("Fetching %s — Section %s: %s", pmcid, section["section"], section["title"])
|
| 319 |
+
|
| 320 |
+
text = fetch_pmc_xml(pmcid)
|
| 321 |
+
|
| 322 |
+
if not text or len(text) < 200:
|
| 323 |
+
logger.warning("No text retrieved for %s — skipping", pmcid)
|
| 324 |
+
failed_sections.append(section["title"])
|
| 325 |
+
time.sleep(0.5)
|
| 326 |
+
continue
|
| 327 |
+
|
| 328 |
+
logger.info(" Retrieved %d chars", len(text))
|
| 329 |
+
|
| 330 |
+
chunks = chunk_guideline_text(text, section, source)
|
| 331 |
+
logger.info(" → %d chunks extracted", len(chunks))
|
| 332 |
+
|
| 333 |
+
if dry_run:
|
| 334 |
+
if chunks:
|
| 335 |
+
logger.info(" Sample chunk:\n%s\n...", chunks[0]["chunk_text"][:300])
|
| 336 |
+
continue
|
| 337 |
+
|
| 338 |
+
for chunk in chunks:
|
| 339 |
+
f.write(json.dumps(chunk) + "\n")
|
| 340 |
+
|
| 341 |
+
total_chunks += len(chunks)
|
| 342 |
+
time.sleep(0.5) # Be polite to NCBI API
|
| 343 |
+
|
| 344 |
+
if not dry_run:
|
| 345 |
+
logger.info("Done. %d total chunks written to %s", total_chunks, out)
|
| 346 |
+
if failed_sections:
|
| 347 |
+
logger.warning("Failed sections: %s", failed_sections)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
if __name__ == "__main__":
|
| 351 |
+
parser = argparse.ArgumentParser()
|
| 352 |
+
parser.add_argument("--source", default=None,
|
| 353 |
+
choices=list(GUIDELINE_SOURCES.keys()),
|
| 354 |
+
help="Guideline source to download (default: all sources)")
|
| 355 |
+
parser.add_argument("--all", action="store_true",
|
| 356 |
+
help="Download all guideline sources")
|
| 357 |
+
parser.add_argument("--output", default="data/guidelines_chunks.jsonl")
|
| 358 |
+
parser.add_argument("--dry-run", action="store_true",
|
| 359 |
+
help="Fetch and parse but don't write output")
|
| 360 |
+
args = parser.parse_args()
|
| 361 |
+
|
| 362 |
+
sources_to_run = list(GUIDELINE_SOURCES.keys()) if (args.all or args.source is None) else [args.source]
|
| 363 |
+
|
| 364 |
+
for source_key in sources_to_run:
|
| 365 |
+
logger.info("Downloading: %s", GUIDELINE_SOURCES[source_key]["name"])
|
| 366 |
+
# For multi-source runs, append non-ada sources to the same output file
|
| 367 |
+
if source_key == sources_to_run[0]:
|
| 368 |
+
download_guidelines(source_key, args.output, dry_run=args.dry_run)
|
| 369 |
+
else:
|
| 370 |
+
# Append to existing file by re-opening in append mode
|
| 371 |
+
out = Path(args.output)
|
| 372 |
+
source = GUIDELINE_SOURCES[source_key]
|
| 373 |
+
total_chunks = 0
|
| 374 |
+
failed_sections = []
|
| 375 |
+
with open(out, "a", encoding="utf-8") as f:
|
| 376 |
+
for section in source["sections"]:
|
| 377 |
+
pmcid = section["pmcid"]
|
| 378 |
+
logger.info("Fetching %s — Section %s: %s", pmcid, section["section"], section["title"])
|
| 379 |
+
text = fetch_pmc_xml(pmcid)
|
| 380 |
+
if not text or len(text) < 200:
|
| 381 |
+
logger.warning("No text retrieved for %s — skipping", pmcid)
|
| 382 |
+
failed_sections.append(section["title"])
|
| 383 |
+
time.sleep(0.5)
|
| 384 |
+
continue
|
| 385 |
+
logger.info(" Retrieved %d chars", len(text))
|
| 386 |
+
chunks = chunk_guideline_text(text, section, source)
|
| 387 |
+
logger.info(" → %d chunks extracted", len(chunks))
|
| 388 |
+
if args.dry_run:
|
| 389 |
+
if chunks:
|
| 390 |
+
logger.info(" Sample chunk:\n%s\n...", chunks[0]["chunk_text"][:300])
|
| 391 |
+
continue
|
| 392 |
+
for chunk in chunks:
|
| 393 |
+
f.write(json.dumps(chunk) + "\n")
|
| 394 |
+
total_chunks += len(chunks)
|
| 395 |
+
time.sleep(0.5)
|
| 396 |
+
if not args.dry_run:
|
| 397 |
+
logger.info("Done. %d total chunks written for %s", total_chunks, source_key)
|
| 398 |
+
if failed_sections:
|
| 399 |
+
logger.warning("Failed sections: %s", failed_sections)
|
scripts/fix_fda_chunk_text.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
scripts/fix_fda_chunk_text.py
|
| 3 |
+
==============================
|
| 4 |
+
One-time fix: replaces the verbose FDA boilerplate prefix in all FDA DailyMed
|
| 5 |
+
chunk_text entries in the metadata store with a clean, BM25-friendly prefix.
|
| 6 |
+
|
| 7 |
+
Before: [FDA DRUG LABEL — These highlights do not include all the information
|
| 8 |
+
needed to use WARFARIN SODIUM TABLETS safely and effectively...]
|
| 9 |
+
CONTRAINDICATIONS: actual content...
|
| 10 |
+
|
| 11 |
+
After: [FDA DailyMed | Warfarin | CONTRAINDICATIONS] actual content...
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python scripts/fix_fda_chunk_text.py
|
| 15 |
+
python scripts/fix_fda_chunk_text.py --dry-run
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import logging
|
| 21 |
+
import pickle
|
| 22 |
+
import re
|
| 23 |
+
import sys
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 27 |
+
import yaml
|
| 28 |
+
|
| 29 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
SECTION_CODES = {
|
| 33 |
+
"34068-7": "DOSAGE AND ADMINISTRATION",
|
| 34 |
+
"34070-3": "CONTRAINDICATIONS",
|
| 35 |
+
"43685-7": "WARNINGS AND PRECAUTIONS",
|
| 36 |
+
"34067-9": "INDICATIONS AND USAGE",
|
| 37 |
+
"34073-7": "DRUG INTERACTIONS",
|
| 38 |
+
"34071-1": "WARNINGS",
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# Matches both old boilerplate and previously-fixed format
|
| 42 |
+
_BOILERPLATE_RE = re.compile(r"^\[FDA[^\]]*\]\s*(?:[A-Za-z][^:]*:\s*)?", re.DOTALL)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def fix_chunk_text(chunk_id: str, old_text: str) -> str:
|
| 46 |
+
"""Return cleaned chunk_text with a compact, keyword-rich prefix."""
|
| 47 |
+
# Extract drug name from chunk_id: fda_{drug_name}_{set_id}_{code}_{offset}
|
| 48 |
+
parts = chunk_id.split("_")
|
| 49 |
+
# parts[0] = "fda", parts[1] = drug_name (may be multi-word), then UUID parts, then code, then offset
|
| 50 |
+
# Find the section code in parts
|
| 51 |
+
section_name = None
|
| 52 |
+
drug_name_parts = []
|
| 53 |
+
for i, part in enumerate(parts[1:], 1):
|
| 54 |
+
if part in SECTION_CODES:
|
| 55 |
+
section_name = SECTION_CODES[part]
|
| 56 |
+
drug_name_parts = parts[1:i]
|
| 57 |
+
break
|
| 58 |
+
|
| 59 |
+
# Filter out UUID parts (set_id format: 8hex-4hex-...) from drug name
|
| 60 |
+
_UUID_RE = re.compile(r'^[0-9a-f]{8}-', re.I)
|
| 61 |
+
drug_name_parts = [p for p in drug_name_parts if not _UUID_RE.match(p)]
|
| 62 |
+
drug_name = " ".join(drug_name_parts).replace("_", " ").title() if drug_name_parts else "Unknown"
|
| 63 |
+
|
| 64 |
+
if not section_name:
|
| 65 |
+
m = _BOILERPLATE_RE.match(old_text)
|
| 66 |
+
section_name = m.group(1).strip() if m else "DRUG INFORMATION"
|
| 67 |
+
|
| 68 |
+
# Strip the old boilerplate prefix and get just the content
|
| 69 |
+
m = _BOILERPLATE_RE.match(old_text)
|
| 70 |
+
content = old_text[m.end():].strip() if m else old_text.strip()
|
| 71 |
+
|
| 72 |
+
# Prepend drug name into content so BM25 finds it even in continuation chunks
|
| 73 |
+
# e.g. chunk starting "Bleeding tendencies..." now reads "Warfarin CONTRAINDICATIONS: Bleeding..."
|
| 74 |
+
return f"[FDA DailyMed | {drug_name} | {section_name}] {drug_name} {section_name}: {content}"
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def main() -> None:
|
| 78 |
+
parser = argparse.ArgumentParser()
|
| 79 |
+
parser.add_argument("--dry-run", action="store_true")
|
| 80 |
+
args = parser.parse_args()
|
| 81 |
+
|
| 82 |
+
with open("config.yaml") as f:
|
| 83 |
+
cfg = yaml.safe_load(f)
|
| 84 |
+
meta_path = cfg["retrieval"]["metadata_path"]
|
| 85 |
+
|
| 86 |
+
logger.info("Loading metadata store from %s ...", meta_path)
|
| 87 |
+
with open(meta_path, "rb") as f:
|
| 88 |
+
store: dict = pickle.load(f)
|
| 89 |
+
|
| 90 |
+
fda_keys = [k for k, v in store.items() if v.get("source") == "FDA DailyMed"]
|
| 91 |
+
logger.info("Found %d FDA DailyMed entries to fix", len(fda_keys))
|
| 92 |
+
|
| 93 |
+
fixed = 0
|
| 94 |
+
for key in fda_keys:
|
| 95 |
+
entry = store[key]
|
| 96 |
+
old_text = entry.get("chunk_text", "")
|
| 97 |
+
# Re-run on both old boilerplate AND previously-fixed entries (to fix UUID + add drug name to content)
|
| 98 |
+
if not (old_text.startswith("[FDA DRUG LABEL") or old_text.startswith("[FDA DailyMed |")):
|
| 99 |
+
continue
|
| 100 |
+
new_text = fix_chunk_text(entry.get("chunk_id", ""), old_text)
|
| 101 |
+
if args.dry_run:
|
| 102 |
+
if fixed < 3:
|
| 103 |
+
logger.info("BEFORE: %s", old_text[:120])
|
| 104 |
+
logger.info("AFTER: %s", new_text[:120])
|
| 105 |
+
logger.info("---")
|
| 106 |
+
else:
|
| 107 |
+
store[key]["chunk_text"] = new_text
|
| 108 |
+
fixed += 1
|
| 109 |
+
|
| 110 |
+
logger.info("%d entries %s", fixed,
|
| 111 |
+
"would be fixed (dry run)" if args.dry_run else "fixed")
|
| 112 |
+
|
| 113 |
+
if not args.dry_run:
|
| 114 |
+
with open(meta_path, "wb") as f:
|
| 115 |
+
pickle.dump(store, f, protocol=pickle.HIGHEST_PROTOCOL)
|
| 116 |
+
logger.info("Metadata store saved. Restart backend to rebuild BM25 index.")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
if __name__ == "__main__":
|
| 120 |
+
main()
|
scripts/ingest_incremental.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
scripts/ingest_incremental.py
|
| 3 |
+
==============================
|
| 4 |
+
Adds new chunks to an EXISTING FAISS index without rebuilding from scratch.
|
| 5 |
+
Only the new chunks are embedded — existing vectors are untouched.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python scripts/ingest_incremental.py --input data/dailymed_chunks.jsonl
|
| 9 |
+
python scripts/ingest_incremental.py --input data/dailymed_chunks.jsonl --dry-run
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import json
|
| 15 |
+
import logging
|
| 16 |
+
import pickle
|
| 17 |
+
import sys
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 21 |
+
|
| 22 |
+
import faiss
|
| 23 |
+
import numpy as np
|
| 24 |
+
import yaml
|
| 25 |
+
|
| 26 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_config() -> dict:
|
| 31 |
+
with open("config.yaml", "r", encoding="utf-8") as f:
|
| 32 |
+
return yaml.safe_load(f)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def load_new_chunks(path: str) -> list[dict]:
|
| 36 |
+
chunks = []
|
| 37 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 38 |
+
for line in f:
|
| 39 |
+
line = line.strip()
|
| 40 |
+
if line:
|
| 41 |
+
chunks.append(json.loads(line))
|
| 42 |
+
logger.info("Loaded %d new chunks from %s", len(chunks), path)
|
| 43 |
+
return chunks
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def embed_chunks(chunks: list[dict], model_name: str) -> np.ndarray:
|
| 47 |
+
from sentence_transformers import SentenceTransformer
|
| 48 |
+
model = SentenceTransformer(model_name)
|
| 49 |
+
texts = [c["chunk_text"] for c in chunks]
|
| 50 |
+
logger.info("Embedding %d new chunks with %s...", len(texts), model_name)
|
| 51 |
+
embeddings = model.encode(
|
| 52 |
+
texts,
|
| 53 |
+
batch_size=32,
|
| 54 |
+
show_progress_bar=True,
|
| 55 |
+
normalize_embeddings=True,
|
| 56 |
+
convert_to_numpy=True,
|
| 57 |
+
)
|
| 58 |
+
return embeddings.astype(np.float32)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def main() -> None:
|
| 62 |
+
parser = argparse.ArgumentParser()
|
| 63 |
+
parser.add_argument("--input", required=True, help="JSONL file of new chunks")
|
| 64 |
+
parser.add_argument("--dry-run", action="store_true",
|
| 65 |
+
help="Show what would be added without writing to disk")
|
| 66 |
+
parser.add_argument("--force-update-section", default=None,
|
| 67 |
+
help="Force-update chunk_text for existing chunks matching this section keyword (e.g. 'ADVERSE REACTIONS')")
|
| 68 |
+
args = parser.parse_args()
|
| 69 |
+
|
| 70 |
+
cfg = load_config()
|
| 71 |
+
idx_path = cfg["retrieval"]["index_path"]
|
| 72 |
+
meta_path = cfg["retrieval"]["metadata_path"]
|
| 73 |
+
model_name = cfg["retrieval"]["embedding_model"]
|
| 74 |
+
|
| 75 |
+
if not Path(idx_path).exists():
|
| 76 |
+
logger.error("FAISS index not found at %s. Run embedder.py first.", idx_path)
|
| 77 |
+
sys.exit(1)
|
| 78 |
+
|
| 79 |
+
# Load existing index + metadata
|
| 80 |
+
logger.info("Loading existing FAISS index from %s ...", idx_path)
|
| 81 |
+
index = faiss.read_index(idx_path)
|
| 82 |
+
existing_count = index.ntotal
|
| 83 |
+
logger.info("Existing index: %d vectors", existing_count)
|
| 84 |
+
|
| 85 |
+
with open(meta_path, "rb") as f:
|
| 86 |
+
metadata_store: dict[int, dict] = pickle.load(f)
|
| 87 |
+
|
| 88 |
+
# Force-update existing chunk_text for a specific section (no new FAISS vectors needed)
|
| 89 |
+
all_input_chunks = load_new_chunks(args.input)
|
| 90 |
+
if args.force_update_section:
|
| 91 |
+
section_kw = args.force_update_section.upper()
|
| 92 |
+
# Primary lookup: chunk_id → FAISS key (works for FDA with deterministic IDs)
|
| 93 |
+
id_to_meta = {v.get("chunk_id"): k for k, v in metadata_store.items()}
|
| 94 |
+
# Secondary lookup: (doc_id, chunk_index) → FAISS key (works for guidelines with random UUID IDs)
|
| 95 |
+
docidx_to_meta = {(v.get("doc_id", ""), v.get("chunk_index", 0)): k
|
| 96 |
+
for k, v in metadata_store.items()}
|
| 97 |
+
updated = 0
|
| 98 |
+
for chunk in all_input_chunks:
|
| 99 |
+
if section_kw in chunk.get("chunk_text", "").upper():
|
| 100 |
+
# Try primary match first
|
| 101 |
+
faiss_key = id_to_meta.get(chunk.get("chunk_id"))
|
| 102 |
+
# Fallback to (doc_id, chunk_index) match
|
| 103 |
+
if faiss_key is None:
|
| 104 |
+
faiss_key = docidx_to_meta.get((chunk.get("doc_id", ""), chunk.get("chunk_index", 0)))
|
| 105 |
+
if faiss_key is not None:
|
| 106 |
+
metadata_store[faiss_key]["chunk_text"] = chunk["chunk_text"]
|
| 107 |
+
updated += 1
|
| 108 |
+
logger.info("Force-updated chunk_text for %d '%s' entries", updated, section_kw)
|
| 109 |
+
if not args.dry_run:
|
| 110 |
+
with open(meta_path, "wb") as f:
|
| 111 |
+
pickle.dump(metadata_store, f, protocol=pickle.HIGHEST_PROTOCOL)
|
| 112 |
+
logger.info("Metadata store saved.")
|
| 113 |
+
# Invalidate BM25 cache
|
| 114 |
+
bm25_cache = Path(meta_path).parent / "bm25_cache.pkl"
|
| 115 |
+
if bm25_cache.exists():
|
| 116 |
+
bm25_cache.unlink()
|
| 117 |
+
logger.info("BM25 cache invalidated — will rebuild on next startup.")
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
# Deduplicate — skip chunks already in the index.
|
| 121 |
+
# Primary key: chunk_id. Secondary key: (doc_id, chunk_index) handles
|
| 122 |
+
# re-ingestion of the same document with new UUIDs (e.g. FDA label updates).
|
| 123 |
+
existing_ids = {v.get("chunk_id", "") for v in metadata_store.values()}
|
| 124 |
+
existing_docidx = {
|
| 125 |
+
(v.get("doc_id", ""), v.get("chunk_index", -1))
|
| 126 |
+
for v in metadata_store.values()
|
| 127 |
+
if v.get("doc_id") and v.get("chunk_index", -1) >= 0
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
def _is_duplicate(c: dict) -> bool:
|
| 131 |
+
if c.get("chunk_id") in existing_ids:
|
| 132 |
+
return True
|
| 133 |
+
key = (c.get("doc_id", ""), c.get("chunk_index", -1))
|
| 134 |
+
return key[0] != "" and key[1] >= 0 and key in existing_docidx
|
| 135 |
+
|
| 136 |
+
new_chunks = [c for c in all_input_chunks if not _is_duplicate(c)]
|
| 137 |
+
|
| 138 |
+
if not new_chunks:
|
| 139 |
+
logger.info("All chunks already in index. Nothing to add.")
|
| 140 |
+
return
|
| 141 |
+
|
| 142 |
+
logger.info("%d new chunks to add (%d duplicates skipped)",
|
| 143 |
+
len(new_chunks), len(all_input_chunks) - len(new_chunks))
|
| 144 |
+
|
| 145 |
+
if args.dry_run:
|
| 146 |
+
logger.info("DRY RUN — no changes written.")
|
| 147 |
+
for c in new_chunks[:5]:
|
| 148 |
+
logger.info(" Would add: %s | %s", c.get("chunk_id"), c.get("title", "")[:60])
|
| 149 |
+
return
|
| 150 |
+
|
| 151 |
+
# Embed new chunks only
|
| 152 |
+
embeddings = embed_chunks(new_chunks, model_name)
|
| 153 |
+
|
| 154 |
+
# Add to existing FAISS index
|
| 155 |
+
index.add(embeddings)
|
| 156 |
+
logger.info("Index now has %d vectors (+%d)", index.ntotal, len(new_chunks))
|
| 157 |
+
|
| 158 |
+
# Extend metadata store (new keys start from existing_count)
|
| 159 |
+
for i, chunk in enumerate(new_chunks):
|
| 160 |
+
metadata_store[existing_count + i] = {
|
| 161 |
+
"chunk_id": chunk.get("chunk_id", f"chunk_{existing_count + i}"),
|
| 162 |
+
"doc_id": chunk.get("doc_id", ""),
|
| 163 |
+
"source": chunk.get("source", ""),
|
| 164 |
+
"title": chunk.get("title", ""),
|
| 165 |
+
"pub_type": chunk.get("pub_type", "unknown"),
|
| 166 |
+
"pub_year": chunk.get("pub_year"),
|
| 167 |
+
"journal": chunk.get("journal", ""),
|
| 168 |
+
"chunk_index": chunk.get("chunk_index", 0),
|
| 169 |
+
"total_chunks": chunk.get("total_chunks", 1),
|
| 170 |
+
"chunk_text": chunk.get("chunk_text", ""),
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
# Save updated artifacts
|
| 174 |
+
faiss.write_index(index, idx_path)
|
| 175 |
+
logger.info("FAISS index saved to %s", idx_path)
|
| 176 |
+
|
| 177 |
+
with open(meta_path, "wb") as f:
|
| 178 |
+
pickle.dump(metadata_store, f, protocol=pickle.HIGHEST_PROTOCOL)
|
| 179 |
+
logger.info("Metadata store saved (%d total entries)", len(metadata_store))
|
| 180 |
+
|
| 181 |
+
# Also append to chunks.jsonl for future full rebuilds
|
| 182 |
+
chunks_jsonl = Path("data/processed/chunks.jsonl")
|
| 183 |
+
with open(chunks_jsonl, "a", encoding="utf-8") as f:
|
| 184 |
+
for chunk in new_chunks:
|
| 185 |
+
f.write(json.dumps(chunk) + "\n")
|
| 186 |
+
logger.info("Appended %d chunks to %s", len(new_chunks), chunks_jsonl)
|
| 187 |
+
|
| 188 |
+
logger.info("Done. Restart the backend to reload the updated index.")
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
if __name__ == "__main__":
|
| 192 |
+
main()
|
scripts/warmup.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/scripts/warmup.py
|
| 3 |
+
=====================
|
| 4 |
+
Pre-loads heavy ML models (FAISS, DeBERTa, SciSpaCy) into memory
|
| 5 |
+
and guarantees instantaneous responses for the first API request during the live demo.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python scripts/warmup.py
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
import time
|
| 18 |
+
import requests
|
| 19 |
+
|
| 20 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
| 21 |
+
logger = logging.getLogger("warmup")
|
| 22 |
+
|
| 23 |
+
def main():
|
| 24 |
+
api_url = "http://localhost:8000"
|
| 25 |
+
|
| 26 |
+
logger.info("Verifying API is running...")
|
| 27 |
+
try:
|
| 28 |
+
health = requests.get(f"{api_url}/health", timeout=5)
|
| 29 |
+
health.raise_for_status()
|
| 30 |
+
logger.info(f"API Health: {health.json()}")
|
| 31 |
+
except requests.exceptions.RequestException as e:
|
| 32 |
+
logger.error(f"API is not running at {api_url}. Please start it with 'uvicorn src.api.main:app' first.")
|
| 33 |
+
sys.exit(1)
|
| 34 |
+
|
| 35 |
+
logger.info("Sending WARMUP query to load DeBERTa, SciSpaCy, and FAISS into RAM... (This may take 15-25s)")
|
| 36 |
+
t0 = time.time()
|
| 37 |
+
|
| 38 |
+
# We send a basic query to force all models to initialize
|
| 39 |
+
payload = {
|
| 40 |
+
"question": "What is the recommended dosage of Metformin for elderly Type 2 Diabetes patients?",
|
| 41 |
+
"top_k": 1,
|
| 42 |
+
"run_ragas": False
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
resp = requests.post(f"{api_url}/query", json=payload, timeout=60)
|
| 47 |
+
resp.raise_for_status()
|
| 48 |
+
elapsed = time.time() - t0
|
| 49 |
+
logger.info(f"Warmup successful in {elapsed:.1f}s!")
|
| 50 |
+
logger.info("All machine learning models are now cached in RAM.")
|
| 51 |
+
logger.info("The next API requests will be completely instantaneous.")
|
| 52 |
+
except Exception as e:
|
| 53 |
+
logger.error(f"Warmup failed: {e}")
|
| 54 |
+
if hasattr(e, "response") and e.response is not None:
|
| 55 |
+
logger.error(f"Response: {e.response.text}")
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
main()
|
setup.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name="medirag-cli",
|
| 5 |
+
version="0.1.0",
|
| 6 |
+
packages=find_packages(),
|
| 7 |
+
install_requires=[
|
| 8 |
+
"typer>=0.9.0",
|
| 9 |
+
],
|
| 10 |
+
entry_points={
|
| 11 |
+
"console_scripts": [
|
| 12 |
+
"medirag=src.cli:app",
|
| 13 |
+
],
|
| 14 |
+
},
|
| 15 |
+
)
|
src/__init__.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/__init__.py — Package initializer and logging setup.
|
| 3 |
+
Runs once on first `import src`. Sets up logging from config.yaml.
|
| 4 |
+
(SRS Section 13)
|
| 5 |
+
"""
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _setup_logging() -> None:
|
| 11 |
+
"""Configure root logger. No-op if handlers already exist."""
|
| 12 |
+
os.makedirs("logs", exist_ok=True)
|
| 13 |
+
|
| 14 |
+
log_level = logging.INFO
|
| 15 |
+
log_format = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
| 16 |
+
log_file = "logs/medirag.log"
|
| 17 |
+
|
| 18 |
+
# Try to load level from config.yaml
|
| 19 |
+
try:
|
| 20 |
+
import yaml
|
| 21 |
+
with open("config.yaml", "r") as f:
|
| 22 |
+
cfg = yaml.safe_load(f)
|
| 23 |
+
level_str = cfg.get("logging", {}).get("level", "INFO")
|
| 24 |
+
log_level = getattr(logging, level_str.upper(), logging.INFO)
|
| 25 |
+
log_file = cfg.get("logging", {}).get("file", log_file)
|
| 26 |
+
log_format = cfg.get("logging", {}).get("format", log_format)
|
| 27 |
+
except Exception:
|
| 28 |
+
pass # Use defaults if config not found (e.g., during tests)
|
| 29 |
+
|
| 30 |
+
root = logging.getLogger()
|
| 31 |
+
if root.handlers:
|
| 32 |
+
return # Already configured — don't add duplicate handlers
|
| 33 |
+
|
| 34 |
+
handlers: list[logging.Handler] = [logging.StreamHandler()]
|
| 35 |
+
try:
|
| 36 |
+
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
| 37 |
+
handlers.append(logging.FileHandler(log_file, encoding="utf-8"))
|
| 38 |
+
except Exception:
|
| 39 |
+
pass # File logging optional — don't fail on permission errors
|
| 40 |
+
|
| 41 |
+
logging.basicConfig(level=log_level, format=log_format, handlers=handlers)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
_setup_logging()
|
src/api/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# src/api/__init__.py
|
src/api/main.py
ADDED
|
@@ -0,0 +1,933 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/api/main.py — MediRAG FastAPI Application
|
| 3 |
+
=============================================
|
| 4 |
+
FR-18: Two endpoints:
|
| 5 |
+
GET /health → liveness check + Ollama status
|
| 6 |
+
POST /evaluate → calls run_evaluation(), returns FR-17 JSON
|
| 7 |
+
|
| 8 |
+
Design decisions:
|
| 9 |
+
- DeBERTa model is loaded once at app startup (not per-request)
|
| 10 |
+
- If any module raises an exception, partial results are returned (no HTTP 500)
|
| 11 |
+
- HTTP 422 Pydantic validation errors are automatic
|
| 12 |
+
- RAGAS is disabled by default (run_ragas=False) — set to True only if
|
| 13 |
+
Ollama/OpenAI is available; the RAGAS module already fails gracefully.
|
| 14 |
+
|
| 15 |
+
To run:
|
| 16 |
+
uvicorn src.api.main:app --reload --host 0.0.0.0 --port 8000
|
| 17 |
+
"""
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import logging
|
| 22 |
+
import time
|
| 23 |
+
from contextlib import asynccontextmanager
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import Optional
|
| 26 |
+
|
| 27 |
+
import requests
|
| 28 |
+
import json
|
| 29 |
+
import sqlite3
|
| 30 |
+
import yaml
|
| 31 |
+
from datetime import datetime
|
| 32 |
+
from fastapi import FastAPI, HTTPException, File, UploadFile
|
| 33 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 34 |
+
from fastapi.responses import RedirectResponse
|
| 35 |
+
|
| 36 |
+
import threading
|
| 37 |
+
from src.api.schemas import (
|
| 38 |
+
HealthResponse,
|
| 39 |
+
EvaluateRequest,
|
| 40 |
+
EvaluateResponse,
|
| 41 |
+
QueryRequest,
|
| 42 |
+
QueryResponse,
|
| 43 |
+
RetrievedChunk,
|
| 44 |
+
IngestRequest,
|
| 45 |
+
ChatRequest,
|
| 46 |
+
ModuleScore,
|
| 47 |
+
ModuleResults,
|
| 48 |
+
)
|
| 49 |
+
from src.evaluate import run_evaluation
|
| 50 |
+
from src.pipeline.generator import generate_answer
|
| 51 |
+
from src.pipeline.retriever import Retriever
|
| 52 |
+
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
# Logging
|
| 55 |
+
# ---------------------------------------------------------------------------
|
| 56 |
+
try:
|
| 57 |
+
_cfg = yaml.safe_load(Path("config.yaml").read_text())
|
| 58 |
+
_log_level = _cfg.get("logging", {}).get("level", "INFO")
|
| 59 |
+
_ollama_base = _cfg.get("llm", {}).get("base_url", "http://localhost:11434")
|
| 60 |
+
_api_cfg = _cfg.get("api", {})
|
| 61 |
+
except Exception:
|
| 62 |
+
_log_level = "INFO"
|
| 63 |
+
_ollama_base = "http://localhost:11434"
|
| 64 |
+
_api_cfg = {}
|
| 65 |
+
|
| 66 |
+
logging.basicConfig(
|
| 67 |
+
level=getattr(logging, _log_level, logging.INFO),
|
| 68 |
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
| 69 |
+
)
|
| 70 |
+
logger = logging.getLogger(__name__)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ---------------------------------------------------------------------------
|
| 74 |
+
# Database settings
|
| 75 |
+
# ---------------------------------------------------------------------------
|
| 76 |
+
|
| 77 |
+
def init_db():
|
| 78 |
+
Path("data").mkdir(exist_ok=True)
|
| 79 |
+
conn = sqlite3.connect("data/logs.db")
|
| 80 |
+
c = conn.cursor()
|
| 81 |
+
c.execute("""
|
| 82 |
+
CREATE TABLE IF NOT EXISTS audit_logs (
|
| 83 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 84 |
+
timestamp TEXT,
|
| 85 |
+
endpoint TEXT,
|
| 86 |
+
question TEXT,
|
| 87 |
+
answer TEXT,
|
| 88 |
+
hrs INTEGER,
|
| 89 |
+
risk_band TEXT,
|
| 90 |
+
composite_score REAL,
|
| 91 |
+
latency_ms INTEGER,
|
| 92 |
+
intervention_applied BOOLEAN,
|
| 93 |
+
details TEXT
|
| 94 |
+
)
|
| 95 |
+
""")
|
| 96 |
+
conn.commit()
|
| 97 |
+
conn.close()
|
| 98 |
+
|
| 99 |
+
def log_audit(endpoint: str, question: str, answer: str, hrs: int, risk_band: str, composite: float, latency: int, intervention: bool, details: dict):
|
| 100 |
+
try:
|
| 101 |
+
conn = sqlite3.connect("data/logs.db")
|
| 102 |
+
c = conn.cursor()
|
| 103 |
+
c.execute("""
|
| 104 |
+
INSERT INTO audit_logs (timestamp, endpoint, question, answer, hrs, risk_band, composite_score, latency_ms, intervention_applied, details)
|
| 105 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 106 |
+
""", (
|
| 107 |
+
datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"),
|
| 108 |
+
endpoint,
|
| 109 |
+
question,
|
| 110 |
+
answer,
|
| 111 |
+
hrs,
|
| 112 |
+
risk_band,
|
| 113 |
+
composite,
|
| 114 |
+
latency,
|
| 115 |
+
intervention,
|
| 116 |
+
json.dumps(details)
|
| 117 |
+
))
|
| 118 |
+
conn.commit()
|
| 119 |
+
conn.close()
|
| 120 |
+
except Exception as e:
|
| 121 |
+
logger.error(f"Failed to save audit log to DB: {e}")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# ---------------------------------------------------------------------------
|
| 125 |
+
# Lifespan: warm DeBERTa once at startup so the first request isn't slow
|
| 126 |
+
# ---------------------------------------------------------------------------
|
| 127 |
+
@asynccontextmanager
|
| 128 |
+
async def lifespan(app: FastAPI):
|
| 129 |
+
"""Pre-warm DeBERTa and Retriever at startup."""
|
| 130 |
+
init_db()
|
| 131 |
+
logger.info("MediRAG API starting — pre-warming models...")
|
| 132 |
+
try:
|
| 133 |
+
from src.modules.faithfulness import _get_model
|
| 134 |
+
_get_model()
|
| 135 |
+
logger.info("DeBERTa pre-warm complete.")
|
| 136 |
+
except Exception as exc:
|
| 137 |
+
logger.warning("DeBERTa pre-warm skipped: %s", exc)
|
| 138 |
+
|
| 139 |
+
# Pre-load the retriever (BioBERT + FAISS index) into app state
|
| 140 |
+
try:
|
| 141 |
+
app.state.retriever = Retriever(_cfg)
|
| 142 |
+
# Trigger lazy load now so first /query request isn't slow
|
| 143 |
+
app.state.retriever._load_model()
|
| 144 |
+
app.state.retriever._load_index()
|
| 145 |
+
logger.info("Retriever pre-warm complete.")
|
| 146 |
+
except Exception as exc:
|
| 147 |
+
logger.warning("Retriever pre-warm skipped: %s", exc)
|
| 148 |
+
app.state.retriever = None
|
| 149 |
+
|
| 150 |
+
yield
|
| 151 |
+
logger.info("MediRAG API shutting down.")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# ---------------------------------------------------------------------------
|
| 156 |
+
# App
|
| 157 |
+
# ---------------------------------------------------------------------------
|
| 158 |
+
app = FastAPI(
|
| 159 |
+
title="MediRAG Evaluation API",
|
| 160 |
+
description=(
|
| 161 |
+
"Evaluate LLM-generated medical answers against retrieved evidence. "
|
| 162 |
+
"Returns faithfulness, entity accuracy, source credibility, "
|
| 163 |
+
"contradiction risk, and a composite Health Risk Score (HRS)."
|
| 164 |
+
),
|
| 165 |
+
version="0.1.0",
|
| 166 |
+
lifespan=lifespan,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Allow all origins for local dev / React frontend on same machine
|
| 170 |
+
app.add_middleware(
|
| 171 |
+
CORSMiddleware,
|
| 172 |
+
allow_origins=["*"],
|
| 173 |
+
allow_methods=["GET", "POST"],
|
| 174 |
+
allow_headers=["*"],
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ---------------------------------------------------------------------------
|
| 179 |
+
# Helper: check Ollama
|
| 180 |
+
# ---------------------------------------------------------------------------
|
| 181 |
+
def _check_ollama() -> bool:
|
| 182 |
+
"""Return True if Ollama API is reachable."""
|
| 183 |
+
try:
|
| 184 |
+
resp = requests.get(f"{_ollama_base}/api/tags", timeout=2)
|
| 185 |
+
return resp.status_code == 200
|
| 186 |
+
except Exception:
|
| 187 |
+
return False
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# ---------------------------------------------------------------------------
|
| 191 |
+
# Helper: convert EvalResult details → ModuleScore
|
| 192 |
+
# ---------------------------------------------------------------------------
|
| 193 |
+
def _module_score(module_results: dict, key: str) -> Optional[ModuleScore]:
|
| 194 |
+
data = module_results.get(key)
|
| 195 |
+
if data is None:
|
| 196 |
+
return None
|
| 197 |
+
return ModuleScore(
|
| 198 |
+
score=data.get("score", 0.0),
|
| 199 |
+
details=data.get("details", {}),
|
| 200 |
+
error=data.get("error"),
|
| 201 |
+
latency_ms=data.get("latency_ms"),
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# ---------------------------------------------------------------------------
|
| 206 |
+
# GET / → redirect to /docs
|
| 207 |
+
# ---------------------------------------------------------------------------
|
| 208 |
+
@app.post("/project-guide")
|
| 209 |
+
def project_guide(req: ChatRequest):
|
| 210 |
+
"""
|
| 211 |
+
Proxy endpoint for the Project Guide chatbot.
|
| 212 |
+
Routes requests to Groq API using the local GROQ_API_KEY.
|
| 213 |
+
"""
|
| 214 |
+
groq_url = "https://api.groq.com/openai/v1/chat/completions"
|
| 215 |
+
api_key = os.getenv("GROQ_API_KEY")
|
| 216 |
+
|
| 217 |
+
if not api_key:
|
| 218 |
+
raise HTTPException(status_code=500, detail="GROQ_API_KEY not found in server environment")
|
| 219 |
+
|
| 220 |
+
headers = {
|
| 221 |
+
"Authorization": f"Bearer {api_key}",
|
| 222 |
+
"Content-Type": "application/json"
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
# Format messages for Groq
|
| 226 |
+
messages = []
|
| 227 |
+
if req.system_prompt:
|
| 228 |
+
messages.append({"role": "system", "content": req.system_prompt})
|
| 229 |
+
|
| 230 |
+
for m in req.messages:
|
| 231 |
+
messages.append({"role": m.role, "content": m.content})
|
| 232 |
+
|
| 233 |
+
payload = {
|
| 234 |
+
"model": "mixtral-8x7b-32768",
|
| 235 |
+
"messages": messages,
|
| 236 |
+
"temperature": 0.5,
|
| 237 |
+
"max_tokens": 1024
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
try:
|
| 241 |
+
resp = requests.post(groq_url, headers=headers, json=payload, timeout=30)
|
| 242 |
+
resp.raise_for_status()
|
| 243 |
+
return resp.json()
|
| 244 |
+
except Exception as e:
|
| 245 |
+
logger.error(f"Groq Proxy Error: {e}")
|
| 246 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
@app.get("/", include_in_schema=False)
|
| 250 |
+
def root():
|
| 251 |
+
return RedirectResponse(url="/docs")
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# ---------------------------------------------------------------------------
|
| 255 |
+
# GET /health
|
| 256 |
+
# ---------------------------------------------------------------------------
|
| 257 |
+
@app.get("/health", response_model=HealthResponse, tags=["system"])
|
| 258 |
+
def health() -> HealthResponse:
|
| 259 |
+
"""
|
| 260 |
+
Liveness check.
|
| 261 |
+
|
| 262 |
+
Returns {"status": "ok", "ollama_available": true/false}.
|
| 263 |
+
Always returns 200 — the caller decides what to do with `ollama_available`.
|
| 264 |
+
"""
|
| 265 |
+
return HealthResponse(
|
| 266 |
+
status="ok",
|
| 267 |
+
ollama_available=_check_ollama(),
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# ---------------------------------------------------------------------------
|
| 272 |
+
# POST /evaluate
|
| 273 |
+
# ---------------------------------------------------------------------------
|
| 274 |
+
@app.post("/evaluate", response_model=EvaluateResponse, tags=["evaluation"])
|
| 275 |
+
def evaluate(req: EvaluateRequest) -> EvaluateResponse:
|
| 276 |
+
"""
|
| 277 |
+
Run the full MediRAG evaluation pipeline on a question + answer + context.
|
| 278 |
+
|
| 279 |
+
- Validates inputs (FR-18: length limits, chunk count)
|
| 280 |
+
- Runs Faithfulness, Entity Verification, Source Credibility, Contradiction
|
| 281 |
+
- Optionally runs RAGAS (set `run_ragas=true` if Ollama/OpenAI is available)
|
| 282 |
+
- Returns composite Health Risk Score (HRS) + per-module breakdown
|
| 283 |
+
|
| 284 |
+
**Note on `run_ragas`**: RAGAS requires a running LLM backend (Ollama or
|
| 285 |
+
OpenAI). If unavailable, RAGAS will gracefully return score=0.5 as a
|
| 286 |
+
neutral fallback — it will NOT crash the request.
|
| 287 |
+
"""
|
| 288 |
+
logger.info(
|
| 289 |
+
"POST /evaluate — question=%r, chunks=%d, run_ragas=%s",
|
| 290 |
+
req.question[:80],
|
| 291 |
+
len(req.context_chunks),
|
| 292 |
+
req.run_ragas,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Convert Pydantic ContextChunk → plain dicts for the pipeline
|
| 296 |
+
context_dicts: list[dict] = [chunk.model_dump(exclude_none=True) for chunk in req.context_chunks]
|
| 297 |
+
|
| 298 |
+
t0 = time.perf_counter()
|
| 299 |
+
try:
|
| 300 |
+
result = run_evaluation(
|
| 301 |
+
question=req.question,
|
| 302 |
+
answer=req.answer,
|
| 303 |
+
context_chunks=context_dicts,
|
| 304 |
+
rxnorm_cache_path=req.rxnorm_cache_path,
|
| 305 |
+
run_ragas=req.run_ragas,
|
| 306 |
+
config=_cfg,
|
| 307 |
+
)
|
| 308 |
+
except Exception as exc:
|
| 309 |
+
logger.exception("run_evaluation raised an unhandled exception: %s", exc)
|
| 310 |
+
raise HTTPException(
|
| 311 |
+
status_code=500,
|
| 312 |
+
detail=f"Evaluation pipeline error: {type(exc).__name__}: {exc}",
|
| 313 |
+
) from exc
|
| 314 |
+
|
| 315 |
+
total_ms = int((time.perf_counter() - t0) * 1000)
|
| 316 |
+
|
| 317 |
+
# Extract composite score + details
|
| 318 |
+
composite = float(result.score)
|
| 319 |
+
details = result.details or {}
|
| 320 |
+
hrs = details.get("hrs", int(round(100 * (1.0 - composite))))
|
| 321 |
+
hrs = max(0, min(100, hrs))
|
| 322 |
+
|
| 323 |
+
confidence_level = details.get("confidence_level", "UNKNOWN")
|
| 324 |
+
risk_band = details.get("risk_band", "UNKNOWN")
|
| 325 |
+
pipeline_ms = details.get("total_pipeline_ms", total_ms)
|
| 326 |
+
|
| 327 |
+
# Build per-module scores
|
| 328 |
+
mod_results: dict = details.get("module_results", {})
|
| 329 |
+
module_scores = ModuleResults(
|
| 330 |
+
faithfulness=_module_score(mod_results, "faithfulness"),
|
| 331 |
+
entity_verifier=_module_score(mod_results, "entity_verifier"),
|
| 332 |
+
source_credibility=_module_score(mod_results, "source_credibility"),
|
| 333 |
+
contradiction=_module_score(mod_results, "contradiction"),
|
| 334 |
+
ragas=_module_score(mod_results, "ragas"),
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
logger.info(
|
| 338 |
+
"POST /evaluate → HRS=%d (%s) in %d ms",
|
| 339 |
+
hrs, risk_band, pipeline_ms,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
log_audit("evaluate", req.question, req.answer, hrs, risk_band, composite, pipeline_ms, False, {
|
| 343 |
+
"module_results": mod_results,
|
| 344 |
+
"confidence_level": confidence_level
|
| 345 |
+
})
|
| 346 |
+
|
| 347 |
+
return EvaluateResponse(
|
| 348 |
+
composite_score=composite,
|
| 349 |
+
hrs=hrs,
|
| 350 |
+
confidence_level=confidence_level,
|
| 351 |
+
risk_band=risk_band,
|
| 352 |
+
module_results=module_scores,
|
| 353 |
+
total_pipeline_ms=pipeline_ms,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
# ---------------------------------------------------------------------------
|
| 358 |
+
# POST /query — end-to-end: question → retrieve → generate → evaluate
|
| 359 |
+
# ---------------------------------------------------------------------------
|
| 360 |
+
@app.post("/query", response_model=QueryResponse, tags=["query"])
|
| 361 |
+
def query(req: QueryRequest) -> QueryResponse:
|
| 362 |
+
"""
|
| 363 |
+
Full end-to-end MediRAG pipeline.
|
| 364 |
+
|
| 365 |
+
1. Retrieves top-k context chunks from FAISS (BioBERT)
|
| 366 |
+
2. Generates a grounded answer using Mistral (Ollama)
|
| 367 |
+
3. Evaluates the answer with all 4 modules + aggregator
|
| 368 |
+
4. Returns the answer, retrieved chunks, HRS score, and full breakdown
|
| 369 |
+
|
| 370 |
+
**Requires Ollama running locally with Mistral pulled.**
|
| 371 |
+
No fallback — returns 503 if Ollama is unavailable.
|
| 372 |
+
"""
|
| 373 |
+
import time as _time
|
| 374 |
+
t_total = _time.perf_counter()
|
| 375 |
+
|
| 376 |
+
logger.info("POST /query — question=%r, top_k=%d", req.question[:80], req.top_k)
|
| 377 |
+
|
| 378 |
+
# Step 1: Retrieve
|
| 379 |
+
retriever: Optional[Retriever] = getattr(app.state, "retriever", None)
|
| 380 |
+
if retriever is None:
|
| 381 |
+
# Fallback: instantiate now (slower first call)
|
| 382 |
+
try:
|
| 383 |
+
retriever = Retriever(_cfg)
|
| 384 |
+
except Exception as exc:
|
| 385 |
+
raise HTTPException(status_code=503,
|
| 386 |
+
detail=f"Retriever unavailable: {exc}") from exc
|
| 387 |
+
|
| 388 |
+
try:
|
| 389 |
+
raw_results = retriever.search(req.question, top_k=req.top_k)
|
| 390 |
+
except FileNotFoundError as exc:
|
| 391 |
+
raise HTTPException(status_code=503,
|
| 392 |
+
detail=f"FAISS index not found: {exc}") from exc
|
| 393 |
+
except Exception as exc:
|
| 394 |
+
raise HTTPException(status_code=500,
|
| 395 |
+
detail=f"Retrieval error: {exc}") from exc
|
| 396 |
+
|
| 397 |
+
if not raw_results:
|
| 398 |
+
raise HTTPException(status_code=404,
|
| 399 |
+
detail="No relevant documents found for this question.")
|
| 400 |
+
|
| 401 |
+
# Convert retriever output → chunk dicts for generator + evaluate
|
| 402 |
+
context_chunks: list[dict] = []
|
| 403 |
+
retrieved_chunks_out: list[RetrievedChunk] = []
|
| 404 |
+
for chunk_text, meta, score in raw_results:
|
| 405 |
+
d = {
|
| 406 |
+
"text": chunk_text,
|
| 407 |
+
"chunk_id": meta.get("chunk_id"),
|
| 408 |
+
"source": meta.get("source", ""),
|
| 409 |
+
"pub_type": meta.get("pub_type", ""),
|
| 410 |
+
"pub_year": meta.get("pub_year"),
|
| 411 |
+
"title": meta.get("title", ""),
|
| 412 |
+
}
|
| 413 |
+
context_chunks.append(d)
|
| 414 |
+
retrieved_chunks_out.append(RetrievedChunk(
|
| 415 |
+
chunk_id=meta.get("chunk_id"),
|
| 416 |
+
text=chunk_text[:500], # truncate for response readability
|
| 417 |
+
source=meta.get("source", ""),
|
| 418 |
+
pub_type=meta.get("pub_type", ""),
|
| 419 |
+
pub_year=meta.get("pub_year"),
|
| 420 |
+
title=meta.get("title", ""),
|
| 421 |
+
similarity_score=round(score, 4),
|
| 422 |
+
))
|
| 423 |
+
|
| 424 |
+
logger.info("Retrieved %d chunks (top score=%.4f)", len(context_chunks),
|
| 425 |
+
raw_results[0][2] if raw_results else 0.0)
|
| 426 |
+
|
| 427 |
+
# Raw FAISS cosine similarity for coverage gap gate.
|
| 428 |
+
# IndexFlatIP + L2-norm = cosine in [-1, 1]. < 0.60 means no close semantic match in DB.
|
| 429 |
+
top_faiss_cosine = (
|
| 430 |
+
raw_results[0][1].get("_top_faiss_cosine", 0.0) if raw_results else 0.0
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# Convert request overrides into a dict for generator
|
| 434 |
+
llm_overrides = {}
|
| 435 |
+
if req.llm_provider:
|
| 436 |
+
llm_overrides["provider"] = req.llm_provider
|
| 437 |
+
if req.llm_api_key:
|
| 438 |
+
llm_overrides["api_key"] = req.llm_api_key
|
| 439 |
+
if req.llm_model:
|
| 440 |
+
llm_overrides["model"] = req.llm_model
|
| 441 |
+
if req.ollama_url:
|
| 442 |
+
llm_overrides["ollama_url"] = req.ollama_url
|
| 443 |
+
if req.system_prompt:
|
| 444 |
+
llm_overrides["system_prompt"] = req.system_prompt
|
| 445 |
+
if req.persona:
|
| 446 |
+
llm_overrides["persona"] = req.persona
|
| 447 |
+
|
| 448 |
+
# =========================================================================
|
| 449 |
+
# Step 2a: PRIVACY SHIELD — MediRAG redacts PHI (Option 1)
|
| 450 |
+
# =========================================================================
|
| 451 |
+
p_mapping = {}
|
| 452 |
+
privacy_applied = False
|
| 453 |
+
question_to_gen = req.question
|
| 454 |
+
|
| 455 |
+
if req.use_privacy_shield:
|
| 456 |
+
from src.pipeline.privacy import shield
|
| 457 |
+
question_to_gen, p_mapping = shield.redact(req.question)
|
| 458 |
+
if p_mapping:
|
| 459 |
+
privacy_applied = True
|
| 460 |
+
logger.info("PRIVACY INTERVENTION: Redacted %d items from question.", len(p_mapping))
|
| 461 |
+
|
| 462 |
+
# Step 2: Generate answer via LLM (Gemini or Ollama)
|
| 463 |
+
try:
|
| 464 |
+
# Use the potentially redacted question for generation
|
| 465 |
+
answer = generate_answer(question_to_gen, context_chunks, _cfg, overrides=llm_overrides)
|
| 466 |
+
except RuntimeError as exc:
|
| 467 |
+
raise HTTPException(status_code=503,
|
| 468 |
+
detail=f"LLM generation failed: {exc}") from exc
|
| 469 |
+
|
| 470 |
+
# Restore the PHI for the final display so the user sees the actual names
|
| 471 |
+
if privacy_applied:
|
| 472 |
+
from src.pipeline.privacy import shield
|
| 473 |
+
answer = shield.restore(answer, p_mapping)
|
| 474 |
+
# =========================================================================
|
| 475 |
+
|
| 476 |
+
# =========================================================================
|
| 477 |
+
# Step 2b: CONSENSUS CHECK — MediRAG compares multiple models (Option 2)
|
| 478 |
+
# =========================================================================
|
| 479 |
+
consensus_results = None
|
| 480 |
+
if req.use_consensus:
|
| 481 |
+
from src.pipeline.consensus import run_consensus_check
|
| 482 |
+
# Determine which providers to use based on available config/overrides
|
| 483 |
+
providers = ["gemini"]
|
| 484 |
+
if os.environ.get("GROQ_API_KEY"):
|
| 485 |
+
providers.append("groq")
|
| 486 |
+
elif os.environ.get("MISTRAL_API_KEY"):
|
| 487 |
+
providers.append("mistral")
|
| 488 |
+
else:
|
| 489 |
+
providers.append("ollama") # fallback to local if no second key
|
| 490 |
+
|
| 491 |
+
logger.info("Running Consensus Layer with %s", providers)
|
| 492 |
+
consensus_results = run_consensus_check(req.question, context_chunks, _cfg, providers=providers)
|
| 493 |
+
|
| 494 |
+
# If consensus finds a safer merged answer, we promote it
|
| 495 |
+
# and update the primary answer for the evaluation loop
|
| 496 |
+
answer = consensus_results.get("consensus_answer", answer)
|
| 497 |
+
# =========================================================================
|
| 498 |
+
|
| 499 |
+
# [DEMO MODE] Inject a false claim to demonstrate the intervention system
|
| 500 |
+
if req.inject_hallucination:
|
| 501 |
+
logger.warning("DEMO MODE: Injecting hallucinated claim into answer: '%s'",
|
| 502 |
+
req.inject_hallucination)
|
| 503 |
+
answer = answer + " " + req.inject_hallucination.strip()
|
| 504 |
+
|
| 505 |
+
# Step 3: Evaluate
|
| 506 |
+
try:
|
| 507 |
+
eval_result = run_evaluation(
|
| 508 |
+
question=req.question,
|
| 509 |
+
answer=answer,
|
| 510 |
+
context_chunks=context_chunks,
|
| 511 |
+
run_ragas=req.run_ragas,
|
| 512 |
+
config=_cfg,
|
| 513 |
+
)
|
| 514 |
+
except Exception as exc:
|
| 515 |
+
logger.exception("Evaluation failed: %s", exc)
|
| 516 |
+
try:
|
| 517 |
+
log_audit("query", req.question, answer, 100, "EVAL_ERROR", 0.0,
|
| 518 |
+
int((_time.perf_counter() - t_total) * 1000),
|
| 519 |
+
False, {"error": str(exc), "error_type": "evaluation_failure"})
|
| 520 |
+
except Exception:
|
| 521 |
+
pass
|
| 522 |
+
raise HTTPException(status_code=500,
|
| 523 |
+
detail=f"Evaluation error: {exc}") from exc
|
| 524 |
+
|
| 525 |
+
# =========================================================================
|
| 526 |
+
# Step 3b: INTERVENTION LOOP — MediRAG acts on evaluation results
|
| 527 |
+
# =========================================================================
|
| 528 |
+
from src.pipeline.generator import generate_strict_answer
|
| 529 |
+
|
| 530 |
+
details = eval_result.details or {}
|
| 531 |
+
composite = float(eval_result.score)
|
| 532 |
+
hrs = int(round(100 * (1.0 - composite)))
|
| 533 |
+
hrs = max(0, min(100, hrs))
|
| 534 |
+
mod_results: dict = details.get("module_results", {})
|
| 535 |
+
|
| 536 |
+
intervention_applied = False
|
| 537 |
+
intervention_reason = None
|
| 538 |
+
original_answer = None
|
| 539 |
+
intervention_details = None
|
| 540 |
+
|
| 541 |
+
faith_score = (mod_results.get("faithfulness") or {}).get("score", 1.0)
|
| 542 |
+
|
| 543 |
+
# Source-credibility-aware faith threshold: high-credibility sources get more tolerance
|
| 544 |
+
source_cred = float(details.get("component_scores", {}).get("source_credibility", 0.5))
|
| 545 |
+
faith_threshold = max(0.3, 0.7 - (source_cred * 0.4)) # 0.30 for cred=1.0, 0.66 for cred=0.3
|
| 546 |
+
|
| 547 |
+
# ── Coverage Gap Gate ────────────────────────────────────────────────────
|
| 548 |
+
# Two signals combined:
|
| 549 |
+
# 1. Refusal answer — LLM says "not in context / insufficient evidence"
|
| 550 |
+
# → LLM itself confirms the DB doesn't cover this topic.
|
| 551 |
+
# 2. FAISS cosine — genuinely poor semantic match vs. the query.
|
| 552 |
+
# BioBERT clusters medical dosing texts, so threshold must be high (0.75).
|
| 553 |
+
_REFUSAL_PATTERNS = (
|
| 554 |
+
"not mentioned in the provided context",
|
| 555 |
+
"not provided in the retrieved context",
|
| 556 |
+
"insufficient evidence in retrieved context",
|
| 557 |
+
"no information about",
|
| 558 |
+
"not in the provided context",
|
| 559 |
+
"cannot find information",
|
| 560 |
+
"the retrieved context does not contain",
|
| 561 |
+
"the context does not contain",
|
| 562 |
+
"not mentioned in the context",
|
| 563 |
+
"is not provided in the context",
|
| 564 |
+
)
|
| 565 |
+
_answer_lower = answer.lower()
|
| 566 |
+
is_refusal_answer = any(p in _answer_lower for p in _REFUSAL_PATTERNS)
|
| 567 |
+
is_low_faiss = top_faiss_cosine < 0.75
|
| 568 |
+
|
| 569 |
+
# If a verified drug with rxcui appears in the question, the intervention's
|
| 570 |
+
# FDA direct lookup can still retrieve the right data even when initial FAISS
|
| 571 |
+
# retrieval missed it. Don't label those as coverage gaps — let intervention run.
|
| 572 |
+
_ev_entities = (mod_results.get("entity_verifier") or {}).get("details", {}).get("entities", [])
|
| 573 |
+
_q_lower_cg = req.question.lower()
|
| 574 |
+
_drug_in_question = any(
|
| 575 |
+
e.get("rxcui") and e.get("entity", "").lower() in _q_lower_cg
|
| 576 |
+
for e in _ev_entities
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
# Refusal is a standalone COVERAGE_GAP signal — faith_score is unreliable here
|
| 580 |
+
# because NLI scores refusal sentences as NEUTRAL (0.5), not low.
|
| 581 |
+
# Exception: if a drug is named in the question, FDA lookup can still help.
|
| 582 |
+
# HALLUCINATION: specific claims made but not grounded in available context.
|
| 583 |
+
if is_refusal_answer and not _drug_in_question:
|
| 584 |
+
gap_type = "COVERAGE_GAP"
|
| 585 |
+
elif faith_score < faith_threshold and is_low_faiss and not _drug_in_question:
|
| 586 |
+
gap_type = "COVERAGE_GAP" # poor retrieval + low faith = DB lacks this topic
|
| 587 |
+
elif faith_score < faith_threshold:
|
| 588 |
+
gap_type = "HALLUCINATION" # relevant context exists but answer ignores it
|
| 589 |
+
else:
|
| 590 |
+
gap_type = None
|
| 591 |
+
|
| 592 |
+
coverage_gap = gap_type == "COVERAGE_GAP"
|
| 593 |
+
coverage_gap_details: dict | None = {
|
| 594 |
+
"gap_type": gap_type,
|
| 595 |
+
"top_faiss_cosine": round(top_faiss_cosine, 4),
|
| 596 |
+
"is_refusal_answer": is_refusal_answer,
|
| 597 |
+
"note": (
|
| 598 |
+
"Database coverage may be insufficient for this topic. "
|
| 599 |
+
"The answer could not be verified against retrieved evidence. "
|
| 600 |
+
"Consult primary medical literature or a specialist."
|
| 601 |
+
) if coverage_gap else None,
|
| 602 |
+
} if gap_type else None
|
| 603 |
+
if coverage_gap:
|
| 604 |
+
logger.warning(
|
| 605 |
+
"COVERAGE_GAP detected — refusal=%s, faiss=%.4f, faith=%.2f",
|
| 606 |
+
is_refusal_answer, top_faiss_cosine, faith_score,
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
# Tier 1: CRITICAL BLOCK (HRS ≥ 86) — response is too dangerous to show
|
| 610 |
+
# Coverage gap: skip both tiers — regenerating from an empty DB won't help
|
| 611 |
+
if coverage_gap:
|
| 612 |
+
logger.info("COVERAGE_GAP — skipping intervention (regeneration cannot add missing data).")
|
| 613 |
+
elif hrs >= 86:
|
| 614 |
+
original_answer = answer
|
| 615 |
+
answer = (
|
| 616 |
+
"⛔ UNSAFE RESPONSE BLOCKED by MediRAG Safety Gate.\n\n"
|
| 617 |
+
"The generated answer was flagged as CRITICAL risk "
|
| 618 |
+
f"(Health Risk Score: {hrs}/100). "
|
| 619 |
+
"It showed signs of hallucination or contradiction with the retrieved evidence. "
|
| 620 |
+
"Please consult a qualified medical professional or rephrase your question."
|
| 621 |
+
)
|
| 622 |
+
intervention_applied = True
|
| 623 |
+
intervention_reason = "CRITICAL_BLOCKED"
|
| 624 |
+
intervention_details = {
|
| 625 |
+
"hrs_original": hrs,
|
| 626 |
+
"faithfulness": faith_score,
|
| 627 |
+
"message": "Response blocked: HRS ≥ 86 (CRITICAL risk band).",
|
| 628 |
+
}
|
| 629 |
+
logger.warning("INTERVENTION: CRITICAL_BLOCKED — HRS=%d", hrs)
|
| 630 |
+
|
| 631 |
+
# Tier 2: HIGH RISK REGENERATION
|
| 632 |
+
elif hrs >= 61 or faith_score < faith_threshold:
|
| 633 |
+
original_answer = answer
|
| 634 |
+
original_hrs = hrs
|
| 635 |
+
logger.warning(
|
| 636 |
+
"INTERVENTION: HIGH_RISK_REGENERATED — HRS=%d, faith=%.2f. Regenerating with strict prompt.",
|
| 637 |
+
hrs, faith_score
|
| 638 |
+
)
|
| 639 |
+
try:
|
| 640 |
+
# Re-retrieve from shared index — find better chunks than the ones that failed
|
| 641 |
+
try:
|
| 642 |
+
# Direct FDA lookup — only for drugs named in the question itself.
|
| 643 |
+
# Drugs found in the answer but NOT in the question (e.g. metformin
|
| 644 |
+
# mentioned incidentally in a general "first-line treatment" answer)
|
| 645 |
+
# should not trigger FDA lookup; that would replace relevant context
|
| 646 |
+
# with the wrong label sections (contraindications instead of treatment).
|
| 647 |
+
fda_direct: list[dict] = []
|
| 648 |
+
try:
|
| 649 |
+
ev_details = eval_result.details.get("module_results", {}).get("entity_verifier", {}).get("details", {})
|
| 650 |
+
verified_drugs = [
|
| 651 |
+
e["entity"] for e in ev_details.get("entities", [])
|
| 652 |
+
if e.get("status") == "VERIFIED" and e.get("rxcui")
|
| 653 |
+
]
|
| 654 |
+
q_lower = req.question.lower()
|
| 655 |
+
for drug in verified_drugs:
|
| 656 |
+
if drug.lower() in q_lower:
|
| 657 |
+
fda_direct += app.state.retriever.get_fda_chunks(drug)
|
| 658 |
+
if fda_direct:
|
| 659 |
+
logger.info("Direct FDA lookup found %d chunks for drugs: %s",
|
| 660 |
+
len(fda_direct), [d for d in verified_drugs if d.lower() in q_lower])
|
| 661 |
+
except Exception as fda_exc:
|
| 662 |
+
logger.debug("Direct FDA lookup skipped: %s", fda_exc)
|
| 663 |
+
|
| 664 |
+
# Direct guideline lookup — only when original retrieval was poor.
|
| 665 |
+
# If FAISS cosine ≥ 0.85 the original chunks were already relevant;
|
| 666 |
+
# adding guideline sections here can pull in wrong topic areas
|
| 667 |
+
# (e.g., ADA Section 2 Diagnosis instead of Section 9 Treatment).
|
| 668 |
+
guideline_direct: list[dict] = []
|
| 669 |
+
if top_faiss_cosine < 0.85:
|
| 670 |
+
try:
|
| 671 |
+
guideline_direct = app.state.retriever.get_guideline_chunks(req.question)
|
| 672 |
+
if guideline_direct:
|
| 673 |
+
logger.info("Direct guideline lookup found %d chunks", len(guideline_direct))
|
| 674 |
+
except Exception as gl_exc:
|
| 675 |
+
logger.debug("Direct guideline lookup skipped: %s", gl_exc)
|
| 676 |
+
else:
|
| 677 |
+
logger.debug("Skipping guideline direct lookup (FAISS cosine=%.4f ≥ 0.85, original retrieval was high-quality)", top_faiss_cosine)
|
| 678 |
+
|
| 679 |
+
# Merge: guideline chunks + FDA chunks + fresh retrieval
|
| 680 |
+
fda_direct = guideline_direct + fda_direct
|
| 681 |
+
|
| 682 |
+
# For drug/clinical questions, expand query toward authoritative sources
|
| 683 |
+
_drug_terms = ("contraindication", "dosage", "dose", "interaction",
|
| 684 |
+
"warning", "adverse", "side effect", "mechanism")
|
| 685 |
+
_q_lower = req.question.lower()
|
| 686 |
+
retry_query = (
|
| 687 |
+
f"FDA drug label clinical guideline {req.question}"
|
| 688 |
+
if any(t in _q_lower for t in _drug_terms)
|
| 689 |
+
else req.question
|
| 690 |
+
)
|
| 691 |
+
fresh_results = app.state.retriever.search(retry_query, top_k=req.top_k)
|
| 692 |
+
fresh_chunks: list[dict] = []
|
| 693 |
+
for chunk_text, meta, score in fresh_results:
|
| 694 |
+
fresh_chunks.append({
|
| 695 |
+
"text": chunk_text, "chunk_id": meta.get("chunk_id"),
|
| 696 |
+
"source": meta.get("source", ""), "pub_type": meta.get("pub_type", ""),
|
| 697 |
+
"pub_year": meta.get("pub_year"), "title": meta.get("title", ""),
|
| 698 |
+
})
|
| 699 |
+
# Merge: direct lookups first (FDA/guidelines), then fresh retrieval
|
| 700 |
+
base_chunks = fresh_chunks if fresh_chunks else context_chunks
|
| 701 |
+
retry_chunks = (fda_direct + base_chunks)[:req.top_k] if fda_direct else base_chunks
|
| 702 |
+
logger.info("Re-retrieval for intervention: %d fresh chunks (top source: %s)",
|
| 703 |
+
len(retry_chunks),
|
| 704 |
+
retry_chunks[0].get("pub_type", "?") if retry_chunks else "none")
|
| 705 |
+
except Exception:
|
| 706 |
+
retry_chunks = context_chunks
|
| 707 |
+
|
| 708 |
+
answer = generate_strict_answer(req.question, retry_chunks, _cfg, overrides=llm_overrides)
|
| 709 |
+
# Re-evaluate the corrected answer
|
| 710 |
+
eval_result = run_evaluation(
|
| 711 |
+
question=req.question,
|
| 712 |
+
answer=answer,
|
| 713 |
+
context_chunks=retry_chunks,
|
| 714 |
+
run_ragas=False, # skip RAGAS on retry to reduce latency
|
| 715 |
+
config=_cfg,
|
| 716 |
+
)
|
| 717 |
+
details = eval_result.details or {}
|
| 718 |
+
composite = float(eval_result.score)
|
| 719 |
+
hrs = int(round(100 * (1.0 - composite)))
|
| 720 |
+
hrs = max(0, min(100, hrs))
|
| 721 |
+
mod_results = details.get("module_results", {})
|
| 722 |
+
except Exception as exc:
|
| 723 |
+
logger.error("Strict regeneration failed: %s — keeping original answer", exc)
|
| 724 |
+
answer = original_answer # fall back gracefully
|
| 725 |
+
original_answer = None
|
| 726 |
+
|
| 727 |
+
intervention_applied = True
|
| 728 |
+
intervention_reason = "HIGH_RISK_REGENERATED"
|
| 729 |
+
intervention_details = {
|
| 730 |
+
"hrs_original": original_hrs,
|
| 731 |
+
"hrs_corrected": hrs,
|
| 732 |
+
"faithfulness_original": faith_score,
|
| 733 |
+
"faithfulness_corrected": (mod_results.get("faithfulness") or {}).get("score", 0),
|
| 734 |
+
"message": "Response regenerated with strict context-only prompt due to high risk score.",
|
| 735 |
+
}
|
| 736 |
+
# =========================================================================
|
| 737 |
+
|
| 738 |
+
# Step 4: Build response
|
| 739 |
+
total_ms = int((_time.perf_counter() - t_total) * 1000)
|
| 740 |
+
logger.info("POST /query → HRS=%d (%s) intervention=%s in %d ms total",
|
| 741 |
+
hrs, details.get("risk_band", "?"), intervention_reason or "none", total_ms)
|
| 742 |
+
|
| 743 |
+
log_audit("query", req.question, answer, hrs, details.get("risk_band", "UNKNOWN"), composite, total_ms, intervention_applied, {
|
| 744 |
+
"module_results": mod_results,
|
| 745 |
+
"confidence_level": details.get("confidence_level", "UNKNOWN"),
|
| 746 |
+
"intervention_reason": intervention_reason,
|
| 747 |
+
"original_answer": original_answer,
|
| 748 |
+
})
|
| 749 |
+
|
| 750 |
+
return QueryResponse(
|
| 751 |
+
question=req.question,
|
| 752 |
+
generated_answer=answer,
|
| 753 |
+
retrieved_chunks=retrieved_chunks_out,
|
| 754 |
+
composite_score=composite,
|
| 755 |
+
hrs=hrs,
|
| 756 |
+
confidence_level=details.get("confidence_level", "UNKNOWN"),
|
| 757 |
+
risk_band=details.get("risk_band", "UNKNOWN"),
|
| 758 |
+
module_results=ModuleResults(
|
| 759 |
+
faithfulness=_module_score(mod_results, "faithfulness"),
|
| 760 |
+
entity_verifier=_module_score(mod_results, "entity_verifier"),
|
| 761 |
+
source_credibility=_module_score(mod_results, "source_credibility"),
|
| 762 |
+
contradiction=_module_score(mod_results, "contradiction"),
|
| 763 |
+
ragas=_module_score(mod_results, "ragas"),
|
| 764 |
+
),
|
| 765 |
+
total_pipeline_ms=total_ms,
|
| 766 |
+
intervention_applied=intervention_applied,
|
| 767 |
+
intervention_reason=intervention_reason,
|
| 768 |
+
original_answer=original_answer,
|
| 769 |
+
intervention_details=intervention_details,
|
| 770 |
+
consensus_results=consensus_results,
|
| 771 |
+
privacy_applied=privacy_applied,
|
| 772 |
+
privacy_details={"redacted_count": len(p_mapping)} if privacy_applied else None,
|
| 773 |
+
coverage_gap=coverage_gap,
|
| 774 |
+
coverage_gap_details=coverage_gap_details,
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
# ---------------------------------------------------------------------------
|
| 778 |
+
# POST /ingest — dynamically append new documents to the FAISS index
|
| 779 |
+
# ---------------------------------------------------------------------------
|
| 780 |
+
_faiss_lock = threading.Lock()
|
| 781 |
+
|
| 782 |
+
@app.post("/ingest", tags=["ingestion"])
|
| 783 |
+
def ingest_document(req: IngestRequest):
|
| 784 |
+
"""
|
| 785 |
+
Dynamically ingest a new document into the running FAISS index.
|
| 786 |
+
Thread-safe implementation uses a lock to prevent concurrent write corruption.
|
| 787 |
+
"""
|
| 788 |
+
import pickle
|
| 789 |
+
import faiss
|
| 790 |
+
from src.pipeline.chunker import chunk_documents
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
retriever = getattr(app.state, "retriever", None)
|
| 794 |
+
if retriever is None or retriever._index is None:
|
| 795 |
+
raise HTTPException(status_code=503, detail="Retriever not pre-warmed. Cannot ingest.")
|
| 796 |
+
|
| 797 |
+
logger.info("POST /ingest — title='%s', size=%d chars", req.title, len(req.text))
|
| 798 |
+
|
| 799 |
+
# 1. Chunk the document
|
| 800 |
+
doc = {
|
| 801 |
+
"text": req.text,
|
| 802 |
+
"doc_id": "custom_" + req.title[:10],
|
| 803 |
+
"title": req.title,
|
| 804 |
+
"source": req.source,
|
| 805 |
+
"pub_type": req.pub_type,
|
| 806 |
+
"pub_year": 2026,
|
| 807 |
+
}
|
| 808 |
+
chunks = chunk_documents([doc], _cfg)
|
| 809 |
+
|
| 810 |
+
if not chunks:
|
| 811 |
+
raise HTTPException(status_code=400, detail="Document produced 0 chunks.")
|
| 812 |
+
|
| 813 |
+
# 2. Embed the chunks using the same BioBERT model as the retriever
|
| 814 |
+
from src.pipeline.embedder import encode_texts
|
| 815 |
+
import numpy as np
|
| 816 |
+
|
| 817 |
+
# Reuse already-loaded SentenceTransformer from the retriever to avoid double RAM load
|
| 818 |
+
if retriever._model is None:
|
| 819 |
+
retriever._load_model()
|
| 820 |
+
st_model = retriever._model
|
| 821 |
+
|
| 822 |
+
texts = [c["chunk_text"] for c in chunks]
|
| 823 |
+
embeddings = np.array(st_model.encode(texts, show_progress_bar=False), dtype=np.float32)
|
| 824 |
+
faiss.normalize_L2(embeddings) # Required: index is IndexFlatIP = cosine sim
|
| 825 |
+
|
| 826 |
+
# 3. Thread-safe Index Update with atomic disk writes
|
| 827 |
+
with _faiss_lock:
|
| 828 |
+
import os
|
| 829 |
+
idx_path = Path(_cfg["retrieval"]["index_path"])
|
| 830 |
+
meta_path = Path(_cfg["retrieval"]["metadata_path"])
|
| 831 |
+
|
| 832 |
+
index = retriever._index
|
| 833 |
+
metadata_store = retriever._metadata
|
| 834 |
+
|
| 835 |
+
start_id = len(metadata_store)
|
| 836 |
+
|
| 837 |
+
# Add to in-memory structures
|
| 838 |
+
for i, chunk in enumerate(chunks):
|
| 839 |
+
metadata_store[start_id + i] = chunk
|
| 840 |
+
|
| 841 |
+
# Add to FAISS in memory
|
| 842 |
+
index.add(embeddings)
|
| 843 |
+
|
| 844 |
+
# Atomic FAISS write: write to temp → rename (never leaves a half-written file)
|
| 845 |
+
idx_tmp = str(idx_path) + ".tmp"
|
| 846 |
+
faiss.write_index(index, idx_tmp)
|
| 847 |
+
os.replace(idx_tmp, str(idx_path))
|
| 848 |
+
|
| 849 |
+
# Atomic metadata write
|
| 850 |
+
meta_tmp = str(meta_path) + ".tmp"
|
| 851 |
+
with open(meta_tmp, "wb") as f:
|
| 852 |
+
pickle.dump(metadata_store, f)
|
| 853 |
+
os.replace(meta_tmp, str(meta_path))
|
| 854 |
+
|
| 855 |
+
# 4. Rebuild BM25 for the running instance
|
| 856 |
+
retriever.rebuild_bm25()
|
| 857 |
+
|
| 858 |
+
logger.info("Successfully injected %d chunks for '%s' into FAISS and BM25.", len(chunks), req.title)
|
| 859 |
+
return {"status": "success", "chunks_added": len(chunks), "title": req.title}
|
| 860 |
+
|
| 861 |
+
# ---------------------------------------------------------------------------
|
| 862 |
+
# GET /logs and /stats — fetch history for dashboard
|
| 863 |
+
# ---------------------------------------------------------------------------
|
| 864 |
+
@app.get("/logs", tags=["dashboard"])
|
| 865 |
+
def get_logs(limit: int = 50):
|
| 866 |
+
try:
|
| 867 |
+
conn = sqlite3.connect("data/logs.db")
|
| 868 |
+
conn.row_factory = sqlite3.Row
|
| 869 |
+
c = conn.cursor()
|
| 870 |
+
c.execute("SELECT * FROM audit_logs ORDER BY id DESC LIMIT ?", (limit,))
|
| 871 |
+
rows = c.fetchall()
|
| 872 |
+
conn.close()
|
| 873 |
+
return [dict(ix) for ix in rows]
|
| 874 |
+
except Exception as e:
|
| 875 |
+
return []
|
| 876 |
+
|
| 877 |
+
@app.get("/stats", tags=["dashboard"])
|
| 878 |
+
def get_stats():
|
| 879 |
+
try:
|
| 880 |
+
conn = sqlite3.connect("data/logs.db")
|
| 881 |
+
c = conn.cursor()
|
| 882 |
+
c.execute("SELECT COUNT(*), AVG(hrs), SUM(CASE WHEN risk_band='CRITICAL' THEN 1 ELSE 0 END) FROM audit_logs")
|
| 883 |
+
total_evals, avg_hrs, crit_alerts = c.fetchone()
|
| 884 |
+
|
| 885 |
+
c.execute("SELECT SUM(CASE WHEN intervention_applied=1 THEN 1 ELSE 0 END) FROM audit_logs")
|
| 886 |
+
interventions = c.fetchone()[0]
|
| 887 |
+
|
| 888 |
+
# Monthly data example
|
| 889 |
+
monthly_query = "SELECT SUBSTR(timestamp, 1, 7) as month, AVG(hrs) FROM audit_logs GROUP BY month ORDER BY month LIMIT 12"
|
| 890 |
+
c.execute(monthly_query)
|
| 891 |
+
monthly_data = [{"month": row[0], "avg_hrs": row[1]} for row in c.fetchall()]
|
| 892 |
+
|
| 893 |
+
conn.close()
|
| 894 |
+
return {
|
| 895 |
+
"totalEvals": total_evals or 0,
|
| 896 |
+
"avgHrs": round(avg_hrs or 0, 1),
|
| 897 |
+
"critAlerts": crit_alerts or 0,
|
| 898 |
+
"interventions": interventions or 0,
|
| 899 |
+
"monthly": monthly_data
|
| 900 |
+
}
|
| 901 |
+
except Exception as e:
|
| 902 |
+
return {
|
| 903 |
+
"totalEvals": 0, "avgHrs": 0, "critAlerts": 0, "interventions": 0, "monthly": []
|
| 904 |
+
}
|
| 905 |
+
|
| 906 |
+
# ---------------------------------------------------------------------------
|
| 907 |
+
# POST /parse_file — helper for frontend to extract PDF/DOCX text
|
| 908 |
+
# ---------------------------------------------------------------------------
|
| 909 |
+
@app.post("/parse_file", tags=["ingestion"])
|
| 910 |
+
async def parse_file(file: UploadFile = File(...)):
|
| 911 |
+
"""Extract text from uploaded txt, md, pdf, or docx files."""
|
| 912 |
+
content = await file.read()
|
| 913 |
+
filename = file.filename.lower()
|
| 914 |
+
text = ""
|
| 915 |
+
try:
|
| 916 |
+
if filename.endswith(".pdf"):
|
| 917 |
+
import fitz
|
| 918 |
+
doc = fitz.open(stream=content, filetype="pdf")
|
| 919 |
+
msgs = []
|
| 920 |
+
for page in doc:
|
| 921 |
+
msgs.append(page.get_text())
|
| 922 |
+
text = "\n".join(msgs)
|
| 923 |
+
elif filename.endswith(".docx"):
|
| 924 |
+
import docx
|
| 925 |
+
from io import BytesIO
|
| 926 |
+
doc = docx.Document(BytesIO(content))
|
| 927 |
+
text = "\n".join([p.text for p in doc.paragraphs])
|
| 928 |
+
else:
|
| 929 |
+
text = content.decode("utf-8", errors="replace")
|
| 930 |
+
return {"status": "success", "text": text}
|
| 931 |
+
except Exception as e:
|
| 932 |
+
raise HTTPException(status_code=400, detail=f"Failed to parse file: {e}")
|
| 933 |
+
|
src/api/schemas.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/api/schemas.py — Pydantic request/response models for MediRAG FastAPI
|
| 3 |
+
=========================================================================
|
| 4 |
+
FR-18: Input validation limits from config.yaml → api:
|
| 5 |
+
- max_query_length: 500
|
| 6 |
+
- max_answer_length: 2000
|
| 7 |
+
- max_chunks: 10
|
| 8 |
+
- max_chunk_length: 2000
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from typing import Any, Dict, List, Optional
|
| 13 |
+
from pydantic import BaseModel, Field, field_validator
|
| 14 |
+
|
| 15 |
+
class IngestRequest(BaseModel):
|
| 16 |
+
"""POST /ingest — append a custom document to the FAISS index."""
|
| 17 |
+
title: str = Field(..., description="Document title")
|
| 18 |
+
text: str = Field(..., min_length=10, description="Raw text of the document to ingest")
|
| 19 |
+
pub_type: str = Field(default="clinical_guideline", description="Document type")
|
| 20 |
+
source: str = Field(default="custom_upload", description="Source of the document")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
# Request schemas
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
|
| 27 |
+
class ContextChunk(BaseModel):
|
| 28 |
+
"""A single retrieved context chunk passed to the evaluation pipeline."""
|
| 29 |
+
text: str = Field(..., min_length=1, max_length=2000,
|
| 30 |
+
description="Chunk text (max 2000 chars)")
|
| 31 |
+
# Optional metadata fields — all pass-through to the pipeline modules
|
| 32 |
+
chunk_id: Optional[str] = None
|
| 33 |
+
pub_type: Optional[str] = None
|
| 34 |
+
pub_year: Optional[int] = None
|
| 35 |
+
source: Optional[str] = None
|
| 36 |
+
title: Optional[str] = None
|
| 37 |
+
tier_type: Optional[str] = None # pre-labelled evidence tier (optional)
|
| 38 |
+
score: Optional[float] = None # retrieval similarity score
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class EvaluateRequest(BaseModel):
|
| 42 |
+
"""POST /evaluate — request body."""
|
| 43 |
+
question: str = Field(
|
| 44 |
+
...,
|
| 45 |
+
min_length=5,
|
| 46 |
+
max_length=500,
|
| 47 |
+
description="User question (5–500 chars)",
|
| 48 |
+
examples=["What is the recommended dosage of Metformin for Type 2 Diabetes in elderly patients?"],
|
| 49 |
+
)
|
| 50 |
+
answer: str = Field(
|
| 51 |
+
...,
|
| 52 |
+
min_length=1,
|
| 53 |
+
max_length=2000,
|
| 54 |
+
description="LLM-generated answer to evaluate (1–2000 chars)",
|
| 55 |
+
examples=["Metformin is typically started at 500 mg twice daily with meals..."],
|
| 56 |
+
)
|
| 57 |
+
context_chunks: List[ContextChunk] = Field(
|
| 58 |
+
...,
|
| 59 |
+
min_length=1,
|
| 60 |
+
max_length=10,
|
| 61 |
+
description="Retrieved context chunks (1–10 items)",
|
| 62 |
+
)
|
| 63 |
+
run_ragas: bool = Field(
|
| 64 |
+
default=False,
|
| 65 |
+
description="Run RAGAS evaluation (requires Ollama or OpenAI backend; slower)",
|
| 66 |
+
)
|
| 67 |
+
llm_provider: Optional[str] = Field(
|
| 68 |
+
default=None,
|
| 69 |
+
description="LLM provider override: 'gemini' or 'ollama'"
|
| 70 |
+
)
|
| 71 |
+
llm_api_key: Optional[str] = Field(
|
| 72 |
+
default=None,
|
| 73 |
+
description="API Key if accessing Gemini"
|
| 74 |
+
)
|
| 75 |
+
llm_model: Optional[str] = Field(
|
| 76 |
+
default=None,
|
| 77 |
+
description="Specific model string if overriding defaults"
|
| 78 |
+
)
|
| 79 |
+
rxnorm_cache_path: str = Field(
|
| 80 |
+
default="data/rxnorm_cache.csv",
|
| 81 |
+
description="Path to RxNorm cache CSV",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
@field_validator("context_chunks")
|
| 85 |
+
@classmethod
|
| 86 |
+
def at_least_one_chunk(cls, v: list) -> list:
|
| 87 |
+
if len(v) == 0:
|
| 88 |
+
raise ValueError("At least one context chunk is required")
|
| 89 |
+
return v
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# ---------------------------------------------------------------------------
|
| 93 |
+
# Response schemas
|
| 94 |
+
# ---------------------------------------------------------------------------
|
| 95 |
+
|
| 96 |
+
class ModuleScore(BaseModel):
|
| 97 |
+
"""Score + details dict for a single evaluation module."""
|
| 98 |
+
score: float = Field(..., ge=0.0, le=1.0, description="Module score in [0, 1]")
|
| 99 |
+
details: Dict[str, Any] = Field(default_factory=dict)
|
| 100 |
+
error: Optional[str] = Field(None, description="Error message if module failed")
|
| 101 |
+
latency_ms: Optional[int] = None
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class ModuleResults(BaseModel):
|
| 105 |
+
"""All per-module scores bundled together."""
|
| 106 |
+
faithfulness: Optional[ModuleScore] = None
|
| 107 |
+
entity_verifier: Optional[ModuleScore] = None
|
| 108 |
+
source_credibility: Optional[ModuleScore] = None
|
| 109 |
+
contradiction: Optional[ModuleScore] = None
|
| 110 |
+
ragas: Optional[ModuleScore] = None
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class EvaluateResponse(BaseModel):
|
| 114 |
+
"""POST /evaluate — response body (FR-17 format)."""
|
| 115 |
+
composite_score: float = Field(
|
| 116 |
+
..., ge=0.0, le=1.0,
|
| 117 |
+
description="Weighted composite score in [0, 1]"
|
| 118 |
+
)
|
| 119 |
+
hrs: int = Field(
|
| 120 |
+
..., ge=0, le=100,
|
| 121 |
+
description="Health Risk Score = round(100 × (1 - composite_score))"
|
| 122 |
+
)
|
| 123 |
+
confidence_level: str = Field(
|
| 124 |
+
...,
|
| 125 |
+
description="HIGH / MODERATE / LOW",
|
| 126 |
+
)
|
| 127 |
+
risk_band: str = Field(
|
| 128 |
+
...,
|
| 129 |
+
description="LOW / MODERATE / HIGH / CRITICAL",
|
| 130 |
+
)
|
| 131 |
+
module_results: ModuleResults
|
| 132 |
+
total_pipeline_ms: int = Field(..., description="Total wall-clock time in ms")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class ChatMessage(BaseModel):
|
| 136 |
+
role: str
|
| 137 |
+
content: str
|
| 138 |
+
|
| 139 |
+
class ChatRequest(BaseModel):
|
| 140 |
+
messages: List[ChatMessage]
|
| 141 |
+
system_prompt: Optional[str] = None
|
| 142 |
+
persona: Optional[str] = "physician"
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class HealthResponse(BaseModel):
|
| 146 |
+
"""GET /health — liveness and dependency status."""
|
| 147 |
+
status: str = Field(default="ok")
|
| 148 |
+
ollama_available: bool
|
| 149 |
+
version: str = Field(default="0.1.0")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# ---------------------------------------------------------------------------
|
| 153 |
+
# End-to-end query schemas (POST /query)
|
| 154 |
+
# ---------------------------------------------------------------------------
|
| 155 |
+
|
| 156 |
+
class QueryRequest(BaseModel):
|
| 157 |
+
"""POST /query — only a question needed; retrieval + generation happen server-side."""
|
| 158 |
+
question: str = Field(
|
| 159 |
+
...,
|
| 160 |
+
min_length=5,
|
| 161 |
+
max_length=8000,
|
| 162 |
+
description="Medical question (5–8000 chars; may include doc context)",
|
| 163 |
+
examples=["What is the recommended dosage of Metformin for elderly Type 2 Diabetes patients?"],
|
| 164 |
+
)
|
| 165 |
+
top_k: int = Field(
|
| 166 |
+
default=5,
|
| 167 |
+
ge=1,
|
| 168 |
+
le=10,
|
| 169 |
+
description="Number of context chunks to retrieve (1–10)",
|
| 170 |
+
)
|
| 171 |
+
run_ragas: bool = Field(
|
| 172 |
+
default=False,
|
| 173 |
+
description="Run RAGAS evaluation (requires LLM backend)",
|
| 174 |
+
)
|
| 175 |
+
# Per-request LLM overrides — if not set, server config.yaml values are used
|
| 176 |
+
# This makes the eval engine portable: callers bring their own key + model
|
| 177 |
+
llm_provider: Optional[str] = Field(
|
| 178 |
+
default=None,
|
| 179 |
+
description="LLM provider override: 'gemini' or 'ollama'"
|
| 180 |
+
)
|
| 181 |
+
llm_api_key: Optional[str] = Field(
|
| 182 |
+
default=None,
|
| 183 |
+
description="API key override (e.g. Gemini key). Not logged or stored."
|
| 184 |
+
)
|
| 185 |
+
llm_model: Optional[str] = Field(
|
| 186 |
+
default=None,
|
| 187 |
+
description="Model name override (e.g. 'gemini-2.5-flash-lite')"
|
| 188 |
+
)
|
| 189 |
+
ollama_url: Optional[str] = Field(
|
| 190 |
+
default=None,
|
| 191 |
+
description="Ollama base URL override (e.g. 'http://localhost:11434')"
|
| 192 |
+
)
|
| 193 |
+
# Demo/test only — injects a false claim into the LLM answer before evaluation
|
| 194 |
+
# to demonstrate the intervention system catching hallucinations.
|
| 195 |
+
inject_hallucination: Optional[str] = Field(
|
| 196 |
+
default=None,
|
| 197 |
+
description="[DEMO ONLY] Appends a false medical claim to the answer before evaluation."
|
| 198 |
+
)
|
| 199 |
+
# Consensus Engine (Option 2)
|
| 200 |
+
use_consensus: bool = Field(
|
| 201 |
+
default=False,
|
| 202 |
+
description="Run multiple models and compare for clinical agreement."
|
| 203 |
+
)
|
| 204 |
+
# Privacy Shield (Option 1)
|
| 205 |
+
use_privacy_shield: bool = Field(
|
| 206 |
+
default=False,
|
| 207 |
+
description="Automatically redact PHI/PII (names, IDs) before external API calls.",
|
| 208 |
+
)
|
| 209 |
+
system_prompt: Optional[str] = Field(
|
| 210 |
+
default=None,
|
| 211 |
+
description="Custom system prompt to override the default clinical persona."
|
| 212 |
+
)
|
| 213 |
+
persona: Optional[str] = Field(
|
| 214 |
+
default="physician",
|
| 215 |
+
description="The target audience for the response: 'physician' or 'patient'."
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class RetrievedChunk(BaseModel):
|
| 220 |
+
"""A single chunk returned alongside the query response for transparency."""
|
| 221 |
+
chunk_id: Optional[str] = None
|
| 222 |
+
text: str
|
| 223 |
+
source: Optional[str] = None
|
| 224 |
+
pub_type: Optional[str] = None
|
| 225 |
+
pub_year: Optional[int] = None
|
| 226 |
+
title: Optional[str] = None
|
| 227 |
+
similarity_score: Optional[float] = None
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class QueryResponse(BaseModel):
|
| 231 |
+
"""POST /query — full end-to-end response."""
|
| 232 |
+
question: str
|
| 233 |
+
generated_answer: str
|
| 234 |
+
retrieved_chunks: List[RetrievedChunk]
|
| 235 |
+
# Evaluation fields (same as EvaluateResponse)
|
| 236 |
+
composite_score: float = Field(..., ge=0.0, le=1.0)
|
| 237 |
+
hrs: int = Field(..., ge=0, le=100)
|
| 238 |
+
confidence_level: str
|
| 239 |
+
risk_band: str
|
| 240 |
+
module_results: ModuleResults
|
| 241 |
+
total_pipeline_ms: int
|
| 242 |
+
# Intervention fields (active safety gate)
|
| 243 |
+
intervention_applied: bool = Field(
|
| 244 |
+
default=False,
|
| 245 |
+
description="True if the system modified or blocked the response for safety.",
|
| 246 |
+
)
|
| 247 |
+
intervention_reason: Optional[str] = Field(
|
| 248 |
+
default=None,
|
| 249 |
+
description="CRITICAL_BLOCKED | HIGH_RISK_REGENERATED | null",
|
| 250 |
+
)
|
| 251 |
+
original_answer: Optional[str] = Field(
|
| 252 |
+
default=None,
|
| 253 |
+
description="The original (unsafe) LLM answer before intervention, for transparency.",
|
| 254 |
+
)
|
| 255 |
+
intervention_details: Optional[Dict[str, Any]] = Field(
|
| 256 |
+
default=None,
|
| 257 |
+
description="Which modules triggered the intervention and their scores.",
|
| 258 |
+
)
|
| 259 |
+
# Consensus fields
|
| 260 |
+
consensus_results: Optional[Dict[str, Any]] = Field(
|
| 261 |
+
default=None,
|
| 262 |
+
description="Results from the multi-model agreement check."
|
| 263 |
+
)
|
| 264 |
+
# Privacy Shield fields
|
| 265 |
+
privacy_applied: bool = Field(default=False)
|
| 266 |
+
privacy_details: Optional[Dict[str, Any]] = Field(default=None)
|
| 267 |
+
# Coverage gap gate — distinguishes missing DB coverage from hallucination
|
| 268 |
+
coverage_gap: bool = Field(
|
| 269 |
+
default=False,
|
| 270 |
+
description="True when retrieval quality is low — the database may lack coverage for this topic.",
|
| 271 |
+
)
|
| 272 |
+
coverage_gap_details: Optional[Dict[str, Any]] = Field(
|
| 273 |
+
default=None,
|
| 274 |
+
description="gap_type (COVERAGE_GAP | HALLUCINATION), retrieval_confidence, threshold.",
|
| 275 |
+
)
|
| 276 |
+
|
src/cli.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import typer
|
| 2 |
+
import subprocess
|
| 3 |
+
import webbrowser
|
| 4 |
+
import time
|
| 5 |
+
import socket
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
app = typer.Typer(help="MediRAG Command Line Interface")
|
| 10 |
+
|
| 11 |
+
def is_port_in_use(port: int) -> bool:
|
| 12 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 13 |
+
return s.connect_ex(('localhost', port)) == 0
|
| 14 |
+
|
| 15 |
+
@app.command()
|
| 16 |
+
def start():
|
| 17 |
+
"""Start the full MediRAG experience (Backend + Full Frontend)"""
|
| 18 |
+
typer.echo("Starting full MediRAG experience...")
|
| 19 |
+
run_servers(practical_mode=False)
|
| 20 |
+
|
| 21 |
+
@app.command()
|
| 22 |
+
def api():
|
| 23 |
+
"""Start the streamlined 'practical' UI"""
|
| 24 |
+
typer.echo("Starting streamlined MediRAG practical UI...")
|
| 25 |
+
run_servers(practical_mode=True)
|
| 26 |
+
|
| 27 |
+
def run_servers(practical_mode: bool):
|
| 28 |
+
# Check ports
|
| 29 |
+
if is_port_in_use(8000):
|
| 30 |
+
typer.echo("Warning: Port 8000 (Backend) might already be in use.")
|
| 31 |
+
if is_port_in_use(5173):
|
| 32 |
+
typer.echo("Warning: Port 5173 (Frontend) might already be in use.")
|
| 33 |
+
|
| 34 |
+
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 35 |
+
frontend_dir = os.path.join(os.path.dirname(backend_dir), "Frontend")
|
| 36 |
+
|
| 37 |
+
# Start Backend
|
| 38 |
+
typer.echo("Starting Backend server...")
|
| 39 |
+
backend_process = subprocess.Popen(
|
| 40 |
+
[sys.executable, "-m", "uvicorn", "src.api.main:app", "--host", "0.0.0.0", "--port", "8000"],
|
| 41 |
+
cwd=backend_dir
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Start Frontend
|
| 45 |
+
typer.echo("Starting Frontend server...")
|
| 46 |
+
# On Windows, npm run dev needs shell=True or using cmd /c
|
| 47 |
+
frontend_process = subprocess.Popen(
|
| 48 |
+
["cmd", "/c", "npm", "run", "dev"] if os.name == 'nt' else ["npm", "run", "dev"],
|
| 49 |
+
cwd=frontend_dir
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
typer.echo("Waiting for servers to start...")
|
| 53 |
+
time.sleep(5) # Basic wait for frontend to spin up
|
| 54 |
+
|
| 55 |
+
url = "http://localhost:5173/cli-view" if practical_mode else "http://localhost:5173/"
|
| 56 |
+
typer.echo(f"Opening browser at {url}...")
|
| 57 |
+
webbrowser.open(url)
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
# Keep process alive
|
| 61 |
+
backend_process.wait()
|
| 62 |
+
frontend_process.wait()
|
| 63 |
+
except KeyboardInterrupt:
|
| 64 |
+
typer.echo("\nShutting down servers...")
|
| 65 |
+
backend_process.terminate()
|
| 66 |
+
frontend_process.terminate()
|
| 67 |
+
typer.echo("Servers stopped.")
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
app()
|
src/dashboard/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# src/dashboard/__init__.py
|
src/evaluate.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FR-22: src/evaluate.py — MediRAG Evaluation Orchestrator
|
| 3 |
+
=========================================================
|
| 4 |
+
Top-level entry point for the evaluation pipeline.
|
| 5 |
+
|
| 6 |
+
Runs all 4 evaluation modules + RAGAS + aggregator for a given
|
| 7 |
+
(question, answer, context_docs) triple, returning a fully structured
|
| 8 |
+
composite EvalResult.
|
| 9 |
+
|
| 10 |
+
Usage as a module:
|
| 11 |
+
from src.evaluate import run_evaluation
|
| 12 |
+
result = run_evaluation(question, answer, context_docs)
|
| 13 |
+
print(f"Score: {result.score:.3f} ({result.details['confidence_level']})")
|
| 14 |
+
|
| 15 |
+
Usage from CLI:
|
| 16 |
+
python -m src.evaluate \\
|
| 17 |
+
--question "What is the recommended dosage of Metformin for Type 2 Diabetes?" \\
|
| 18 |
+
--answer "Metformin is typically started at 500mg twice daily..." \\
|
| 19 |
+
--context-file data/processed/chunks.jsonl \\
|
| 20 |
+
--top-k 5
|
| 21 |
+
|
| 22 |
+
SRS reference: FR-22, Section 7 (Evaluation Pipeline Overview)
|
| 23 |
+
"""
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import argparse
|
| 27 |
+
import json
|
| 28 |
+
import logging
|
| 29 |
+
import sys
|
| 30 |
+
import time
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from typing import Optional
|
| 33 |
+
|
| 34 |
+
from src.modules.base import EvalResult
|
| 35 |
+
from src.modules.faithfulness import score_faithfulness
|
| 36 |
+
from src.modules.entity_verifier import verify_entities
|
| 37 |
+
from src.modules.source_credibility import score_source_credibility
|
| 38 |
+
from src.modules.contradiction import score_contradiction
|
| 39 |
+
from src.evaluation.ragas_eval import score_ragas
|
| 40 |
+
from src.evaluation.aggregator import aggregate
|
| 41 |
+
|
| 42 |
+
logger = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
# Main evaluation function
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
|
| 49 |
+
def run_evaluation(
|
| 50 |
+
question: str,
|
| 51 |
+
answer: str,
|
| 52 |
+
context_chunks: list[dict],
|
| 53 |
+
rxnorm_cache_path: str = "data/rxnorm_cache.csv",
|
| 54 |
+
run_ragas: bool = True,
|
| 55 |
+
weights: Optional[dict[str, float]] = None,
|
| 56 |
+
config: Optional[dict] = None,
|
| 57 |
+
) -> EvalResult:
|
| 58 |
+
"""
|
| 59 |
+
Run the full MediRAG evaluation pipeline for a single QA pair.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
question : Original user question.
|
| 63 |
+
answer : LLM-generated answer to evaluate.
|
| 64 |
+
context_chunks : List of retrieved chunk dicts (from retriever.retrieve()).
|
| 65 |
+
Each chunk must have at minimum {'text': str}.
|
| 66 |
+
rxnorm_cache_path : Path to rxnorm_cache.csv for entity verification.
|
| 67 |
+
run_ragas : Whether to run the RAGAS module (requires LLM backend).
|
| 68 |
+
weights : Override default aggregation weights (optional).
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
EvalResult for the "aggregator" module containing:
|
| 72 |
+
.score → composite score in [0, 1]
|
| 73 |
+
.details → full breakdown per module
|
| 74 |
+
.latency_ms → total wall-clock time
|
| 75 |
+
"""
|
| 76 |
+
t_start = time.perf_counter()
|
| 77 |
+
logger.info("=== MediRAG Evaluation START ===")
|
| 78 |
+
logger.info("Question: %s", question[:120])
|
| 79 |
+
logger.info("Answer : %s", answer[:120])
|
| 80 |
+
logger.info("Chunks : %d context documents", len(context_chunks))
|
| 81 |
+
|
| 82 |
+
# Extract text and metadata for modules that need it
|
| 83 |
+
context_texts: list[str] = [c.get("text", "") for c in context_chunks]
|
| 84 |
+
chunk_ids: list[str] = [
|
| 85 |
+
c.get("chunk_id") or c.get("metadata", {}).get("chunk_id") or f"chunk_{i}"
|
| 86 |
+
for i, c in enumerate(context_chunks)
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
# -------------------------------------------------------------------------
|
| 90 |
+
# Retrieval Quality Gate
|
| 91 |
+
# If the retriever's absolute RRF score is too low, the chunks are likely
|
| 92 |
+
# unrelated to the question — evaluation against them produces false HRS spikes.
|
| 93 |
+
# Threshold: max raw RRF for top-1 in both sources = 2/(60+1) ≈ 0.0328
|
| 94 |
+
# We flag as insufficient if max_rrf < 0.008 (only very weak BM25 or FAISS match)
|
| 95 |
+
# -------------------------------------------------------------------------
|
| 96 |
+
RETRIEVAL_CONFIDENCE_THRESHOLD = 0.008
|
| 97 |
+
retrieval_confidence = context_chunks[0].get("_retrieval_confidence", 1.0) if context_chunks else 0.0
|
| 98 |
+
|
| 99 |
+
if context_chunks and retrieval_confidence < RETRIEVAL_CONFIDENCE_THRESHOLD:
|
| 100 |
+
logger.warning(
|
| 101 |
+
"Retrieval confidence %.6f below threshold %.3f — context likely irrelevant to question.",
|
| 102 |
+
retrieval_confidence, RETRIEVAL_CONFIDENCE_THRESHOLD,
|
| 103 |
+
)
|
| 104 |
+
total_ms = int((time.perf_counter() - t_start) * 1000)
|
| 105 |
+
return EvalResult(
|
| 106 |
+
module_name="aggregator",
|
| 107 |
+
score=0.5,
|
| 108 |
+
details={
|
| 109 |
+
"retrieval_insufficient": True,
|
| 110 |
+
"retrieval_confidence": retrieval_confidence,
|
| 111 |
+
"hrs": 50,
|
| 112 |
+
"risk_band": "MODERATE",
|
| 113 |
+
"confidence_level": "LOW",
|
| 114 |
+
"total_pipeline_ms": total_ms,
|
| 115 |
+
"module_results": {},
|
| 116 |
+
"warning": (
|
| 117 |
+
"Retrieved context has very low relevance to the question "
|
| 118 |
+
f"(retrieval_confidence={retrieval_confidence:.4f}). "
|
| 119 |
+
"Evaluation scores would be meaningless. "
|
| 120 |
+
"Consider rephrasing the question or expanding the index."
|
| 121 |
+
),
|
| 122 |
+
},
|
| 123 |
+
latency_ms=total_ms,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# -------------------------------------------------------------------------
|
| 127 |
+
# Module 1: Faithfulness (DeBERTa NLI)
|
| 128 |
+
# -------------------------------------------------------------------------
|
| 129 |
+
logger.info("--- Module 1: Faithfulness ---")
|
| 130 |
+
faith_result = score_faithfulness(
|
| 131 |
+
answer=answer,
|
| 132 |
+
context_docs=context_texts,
|
| 133 |
+
chunk_ids=chunk_ids,
|
| 134 |
+
config=config,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# -------------------------------------------------------------------------
|
| 138 |
+
# Module 2: Entity Verification (SciSpaCy + RxNorm)
|
| 139 |
+
# -------------------------------------------------------------------------
|
| 140 |
+
logger.info("--- Module 2: Entity Verification ---")
|
| 141 |
+
entity_result = verify_entities(
|
| 142 |
+
answer=answer,
|
| 143 |
+
question=question,
|
| 144 |
+
context_docs=context_texts,
|
| 145 |
+
rxnorm_cache_path=rxnorm_cache_path,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# -------------------------------------------------------------------------
|
| 149 |
+
# Module 3: Source Credibility (Evidence Tier)
|
| 150 |
+
# -------------------------------------------------------------------------
|
| 151 |
+
logger.info("--- Module 3: Source Credibility ---")
|
| 152 |
+
source_result = score_source_credibility(retrieved_chunks=context_chunks)
|
| 153 |
+
|
| 154 |
+
# -------------------------------------------------------------------------
|
| 155 |
+
# Module 4: Contradiction Detection (DeBERTa NLI cross-check)
|
| 156 |
+
# -------------------------------------------------------------------------
|
| 157 |
+
logger.info("--- Module 4: Contradiction Detection ---")
|
| 158 |
+
contra_result = score_contradiction(
|
| 159 |
+
answer=answer,
|
| 160 |
+
context_docs=context_texts,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# -------------------------------------------------------------------------
|
| 164 |
+
# RAGAS (optional — requires LLM backend)
|
| 165 |
+
# -------------------------------------------------------------------------
|
| 166 |
+
ragas_result: Optional[EvalResult] = None
|
| 167 |
+
if run_ragas:
|
| 168 |
+
logger.info("--- RAGAS Evaluation ---")
|
| 169 |
+
ragas_result = score_ragas(
|
| 170 |
+
question=question,
|
| 171 |
+
answer=answer,
|
| 172 |
+
context_docs=context_texts,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# -------------------------------------------------------------------------
|
| 176 |
+
# Aggregator: weighted composite
|
| 177 |
+
# -------------------------------------------------------------------------
|
| 178 |
+
logger.info("--- Aggregator ---")
|
| 179 |
+
agg_result = aggregate(
|
| 180 |
+
faithfulness_result=faith_result,
|
| 181 |
+
entity_result=entity_result,
|
| 182 |
+
source_result=source_result,
|
| 183 |
+
contradiction_result=contra_result,
|
| 184 |
+
ragas_result=ragas_result,
|
| 185 |
+
weights=weights,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
total_ms = int((time.perf_counter() - t_start) * 1000)
|
| 189 |
+
agg_result.details["total_pipeline_ms"] = total_ms
|
| 190 |
+
|
| 191 |
+
# Attach per-module results for API/dashboard access
|
| 192 |
+
agg_result.details["module_results"] = {
|
| 193 |
+
"faithfulness": {"score": faith_result.score, "details": faith_result.details},
|
| 194 |
+
"entity_verifier": {"score": entity_result.score, "details": entity_result.details},
|
| 195 |
+
"source_credibility": {"score": source_result.score, "details": source_result.details},
|
| 196 |
+
"contradiction": {"score": contra_result.score, "details": contra_result.details},
|
| 197 |
+
"ragas": {"score": ragas_result.score, "details": ragas_result.details} if ragas_result else None,
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
logger.info(
|
| 201 |
+
"=== MediRAG Evaluation DONE: score=%.3f (%s) in %d ms ===",
|
| 202 |
+
agg_result.score,
|
| 203 |
+
agg_result.details.get("confidence_level", "?"),
|
| 204 |
+
total_ms,
|
| 205 |
+
)
|
| 206 |
+
return agg_result
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# ---------------------------------------------------------------------------
|
| 210 |
+
# CLI entry point
|
| 211 |
+
# ---------------------------------------------------------------------------
|
| 212 |
+
|
| 213 |
+
def _build_parser() -> argparse.ArgumentParser:
|
| 214 |
+
p = argparse.ArgumentParser(
|
| 215 |
+
description="MediRAG evaluation pipeline (FR-22)"
|
| 216 |
+
)
|
| 217 |
+
p.add_argument("--question", required=True, help="User question")
|
| 218 |
+
p.add_argument("--answer", required=True, help="LLM answer to evaluate")
|
| 219 |
+
p.add_argument("--context-file", default="data/processed/chunks.jsonl",
|
| 220 |
+
help="JSONL file of chunks (output of ingest.py)")
|
| 221 |
+
p.add_argument("--top-k", type=int, default=5,
|
| 222 |
+
help="Number of context chunks to use")
|
| 223 |
+
p.add_argument("--rxnorm-cache", default="data/rxnorm_cache.csv",
|
| 224 |
+
help="Path to rxnorm_cache.csv")
|
| 225 |
+
p.add_argument("--no-ragas", action="store_true",
|
| 226 |
+
help="Skip RAGAS evaluation (no LLM backend needed)")
|
| 227 |
+
p.add_argument("--json", action="store_true",
|
| 228 |
+
help="Output result as JSON")
|
| 229 |
+
return p
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _load_context_from_file(path: str, top_k: int) -> list[dict]:
|
| 233 |
+
"""Load top-k chunks from a JSONL file as simple dicts."""
|
| 234 |
+
chunks = []
|
| 235 |
+
try:
|
| 236 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 237 |
+
for line in f:
|
| 238 |
+
line = line.strip()
|
| 239 |
+
if line:
|
| 240 |
+
chunks.append(json.loads(line))
|
| 241 |
+
if len(chunks) >= top_k:
|
| 242 |
+
break
|
| 243 |
+
except FileNotFoundError:
|
| 244 |
+
logger.error("Context file not found: %s", path)
|
| 245 |
+
sys.exit(1)
|
| 246 |
+
return chunks
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if __name__ == "__main__":
|
| 250 |
+
import yaml
|
| 251 |
+
|
| 252 |
+
# Load config.yaml for logging setup
|
| 253 |
+
try:
|
| 254 |
+
cfg = yaml.safe_load(Path("config.yaml").read_text())
|
| 255 |
+
log_level = cfg.get("logging", {}).get("level", "INFO")
|
| 256 |
+
except Exception:
|
| 257 |
+
log_level = "INFO"
|
| 258 |
+
|
| 259 |
+
logging.basicConfig(
|
| 260 |
+
level=getattr(logging, log_level, logging.INFO),
|
| 261 |
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
args = _build_parser().parse_args()
|
| 265 |
+
chunks = _load_context_from_file(args.context_file, args.top_k)
|
| 266 |
+
|
| 267 |
+
result = run_evaluation(
|
| 268 |
+
question=args.question,
|
| 269 |
+
answer=args.answer,
|
| 270 |
+
context_chunks=chunks,
|
| 271 |
+
rxnorm_cache_path=args.rxnorm_cache,
|
| 272 |
+
run_ragas=not args.no_ragas,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
if args.json:
|
| 276 |
+
import dataclasses
|
| 277 |
+
print(json.dumps(dataclasses.asdict(result), indent=2))
|
| 278 |
+
else:
|
| 279 |
+
print(f"\n{'='*60}")
|
| 280 |
+
print(f" MediRAG Evaluation Result")
|
| 281 |
+
print(f"{'='*60}")
|
| 282 |
+
print(f" Score : {result.score:.3f}")
|
| 283 |
+
print(f" Confidence : {result.details.get('confidence_level', 'N/A')}")
|
| 284 |
+
print(f" Pipeline time : {result.details.get('total_pipeline_ms', 0)} ms")
|
| 285 |
+
print(f"\n Module Breakdown:")
|
| 286 |
+
for mod, res in (result.details.get("module_results") or {}).items():
|
| 287 |
+
if res:
|
| 288 |
+
print(f" {mod:22s}: {res['score']:.3f}")
|
| 289 |
+
print(f"{'='*60}\n")
|
src/evaluation/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# src/evaluation/__init__.py
|
src/evaluation/aggregator.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FR-19: src/evaluation/aggregator.py — Weighted Score Aggregation
|
| 3 |
+
================================================================
|
| 4 |
+
Combines scores from all evaluation modules into a single composite score
|
| 5 |
+
using the fixed weights defined in SRS Section 8.2.
|
| 6 |
+
|
| 7 |
+
Weights (must sum to 1.0):
|
| 8 |
+
faithfulness : 0.35 (primary signal — DeBERTa NLI)
|
| 9 |
+
entity_accuracy : 0.20 (SciSpaCy NER + RxNorm)
|
| 10 |
+
source_credibility : 0.20 (evidence tier)
|
| 11 |
+
contradiction_risk : 0.15 (1.0 - contradiction_score)
|
| 12 |
+
ragas_composite : 0.10 (optional — 0.5 neutral if unavailable)
|
| 13 |
+
|
| 14 |
+
Output:
|
| 15 |
+
EvalResult with:
|
| 16 |
+
module_name = "aggregator"
|
| 17 |
+
score = weighted composite in [0, 1]
|
| 18 |
+
details = {weights_used, weighted_composite, component_contributions}
|
| 19 |
+
|
| 20 |
+
Usage:
|
| 21 |
+
from src.evaluation.aggregator import aggregate
|
| 22 |
+
agg_result = aggregate(faith_res, entity_res, source_res, contra_res, ragas_res)
|
| 23 |
+
"""
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import logging
|
| 27 |
+
import time
|
| 28 |
+
from typing import Optional
|
| 29 |
+
|
| 30 |
+
from src.modules.base import EvalResult
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
# Default weights (SRS Section 8.2)
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
|
| 38 |
+
DEFAULT_WEIGHTS: dict[str, float] = {
|
| 39 |
+
"faithfulness": 0.35,
|
| 40 |
+
"entity_accuracy": 0.20,
|
| 41 |
+
"source_credibility": 0.20,
|
| 42 |
+
"contradiction_risk": 0.15,
|
| 43 |
+
"ragas_composite": 0.10,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def aggregate(
|
| 48 |
+
faithfulness_result: EvalResult,
|
| 49 |
+
entity_result: EvalResult,
|
| 50 |
+
source_result: EvalResult,
|
| 51 |
+
contradiction_result: EvalResult,
|
| 52 |
+
ragas_result: Optional[EvalResult] = None,
|
| 53 |
+
weights: Optional[dict[str, float]] = None,
|
| 54 |
+
) -> EvalResult:
|
| 55 |
+
"""
|
| 56 |
+
Aggregate all module scores into a single composite evaluation result.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
faithfulness_result : Output from faithfulness.score_faithfulness()
|
| 60 |
+
entity_result : Output from entity_verifier.verify_entities()
|
| 61 |
+
source_result : Output from source_credibility.score_source_credibility()
|
| 62 |
+
contradiction_result : Output from contradiction.score_contradiction()
|
| 63 |
+
ragas_result : Output from ragas_eval.score_ragas() (optional)
|
| 64 |
+
weights : Override default weights (must sum to 1.0)
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
EvalResult with module_name="aggregator" and composite score.
|
| 68 |
+
"""
|
| 69 |
+
t0 = time.perf_counter()
|
| 70 |
+
w = weights or DEFAULT_WEIGHTS
|
| 71 |
+
|
| 72 |
+
# Validate weights sum to 1.0 (tolerance 0.01)
|
| 73 |
+
weight_sum = sum(w.values())
|
| 74 |
+
if abs(weight_sum - 1.0) > 0.01:
|
| 75 |
+
logger.warning(
|
| 76 |
+
"Weights sum to %.4f (expected 1.0) — normalising.", weight_sum
|
| 77 |
+
)
|
| 78 |
+
w = {k: v / weight_sum for k, v in w.items()}
|
| 79 |
+
|
| 80 |
+
# Extract scores — use 0.5 neutral for any unavailable module
|
| 81 |
+
faith_score = faithfulness_result.score if not faithfulness_result.error else 0.5
|
| 82 |
+
entity_score = entity_result.score if not entity_result.error else 0.5
|
| 83 |
+
source_score = source_result.score if not source_result.error else 0.5
|
| 84 |
+
contra_score = contradiction_result.score if not contradiction_result.error else 1.0
|
| 85 |
+
ragas_score = (ragas_result.score if ragas_result and not ragas_result.error else 0.5)
|
| 86 |
+
|
| 87 |
+
# Compute base weighted contributions
|
| 88 |
+
contributions = {
|
| 89 |
+
"faithfulness_contribution": round(faith_score * w["faithfulness"], 4),
|
| 90 |
+
"entity_contribution": round(entity_score * w["entity_accuracy"], 4),
|
| 91 |
+
"source_contribution": round(source_score * w["source_credibility"], 4),
|
| 92 |
+
"contradiction_contribution": round(contra_score * w["contradiction_risk"], 4),
|
| 93 |
+
"ragas_contribution": round(ragas_score * w["ragas_composite"], 4),
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
base_composite = sum(contributions.values())
|
| 97 |
+
|
| 98 |
+
# --- Non-linear Safety Penalties ---
|
| 99 |
+
# Faithfulness penalty: applies when answer is not grounded in context.
|
| 100 |
+
# Contradiction penalty: only applies when actual contradictions are detected
|
| 101 |
+
# (score < 0.3). Score = 0.5 means "neutral/cannot verify" (refusal answers,
|
| 102 |
+
# no keyword overlap) — these should NOT be double-penalized.
|
| 103 |
+
penalty_multiplier = 1.0
|
| 104 |
+
if faith_score <= 0.6:
|
| 105 |
+
penalty_multiplier *= 0.6 # 40% penalty for ungrounded claims
|
| 106 |
+
if contra_score < 0.3:
|
| 107 |
+
penalty_multiplier *= 0.6 # 40% penalty only for confirmed contradictions
|
| 108 |
+
|
| 109 |
+
composite = base_composite * penalty_multiplier
|
| 110 |
+
|
| 111 |
+
# HRS = round(100 × (1 - composite)), then map to risk band
|
| 112 |
+
# Thresholds must match config.yaml aggregator.risk_bands
|
| 113 |
+
_HRS_LOW = 30
|
| 114 |
+
_HRS_MODERATE = 60
|
| 115 |
+
_HRS_HIGH = 85
|
| 116 |
+
|
| 117 |
+
hrs = int(round(100 * (1.0 - composite)))
|
| 118 |
+
hrs = max(0, min(100, hrs))
|
| 119 |
+
|
| 120 |
+
if hrs <= _HRS_LOW:
|
| 121 |
+
risk_band = "LOW"
|
| 122 |
+
elif hrs <= _HRS_MODERATE:
|
| 123 |
+
risk_band = "MODERATE"
|
| 124 |
+
elif hrs <= _HRS_HIGH:
|
| 125 |
+
risk_band = "HIGH"
|
| 126 |
+
else:
|
| 127 |
+
risk_band = "CRITICAL"
|
| 128 |
+
|
| 129 |
+
# Confidence level (based on composite, not HRS)
|
| 130 |
+
if composite >= 0.80:
|
| 131 |
+
confidence = "HIGH"
|
| 132 |
+
elif composite >= 0.55:
|
| 133 |
+
confidence = "MODERATE"
|
| 134 |
+
else:
|
| 135 |
+
confidence = "LOW"
|
| 136 |
+
|
| 137 |
+
details = {
|
| 138 |
+
"weights_used": {k: round(v, 4) for k, v in w.items()},
|
| 139 |
+
"component_scores": {
|
| 140 |
+
"faithfulness": round(faith_score, 4),
|
| 141 |
+
"entity_accuracy": round(entity_score, 4),
|
| 142 |
+
"source_credibility": round(source_score, 4),
|
| 143 |
+
"contradiction_risk": round(contra_score, 4),
|
| 144 |
+
"ragas_composite": round(ragas_score, 4),
|
| 145 |
+
},
|
| 146 |
+
"weighted_composite": round(composite, 4),
|
| 147 |
+
"hrs": hrs,
|
| 148 |
+
"risk_band": risk_band,
|
| 149 |
+
"component_contributions": contributions,
|
| 150 |
+
"confidence_level": confidence,
|
| 151 |
+
"module_latencies_ms": {
|
| 152 |
+
"faithfulness": faithfulness_result.latency_ms,
|
| 153 |
+
"entity_verifier": entity_result.latency_ms,
|
| 154 |
+
"source_credibility": source_result.latency_ms,
|
| 155 |
+
"contradiction": contradiction_result.latency_ms,
|
| 156 |
+
"ragas": ragas_result.latency_ms if ragas_result else 0,
|
| 157 |
+
},
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
latency_ms = int((time.perf_counter() - t0) * 1000)
|
| 161 |
+
logger.info(
|
| 162 |
+
"Aggregated score: %.3f (%s confidence) — "
|
| 163 |
+
"faith=%.2f entity=%.2f source=%.2f contra=%.2f ragas=%.2f",
|
| 164 |
+
composite, confidence,
|
| 165 |
+
faith_score, entity_score, source_score, contra_score, ragas_score,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
return EvalResult(
|
| 169 |
+
module_name="aggregator",
|
| 170 |
+
score=composite,
|
| 171 |
+
details=details,
|
| 172 |
+
latency_ms=latency_ms,
|
| 173 |
+
)
|
src/evaluation/ragas_eval.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FR-06: src/evaluation/ragas_eval.py — RAGAS Faithfulness + Answer Relevancy
|
| 3 |
+
=============================================================================
|
| 4 |
+
Wraps the ragas library to compute:
|
| 5 |
+
- faithfulness : context-grounded claim verification
|
| 6 |
+
- answer_relevancy : semantic similarity of answer to question
|
| 7 |
+
|
| 8 |
+
Requires an LLM backend. Supported backends (in priority order):
|
| 9 |
+
1. Ollama (local, free) — set OLLAMA_HOST env var or use default localhost:11434
|
| 10 |
+
2. OpenAI API — set OPENAI_API_KEY env var
|
| 11 |
+
3. Graceful degradation — returns score=None with explanation if no LLM available
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
from src.evaluation.ragas_eval import score_ragas
|
| 15 |
+
result = score_ragas(question, answer, context_docs)
|
| 16 |
+
|
| 17 |
+
SRS reference: FR-06, Section 7 (Evaluation Pipeline)
|
| 18 |
+
"""
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import logging
|
| 22 |
+
import os
|
| 23 |
+
import time
|
| 24 |
+
from typing import Optional
|
| 25 |
+
|
| 26 |
+
from src.modules.base import EvalResult
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
# Backend detection
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
|
| 35 |
+
def _detect_llm_backend() -> Optional[str]:
|
| 36 |
+
"""Return 'ollama', 'openai', or None."""
|
| 37 |
+
if os.getenv("OPENAI_API_KEY"):
|
| 38 |
+
return "openai"
|
| 39 |
+
# Check if Ollama is running locally
|
| 40 |
+
try:
|
| 41 |
+
import requests
|
| 42 |
+
host = os.getenv("OLLAMA_HOST", "http://localhost:11434")
|
| 43 |
+
resp = requests.get(f"{host}/api/tags", timeout=2)
|
| 44 |
+
if resp.status_code == 200:
|
| 45 |
+
return "ollama"
|
| 46 |
+
except Exception:
|
| 47 |
+
pass
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _build_ragas_llm(backend: str):
|
| 52 |
+
"""Build a ragas-compatible LLM wrapper."""
|
| 53 |
+
if backend == "openai":
|
| 54 |
+
from langchain_openai import ChatOpenAI
|
| 55 |
+
return ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
|
| 56 |
+
elif backend == "ollama":
|
| 57 |
+
from langchain_community.chat_models import ChatOllama
|
| 58 |
+
host = os.getenv("OLLAMA_HOST", "http://localhost:11434")
|
| 59 |
+
model = os.getenv("OLLAMA_MODEL", "mistral")
|
| 60 |
+
return ChatOllama(base_url=host, model=model)
|
| 61 |
+
raise ValueError(f"Unknown backend: {backend}")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _build_ragas_embeddings(backend: str):
|
| 65 |
+
"""Build a ragas-compatible embeddings wrapper."""
|
| 66 |
+
if backend == "openai":
|
| 67 |
+
from langchain_openai import OpenAIEmbeddings
|
| 68 |
+
return OpenAIEmbeddings()
|
| 69 |
+
elif backend == "ollama":
|
| 70 |
+
from langchain_community.embeddings import OllamaEmbeddings
|
| 71 |
+
host = os.getenv("OLLAMA_HOST", "http://localhost:11434")
|
| 72 |
+
model = os.getenv("OLLAMA_EMBED_MODEL", "nomic-embed-text")
|
| 73 |
+
return OllamaEmbeddings(base_url=host, model=model)
|
| 74 |
+
raise ValueError(f"Unknown backend: {backend}")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ---------------------------------------------------------------------------
|
| 78 |
+
# Public API
|
| 79 |
+
# ---------------------------------------------------------------------------
|
| 80 |
+
|
| 81 |
+
def score_ragas(
|
| 82 |
+
question: str,
|
| 83 |
+
answer: str,
|
| 84 |
+
context_docs: list[str],
|
| 85 |
+
max_contexts: int = 3,
|
| 86 |
+
) -> EvalResult:
|
| 87 |
+
"""
|
| 88 |
+
Compute RAGAS faithfulness and answer_relevancy scores.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
question : Original user question.
|
| 92 |
+
answer : LLM-generated answer.
|
| 93 |
+
context_docs : Retrieved context passages.
|
| 94 |
+
max_contexts : Max context chunks to pass to RAGAS (to limit token cost).
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
EvalResult with module_name="ragas", score in [0,1].
|
| 98 |
+
score = mean(faithfulness, answer_relevancy).
|
| 99 |
+
Returns score=0.5 (neutral) if no LLM backend is available.
|
| 100 |
+
"""
|
| 101 |
+
t0 = time.perf_counter()
|
| 102 |
+
|
| 103 |
+
backend = _detect_llm_backend()
|
| 104 |
+
if backend is None:
|
| 105 |
+
logger.warning(
|
| 106 |
+
"No LLM backend available for RAGAS. "
|
| 107 |
+
"Set OPENAI_API_KEY or start Ollama (ollama serve). "
|
| 108 |
+
"Returning neutral score (0.5)."
|
| 109 |
+
)
|
| 110 |
+
return EvalResult(
|
| 111 |
+
module_name="ragas",
|
| 112 |
+
score=0.5,
|
| 113 |
+
details={
|
| 114 |
+
"backend": None,
|
| 115 |
+
"faithfulness": None,
|
| 116 |
+
"answer_relevancy": None,
|
| 117 |
+
"note": "No LLM backend — set OPENAI_API_KEY or start Ollama",
|
| 118 |
+
},
|
| 119 |
+
latency_ms=int((time.perf_counter() - t0) * 1000),
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
from datasets import Dataset
|
| 124 |
+
from ragas import evaluate
|
| 125 |
+
from ragas.metrics import faithfulness, answer_relevancy
|
| 126 |
+
|
| 127 |
+
llm = _build_ragas_llm(backend)
|
| 128 |
+
embeddings = _build_ragas_embeddings(backend)
|
| 129 |
+
|
| 130 |
+
# Configure metrics to use our chosen backend
|
| 131 |
+
faithfulness.llm = llm
|
| 132 |
+
faithfulness.embeddings = embeddings
|
| 133 |
+
answer_relevancy.llm = llm
|
| 134 |
+
answer_relevancy.embeddings = embeddings
|
| 135 |
+
|
| 136 |
+
contexts = context_docs[:max_contexts]
|
| 137 |
+
dataset = Dataset.from_dict(
|
| 138 |
+
{
|
| 139 |
+
"question": [question],
|
| 140 |
+
"answer": [answer],
|
| 141 |
+
"contexts": [contexts],
|
| 142 |
+
}
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
result = evaluate(dataset, metrics=[faithfulness, answer_relevancy])
|
| 146 |
+
|
| 147 |
+
faith_score = float(result["faithfulness"])
|
| 148 |
+
relevancy_score = float(result["answer_relevancy"])
|
| 149 |
+
composite = (faith_score + relevancy_score) / 2.0
|
| 150 |
+
|
| 151 |
+
details = {
|
| 152 |
+
"backend": backend,
|
| 153 |
+
"faithfulness": round(faith_score, 4),
|
| 154 |
+
"answer_relevancy": round(relevancy_score, 4),
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
latency_ms = int((time.perf_counter() - t0) * 1000)
|
| 158 |
+
logger.info(
|
| 159 |
+
"RAGAS: faith=%.3f, relevancy=%.3f → composite=%.3f in %d ms",
|
| 160 |
+
faith_score, relevancy_score, composite, latency_ms,
|
| 161 |
+
)
|
| 162 |
+
return EvalResult(
|
| 163 |
+
module_name="ragas",
|
| 164 |
+
score=composite,
|
| 165 |
+
details=details,
|
| 166 |
+
latency_ms=latency_ms,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
except Exception as exc:
|
| 170 |
+
logger.error("RAGAS evaluation failed: %s", exc)
|
| 171 |
+
return EvalResult(
|
| 172 |
+
module_name="ragas",
|
| 173 |
+
score=0.5,
|
| 174 |
+
details={"backend": backend, "error": str(exc)},
|
| 175 |
+
error=str(exc),
|
| 176 |
+
latency_ms=int((time.perf_counter() - t0) * 1000),
|
| 177 |
+
)
|
src/modules/__init__.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/modules/base.py — Shared EvalResult dataclass.
|
| 3 |
+
Used as the standard output schema by all 4 evaluation modules.
|
| 4 |
+
Details shape per module is fully specified here (SRS Section 5).
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Any, Optional
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class EvalResult:
|
| 17 |
+
"""
|
| 18 |
+
Shared output schema for all evaluation modules.
|
| 19 |
+
|
| 20 |
+
Attributes:
|
| 21 |
+
module_name : Identifier string, e.g. "faithfulness"
|
| 22 |
+
score : Module score in [0.0, 1.0] — clipped automatically
|
| 23 |
+
details : Module-specific dict (see DETAILS SHAPES below)
|
| 24 |
+
error : None if successful; error message string if module failed
|
| 25 |
+
latency_ms : Wall-clock milliseconds for this module's execution
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
module_name: str
|
| 29 |
+
score: float
|
| 30 |
+
details: dict[str, Any] = field(default_factory=dict)
|
| 31 |
+
error: Optional[str] = None
|
| 32 |
+
latency_ms: int = 0
|
| 33 |
+
|
| 34 |
+
def __post_init__(self) -> None:
|
| 35 |
+
"""Clip score to [0.0, 1.0] as required by SRS 4.2."""
|
| 36 |
+
if not (0.0 <= self.score <= 1.0):
|
| 37 |
+
logger.warning(
|
| 38 |
+
"%s: score %.4f out of [0,1], clipping.",
|
| 39 |
+
self.module_name,
|
| 40 |
+
self.score,
|
| 41 |
+
)
|
| 42 |
+
self.score = max(0.0, min(1.0, self.score))
|
| 43 |
+
|
| 44 |
+
# -------------------------------------------------------------------------
|
| 45 |
+
# DETAILS SHAPE REFERENCE (SRS Section 5)
|
| 46 |
+
# -------------------------------------------------------------------------
|
| 47 |
+
#
|
| 48 |
+
# faithfulness.details:
|
| 49 |
+
# {
|
| 50 |
+
# "total_claims": int,
|
| 51 |
+
# "entailed_count": int,
|
| 52 |
+
# "neutral_count": int,
|
| 53 |
+
# "contradicted_count": int,
|
| 54 |
+
# "claims": [
|
| 55 |
+
# {
|
| 56 |
+
# "claim": str,
|
| 57 |
+
# "status": "ENTAILED" | "NEUTRAL" | "CONTRADICTED",
|
| 58 |
+
# "best_chunk_id": str, # chunk with highest NLI score
|
| 59 |
+
# "nli_score": float
|
| 60 |
+
# }
|
| 61 |
+
# ]
|
| 62 |
+
# }
|
| 63 |
+
#
|
| 64 |
+
# entity_verifier.details:
|
| 65 |
+
# {
|
| 66 |
+
# "total_entities": int,
|
| 67 |
+
# "verified_count": int,
|
| 68 |
+
# "flagged_count": int,
|
| 69 |
+
# "entities": [
|
| 70 |
+
# {
|
| 71 |
+
# "entity": str,
|
| 72 |
+
# "type": "DRUG" | "DOSAGE" | "CONDITION" | "PROCEDURE",
|
| 73 |
+
# "status": "VERIFIED" | "FLAGGED" | "NOT_FOUND",
|
| 74 |
+
# "severity": "CRITICAL" | "MODERATE" | "MINOR" | null,
|
| 75 |
+
# "answer_value": str,
|
| 76 |
+
# "context_value": str | null,
|
| 77 |
+
# "rxcui": str | null
|
| 78 |
+
# }
|
| 79 |
+
# ]
|
| 80 |
+
# }
|
| 81 |
+
#
|
| 82 |
+
# source_credibility.details:
|
| 83 |
+
# {
|
| 84 |
+
# "method_used": "keyword" | "metadata",
|
| 85 |
+
# "chunks": [
|
| 86 |
+
# {
|
| 87 |
+
# "chunk_id": str,
|
| 88 |
+
# "tier": int, # 1–5
|
| 89 |
+
# "tier_weight": float,
|
| 90 |
+
# "pub_type": str,
|
| 91 |
+
# "title": str,
|
| 92 |
+
# "matched_keyword": str | null
|
| 93 |
+
# }
|
| 94 |
+
# ]
|
| 95 |
+
# }
|
| 96 |
+
#
|
| 97 |
+
# contradiction.details:
|
| 98 |
+
# {
|
| 99 |
+
# "total_sentences": int,
|
| 100 |
+
# "checked_pairs": int,
|
| 101 |
+
# "contradicted_pairs": int,
|
| 102 |
+
# "pairs": [
|
| 103 |
+
# {
|
| 104 |
+
# "sentence_a": str,
|
| 105 |
+
# "sentence_b": str,
|
| 106 |
+
# "contradiction_score": float,
|
| 107 |
+
# "flagged": bool
|
| 108 |
+
# }
|
| 109 |
+
# ]
|
| 110 |
+
# }
|
| 111 |
+
#
|
| 112 |
+
# aggregator.details:
|
| 113 |
+
# {
|
| 114 |
+
# "weights_used": {
|
| 115 |
+
# "faithfulness": float,
|
| 116 |
+
# "entity_accuracy": float,
|
| 117 |
+
# "source_credibility": float,
|
| 118 |
+
# "contradiction_risk": float
|
| 119 |
+
# },
|
| 120 |
+
# "weighted_composite": float,
|
| 121 |
+
# "component_contributions": {
|
| 122 |
+
# "faithfulness_contribution": float,
|
| 123 |
+
# "entity_contribution": float,
|
| 124 |
+
# "source_contribution": float,
|
| 125 |
+
# "contradiction_contribution": float
|
| 126 |
+
# }
|
| 127 |
+
# }
|
src/modules/base.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""src/modules/base.py — see __init__.py of this package for EvalResult."""
|
| 2 |
+
from src.modules import EvalResult # re-export for convenience
|
| 3 |
+
|
| 4 |
+
__all__ = ["EvalResult"]
|
src/modules/contradiction.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FR-17: src/modules/contradiction.py — Module 4: Cross-Document Contradiction Detection
|
| 3 |
+
========================================================================================
|
| 4 |
+
Uses the same DeBERTa NLI cross-encoder (cross-encoder/nli-deberta-v3-small) to
|
| 5 |
+
detect contradictions between the LLM answer and retrieved context passages.
|
| 6 |
+
|
| 7 |
+
Algorithm (SRS Section 6.4):
|
| 8 |
+
1. Split answer into sentences (claims)
|
| 9 |
+
2. Split each context chunk into sentences
|
| 10 |
+
3. For each (answer_sentence, context_sentence) pair:
|
| 11 |
+
- Run NLI → get contradiction score
|
| 12 |
+
- If contradiction_score ≥ CONTRADICTION_THRESHOLD → flag pair
|
| 13 |
+
4. score = 1.0 - (flagged_pairs / total_pairs)
|
| 14 |
+
|
| 15 |
+
This module shares the NLI model instance with faithfulness.py when both
|
| 16 |
+
run in the same process (the model is cached at the faithfulness module level).
|
| 17 |
+
|
| 18 |
+
Design note:
|
| 19 |
+
To keep latency manageable, context sentences are limited to
|
| 20 |
+
MAX_CONTEXT_SENTS per chunk and total pairs are capped at MAX_PAIRS.
|
| 21 |
+
"""
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import logging
|
| 25 |
+
import time
|
| 26 |
+
|
| 27 |
+
import pysbd
|
| 28 |
+
|
| 29 |
+
from src.modules.base import EvalResult
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
# Constants
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
CONTRADICTION_THRESHOLD = 0.50 # Balanced: catches real contradictions without over-flagging
|
| 38 |
+
MIN_KEYWORD_OVERLAP = 3 # At least 3 meaningful words in common before running NLI
|
| 39 |
+
MAX_CONTEXT_SENTS = 4 # top N sentences per context chunk
|
| 40 |
+
MAX_PAIRS = 200 # hard cap to keep latency bounded (~2-3s)
|
| 41 |
+
|
| 42 |
+
_segmenter = None
|
| 43 |
+
|
| 44 |
+
# Common stopwords to ignore in overlap check
|
| 45 |
+
_STOPWORDS = {
|
| 46 |
+
"the", "a", "an", "is", "in", "of", "to", "for", "and", "or", "are",
|
| 47 |
+
"be", "at", "by", "if", "it", "as", "on", "with", "this", "that",
|
| 48 |
+
"was", "were", "not", "no", "have", "has", "had", "but", "so", "from",
|
| 49 |
+
"should", "may", "can", "will", "than", "more", "when", "which", "who",
|
| 50 |
+
"what", "all", "each", "after", "before", "been", "do", "does", "1",
|
| 51 |
+
"2", "3", "mg", "iv", "od", "per", "day", "based", "using", "include",
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _get_segmenter():
|
| 56 |
+
"""Lazily load and return the pysbd segmenter."""
|
| 57 |
+
global _segmenter
|
| 58 |
+
if _segmenter is None:
|
| 59 |
+
try:
|
| 60 |
+
import pysbd
|
| 61 |
+
_segmenter = pysbd.Segmenter(language="en", clean=False)
|
| 62 |
+
except ImportError:
|
| 63 |
+
logger.warning("pysbd not installed, falling back to naive sentence splitting.")
|
| 64 |
+
_segmenter = "stub" # Use a string to indicate stub mode
|
| 65 |
+
except Exception as e:
|
| 66 |
+
logger.error("Failed to initialize pysbd segmenter: %s", e)
|
| 67 |
+
_segmenter = "stub"
|
| 68 |
+
return _segmenter
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _keyword_overlap(sent_a: str, sent_b: str) -> int:
|
| 72 |
+
"""Count shared content words between two sentences."""
|
| 73 |
+
tokens_a = {w.lower() for w in sent_a.split() if w.lower() not in _STOPWORDS and len(w) > 2}
|
| 74 |
+
tokens_b = {w.lower() for w in sent_b.split() if w.lower() not in _STOPWORDS and len(w) > 2}
|
| 75 |
+
return len(tokens_a & tokens_b)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _segment(text: str) -> list[str]:
|
| 79 |
+
"""Segment text into sentences using pysbd or a fallback."""
|
| 80 |
+
seg = _get_segmenter()
|
| 81 |
+
try:
|
| 82 |
+
if seg == "stub":
|
| 83 |
+
return [s.strip() for s in text.split(".") if s.strip()]
|
| 84 |
+
else:
|
| 85 |
+
return [s.strip() for s in seg.segment(text) if s.strip()]
|
| 86 |
+
except Exception:
|
| 87 |
+
return [s.strip() for s in text.split(".") if s.strip()]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ---------------------------------------------------------------------------
|
| 91 |
+
# Public API
|
| 92 |
+
# ---------------------------------------------------------------------------
|
| 93 |
+
|
| 94 |
+
def score_contradiction(
|
| 95 |
+
answer: str,
|
| 96 |
+
context_docs: list[str],
|
| 97 |
+
max_chunks: int = 5,
|
| 98 |
+
) -> EvalResult:
|
| 99 |
+
"""
|
| 100 |
+
Detect contradictions between the LLM answer and retrieved context.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
answer : LLM-generated answer text.
|
| 104 |
+
context_docs : List of retrieved context passage strings.
|
| 105 |
+
max_chunks : Max number of context chunks to evaluate.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
EvalResult with module_name="contradiction", score in [0,1] where
|
| 109 |
+
1.0 = no contradictions detected, 0.0 = all pairs contradicted.
|
| 110 |
+
"""
|
| 111 |
+
t0 = time.perf_counter()
|
| 112 |
+
|
| 113 |
+
if not answer or not context_docs:
|
| 114 |
+
return EvalResult(
|
| 115 |
+
module_name="contradiction",
|
| 116 |
+
score=0.5, # neutral — cannot verify with missing input
|
| 117 |
+
details={"total_sentences": 0, "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []},
|
| 118 |
+
latency_ms=0,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Import model via faithfulness module (shared cache)
|
| 122 |
+
try:
|
| 123 |
+
from src.modules.faithfulness import _get_model, LABEL_CONTRADICTION
|
| 124 |
+
except ImportError:
|
| 125 |
+
# (Lazy imports to prevent startup crashes when libraries aren't installed yet)
|
| 126 |
+
try:
|
| 127 |
+
from sentence_transformers import CrossEncoder
|
| 128 |
+
_model = CrossEncoder("cross-encoder/nli-deberta-v3-small")
|
| 129 |
+
_get_model = lambda: _model # noqa: E731
|
| 130 |
+
LABEL_CONTRADICTION = 0
|
| 131 |
+
except ImportError:
|
| 132 |
+
logger.error("sentence-transformers not installed. Cannot run NLI model.")
|
| 133 |
+
return EvalResult(
|
| 134 |
+
module_name="contradiction",
|
| 135 |
+
score=1.0,
|
| 136 |
+
details={},
|
| 137 |
+
error="NLI model (sentence-transformers) not installed.",
|
| 138 |
+
latency_ms=int((time.perf_counter() - t0) * 1000),
|
| 139 |
+
)
|
| 140 |
+
except Exception as exc:
|
| 141 |
+
logger.error("Failed to load NLI model: %s", exc)
|
| 142 |
+
return EvalResult(
|
| 143 |
+
module_name="contradiction",
|
| 144 |
+
score=1.0,
|
| 145 |
+
details={},
|
| 146 |
+
error=f"Failed to load NLI model: {exc}",
|
| 147 |
+
latency_ms=int((time.perf_counter() - t0) * 1000),
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
model = _get_model()
|
| 151 |
+
|
| 152 |
+
# Strip markdown/citations from answer before NLI (same reason as faithfulness.py)
|
| 153 |
+
import re as _re
|
| 154 |
+
_MD = _re.compile(
|
| 155 |
+
r'\[Source:[^\]]*\]|\[[^\]]{0,120}\]' # citations
|
| 156 |
+
r'|\*\*([^*]+)\*\*|\*([^*]+)\*' # bold/italic → keep text
|
| 157 |
+
r'|`[^`]+`' # code
|
| 158 |
+
)
|
| 159 |
+
answer = _MD.sub(lambda m: (m.group(1) or m.group(2) or ''), answer).strip()
|
| 160 |
+
|
| 161 |
+
# Segment answer into claims
|
| 162 |
+
answer_sents = _segment(answer)
|
| 163 |
+
if not answer_sents:
|
| 164 |
+
return EvalResult(
|
| 165 |
+
module_name="contradiction",
|
| 166 |
+
score=0.5, # neutral — cannot verify with no sentences
|
| 167 |
+
details={"total_sentences": 0, "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []},
|
| 168 |
+
latency_ms=0,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# Segment context chunks
|
| 172 |
+
docs = context_docs[:max_chunks]
|
| 173 |
+
context_sents: list[str] = []
|
| 174 |
+
for doc in docs:
|
| 175 |
+
sents = _segment(doc)[:MAX_CONTEXT_SENTS]
|
| 176 |
+
context_sents.extend(sents)
|
| 177 |
+
|
| 178 |
+
if not context_sents:
|
| 179 |
+
return EvalResult(
|
| 180 |
+
module_name="contradiction",
|
| 181 |
+
score=0.5, # neutral — cannot verify with no context sentences
|
| 182 |
+
details={"total_sentences": len(answer_sents), "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []},
|
| 183 |
+
latency_ms=0,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Build pairs WITH topical pre-filter (skip unrelated sentence pairs entirely)
|
| 187 |
+
all_pairs: list[tuple[str, str]] = []
|
| 188 |
+
for a_sent in answer_sents:
|
| 189 |
+
for c_sent in context_sents:
|
| 190 |
+
if _keyword_overlap(a_sent, c_sent) >= MIN_KEYWORD_OVERLAP:
|
| 191 |
+
all_pairs.append((a_sent, c_sent))
|
| 192 |
+
if len(all_pairs) >= MAX_PAIRS:
|
| 193 |
+
break
|
| 194 |
+
if len(all_pairs) >= MAX_PAIRS:
|
| 195 |
+
break
|
| 196 |
+
|
| 197 |
+
if not all_pairs:
|
| 198 |
+
# Topically unrelated — cannot check for contradictions
|
| 199 |
+
return EvalResult(
|
| 200 |
+
module_name="contradiction",
|
| 201 |
+
score=0.5, # neutral — no overlapping pairs to evaluate
|
| 202 |
+
details={"total_sentences": len(answer_sents), "checked_pairs": 0, "contradicted_pairs": 0, "pairs": []},
|
| 203 |
+
latency_ms=int((time.perf_counter() - t0) * 1000),
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Batch NLI inference
|
| 207 |
+
try:
|
| 208 |
+
scores_raw = model.predict(all_pairs, apply_softmax=True)
|
| 209 |
+
except Exception as exc:
|
| 210 |
+
logger.error("Contradiction NLI inference failed: %s", exc)
|
| 211 |
+
return EvalResult(
|
| 212 |
+
module_name="contradiction",
|
| 213 |
+
score=1.0,
|
| 214 |
+
details={},
|
| 215 |
+
error=f"Model inference error: {exc}",
|
| 216 |
+
latency_ms=int((time.perf_counter() - t0) * 1000),
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# Collect flagged pairs
|
| 220 |
+
pair_details: list[dict] = []
|
| 221 |
+
contradicted = 0
|
| 222 |
+
total = len(all_pairs)
|
| 223 |
+
|
| 224 |
+
for i, (a_sent, c_sent) in enumerate(all_pairs):
|
| 225 |
+
con_score = float(scores_raw[i][LABEL_CONTRADICTION])
|
| 226 |
+
flagged = con_score >= CONTRADICTION_THRESHOLD
|
| 227 |
+
if flagged:
|
| 228 |
+
contradicted += 1
|
| 229 |
+
# Only log the most severe contradictions to keep details manageable
|
| 230 |
+
pair_details.append(
|
| 231 |
+
{
|
| 232 |
+
"sentence_a": a_sent[:120],
|
| 233 |
+
"sentence_b": c_sent[:120],
|
| 234 |
+
"contradiction_score": round(con_score, 4),
|
| 235 |
+
"flagged": True,
|
| 236 |
+
}
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Score: 1.0 = clean, lower = more contradictions found
|
| 240 |
+
score = 1.0 - (contradicted / total) if total > 0 else 1.0
|
| 241 |
+
|
| 242 |
+
details = {
|
| 243 |
+
"total_sentences": len(answer_sents),
|
| 244 |
+
"checked_pairs": total,
|
| 245 |
+
"contradicted_pairs": contradicted,
|
| 246 |
+
"pairs": pair_details[:20], # cap output to top 20 flagged pairs
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
latency_ms = int((time.perf_counter() - t0) * 1000)
|
| 250 |
+
logger.info(
|
| 251 |
+
"Contradiction: %.3f (%d/%d pairs flagged) in %d ms",
|
| 252 |
+
score, contradicted, total, latency_ms,
|
| 253 |
+
)
|
| 254 |
+
return EvalResult(
|
| 255 |
+
module_name="contradiction",
|
| 256 |
+
score=score,
|
| 257 |
+
details=details,
|
| 258 |
+
latency_ms=latency_ms,
|
| 259 |
+
)
|
src/modules/entity_verifier.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FR-09: src/modules/entity_verifier.py — Module 2: Medical Entity Verification
|
| 3 |
+
==============================================================================
|
| 4 |
+
Uses SciSpaCy NER (en_core_sci_lg) to extract medical entities from the answer,
|
| 5 |
+
then verifies drug entities against the RxNorm cache and/or REST API.
|
| 6 |
+
|
| 7 |
+
Verification pipeline (SRS Section 6.2):
|
| 8 |
+
1. NER: extract DRUG, DOSAGE, CONDITION, PROCEDURE entities from answer
|
| 9 |
+
2. For each DRUG entity:
|
| 10 |
+
a. Look up in local rxnorm_cache.csv (fast, offline)
|
| 11 |
+
b. If not found, query RxNorm REST API /approximateTerm (live fallback)
|
| 12 |
+
c. If still not found, mark as NOT_FOUND
|
| 13 |
+
3. Cross-check entity presence in context docs (optional validation)
|
| 14 |
+
4. Score = verified_drug_count / total_drug_count (non-drug entities have no score impact)
|
| 15 |
+
|
| 16 |
+
Entity status values:
|
| 17 |
+
VERIFIED — drug found in RxNorm cache or API with rxcui
|
| 18 |
+
FLAGGED — entity found but has a known dangerous synonym conflict
|
| 19 |
+
NOT_FOUND — drug name not resolvable via any layer
|
| 20 |
+
|
| 21 |
+
Severity mapping (for FLAGGED):
|
| 22 |
+
brand ↔ generic mismatch → CRITICAL
|
| 23 |
+
dosage discrepancy → MODERATE
|
| 24 |
+
minor synonym variant → MINOR
|
| 25 |
+
"""
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import logging
|
| 29 |
+
import re
|
| 30 |
+
import time
|
| 31 |
+
from functools import lru_cache
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
from typing import Optional
|
| 34 |
+
|
| 35 |
+
import pandas as pd
|
| 36 |
+
import requests
|
| 37 |
+
|
| 38 |
+
from src.modules.base import EvalResult
|
| 39 |
+
|
| 40 |
+
logger = logging.getLogger(__name__)
|
| 41 |
+
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
# Constants
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
|
| 46 |
+
RXNORM_APPROX_URL = "https://rxnav.nlm.nih.gov/REST/approximateTerm.json"
|
| 47 |
+
DEFAULT_CACHE_PATH = "data/rxnorm_cache.csv"
|
| 48 |
+
NER_MODEL = "en_ner_bc5cdr_md"
|
| 49 |
+
DOSAGE_TOLERANCE_PCT = 10 # flag if answer dose differs from context dose by > 10%
|
| 50 |
+
|
| 51 |
+
# Matches clinical dose values: "500 mg", "2.5 mcg/kg", "10 IU", etc.
|
| 52 |
+
_DOSE_RE = re.compile(
|
| 53 |
+
r'(\d+(?:\.\d+)?)\s*(?:mg|mcg|g\b|ml|iu|units?|mg/kg|mg/dl)',
|
| 54 |
+
re.IGNORECASE,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Map spacy entity labels to our schema types
|
| 58 |
+
_ENTITY_TYPE_MAP = {
|
| 59 |
+
# en_core_sci_lg (CRAFT corpus) labels
|
| 60 |
+
"CHEBI": "DRUG", # Chemical Entities of Biological Interest — covers drugs
|
| 61 |
+
"GGP": "CONDITION", # Gene or Gene Product
|
| 62 |
+
"SO": "CONDITION", # Sequence Ontology
|
| 63 |
+
"TAXON": "CONDITION",
|
| 64 |
+
"GO": "CONDITION", # Gene Ontology
|
| 65 |
+
"CL": "CONDITION", # Cell Line
|
| 66 |
+
"DNA": "CONDITION",
|
| 67 |
+
"RNA": "CONDITION",
|
| 68 |
+
"CELL_TYPE": "CONDITION",
|
| 69 |
+
"CELL_LINE": "CONDITION",
|
| 70 |
+
"PROTEIN": "CONDITION",
|
| 71 |
+
# BC5CDR labels (used by some scispacy models)
|
| 72 |
+
"Chemical": "DRUG",
|
| 73 |
+
"Disease": "CONDITION",
|
| 74 |
+
# Generic / fallback labels
|
| 75 |
+
"CHEMICAL": "DRUG",
|
| 76 |
+
"DRUG": "DRUG",
|
| 77 |
+
"COMPOUND": "DRUG",
|
| 78 |
+
"DISEASE": "CONDITION",
|
| 79 |
+
"SYMPTOM": "CONDITION",
|
| 80 |
+
"PROCEDURE": "PROCEDURE",
|
| 81 |
+
"DOSAGE": "DOSAGE",
|
| 82 |
+
}
|
| 83 |
+
DRUG_TYPES = {"DRUG"} # only these get verified against RxNorm
|
| 84 |
+
|
| 85 |
+
# ---------------------------------------------------------------------------
|
| 86 |
+
# Module-level caches
|
| 87 |
+
# ---------------------------------------------------------------------------
|
| 88 |
+
|
| 89 |
+
_spacy_model = None
|
| 90 |
+
_rxnorm_cache: dict[str, str] | None = None # drug_name -> rxcui
|
| 91 |
+
_rxnorm_cache_path: str = DEFAULT_CACHE_PATH
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _get_spacy_model():
|
| 95 |
+
global _spacy_model
|
| 96 |
+
if _spacy_model is None:
|
| 97 |
+
import spacy
|
| 98 |
+
logger.info("Loading SciSpaCy NER model: %s (first call only)", NER_MODEL)
|
| 99 |
+
try:
|
| 100 |
+
_spacy_model = spacy.load(NER_MODEL)
|
| 101 |
+
logger.info("SciSpaCy model loaded.")
|
| 102 |
+
except OSError as exc:
|
| 103 |
+
logger.error(
|
| 104 |
+
"Failed to load '%s'. Install with: "
|
| 105 |
+
"pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/"
|
| 106 |
+
"releases/v0.5.4/en_ner_bc5cdr_md-0.5.4.tar.gz\nError: %s",
|
| 107 |
+
NER_MODEL, exc,
|
| 108 |
+
)
|
| 109 |
+
raise
|
| 110 |
+
return _spacy_model
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _load_rxnorm_cache(cache_path: str) -> dict[str, str]:
|
| 114 |
+
"""Load the RxNorm cache CSV into a lowercase drug_name → rxcui dict."""
|
| 115 |
+
path = Path(cache_path)
|
| 116 |
+
if not path.exists():
|
| 117 |
+
logger.warning(
|
| 118 |
+
"RxNorm cache not found at '%s'. Live API only will be used.", cache_path
|
| 119 |
+
)
|
| 120 |
+
return {}
|
| 121 |
+
try:
|
| 122 |
+
df = pd.read_csv(path, dtype=str)
|
| 123 |
+
cache = {
|
| 124 |
+
str(row["drug_name"]).strip().lower(): str(row["rxcui"]).strip()
|
| 125 |
+
for _, row in df.iterrows()
|
| 126 |
+
if pd.notna(row.get("drug_name")) and pd.notna(row.get("rxcui"))
|
| 127 |
+
and str(row.get("rxcui", "")).strip()
|
| 128 |
+
}
|
| 129 |
+
logger.info("RxNorm cache loaded: %d entries from %s", len(cache), cache_path)
|
| 130 |
+
return cache
|
| 131 |
+
except Exception as exc:
|
| 132 |
+
logger.warning("Failed to load RxNorm cache: %s", exc)
|
| 133 |
+
return {}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _get_rxnorm_cache(cache_path: str) -> dict[str, str]:
|
| 137 |
+
global _rxnorm_cache, _rxnorm_cache_path
|
| 138 |
+
if _rxnorm_cache is None or cache_path != _rxnorm_cache_path:
|
| 139 |
+
_rxnorm_cache_path = cache_path
|
| 140 |
+
_rxnorm_cache = _load_rxnorm_cache(cache_path)
|
| 141 |
+
return _rxnorm_cache
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _extract_doses_near(text: str, drug_name: str, window: int = 180) -> list[float]:
|
| 145 |
+
"""Return numeric dose values found within `window` chars of `drug_name` in `text`."""
|
| 146 |
+
idx = text.lower().find(drug_name.lower())
|
| 147 |
+
if idx == -1:
|
| 148 |
+
return []
|
| 149 |
+
vicinity = text[max(0, idx - window // 2): idx + len(drug_name) + window]
|
| 150 |
+
return [float(m.group(1)) for m in _DOSE_RE.finditer(vicinity)]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _lookup_rxnorm_api(drug_name: str, timeout: int = 4) -> Optional[str]:
|
| 154 |
+
"""Query RxNorm REST API. Returns rxcui string or None."""
|
| 155 |
+
try:
|
| 156 |
+
resp = requests.get(
|
| 157 |
+
RXNORM_APPROX_URL,
|
| 158 |
+
params={"term": drug_name, "maxEntries": "1", "option": "1"},
|
| 159 |
+
timeout=timeout,
|
| 160 |
+
)
|
| 161 |
+
if resp.status_code != 200:
|
| 162 |
+
return None
|
| 163 |
+
candidates = (
|
| 164 |
+
resp.json()
|
| 165 |
+
.get("approximateGroup", {})
|
| 166 |
+
.get("candidate", [])
|
| 167 |
+
)
|
| 168 |
+
if candidates:
|
| 169 |
+
return str(candidates[0].get("rxcui", "")).strip() or None
|
| 170 |
+
except Exception:
|
| 171 |
+
pass
|
| 172 |
+
return None
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# ---------------------------------------------------------------------------
|
| 176 |
+
# Public API
|
| 177 |
+
# ---------------------------------------------------------------------------
|
| 178 |
+
|
| 179 |
+
def verify_entities(
|
| 180 |
+
answer: str,
|
| 181 |
+
question: str = "",
|
| 182 |
+
context_docs: list[str] | None = None,
|
| 183 |
+
rxnorm_cache_path: str = DEFAULT_CACHE_PATH,
|
| 184 |
+
use_api_fallback: bool = True,
|
| 185 |
+
) -> EvalResult:
|
| 186 |
+
"""
|
| 187 |
+
Extract and verify medical entities from the LLM answer.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
answer : LLM-generated answer text.
|
| 191 |
+
question : Original question (NER'd alongside answer for richer entity set).
|
| 192 |
+
context_docs : Retrieved context passages (used for cross-checking).
|
| 193 |
+
rxnorm_cache_path : Path to rxnorm_cache.csv.
|
| 194 |
+
use_api_fallback : Whether to call RxNorm REST API for cache misses.
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
EvalResult with module_name="entity_verifier", score in [0,1], and
|
| 198 |
+
details matching the shape from src/modules/__init__.py.
|
| 199 |
+
"""
|
| 200 |
+
t0 = time.perf_counter()
|
| 201 |
+
|
| 202 |
+
# --- NER -----------------------------------------------------------------
|
| 203 |
+
try:
|
| 204 |
+
nlp = _get_spacy_model()
|
| 205 |
+
except Exception as exc:
|
| 206 |
+
return EvalResult(
|
| 207 |
+
module_name="entity_verifier",
|
| 208 |
+
score=0.5, # neutral fallback — don't penalise if model not available
|
| 209 |
+
details={"error": str(exc), "entities": []},
|
| 210 |
+
error=f"NER model unavailable: {exc}",
|
| 211 |
+
latency_ms=int((time.perf_counter() - t0) * 1000),
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Combine question + answer for richer entity extraction
|
| 215 |
+
combined_text = f"{question} {answer}" if question else answer
|
| 216 |
+
doc = nlp(combined_text)
|
| 217 |
+
|
| 218 |
+
# Collect entities with deduplication
|
| 219 |
+
seen: set[str] = set()
|
| 220 |
+
raw_entities: list[tuple[str, str]] = [] # (text, type)
|
| 221 |
+
for ent in doc.ents:
|
| 222 |
+
key = ent.text.strip().lower()
|
| 223 |
+
if not key or key in seen:
|
| 224 |
+
continue
|
| 225 |
+
seen.add(key)
|
| 226 |
+
entity_type = _ENTITY_TYPE_MAP.get(ent.label_, "CONDITION")
|
| 227 |
+
raw_entities.append((ent.text.strip(), entity_type))
|
| 228 |
+
|
| 229 |
+
if not raw_entities:
|
| 230 |
+
return EvalResult(
|
| 231 |
+
module_name="entity_verifier",
|
| 232 |
+
score=0.5, # neutral — cannot verify what isn't there
|
| 233 |
+
details={"total_entities": 0, "verified_count": 0, "flagged_count": 0, "entities": []},
|
| 234 |
+
latency_ms=int((time.perf_counter() - t0) * 1000),
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# --- RxNorm verification for DRUG entities -------------------------------
|
| 238 |
+
cache = _get_rxnorm_cache(rxnorm_cache_path)
|
| 239 |
+
context_text = " ".join(context_docs or []).lower()
|
| 240 |
+
|
| 241 |
+
entity_results: list[dict] = []
|
| 242 |
+
drug_total = 0
|
| 243 |
+
drug_verified = 0
|
| 244 |
+
drug_flagged = 0
|
| 245 |
+
|
| 246 |
+
for entity_text, entity_type in raw_entities:
|
| 247 |
+
result = {
|
| 248 |
+
"entity": entity_text,
|
| 249 |
+
"type": entity_type,
|
| 250 |
+
"status": "NOT_FOUND",
|
| 251 |
+
"severity": None,
|
| 252 |
+
"answer_value": entity_text,
|
| 253 |
+
"context_value": None,
|
| 254 |
+
"rxcui": None,
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
if entity_type in DRUG_TYPES:
|
| 258 |
+
drug_total += 1
|
| 259 |
+
key = entity_text.lower()
|
| 260 |
+
|
| 261 |
+
# Layer 1: Local cache lookup
|
| 262 |
+
rxcui = cache.get(key)
|
| 263 |
+
|
| 264 |
+
# Layer 2: API fallback
|
| 265 |
+
if not rxcui and use_api_fallback:
|
| 266 |
+
rxcui = _lookup_rxnorm_api(entity_text)
|
| 267 |
+
|
| 268 |
+
if rxcui:
|
| 269 |
+
result["rxcui"] = rxcui
|
| 270 |
+
|
| 271 |
+
# Check for dosage discrepancy before marking VERIFIED
|
| 272 |
+
answer_doses = _extract_doses_near(answer, entity_text)
|
| 273 |
+
context_doses = _extract_doses_near(context_text, entity_text)
|
| 274 |
+
flagged_dose = False
|
| 275 |
+
if answer_doses and context_doses:
|
| 276 |
+
a_dose = answer_doses[0]
|
| 277 |
+
c_dose = min(context_doses, key=lambda d: abs(d - a_dose))
|
| 278 |
+
pct_diff = abs(a_dose - c_dose) / max(c_dose, 1e-9) * 100
|
| 279 |
+
if pct_diff > DOSAGE_TOLERANCE_PCT:
|
| 280 |
+
result["status"] = "FLAGGED"
|
| 281 |
+
result["severity"] = "MODERATE"
|
| 282 |
+
result["answer_value"] = f"{a_dose} (answer)"
|
| 283 |
+
result["context_value"] = f"{c_dose} (context, Δ{pct_diff:.0f}%)"
|
| 284 |
+
drug_flagged += 1
|
| 285 |
+
flagged_dose = True
|
| 286 |
+
logger.warning(
|
| 287 |
+
"Dosage discrepancy for '%s': answer=%.1f context=%.1f (%.0f%%)",
|
| 288 |
+
entity_text, a_dose, c_dose, pct_diff,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
if not flagged_dose:
|
| 292 |
+
result["status"] = "VERIFIED"
|
| 293 |
+
drug_verified += 1
|
| 294 |
+
if key in context_text:
|
| 295 |
+
result["context_value"] = entity_text
|
| 296 |
+
else:
|
| 297 |
+
result["status"] = "NOT_FOUND"
|
| 298 |
+
|
| 299 |
+
elif entity_type in ("CONDITION", "PROCEDURE"):
|
| 300 |
+
# Non-drug entities: check presence in context only
|
| 301 |
+
if entity_text.lower() in context_text:
|
| 302 |
+
result["status"] = "VERIFIED"
|
| 303 |
+
result["context_value"] = entity_text
|
| 304 |
+
else:
|
| 305 |
+
result["status"] = "NOT_FOUND"
|
| 306 |
+
|
| 307 |
+
entity_results.append(result)
|
| 308 |
+
|
| 309 |
+
# --- Score ---------------------------------------------------------------
|
| 310 |
+
# Score is based on drug entities only (per SRS Section 6.2)
|
| 311 |
+
if drug_total == 0:
|
| 312 |
+
score = 0.5 # neutral — no drug entities to verify
|
| 313 |
+
else:
|
| 314 |
+
score = drug_verified / drug_total
|
| 315 |
+
|
| 316 |
+
details = {
|
| 317 |
+
"total_entities": len(raw_entities),
|
| 318 |
+
"drug_total": drug_total,
|
| 319 |
+
"verified_count": drug_verified,
|
| 320 |
+
"flagged_count": drug_flagged,
|
| 321 |
+
"entities": entity_results,
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
latency_ms = int((time.perf_counter() - t0) * 1000)
|
| 325 |
+
logger.info(
|
| 326 |
+
"Entity verification: %.3f (%d/%d drugs verified) in %d ms",
|
| 327 |
+
score, drug_verified, drug_total, latency_ms,
|
| 328 |
+
)
|
| 329 |
+
return EvalResult(
|
| 330 |
+
module_name="entity_verifier",
|
| 331 |
+
score=score,
|
| 332 |
+
details=details,
|
| 333 |
+
latency_ms=latency_ms,
|
| 334 |
+
)
|
src/modules/faithfulness.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FR-05: src/modules/faithfulness.py — Module 1: Faithfulness Scoring
|
| 3 |
+
=====================================================================
|
| 4 |
+
Uses cross-encoder/nli-deberta-v3-small to score how well the LLM answer
|
| 5 |
+
is entailed by the retrieved context chunks.
|
| 6 |
+
|
| 7 |
+
Architecture:
|
| 8 |
+
1. Split answer into individual claims (sentences via pysbd)
|
| 9 |
+
2. For each claim: compute NLI score against every context chunk
|
| 10 |
+
3. Assign claim status: ENTAILED / NEUTRAL / CONTRADICTED
|
| 11 |
+
4. score = entailed_count / total_claims
|
| 12 |
+
|
| 13 |
+
Thresholds (SRS Section 6.1):
|
| 14 |
+
entailment ≥ 0.50 → ENTAILED
|
| 15 |
+
contradiction ≥ 0.30 → CONTRADICTED
|
| 16 |
+
otherwise → NEUTRAL
|
| 17 |
+
|
| 18 |
+
Model loaded lazily and cached at module level (avoids double-loading
|
| 19 |
+
when called multiple times in same process).
|
| 20 |
+
"""
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import logging
|
| 24 |
+
import time
|
| 25 |
+
from functools import lru_cache
|
| 26 |
+
from typing import TYPE_CHECKING
|
| 27 |
+
|
| 28 |
+
from src.modules.base import EvalResult
|
| 29 |
+
|
| 30 |
+
if TYPE_CHECKING:
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
# Constants
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
|
| 39 |
+
# BioLinkBERT fine-tuned on MedNLI (clinical notes, MIMIC-III)
|
| 40 |
+
# Paper 15 (Chen et al. SemEval-2023): best single model for biomedical NLI (F1=0.765)
|
| 41 |
+
# Faster on CPU than DeBERTa-large (BERT-base architecture)
|
| 42 |
+
MODEL_NAME = "cnut1648/biolinkbert-mednli"
|
| 43 |
+
|
| 44 |
+
# MedNLI label order (verified): {0: entailment, 1: neutral, 2: contradiction}
|
| 45 |
+
LABEL_ENTAILMENT = 0
|
| 46 |
+
LABEL_NEUTRAL = 1
|
| 47 |
+
LABEL_CONTRADICTION = 2
|
| 48 |
+
|
| 49 |
+
ENTAILMENT_THRESHOLD = 0.50
|
| 50 |
+
CONTRADICTION_THRESHOLD = 0.30
|
| 51 |
+
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
# Lazy model loader
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
|
| 56 |
+
_model = None
|
| 57 |
+
_segmenter = None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _get_model():
|
| 61 |
+
global _model
|
| 62 |
+
if _model is None:
|
| 63 |
+
try:
|
| 64 |
+
from sentence_transformers import CrossEncoder
|
| 65 |
+
logger.info("Loading NLI model: %s (first call only)", MODEL_NAME)
|
| 66 |
+
_model = CrossEncoder(MODEL_NAME)
|
| 67 |
+
logger.info("NLI model loaded.")
|
| 68 |
+
except ImportError:
|
| 69 |
+
logger.error("sentence_transformers not installed. Faithfulness will be stubbed.")
|
| 70 |
+
_model = "stub"
|
| 71 |
+
return _model
|
| 72 |
+
|
| 73 |
+
def _get_segmenter():
|
| 74 |
+
global _segmenter
|
| 75 |
+
if _segmenter is None:
|
| 76 |
+
try:
|
| 77 |
+
import pysbd
|
| 78 |
+
_segmenter = pysbd.Segmenter(language="en", clean=False)
|
| 79 |
+
except ImportError:
|
| 80 |
+
_segmenter = "stub"
|
| 81 |
+
return _segmenter
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
# Public API
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
|
| 88 |
+
def score_faithfulness(
|
| 89 |
+
answer: str,
|
| 90 |
+
context_docs: list[str],
|
| 91 |
+
chunk_ids: list[str] | None = None,
|
| 92 |
+
max_chunks: int = 3,
|
| 93 |
+
config: dict | None = None,
|
| 94 |
+
) -> EvalResult:
|
| 95 |
+
"""
|
| 96 |
+
Score the faithfulness of an answer against retrieved context documents.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
answer : The LLM-generated answer text.
|
| 100 |
+
context_docs : List of context passage strings (top-k retrieved chunks).
|
| 101 |
+
chunk_ids : Optional IDs matching context_docs for traceability.
|
| 102 |
+
max_chunks : Maximum context chunks to consider (to limit API calls).
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
EvalResult with module_name="faithfulness", score in [0,1], and details
|
| 106 |
+
dict matching the shape defined in src/modules/__init__.py.
|
| 107 |
+
"""
|
| 108 |
+
t0 = time.perf_counter()
|
| 109 |
+
|
| 110 |
+
_faith_cfg = (config or {}).get("modules", {}).get("faithfulness", {})
|
| 111 |
+
entailment_threshold = _faith_cfg.get("entailment_threshold", ENTAILMENT_THRESHOLD)
|
| 112 |
+
contradiction_threshold = CONTRADICTION_THRESHOLD
|
| 113 |
+
|
| 114 |
+
if not answer or not context_docs:
|
| 115 |
+
return EvalResult(
|
| 116 |
+
module_name="faithfulness",
|
| 117 |
+
score=0.0,
|
| 118 |
+
details={"error": "Empty answer or no context provided"},
|
| 119 |
+
error="Empty answer or no context",
|
| 120 |
+
latency_ms=0,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Limit context size
|
| 124 |
+
docs = context_docs[:max_chunks]
|
| 125 |
+
ids = (chunk_ids or [f"chunk_{i}" for i in range(len(docs))])[:max_chunks]
|
| 126 |
+
|
| 127 |
+
# Strip markdown formatting from guideline/structured chunks before NLI
|
| 128 |
+
# DeBERTa NLI was trained on clean prose — markdown confuses it
|
| 129 |
+
import re as _re
|
| 130 |
+
_MD_CLEAN = _re.compile(r'\[([^\]]+)\]\n|#{1,6}\s+|•\s+|\*\*([^*]+)\*\*|\*([^*]+)\*|`[^`]+`')
|
| 131 |
+
docs = [_MD_CLEAN.sub(lambda m: m.group(2) or m.group(3) or '', d) for d in docs]
|
| 132 |
+
|
| 133 |
+
# Strip inline citations and markdown from the answer before claim splitting.
|
| 134 |
+
# LLM answers often include [Source: *title*] citations and **bold** text that
|
| 135 |
+
# confuse BioLinkBERT NLI — the model was trained on clean prose.
|
| 136 |
+
_CITE_RE = _re.compile(
|
| 137 |
+
r'\[Source:[^\]]*\]' # [Source: title] or [Source: *italic title*]
|
| 138 |
+
r'|\[[^\]]{0,120}\]' # other short bracket constructs
|
| 139 |
+
r'|\*\*([^*]+)\*\*' # **bold** → keep inner text
|
| 140 |
+
r'|\*([^*]+)\*' # *italic* → keep inner text
|
| 141 |
+
r'|`[^`]+`' # `code`
|
| 142 |
+
r'|^\s*[*•]\s+' # bullet points at line start
|
| 143 |
+
)
|
| 144 |
+
answer_clean = _CITE_RE.sub(lambda m: (m.group(1) or m.group(2) or ''), answer).strip()
|
| 145 |
+
|
| 146 |
+
# Split answer into claims
|
| 147 |
+
seg = _get_segmenter()
|
| 148 |
+
try:
|
| 149 |
+
if seg == "stub":
|
| 150 |
+
claims = [s.strip() for s in answer_clean.split(".") if s.strip()]
|
| 151 |
+
else:
|
| 152 |
+
claims = [s.strip() for s in seg.segment(answer_clean) if s.strip()]
|
| 153 |
+
except Exception:
|
| 154 |
+
claims = [s.strip() for s in answer_clean.split(".") if s.strip()]
|
| 155 |
+
|
| 156 |
+
if not claims:
|
| 157 |
+
return EvalResult(
|
| 158 |
+
module_name="faithfulness",
|
| 159 |
+
score=0.5,
|
| 160 |
+
details={"error": "Could not extract claims from answer"},
|
| 161 |
+
error="No claims extracted",
|
| 162 |
+
latency_ms=0,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
model = _get_model()
|
| 166 |
+
|
| 167 |
+
# Limit claims to avoid O(claims×chunks) explosion with the large model
|
| 168 |
+
claims = claims[:12]
|
| 169 |
+
|
| 170 |
+
# ---------------------------------------------------------------------------
|
| 171 |
+
# Numerical Bypass (Paper 14: non-optional for clinical NLI)
|
| 172 |
+
# NLI models structurally cannot verify numerical comparisons (≥6.5%, 126 mg/dL).
|
| 173 |
+
# Use direct string/lexical matching for claims containing clinical measurements.
|
| 174 |
+
# ---------------------------------------------------------------------------
|
| 175 |
+
import re as _re2
|
| 176 |
+
_NUM_PATTERN = _re2.compile(
|
| 177 |
+
r'[\d]+[\s]*(mg|mcg|%|mL|mmol|IU|units?|g|kg|≥|≤|>|<|±|mg/dL|mmol/L|mg/kg)',
|
| 178 |
+
_re2.IGNORECASE,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def _numerical_match(claim: str, context_chunks: list[str]) -> str:
|
| 182 |
+
"""
|
| 183 |
+
For claims with numerical clinical values, check if the key numbers
|
| 184 |
+
appear in any context chunk. Returns ENTAILED or NEUTRAL.
|
| 185 |
+
"""
|
| 186 |
+
nums = _re2.findall(r'[\d]+\.?[\d]*', claim)
|
| 187 |
+
if not nums:
|
| 188 |
+
return "NEUTRAL"
|
| 189 |
+
combined = " ".join(context_chunks).lower()
|
| 190 |
+
matched = sum(1 for n in nums if n in combined)
|
| 191 |
+
return "ENTAILED" if matched >= len(nums) * 0.6 else "NEUTRAL"
|
| 192 |
+
|
| 193 |
+
# Separate numerical claims (bypass NLI) from textual claims (use NLI)
|
| 194 |
+
numerical_results: dict[int, str] = {} # claim_idx → status
|
| 195 |
+
nli_claim_indices: list[int] = []
|
| 196 |
+
|
| 197 |
+
for ci, claim in enumerate(claims):
|
| 198 |
+
if _NUM_PATTERN.search(claim):
|
| 199 |
+
numerical_results[ci] = _numerical_match(claim, docs)
|
| 200 |
+
else:
|
| 201 |
+
nli_claim_indices.append(ci)
|
| 202 |
+
|
| 203 |
+
# Build NLI pairs only for non-numerical claims
|
| 204 |
+
nli_claims = [claims[ci] for ci in nli_claim_indices]
|
| 205 |
+
all_pairs = []
|
| 206 |
+
pair_map: list[tuple[int, int]] = [] # (nli_claim_idx, doc_idx)
|
| 207 |
+
for nci, claim in enumerate(nli_claims):
|
| 208 |
+
for di, doc in enumerate(docs):
|
| 209 |
+
all_pairs.append((doc, claim))
|
| 210 |
+
pair_map.append((nci, di))
|
| 211 |
+
|
| 212 |
+
# Batch NLI inference
|
| 213 |
+
try:
|
| 214 |
+
if model == "stub":
|
| 215 |
+
# Provide dummy scores if model is unavailable
|
| 216 |
+
scores_raw = [[0.1, 0.1, 0.8] for _ in all_pairs]
|
| 217 |
+
else:
|
| 218 |
+
scores_raw = model.predict(all_pairs, apply_softmax=True)
|
| 219 |
+
except Exception as exc:
|
| 220 |
+
logger.error("NLI model inference failed: %s", exc)
|
| 221 |
+
return EvalResult(
|
| 222 |
+
module_name="faithfulness",
|
| 223 |
+
score=0.0,
|
| 224 |
+
details={},
|
| 225 |
+
error=f"Model inference error: {exc}",
|
| 226 |
+
latency_ms=int((time.perf_counter() - t0) * 1000),
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Aggregate: for each claim find the context with the highest entailment
|
| 230 |
+
claim_results: list[dict] = []
|
| 231 |
+
entailed = 0
|
| 232 |
+
neutral = 0
|
| 233 |
+
contradicted = 0
|
| 234 |
+
|
| 235 |
+
# Build per-NLI-claim best scores from batch results
|
| 236 |
+
nli_best: dict[int, tuple[float, float, int]] = {} # nci → (best_ent, best_con, best_doc)
|
| 237 |
+
for idx, (nci, d_i) in enumerate(pair_map):
|
| 238 |
+
score_vec = scores_raw[idx]
|
| 239 |
+
ent_score = float(score_vec[LABEL_ENTAILMENT])
|
| 240 |
+
con_score = float(score_vec[LABEL_CONTRADICTION])
|
| 241 |
+
if nci not in nli_best or ent_score > nli_best[nci][0]:
|
| 242 |
+
nli_best[nci] = (ent_score, con_score, d_i)
|
| 243 |
+
|
| 244 |
+
for ci, claim in enumerate(claims):
|
| 245 |
+
if ci in numerical_results:
|
| 246 |
+
# Numerical bypass — lexical match result
|
| 247 |
+
status = numerical_results[ci]
|
| 248 |
+
nli_score = 1.0 if status == "ENTAILED" else 0.0
|
| 249 |
+
best_doc_idx = 0
|
| 250 |
+
method = "numerical_bypass"
|
| 251 |
+
else:
|
| 252 |
+
# NLI result
|
| 253 |
+
nci = nli_claim_indices.index(ci) if ci in nli_claim_indices else -1
|
| 254 |
+
best_entailment, best_contradiction, best_doc_idx = nli_best.get(nci, (0.0, 0.0, 0))
|
| 255 |
+
if best_entailment >= entailment_threshold:
|
| 256 |
+
status = "ENTAILED"
|
| 257 |
+
nli_score = best_entailment
|
| 258 |
+
elif best_contradiction >= contradiction_threshold:
|
| 259 |
+
status = "CONTRADICTED"
|
| 260 |
+
nli_score = best_contradiction
|
| 261 |
+
else:
|
| 262 |
+
status = "NEUTRAL"
|
| 263 |
+
nli_score = best_entailment
|
| 264 |
+
method = "nli"
|
| 265 |
+
|
| 266 |
+
if status == "ENTAILED":
|
| 267 |
+
entailed += 1
|
| 268 |
+
elif status == "CONTRADICTED":
|
| 269 |
+
contradicted += 1
|
| 270 |
+
else:
|
| 271 |
+
neutral += 1
|
| 272 |
+
|
| 273 |
+
claim_results.append({
|
| 274 |
+
"claim": claim,
|
| 275 |
+
"status": status,
|
| 276 |
+
"best_chunk_id": ids[best_doc_idx],
|
| 277 |
+
"nli_score": round(nli_score, 4),
|
| 278 |
+
"method": method,
|
| 279 |
+
})
|
| 280 |
+
|
| 281 |
+
total = len(claims)
|
| 282 |
+
score = max(0.0, (entailed - contradicted) / total) if total > 0 else 0.0
|
| 283 |
+
|
| 284 |
+
details = {
|
| 285 |
+
"total_claims": total,
|
| 286 |
+
"entailed_count": entailed,
|
| 287 |
+
"neutral_count": neutral,
|
| 288 |
+
"contradicted_count": contradicted,
|
| 289 |
+
"claims": claim_results,
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
latency_ms = int((time.perf_counter() - t0) * 1000)
|
| 293 |
+
logger.info(
|
| 294 |
+
"Faithfulness: %.3f (%d/%d entailed) in %d ms",
|
| 295 |
+
score, entailed, total, latency_ms,
|
| 296 |
+
)
|
| 297 |
+
return EvalResult(
|
| 298 |
+
module_name="faithfulness",
|
| 299 |
+
score=score,
|
| 300 |
+
details=details,
|
| 301 |
+
latency_ms=latency_ms,
|
| 302 |
+
)
|
src/modules/source_credibility.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FR-14: src/modules/source_credibility.py — Module 3: Source Credibility Scoring
|
| 3 |
+
=================================================================================
|
| 4 |
+
Scores the credibility of retrieved source documents based on their publication
|
| 5 |
+
type / evidence tier.
|
| 6 |
+
|
| 7 |
+
Tier weights (SRS Section 6.3):
|
| 8 |
+
clinical_guideline → 1.00 (Tier 1 — highest authority)
|
| 9 |
+
systematic_review → 0.85 (Tier 2)
|
| 10 |
+
research_abstract → 0.70 (Tier 3 — PubMedQA default)
|
| 11 |
+
review_article → 0.60 (Tier 4)
|
| 12 |
+
clinical_case → 0.50 (Tier 5)
|
| 13 |
+
unknown / other → 0.30 (fallback)
|
| 14 |
+
|
| 15 |
+
Detection:
|
| 16 |
+
1. Use 'tier_type' metadata field if present (set by embedder.py)
|
| 17 |
+
2. Fall back to keyword matching in pub_type / title text
|
| 18 |
+
|
| 19 |
+
Score = weighted mean of tier weights across all retrieved chunks.
|
| 20 |
+
|
| 21 |
+
Each chunk must be a dict with at minimum:
|
| 22 |
+
{"text": str, "metadata": {"tier_type": str, "pub_type": str, "title": str}}
|
| 23 |
+
or the simpler form accepted by the retriever:
|
| 24 |
+
{"text": str, "source": str, "tier_type": str, "title": str}
|
| 25 |
+
"""
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import logging
|
| 29 |
+
import re
|
| 30 |
+
import time
|
| 31 |
+
|
| 32 |
+
from src.modules.base import EvalResult
|
| 33 |
+
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
# Evidence tier weights
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
|
| 40 |
+
TIER_WEIGHTS: dict[str, float] = {
|
| 41 |
+
"clinical_guideline": 1.00,
|
| 42 |
+
"systematic_review": 0.85,
|
| 43 |
+
"drug_label": 0.90, # FDA-approved drug labels — authoritative regulatory source
|
| 44 |
+
"research_abstract": 0.70,
|
| 45 |
+
"review_article": 0.60,
|
| 46 |
+
"clinical_case": 0.50,
|
| 47 |
+
"unknown": 0.30,
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
# Keyword → tier_type mapping for fallback text matching
|
| 51 |
+
_KEYWORD_MAP: list[tuple[re.Pattern, str]] = [
|
| 52 |
+
(re.compile(r"\b(guideline|clinical practice|recommendation|consensus)\b", re.I), "clinical_guideline"),
|
| 53 |
+
(re.compile(r"\b(systematic review|meta.?analysis)\b", re.I), "systematic_review"),
|
| 54 |
+
# RCT / controlled trial → highest single-study evidence tier
|
| 55 |
+
(re.compile(r"\b(randomized|randomised|controlled trial|rct|clinical trial)\b", re.I), "clinical_guideline"),
|
| 56 |
+
# FDA drug labels
|
| 57 |
+
(re.compile(r"\b(fda|drug label|prescribing information|package insert|dailymed)\b", re.I), "drug_label"),
|
| 58 |
+
(re.compile(r"\b(review|overview)\b", re.I), "review_article"),
|
| 59 |
+
(re.compile(r"\b(case report|case study|clinical case)\b", re.I), "clinical_case"),
|
| 60 |
+
(re.compile(r"\b(abstract|research article|original article|journal)\b", re.I), "research_abstract"),
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _classify_tier(chunk: dict) -> tuple[str, str | None]:
|
| 65 |
+
"""
|
| 66 |
+
Return (tier_type, matched_keyword) for a single retrieved chunk dict.
|
| 67 |
+
|
| 68 |
+
Priority 1: explicit tier_type field (set by embedder.py)
|
| 69 |
+
Priority 2: pub_type field directly maps to a known tier name
|
| 70 |
+
Priority 3: keyword regex on pub_type + title text
|
| 71 |
+
"""
|
| 72 |
+
# Priority 1: explicit tier_type already set (e.g., by embedder.py)
|
| 73 |
+
tier = (
|
| 74 |
+
chunk.get("tier_type")
|
| 75 |
+
or chunk.get("metadata", {}).get("tier_type")
|
| 76 |
+
)
|
| 77 |
+
if tier and tier in TIER_WEIGHTS:
|
| 78 |
+
return tier, None
|
| 79 |
+
|
| 80 |
+
# Priority 2: direct pub_type value lookup
|
| 81 |
+
# Handles underscore-separated values like "research_abstract" which
|
| 82 |
+
# won't match word-boundary regex patterns
|
| 83 |
+
pub_type_raw = str(
|
| 84 |
+
chunk.get("pub_type") or chunk.get("metadata", {}).get("pub_type") or ""
|
| 85 |
+
).strip().lower()
|
| 86 |
+
|
| 87 |
+
_PUB_TYPE_DIRECT: dict[str, str] = {
|
| 88 |
+
"research_abstract": "research_abstract",
|
| 89 |
+
"abstract": "research_abstract",
|
| 90 |
+
"systematic_review": "systematic_review",
|
| 91 |
+
"systematic review": "systematic_review",
|
| 92 |
+
"meta_analysis": "systematic_review",
|
| 93 |
+
"meta-analysis": "systematic_review",
|
| 94 |
+
"drug_label": "drug_label",
|
| 95 |
+
"drug label": "drug_label",
|
| 96 |
+
"clinical_guideline": "clinical_guideline",
|
| 97 |
+
"clinical guideline": "clinical_guideline",
|
| 98 |
+
"guideline": "clinical_guideline",
|
| 99 |
+
"review_article": "review_article",
|
| 100 |
+
"review article": "review_article",
|
| 101 |
+
"review": "review_article",
|
| 102 |
+
"clinical_case": "clinical_case",
|
| 103 |
+
"case_report": "clinical_case",
|
| 104 |
+
"case report": "clinical_case",
|
| 105 |
+
}
|
| 106 |
+
if pub_type_raw in _PUB_TYPE_DIRECT:
|
| 107 |
+
return _PUB_TYPE_DIRECT[pub_type_raw], None
|
| 108 |
+
|
| 109 |
+
# Priority 3: keyword regex on pub_type + title text
|
| 110 |
+
title = str(chunk.get("title") or chunk.get("metadata", {}).get("title") or "")
|
| 111 |
+
text_to_search = f"{pub_type_raw} {title}"
|
| 112 |
+
|
| 113 |
+
for pattern, matched_tier in _KEYWORD_MAP:
|
| 114 |
+
m = pattern.search(text_to_search)
|
| 115 |
+
if m:
|
| 116 |
+
return matched_tier, m.group(0)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
return "unknown", None
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# ---------------------------------------------------------------------------
|
| 123 |
+
# Public API
|
| 124 |
+
# ---------------------------------------------------------------------------
|
| 125 |
+
|
| 126 |
+
def score_source_credibility(
|
| 127 |
+
retrieved_chunks: list[dict],
|
| 128 |
+
) -> EvalResult:
|
| 129 |
+
"""
|
| 130 |
+
Score the credibility of a set of retrieved source documents.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
retrieved_chunks : List of chunk dicts as returned by retriever.retrieve().
|
| 134 |
+
Each must contain at minimum 'text' and ideally
|
| 135 |
+
'tier_type', 'pub_type', 'title', 'chunk_id' fields.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
EvalResult with module_name="source_credibility", score in [0,1], and
|
| 139 |
+
details matching the shape from src/modules/__init__.py.
|
| 140 |
+
"""
|
| 141 |
+
t0 = time.perf_counter()
|
| 142 |
+
|
| 143 |
+
if not retrieved_chunks:
|
| 144 |
+
return EvalResult(
|
| 145 |
+
module_name="source_credibility",
|
| 146 |
+
score=0.0,
|
| 147 |
+
details={"chunks": [], "method_used": "none"},
|
| 148 |
+
error="No chunks provided",
|
| 149 |
+
latency_ms=0,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
chunk_details: list[dict] = []
|
| 153 |
+
weights: list[float] = []
|
| 154 |
+
method_used = "metadata" # assume metadata-first; may switch to keyword
|
| 155 |
+
|
| 156 |
+
for i, chunk in enumerate(retrieved_chunks):
|
| 157 |
+
tier_type, matched_kw = _classify_tier(chunk)
|
| 158 |
+
weight = TIER_WEIGHTS.get(tier_type, TIER_WEIGHTS["unknown"])
|
| 159 |
+
weights.append(weight)
|
| 160 |
+
|
| 161 |
+
if matched_kw:
|
| 162 |
+
method_used = "keyword"
|
| 163 |
+
|
| 164 |
+
# Compute tier number (1-5) for display
|
| 165 |
+
tier_num = {
|
| 166 |
+
"clinical_guideline": 1,
|
| 167 |
+
"systematic_review": 2,
|
| 168 |
+
"research_abstract": 3,
|
| 169 |
+
"review_article": 4,
|
| 170 |
+
"clinical_case": 5,
|
| 171 |
+
}.get(tier_type, 6) # 6 = unknown/unclassified
|
| 172 |
+
|
| 173 |
+
chunk_details.append(
|
| 174 |
+
{
|
| 175 |
+
"chunk_id": chunk.get("chunk_id") or chunk.get("metadata", {}).get("chunk_id") or f"chunk_{i}",
|
| 176 |
+
"tier": tier_num,
|
| 177 |
+
"tier_type": tier_type,
|
| 178 |
+
"tier_weight": round(weight, 2),
|
| 179 |
+
"pub_type": chunk.get("pub_type") or chunk.get("metadata", {}).get("pub_type") or "",
|
| 180 |
+
"title": (chunk.get("title") or chunk.get("metadata", {}).get("title") or "")[:80],
|
| 181 |
+
"matched_keyword": matched_kw,
|
| 182 |
+
}
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
score = sum(weights) / len(weights) if weights else 0.0
|
| 186 |
+
|
| 187 |
+
details = {
|
| 188 |
+
"method_used": method_used,
|
| 189 |
+
"chunk_count": len(retrieved_chunks),
|
| 190 |
+
"avg_tier_weight": round(score, 4),
|
| 191 |
+
"chunks": chunk_details,
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
latency_ms = int((time.perf_counter() - t0) * 1000)
|
| 195 |
+
logger.info(
|
| 196 |
+
"Source credibility: %.3f (avg tier weight over %d chunks) in %d ms",
|
| 197 |
+
score, len(retrieved_chunks), latency_ms,
|
| 198 |
+
)
|
| 199 |
+
return EvalResult(
|
| 200 |
+
module_name="source_credibility",
|
| 201 |
+
score=score,
|
| 202 |
+
details=details,
|
| 203 |
+
latency_ms=latency_ms,
|
| 204 |
+
)
|
src/pipeline/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# src/pipeline/__init__.py
|
src/pipeline/chunker.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FR-02: Document Chunking
|
| 3 |
+
========================
|
| 4 |
+
LangChain RecursiveCharacterTextSplitter
|
| 5 |
+
chunk_size = 512 chars (config: retrieval.chunk_size)
|
| 6 |
+
overlap = 50 chars (config: retrieval.chunk_overlap)
|
| 7 |
+
|
| 8 |
+
Each chunk carries the full FR-03b metadata schema required by Module 3
|
| 9 |
+
(source credibility) and the FAISS metadata store.
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
import uuid
|
| 15 |
+
from typing import Any
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def chunk_documents(
|
| 21 |
+
documents: list[dict[str, Any]],
|
| 22 |
+
config: dict,
|
| 23 |
+
) -> list[dict[str, Any]]:
|
| 24 |
+
"""
|
| 25 |
+
Split a list of raw documents into overlapping text chunks.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
documents : List of dicts with keys:
|
| 29 |
+
text, doc_id, source, title, pub_type, pub_year, journal
|
| 30 |
+
config : Loaded config.yaml dict
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
List of chunk dicts (FR-03b metadata schema):
|
| 34 |
+
chunk_id, chunk_text, doc_id, source, title,
|
| 35 |
+
pub_type, pub_year, journal, chunk_index, total_chunks
|
| 36 |
+
"""
|
| 37 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 38 |
+
|
| 39 |
+
chunk_size = config["retrieval"]["chunk_size"] # 512
|
| 40 |
+
chunk_overlap = config["retrieval"]["chunk_overlap"] # 50
|
| 41 |
+
|
| 42 |
+
splitter = RecursiveCharacterTextSplitter(
|
| 43 |
+
chunk_size=chunk_size,
|
| 44 |
+
chunk_overlap=chunk_overlap,
|
| 45 |
+
length_function=len,
|
| 46 |
+
separators=["\n\n", "\n", ". ", " ", ""],
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
all_chunks: list[dict] = []
|
| 50 |
+
|
| 51 |
+
for doc in documents:
|
| 52 |
+
text = doc.get("text", "").strip()
|
| 53 |
+
if not text:
|
| 54 |
+
logger.debug("Skipping empty document: doc_id=%s", doc.get("doc_id"))
|
| 55 |
+
continue
|
| 56 |
+
|
| 57 |
+
raw_chunks = splitter.split_text(text)
|
| 58 |
+
total = len(raw_chunks)
|
| 59 |
+
|
| 60 |
+
for idx, chunk_text in enumerate(raw_chunks):
|
| 61 |
+
chunk_text = chunk_text.strip()
|
| 62 |
+
if not chunk_text:
|
| 63 |
+
continue
|
| 64 |
+
all_chunks.append({
|
| 65 |
+
# FR-03b schema
|
| 66 |
+
"chunk_id": str(uuid.uuid4()),
|
| 67 |
+
"chunk_text": chunk_text,
|
| 68 |
+
"doc_id": doc["doc_id"],
|
| 69 |
+
"source": doc["source"],
|
| 70 |
+
"title": doc["title"],
|
| 71 |
+
"pub_type": doc["pub_type"],
|
| 72 |
+
"pub_year": doc.get("pub_year", 0),
|
| 73 |
+
"journal": doc.get("journal", ""),
|
| 74 |
+
"chunk_index": idx,
|
| 75 |
+
"total_chunks": total,
|
| 76 |
+
})
|
| 77 |
+
|
| 78 |
+
logger.info(
|
| 79 |
+
"Chunked %d documents → %d chunks (size=%d, overlap=%d)",
|
| 80 |
+
len(documents), len(all_chunks), chunk_size, chunk_overlap,
|
| 81 |
+
)
|
| 82 |
+
return all_chunks
|
src/pipeline/consensus.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/pipeline/consensus.py — Multi-Model Consensus Engine
|
| 3 |
+
=========================================================
|
| 4 |
+
Implements the "Ensemble Judge" middleware feature.
|
| 5 |
+
Calls multiple LLMs and compares their answers for medical contradictions.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
import logging
|
| 9 |
+
import concurrent.futures
|
| 10 |
+
from typing import List, Dict, Any, Optional
|
| 11 |
+
from src.pipeline.generator import generate_answer
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
def run_consensus_check(
|
| 16 |
+
question: str,
|
| 17 |
+
context_chunks: List[Dict[str, Any]],
|
| 18 |
+
config: Dict[str, Any],
|
| 19 |
+
providers: List[str] = ["gemini", "groq"]
|
| 20 |
+
) -> Dict[str, Any]:
|
| 21 |
+
"""
|
| 22 |
+
Calls multiple providers in parallel and compares outcomes.
|
| 23 |
+
Returns: {
|
| 24 |
+
"answers": { provider: answer },
|
| 25 |
+
"agreement_score": float [0-1],
|
| 26 |
+
"conflicts": List[str],
|
| 27 |
+
"consensus_answer": str
|
| 28 |
+
}
|
| 29 |
+
"""
|
| 30 |
+
logger.info("Starting Consensus Check with providers: %s", providers)
|
| 31 |
+
|
| 32 |
+
# 1. Generate answers in parallel
|
| 33 |
+
answers = {}
|
| 34 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 35 |
+
future_to_provider = {
|
| 36 |
+
executor.submit(generate_answer, question, context_chunks, config, {"provider": p}): p
|
| 37 |
+
for p in providers
|
| 38 |
+
}
|
| 39 |
+
for future in concurrent.futures.as_completed(future_to_provider):
|
| 40 |
+
provider = future_to_provider[future]
|
| 41 |
+
try:
|
| 42 |
+
answers[provider] = future.result()
|
| 43 |
+
except Exception as exc:
|
| 44 |
+
logger.error("Provider %s failed during consensus: %s", provider, exc)
|
| 45 |
+
answers[provider] = f"ERROR: {exc}"
|
| 46 |
+
|
| 47 |
+
if len(answers) < 2:
|
| 48 |
+
return {
|
| 49 |
+
"answers": answers,
|
| 50 |
+
"agreement_score": 1.0,
|
| 51 |
+
"conflicts": ["Insufficient providers responded for a full consensus check."],
|
| 52 |
+
"consensus_answer": list(answers.values())[0] if answers else "Safety failure: No providers responded."
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
# Compile context text to feed the Judge
|
| 56 |
+
context_text = "\n\n".join([f"Source {i+1}:\n{c.get('text', '')}" for i, c in enumerate(context_chunks)])
|
| 57 |
+
|
| 58 |
+
# 2. Compare answers using a "Judge" Agent
|
| 59 |
+
# We use Gemini (or the primary provider) as the judge
|
| 60 |
+
comparison_prompt = f"""
|
| 61 |
+
You are a Medical Consensus Judge. Compare the following two medical answers provided by different AI models to the same question.
|
| 62 |
+
CRITICAL INSTRUCTION: Your primary duty is to ensure the final answer is explicitly grounded in the provided MEDICAL CONTEXT.
|
| 63 |
+
Identify any CLINICAL CONTRADICTIONS or significant discrepancies in drug names, dosages, or recommendations.
|
| 64 |
+
If one model hallucinates outside the context, you must side with the model that stuck to the context.
|
| 65 |
+
|
| 66 |
+
QUESTION: {question}
|
| 67 |
+
|
| 68 |
+
MEDICAL CONTEXT FROM DATASET:
|
| 69 |
+
{context_text}
|
| 70 |
+
|
| 71 |
+
ANSWER A:
|
| 72 |
+
{list(answers.values())[0]}
|
| 73 |
+
|
| 74 |
+
ANSWER B:
|
| 75 |
+
{list(answers.values())[1] if len(answers) > 1 else "N/A"}
|
| 76 |
+
|
| 77 |
+
OUTPUT FORMAT (JSON ONLY):
|
| 78 |
+
{{
|
| 79 |
+
"agreement_score": 0.0 to 1.0 (1.0 means perfect alignment, 0.0 means complete contradiction),
|
| 80 |
+
"conflicts": ["list of specific medical discrepancies found"],
|
| 81 |
+
"summary": "brief summary of how they differ and which one aligns better with the Medical Context",
|
| 82 |
+
"recommended_consensus": "the most conservative and safe unified answer that strictly adheres to the Medical Context"
|
| 83 |
+
}}
|
| 84 |
+
"""
|
| 85 |
+
try:
|
| 86 |
+
# Use the generator's default to run the judge
|
| 87 |
+
judge_raw = generate_answer("Medical Consensus Judge Task", [{"text": comparison_prompt}], config)
|
| 88 |
+
# Attempt to parse JSON from the judge's response
|
| 89 |
+
# (A real implementation would use structured output, but we use a robust parse for now)
|
| 90 |
+
import json
|
| 91 |
+
import re
|
| 92 |
+
|
| 93 |
+
# Clean potential markdown
|
| 94 |
+
clean_json = re.sub(r'```json\n?|\n?```', '', judge_raw).strip()
|
| 95 |
+
judge_data = json.loads(clean_json)
|
| 96 |
+
|
| 97 |
+
return {
|
| 98 |
+
"answers": answers,
|
| 99 |
+
"agreement_score": judge_data.get("agreement_score", 0.5),
|
| 100 |
+
"conflicts": judge_data.get("conflicts", []),
|
| 101 |
+
"summary": judge_data.get("summary", ""),
|
| 102 |
+
"consensus_answer": judge_data.get("recommended_consensus", list(answers.values())[0])
|
| 103 |
+
}
|
| 104 |
+
except Exception as e:
|
| 105 |
+
logger.error("Consensus Judge failed: %s", e)
|
| 106 |
+
return {
|
| 107 |
+
"answers": answers,
|
| 108 |
+
"agreement_score": 0.5,
|
| 109 |
+
"conflicts": [f"Judge failed: {e}"],
|
| 110 |
+
"consensus_answer": list(answers.values())[0]
|
| 111 |
+
}
|
src/pipeline/embedder.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FR-03 + FR-03b: Embedding Generation & FAISS Index Construction
|
| 3 |
+
===============================================================
|
| 4 |
+
Model : dmis-lab/biobert-v1.1 (768-dim dense vectors, SentenceTransformer)
|
| 5 |
+
Index : FAISS IndexFlatIP with L2-normalized vectors (= cosine similarity)
|
| 6 |
+
Metadata: Parallel dict[int → dict] saved as pickle alongside index
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python src/pipeline/embedder.py
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import sys
|
| 14 |
+
import os
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
|
| 18 |
+
|
| 19 |
+
import json
|
| 20 |
+
import logging
|
| 21 |
+
import pickle
|
| 22 |
+
|
| 23 |
+
import faiss
|
| 24 |
+
import numpy as np
|
| 25 |
+
import yaml
|
| 26 |
+
|
| 27 |
+
import src # noqa: F401 — logging setup
|
| 28 |
+
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _load_config() -> dict:
|
| 33 |
+
with open("config.yaml", "r", encoding="utf-8") as f:
|
| 34 |
+
return yaml.safe_load(f)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def load_chunks(chunks_path: str = "data/processed/chunks.jsonl") -> list[dict]:
|
| 38 |
+
"""Load chunks from JSONL produced by ingest.py."""
|
| 39 |
+
path = Path(chunks_path)
|
| 40 |
+
if not path.exists():
|
| 41 |
+
raise FileNotFoundError(
|
| 42 |
+
f"Chunks file not found: '{chunks_path}'. "
|
| 43 |
+
"Run python src/pipeline/ingest.py first."
|
| 44 |
+
)
|
| 45 |
+
chunks = []
|
| 46 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 47 |
+
for line in f:
|
| 48 |
+
line = line.strip()
|
| 49 |
+
if line:
|
| 50 |
+
chunks.append(json.loads(line))
|
| 51 |
+
logger.info("Loaded %d chunks from %s", len(chunks), chunks_path)
|
| 52 |
+
return chunks
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def encode_texts(
|
| 56 |
+
texts: list[str],
|
| 57 |
+
model_name: str,
|
| 58 |
+
batch_size: int = 32,
|
| 59 |
+
) -> np.ndarray:
|
| 60 |
+
"""
|
| 61 |
+
Encode texts using BioBERT via SentenceTransformer.
|
| 62 |
+
Returns L2-normalized float32 array of shape (N, 768).
|
| 63 |
+
"""
|
| 64 |
+
from sentence_transformers import SentenceTransformer
|
| 65 |
+
|
| 66 |
+
logger.info("Loading embedding model: %s", model_name)
|
| 67 |
+
model = SentenceTransformer(model_name)
|
| 68 |
+
|
| 69 |
+
logger.info("Encoding %d texts (batch_size=%d)...", len(texts), batch_size)
|
| 70 |
+
embeddings: np.ndarray = model.encode(
|
| 71 |
+
texts,
|
| 72 |
+
batch_size=batch_size,
|
| 73 |
+
show_progress_bar=True,
|
| 74 |
+
normalize_embeddings=True, # L2-normalise → cosine via IndexFlatIP
|
| 75 |
+
convert_to_numpy=True,
|
| 76 |
+
)
|
| 77 |
+
logger.info("Encoded shape: %s", embeddings.shape)
|
| 78 |
+
return embeddings.astype(np.float32)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def build_faiss_index(embeddings: np.ndarray) -> faiss.IndexFlatIP:
|
| 82 |
+
"""
|
| 83 |
+
Build FAISS IndexFlatIP.
|
| 84 |
+
Because vectors are L2-normalised, inner product == cosine similarity.
|
| 85 |
+
"""
|
| 86 |
+
dim = embeddings.shape[1] # 768 for BioBERT
|
| 87 |
+
index = faiss.IndexFlatIP(dim)
|
| 88 |
+
index.add(embeddings)
|
| 89 |
+
logger.info(
|
| 90 |
+
"FAISS IndexFlatIP built: %d vectors, dim=%d", index.ntotal, dim
|
| 91 |
+
)
|
| 92 |
+
return index
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def build_metadata_store(chunks: list[dict]) -> dict[int, dict]:
|
| 96 |
+
"""
|
| 97 |
+
Build parallel metadata dict → key = FAISS integer index (0-based).
|
| 98 |
+
Stores the full FR-03b schema plus chunk_text for retrieval.
|
| 99 |
+
"""
|
| 100 |
+
store: dict[int, dict] = {}
|
| 101 |
+
for i, chunk in enumerate(chunks):
|
| 102 |
+
store[i] = {
|
| 103 |
+
"chunk_id": chunk["chunk_id"],
|
| 104 |
+
"doc_id": chunk["doc_id"],
|
| 105 |
+
"source": chunk["source"],
|
| 106 |
+
"title": chunk["title"],
|
| 107 |
+
"pub_type": chunk["pub_type"],
|
| 108 |
+
"pub_year": chunk["pub_year"],
|
| 109 |
+
"journal": chunk["journal"],
|
| 110 |
+
"chunk_index": chunk["chunk_index"],
|
| 111 |
+
"total_chunks": chunk["total_chunks"],
|
| 112 |
+
"chunk_text": chunk["chunk_text"], # kept for retrieval
|
| 113 |
+
}
|
| 114 |
+
return store
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def save_artifacts(
|
| 118 |
+
index: faiss.IndexFlatIP,
|
| 119 |
+
metadata_store: dict,
|
| 120 |
+
config: dict,
|
| 121 |
+
) -> None:
|
| 122 |
+
"""Persist FAISS index and metadata pickle to disk."""
|
| 123 |
+
index_path = Path(config["retrieval"]["index_path"])
|
| 124 |
+
meta_path = Path(config["retrieval"]["metadata_path"])
|
| 125 |
+
|
| 126 |
+
index_path.parent.mkdir(parents=True, exist_ok=True)
|
| 127 |
+
meta_path.parent.mkdir(parents=True, exist_ok=True)
|
| 128 |
+
|
| 129 |
+
faiss.write_index(index, str(index_path))
|
| 130 |
+
logger.info("FAISS index written to %s", index_path)
|
| 131 |
+
|
| 132 |
+
with open(meta_path, "wb") as f:
|
| 133 |
+
pickle.dump(metadata_store, f, protocol=pickle.HIGHEST_PROTOCOL)
|
| 134 |
+
logger.info(
|
| 135 |
+
"Metadata store written to %s (%d entries)", meta_path, len(metadata_store)
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def main() -> None:
|
| 140 |
+
config = _load_config()
|
| 141 |
+
chunks = load_chunks("data/processed/chunks.jsonl")
|
| 142 |
+
|
| 143 |
+
if not chunks:
|
| 144 |
+
logger.error("No chunks to embed. Run python src/pipeline/ingest.py first.")
|
| 145 |
+
sys.exit(1)
|
| 146 |
+
|
| 147 |
+
texts = [c["chunk_text"] for c in chunks]
|
| 148 |
+
model_name = config["retrieval"]["embedding_model"]
|
| 149 |
+
embeddings = encode_texts(texts, model_name, batch_size=32)
|
| 150 |
+
index = build_faiss_index(embeddings)
|
| 151 |
+
metadata_store = build_metadata_store(chunks)
|
| 152 |
+
|
| 153 |
+
save_artifacts(index, metadata_store, config)
|
| 154 |
+
|
| 155 |
+
logger.info(
|
| 156 |
+
"Embedding complete. Index has %d vectors. "
|
| 157 |
+
"Next: python scripts/warmup.py && streamlit run src/dashboard/app.py",
|
| 158 |
+
index.ntotal,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
main()
|
src/pipeline/generator.py
ADDED
|
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/pipeline/generator.py — LLM Answer Generation
|
| 3 |
+
===================================================
|
| 4 |
+
Supports multiple providers based on config.yaml → llm.provider:
|
| 5 |
+
- "gemini" : Google Gemini API (recommended)
|
| 6 |
+
- "mistral" : Mistral AI API (api.mistral.ai)
|
| 7 |
+
- "groq" : Groq Cloud API (fast inference)
|
| 8 |
+
- "ollama" : Local Ollama/Mistral (requires Ollama running locally)
|
| 9 |
+
|
| 10 |
+
API Key setup:
|
| 11 |
+
Set env variables in Backend/.env:
|
| 12 |
+
GEMINI_API_KEY=your_key
|
| 13 |
+
MISTRAL_API_KEY=your_key
|
| 14 |
+
GROQ_API_KEY=your_key
|
| 15 |
+
"""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import json
|
| 19 |
+
import logging
|
| 20 |
+
import os
|
| 21 |
+
import time
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Optional
|
| 24 |
+
|
| 25 |
+
import yaml
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
# Load .env file at module import time
|
| 30 |
+
def _load_env():
|
| 31 |
+
env_path = Path(".env")
|
| 32 |
+
if not env_path.exists():
|
| 33 |
+
# Try one level up
|
| 34 |
+
env_path = Path("../Backend/.env")
|
| 35 |
+
if env_path.exists():
|
| 36 |
+
for line in env_path.read_text().splitlines():
|
| 37 |
+
line = line.strip()
|
| 38 |
+
if line and not line.startswith("#") and "=" in line:
|
| 39 |
+
key, val = line.split("=", 1)
|
| 40 |
+
key = key.strip()
|
| 41 |
+
val = val.strip().strip('"').strip("'")
|
| 42 |
+
if key and val and key not in os.environ:
|
| 43 |
+
os.environ[key] = val
|
| 44 |
+
|
| 45 |
+
_load_env()
|
| 46 |
+
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
# Config loader
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
def _load_config() -> dict:
|
| 52 |
+
try:
|
| 53 |
+
return yaml.safe_load(Path("config.yaml").read_text())
|
| 54 |
+
except Exception:
|
| 55 |
+
return {}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ---------------------------------------------------------------------------
|
| 59 |
+
# Prompt builder (shared by both providers)
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
|
| 62 |
+
_PHYSICIAN_PROMPT = (
|
| 63 |
+
"You are MediRAG, a medical AI assistant tailored for clinicians and researchers. "
|
| 64 |
+
"You MUST answer ONLY using information explicitly stated in the CONTEXT provided below. "
|
| 65 |
+
"Use professional medical terminology, be concise, and cite specific details. "
|
| 66 |
+
"After each claim, cite it inline as [Source: <document title>]. "
|
| 67 |
+
"If the context does NOT contain sufficient information to answer safely, you MUST respond EXACTLY with: "
|
| 68 |
+
"'⚠️ The retrieved context does not contain enough information to answer this safely. "
|
| 69 |
+
"Please consult authoritative clinical guidelines or a specialist.' "
|
| 70 |
+
"NEVER use general knowledge, training data, or information outside the provided context."
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
_PATIENT_PROMPT = (
|
| 74 |
+
"You are MediRAG, a medical AI assistant tailored for patients and non-experts. "
|
| 75 |
+
"You MUST answer ONLY using information explicitly stated in the CONTEXT provided below. "
|
| 76 |
+
"Explain medical information in a clear, accessible, and empathetic way. "
|
| 77 |
+
"After each claim, cite it inline as [Source: <document title>]. "
|
| 78 |
+
"If the context does NOT contain sufficient information to answer safely, you MUST respond EXACTLY with: "
|
| 79 |
+
"'⚠️ The retrieved context does not contain enough information to answer this safely. "
|
| 80 |
+
"Please consult your doctor or a medical specialist.' "
|
| 81 |
+
"NEVER use general knowledge, training data, or information outside the provided context."
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
_SYSTEM_PROMPT = _PHYSICIAN_PROMPT # Default fallback
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _build_prompt(question: str, context_chunks: list[dict], system_prompt: Optional[str] = None, persona: str = "physician") -> str:
|
| 88 |
+
"""Build the RAG prompt from the question + retrieved chunks.
|
| 89 |
+
|
| 90 |
+
Explicitly surfaces title and source for each chunk in the header so the LLM
|
| 91 |
+
can cite [Source: <title>] inline in its answer.
|
| 92 |
+
"""
|
| 93 |
+
context_parts = []
|
| 94 |
+
for i, chunk in enumerate(context_chunks, 1):
|
| 95 |
+
text = chunk.get("text") or chunk.get("chunk_text", "")
|
| 96 |
+
title = chunk.get("title", "")
|
| 97 |
+
source = chunk.get("source", "")
|
| 98 |
+
pub_type = chunk.get("pub_type", "")
|
| 99 |
+
# Include title as the primary citation label
|
| 100 |
+
header_parts = [f"Source {i}"]
|
| 101 |
+
if title:
|
| 102 |
+
header_parts.append(f"Title: {title}")
|
| 103 |
+
if pub_type:
|
| 104 |
+
header_parts.append(pub_type)
|
| 105 |
+
if source and source != title:
|
| 106 |
+
header_parts.append(source)
|
| 107 |
+
header = "[" + " | ".join(header_parts) + "]"
|
| 108 |
+
context_parts.append(f"{header}\n{text.strip()}")
|
| 109 |
+
|
| 110 |
+
context_block = "\n\n".join(context_parts)
|
| 111 |
+
|
| 112 |
+
# Determine effective system prompt based on persona if no manual override
|
| 113 |
+
if system_prompt:
|
| 114 |
+
effective_system = system_prompt
|
| 115 |
+
else:
|
| 116 |
+
effective_system = _PATIENT_PROMPT if persona == "patient" else _PHYSICIAN_PROMPT
|
| 117 |
+
|
| 118 |
+
return (
|
| 119 |
+
f"{effective_system}\n\n"
|
| 120 |
+
f"CONTEXT:\n{context_block}\n\n"
|
| 121 |
+
f"QUESTION: {question}\n\n"
|
| 122 |
+
f"ANSWER (cite sources inline as [Source: document title]):"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# Strict prompt — used when first answer fails evaluation (HRS ≥ 60)
|
| 127 |
+
_STRICT_SYSTEM_PROMPT = (
|
| 128 |
+
"You are MediRAG, a clinical safety assistant under strict mode. "
|
| 129 |
+
"A previous response was flagged as potentially unsafe or inaccurate. "
|
| 130 |
+
"You MUST answer ONLY using the information explicitly stated in the CONTEXT below. "
|
| 131 |
+
"Do NOT use any general medical knowledge, training data, or outside information. "
|
| 132 |
+
"If the context is insufficient, you MUST say EXACTLY: "
|
| 133 |
+
"'⚠️ Insufficient evidence in retrieved context to answer safely. Please consult a clinical specialist.' "
|
| 134 |
+
"NEVER hallucinate drug names, dosages, or clinical recommendations."
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _build_strict_prompt(question: str, context_chunks: list[dict]) -> str:
|
| 139 |
+
"""Strict prompt: context-only, used on regeneration after failed evaluation."""
|
| 140 |
+
context_parts = []
|
| 141 |
+
for i, chunk in enumerate(context_chunks, 1):
|
| 142 |
+
text = chunk.get("text") or chunk.get("chunk_text", "")
|
| 143 |
+
title = chunk.get("title", "")
|
| 144 |
+
source = chunk.get("source", "")
|
| 145 |
+
pub_type = chunk.get("pub_type", "")
|
| 146 |
+
header_parts = [f"Source {i}"]
|
| 147 |
+
if title:
|
| 148 |
+
header_parts.append(f"Title: {title}")
|
| 149 |
+
if pub_type:
|
| 150 |
+
header_parts.append(pub_type)
|
| 151 |
+
if source and source != title:
|
| 152 |
+
header_parts.append(source)
|
| 153 |
+
header = "[" + " | ".join(header_parts) + "]"
|
| 154 |
+
context_parts.append(f"{header}\n{text.strip()}")
|
| 155 |
+
|
| 156 |
+
context_block = "\n\n".join(context_parts)
|
| 157 |
+
return (
|
| 158 |
+
f"{_STRICT_SYSTEM_PROMPT}\n\n"
|
| 159 |
+
f"CONTEXT:\n{context_block}\n\n"
|
| 160 |
+
f"QUESTION: {question}\n\n"
|
| 161 |
+
f"SAFE ANSWER (context-only, cite [Source: title] for every claim):"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
# OpenAI provider
|
| 167 |
+
# ---------------------------------------------------------------------------
|
| 168 |
+
|
| 169 |
+
def _generate_openai(prompt: str, config: dict) -> str:
|
| 170 |
+
llm_cfg = config.get("llm", {})
|
| 171 |
+
|
| 172 |
+
# Override from frontend/config takes priority over system ENV
|
| 173 |
+
api_key = llm_cfg.get("openai_api_key") or os.environ.get("OPENAI_API_KEY")
|
| 174 |
+
if not api_key:
|
| 175 |
+
env_file = Path(".env")
|
| 176 |
+
if env_file.exists():
|
| 177 |
+
for line in env_file.read_text().splitlines():
|
| 178 |
+
if line.startswith("OPENAI_API_KEY="):
|
| 179 |
+
api_key = line.split("=", 1)[1].strip().strip('"').strip("'")
|
| 180 |
+
break
|
| 181 |
+
|
| 182 |
+
if not api_key:
|
| 183 |
+
raise RuntimeError("OpenAI API key not found. Set OPENAI_API_KEY env var or in .env.")
|
| 184 |
+
|
| 185 |
+
try:
|
| 186 |
+
from openai import OpenAI
|
| 187 |
+
except ImportError:
|
| 188 |
+
raise RuntimeError("openai not installed. Run: pip install openai")
|
| 189 |
+
|
| 190 |
+
model_name = llm_cfg.get("openai_model") or llm_cfg.get("model") or "gpt-4o"
|
| 191 |
+
client = OpenAI(api_key=api_key)
|
| 192 |
+
|
| 193 |
+
logger.info("Calling OpenAI API (model=%s)...", model_name)
|
| 194 |
+
t0 = time.perf_counter()
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
response = client.chat.completions.create(
|
| 198 |
+
model=model_name,
|
| 199 |
+
messages=[{"role": "user", "content": prompt}],
|
| 200 |
+
temperature=float(llm_cfg.get("generation_temperature", 0.7)),
|
| 201 |
+
max_tokens=1024,
|
| 202 |
+
)
|
| 203 |
+
except Exception as exc:
|
| 204 |
+
raise RuntimeError(f"OpenAI API error: {exc}") from exc
|
| 205 |
+
|
| 206 |
+
elapsed = int((time.perf_counter() - t0) * 1000)
|
| 207 |
+
answer = response.choices[0].message.content.strip()
|
| 208 |
+
|
| 209 |
+
if not answer:
|
| 210 |
+
raise RuntimeError("OpenAI returned an empty response.")
|
| 211 |
+
|
| 212 |
+
logger.info("OpenAI generated answer in %d ms (%d chars)", elapsed, len(answer))
|
| 213 |
+
return answer
|
| 214 |
+
|
| 215 |
+
def _generate_gemini(prompt: str, config: dict) -> str:
|
| 216 |
+
llm_cfg = config.get("llm", {})
|
| 217 |
+
|
| 218 |
+
# Override from frontend/config takes priority over system ENV
|
| 219 |
+
api_key = llm_cfg.get("gemini_api_key") or os.environ.get("GEMINI_API_KEY")
|
| 220 |
+
if not api_key:
|
| 221 |
+
# Try loading from .env file if present
|
| 222 |
+
env_file = Path(".env")
|
| 223 |
+
if env_file.exists():
|
| 224 |
+
for line in env_file.read_text().splitlines():
|
| 225 |
+
if line.startswith("GEMINI_API_KEY="):
|
| 226 |
+
api_key = line.split("=", 1)[1].strip().strip('"').strip("'")
|
| 227 |
+
break
|
| 228 |
+
|
| 229 |
+
if not api_key:
|
| 230 |
+
raise RuntimeError(
|
| 231 |
+
"Gemini API key not found. "
|
| 232 |
+
"Either: (1) set GEMINI_API_KEY=your_key in the same terminal as uvicorn, "
|
| 233 |
+
"or (2) create a .env file with GEMINI_API_KEY=your_key in the project root."
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
try:
|
| 237 |
+
from google import genai
|
| 238 |
+
from google.genai import types
|
| 239 |
+
except ImportError:
|
| 240 |
+
raise RuntimeError(
|
| 241 |
+
"google-genai not installed. Run: pip install google-genai"
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
model_name = llm_cfg.get("gemini_model", "gemini-2.0-flash")
|
| 245 |
+
client = genai.Client(api_key=api_key)
|
| 246 |
+
|
| 247 |
+
logger.info("Calling Gemini API (model=%s)...", model_name)
|
| 248 |
+
t0 = time.perf_counter()
|
| 249 |
+
|
| 250 |
+
try:
|
| 251 |
+
response = client.models.generate_content(
|
| 252 |
+
model=model_name,
|
| 253 |
+
contents=prompt,
|
| 254 |
+
config=types.GenerateContentConfig(
|
| 255 |
+
temperature=float(llm_cfg.get("generation_temperature", 0.7)),
|
| 256 |
+
max_output_tokens=1024,
|
| 257 |
+
),
|
| 258 |
+
)
|
| 259 |
+
except Exception as exc:
|
| 260 |
+
raise RuntimeError(f"Gemini API error: {exc}") from exc
|
| 261 |
+
|
| 262 |
+
elapsed = int((time.perf_counter() - t0) * 1000)
|
| 263 |
+
answer = response.text.strip() if response.text else ""
|
| 264 |
+
|
| 265 |
+
if not answer:
|
| 266 |
+
raise RuntimeError("Gemini returned an empty response.")
|
| 267 |
+
|
| 268 |
+
logger.info("Gemini generated answer in %d ms (%d chars)", elapsed, len(answer))
|
| 269 |
+
return answer
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# ---------------------------------------------------------------------------
|
| 273 |
+
# Ollama provider (kept as fallback)
|
| 274 |
+
# ---------------------------------------------------------------------------
|
| 275 |
+
|
| 276 |
+
def _generate_ollama(prompt: str, config: dict) -> str:
|
| 277 |
+
import requests as _requests
|
| 278 |
+
|
| 279 |
+
llm_cfg = config.get("llm", {})
|
| 280 |
+
base_url = llm_cfg.get("base_url", "http://localhost:11434")
|
| 281 |
+
model = llm_cfg.get("model", "mistral")
|
| 282 |
+
timeout = llm_cfg.get("timeout_seconds", 120)
|
| 283 |
+
temperature = llm_cfg.get("generation_temperature", 0.7)
|
| 284 |
+
|
| 285 |
+
payload = {
|
| 286 |
+
"model": model,
|
| 287 |
+
"prompt": prompt,
|
| 288 |
+
"stream": False,
|
| 289 |
+
"options": {"temperature": temperature, "num_predict": 512},
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
url = f"{base_url}/api/generate"
|
| 293 |
+
logger.info("Calling Ollama (%s @ %s)...", model, base_url)
|
| 294 |
+
t0 = time.perf_counter()
|
| 295 |
+
|
| 296 |
+
try:
|
| 297 |
+
resp = _requests.post(url, json=payload, timeout=timeout)
|
| 298 |
+
except _requests.exceptions.ConnectionError as exc:
|
| 299 |
+
raise RuntimeError(
|
| 300 |
+
f"Ollama is not running at {base_url}. Start with: ollama serve"
|
| 301 |
+
) from exc
|
| 302 |
+
except _requests.exceptions.Timeout as exc:
|
| 303 |
+
raise RuntimeError(
|
| 304 |
+
f"Ollama timed out after {timeout}s. Increase llm.timeout_seconds in config.yaml."
|
| 305 |
+
) from exc
|
| 306 |
+
|
| 307 |
+
if resp.status_code != 200:
|
| 308 |
+
raise RuntimeError(f"Ollama HTTP {resp.status_code}: {resp.text[:300]}")
|
| 309 |
+
|
| 310 |
+
try:
|
| 311 |
+
data = resp.json()
|
| 312 |
+
answer = data.get("response", "").strip()
|
| 313 |
+
except (json.JSONDecodeError, KeyError) as exc:
|
| 314 |
+
raise RuntimeError(f"Unexpected Ollama response: {exc}") from exc
|
| 315 |
+
|
| 316 |
+
if not answer:
|
| 317 |
+
raise RuntimeError("Ollama returned an empty response.")
|
| 318 |
+
|
| 319 |
+
elapsed = int((time.perf_counter() - t0) * 1000)
|
| 320 |
+
logger.info("Ollama generated answer in %d ms (%d chars)", elapsed, len(answer))
|
| 321 |
+
return answer
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
# ---------------------------------------------------------------------------
|
| 325 |
+
# Mistral provider
|
| 326 |
+
# ---------------------------------------------------------------------------
|
| 327 |
+
|
| 328 |
+
def _generate_mistral(prompt: str, config: dict) -> str:
|
| 329 |
+
import requests as _requests
|
| 330 |
+
|
| 331 |
+
llm_cfg = config.get("llm", {})
|
| 332 |
+
# Resolve placeholder or direct value
|
| 333 |
+
_raw_key = llm_cfg.get("mistral_api_key", "")
|
| 334 |
+
api_key = os.environ.get("MISTRAL_API_KEY") if (not _raw_key or _raw_key.startswith("${")) else _raw_key
|
| 335 |
+
if not api_key:
|
| 336 |
+
raise RuntimeError(
|
| 337 |
+
"Mistral API key not found. Set MISTRAL_API_KEY in Backend/.env"
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
model = llm_cfg.get("model", "mistral-large-latest")
|
| 341 |
+
timeout = llm_cfg.get("timeout_seconds", 120)
|
| 342 |
+
temperature = llm_cfg.get("generation_temperature", 0.7)
|
| 343 |
+
|
| 344 |
+
payload = {
|
| 345 |
+
"model": model,
|
| 346 |
+
"messages": [{"role": "user", "content": prompt}],
|
| 347 |
+
"temperature": temperature,
|
| 348 |
+
"max_tokens": 1024,
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
headers = {
|
| 352 |
+
"Authorization": f"Bearer {api_key}",
|
| 353 |
+
"Content-Type": "application/json"
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
url = "https://api.mistral.ai/v1/chat/completions"
|
| 357 |
+
logger.info("Calling Mistral API (model=%s, key=...***)", model)
|
| 358 |
+
t0 = time.perf_counter()
|
| 359 |
+
|
| 360 |
+
try:
|
| 361 |
+
resp = _requests.post(url, json=payload, headers=headers, timeout=timeout)
|
| 362 |
+
except Exception as exc:
|
| 363 |
+
raise RuntimeError(f"Mistral API network error: {exc}") from exc
|
| 364 |
+
|
| 365 |
+
if resp.status_code != 200:
|
| 366 |
+
raise RuntimeError(f"Mistral HTTP {resp.status_code}: {resp.text[:300]}")
|
| 367 |
+
|
| 368 |
+
try:
|
| 369 |
+
data = resp.json()
|
| 370 |
+
answer = data["choices"][0]["message"]["content"].strip()
|
| 371 |
+
except Exception as exc:
|
| 372 |
+
raise RuntimeError(f"Unexpected Mistral response: {exc}") from exc
|
| 373 |
+
|
| 374 |
+
if not answer:
|
| 375 |
+
raise RuntimeError("Mistral returned an empty response.")
|
| 376 |
+
|
| 377 |
+
elapsed = int((time.perf_counter() - t0) * 1000)
|
| 378 |
+
logger.info("Mistral generated answer in %d ms (%d chars)", elapsed, len(answer))
|
| 379 |
+
return answer
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
# ---------------------------------------------------------------------------
|
| 383 |
+
# Groq provider
|
| 384 |
+
# ---------------------------------------------------------------------------
|
| 385 |
+
|
| 386 |
+
def _generate_groq(prompt: str, config: dict) -> str:
|
| 387 |
+
import requests as _requests
|
| 388 |
+
|
| 389 |
+
llm_cfg = config.get("llm", {})
|
| 390 |
+
_raw_key = llm_cfg.get("groq_api_key", "")
|
| 391 |
+
api_key = os.environ.get("GROQ_API_KEY") if (not _raw_key or _raw_key.startswith("${")) else _raw_key
|
| 392 |
+
if not api_key:
|
| 393 |
+
raise RuntimeError(
|
| 394 |
+
"Groq API key not found. Set GROQ_API_KEY in Backend/.env"
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
model = llm_cfg.get("groq_model") or llm_cfg.get("model", "llama-3.3-70b-versatile")
|
| 398 |
+
timeout = llm_cfg.get("timeout_seconds", 120)
|
| 399 |
+
temperature = llm_cfg.get("generation_temperature", 0.7)
|
| 400 |
+
|
| 401 |
+
payload = {
|
| 402 |
+
"model": model,
|
| 403 |
+
"messages": [{"role": "user", "content": prompt}],
|
| 404 |
+
"temperature": temperature,
|
| 405 |
+
"max_tokens": 1024,
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
headers = {
|
| 409 |
+
"Authorization": f"Bearer {api_key}",
|
| 410 |
+
"Content-Type": "application/json"
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
url = "https://api.groq.com/openai/v1/chat/completions"
|
| 414 |
+
logger.info("Calling Groq API (model=%s, key=...***)", model)
|
| 415 |
+
t0 = time.perf_counter()
|
| 416 |
+
|
| 417 |
+
try:
|
| 418 |
+
resp = _requests.post(url, json=payload, headers=headers, timeout=timeout)
|
| 419 |
+
except Exception as exc:
|
| 420 |
+
raise RuntimeError(f"Groq API network error: {exc}") from exc
|
| 421 |
+
|
| 422 |
+
if resp.status_code != 200:
|
| 423 |
+
raise RuntimeError(f"Groq HTTP {resp.status_code}: {resp.text[:300]}")
|
| 424 |
+
|
| 425 |
+
try:
|
| 426 |
+
data = resp.json()
|
| 427 |
+
answer = data["choices"][0]["message"]["content"].strip()
|
| 428 |
+
except Exception as exc:
|
| 429 |
+
raise RuntimeError(f"Unexpected Groq response: {exc}") from exc
|
| 430 |
+
|
| 431 |
+
if not answer:
|
| 432 |
+
raise RuntimeError("Groq returned an empty response.")
|
| 433 |
+
|
| 434 |
+
elapsed = int((time.perf_counter() - t0) * 1000)
|
| 435 |
+
logger.info("Groq generated answer in %d ms (%d chars)", elapsed, len(answer))
|
| 436 |
+
return answer
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
# ---------------------------------------------------------------------------
|
| 440 |
+
# Public API
|
| 441 |
+
# ---------------------------------------------------------------------------
|
| 442 |
+
|
| 443 |
+
def generate_answer(
|
| 444 |
+
question: str,
|
| 445 |
+
context_chunks: list[dict],
|
| 446 |
+
config: Optional[dict] = None,
|
| 447 |
+
overrides: Optional[dict] = None,
|
| 448 |
+
) -> str:
|
| 449 |
+
"""
|
| 450 |
+
Generate a grounded medical answer.
|
| 451 |
+
|
| 452 |
+
Provider is selected from config.yaml → llm.provider, but can be
|
| 453 |
+
overridden per-request via the `overrides` dict. This makes the eval
|
| 454 |
+
engine portable — callers bring their own API key and model.
|
| 455 |
+
|
| 456 |
+
Args:
|
| 457 |
+
question : User's medical question.
|
| 458 |
+
context_chunks : Retrieved context chunks (dicts with 'text' key).
|
| 459 |
+
config : Config dict (loaded from config.yaml if None).
|
| 460 |
+
overrides : Per-request overrides. Supported keys:
|
| 461 |
+
provider → "gemini" or "ollama"
|
| 462 |
+
api_key → Gemini API key
|
| 463 |
+
model → model name (e.g. "gemini-2.5-flash-lite")
|
| 464 |
+
ollama_url → Ollama base URL
|
| 465 |
+
|
| 466 |
+
Returns:
|
| 467 |
+
Generated answer string.
|
| 468 |
+
|
| 469 |
+
Raises:
|
| 470 |
+
RuntimeError : If the provider is unreachable or returns an error.
|
| 471 |
+
"""
|
| 472 |
+
if config is None:
|
| 473 |
+
config = _load_config()
|
| 474 |
+
|
| 475 |
+
# Build effective config: server config as base, overrides win
|
| 476 |
+
effective_llm = dict(config.get("llm", {}))
|
| 477 |
+
if overrides:
|
| 478 |
+
if overrides.get("provider"):
|
| 479 |
+
effective_llm["provider"] = overrides["provider"]
|
| 480 |
+
if overrides.get("api_key"):
|
| 481 |
+
pk = (overrides.get("provider") or "gemini").lower()
|
| 482 |
+
key_map = {
|
| 483 |
+
"gemini": "gemini_api_key",
|
| 484 |
+
"openai": "openai_api_key",
|
| 485 |
+
"mistral": "mistral_api_key",
|
| 486 |
+
"groq": "groq_api_key",
|
| 487 |
+
}
|
| 488 |
+
effective_llm[key_map.get(pk, "gemini_api_key")] = overrides["api_key"]
|
| 489 |
+
if overrides.get("model"):
|
| 490 |
+
pk = (overrides.get("provider") or "gemini").lower()
|
| 491 |
+
model_map = {
|
| 492 |
+
"gemini": "gemini_model",
|
| 493 |
+
"openai": "openai_model",
|
| 494 |
+
"mistral": "model",
|
| 495 |
+
"groq": "groq_model",
|
| 496 |
+
}
|
| 497 |
+
effective_llm[model_map.get(pk, "gemini_model")] = overrides["model"]
|
| 498 |
+
if overrides.get("ollama_url"):
|
| 499 |
+
effective_llm["base_url"] = overrides["ollama_url"]
|
| 500 |
+
|
| 501 |
+
effective_config = {**config, "llm": effective_llm}
|
| 502 |
+
provider = effective_llm.get("provider", "gemini").lower()
|
| 503 |
+
system_prompt_override = overrides.get("system_prompt") if overrides else None
|
| 504 |
+
persona = overrides.get("persona", "physician") if overrides else "physician"
|
| 505 |
+
|
| 506 |
+
prompt = _build_prompt(
|
| 507 |
+
question,
|
| 508 |
+
context_chunks,
|
| 509 |
+
system_prompt=system_prompt_override,
|
| 510 |
+
persona=persona
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
if provider == "gemini":
|
| 514 |
+
return _generate_gemini(prompt, effective_config)
|
| 515 |
+
elif provider == "openai":
|
| 516 |
+
return _generate_openai(prompt, effective_config)
|
| 517 |
+
elif provider == "ollama":
|
| 518 |
+
return _generate_ollama(prompt, effective_config)
|
| 519 |
+
elif provider == "mistral":
|
| 520 |
+
return _generate_mistral(prompt, effective_config)
|
| 521 |
+
elif provider == "groq":
|
| 522 |
+
return _generate_groq(prompt, effective_config)
|
| 523 |
+
else:
|
| 524 |
+
raise RuntimeError(
|
| 525 |
+
f"Unknown LLM provider '{provider}'. "
|
| 526 |
+
"Set llm.provider to 'gemini', 'mistral', 'groq', or 'ollama'."
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def generate_strict_answer(
|
| 531 |
+
question: str,
|
| 532 |
+
context_chunks: list[dict],
|
| 533 |
+
config: Optional[dict] = None,
|
| 534 |
+
overrides: Optional[dict] = None,
|
| 535 |
+
) -> str:
|
| 536 |
+
"""
|
| 537 |
+
Generate a STRICT context-only answer.
|
| 538 |
+
Called when initial answer fails evaluation (HRS >= 60).
|
| 539 |
+
The LLM is forbidden from using any training knowledge.
|
| 540 |
+
"""
|
| 541 |
+
if config is None:
|
| 542 |
+
config = _load_config()
|
| 543 |
+
|
| 544 |
+
effective_llm = dict(config.get("llm", {}))
|
| 545 |
+
if overrides:
|
| 546 |
+
if overrides.get("provider"):
|
| 547 |
+
effective_llm["provider"] = overrides["provider"]
|
| 548 |
+
if overrides.get("api_key"):
|
| 549 |
+
pk = (overrides.get("provider") or "gemini").lower()
|
| 550 |
+
key_map = {
|
| 551 |
+
"gemini": "gemini_api_key",
|
| 552 |
+
"openai": "openai_api_key",
|
| 553 |
+
"mistral": "mistral_api_key",
|
| 554 |
+
"groq": "groq_api_key",
|
| 555 |
+
}
|
| 556 |
+
effective_llm[key_map.get(pk, "gemini_api_key")] = overrides["api_key"]
|
| 557 |
+
if overrides.get("model"):
|
| 558 |
+
pk = (overrides.get("provider") or "gemini").lower()
|
| 559 |
+
model_map = {
|
| 560 |
+
"gemini": "gemini_model",
|
| 561 |
+
"openai": "openai_model",
|
| 562 |
+
"mistral": "model",
|
| 563 |
+
"groq": "groq_model",
|
| 564 |
+
}
|
| 565 |
+
effective_llm[model_map.get(pk, "gemini_model")] = overrides["model"]
|
| 566 |
+
if overrides.get("ollama_url"):
|
| 567 |
+
effective_llm["base_url"] = overrides["ollama_url"]
|
| 568 |
+
|
| 569 |
+
effective_config = {**config, "llm": effective_llm}
|
| 570 |
+
provider = effective_llm.get("provider", "gemini").lower()
|
| 571 |
+
prompt = _build_strict_prompt(question, context_chunks)
|
| 572 |
+
|
| 573 |
+
if provider == "gemini":
|
| 574 |
+
return _generate_gemini(prompt, effective_config)
|
| 575 |
+
elif provider == "openai":
|
| 576 |
+
return _generate_openai(prompt, effective_config)
|
| 577 |
+
elif provider == "ollama":
|
| 578 |
+
return _generate_ollama(prompt, effective_config)
|
| 579 |
+
elif provider == "mistral":
|
| 580 |
+
return _generate_mistral(prompt, effective_config)
|
| 581 |
+
elif provider == "groq":
|
| 582 |
+
return _generate_groq(prompt, effective_config)
|
| 583 |
+
else:
|
| 584 |
+
raise RuntimeError(f"Unknown LLM provider '{provider}'.")
|
src/pipeline/ingest.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FR-01: Document Ingestion
|
| 3 |
+
=========================
|
| 4 |
+
Loads documents from:
|
| 5 |
+
- PubMedQA (HuggingFace: pubmed_qa, pqa_labeled) — up to 500 samples
|
| 6 |
+
- MedQA-USMLE (local JSONL from jind11/MedQA) — up to 200 samples
|
| 7 |
+
|
| 8 |
+
Then calls chunker.py to split and saves chunks to data/processed/chunks.jsonl.
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python src/pipeline/ingest.py
|
| 12 |
+
python src/pipeline/ingest.py --pubmedqa 500 --medqa 200
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import sys
|
| 17 |
+
import os
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
# Make project root importable when running as a script
|
| 21 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import json
|
| 25 |
+
import logging
|
| 26 |
+
import uuid
|
| 27 |
+
import yaml
|
| 28 |
+
from typing import Any
|
| 29 |
+
|
| 30 |
+
import src # noqa: F401 — triggers logging setup
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
# Config
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
|
| 39 |
+
def _load_config() -> dict:
|
| 40 |
+
with open("config.yaml", "r", encoding="utf-8") as f:
|
| 41 |
+
return yaml.safe_load(f)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
# PubMedQA Ingestion (FR-01)
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
def ingest_pubmedqa(max_samples: int = 500) -> list[dict[str, Any]]:
|
| 49 |
+
"""
|
| 50 |
+
Load PubMedQA from HuggingFace datasets.
|
| 51 |
+
Each QA item contributes its context passages (abstracts) as documents,
|
| 52 |
+
plus its long_answer if available.
|
| 53 |
+
|
| 54 |
+
pub_type = "research_abstract" → Tier 3 (SRS FR-03b)
|
| 55 |
+
"""
|
| 56 |
+
# Use 'pqa_artificial' (211k rows) if asking for more than 1000,
|
| 57 |
+
# as 'pqa_labeled' only has 1000 rows.
|
| 58 |
+
split_name = "pqa_artificial" if max_samples > 1000 else "pqa_labeled"
|
| 59 |
+
logger.info("Loading PubMedQA split='%s' (max %d QA pairs)...", split_name, max_samples)
|
| 60 |
+
try:
|
| 61 |
+
from datasets import load_dataset
|
| 62 |
+
dataset = load_dataset(
|
| 63 |
+
"pubmed_qa", split_name, split="train", trust_remote_code=True
|
| 64 |
+
)
|
| 65 |
+
except Exception as exc:
|
| 66 |
+
logger.error("Failed to load PubMedQA from HuggingFace: %s", exc)
|
| 67 |
+
logger.error("Ensure you have an internet connection and datasets>=2.18.0")
|
| 68 |
+
return []
|
| 69 |
+
|
| 70 |
+
documents: list[dict] = []
|
| 71 |
+
for i, item in enumerate(dataset):
|
| 72 |
+
if i >= max_samples:
|
| 73 |
+
break
|
| 74 |
+
|
| 75 |
+
pub_id = str(item.get("pubid", uuid.uuid4().hex[:8]))
|
| 76 |
+
question = item.get("question", "")[:200]
|
| 77 |
+
|
| 78 |
+
# Index each context passage as a separate document
|
| 79 |
+
contexts: list[str] = item.get("context", {}).get("contexts", [])
|
| 80 |
+
for ctx in contexts:
|
| 81 |
+
if ctx and ctx.strip():
|
| 82 |
+
documents.append({
|
| 83 |
+
"text": ctx.strip(),
|
| 84 |
+
"title": question,
|
| 85 |
+
"doc_id": f"pubmedqa_{pub_id}",
|
| 86 |
+
"source": "pubmedqa",
|
| 87 |
+
"pub_type": "research_abstract",
|
| 88 |
+
"pub_year": 0,
|
| 89 |
+
"journal": "",
|
| 90 |
+
})
|
| 91 |
+
|
| 92 |
+
# Also index the long_answer (gold-standard explanation)
|
| 93 |
+
long_ans: str = item.get("long_answer", "").strip()
|
| 94 |
+
if long_ans:
|
| 95 |
+
documents.append({
|
| 96 |
+
"text": long_ans,
|
| 97 |
+
"title": question,
|
| 98 |
+
"doc_id": f"pubmedqa_{pub_id}_ans",
|
| 99 |
+
"source": "pubmedqa",
|
| 100 |
+
"pub_type": "research_abstract",
|
| 101 |
+
"pub_year": 0,
|
| 102 |
+
"journal": "",
|
| 103 |
+
})
|
| 104 |
+
|
| 105 |
+
logger.info(
|
| 106 |
+
"PubMedQA: %d documents loaded from %d QA items",
|
| 107 |
+
len(documents),
|
| 108 |
+
min(max_samples, len(dataset)),
|
| 109 |
+
)
|
| 110 |
+
return documents
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# ---------------------------------------------------------------------------
|
| 114 |
+
# MedQA-USMLE Ingestion (FR-01)
|
| 115 |
+
# ---------------------------------------------------------------------------
|
| 116 |
+
|
| 117 |
+
def ingest_medqa(
|
| 118 |
+
data_dir: str = "data/raw/medqa",
|
| 119 |
+
max_samples: int = 200,
|
| 120 |
+
) -> list[dict[str, Any]]:
|
| 121 |
+
"""
|
| 122 |
+
Load MedQA-USMLE from local JSONL files.
|
| 123 |
+
|
| 124 |
+
To obtain the data:
|
| 125 |
+
git clone https://github.com/jind11/MedQA
|
| 126 |
+
Copy the JSONL files from data_clean/questions/US/ to data/raw/medqa/
|
| 127 |
+
|
| 128 |
+
pub_type = "exam_question" → Tier 5 (SRS FR-03b)
|
| 129 |
+
"""
|
| 130 |
+
data_path = Path(data_dir)
|
| 131 |
+
jsonl_files = sorted(list(data_path.glob("*.jsonl")) + list(data_path.glob("**/*.jsonl")))
|
| 132 |
+
|
| 133 |
+
if not jsonl_files:
|
| 134 |
+
logger.warning(
|
| 135 |
+
"MedQA data not found at '%s'. "
|
| 136 |
+
"To get it: git clone https://github.com/jind11/MedQA "
|
| 137 |
+
"and copy JSONL files to %s/",
|
| 138 |
+
data_dir, data_dir,
|
| 139 |
+
)
|
| 140 |
+
return []
|
| 141 |
+
|
| 142 |
+
logger.info("Loading MedQA from '%s' (%d files)...", data_dir, len(jsonl_files))
|
| 143 |
+
documents: list[dict] = []
|
| 144 |
+
|
| 145 |
+
for jsonl_file in jsonl_files:
|
| 146 |
+
if len(documents) >= max_samples:
|
| 147 |
+
break
|
| 148 |
+
with open(jsonl_file, "r", encoding="utf-8") as f:
|
| 149 |
+
for raw_line in f:
|
| 150 |
+
if len(documents) >= max_samples:
|
| 151 |
+
break
|
| 152 |
+
raw_line = raw_line.strip()
|
| 153 |
+
if not raw_line:
|
| 154 |
+
continue
|
| 155 |
+
try:
|
| 156 |
+
item = json.loads(raw_line)
|
| 157 |
+
except json.JSONDecodeError as exc:
|
| 158 |
+
logger.warning("Skipping malformed JSON in %s: %s", jsonl_file.name, exc)
|
| 159 |
+
continue
|
| 160 |
+
|
| 161 |
+
question: str = item.get("question", "")
|
| 162 |
+
options: dict = item.get("options", {})
|
| 163 |
+
answer_key: str = item.get("answer", "")
|
| 164 |
+
answer_text: str = options.get(answer_key, "")
|
| 165 |
+
|
| 166 |
+
# Combine question + all options + correct answer as document text
|
| 167 |
+
opts_text = " ".join(f"{k}: {v}" for k, v in options.items())
|
| 168 |
+
text = f"Question: {question}\nOptions: {opts_text}"
|
| 169 |
+
if answer_text:
|
| 170 |
+
text += f"\nAnswer ({answer_key}): {answer_text}"
|
| 171 |
+
|
| 172 |
+
documents.append({
|
| 173 |
+
"text": text,
|
| 174 |
+
"title": question[:200],
|
| 175 |
+
"doc_id": f"medqa_{uuid.uuid4().hex[:10]}",
|
| 176 |
+
"source": "medqa",
|
| 177 |
+
"pub_type": "exam_question",
|
| 178 |
+
"pub_year": 0,
|
| 179 |
+
"journal": "",
|
| 180 |
+
})
|
| 181 |
+
|
| 182 |
+
logger.info("MedQA: %d documents loaded", len(documents))
|
| 183 |
+
return documents
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# ---------------------------------------------------------------------------
|
| 187 |
+
# Helpers
|
| 188 |
+
# ---------------------------------------------------------------------------
|
| 189 |
+
|
| 190 |
+
def _save_raw_documents(documents: list[dict], output_path: str) -> None:
|
| 191 |
+
out = Path(output_path)
|
| 192 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 193 |
+
with open(out, "w", encoding="utf-8") as f:
|
| 194 |
+
for doc in documents:
|
| 195 |
+
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
| 196 |
+
logger.info("Saved %d raw documents to %s", len(documents), output_path)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _save_chunks(chunks: list[dict], output_path: str) -> None:
|
| 200 |
+
out = Path(output_path)
|
| 201 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 202 |
+
with open(out, "w", encoding="utf-8") as f:
|
| 203 |
+
for chunk in chunks:
|
| 204 |
+
f.write(json.dumps(chunk, ensure_ascii=False) + "\n")
|
| 205 |
+
logger.info("Saved %d chunks to %s", len(chunks), output_path)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# ---------------------------------------------------------------------------
|
| 209 |
+
# Main
|
| 210 |
+
# ---------------------------------------------------------------------------
|
| 211 |
+
|
| 212 |
+
def main() -> None:
|
| 213 |
+
parser = argparse.ArgumentParser(description="MediRAG-Eval Document Ingestion (FR-01)")
|
| 214 |
+
parser.add_argument("--pubmedqa", type=int, default=500, help="Max PubMedQA samples")
|
| 215 |
+
parser.add_argument("--medqa", type=int, default=200, help="Max MedQA-USMLE samples")
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--medqa-dir", default="data/raw/medqa",
|
| 218 |
+
help="Directory containing MedQA JSONL files",
|
| 219 |
+
)
|
| 220 |
+
args = parser.parse_args()
|
| 221 |
+
|
| 222 |
+
config = _load_config()
|
| 223 |
+
|
| 224 |
+
# --- Ingest ---
|
| 225 |
+
pubmedqa_docs = ingest_pubmedqa(max_samples=args.pubmedqa)
|
| 226 |
+
medqa_docs = ingest_medqa(data_dir=args.medqa_dir, max_samples=args.medqa)
|
| 227 |
+
all_docs = pubmedqa_docs + medqa_docs
|
| 228 |
+
|
| 229 |
+
logger.info("Total documents ingested: %d", len(all_docs))
|
| 230 |
+
|
| 231 |
+
if not all_docs:
|
| 232 |
+
logger.error("No documents loaded. Check internet for PubMedQA and/or data/raw/medqa/ for MedQA.")
|
| 233 |
+
sys.exit(1)
|
| 234 |
+
|
| 235 |
+
# --- Save raw documents (for inspection) ---
|
| 236 |
+
_save_raw_documents(all_docs, "data/raw/documents.jsonl")
|
| 237 |
+
|
| 238 |
+
# --- Chunk ---
|
| 239 |
+
from src.pipeline.chunker import chunk_documents
|
| 240 |
+
chunks = chunk_documents(all_docs, config)
|
| 241 |
+
logger.info("Total chunks produced: %d", len(chunks))
|
| 242 |
+
|
| 243 |
+
# --- Save chunks for embedder ---
|
| 244 |
+
_save_chunks(chunks, "data/processed/chunks.jsonl")
|
| 245 |
+
|
| 246 |
+
logger.info("Ingestion complete. Now run: python src/pipeline/embedder.py")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if __name__ == "__main__":
|
| 250 |
+
main()
|
src/pipeline/privacy.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
src/pipeline/privacy.py — PHI/PII Privacy Shield (The Sanitizer)
|
| 3 |
+
==============================================================
|
| 4 |
+
Detects and redacts sensitive patient information before external API calls.
|
| 5 |
+
Supports names, dates, contact info, and generic medical IDs.
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
import re
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Dict, Tuple
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
class PrivacyShield:
|
| 15 |
+
def __init__(self):
|
| 16 |
+
# Basic patterns for common PII
|
| 17 |
+
self.patterns = {
|
| 18 |
+
"EMAIL": r'[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+',
|
| 19 |
+
"PHONE": r'\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}',
|
| 20 |
+
"SSN": r'\b\d{3}-\d{2}-\d{4}\b',
|
| 21 |
+
"DOB": r'\b\d{2}/\d{2}/\d{4}\b|\b\d{4}-\d{2}-\d{2}\b',
|
| 22 |
+
"ID": r'\bPT-\d{4,8}\b|\bID:\s?\d{4,8}\b'
|
| 23 |
+
}
|
| 24 |
+
# Names are harder without heavy NER, so we start with common indicators or capital patterns
|
| 25 |
+
# In a production app, we would use a dedicated medical NER model.
|
| 26 |
+
self.name_pattern = r'\b(?:Mr\.|Ms\.|Mrs\.|Dr\.)\s[A-Z][a-z]+(?:\s[A-Z][a-z]+)?\b'
|
| 27 |
+
|
| 28 |
+
def redact(self, text: str) -> Tuple[str, Dict[str, str]]:
|
| 29 |
+
"""
|
| 30 |
+
Redacts PHI in text and returns (redacted_text, placeholder_map).
|
| 31 |
+
"""
|
| 32 |
+
mapping = {}
|
| 33 |
+
redacted = text
|
| 34 |
+
|
| 35 |
+
# 1. Redact specific patterns
|
| 36 |
+
for label, pattern in self.patterns.items():
|
| 37 |
+
matches = re.findall(pattern, redacted)
|
| 38 |
+
for i, match in enumerate(set(matches)):
|
| 39 |
+
placeholder = f"[{label}_{i+1}]"
|
| 40 |
+
mapping[placeholder] = match
|
| 41 |
+
redacted = redacted.replace(match, placeholder)
|
| 42 |
+
|
| 43 |
+
# 2. Redact potential names
|
| 44 |
+
name_matches = re.findall(self.name_pattern, redacted)
|
| 45 |
+
for i, match in enumerate(set(name_matches)):
|
| 46 |
+
placeholder = f"[PATIENT_NAME_{i+1}]"
|
| 47 |
+
mapping[placeholder] = match
|
| 48 |
+
redacted = redacted.replace(match, placeholder)
|
| 49 |
+
|
| 50 |
+
if mapping:
|
| 51 |
+
logger.info("Privacy Shield: Redacted %d sensitive items.", len(mapping))
|
| 52 |
+
|
| 53 |
+
return redacted, mapping
|
| 54 |
+
|
| 55 |
+
def restore(self, text: str, mapping: Dict[str, str]) -> str:
|
| 56 |
+
"""
|
| 57 |
+
Replaces placeholders in the AI response with original values.
|
| 58 |
+
"""
|
| 59 |
+
restored = text
|
| 60 |
+
for placeholder, original in mapping.items():
|
| 61 |
+
restored = restored.replace(placeholder, original)
|
| 62 |
+
return restored
|
| 63 |
+
|
| 64 |
+
# Singleton instance
|
| 65 |
+
shield = PrivacyShield()
|
src/pipeline/retriever.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FR-04: Vector Retrieval
|
| 3 |
+
=======================
|
| 4 |
+
FAISS IndexFlatIP with L2-normalised vectors (inner product = cosine similarity).
|
| 5 |
+
Returns top-k chunks as (chunk_text, metadata_dict, similarity_score) tuples.
|
| 6 |
+
|
| 7 |
+
Usage (as a module):
|
| 8 |
+
from src.pipeline.retriever import Retriever
|
| 9 |
+
r = Retriever(config)
|
| 10 |
+
results = r.search("What is the treatment for Type 2 Diabetes?")
|
| 11 |
+
for text, meta, score in results:
|
| 12 |
+
print(score, meta["pub_type"], text[:80])
|
| 13 |
+
|
| 14 |
+
Usage (smoke test):
|
| 15 |
+
python src/pipeline/retriever.py
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import sys
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
|
| 23 |
+
|
| 24 |
+
import logging
|
| 25 |
+
import pickle
|
| 26 |
+
from typing import Any
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
import faiss
|
| 30 |
+
except ImportError:
|
| 31 |
+
faiss = None
|
| 32 |
+
|
| 33 |
+
import numpy as np
|
| 34 |
+
import yaml
|
| 35 |
+
|
| 36 |
+
logger = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Retriever:
|
| 40 |
+
"""
|
| 41 |
+
Hybrid FAISS + BM25 document retriever.
|
| 42 |
+
|
| 43 |
+
On first search, lazily builds a BM25 index over all chunk texts.
|
| 44 |
+
Each search runs both FAISS (semantic) and BM25 (keyword) then merges
|
| 45 |
+
results using Reciprocal Rank Fusion (RRF) for best-of-both precision
|
| 46 |
+
and recall.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
| 50 |
+
RERANK_CANDIDATES = 60 # retrieve this many via RRF, then re-rank to top_k
|
| 51 |
+
|
| 52 |
+
def __init__(self, config: dict) -> None:
|
| 53 |
+
self.config = config
|
| 54 |
+
self.top_k: int = config["retrieval"]["top_k"]
|
| 55 |
+
self.model_name: str = config["retrieval"]["embedding_model"]
|
| 56 |
+
self.index_path: str = config["retrieval"]["index_path"]
|
| 57 |
+
self.meta_path: str = config["retrieval"]["metadata_path"]
|
| 58 |
+
|
| 59 |
+
self._model = None
|
| 60 |
+
self._reranker = None # cross-encoder re-ranker, loaded lazily
|
| 61 |
+
self._index = None
|
| 62 |
+
self._metadata: dict[int, dict] | None = None
|
| 63 |
+
self._bm25 = None # built lazily on first search
|
| 64 |
+
self._bm25_ids: list[int] = [] # maps bm25 row → faiss_idx
|
| 65 |
+
|
| 66 |
+
# ------------------------------------------------------------------
|
| 67 |
+
# Private loaders (lazy)
|
| 68 |
+
# ------------------------------------------------------------------
|
| 69 |
+
|
| 70 |
+
def _load_model(self) -> None:
|
| 71 |
+
if self._model is None:
|
| 72 |
+
try:
|
| 73 |
+
from sentence_transformers import SentenceTransformer
|
| 74 |
+
logger.info("Loading BioBERT: %s", self.model_name)
|
| 75 |
+
self._model = SentenceTransformer(self.model_name)
|
| 76 |
+
logger.info("BioBERT model loaded successfully.")
|
| 77 |
+
except ImportError as e:
|
| 78 |
+
logger.error("sentence_transformers not installed: %s", e)
|
| 79 |
+
self._model = None
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.error("Failed to load embedding model '%s': %s — FAISS search will be skipped, falling back to BM25.", self.model_name, e)
|
| 82 |
+
self._model = None
|
| 83 |
+
|
| 84 |
+
def _load_reranker(self) -> None:
|
| 85 |
+
if self._reranker is None:
|
| 86 |
+
try:
|
| 87 |
+
from sentence_transformers import CrossEncoder
|
| 88 |
+
logger.info("Loading re-ranker: %s", self.RERANKER_MODEL)
|
| 89 |
+
self._reranker = CrossEncoder(self.RERANKER_MODEL)
|
| 90 |
+
logger.info("Re-ranker loaded.")
|
| 91 |
+
except Exception as e:
|
| 92 |
+
logger.warning("Re-ranker unavailable (%s) — falling back to RRF ranking.", e)
|
| 93 |
+
self._reranker = "unavailable"
|
| 94 |
+
|
| 95 |
+
def _load_index(self) -> None:
|
| 96 |
+
if self._index is not None:
|
| 97 |
+
return
|
| 98 |
+
|
| 99 |
+
idx_path = Path(self.index_path)
|
| 100 |
+
meta_path = Path(self.meta_path)
|
| 101 |
+
|
| 102 |
+
if not idx_path.exists():
|
| 103 |
+
raise FileNotFoundError(
|
| 104 |
+
f"FAISS index not found at '{idx_path}'. "
|
| 105 |
+
"Run python src/pipeline/ingest.py && python src/pipeline/embedder.py first."
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
logger.info("Loading FAISS index from %s", idx_path)
|
| 110 |
+
if faiss is not None:
|
| 111 |
+
self._index = faiss.read_index(str(idx_path))
|
| 112 |
+
else:
|
| 113 |
+
self._index = None
|
| 114 |
+
logger.warning("FAISS not installed — FAISS search disabled.")
|
| 115 |
+
|
| 116 |
+
logger.info("Loading metadata store from %s", meta_path)
|
| 117 |
+
with open(meta_path, "rb") as f:
|
| 118 |
+
self._metadata = pickle.load(f)
|
| 119 |
+
|
| 120 |
+
logger.info(
|
| 121 |
+
"Retriever ready: %d vectors, %d metadata entries",
|
| 122 |
+
self._index.ntotal if self._index is not None else 0, len(self._metadata),
|
| 123 |
+
)
|
| 124 |
+
# Build drug→FDA chunks lookup (O(1) at query time)
|
| 125 |
+
self._fda_index: dict[str, list[int]] = {}
|
| 126 |
+
for idx, meta in self._metadata.items():
|
| 127 |
+
if meta.get("source") == "FDA DailyMed":
|
| 128 |
+
doc_id = meta.get("doc_id", "")
|
| 129 |
+
# doc_id format: fda_{drug_name}_{set_id}
|
| 130 |
+
parts = doc_id.split("_")
|
| 131 |
+
drug_key = parts[1].lower() if len(parts) >= 2 else ""
|
| 132 |
+
if drug_key:
|
| 133 |
+
self._fda_index.setdefault(drug_key, []).append(idx)
|
| 134 |
+
logger.info("FDA drug index built: %d unique drugs", len(self._fda_index))
|
| 135 |
+
|
| 136 |
+
# Build keyword→guideline chunks lookup for clinical guidelines
|
| 137 |
+
self._guideline_index: dict[str, list[int]] = {}
|
| 138 |
+
for idx, meta in self._metadata.items():
|
| 139 |
+
if meta.get("pub_type") == "clinical_guideline":
|
| 140 |
+
text = (meta.get("chunk_text", "") + " " + meta.get("title", "")).lower()
|
| 141 |
+
for keyword in [
|
| 142 |
+
# Diabetes / ADA
|
| 143 |
+
"diagnosis", "diagnostic", "treatment", "pharmacologic",
|
| 144 |
+
"glycemic", "insulin", "obesity", "hypoglycemia",
|
| 145 |
+
"screening", "complication", "pregnancy",
|
| 146 |
+
"children", "adolescent", "older adult", "hospital",
|
| 147 |
+
# Cardiovascular / ACC-AHA
|
| 148 |
+
"hypertension", "blood pressure", "antihypertensive",
|
| 149 |
+
"statin", "cholesterol", "ldl", "lipid", "triglyceride",
|
| 150 |
+
"cardiovascular", "coronary", "heart disease", "stroke",
|
| 151 |
+
"aspirin", "antiplatelet", "anticoagulant",
|
| 152 |
+
"prevention", "risk reduction", "atherosclerosis",
|
| 153 |
+
"heart failure", "ejection fraction",
|
| 154 |
+
"smoking", "exercise", "diet", "lifestyle",
|
| 155 |
+
]:
|
| 156 |
+
if keyword in text:
|
| 157 |
+
self._guideline_index.setdefault(keyword, []).append(idx)
|
| 158 |
+
logger.info("Guideline index built: %d keyword entries", len(self._guideline_index))
|
| 159 |
+
except Exception as e:
|
| 160 |
+
logger.error("Failed to load FAISS index or metadata: %s", e)
|
| 161 |
+
self._index = None
|
| 162 |
+
if self._metadata is None:
|
| 163 |
+
self._metadata = {}
|
| 164 |
+
|
| 165 |
+
def _build_bm25(self) -> None:
|
| 166 |
+
"""Build BM25 index from the loaded metadata store (called once)."""
|
| 167 |
+
if self._bm25 is not None:
|
| 168 |
+
return
|
| 169 |
+
self.rebuild_bm25()
|
| 170 |
+
|
| 171 |
+
def rebuild_bm25(self) -> None:
|
| 172 |
+
"""Build BM25 index — loads from cache if available, otherwise builds and saves."""
|
| 173 |
+
try:
|
| 174 |
+
from rank_bm25 import BM25Okapi
|
| 175 |
+
except ImportError:
|
| 176 |
+
logger.warning("rank-bm25 not installed — falling back to FAISS-only.")
|
| 177 |
+
return
|
| 178 |
+
|
| 179 |
+
if self._metadata is None:
|
| 180 |
+
self._load_index()
|
| 181 |
+
|
| 182 |
+
# Cache path: alongside the metadata store
|
| 183 |
+
bm25_cache = Path(self.meta_path).parent / "bm25_cache.pkl"
|
| 184 |
+
meta_mtime = Path(self.meta_path).stat().st_mtime if Path(self.meta_path).exists() else 0
|
| 185 |
+
|
| 186 |
+
# Load from cache if it exists and is newer than the metadata store
|
| 187 |
+
if bm25_cache.exists() and bm25_cache.stat().st_mtime >= meta_mtime:
|
| 188 |
+
try:
|
| 189 |
+
logger.info("Loading BM25 index from cache %s …", bm25_cache)
|
| 190 |
+
with open(bm25_cache, "rb") as f:
|
| 191 |
+
cached = pickle.load(f)
|
| 192 |
+
self._bm25 = cached["bm25"]
|
| 193 |
+
self._bm25_ids = cached["ids"]
|
| 194 |
+
logger.info("BM25 cache loaded (%d docs).", len(self._bm25_ids))
|
| 195 |
+
return
|
| 196 |
+
except Exception as e:
|
| 197 |
+
logger.warning("BM25 cache load failed (%s) — rebuilding.", e)
|
| 198 |
+
|
| 199 |
+
logger.info("Rebuilding BM25 index over %d chunks…", len(self._metadata))
|
| 200 |
+
corpus_ids: list[int] = []
|
| 201 |
+
corpus_tokens: list[list[str]] = []
|
| 202 |
+
for faiss_idx, meta in self._metadata.items():
|
| 203 |
+
text = meta.get("chunk_text", "")
|
| 204 |
+
if text:
|
| 205 |
+
corpus_ids.append(faiss_idx)
|
| 206 |
+
corpus_tokens.append(text.lower().split())
|
| 207 |
+
|
| 208 |
+
self._bm25 = BM25Okapi(corpus_tokens)
|
| 209 |
+
self._bm25_ids = corpus_ids
|
| 210 |
+
logger.info("BM25 index built (%d docs). Saving cache…", len(corpus_ids))
|
| 211 |
+
|
| 212 |
+
try:
|
| 213 |
+
with open(bm25_cache, "wb") as f:
|
| 214 |
+
pickle.dump({"bm25": self._bm25, "ids": self._bm25_ids}, f,
|
| 215 |
+
protocol=pickle.HIGHEST_PROTOCOL)
|
| 216 |
+
logger.info("BM25 cache saved to %s", bm25_cache)
|
| 217 |
+
except Exception as e:
|
| 218 |
+
logger.warning("BM25 cache save failed: %s", e)
|
| 219 |
+
|
| 220 |
+
def get_fda_chunks(self, drug_name: str, section_priority: list[str] | None = None) -> list[dict]:
|
| 221 |
+
"""
|
| 222 |
+
Directly return FDA DailyMed chunks for a specific drug by name.
|
| 223 |
+
Bypasses FAISS/BM25 ranking — O(1) lookup, always finds the drug's label.
|
| 224 |
+
Used during intervention re-retrieval when entity_verifier identifies a drug.
|
| 225 |
+
"""
|
| 226 |
+
self._load_index()
|
| 227 |
+
key = drug_name.lower().strip()
|
| 228 |
+
indices = getattr(self, "_fda_index", {}).get(key, [])
|
| 229 |
+
if not indices:
|
| 230 |
+
# Try partial match (e.g. "warfarin sodium" → "warfarin")
|
| 231 |
+
indices = next(
|
| 232 |
+
(v for k, v in getattr(self, "_fda_index", {}).items() if key in k or k in key),
|
| 233 |
+
[]
|
| 234 |
+
)
|
| 235 |
+
chunks = []
|
| 236 |
+
priority = section_priority or ["CONTRAINDICATIONS", "ADVERSE REACTIONS",
|
| 237 |
+
"DOSAGE AND ADMINISTRATION", "WARNINGS AND PRECAUTIONS",
|
| 238 |
+
"DRUG INTERACTIONS", "INDICATIONS AND USAGE",
|
| 239 |
+
"USE IN SPECIFIC POPULATIONS"]
|
| 240 |
+
for idx in indices:
|
| 241 |
+
meta = self._metadata.get(idx, {})
|
| 242 |
+
chunk_text = meta.get("chunk_text", "")
|
| 243 |
+
section = next((s for s in priority if s in chunk_text.upper()), "OTHER")
|
| 244 |
+
chunks.append({
|
| 245 |
+
"text": chunk_text, "chunk_id": meta.get("chunk_id"),
|
| 246 |
+
"source": meta.get("source", ""), "pub_type": meta.get("pub_type", ""),
|
| 247 |
+
"pub_year": meta.get("pub_year"), "title": meta.get("title", ""),
|
| 248 |
+
"_section": section, "_priority": priority.index(section) if section in priority else 99,
|
| 249 |
+
})
|
| 250 |
+
chunks.sort(key=lambda c: c["_priority"])
|
| 251 |
+
return chunks[:5]
|
| 252 |
+
|
| 253 |
+
def get_guideline_chunks(self, query: str, top_n: int = 5) -> list[dict]:
|
| 254 |
+
"""
|
| 255 |
+
Return clinical guideline chunks relevant to the query via keyword matching.
|
| 256 |
+
Bypasses FAISS/BM25 ranking — used during intervention when retrieval fails.
|
| 257 |
+
"""
|
| 258 |
+
self._load_index()
|
| 259 |
+
query_lower = query.lower()
|
| 260 |
+
guideline_idx = getattr(self, "_guideline_index", {})
|
| 261 |
+
if not guideline_idx:
|
| 262 |
+
return []
|
| 263 |
+
|
| 264 |
+
# Find matching indices — union of all matching keyword lists
|
| 265 |
+
matched: dict[int, int] = {} # idx → match count
|
| 266 |
+
for keyword, indices in guideline_idx.items():
|
| 267 |
+
if keyword in query_lower:
|
| 268 |
+
for idx in indices:
|
| 269 |
+
matched[idx] = matched.get(idx, 0) + 1
|
| 270 |
+
|
| 271 |
+
if not matched:
|
| 272 |
+
return []
|
| 273 |
+
|
| 274 |
+
# Sort by match count (most keyword hits first), take top_n
|
| 275 |
+
top_indices = sorted(matched, key=lambda i: matched[i], reverse=True)[:top_n]
|
| 276 |
+
|
| 277 |
+
chunks = []
|
| 278 |
+
for idx in top_indices:
|
| 279 |
+
meta = self._metadata.get(idx, {})
|
| 280 |
+
chunks.append({
|
| 281 |
+
"text": meta.get("chunk_text", ""),
|
| 282 |
+
"chunk_id": meta.get("chunk_id"),
|
| 283 |
+
"source": meta.get("source", ""),
|
| 284 |
+
"pub_type": meta.get("pub_type", "clinical_guideline"),
|
| 285 |
+
"pub_year": meta.get("pub_year"),
|
| 286 |
+
"title": meta.get("title", ""),
|
| 287 |
+
})
|
| 288 |
+
return chunks
|
| 289 |
+
|
| 290 |
+
# ------------------------------------------------------------------
|
| 291 |
+
# Public API
|
| 292 |
+
# ------------------------------------------------------------------
|
| 293 |
+
|
| 294 |
+
def search(
|
| 295 |
+
self,
|
| 296 |
+
query: str,
|
| 297 |
+
top_k: int | None = None,
|
| 298 |
+
) -> list[tuple[str, dict[str, Any], float]]:
|
| 299 |
+
"""
|
| 300 |
+
Hybrid semantic + keyword search using Reciprocal Rank Fusion.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
query : Natural language query
|
| 304 |
+
top_k : Override config top_k if provided
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
List of (chunk_text, metadata_dict, rrf_score),
|
| 308 |
+
sorted by descending combined score.
|
| 309 |
+
"""
|
| 310 |
+
if not query or not query.strip():
|
| 311 |
+
logger.warning("Retriever.search called with empty query — returning []")
|
| 312 |
+
return []
|
| 313 |
+
|
| 314 |
+
k = top_k or self.top_k
|
| 315 |
+
# Fetch RERANK_CANDIDATES via RRF, then re-rank to top-k
|
| 316 |
+
fetch_k = max(self.RERANK_CANDIDATES, k * 3)
|
| 317 |
+
RRF_K = 60 # standard RRF constant (higher = smoother rank blending)
|
| 318 |
+
|
| 319 |
+
self._load_model()
|
| 320 |
+
self._load_reranker()
|
| 321 |
+
self._load_index()
|
| 322 |
+
self._build_bm25()
|
| 323 |
+
|
| 324 |
+
# ── 1. FAISS semantic search ──────────────────────────────────
|
| 325 |
+
faiss_ranks: dict[int, int] = {}
|
| 326 |
+
if self._model is not None and self._index is not None and faiss is not None:
|
| 327 |
+
try:
|
| 328 |
+
q_vec: np.ndarray = self._model.encode(
|
| 329 |
+
[query.strip()],
|
| 330 |
+
normalize_embeddings=True,
|
| 331 |
+
convert_to_numpy=True,
|
| 332 |
+
).astype(np.float32)
|
| 333 |
+
|
| 334 |
+
scores_arr, idx_arr = self._index.search(q_vec, fetch_k)
|
| 335 |
+
faiss_scores = scores_arr[0]
|
| 336 |
+
faiss_indices = idx_arr[0]
|
| 337 |
+
|
| 338 |
+
# Map faiss_idx → rank (1-indexed)
|
| 339 |
+
for rank, (faiss_idx, score) in enumerate(zip(faiss_indices, faiss_scores), 1):
|
| 340 |
+
if faiss_idx != -1:
|
| 341 |
+
faiss_ranks[int(faiss_idx)] = rank
|
| 342 |
+
|
| 343 |
+
# Raw top-1 cosine similarity (IndexFlatIP + L2-norm = cosine).
|
| 344 |
+
# Used by main.py for coverage-gap detection — a poor match here
|
| 345 |
+
# means the topic is genuinely absent from the database.
|
| 346 |
+
_top_faiss_cosine = float(faiss_scores[0]) if len(faiss_scores) > 0 else 0.0
|
| 347 |
+
except Exception as e:
|
| 348 |
+
logger.error("FAISS search failed: %s", e)
|
| 349 |
+
|
| 350 |
+
# If FAISS failed but BM25 is available, continue with BM25-only (no stub)
|
| 351 |
+
if not faiss_ranks and self._bm25 is not None:
|
| 352 |
+
_top_faiss_cosine = 0.0 # no FAISS score available
|
| 353 |
+
logger.warning("FAISS model unavailable — using BM25-only search for this query.")
|
| 354 |
+
|
| 355 |
+
# Only return empty if BOTH are completely unavailable
|
| 356 |
+
if not faiss_ranks and self._bm25 is None:
|
| 357 |
+
logger.error("Both FAISS and BM25 are unavailable. Cannot retrieve. Check that the index exists and dependencies are installed.")
|
| 358 |
+
return []
|
| 359 |
+
|
| 360 |
+
# ── 2. BM25 keyword search ────────────────────────────────────
|
| 361 |
+
bm25_ranks: dict[int, int] = {}
|
| 362 |
+
if self._bm25 is not None:
|
| 363 |
+
query_tokens = query.lower().split()
|
| 364 |
+
bm25_scores_arr = self._bm25.get_scores(query_tokens)
|
| 365 |
+
# Get top fetch_k indices by BM25 score
|
| 366 |
+
top_bm25 = np.argsort(bm25_scores_arr)[::-1][:fetch_k]
|
| 367 |
+
for rank, corpus_pos in enumerate(top_bm25, 1):
|
| 368 |
+
if bm25_scores_arr[corpus_pos] > 0:
|
| 369 |
+
faiss_idx = self._bm25_ids[corpus_pos]
|
| 370 |
+
bm25_ranks[faiss_idx] = rank
|
| 371 |
+
|
| 372 |
+
# ── 3. Reciprocal Rank Fusion ─────────────────────────────────
|
| 373 |
+
# Score = 1/(k+rank_faiss) + 1/(k+rank_bm25)
|
| 374 |
+
# A chunk only in FAISS gets 1/(60+rank); only in BM25 gets 1/(60+rank)
|
| 375 |
+
# A chunk in BOTH gets the sum — it floats to the top
|
| 376 |
+
all_ids = set(faiss_ranks.keys()) | set(bm25_ranks.keys())
|
| 377 |
+
rrf_scores: dict[int, float] = {}
|
| 378 |
+
for faiss_idx in all_ids:
|
| 379 |
+
score = 0.0
|
| 380 |
+
if faiss_idx in faiss_ranks:
|
| 381 |
+
score += 1.0 / (RRF_K + faiss_ranks[faiss_idx])
|
| 382 |
+
if faiss_idx in bm25_ranks:
|
| 383 |
+
score += 1.0 / (RRF_K + bm25_ranks[faiss_idx])
|
| 384 |
+
rrf_scores[faiss_idx] = score
|
| 385 |
+
|
| 386 |
+
# Capture absolute quality BEFORE normalising (used for retrieval confidence gate)
|
| 387 |
+
max_rrf_absolute = max(rrf_scores.values()) if rrf_scores else 0.0
|
| 388 |
+
|
| 389 |
+
# Normalise RRF scores to [0, 1] for display
|
| 390 |
+
if rrf_scores and max_rrf_absolute > 0:
|
| 391 |
+
rrf_scores = {k: v / max_rrf_absolute for k, v in rrf_scores.items()}
|
| 392 |
+
|
| 393 |
+
# Sort by RRF score descending — take RERANK_CANDIDATES (not just top-k)
|
| 394 |
+
candidate_ids = sorted(rrf_scores.keys(), key=lambda i: rrf_scores[i], reverse=True)[:self.RERANK_CANDIDATES]
|
| 395 |
+
|
| 396 |
+
candidates: list[tuple[str, dict, float]] = []
|
| 397 |
+
for faiss_idx in candidate_ids:
|
| 398 |
+
meta = self._metadata.get(faiss_idx, {})
|
| 399 |
+
text = meta.get("chunk_text", "")
|
| 400 |
+
meta["_retrieval_confidence"] = round(max_rrf_absolute, 6)
|
| 401 |
+
meta["_top_faiss_cosine"] = round(_top_faiss_cosine, 4)
|
| 402 |
+
candidates.append((text, meta, rrf_scores[faiss_idx]))
|
| 403 |
+
|
| 404 |
+
# ── Re-ranking ────────────────────────────────────────────────────
|
| 405 |
+
# Cross-encoder scores every (query, chunk) pair directly.
|
| 406 |
+
# No volume bias — the right chunk wins on relevance regardless of source.
|
| 407 |
+
if self._reranker and self._reranker != "unavailable" and len(candidates) > k:
|
| 408 |
+
pairs = [(query, text) for text, _, _ in candidates]
|
| 409 |
+
rerank_scores = self._reranker.predict(pairs)
|
| 410 |
+
ranked = sorted(
|
| 411 |
+
zip(rerank_scores, candidates),
|
| 412 |
+
key=lambda x: x[0],
|
| 413 |
+
reverse=True,
|
| 414 |
+
)
|
| 415 |
+
results = [item for _, item in ranked[:k]]
|
| 416 |
+
logger.debug("Re-ranked %d candidates → top-%d", len(candidates), k)
|
| 417 |
+
else:
|
| 418 |
+
results = candidates[:k]
|
| 419 |
+
|
| 420 |
+
logger.debug(
|
| 421 |
+
"Hybrid query '%s...' → %d results (top RRF=%.4f) "
|
| 422 |
+
"[FAISS candidates: %d, BM25 candidates: %d]",
|
| 423 |
+
query[:40], len(results),
|
| 424 |
+
results[0][2] if results else 0.0,
|
| 425 |
+
len(faiss_ranks), len(bm25_ranks),
|
| 426 |
+
)
|
| 427 |
+
return results
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
# ---------------------------------------------------------------------------
|
| 432 |
+
# CLI smoke test
|
| 433 |
+
# ---------------------------------------------------------------------------
|
| 434 |
+
|
| 435 |
+
def _load_config() -> dict:
|
| 436 |
+
with open("config.yaml", "r", encoding="utf-8") as f:
|
| 437 |
+
return yaml.safe_load(f)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
if __name__ == "__main__":
|
| 441 |
+
import src # noqa: F401 — logging
|
| 442 |
+
config = _load_config()
|
| 443 |
+
retriever = Retriever(config)
|
| 444 |
+
|
| 445 |
+
test_queries = [
|
| 446 |
+
"What is the recommended dosage of Metformin for Type 2 Diabetes in elderly patients?",
|
| 447 |
+
"Contraindications of ibuprofen for patients with chronic kidney disease",
|
| 448 |
+
"First-line treatment for hypertension according to clinical guidelines",
|
| 449 |
+
]
|
| 450 |
+
|
| 451 |
+
for query in test_queries:
|
| 452 |
+
print(f"\n{'='*70}")
|
| 453 |
+
print(f"QUERY: {query}")
|
| 454 |
+
print("=" * 70)
|
| 455 |
+
results = retriever.search(query, top_k=3)
|
| 456 |
+
if not results:
|
| 457 |
+
print(" No results — is the FAISS index built?")
|
| 458 |
+
continue
|
| 459 |
+
for rank, (text, meta, score) in enumerate(results, 1):
|
| 460 |
+
print(f"\n Rank {rank} | score={score:.4f} | source={meta.get('source')} | "
|
| 461 |
+
f"tier_type={meta.get('pub_type')}")
|
| 462 |
+
print(f" Title: {meta.get('title', '')[:80]}")
|
| 463 |
+
print(f" Text : {text[:200]}...")
|
tests/test_api.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from fastapi.testclient import TestClient
|
| 3 |
+
from src.api.main import app
|
| 4 |
+
|
| 5 |
+
client = TestClient(app)
|
| 6 |
+
|
| 7 |
+
def test_health_endpoint():
|
| 8 |
+
"""Test that the /health endpoint correctly reports system status."""
|
| 9 |
+
response = client.get("/health")
|
| 10 |
+
assert response.status_code == 200
|
| 11 |
+
data = response.json()
|
| 12 |
+
assert data["status"] == "ok"
|
| 13 |
+
assert "ollama_available" in data
|
| 14 |
+
|
| 15 |
+
def test_evaluate_endpoint():
|
| 16 |
+
"""Test the /evaluate endpoint with mock claims."""
|
| 17 |
+
payload = {
|
| 18 |
+
"question": "Is Metformin safe?",
|
| 19 |
+
"answer": "Metformin is a safe and effective drug. It is recommended.",
|
| 20 |
+
"context_chunks": [
|
| 21 |
+
{
|
| 22 |
+
"chunk_id": "mock-1",
|
| 23 |
+
"text": "Metformin is a first-line medication for the treatment of type 2 diabetes. It is safe.",
|
| 24 |
+
"source": "mock_db",
|
| 25 |
+
"pub_type": "research_abstract",
|
| 26 |
+
"pub_year": 2024,
|
| 27 |
+
"title": "Study on Metformin safety"
|
| 28 |
+
}
|
| 29 |
+
],
|
| 30 |
+
"run_ragas": False
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# Since the evaluation modules load heavy ML models,
|
| 34 |
+
# the first test call might take 10-15s to run.
|
| 35 |
+
response = client.post("/evaluate", json=payload)
|
| 36 |
+
assert response.status_code == 200
|
| 37 |
+
|
| 38 |
+
data = response.json()
|
| 39 |
+
assert "composite_score" in data
|
| 40 |
+
assert "hrs" in data
|
| 41 |
+
assert data["risk_band"] in ["LOW", "MODERATE", "HIGH", "CRITICAL"]
|
| 42 |
+
assert "faithfulness" in data["module_results"]
|
| 43 |
+
|
| 44 |
+
def test_query_invalid_params():
|
| 45 |
+
"""Test the /query validation rules."""
|
| 46 |
+
payload = {
|
| 47 |
+
"question": "Hi", # 2 chars — below min_length=5, triggers 422
|
| 48 |
+
"top_k": 5
|
| 49 |
+
}
|
| 50 |
+
response = client.post("/query", json=payload)
|
| 51 |
+
assert response.status_code == 422 # Unprocessable Entity (Pydantic validation error)
|
tests/test_modules.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from src.modules.faithfulness import score_faithfulness
|
| 3 |
+
from src.modules.source_credibility import score_source_credibility
|
| 4 |
+
from src.modules.contradiction import score_contradiction
|
| 5 |
+
from src.evaluation.aggregator import aggregate
|
| 6 |
+
|
| 7 |
+
def test_source_credibility():
|
| 8 |
+
chunks = [
|
| 9 |
+
{"chunk_id": "c1", "pub_type": "research_abstract", "title": "Mock Paper"},
|
| 10 |
+
{"chunk_id": "c2", "pub_type": "exam_question", "title": "Mock Exam Q"}
|
| 11 |
+
]
|
| 12 |
+
results = score_source_credibility(chunks)
|
| 13 |
+
assert results.score > 0.0
|
| 14 |
+
assert 0.3 <= results.score <= 0.5
|
| 15 |
+
assert results.details["chunk_count"] == 2
|
| 16 |
+
|
| 17 |
+
def test_faithfulness_nli():
|
| 18 |
+
res_entail = score_faithfulness(
|
| 19 |
+
answer="The sky is blue.",
|
| 20 |
+
context_docs=["The sky is colored blue today."]
|
| 21 |
+
)
|
| 22 |
+
assert res_entail.score >= 0.8
|
| 23 |
+
|
| 24 |
+
res_contra = score_faithfulness(
|
| 25 |
+
answer="The sky is red.",
|
| 26 |
+
context_docs=["The sky is completely blue and not red."]
|
| 27 |
+
)
|
| 28 |
+
assert res_contra.score <= 0.2
|
| 29 |
+
|
| 30 |
+
def test_aggregator_logic():
|
| 31 |
+
# Mock config
|
| 32 |
+
test_cfg = {
|
| 33 |
+
"evaluation": {
|
| 34 |
+
"weights": {
|
| 35 |
+
"faithfulness": 0.4,
|
| 36 |
+
"entity_accuracy": 0.2,
|
| 37 |
+
"source_credibility": 0.2,
|
| 38 |
+
"contradiction_risk": 0.2,
|
| 39 |
+
"ragas_composite": 0.0
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
module_results = {
|
| 45 |
+
"faithfulness": {"score": 1.0},
|
| 46 |
+
"entity_verifier": {"score": 1.0},
|
| 47 |
+
"source_credibility": {"score": 0.5},
|
| 48 |
+
"contradiction": {"score": 1.0},
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
class MockResult:
|
| 52 |
+
def __init__(self, score, error=None):
|
| 53 |
+
self.score = score
|
| 54 |
+
self.error = error
|
| 55 |
+
self.latency_ms = 10
|
| 56 |
+
|
| 57 |
+
res = aggregate(
|
| 58 |
+
faithfulness_result=MockResult(1.0),
|
| 59 |
+
entity_result=MockResult(1.0),
|
| 60 |
+
source_result=MockResult(0.5),
|
| 61 |
+
contradiction_result=MockResult(1.0),
|
| 62 |
+
weights=test_cfg["evaluation"]["weights"]
|
| 63 |
+
)
|
| 64 |
+
assert abs(res.score - 0.9) < 0.01
|
| 65 |
+
assert res.details["hrs"] == 10
|
| 66 |
+
assert res.details["risk_band"] == "LOW"
|