Commit Β·
345576d
0
Parent(s):
Deploy MedRAG to Hugging Face Space v4
Browse files- .dockerignore +10 -0
- .gitignore +18 -0
- .streamlit/config.toml +5 -0
- Dockerfile +30 -0
- MedRAG.ipynb +0 -0
- README.md +116 -0
- app.py +316 -0
- data_downloader.py +259 -0
- download_assets.py +121 -0
- gallery_builder.py +406 -0
- render.yaml +20 -0
- requirements-space.txt +11 -0
- requirements.txt +31 -0
- rewrite_metadata.py +43 -0
- start.sh +9 -0
- test_visual_search.py +308 -0
- visual_search.py +358 -0
.dockerignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
.DS_Store
|
| 7 |
+
data/
|
| 8 |
+
index/
|
| 9 |
+
*.zip
|
| 10 |
+
render.yaml
|
.gitignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
| 5 |
+
*.pyd
|
| 6 |
+
.Python
|
| 7 |
+
*.ipynb_checkpoints
|
| 8 |
+
|
| 9 |
+
# Data and indexes
|
| 10 |
+
data/
|
| 11 |
+
index/
|
| 12 |
+
*.zip
|
| 13 |
+
embeddings_heatmap.png
|
| 14 |
+
embeddings_pca.png
|
| 15 |
+
embeddings_raw.png
|
| 16 |
+
|
| 17 |
+
# Large datasets
|
| 18 |
+
chexpert_full/
|
.streamlit/config.toml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[server]
|
| 2 |
+
enableCORS = false
|
| 3 |
+
enableXsrfProtection = false
|
| 4 |
+
maxUploadSize = 200
|
| 5 |
+
headless = true
|
Dockerfile
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 4 |
+
ENV PYTHONUNBUFFERED=1
|
| 5 |
+
ENV PIP_NO_CACHE_DIR=1
|
| 6 |
+
ENV DATA_DIR=/tmp/medrag_data
|
| 7 |
+
ENV HF_HOME=/tmp/hf_cache
|
| 8 |
+
ENV PREFETCH_MODEL=1
|
| 9 |
+
|
| 10 |
+
WORKDIR /app
|
| 11 |
+
|
| 12 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 13 |
+
git \
|
| 14 |
+
libglib2.0-0 \
|
| 15 |
+
libsm6 \
|
| 16 |
+
libxext6 \
|
| 17 |
+
libxrender1 \
|
| 18 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 19 |
+
|
| 20 |
+
COPY requirements-space.txt ./
|
| 21 |
+
RUN pip install --index-url https://download.pytorch.org/whl/cpu torch torchvision \
|
| 22 |
+
&& pip install -r requirements-space.txt
|
| 23 |
+
|
| 24 |
+
COPY . .
|
| 25 |
+
|
| 26 |
+
RUN chmod +x /app/start.sh
|
| 27 |
+
|
| 28 |
+
EXPOSE 7860
|
| 29 |
+
|
| 30 |
+
CMD ["/app/start.sh"]
|
MedRAG.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
README.md
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: MedRAG Diagnostic Assistant
|
| 3 |
+
emoji: π©Ί
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: false
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# MedRAG
|
| 12 |
+
|
| 13 |
+
MedRAG is a multimodal chest X-ray retrieval and diagnostic-assistance app built on:
|
| 14 |
+
- BiomedCLIP for image embeddings and zero-shot disease scoring
|
| 15 |
+
- FAISS for similar-case retrieval
|
| 16 |
+
- a crosscheck layer that combines classifier output with retrieved case evidence
|
| 17 |
+
- Streamlit for the application UI
|
| 18 |
+
|
| 19 |
+
The current app supports:
|
| 20 |
+
- chest X-ray upload
|
| 21 |
+
- sample-image testing
|
| 22 |
+
- similar-case retrieval from the indexed gallery
|
| 23 |
+
- zero-shot disease probability ranking
|
| 24 |
+
- retrieval-supported clinical assessment text
|
| 25 |
+
- Hugging Face Spaces deployment through Docker
|
| 26 |
+
|
| 27 |
+
## Current App Flow
|
| 28 |
+
|
| 29 |
+
1. The user uploads a chest X-ray or selects a sample image.
|
| 30 |
+
2. The app encodes the image with BiomedCLIP.
|
| 31 |
+
3. FAISS retrieves the most visually similar historical cases.
|
| 32 |
+
4. BiomedCLIP scores 14 CheXpert disease prompts.
|
| 33 |
+
5. A crosscheck step combines retrieval agreement with classifier confidence.
|
| 34 |
+
6. The app renders:
|
| 35 |
+
- generated clinical assessment
|
| 36 |
+
- ranked diagnoses
|
| 37 |
+
- top disease probabilities
|
| 38 |
+
- similar historical cases
|
| 39 |
+
|
| 40 |
+
## Project Files
|
| 41 |
+
|
| 42 |
+
Core app:
|
| 43 |
+
- `app.py` - Streamlit UI and diagnosis pipeline
|
| 44 |
+
- `visual_search.py` - FAISS-backed visual search engine
|
| 45 |
+
- `download_assets.py` - downloads demo index/images and prefetches BiomedCLIP
|
| 46 |
+
|
| 47 |
+
Index/data tooling:
|
| 48 |
+
- `gallery_builder.py` - build FAISS index from chest X-ray images
|
| 49 |
+
- `data_downloader.py` - download source datasets
|
| 50 |
+
- `rewrite_metadata.py` - rewrite metadata filepaths for deployment
|
| 51 |
+
|
| 52 |
+
Research/demo:
|
| 53 |
+
- `MedRAG.ipynb` - notebook containing the retrieval, zero-shot classification, and crosscheck logic that the app was ported from
|
| 54 |
+
|
| 55 |
+
Deployment:
|
| 56 |
+
- `Dockerfile` - Hugging Face Spaces container build
|
| 57 |
+
- `start.sh` - startup entrypoint for Spaces
|
| 58 |
+
- `requirements-space.txt` - CPU-friendly dependencies for Spaces
|
| 59 |
+
- `render.yaml` - older Render deployment config
|
| 60 |
+
|
| 61 |
+
## Hugging Face Spaces
|
| 62 |
+
|
| 63 |
+
This repo is configured for a Docker Space.
|
| 64 |
+
|
| 65 |
+
### Deploy steps
|
| 66 |
+
|
| 67 |
+
1. Create a new Hugging Face Space.
|
| 68 |
+
2. Choose `Docker`.
|
| 69 |
+
3. Push this repo to the Space remote.
|
| 70 |
+
4. Let the Space build and start.
|
| 71 |
+
|
| 72 |
+
The Space startup does the following:
|
| 73 |
+
- installs CPU-only PyTorch
|
| 74 |
+
- downloads the public `index.zip` and `images.zip`
|
| 75 |
+
- prefetches the BiomedCLIP model
|
| 76 |
+
- starts Streamlit on port `7860`
|
| 77 |
+
|
| 78 |
+
## Local Run
|
| 79 |
+
|
| 80 |
+
Install dependencies:
|
| 81 |
+
|
| 82 |
+
```bash
|
| 83 |
+
pip install --index-url https://download.pytorch.org/whl/cpu torch torchvision
|
| 84 |
+
pip install -r requirements-space.txt
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
Run the app:
|
| 88 |
+
|
| 89 |
+
```bash
|
| 90 |
+
python download_assets.py
|
| 91 |
+
streamlit run app.py
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
## Data Notes
|
| 95 |
+
|
| 96 |
+
The deployed demo uses a reduced subset of CheXpert so it can run on free CPU infrastructure.
|
| 97 |
+
|
| 98 |
+
Assets are pulled from public Google Drive links by default:
|
| 99 |
+
- FAISS index archive
|
| 100 |
+
- subset image archive
|
| 101 |
+
|
| 102 |
+
If needed, override them with:
|
| 103 |
+
- `GDRIVE_INDEX_URL`
|
| 104 |
+
- `GDRIVE_IMAGES_URL`
|
| 105 |
+
|
| 106 |
+
Optional environment variables:
|
| 107 |
+
- `DATA_DIR`
|
| 108 |
+
- `HF_HOME`
|
| 109 |
+
- `PREFETCH_MODEL`
|
| 110 |
+
|
| 111 |
+
## Limitations
|
| 112 |
+
|
| 113 |
+
- The app is a diagnostic aid, not a clinical decision system.
|
| 114 |
+
- Free-tier hosting will have slow cold starts.
|
| 115 |
+
- The generated assessment is rule-based synthesis from model scores and retrieval support, not a physician-grade interpretation.
|
| 116 |
+
- The original project plan referenced a larger multi-agent/LLM flow; the current deployed app implements the retrieval + classifier + crosscheck path from the notebook.
|
app.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import shutil
|
| 4 |
+
from collections import Counter
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import streamlit as st
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
from visual_search import VisualSearchEngine
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
APP_TITLE = "Multimodal Medical RAG Diagnostic Assistant"
|
| 15 |
+
MODEL_ID = "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
|
| 16 |
+
DISEASE_PROMPTS = {
|
| 17 |
+
"No Finding": "Chest X-ray with no abnormality, normal findings",
|
| 18 |
+
"Enlarged Cardiomediastinum": "Chest X-ray showing enlarged cardiomediastinum",
|
| 19 |
+
"Cardiomegaly": "Chest X-ray showing cardiomegaly, enlarged heart",
|
| 20 |
+
"Lung Opacity": "Chest X-ray showing lung opacity",
|
| 21 |
+
"Lung Lesion": "Chest X-ray showing lung lesion or mass",
|
| 22 |
+
"Edema": "Chest X-ray showing pulmonary edema, fluid in lungs",
|
| 23 |
+
"Consolidation": "Chest X-ray showing consolidation in lung",
|
| 24 |
+
"Pneumonia": "Chest X-ray showing pneumonia, lung infection",
|
| 25 |
+
"Atelectasis": "Chest X-ray showing atelectasis, collapsed lung",
|
| 26 |
+
"Pneumothorax": "Chest X-ray showing pneumothorax, air in pleural space",
|
| 27 |
+
"Pleural Effusion": "Chest X-ray showing pleural effusion, fluid around lung",
|
| 28 |
+
"Pleural Other": "Chest X-ray showing pleural abnormality",
|
| 29 |
+
"Fracture": "Chest X-ray showing rib fracture or bone fracture",
|
| 30 |
+
"Support Devices": "Chest X-ray showing support devices, tubes or lines",
|
| 31 |
+
}
|
| 32 |
+
INPUT_GUARDRAIL_PROMPTS = {
|
| 33 |
+
"Chest X-ray": "A diagnostic chest X-ray radiograph showing the thorax and lungs",
|
| 34 |
+
"Portrait Photo": "A portrait photograph of a person or celebrity",
|
| 35 |
+
"Animal Photo": "A natural photograph of an animal or pet",
|
| 36 |
+
"Document Screenshot": "A screenshot of a document, website, or computer interface",
|
| 37 |
+
"Natural Image": "A normal everyday color photograph of a scene or object",
|
| 38 |
+
}
|
| 39 |
+
SYNONYMS = {
|
| 40 |
+
"Pleural Effusion": ["pleural fluid", "fluid around lung", "effusion"],
|
| 41 |
+
"Cardiomegaly": ["enlarged heart", "cardiac enlargement"],
|
| 42 |
+
"Pneumonia": ["lung infection", "consolidation"],
|
| 43 |
+
"Edema": ["fluid in lungs", "pulmonary edema"],
|
| 44 |
+
"Atelectasis": ["collapsed lung", "lung collapse"],
|
| 45 |
+
"Lung Opacity": ["opacity", "haziness", "infiltrate"],
|
| 46 |
+
"No Finding": ["normal", "no abnormality", "clear"],
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _get_paths() -> tuple[Path, Path]:
|
| 51 |
+
repo_index = Path("index").resolve()
|
| 52 |
+
data_dir = Path(os.getenv("DATA_DIR", "/tmp/medrag_data")).resolve()
|
| 53 |
+
index_dir = Path(os.getenv("INDEX_DIR", data_dir / "index")).resolve()
|
| 54 |
+
return repo_index, index_dir
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _ensure_index_available() -> Path:
|
| 58 |
+
repo_index, index_dir = _get_paths()
|
| 59 |
+
if index_dir.exists():
|
| 60 |
+
return index_dir
|
| 61 |
+
if repo_index.exists():
|
| 62 |
+
index_dir.parent.mkdir(parents=True, exist_ok=True)
|
| 63 |
+
shutil.copytree(repo_index, index_dir)
|
| 64 |
+
return index_dir
|
| 65 |
+
raise FileNotFoundError("FAISS index not found. Expected at DATA_DIR/index or ./index")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@st.cache_resource(show_spinner=True)
|
| 69 |
+
def _load_engine() -> VisualSearchEngine:
|
| 70 |
+
index_dir = _ensure_index_available()
|
| 71 |
+
return VisualSearchEngine(index_dir=index_dir, device="auto", top_k=5)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@st.cache_resource(show_spinner=False)
|
| 75 |
+
def _load_text_features() -> tuple[list[str], torch.Tensor]:
|
| 76 |
+
engine = _load_engine()
|
| 77 |
+
tokenizer = __import__("open_clip").get_tokenizer(MODEL_ID)
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
tokens = tokenizer(list(DISEASE_PROMPTS.values())).to(engine.device)
|
| 80 |
+
text_features = engine._model.encode_text(tokens)
|
| 81 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 82 |
+
return list(DISEASE_PROMPTS.keys()), text_features
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@st.cache_resource(show_spinner=False)
|
| 86 |
+
def _load_guardrail_features() -> tuple[list[str], torch.Tensor]:
|
| 87 |
+
engine = _load_engine()
|
| 88 |
+
tokenizer = __import__("open_clip").get_tokenizer(MODEL_ID)
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
tokens = tokenizer(list(INPUT_GUARDRAIL_PROMPTS.values())).to(engine.device)
|
| 91 |
+
text_features = engine._model.encode_text(tokens)
|
| 92 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 93 |
+
return list(INPUT_GUARDRAIL_PROMPTS.keys()), text_features
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _pick_sample_image(data_dir: Path) -> Path | None:
|
| 97 |
+
images_dir = data_dir / "images"
|
| 98 |
+
if not images_dir.exists():
|
| 99 |
+
return None
|
| 100 |
+
candidates = list(images_dir.glob("*.jpg")) + list(images_dir.glob("*.png")) + list(images_dir.glob("*.jpeg"))
|
| 101 |
+
if not candidates:
|
| 102 |
+
return None
|
| 103 |
+
return random.choice(candidates)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@torch.no_grad()
|
| 107 |
+
def _predict_diseases(image: Image.Image) -> dict[str, float]:
|
| 108 |
+
engine = _load_engine()
|
| 109 |
+
disease_names, text_features = _load_text_features()
|
| 110 |
+
tensor = engine._transform(image.convert("RGB")).unsqueeze(0).to(engine.device)
|
| 111 |
+
image_features = engine._model.encode_image(tensor)
|
| 112 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 113 |
+
similarities = (image_features @ text_features.T).squeeze(0)
|
| 114 |
+
probs = torch.softmax(similarities * 100, dim=0).detach().cpu().tolist()
|
| 115 |
+
results = {
|
| 116 |
+
disease_names[i]: round(float(probs[i]) * 100, 2)
|
| 117 |
+
for i in range(len(disease_names))
|
| 118 |
+
}
|
| 119 |
+
return dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@torch.no_grad()
|
| 123 |
+
def _validate_input_image(image: Image.Image) -> tuple[bool, dict[str, float]]:
|
| 124 |
+
engine = _load_engine()
|
| 125 |
+
labels, text_features = _load_guardrail_features()
|
| 126 |
+
tensor = engine._transform(image.convert("RGB")).unsqueeze(0).to(engine.device)
|
| 127 |
+
image_features = engine._model.encode_image(tensor)
|
| 128 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 129 |
+
similarities = (image_features @ text_features.T).squeeze(0)
|
| 130 |
+
probs = torch.softmax(similarities * 100, dim=0).detach().cpu().tolist()
|
| 131 |
+
scores = {labels[i]: round(float(probs[i]) * 100, 2) for i in range(len(labels))}
|
| 132 |
+
chest_score = scores["Chest X-ray"]
|
| 133 |
+
next_best = max(score for label, score in scores.items() if label != "Chest X-ray")
|
| 134 |
+
is_valid = chest_score >= 55 and chest_score > next_best
|
| 135 |
+
return is_valid, dict(sorted(scores.items(), key=lambda item: item[1], reverse=True))
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _labels_match(disease: str, label_str: str) -> bool:
|
| 139 |
+
label_lower = label_str.lower()
|
| 140 |
+
if disease.lower() in label_lower:
|
| 141 |
+
return True
|
| 142 |
+
return any(syn.lower() in label_lower for syn in SYNONYMS.get(disease, []))
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _crosscheck(similar_cases, disease_probs: dict[str, float]) -> list[dict]:
|
| 146 |
+
top_diseases = list(disease_probs.keys())[:5]
|
| 147 |
+
diagnosis = []
|
| 148 |
+
total_cases = max(len(similar_cases), 1)
|
| 149 |
+
|
| 150 |
+
for disease in top_diseases:
|
| 151 |
+
llm_prob = disease_probs[disease]
|
| 152 |
+
matching_cases = sum(1 for case in similar_cases if _labels_match(disease, case.labels))
|
| 153 |
+
gallery_support = matching_cases / total_cases
|
| 154 |
+
confidence = (llm_prob / 100 * 0.5) + (gallery_support * 0.5)
|
| 155 |
+
if gallery_support >= 0.6 and llm_prob >= 20:
|
| 156 |
+
status = "HIGH"
|
| 157 |
+
elif gallery_support >= 0.3 or llm_prob >= 15:
|
| 158 |
+
status = "MEDIUM"
|
| 159 |
+
else:
|
| 160 |
+
status = "LOW"
|
| 161 |
+
diagnosis.append({
|
| 162 |
+
"disease": disease,
|
| 163 |
+
"llm_probability": llm_prob,
|
| 164 |
+
"matching_cases": matching_cases,
|
| 165 |
+
"total_cases": total_cases,
|
| 166 |
+
"gallery_support": f"{matching_cases}/{total_cases} cases",
|
| 167 |
+
"confidence": round(confidence * 100, 1),
|
| 168 |
+
"status": status,
|
| 169 |
+
})
|
| 170 |
+
return sorted(diagnosis, key=lambda item: item["confidence"], reverse=True)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _positive_labels(label_str: str) -> list[str]:
|
| 174 |
+
positives = []
|
| 175 |
+
for part in label_str.split(" | "):
|
| 176 |
+
if ": Positive" in part:
|
| 177 |
+
positives.append(part.split(":")[0])
|
| 178 |
+
return positives
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _generate_assessment(diagnosis: list[dict], similar_cases) -> str:
|
| 182 |
+
primary = diagnosis[0]
|
| 183 |
+
top_positive_labels = Counter()
|
| 184 |
+
for case in similar_cases:
|
| 185 |
+
top_positive_labels.update(_positive_labels(case.labels))
|
| 186 |
+
|
| 187 |
+
supporting_findings = ", ".join(label for label, _ in top_positive_labels.most_common(3)) or "no repeated positive findings"
|
| 188 |
+
differential = ", ".join(item["disease"] for item in diagnosis[1:4])
|
| 189 |
+
|
| 190 |
+
return f"""
|
| 191 |
+
## Primary Clinical Impression
|
| 192 |
+
|
| 193 |
+
Based on visual similarity retrieval and zero-shot disease classification, the leading impression is **{primary["disease"]}** with a combined confidence of **{primary["confidence"]}%**.
|
| 194 |
+
|
| 195 |
+
## Evidence Summary
|
| 196 |
+
|
| 197 |
+
- The classifier estimated **{primary["llm_probability"]}%** probability for {primary["disease"]}.
|
| 198 |
+
- The retrieval engine found **{primary["gallery_support"]}** similar cases supporting this diagnosis.
|
| 199 |
+
- The most repeated positive findings among retrieved cases were: **{supporting_findings}**.
|
| 200 |
+
|
| 201 |
+
## Differential Diagnosis
|
| 202 |
+
|
| 203 |
+
Alternative conditions to consider are **{differential}**. These remain relevant because visually similar cases include overlapping thoracic findings common across chest X-ray pathology.
|
| 204 |
+
|
| 205 |
+
## Clinical Note
|
| 206 |
+
|
| 207 |
+
This is a retrieval-supported decision aid, not a definitive medical diagnosis. Final interpretation should be confirmed by a radiologist or clinician.
|
| 208 |
+
""".strip()
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def _run_analysis(image: Image.Image, top_k: int):
|
| 212 |
+
engine = _load_engine()
|
| 213 |
+
similar_cases = engine.search(image, top_k=top_k, load_images=False)
|
| 214 |
+
disease_probs = _predict_diseases(image)
|
| 215 |
+
diagnosis = _crosscheck(similar_cases, disease_probs)
|
| 216 |
+
assessment = _generate_assessment(diagnosis, similar_cases)
|
| 217 |
+
return similar_cases, disease_probs, diagnosis, assessment
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _render_similar_cases(similar_cases):
|
| 221 |
+
st.markdown("### Similar Historical Cases")
|
| 222 |
+
for idx, case in enumerate(similar_cases, start=1):
|
| 223 |
+
cols = st.columns([1, 3])
|
| 224 |
+
with cols[0]:
|
| 225 |
+
if case.filepath and Path(case.filepath).exists():
|
| 226 |
+
try:
|
| 227 |
+
st.image(Image.open(case.filepath).convert("RGB"), use_container_width=True)
|
| 228 |
+
except Exception:
|
| 229 |
+
st.caption("Preview unavailable")
|
| 230 |
+
with cols[1]:
|
| 231 |
+
st.markdown(f"**#{idx} {case.filename}**")
|
| 232 |
+
st.write(f"Similarity: {case.similarity:.3f}")
|
| 233 |
+
positives = _positive_labels(case.labels)
|
| 234 |
+
st.write(f"Confirmed findings: {', '.join(positives) if positives else 'None'}")
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def main():
|
| 238 |
+
st.set_page_config(page_title=APP_TITLE, layout="wide")
|
| 239 |
+
st.title(APP_TITLE)
|
| 240 |
+
st.caption(
|
| 241 |
+
"Upload a chest X-ray. The system retrieves similar historical cases and generates a retrieval-supported differential diagnosis."
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
with st.sidebar:
|
| 245 |
+
st.markdown("**Index Status**")
|
| 246 |
+
try:
|
| 247 |
+
index_dir = _ensure_index_available()
|
| 248 |
+
st.write(f"Index dir: `{index_dir}`")
|
| 249 |
+
data_dir = index_dir.parent
|
| 250 |
+
except FileNotFoundError as exc:
|
| 251 |
+
st.error(str(exc))
|
| 252 |
+
return
|
| 253 |
+
|
| 254 |
+
top_k = st.slider("Retrieved Cases", min_value=3, max_value=20, value=5, step=1)
|
| 255 |
+
if st.button("Use Sample Image"):
|
| 256 |
+
st.session_state["sample_path"] = str(_pick_sample_image(data_dir) or "")
|
| 257 |
+
if st.button("Clear"):
|
| 258 |
+
st.session_state.pop("sample_path", None)
|
| 259 |
+
st.session_state.pop("analysis_ready", None)
|
| 260 |
+
st.rerun()
|
| 261 |
+
st.caption("First analysis can still be slow on Render free tier.")
|
| 262 |
+
|
| 263 |
+
uploaded = st.file_uploader("Upload Patient Chest X-Ray", type=["png", "jpg", "jpeg"])
|
| 264 |
+
sample_path = st.session_state.get("sample_path")
|
| 265 |
+
|
| 266 |
+
query_image = None
|
| 267 |
+
if uploaded is not None:
|
| 268 |
+
query_image = Image.open(uploaded).convert("RGB")
|
| 269 |
+
st.session_state["analysis_ready"] = True
|
| 270 |
+
elif sample_path:
|
| 271 |
+
query_image = Image.open(sample_path).convert("RGB")
|
| 272 |
+
st.session_state["analysis_ready"] = True
|
| 273 |
+
|
| 274 |
+
left, right = st.columns([1.05, 1.25])
|
| 275 |
+
|
| 276 |
+
with left:
|
| 277 |
+
st.markdown("### Input X-Ray")
|
| 278 |
+
if query_image is not None:
|
| 279 |
+
st.image(query_image, use_container_width=True)
|
| 280 |
+
else:
|
| 281 |
+
st.info("Upload an image or use the sample button.")
|
| 282 |
+
|
| 283 |
+
with right:
|
| 284 |
+
st.markdown("### Generated Clinical Assessment")
|
| 285 |
+
if query_image is None:
|
| 286 |
+
st.info("Run an analysis to generate the assessment.")
|
| 287 |
+
return
|
| 288 |
+
|
| 289 |
+
if st.button("Submit", type="primary") or st.session_state.get("analysis_ready"):
|
| 290 |
+
with st.spinner("Running retrieval, classification, and crosscheck..."):
|
| 291 |
+
is_valid_xray, input_scores = _validate_input_image(query_image)
|
| 292 |
+
if not is_valid_xray:
|
| 293 |
+
st.error("This tool only supports chest X-ray images. Please upload a chest radiograph.")
|
| 294 |
+
st.markdown("### Input Validation")
|
| 295 |
+
for label, score in list(input_scores.items())[:3]:
|
| 296 |
+
st.write(f"{label}: {score}%")
|
| 297 |
+
st.session_state["analysis_ready"] = False
|
| 298 |
+
return
|
| 299 |
+
similar_cases, disease_probs, diagnosis, assessment = _run_analysis(query_image, top_k)
|
| 300 |
+
|
| 301 |
+
st.markdown(assessment)
|
| 302 |
+
st.markdown("### Ranked Diagnoses")
|
| 303 |
+
for item in diagnosis:
|
| 304 |
+
st.write(
|
| 305 |
+
f"**{item['disease']}** | classifier {item['llm_probability']}% | "
|
| 306 |
+
f"gallery {item['gallery_support']} | confidence {item['confidence']}% [{item['status']}]"
|
| 307 |
+
)
|
| 308 |
+
st.markdown("### Top Disease Probabilities")
|
| 309 |
+
for disease, prob in list(disease_probs.items())[:5]:
|
| 310 |
+
st.write(f"{disease}: {prob}%")
|
| 311 |
+
_render_similar_cases(similar_cases)
|
| 312 |
+
st.session_state["analysis_ready"] = False
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
if __name__ == "__main__":
|
| 316 |
+
main()
|
data_downloader.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data_downloader.py
|
| 3 |
+
ββββββββββββββββββ
|
| 4 |
+
Downloads the NIH ChestX-ray14 dataset sample (5,606 images, ~1.2 GB).
|
| 5 |
+
This is the public domain dataset used to build the visual_db.index.
|
| 6 |
+
|
| 7 |
+
The NIH dataset contains 14 disease labels per image in the CSV metadata:
|
| 8 |
+
Atelectasis, Cardiomegaly, Effusion, Infiltration, Mass, Nodule,
|
| 9 |
+
Pneumonia, Pneumothorax, Consolidation, Edema, Emphysema, Fibrosis,
|
| 10 |
+
Pleural_Thickening, Hernia (plus "No Finding")
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python data_downloader.py --output_dir ./data
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
import time
|
| 19 |
+
import zipfile
|
| 20 |
+
import argparse
|
| 21 |
+
import requests
|
| 22 |
+
import pandas as pd
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
# ββ NIH ChestX-ray14 public download URLs βββββββββββββββββββββββββββββββββββββ
|
| 27 |
+
# Source: https://nihcc.app.box.com/v/ChestXray-NIHCC
|
| 28 |
+
# The NIH provides 12 batch ZIPs + 1 metadata CSV.
|
| 29 |
+
# We use only the FIRST batch (images_001.tar.gz β ~1.1 GB, 4,999 images)
|
| 30 |
+
# for a fast bootstrap. Add more batches for larger gallery.
|
| 31 |
+
|
| 32 |
+
NIH_METADATA_URL = (
|
| 33 |
+
"https://raw.githubusercontent.com/ieee8023/covid-chestxray-dataset/"
|
| 34 |
+
"master/metadata.csv" # placeholder β real URL below
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Real NIH metadata (hosted on Kaggle mirror for convenience)
|
| 38 |
+
NIH_KAGGLE_METADATA = "https://raw.githubusercontent.com/mlmed/torchxrayvision/master/torchxrayvision/data_dicts/nih_chest_xray_dict.json"
|
| 39 |
+
|
| 40 |
+
# ββ Open-I (Indiana University) β ALWAYS freely available, no login βββββββββββ
|
| 41 |
+
# 7,470 frontal X-rays ~900 MB
|
| 42 |
+
OPENI_BASE = "https://openi.nlm.nih.gov/imgs/collections/"
|
| 43 |
+
OPENI_ARCHIVE = "NLMCXR_png.tgz" # full archive
|
| 44 |
+
OPENI_METADATA_URL = "https://openi.nlm.nih.gov/api/search?q=&it=x&m=1&n=500"
|
| 45 |
+
|
| 46 |
+
# ββ Lightweight fallback: Kaggle chest-xray-pneumonia (1.15 GB) βββββββββββββββ
|
| 47 |
+
# https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia
|
| 48 |
+
# Requires kaggle CLI auth token.
|
| 49 |
+
|
| 50 |
+
SUPPORTED_SOURCES = ["openi", "nih_sample", "local"]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def download_with_progress(url: str, dest_path: Path, chunk_size: int = 8192) -> bool:
|
| 54 |
+
"""Stream-download a file with a tqdm progress bar."""
|
| 55 |
+
try:
|
| 56 |
+
resp = requests.get(url, stream=True, timeout=60)
|
| 57 |
+
resp.raise_for_status()
|
| 58 |
+
total = int(resp.headers.get("content-length", 0))
|
| 59 |
+
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
| 60 |
+
with open(dest_path, "wb") as f, tqdm(
|
| 61 |
+
total=total, unit="B", unit_scale=True,
|
| 62 |
+
desc=dest_path.name, ncols=80
|
| 63 |
+
) as bar:
|
| 64 |
+
for chunk in resp.iter_content(chunk_size=chunk_size):
|
| 65 |
+
f.write(chunk)
|
| 66 |
+
bar.update(len(chunk))
|
| 67 |
+
return True
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"[ERROR] Download failed: {e}")
|
| 70 |
+
return False
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def download_openi(output_dir: Path) -> Path:
|
| 74 |
+
"""
|
| 75 |
+
Download Open-I Indiana University chest X-ray PNG collection.
|
| 76 |
+
Returns the directory containing .png images.
|
| 77 |
+
"""
|
| 78 |
+
import tarfile
|
| 79 |
+
|
| 80 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 81 |
+
archive_path = output_dir / OPENI_ARCHIVE
|
| 82 |
+
images_dir = output_dir / "openi_images"
|
| 83 |
+
|
| 84 |
+
if images_dir.exists() and any(images_dir.glob("*.png")):
|
| 85 |
+
print(f"[SKIP] Open-I images already present at {images_dir}")
|
| 86 |
+
return images_dir
|
| 87 |
+
|
| 88 |
+
print("=" * 60)
|
| 89 |
+
print("Downloading Open-I Indiana X-ray dataset (~900 MB)...")
|
| 90 |
+
print("Source: National Library of Medicine (public domain)")
|
| 91 |
+
print("=" * 60)
|
| 92 |
+
|
| 93 |
+
url = OPENI_BASE + OPENI_ARCHIVE
|
| 94 |
+
if not download_with_progress(url, archive_path):
|
| 95 |
+
raise RuntimeError("Failed to download Open-I archive.")
|
| 96 |
+
|
| 97 |
+
print(f"Extracting to {images_dir}...")
|
| 98 |
+
images_dir.mkdir(exist_ok=True)
|
| 99 |
+
with tarfile.open(archive_path, "r:gz") as tar:
|
| 100 |
+
tar.extractall(path=images_dir)
|
| 101 |
+
|
| 102 |
+
archive_path.unlink() # free disk space
|
| 103 |
+
print(f"[OK] Open-I images extracted β {images_dir}")
|
| 104 |
+
return images_dir
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def download_nih_sample(output_dir: Path, max_images: int = 5000) -> Path:
|
| 108 |
+
"""
|
| 109 |
+
Download NIH ChestX-ray14 batch_01 (~4,999 images, ~1.1 GB).
|
| 110 |
+
Uses direct Box.com links published by NIH.
|
| 111 |
+
"""
|
| 112 |
+
import tarfile
|
| 113 |
+
|
| 114 |
+
NIH_BATCH1_URL = (
|
| 115 |
+
"https://nihcc.box.com/shared/static/"
|
| 116 |
+
"vfk49d74nhbxq3nqjg0900w5nvkorp5c.gz"
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 120 |
+
archive_path = output_dir / "nih_images_001.tar.gz"
|
| 121 |
+
images_dir = output_dir / "nih_images"
|
| 122 |
+
|
| 123 |
+
if images_dir.exists() and any(images_dir.glob("*.png")):
|
| 124 |
+
print(f"[SKIP] NIH images already present at {images_dir}")
|
| 125 |
+
return images_dir
|
| 126 |
+
|
| 127 |
+
print("=" * 60)
|
| 128 |
+
print("Downloading NIH ChestX-ray14 Batch 1 (~1.1 GB)...")
|
| 129 |
+
print("Source: NIH Clinical Center (CC0 license)")
|
| 130 |
+
print("=" * 60)
|
| 131 |
+
|
| 132 |
+
if not download_with_progress(NIH_BATCH1_URL, archive_path):
|
| 133 |
+
raise RuntimeError(
|
| 134 |
+
"Failed to download NIH batch. "
|
| 135 |
+
"Try manual download from: https://nihcc.app.box.com/v/ChestXray-NIHCC"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
print(f"Extracting to {images_dir}...")
|
| 139 |
+
images_dir.mkdir(exist_ok=True)
|
| 140 |
+
with tarfile.open(archive_path, "r:gz") as tar:
|
| 141 |
+
members = tar.getmembers()[:max_images]
|
| 142 |
+
tar.extractall(path=images_dir, members=members)
|
| 143 |
+
|
| 144 |
+
archive_path.unlink()
|
| 145 |
+
print(f"[OK] NIH images extracted β {images_dir}")
|
| 146 |
+
return images_dir
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def download_nih_metadata(output_dir: Path) -> Path:
|
| 150 |
+
"""Download the NIH ChestX-ray14 labels CSV."""
|
| 151 |
+
META_URL = (
|
| 152 |
+
"https://raw.githubusercontent.com/mlmed/torchxrayvision/"
|
| 153 |
+
"master/tests/test_data/nih_data_entry_small.csv"
|
| 154 |
+
)
|
| 155 |
+
# Full metadata (108,948 rows):
|
| 156 |
+
FULL_META_URL = (
|
| 157 |
+
"https://raw.githubusercontent.com/ieee8023/chexnet-dataset/"
|
| 158 |
+
"master/Data_Entry_2017.csv"
|
| 159 |
+
)
|
| 160 |
+
dest = output_dir / "nih_metadata.csv"
|
| 161 |
+
if dest.exists():
|
| 162 |
+
return dest
|
| 163 |
+
print("Downloading NIH metadata CSV...")
|
| 164 |
+
download_with_progress(FULL_META_URL, dest)
|
| 165 |
+
return dest
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def scan_local_images(image_dir: Path) -> list[Path]:
|
| 169 |
+
"""Return all PNG/JPG images in a directory (recursive)."""
|
| 170 |
+
extensions = {".png", ".jpg", ".jpeg"}
|
| 171 |
+
images = [
|
| 172 |
+
p for p in image_dir.rglob("*")
|
| 173 |
+
if p.suffix.lower() in extensions
|
| 174 |
+
]
|
| 175 |
+
print(f"[SCAN] Found {len(images):,} images in {image_dir}")
|
| 176 |
+
return images
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def build_metadata_csv(
|
| 180 |
+
image_dir: Path,
|
| 181 |
+
nih_csv_path: Path | None,
|
| 182 |
+
output_path: Path
|
| 183 |
+
) -> pd.DataFrame:
|
| 184 |
+
"""
|
| 185 |
+
Build a unified metadata CSV:
|
| 186 |
+
filename | filepath | labels | source
|
| 187 |
+
Works whether NIH labels CSV is available or not.
|
| 188 |
+
"""
|
| 189 |
+
images = scan_local_images(image_dir)
|
| 190 |
+
|
| 191 |
+
rows = []
|
| 192 |
+
label_lookup = {}
|
| 193 |
+
|
| 194 |
+
if nih_csv_path and nih_csv_path.exists():
|
| 195 |
+
df_nih = pd.read_csv(nih_csv_path)
|
| 196 |
+
# NIH CSV cols: Image Index, Finding Labels, Patient ID, ...
|
| 197 |
+
for _, row in df_nih.iterrows():
|
| 198 |
+
label_lookup[row["Image Index"]] = row["Finding Labels"]
|
| 199 |
+
|
| 200 |
+
for img_path in images:
|
| 201 |
+
fname = img_path.name
|
| 202 |
+
labels = label_lookup.get(fname, "Unknown")
|
| 203 |
+
rows.append({
|
| 204 |
+
"filename": fname,
|
| 205 |
+
"filepath": str(img_path.resolve()),
|
| 206 |
+
"labels": labels,
|
| 207 |
+
"source": "NIH" if label_lookup else "Unknown",
|
| 208 |
+
})
|
| 209 |
+
|
| 210 |
+
df = pd.DataFrame(rows)
|
| 211 |
+
df.to_csv(output_path, index=False)
|
| 212 |
+
print(f"[OK] Metadata saved β {output_path} ({len(df):,} rows)")
|
| 213 |
+
return df
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def main():
|
| 217 |
+
parser = argparse.ArgumentParser(
|
| 218 |
+
description="Download chest X-ray dataset for gallery builder"
|
| 219 |
+
)
|
| 220 |
+
parser.add_argument(
|
| 221 |
+
"--source", choices=SUPPORTED_SOURCES, default="openi",
|
| 222 |
+
help="Dataset source (default: openi β no login required)"
|
| 223 |
+
)
|
| 224 |
+
parser.add_argument(
|
| 225 |
+
"--output_dir", type=Path, default=Path("./data"),
|
| 226 |
+
help="Directory to save images and metadata"
|
| 227 |
+
)
|
| 228 |
+
parser.add_argument(
|
| 229 |
+
"--local_dir", type=Path, default=None,
|
| 230 |
+
help="Path to existing local image folder (use with --source local)"
|
| 231 |
+
)
|
| 232 |
+
args = parser.parse_args()
|
| 233 |
+
|
| 234 |
+
output_dir: Path = args.output_dir.resolve()
|
| 235 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 236 |
+
|
| 237 |
+
if args.source == "openi":
|
| 238 |
+
images_dir = download_openi(output_dir)
|
| 239 |
+
elif args.source == "nih_sample":
|
| 240 |
+
images_dir = download_nih_sample(output_dir)
|
| 241 |
+
nih_meta = download_nih_metadata(output_dir)
|
| 242 |
+
build_metadata_csv(images_dir, nih_meta, output_dir / "metadata.csv")
|
| 243 |
+
return
|
| 244 |
+
elif args.source == "local":
|
| 245 |
+
if not args.local_dir:
|
| 246 |
+
print("[ERROR] --local_dir is required when --source=local")
|
| 247 |
+
sys.exit(1)
|
| 248 |
+
images_dir = args.local_dir.resolve()
|
| 249 |
+
else:
|
| 250 |
+
print(f"[ERROR] Unknown source: {args.source}")
|
| 251 |
+
sys.exit(1)
|
| 252 |
+
|
| 253 |
+
build_metadata_csv(images_dir, None, output_dir / "metadata.csv")
|
| 254 |
+
print("\nβ
Dataset ready. Next step:")
|
| 255 |
+
print(f" python gallery_builder.py --image_dir {images_dir} --output_dir ./index")
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
if __name__ == "__main__":
|
| 259 |
+
main()
|
download_assets.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
download_assets.py
|
| 3 |
+
------------------
|
| 4 |
+
Downloads index/ and image assets from Google Drive into /var/data.
|
| 5 |
+
|
| 6 |
+
Env vars:
|
| 7 |
+
GDRIVE_INDEX_URL - share link or direct download url for a zip/tar of index/
|
| 8 |
+
GDRIVE_IMAGES_URL - share link or direct download url for a zip/tar of images/
|
| 9 |
+
DATA_DIR - base path (default: /var/data)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import shutil
|
| 14 |
+
import tarfile
|
| 15 |
+
import zipfile
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import gdown
|
| 19 |
+
from huggingface_hub import snapshot_download
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
BIOMEDCLIP_REPO = "microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
|
| 23 |
+
DEFAULT_INDEX_URL = "https://drive.google.com/uc?id=1NwEac0s_qah8L27RO-aFIz2PRfvXC9j0"
|
| 24 |
+
DEFAULT_IMAGES_URL = "https://drive.google.com/uc?id=1LAMNffnw3kFHZXvY9ySR62VxlaChRXyv"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _download(url: str, dest: Path) -> Path:
|
| 28 |
+
dest.parent.mkdir(parents=True, exist_ok=True)
|
| 29 |
+
if dest.exists():
|
| 30 |
+
return dest
|
| 31 |
+
gdown.download(url, str(dest), quiet=False)
|
| 32 |
+
return dest
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _extract(archive: Path, target_dir: Path) -> None:
|
| 36 |
+
target_dir.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
if zipfile.is_zipfile(archive):
|
| 38 |
+
with zipfile.ZipFile(archive, "r") as zf:
|
| 39 |
+
zf.extractall(target_dir)
|
| 40 |
+
elif tarfile.is_tarfile(archive) or archive.name.endswith((".tgz", ".tar.gz", ".gz")):
|
| 41 |
+
with tarfile.open(archive, "r:*") as tf:
|
| 42 |
+
tf.extractall(target_dir)
|
| 43 |
+
else:
|
| 44 |
+
raise ValueError(f"Unsupported archive: {archive}")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _ensure_dir(path: Path) -> None:
|
| 48 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _pick_data_dir() -> Path:
|
| 52 |
+
env_dir = os.getenv("DATA_DIR")
|
| 53 |
+
if env_dir:
|
| 54 |
+
return Path(env_dir).resolve()
|
| 55 |
+
for candidate in (Path("/var/data"), Path("/tmp/medrag_data")):
|
| 56 |
+
try:
|
| 57 |
+
candidate.mkdir(parents=True, exist_ok=True)
|
| 58 |
+
return candidate
|
| 59 |
+
except Exception:
|
| 60 |
+
continue
|
| 61 |
+
return Path("/tmp/medrag_data").resolve()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _prefetch_biomedclip() -> None:
|
| 65 |
+
cache_dir = Path(os.getenv("HF_HOME", "/tmp/hf_cache")).resolve()
|
| 66 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 67 |
+
snapshot_download(
|
| 68 |
+
repo_id=BIOMEDCLIP_REPO,
|
| 69 |
+
cache_dir=str(cache_dir),
|
| 70 |
+
local_dir_use_symlinks=False,
|
| 71 |
+
)
|
| 72 |
+
print(f"BiomedCLIP cached in {cache_dir}")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def main():
|
| 76 |
+
data_dir = _pick_data_dir()
|
| 77 |
+
index_dir = data_dir / "index"
|
| 78 |
+
images_dir = data_dir / "images"
|
| 79 |
+
|
| 80 |
+
index_url = os.getenv("GDRIVE_INDEX_URL", DEFAULT_INDEX_URL)
|
| 81 |
+
images_url = os.getenv("GDRIVE_IMAGES_URL", DEFAULT_IMAGES_URL)
|
| 82 |
+
|
| 83 |
+
_ensure_dir(data_dir)
|
| 84 |
+
|
| 85 |
+
if index_dir.exists() and any(index_dir.iterdir()):
|
| 86 |
+
print(f"Index already present at {index_dir}")
|
| 87 |
+
elif index_url:
|
| 88 |
+
archive = data_dir / "index_archive.zip"
|
| 89 |
+
archive = _download(index_url, archive)
|
| 90 |
+
_extract(archive, index_dir)
|
| 91 |
+
print(f"Index extracted to {index_dir}")
|
| 92 |
+
else:
|
| 93 |
+
print("GDRIVE_INDEX_URL not set; index not downloaded.")
|
| 94 |
+
|
| 95 |
+
if images_dir.exists() and any(images_dir.iterdir()):
|
| 96 |
+
print(f"Images already present at {images_dir}")
|
| 97 |
+
elif images_url:
|
| 98 |
+
archive = data_dir / "images_archive.zip"
|
| 99 |
+
archive = _download(images_url, archive)
|
| 100 |
+
_extract(archive, images_dir)
|
| 101 |
+
print(f"Images extracted to {images_dir}")
|
| 102 |
+
else:
|
| 103 |
+
print("GDRIVE_IMAGES_URL not set; images not downloaded.")
|
| 104 |
+
|
| 105 |
+
# cleanup
|
| 106 |
+
for f in [data_dir / "index_archive.zip", data_dir / "images_archive.zip"]:
|
| 107 |
+
if f.exists():
|
| 108 |
+
try:
|
| 109 |
+
if f.is_file():
|
| 110 |
+
f.unlink()
|
| 111 |
+
else:
|
| 112 |
+
shutil.rmtree(f)
|
| 113 |
+
except Exception:
|
| 114 |
+
pass
|
| 115 |
+
|
| 116 |
+
if os.getenv("PREFETCH_MODEL", "1") == "1":
|
| 117 |
+
_prefetch_biomedclip()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
main()
|
gallery_builder.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
gallery_builder.py
|
| 3 |
+
ββββββββββββββββββ
|
| 4 |
+
Builds the visual search database for Medical X-ray RAG.
|
| 5 |
+
|
| 6 |
+
Pipeline:
|
| 7 |
+
1. Load all X-ray images from --image_dir
|
| 8 |
+
2. Encode each image β 512-dim vector via BiomedCLIP
|
| 9 |
+
3. Normalize + store in FAISS IndexFlatIP (cosine similarity via dot product)
|
| 10 |
+
4. Save: visual_db.index (FAISS binary)
|
| 11 |
+
metadata.json (filename β {path, labels, idx})
|
| 12 |
+
embeddings.npy (raw numpy array, optional backup)
|
| 13 |
+
|
| 14 |
+
BiomedCLIP:
|
| 15 |
+
microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224
|
| 16 |
+
Trained on 15M biomedical image-caption pairs from PubMed Central.
|
| 17 |
+
Zero-shot performance on CheXpert = 0.85+ AUC (no fine-tuning needed).
|
| 18 |
+
|
| 19 |
+
Usage:
|
| 20 |
+
python gallery_builder.py \
|
| 21 |
+
--image_dir ./data/openi_images \
|
| 22 |
+
--output_dir ./index \
|
| 23 |
+
--batch_size 64 \
|
| 24 |
+
--device cpu
|
| 25 |
+
|
| 26 |
+
# Resume interrupted build:
|
| 27 |
+
python gallery_builder.py --image_dir ./data/openi_images --resume
|
| 28 |
+
|
| 29 |
+
Output files:
|
| 30 |
+
./index/visual_db.index β FAISS binary index
|
| 31 |
+
./index/metadata.json β id β {filename, filepath, labels}
|
| 32 |
+
./index/embeddings.npy β (N, 512) float32 array
|
| 33 |
+
./index/build_stats.json β timing + counts
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import os
|
| 37 |
+
import sys
|
| 38 |
+
import json
|
| 39 |
+
import time
|
| 40 |
+
import argparse
|
| 41 |
+
import logging
|
| 42 |
+
import numpy as np
|
| 43 |
+
from pathlib import Path
|
| 44 |
+
from typing import Optional
|
| 45 |
+
|
| 46 |
+
import torch
|
| 47 |
+
from torch.utils.data import Dataset, DataLoader
|
| 48 |
+
from PIL import Image, UnidentifiedImageError
|
| 49 |
+
import faiss
|
| 50 |
+
import open_clip
|
| 51 |
+
from tqdm import tqdm
|
| 52 |
+
|
| 53 |
+
logging.basicConfig(
|
| 54 |
+
level=logging.INFO,
|
| 55 |
+
format="%(asctime)s %(levelname)-7s %(message)s",
|
| 56 |
+
datefmt="%H:%M:%S",
|
| 57 |
+
)
|
| 58 |
+
log = logging.getLogger(__name__)
|
| 59 |
+
|
| 60 |
+
# ββ Constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 61 |
+
BIOMEDCLIP_MODEL = "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
|
| 62 |
+
EMBED_DIM = 512
|
| 63 |
+
SUPPORTED_EXTS = {".png", ".jpg", ".jpeg", ".dcm"}
|
| 64 |
+
INDEX_FILE = "visual_db.index"
|
| 65 |
+
METADATA_FILE = "metadata.json"
|
| 66 |
+
EMBEDDINGS_FILE = "embeddings.npy"
|
| 67 |
+
STATS_FILE = "build_stats.json"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ββ Dataset ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 71 |
+
class XRayDataset(Dataset):
|
| 72 |
+
"""
|
| 73 |
+
Lazy-loading dataset for chest X-ray images.
|
| 74 |
+
Applies BiomedCLIP preprocessing (resize 224, normalize).
|
| 75 |
+
Skips corrupt/unreadable files gracefully.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
image_paths: list[Path],
|
| 81 |
+
transform,
|
| 82 |
+
metadata_csv_path: Optional[Path] = None,
|
| 83 |
+
):
|
| 84 |
+
self.paths = image_paths
|
| 85 |
+
self.transform = transform
|
| 86 |
+
self.label_map: dict[str, str] = {}
|
| 87 |
+
|
| 88 |
+
# Optional: load NIH/CheXpert labels CSV
|
| 89 |
+
if metadata_csv_path and metadata_csv_path.exists():
|
| 90 |
+
import pandas as pd
|
| 91 |
+
df = pd.read_csv(metadata_csv_path)
|
| 92 |
+
if "filename" in df.columns and "labels" in df.columns:
|
| 93 |
+
self.label_map = dict(zip(df["filename"], df["labels"].fillna("Unknown")))
|
| 94 |
+
|
| 95 |
+
def __len__(self):
|
| 96 |
+
return len(self.paths)
|
| 97 |
+
|
| 98 |
+
def __getitem__(self, idx: int):
|
| 99 |
+
path = self.paths[idx]
|
| 100 |
+
try:
|
| 101 |
+
img = Image.open(path).convert("RGB")
|
| 102 |
+
tensor = self.transform(img)
|
| 103 |
+
label = self.label_map.get(path.name, "Unknown")
|
| 104 |
+
return tensor, str(path), label, True # (tensor, path, label, valid)
|
| 105 |
+
except (UnidentifiedImageError, OSError, Exception) as e:
|
| 106 |
+
log.warning(f"Skipping corrupt image: {path.name} ({e})")
|
| 107 |
+
# Return a zero tensor so DataLoader batch stays uniform
|
| 108 |
+
dummy = torch.zeros(3, 224, 224)
|
| 109 |
+
return dummy, str(path), "CORRUPT", False
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def collate_skip_corrupt(batch):
|
| 113 |
+
"""Custom collate: filter out corrupt images before batching."""
|
| 114 |
+
valid = [(t, p, l) for t, p, l, ok in batch if ok]
|
| 115 |
+
if not valid:
|
| 116 |
+
return None
|
| 117 |
+
tensors, paths, labels = zip(*valid)
|
| 118 |
+
return torch.stack(tensors), list(paths), list(labels)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# ββ Model loader βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 122 |
+
def load_biomedclip(device: str):
|
| 123 |
+
"""
|
| 124 |
+
Load BiomedCLIP vision encoder from HuggingFace hub.
|
| 125 |
+
Returns (model, transform) where model outputs 512-dim image embeddings.
|
| 126 |
+
"""
|
| 127 |
+
log.info("Loading BiomedCLIP from HuggingFace hub (first run downloads ~350 MB)...")
|
| 128 |
+
try:
|
| 129 |
+
model, _, transform = open_clip.create_model_and_transforms(
|
| 130 |
+
BIOMEDCLIP_MODEL
|
| 131 |
+
)
|
| 132 |
+
model = model.to(device).eval()
|
| 133 |
+
log.info(f"BiomedCLIP loaded β device={device}")
|
| 134 |
+
return model, transform
|
| 135 |
+
except Exception as e:
|
| 136 |
+
log.error(f"Failed to load BiomedCLIP: {e}")
|
| 137 |
+
log.error("Ensure open-clip-torch is installed: pip install open-clip-torch")
|
| 138 |
+
raise
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ββ Embedding engine βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 142 |
+
@torch.no_grad()
|
| 143 |
+
def encode_batch(model, image_tensors: torch.Tensor, device: str) -> np.ndarray:
|
| 144 |
+
"""Encode a batch of image tensors β L2-normalized embeddings (N, 512)."""
|
| 145 |
+
image_tensors = image_tensors.to(device)
|
| 146 |
+
features = model.encode_image(image_tensors)
|
| 147 |
+
# L2 normalize β cosine similarity = dot product
|
| 148 |
+
features = features / features.norm(dim=-1, keepdim=True)
|
| 149 |
+
return features.cpu().numpy().astype(np.float32)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# ββ FAISS index builder ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 153 |
+
def build_faiss_index(embeddings: np.ndarray) -> faiss.IndexFlatIP:
|
| 154 |
+
"""
|
| 155 |
+
Build FAISS IndexFlatIP (inner product = cosine similarity after L2-norm).
|
| 156 |
+
For galleries > 100K images, swap to IndexIVFFlat for 10x faster search.
|
| 157 |
+
"""
|
| 158 |
+
n, d = embeddings.shape
|
| 159 |
+
log.info(f"Building FAISS index ({n:,} vectors Γ {d} dims)")
|
| 160 |
+
|
| 161 |
+
if n < 10_000:
|
| 162 |
+
# Exact search β best for < 10K images
|
| 163 |
+
index = faiss.IndexFlatIP(d)
|
| 164 |
+
else:
|
| 165 |
+
# Approximate search β needed for large galleries
|
| 166 |
+
nlist = min(256, n // 39) # IVF rule: nlist β sqrt(N)
|
| 167 |
+
quantizer = faiss.IndexFlatIP(d)
|
| 168 |
+
index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT)
|
| 169 |
+
log.info(f"Training IVF index with nlist={nlist}...")
|
| 170 |
+
index.train(embeddings)
|
| 171 |
+
index.nprobe = 16 # search 16 cells at query time (accuracy vs speed)
|
| 172 |
+
|
| 173 |
+
index.add(embeddings)
|
| 174 |
+
log.info(f"FAISS index built β total vectors: {index.ntotal:,}")
|
| 175 |
+
return index
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ββ Resume support βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 179 |
+
def load_checkpoint(output_dir: Path) -> tuple[np.ndarray | None, dict | None, int]:
|
| 180 |
+
"""Load partial embeddings + metadata if build was interrupted."""
|
| 181 |
+
emb_ckpt = output_dir / "embeddings_checkpoint.npy"
|
| 182 |
+
meta_ckpt = output_dir / "metadata_checkpoint.json"
|
| 183 |
+
|
| 184 |
+
if emb_ckpt.exists() and meta_ckpt.exists():
|
| 185 |
+
embeddings = np.load(emb_ckpt)
|
| 186 |
+
with open(meta_ckpt) as f:
|
| 187 |
+
metadata = json.load(f)
|
| 188 |
+
start_idx = len(metadata)
|
| 189 |
+
log.info(f"[RESUME] Found checkpoint with {start_idx:,} images. Continuing...")
|
| 190 |
+
return embeddings, metadata, start_idx
|
| 191 |
+
|
| 192 |
+
return None, None, 0
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def save_checkpoint(output_dir: Path, embeddings: np.ndarray, metadata: dict):
|
| 196 |
+
"""Save incremental checkpoint every N batches."""
|
| 197 |
+
np.save(output_dir / "embeddings_checkpoint.npy", embeddings)
|
| 198 |
+
with open(output_dir / "metadata_checkpoint.json", "w") as f:
|
| 199 |
+
json.dump(metadata, f)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# ββ Main pipeline ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 203 |
+
def build_gallery(
|
| 204 |
+
image_dir: Path,
|
| 205 |
+
output_dir: Path,
|
| 206 |
+
batch_size: int = 64,
|
| 207 |
+
device: str = "auto",
|
| 208 |
+
metadata_csv: Optional[Path] = None,
|
| 209 |
+
resume: bool = False,
|
| 210 |
+
checkpoint_every: int = 500,
|
| 211 |
+
):
|
| 212 |
+
"""
|
| 213 |
+
Full pipeline: images β BiomedCLIP embeddings β FAISS index.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
image_dir: Directory containing X-ray images (scanned recursively)
|
| 217 |
+
output_dir: Where to save visual_db.index + metadata.json
|
| 218 |
+
batch_size: Images per GPU/CPU batch (lower if OOM)
|
| 219 |
+
device: "cuda", "cpu", or "auto"
|
| 220 |
+
metadata_csv: Optional CSV with columns: filename, labels
|
| 221 |
+
resume: Resume from last checkpoint if available
|
| 222 |
+
checkpoint_every: Save checkpoint every N images
|
| 223 |
+
"""
|
| 224 |
+
t_start = time.time()
|
| 225 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 226 |
+
|
| 227 |
+
# ββ Resolve device βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 228 |
+
if device == "auto":
|
| 229 |
+
device = "cuda" if torch.cuda.is_available() else (
|
| 230 |
+
"mps" if torch.backends.mps.is_available() else "cpu"
|
| 231 |
+
)
|
| 232 |
+
log.info(f"Device: {device}")
|
| 233 |
+
|
| 234 |
+
# ββ Collect image paths ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 235 |
+
all_images = sorted([
|
| 236 |
+
p for p in image_dir.rglob("*")
|
| 237 |
+
if p.suffix.lower() in SUPPORTED_EXTS
|
| 238 |
+
])
|
| 239 |
+
if not all_images:
|
| 240 |
+
raise FileNotFoundError(f"No images found in {image_dir}")
|
| 241 |
+
log.info(f"Found {len(all_images):,} images in {image_dir}")
|
| 242 |
+
|
| 243 |
+
# ββ Resume checkpoint ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 244 |
+
existing_emb, existing_meta, start_idx = (None, None, 0)
|
| 245 |
+
if resume:
|
| 246 |
+
existing_emb, existing_meta, start_idx = load_checkpoint(output_dir)
|
| 247 |
+
|
| 248 |
+
images_to_process = all_images[start_idx:]
|
| 249 |
+
log.info(f"Images to process: {len(images_to_process):,}")
|
| 250 |
+
|
| 251 |
+
# ββ Load BiomedCLIP ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 252 |
+
model, transform = load_biomedclip(device)
|
| 253 |
+
|
| 254 |
+
# ββ Dataset + DataLoader βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 255 |
+
dataset = XRayDataset(images_to_process, transform, metadata_csv)
|
| 256 |
+
loader = DataLoader(
|
| 257 |
+
dataset,
|
| 258 |
+
batch_size=batch_size,
|
| 259 |
+
num_workers=min(4, os.cpu_count() or 1),
|
| 260 |
+
pin_memory=(device == "cuda"),
|
| 261 |
+
collate_fn=collate_skip_corrupt,
|
| 262 |
+
prefetch_factor=2 if device == "cuda" else None,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# ββ Accumulate embeddings ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 266 |
+
all_embeddings: list[np.ndarray] = []
|
| 267 |
+
all_metadata: dict = existing_meta or {} # id (int) β {filename, filepath, labels}
|
| 268 |
+
global_idx = start_idx
|
| 269 |
+
skipped = 0
|
| 270 |
+
|
| 271 |
+
log.info("Encoding images with BiomedCLIP...")
|
| 272 |
+
for batch in tqdm(loader, desc="Encoding", unit="batch", ncols=80):
|
| 273 |
+
if batch is None:
|
| 274 |
+
continue
|
| 275 |
+
tensors, paths, labels = batch
|
| 276 |
+
batch_emb = encode_batch(model, tensors, device)
|
| 277 |
+
|
| 278 |
+
for i, (path, label) in enumerate(zip(paths, labels)):
|
| 279 |
+
all_embeddings.append(batch_emb[i])
|
| 280 |
+
all_metadata[str(global_idx)] = {
|
| 281 |
+
"filename": Path(path).name,
|
| 282 |
+
"filepath": path,
|
| 283 |
+
"labels": label,
|
| 284 |
+
"idx": global_idx,
|
| 285 |
+
}
|
| 286 |
+
global_idx += 1
|
| 287 |
+
|
| 288 |
+
# Periodic checkpoint
|
| 289 |
+
if global_idx % checkpoint_every < batch_size:
|
| 290 |
+
combined_emb = np.vstack(
|
| 291 |
+
[existing_emb] + all_embeddings
|
| 292 |
+
if existing_emb is not None else all_embeddings
|
| 293 |
+
)
|
| 294 |
+
save_checkpoint(output_dir, combined_emb, all_metadata)
|
| 295 |
+
log.info(f" Checkpoint saved at {global_idx:,} images")
|
| 296 |
+
|
| 297 |
+
if not all_embeddings:
|
| 298 |
+
raise RuntimeError("No valid images were encoded. Check image directory.")
|
| 299 |
+
|
| 300 |
+
# ββ Stack all embeddings βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 301 |
+
new_embeddings = np.vstack(all_embeddings)
|
| 302 |
+
if existing_emb is not None:
|
| 303 |
+
final_embeddings = np.vstack([existing_emb, new_embeddings])
|
| 304 |
+
else:
|
| 305 |
+
final_embeddings = new_embeddings
|
| 306 |
+
|
| 307 |
+
log.info(f"Embeddings shape: {final_embeddings.shape}")
|
| 308 |
+
|
| 309 |
+
# ββ Build + save FAISS index βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 310 |
+
index = build_faiss_index(final_embeddings)
|
| 311 |
+
index_path = output_dir / INDEX_FILE
|
| 312 |
+
faiss.write_index(index, str(index_path))
|
| 313 |
+
log.info(f"FAISS index saved β {index_path} ({index_path.stat().st_size / 1e6:.1f} MB)")
|
| 314 |
+
|
| 315 |
+
# ββ Save metadata ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 316 |
+
meta_path = output_dir / METADATA_FILE
|
| 317 |
+
with open(meta_path, "w") as f:
|
| 318 |
+
json.dump(all_metadata, f, indent=2)
|
| 319 |
+
log.info(f"Metadata saved β {meta_path}")
|
| 320 |
+
|
| 321 |
+
# ββ Save raw embeddings (optional, useful for offline analysis) ββββββββββββ
|
| 322 |
+
emb_path = output_dir / EMBEDDINGS_FILE
|
| 323 |
+
np.save(emb_path, final_embeddings)
|
| 324 |
+
log.info(f"Embeddings saved β {emb_path} ({emb_path.stat().st_size / 1e6:.1f} MB)")
|
| 325 |
+
|
| 326 |
+
# ββ Clean up checkpoints βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 327 |
+
for ckpt in ["embeddings_checkpoint.npy", "metadata_checkpoint.json"]:
|
| 328 |
+
ckpt_path = output_dir / ckpt
|
| 329 |
+
if ckpt_path.exists():
|
| 330 |
+
ckpt_path.unlink()
|
| 331 |
+
|
| 332 |
+
# ββ Build stats ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 333 |
+
elapsed = time.time() - t_start
|
| 334 |
+
stats = {
|
| 335 |
+
"total_images": index.ntotal,
|
| 336 |
+
"skipped": skipped,
|
| 337 |
+
"embed_dim": EMBED_DIM,
|
| 338 |
+
"model": BIOMEDCLIP_MODEL,
|
| 339 |
+
"index_type": type(index).__name__,
|
| 340 |
+
"build_time_sec": round(elapsed, 1),
|
| 341 |
+
"throughput_img_per_sec": round(index.ntotal / elapsed, 1),
|
| 342 |
+
"index_size_mb": round(index_path.stat().st_size / 1e6, 2),
|
| 343 |
+
"device": device,
|
| 344 |
+
}
|
| 345 |
+
with open(output_dir / STATS_FILE, "w") as f:
|
| 346 |
+
json.dump(stats, f, indent=2)
|
| 347 |
+
|
| 348 |
+
log.info("=" * 55)
|
| 349 |
+
log.info(f"β
Gallery build complete!")
|
| 350 |
+
log.info(f" Images indexed : {index.ntotal:,}")
|
| 351 |
+
log.info(f" Build time : {elapsed:.0f}s ({stats['throughput_img_per_sec']} img/s)")
|
| 352 |
+
log.info(f" Index size : {stats['index_size_mb']} MB")
|
| 353 |
+
log.info(f" Output dir : {output_dir.resolve()}")
|
| 354 |
+
log.info("=" * 55)
|
| 355 |
+
|
| 356 |
+
return index, all_metadata
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
# ββ CLI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 360 |
+
def main():
|
| 361 |
+
parser = argparse.ArgumentParser(
|
| 362 |
+
description="Build FAISS visual search index from chest X-ray images"
|
| 363 |
+
)
|
| 364 |
+
parser.add_argument(
|
| 365 |
+
"--image_dir", type=Path, required=True,
|
| 366 |
+
help="Root directory containing X-ray images (searched recursively)"
|
| 367 |
+
)
|
| 368 |
+
parser.add_argument(
|
| 369 |
+
"--output_dir", type=Path, default=Path("./index"),
|
| 370 |
+
help="Where to save visual_db.index + metadata.json (default: ./index)"
|
| 371 |
+
)
|
| 372 |
+
parser.add_argument(
|
| 373 |
+
"--batch_size", type=int, default=64,
|
| 374 |
+
help="Batch size for encoding. Reduce to 16 if CPU RAM < 8 GB (default: 64)"
|
| 375 |
+
)
|
| 376 |
+
parser.add_argument(
|
| 377 |
+
"--device", choices=["auto", "cuda", "cpu", "mps"], default="auto",
|
| 378 |
+
help="Compute device (default: auto-detect)"
|
| 379 |
+
)
|
| 380 |
+
parser.add_argument(
|
| 381 |
+
"--metadata_csv", type=Path, default=None,
|
| 382 |
+
help="Optional CSV with columns: filename, labels"
|
| 383 |
+
)
|
| 384 |
+
parser.add_argument(
|
| 385 |
+
"--resume", action="store_true",
|
| 386 |
+
help="Resume from last checkpoint if build was interrupted"
|
| 387 |
+
)
|
| 388 |
+
parser.add_argument(
|
| 389 |
+
"--checkpoint_every", type=int, default=500,
|
| 390 |
+
help="Save checkpoint every N images (default: 500)"
|
| 391 |
+
)
|
| 392 |
+
args = parser.parse_args()
|
| 393 |
+
|
| 394 |
+
build_gallery(
|
| 395 |
+
image_dir=args.image_dir.resolve(),
|
| 396 |
+
output_dir=args.output_dir.resolve(),
|
| 397 |
+
batch_size=args.batch_size,
|
| 398 |
+
device=args.device,
|
| 399 |
+
metadata_csv=args.metadata_csv,
|
| 400 |
+
resume=args.resume,
|
| 401 |
+
checkpoint_every=args.checkpoint_every,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
if __name__ == "__main__":
|
| 406 |
+
main()
|
render.yaml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
- type: web
|
| 3 |
+
name: medrag-app
|
| 4 |
+
env: python
|
| 5 |
+
plan: free
|
| 6 |
+
buildCommand: pip install --index-url https://download.pytorch.org/whl/cpu torch torchvision && pip install -r requirements.txt
|
| 7 |
+
startCommand: python download_assets.py && streamlit run app.py --server.port $PORT --server.address 0.0.0.0
|
| 8 |
+
envVars:
|
| 9 |
+
- key: PYTHONUNBUFFERED
|
| 10 |
+
value: "1"
|
| 11 |
+
- key: DATA_DIR
|
| 12 |
+
value: /tmp/medrag_data
|
| 13 |
+
- key: HF_HOME
|
| 14 |
+
value: /tmp/hf_cache
|
| 15 |
+
- key: PREFETCH_MODEL
|
| 16 |
+
value: "1"
|
| 17 |
+
- key: GDRIVE_INDEX_URL
|
| 18 |
+
value: ""
|
| 19 |
+
- key: GDRIVE_IMAGES_URL
|
| 20 |
+
value: ""
|
requirements-space.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
open-clip-torch>=2.24.0
|
| 2 |
+
faiss-cpu>=1.7.4
|
| 3 |
+
Pillow>=10.0.0
|
| 4 |
+
numpy>=1.24.0
|
| 5 |
+
tqdm>=4.66.0
|
| 6 |
+
requests>=2.31.0
|
| 7 |
+
pandas>=2.0.0
|
| 8 |
+
streamlit>=1.31.0
|
| 9 |
+
gdown>=5.1.0
|
| 10 |
+
huggingface-hub>=0.28.0
|
| 11 |
+
transformers>=4.30.0,<5
|
requirements.txt
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Gallery Builder β Python Dependencies
|
| 2 |
+
# Install with: pip install -r requirements.txt
|
| 3 |
+
#
|
| 4 |
+
# GPU support (recommended for faster encoding):
|
| 5 |
+
# pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
|
| 6 |
+
#
|
| 7 |
+
# CPU only:
|
| 8 |
+
# pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
| 9 |
+
|
| 10 |
+
# Core ML
|
| 11 |
+
torch>=2.1.0
|
| 12 |
+
torchvision>=0.16.0
|
| 13 |
+
open-clip-torch>=2.24.0 # BiomedCLIP lives here
|
| 14 |
+
|
| 15 |
+
# Vector database
|
| 16 |
+
faiss-cpu>=1.7.4 # swap for faiss-gpu if CUDA available
|
| 17 |
+
|
| 18 |
+
# Image processing
|
| 19 |
+
Pillow>=10.0.0
|
| 20 |
+
numpy>=1.24.0
|
| 21 |
+
|
| 22 |
+
# Utilities
|
| 23 |
+
tqdm>=4.66.0
|
| 24 |
+
requests>=2.31.0
|
| 25 |
+
pandas>=2.0.0
|
| 26 |
+
streamlit>=1.31.0
|
| 27 |
+
gdown>=5.1.0
|
| 28 |
+
|
| 29 |
+
# Testing
|
| 30 |
+
pytest>=7.4.0
|
| 31 |
+
pytest-cov>=4.1.0
|
rewrite_metadata.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
rewrite_metadata.py
|
| 3 |
+
-------------------
|
| 4 |
+
Utility to rewrite metadata.json filepaths for deployment.
|
| 5 |
+
|
| 6 |
+
Example:
|
| 7 |
+
python rewrite_metadata.py \
|
| 8 |
+
--index_dir ./index \
|
| 9 |
+
--from_prefix "/Users/you/MedRAG/data/train" \
|
| 10 |
+
--to_prefix "/var/data/images"
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import json
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def main():
|
| 19 |
+
parser = argparse.ArgumentParser(description="Rewrite metadata.json filepaths")
|
| 20 |
+
parser.add_argument("--index_dir", type=Path, default=Path("./index"))
|
| 21 |
+
parser.add_argument("--from_prefix", required=True)
|
| 22 |
+
parser.add_argument("--to_prefix", required=True)
|
| 23 |
+
args = parser.parse_args()
|
| 24 |
+
|
| 25 |
+
meta_path = args.index_dir / "metadata.json"
|
| 26 |
+
if not meta_path.exists():
|
| 27 |
+
raise FileNotFoundError(f"metadata.json not found: {meta_path}")
|
| 28 |
+
|
| 29 |
+
data = json.loads(meta_path.read_text())
|
| 30 |
+
updated = 0
|
| 31 |
+
|
| 32 |
+
for _, entry in data.items():
|
| 33 |
+
fp = entry.get("filepath", "")
|
| 34 |
+
if fp.startswith(args.from_prefix):
|
| 35 |
+
entry["filepath"] = fp.replace(args.from_prefix, args.to_prefix, 1)
|
| 36 |
+
updated += 1
|
| 37 |
+
|
| 38 |
+
meta_path.write_text(json.dumps(data, indent=2))
|
| 39 |
+
print(f"Rewrote {updated} filepaths in {meta_path}")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
main()
|
start.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
export DATA_DIR="${DATA_DIR:-/tmp/medrag_data}"
|
| 5 |
+
export HF_HOME="${HF_HOME:-/tmp/hf_cache}"
|
| 6 |
+
export PREFETCH_MODEL="${PREFETCH_MODEL:-1}"
|
| 7 |
+
|
| 8 |
+
python download_assets.py
|
| 9 |
+
exec streamlit run app.py --server.port 7860 --server.address 0.0.0.0
|
test_visual_search.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
test_visual_search.py
|
| 3 |
+
βββββββββββββββββββββ
|
| 4 |
+
Unit + integration tests for the gallery builder pipeline.
|
| 5 |
+
|
| 6 |
+
Run:
|
| 7 |
+
# Fast unit tests (no model needed):
|
| 8 |
+
pytest test_visual_search.py -v -m "not integration"
|
| 9 |
+
|
| 10 |
+
# Full integration test (requires built index):
|
| 11 |
+
pytest test_visual_search.py -v --index_dir ./index --image_dir ./data
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
import tempfile
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pytest
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from unittest.mock import patch, MagicMock
|
| 20 |
+
from PIL import Image
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ββ Fixtures βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 24 |
+
@pytest.fixture
|
| 25 |
+
def dummy_index_dir(tmp_path):
|
| 26 |
+
"""Create a minimal fake FAISS index + metadata for unit tests."""
|
| 27 |
+
import faiss
|
| 28 |
+
|
| 29 |
+
d = 512
|
| 30 |
+
n = 20
|
| 31 |
+
embeddings = np.random.randn(n, d).astype(np.float32)
|
| 32 |
+
# L2 normalize
|
| 33 |
+
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
| 34 |
+
embeddings /= norms
|
| 35 |
+
|
| 36 |
+
index = faiss.IndexFlatIP(d)
|
| 37 |
+
index.add(embeddings)
|
| 38 |
+
faiss.write_index(index, str(tmp_path / "visual_db.index"))
|
| 39 |
+
|
| 40 |
+
metadata = {
|
| 41 |
+
str(i): {
|
| 42 |
+
"filename": f"image_{i:04d}.png",
|
| 43 |
+
"filepath": str(tmp_path / f"image_{i:04d}.png"),
|
| 44 |
+
"labels": "Pneumonia" if i % 3 == 0 else "No Finding",
|
| 45 |
+
"idx": i,
|
| 46 |
+
}
|
| 47 |
+
for i in range(n)
|
| 48 |
+
}
|
| 49 |
+
with open(tmp_path / "metadata.json", "w") as f:
|
| 50 |
+
json.dump(metadata, f)
|
| 51 |
+
|
| 52 |
+
# Create dummy PNG files
|
| 53 |
+
for i in range(n):
|
| 54 |
+
img = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8))
|
| 55 |
+
img.save(tmp_path / f"image_{i:04d}.png")
|
| 56 |
+
|
| 57 |
+
return tmp_path, embeddings
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@pytest.fixture
|
| 61 |
+
def dummy_xray_image(tmp_path) -> Path:
|
| 62 |
+
"""Create a fake grayscale X-ray image."""
|
| 63 |
+
img_array = np.random.randint(0, 255, (224, 224), dtype=np.uint8)
|
| 64 |
+
img = Image.fromarray(img_array, mode="L").convert("RGB")
|
| 65 |
+
path = tmp_path / "test_xray.png"
|
| 66 |
+
img.save(path)
|
| 67 |
+
return path
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ββ Unit tests βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 71 |
+
class TestSearchResult:
|
| 72 |
+
def test_to_dict(self):
|
| 73 |
+
from visual_search import SearchResult
|
| 74 |
+
r = SearchResult(rank=1, idx=5, filename="img.png",
|
| 75 |
+
filepath="/data/img.png", labels="Pneumonia",
|
| 76 |
+
similarity=0.87654)
|
| 77 |
+
d = r.to_dict()
|
| 78 |
+
assert d["rank"] == 1
|
| 79 |
+
assert d["similarity"] == 0.8765 # rounded to 4 decimal places
|
| 80 |
+
assert d["labels"] == "Pneumonia"
|
| 81 |
+
assert "image" not in d # PIL image not serialized
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class TestFAISSIndex:
|
| 85 |
+
"""Test FAISS index properties independent of BiomedCLIP."""
|
| 86 |
+
|
| 87 |
+
def test_build_flat_index(self):
|
| 88 |
+
import faiss
|
| 89 |
+
d, n = 512, 100
|
| 90 |
+
emb = np.random.randn(n, d).astype(np.float32)
|
| 91 |
+
emb /= np.linalg.norm(emb, axis=1, keepdims=True)
|
| 92 |
+
|
| 93 |
+
index = faiss.IndexFlatIP(d)
|
| 94 |
+
index.add(emb)
|
| 95 |
+
assert index.ntotal == n
|
| 96 |
+
|
| 97 |
+
def test_search_returns_correct_k(self):
|
| 98 |
+
import faiss
|
| 99 |
+
d, n = 512, 50
|
| 100 |
+
emb = np.random.randn(n, d).astype(np.float32)
|
| 101 |
+
emb /= np.linalg.norm(emb, axis=1, keepdims=True)
|
| 102 |
+
|
| 103 |
+
index = faiss.IndexFlatIP(d)
|
| 104 |
+
index.add(emb)
|
| 105 |
+
|
| 106 |
+
query = emb[0:1] # use first vector as query
|
| 107 |
+
sims, idxs = index.search(query, k=5)
|
| 108 |
+
assert sims.shape == (1, 5)
|
| 109 |
+
assert idxs.shape == (1, 5)
|
| 110 |
+
# Self-match should be first with similarity β 1.0
|
| 111 |
+
assert abs(sims[0][0] - 1.0) < 1e-5
|
| 112 |
+
assert idxs[0][0] == 0
|
| 113 |
+
|
| 114 |
+
def test_cosine_similarity_via_dot_product(self):
|
| 115 |
+
"""L2-normalized dot product = cosine similarity."""
|
| 116 |
+
import faiss
|
| 117 |
+
d = 512
|
| 118 |
+
# Two identical vectors should have similarity 1.0
|
| 119 |
+
v = np.random.randn(1, d).astype(np.float32)
|
| 120 |
+
v /= np.linalg.norm(v)
|
| 121 |
+
|
| 122 |
+
index = faiss.IndexFlatIP(d)
|
| 123 |
+
index.add(v)
|
| 124 |
+
|
| 125 |
+
sims, _ = index.search(v, k=1)
|
| 126 |
+
assert abs(sims[0][0] - 1.0) < 1e-5
|
| 127 |
+
|
| 128 |
+
def test_ivf_index_for_large_gallery(self):
|
| 129 |
+
"""IVF index works for large galleries (>10K vectors)."""
|
| 130 |
+
import faiss
|
| 131 |
+
d, n = 512, 10_000
|
| 132 |
+
emb = np.random.randn(n, d).astype(np.float32)
|
| 133 |
+
emb /= np.linalg.norm(emb, axis=1, keepdims=True)
|
| 134 |
+
|
| 135 |
+
nlist = 64
|
| 136 |
+
quantizer = faiss.IndexFlatIP(d)
|
| 137 |
+
index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT)
|
| 138 |
+
index.train(emb)
|
| 139 |
+
index.add(emb)
|
| 140 |
+
index.nprobe = 8
|
| 141 |
+
|
| 142 |
+
assert index.ntotal == n
|
| 143 |
+
# Check that search still works
|
| 144 |
+
sims, idxs = index.search(emb[0:1], k=5)
|
| 145 |
+
assert idxs[0][0] == 0 # self should be top result
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class TestMetadataBuilding:
|
| 149 |
+
def test_metadata_keys(self, dummy_index_dir):
|
| 150 |
+
_, embeddings = dummy_index_dir
|
| 151 |
+
meta_path = dummy_index_dir[0] / "metadata.json"
|
| 152 |
+
with open(meta_path) as f:
|
| 153 |
+
meta = json.load(f)
|
| 154 |
+
assert "0" in meta
|
| 155 |
+
entry = meta["0"]
|
| 156 |
+
assert "filename" in entry
|
| 157 |
+
assert "filepath" in entry
|
| 158 |
+
assert "labels" in entry
|
| 159 |
+
assert "idx" in entry
|
| 160 |
+
|
| 161 |
+
def test_metadata_count_matches_index(self, dummy_index_dir):
|
| 162 |
+
import faiss
|
| 163 |
+
index_dir = dummy_index_dir[0]
|
| 164 |
+
index = faiss.read_index(str(index_dir / "visual_db.index"))
|
| 165 |
+
with open(index_dir / "metadata.json") as f:
|
| 166 |
+
meta = json.load(f)
|
| 167 |
+
assert index.ntotal == len(meta)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class TestVisualSearchEngine:
|
| 171 |
+
"""Tests using mocked BiomedCLIP to avoid model download."""
|
| 172 |
+
|
| 173 |
+
def _get_engine_with_mock_model(self, index_dir):
|
| 174 |
+
"""Create engine with BiomedCLIP mocked out."""
|
| 175 |
+
from visual_search import VisualSearchEngine
|
| 176 |
+
import faiss
|
| 177 |
+
|
| 178 |
+
with patch("visual_search.open_clip.create_model_and_transforms") as mock_create:
|
| 179 |
+
mock_model = MagicMock()
|
| 180 |
+
mock_transform = MagicMock(return_value=MagicMock(
|
| 181 |
+
unsqueeze=lambda _: MagicMock(to=lambda _: MagicMock())
|
| 182 |
+
))
|
| 183 |
+
mock_create.return_value = (mock_model, None, mock_transform)
|
| 184 |
+
|
| 185 |
+
engine = VisualSearchEngine(index_dir=index_dir, device="cpu")
|
| 186 |
+
|
| 187 |
+
# Mock the embed function to return a random normalized vector
|
| 188 |
+
def fake_embed(img):
|
| 189 |
+
v = np.random.randn(1, 512).astype(np.float32)
|
| 190 |
+
v /= np.linalg.norm(v, axis=1, keepdims=True)
|
| 191 |
+
return v
|
| 192 |
+
|
| 193 |
+
engine._embed_image = fake_embed
|
| 194 |
+
return engine
|
| 195 |
+
|
| 196 |
+
def test_search_returns_k_results(self, dummy_index_dir):
|
| 197 |
+
index_dir = dummy_index_dir[0]
|
| 198 |
+
engine = self._get_engine_with_mock_model(index_dir)
|
| 199 |
+
|
| 200 |
+
dummy_img = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
|
| 201 |
+
results = engine.search(dummy_img, top_k=5)
|
| 202 |
+
assert len(results) == 5
|
| 203 |
+
|
| 204 |
+
def test_results_sorted_by_similarity(self, dummy_index_dir):
|
| 205 |
+
index_dir = dummy_index_dir[0]
|
| 206 |
+
engine = self._get_engine_with_mock_model(index_dir)
|
| 207 |
+
|
| 208 |
+
dummy_img = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
|
| 209 |
+
results = engine.search(dummy_img, top_k=5)
|
| 210 |
+
sims = [r.similarity for r in results]
|
| 211 |
+
assert sims == sorted(sims, reverse=True)
|
| 212 |
+
|
| 213 |
+
def test_results_have_required_fields(self, dummy_index_dir):
|
| 214 |
+
index_dir = dummy_index_dir[0]
|
| 215 |
+
engine = self._get_engine_with_mock_model(index_dir)
|
| 216 |
+
|
| 217 |
+
dummy_img = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
|
| 218 |
+
results = engine.search(dummy_img, top_k=3)
|
| 219 |
+
for r in results:
|
| 220 |
+
assert hasattr(r, "rank")
|
| 221 |
+
assert hasattr(r, "filename")
|
| 222 |
+
assert hasattr(r, "filepath")
|
| 223 |
+
assert hasattr(r, "labels")
|
| 224 |
+
assert hasattr(r, "similarity")
|
| 225 |
+
assert 0.0 <= r.similarity <= 1.0
|
| 226 |
+
|
| 227 |
+
def test_ranks_are_sequential(self, dummy_index_dir):
|
| 228 |
+
index_dir = dummy_index_dir[0]
|
| 229 |
+
engine = self._get_engine_with_mock_model(index_dir)
|
| 230 |
+
|
| 231 |
+
dummy_img = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
|
| 232 |
+
results = engine.search(dummy_img, top_k=5)
|
| 233 |
+
for i, r in enumerate(results, start=1):
|
| 234 |
+
assert r.rank == i
|
| 235 |
+
|
| 236 |
+
def test_file_not_found_raises(self, dummy_index_dir):
|
| 237 |
+
index_dir = dummy_index_dir[0]
|
| 238 |
+
engine = self._get_engine_with_mock_model(index_dir)
|
| 239 |
+
with pytest.raises(FileNotFoundError):
|
| 240 |
+
engine.search("/nonexistent/image.png")
|
| 241 |
+
|
| 242 |
+
def test_batch_search(self, dummy_index_dir):
|
| 243 |
+
index_dir = dummy_index_dir[0]
|
| 244 |
+
engine = self._get_engine_with_mock_model(index_dir)
|
| 245 |
+
|
| 246 |
+
imgs = [
|
| 247 |
+
Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
|
| 248 |
+
for _ in range(3)
|
| 249 |
+
]
|
| 250 |
+
batch_results = engine.search_batch(imgs, top_k=5)
|
| 251 |
+
assert len(batch_results) == 3
|
| 252 |
+
assert all(len(r) == 5 for r in batch_results)
|
| 253 |
+
|
| 254 |
+
def test_get_stats(self, dummy_index_dir):
|
| 255 |
+
index_dir = dummy_index_dir[0]
|
| 256 |
+
engine = self._get_engine_with_mock_model(index_dir)
|
| 257 |
+
stats = engine.get_stats()
|
| 258 |
+
assert "total_images" in stats
|
| 259 |
+
assert stats["total_images"] == 20
|
| 260 |
+
assert stats["embed_dim"] == 512
|
| 261 |
+
|
| 262 |
+
def test_to_dict_serializable(self, dummy_index_dir):
|
| 263 |
+
"""Search results must be JSON serializable for API responses."""
|
| 264 |
+
index_dir = dummy_index_dir[0]
|
| 265 |
+
engine = self._get_engine_with_mock_model(index_dir)
|
| 266 |
+
|
| 267 |
+
dummy_img = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
|
| 268 |
+
results = engine.search(dummy_img, top_k=3)
|
| 269 |
+
payload = [r.to_dict() for r in results]
|
| 270 |
+
assert json.dumps(payload) # raises if not serializable
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# ββ Integration tests (require real index) βββββββββββββββββββββββββββββββββββββ
|
| 274 |
+
@pytest.mark.integration
|
| 275 |
+
class TestIntegration:
|
| 276 |
+
"""Run with: pytest -m integration --index_dir ./index --image_dir ./data"""
|
| 277 |
+
|
| 278 |
+
@pytest.fixture(autouse=True)
|
| 279 |
+
def setup(self, request):
|
| 280 |
+
self.index_dir = Path(request.config.getoption("--index_dir", default="./index"))
|
| 281 |
+
self.image_dir = Path(request.config.getoption("--image_dir", default="./data"))
|
| 282 |
+
|
| 283 |
+
def test_real_search(self):
|
| 284 |
+
from visual_search import VisualSearchEngine
|
| 285 |
+
engine = VisualSearchEngine(self.index_dir, device="cpu")
|
| 286 |
+
stats = engine.get_stats()
|
| 287 |
+
assert stats["total_images"] > 0
|
| 288 |
+
print(f"\nIndex contains {stats['total_images']:,} images")
|
| 289 |
+
|
| 290 |
+
def test_search_with_real_image(self):
|
| 291 |
+
from visual_search import VisualSearchEngine
|
| 292 |
+
engine = VisualSearchEngine(self.index_dir, device="cpu")
|
| 293 |
+
|
| 294 |
+
# Find first image in data dir
|
| 295 |
+
images = list(self.image_dir.rglob("*.png"))[:1]
|
| 296 |
+
if not images:
|
| 297 |
+
pytest.skip("No test images found")
|
| 298 |
+
|
| 299 |
+
results = engine.search(images[0], top_k=5, exclude_perfect_match=True)
|
| 300 |
+
assert len(results) > 0
|
| 301 |
+
assert results[0].similarity <= 1.0
|
| 302 |
+
print(f"\nTop result: {results[0].filename} sim={results[0].similarity:.3f}")
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# ββ Pytest config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 306 |
+
def pytest_addoption(parser):
|
| 307 |
+
parser.addoption("--index_dir", action="store", default="./index")
|
| 308 |
+
parser.addoption("--image_dir", action="store", default="./data")
|
visual_search.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
visual_search.py
|
| 3 |
+
ββββββββββββββββ
|
| 4 |
+
Search function for the Medical X-ray RAG system.
|
| 5 |
+
|
| 6 |
+
Input: A chest X-ray image (file path or PIL Image or numpy array)
|
| 7 |
+
Output: Top-K most similar cases from the gallery database
|
| 8 |
+
|
| 9 |
+
This is the module imported by your web app and RAG pipeline.
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
from visual_search import VisualSearchEngine
|
| 13 |
+
|
| 14 |
+
engine = VisualSearchEngine(
|
| 15 |
+
index_dir="./index",
|
| 16 |
+
device="auto"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
results = engine.search("./query_xray.png", top_k=5)
|
| 20 |
+
# returns List[SearchResult]
|
| 21 |
+
for r in results:
|
| 22 |
+
print(f"{r.rank}. {r.filename} sim={r.similarity:.3f} labels={r.labels}")
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import json
|
| 26 |
+
import time
|
| 27 |
+
import logging
|
| 28 |
+
import numpy as np
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
from dataclasses import dataclass, field
|
| 31 |
+
from typing import Union, Optional
|
| 32 |
+
|
| 33 |
+
import faiss
|
| 34 |
+
import torch
|
| 35 |
+
import open_clip
|
| 36 |
+
from PIL import Image, UnidentifiedImageError
|
| 37 |
+
|
| 38 |
+
log = logging.getLogger(__name__)
|
| 39 |
+
|
| 40 |
+
# ββ Constants (must match gallery_builder.py) ββββββββββββββββββββββββββββββββββ
|
| 41 |
+
BIOMEDCLIP_MODEL = "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
|
| 42 |
+
INDEX_FILE = "visual_db.index"
|
| 43 |
+
METADATA_FILE = "metadata.json"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ββ Result dataclass βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 47 |
+
@dataclass
|
| 48 |
+
class SearchResult:
|
| 49 |
+
"""One similar case returned by the search engine."""
|
| 50 |
+
rank: int # 1 = most similar
|
| 51 |
+
idx: int # Internal FAISS index ID
|
| 52 |
+
filename: str # Image filename
|
| 53 |
+
filepath: str # Absolute path to the image
|
| 54 |
+
labels: str # Diagnosis labels (from metadata)
|
| 55 |
+
similarity: float # Cosine similarity [0, 1]
|
| 56 |
+
image: Optional[object] = field(default=None, repr=False)
|
| 57 |
+
# β Optionally loaded PIL Image (set load_images=True in search())
|
| 58 |
+
|
| 59 |
+
def to_dict(self) -> dict:
|
| 60 |
+
return {
|
| 61 |
+
"rank": self.rank,
|
| 62 |
+
"idx": self.idx,
|
| 63 |
+
"filename": self.filename,
|
| 64 |
+
"filepath": self.filepath,
|
| 65 |
+
"labels": self.labels,
|
| 66 |
+
"similarity": round(float(self.similarity), 4),
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ββ Search Engine ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 71 |
+
class VisualSearchEngine:
|
| 72 |
+
"""
|
| 73 |
+
Thread-safe visual search engine for chest X-ray similarity retrieval.
|
| 74 |
+
|
| 75 |
+
Architecture:
|
| 76 |
+
Query image
|
| 77 |
+
β
|
| 78 |
+
βΌ
|
| 79 |
+
BiomedCLIP vision encoder β 512-dim embedding (L2 normalized)
|
| 80 |
+
β
|
| 81 |
+
βΌ
|
| 82 |
+
FAISS IndexFlatIP β cosine similarity search
|
| 83 |
+
β
|
| 84 |
+
βΌ
|
| 85 |
+
Top-K results + metadata
|
| 86 |
+
|
| 87 |
+
Attributes:
|
| 88 |
+
index_dir (Path): Directory containing visual_db.index + metadata.json
|
| 89 |
+
device (str): Compute device for BiomedCLIP
|
| 90 |
+
top_k (int): Default number of results to return
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
index_dir: Union[str, Path],
|
| 96 |
+
device: str = "auto",
|
| 97 |
+
top_k: int = 5,
|
| 98 |
+
):
|
| 99 |
+
self.index_dir = Path(index_dir).resolve()
|
| 100 |
+
self.top_k = top_k
|
| 101 |
+
self._model = None
|
| 102 |
+
self._transform = None
|
| 103 |
+
self._index = None
|
| 104 |
+
self._metadata: dict = {}
|
| 105 |
+
|
| 106 |
+
# Resolve device
|
| 107 |
+
if device == "auto":
|
| 108 |
+
if torch.cuda.is_available():
|
| 109 |
+
self.device = "cuda"
|
| 110 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 111 |
+
self.device = "mps"
|
| 112 |
+
else:
|
| 113 |
+
self.device = "cpu"
|
| 114 |
+
else:
|
| 115 |
+
self.device = device
|
| 116 |
+
|
| 117 |
+
# Eager load
|
| 118 |
+
self._load_index()
|
| 119 |
+
self._load_model()
|
| 120 |
+
log.info(f"VisualSearchEngine ready (index={self._index.ntotal:,} images, device={self.device})")
|
| 121 |
+
|
| 122 |
+
# ββ Private loaders ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 123 |
+
def _load_index(self):
|
| 124 |
+
"""Load FAISS index + metadata from disk."""
|
| 125 |
+
index_path = self.index_dir / INDEX_FILE
|
| 126 |
+
meta_path = self.index_dir / METADATA_FILE
|
| 127 |
+
|
| 128 |
+
if not index_path.exists():
|
| 129 |
+
raise FileNotFoundError(
|
| 130 |
+
f"FAISS index not found: {index_path}\n"
|
| 131 |
+
"Run: python gallery_builder.py --image_dir ./data --output_dir ./index"
|
| 132 |
+
)
|
| 133 |
+
if not meta_path.exists():
|
| 134 |
+
raise FileNotFoundError(f"Metadata file not found: {meta_path}")
|
| 135 |
+
|
| 136 |
+
log.info(f"Loading FAISS index from {index_path}...")
|
| 137 |
+
self._index = faiss.read_index(str(index_path))
|
| 138 |
+
|
| 139 |
+
# For IVF indexes, set nprobe for recall/speed tradeoff
|
| 140 |
+
if hasattr(self._index, "nprobe"):
|
| 141 |
+
self._index.nprobe = 16
|
| 142 |
+
|
| 143 |
+
log.info(f"Index loaded ({self._index.ntotal:,} vectors, dim={self._index.d})")
|
| 144 |
+
|
| 145 |
+
with open(meta_path) as f:
|
| 146 |
+
self._metadata = json.load(f)
|
| 147 |
+
|
| 148 |
+
def _load_model(self):
|
| 149 |
+
"""Load BiomedCLIP vision encoder."""
|
| 150 |
+
log.info("Loading BiomedCLIP encoder...")
|
| 151 |
+
model, _, transform = open_clip.create_model_and_transforms(BIOMEDCLIP_MODEL)
|
| 152 |
+
self._model = model.to(self.device).eval()
|
| 153 |
+
self._transform = transform
|
| 154 |
+
log.info("BiomedCLIP loaded β")
|
| 155 |
+
|
| 156 |
+
# ββ Embedding ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 157 |
+
@torch.no_grad()
|
| 158 |
+
def _embed_image(self, image: Image.Image) -> np.ndarray:
|
| 159 |
+
"""
|
| 160 |
+
Encode a single PIL image β L2-normalized 512-dim embedding.
|
| 161 |
+
Returns shape (1, 512) float32 numpy array.
|
| 162 |
+
"""
|
| 163 |
+
tensor = self._transform(image).unsqueeze(0).to(self.device)
|
| 164 |
+
features = self._model.encode_image(tensor)
|
| 165 |
+
features = features / features.norm(dim=-1, keepdim=True)
|
| 166 |
+
return features.cpu().numpy().astype(np.float32)
|
| 167 |
+
|
| 168 |
+
# ββ Public API βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 169 |
+
def search(
|
| 170 |
+
self,
|
| 171 |
+
query: Union[str, Path, Image.Image, np.ndarray],
|
| 172 |
+
top_k: Optional[int] = None,
|
| 173 |
+
load_images: bool = False,
|
| 174 |
+
exclude_perfect_match: bool = False,
|
| 175 |
+
) -> list[SearchResult]:
|
| 176 |
+
"""
|
| 177 |
+
Find the top-K most similar X-ray images to a query.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
query: File path, PIL Image, or RGB numpy array
|
| 181 |
+
top_k: Number of results (overrides default)
|
| 182 |
+
load_images: Load PIL Images into SearchResult.image
|
| 183 |
+
exclude_perfect_match: Skip results with similarity β₯ 0.9999
|
| 184 |
+
(use when query is in the gallery itself)
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
List[SearchResult] ordered by descending similarity
|
| 188 |
+
"""
|
| 189 |
+
t0 = time.perf_counter()
|
| 190 |
+
k = top_k or self.top_k
|
| 191 |
+
|
| 192 |
+
# ββ Load query image βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 193 |
+
if isinstance(query, (str, Path)):
|
| 194 |
+
query_path = Path(query)
|
| 195 |
+
if not query_path.exists():
|
| 196 |
+
raise FileNotFoundError(f"Query image not found: {query_path}")
|
| 197 |
+
try:
|
| 198 |
+
img = Image.open(query_path).convert("RGB")
|
| 199 |
+
except (UnidentifiedImageError, OSError) as e:
|
| 200 |
+
raise ValueError(f"Cannot open image: {query_path} ({e})")
|
| 201 |
+
|
| 202 |
+
elif isinstance(query, np.ndarray):
|
| 203 |
+
img = Image.fromarray(query.astype(np.uint8))
|
| 204 |
+
|
| 205 |
+
elif isinstance(query, Image.Image):
|
| 206 |
+
img = query.convert("RGB")
|
| 207 |
+
|
| 208 |
+
else:
|
| 209 |
+
raise TypeError(f"Unsupported query type: {type(query)}")
|
| 210 |
+
|
| 211 |
+
# ββ Encode βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 212 |
+
query_emb = self._embed_image(img) # (1, 512)
|
| 213 |
+
|
| 214 |
+
# ββ FAISS search βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 215 |
+
search_k = k + 1 if exclude_perfect_match else k
|
| 216 |
+
similarities, indices = self._index.search(query_emb, search_k)
|
| 217 |
+
similarities = similarities[0] # (k,)
|
| 218 |
+
indices = indices[0] # (k,)
|
| 219 |
+
|
| 220 |
+
# ββ Build results ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 221 |
+
results: list[SearchResult] = []
|
| 222 |
+
rank = 1
|
| 223 |
+
for sim, idx in zip(similarities, indices):
|
| 224 |
+
if idx < 0: # FAISS returns -1 for empty slots
|
| 225 |
+
continue
|
| 226 |
+
if exclude_perfect_match and float(sim) >= 0.9999:
|
| 227 |
+
continue # skip exact self-match
|
| 228 |
+
|
| 229 |
+
meta = self._metadata.get(str(idx), {})
|
| 230 |
+
filepath = meta.get("filepath", "")
|
| 231 |
+
|
| 232 |
+
result = SearchResult(
|
| 233 |
+
rank=rank,
|
| 234 |
+
idx=int(idx),
|
| 235 |
+
filename=meta.get("filename", f"image_{idx}"),
|
| 236 |
+
filepath=filepath,
|
| 237 |
+
labels=meta.get("labels", "Unknown"),
|
| 238 |
+
similarity=float(sim),
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
if load_images and filepath and Path(filepath).exists():
|
| 242 |
+
try:
|
| 243 |
+
result.image = Image.open(filepath).convert("RGB")
|
| 244 |
+
except Exception:
|
| 245 |
+
pass # image loading is best-effort
|
| 246 |
+
|
| 247 |
+
results.append(result)
|
| 248 |
+
rank += 1
|
| 249 |
+
if len(results) >= k:
|
| 250 |
+
break
|
| 251 |
+
|
| 252 |
+
elapsed_ms = (time.perf_counter() - t0) * 1000
|
| 253 |
+
log.debug(f"Search completed in {elapsed_ms:.1f} ms β {len(results)} results")
|
| 254 |
+
return results
|
| 255 |
+
|
| 256 |
+
def search_batch(
|
| 257 |
+
self,
|
| 258 |
+
queries: list[Union[str, Path, Image.Image]],
|
| 259 |
+
top_k: Optional[int] = None,
|
| 260 |
+
) -> list[list[SearchResult]]:
|
| 261 |
+
"""
|
| 262 |
+
Batch search for multiple query images.
|
| 263 |
+
More efficient than calling search() in a loop.
|
| 264 |
+
"""
|
| 265 |
+
k = top_k or self.top_k
|
| 266 |
+
embeddings = []
|
| 267 |
+
|
| 268 |
+
for q in queries:
|
| 269 |
+
if isinstance(q, (str, Path)):
|
| 270 |
+
img = Image.open(q).convert("RGB")
|
| 271 |
+
elif isinstance(q, np.ndarray):
|
| 272 |
+
img = Image.fromarray(q.astype(np.uint8))
|
| 273 |
+
else:
|
| 274 |
+
img = q.convert("RGB")
|
| 275 |
+
embeddings.append(self._embed_image(img)[0])
|
| 276 |
+
|
| 277 |
+
batch_emb = np.stack(embeddings) # (N, 512)
|
| 278 |
+
sims_batch, idxs_batch = self._index.search(batch_emb, k)
|
| 279 |
+
|
| 280 |
+
all_results = []
|
| 281 |
+
for sims, idxs in zip(sims_batch, idxs_batch):
|
| 282 |
+
results = []
|
| 283 |
+
for rank, (sim, idx) in enumerate(zip(sims, idxs), start=1):
|
| 284 |
+
if idx < 0:
|
| 285 |
+
continue
|
| 286 |
+
meta = self._metadata.get(str(idx), {})
|
| 287 |
+
results.append(SearchResult(
|
| 288 |
+
rank=rank,
|
| 289 |
+
idx=int(idx),
|
| 290 |
+
filename=meta.get("filename", f"image_{idx}"),
|
| 291 |
+
filepath=meta.get("filepath", ""),
|
| 292 |
+
labels=meta.get("labels", "Unknown"),
|
| 293 |
+
similarity=float(sim),
|
| 294 |
+
))
|
| 295 |
+
all_results.append(results)
|
| 296 |
+
|
| 297 |
+
return all_results
|
| 298 |
+
|
| 299 |
+
def get_stats(self) -> dict:
|
| 300 |
+
"""Return index statistics."""
|
| 301 |
+
return {
|
| 302 |
+
"total_images": self._index.ntotal,
|
| 303 |
+
"embed_dim": self._index.d,
|
| 304 |
+
"index_type": type(self._index).__name__,
|
| 305 |
+
"device": self.device,
|
| 306 |
+
"index_dir": str(self.index_dir),
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
def __repr__(self) -> str:
|
| 310 |
+
return (
|
| 311 |
+
f"VisualSearchEngine("
|
| 312 |
+
f"images={self._index.ntotal:,}, "
|
| 313 |
+
f"device={self.device}, "
|
| 314 |
+
f"index_dir={self.index_dir})"
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
# ββ Standalone CLI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 319 |
+
def main():
|
| 320 |
+
import argparse
|
| 321 |
+
from pprint import pprint
|
| 322 |
+
|
| 323 |
+
parser = argparse.ArgumentParser(
|
| 324 |
+
description="Search for similar X-ray images"
|
| 325 |
+
)
|
| 326 |
+
parser.add_argument("query_image", type=Path, help="Path to query X-ray image")
|
| 327 |
+
parser.add_argument(
|
| 328 |
+
"--index_dir", type=Path, default=Path("./index"),
|
| 329 |
+
help="Directory with visual_db.index (default: ./index)"
|
| 330 |
+
)
|
| 331 |
+
parser.add_argument("--top_k", type=int, default=5)
|
| 332 |
+
parser.add_argument("--device", default="auto")
|
| 333 |
+
args = parser.parse_args()
|
| 334 |
+
|
| 335 |
+
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
|
| 336 |
+
|
| 337 |
+
engine = VisualSearchEngine(
|
| 338 |
+
index_dir=args.index_dir,
|
| 339 |
+
device=args.device,
|
| 340 |
+
top_k=args.top_k,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
print(f"\nπ Query: {args.query_image}")
|
| 344 |
+
print("=" * 60)
|
| 345 |
+
results = engine.search(args.query_image, exclude_perfect_match=True)
|
| 346 |
+
|
| 347 |
+
for r in results:
|
| 348 |
+
bar = "β" * int(r.similarity * 30)
|
| 349 |
+
print(f" #{r.rank} {r.similarity:.3f} {bar}")
|
| 350 |
+
print(f" {r.filename}")
|
| 351 |
+
print(f" Labels: {r.labels}")
|
| 352 |
+
print()
|
| 353 |
+
|
| 354 |
+
print(f"Index stats: {engine.get_stats()}")
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
if __name__ == "__main__":
|
| 358 |
+
main()
|