Sukrati commited on
Commit
345576d
Β·
0 Parent(s):

Deploy MedRAG to Hugging Face Space v4

Browse files
.dockerignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .git
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ .DS_Store
7
+ data/
8
+ index/
9
+ *.zip
10
+ render.yaml
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ .Python
7
+ *.ipynb_checkpoints
8
+
9
+ # Data and indexes
10
+ data/
11
+ index/
12
+ *.zip
13
+ embeddings_heatmap.png
14
+ embeddings_pca.png
15
+ embeddings_raw.png
16
+
17
+ # Large datasets
18
+ chexpert_full/
.streamlit/config.toml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [server]
2
+ enableCORS = false
3
+ enableXsrfProtection = false
4
+ maxUploadSize = 200
5
+ headless = true
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1
4
+ ENV PYTHONUNBUFFERED=1
5
+ ENV PIP_NO_CACHE_DIR=1
6
+ ENV DATA_DIR=/tmp/medrag_data
7
+ ENV HF_HOME=/tmp/hf_cache
8
+ ENV PREFETCH_MODEL=1
9
+
10
+ WORKDIR /app
11
+
12
+ RUN apt-get update && apt-get install -y --no-install-recommends \
13
+ git \
14
+ libglib2.0-0 \
15
+ libsm6 \
16
+ libxext6 \
17
+ libxrender1 \
18
+ && rm -rf /var/lib/apt/lists/*
19
+
20
+ COPY requirements-space.txt ./
21
+ RUN pip install --index-url https://download.pytorch.org/whl/cpu torch torchvision \
22
+ && pip install -r requirements-space.txt
23
+
24
+ COPY . .
25
+
26
+ RUN chmod +x /app/start.sh
27
+
28
+ EXPOSE 7860
29
+
30
+ CMD ["/app/start.sh"]
MedRAG.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
README.md ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MedRAG Diagnostic Assistant
3
+ emoji: 🩺
4
+ colorFrom: blue
5
+ colorTo: gray
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: false
9
+ ---
10
+
11
+ # MedRAG
12
+
13
+ MedRAG is a multimodal chest X-ray retrieval and diagnostic-assistance app built on:
14
+ - BiomedCLIP for image embeddings and zero-shot disease scoring
15
+ - FAISS for similar-case retrieval
16
+ - a crosscheck layer that combines classifier output with retrieved case evidence
17
+ - Streamlit for the application UI
18
+
19
+ The current app supports:
20
+ - chest X-ray upload
21
+ - sample-image testing
22
+ - similar-case retrieval from the indexed gallery
23
+ - zero-shot disease probability ranking
24
+ - retrieval-supported clinical assessment text
25
+ - Hugging Face Spaces deployment through Docker
26
+
27
+ ## Current App Flow
28
+
29
+ 1. The user uploads a chest X-ray or selects a sample image.
30
+ 2. The app encodes the image with BiomedCLIP.
31
+ 3. FAISS retrieves the most visually similar historical cases.
32
+ 4. BiomedCLIP scores 14 CheXpert disease prompts.
33
+ 5. A crosscheck step combines retrieval agreement with classifier confidence.
34
+ 6. The app renders:
35
+ - generated clinical assessment
36
+ - ranked diagnoses
37
+ - top disease probabilities
38
+ - similar historical cases
39
+
40
+ ## Project Files
41
+
42
+ Core app:
43
+ - `app.py` - Streamlit UI and diagnosis pipeline
44
+ - `visual_search.py` - FAISS-backed visual search engine
45
+ - `download_assets.py` - downloads demo index/images and prefetches BiomedCLIP
46
+
47
+ Index/data tooling:
48
+ - `gallery_builder.py` - build FAISS index from chest X-ray images
49
+ - `data_downloader.py` - download source datasets
50
+ - `rewrite_metadata.py` - rewrite metadata filepaths for deployment
51
+
52
+ Research/demo:
53
+ - `MedRAG.ipynb` - notebook containing the retrieval, zero-shot classification, and crosscheck logic that the app was ported from
54
+
55
+ Deployment:
56
+ - `Dockerfile` - Hugging Face Spaces container build
57
+ - `start.sh` - startup entrypoint for Spaces
58
+ - `requirements-space.txt` - CPU-friendly dependencies for Spaces
59
+ - `render.yaml` - older Render deployment config
60
+
61
+ ## Hugging Face Spaces
62
+
63
+ This repo is configured for a Docker Space.
64
+
65
+ ### Deploy steps
66
+
67
+ 1. Create a new Hugging Face Space.
68
+ 2. Choose `Docker`.
69
+ 3. Push this repo to the Space remote.
70
+ 4. Let the Space build and start.
71
+
72
+ The Space startup does the following:
73
+ - installs CPU-only PyTorch
74
+ - downloads the public `index.zip` and `images.zip`
75
+ - prefetches the BiomedCLIP model
76
+ - starts Streamlit on port `7860`
77
+
78
+ ## Local Run
79
+
80
+ Install dependencies:
81
+
82
+ ```bash
83
+ pip install --index-url https://download.pytorch.org/whl/cpu torch torchvision
84
+ pip install -r requirements-space.txt
85
+ ```
86
+
87
+ Run the app:
88
+
89
+ ```bash
90
+ python download_assets.py
91
+ streamlit run app.py
92
+ ```
93
+
94
+ ## Data Notes
95
+
96
+ The deployed demo uses a reduced subset of CheXpert so it can run on free CPU infrastructure.
97
+
98
+ Assets are pulled from public Google Drive links by default:
99
+ - FAISS index archive
100
+ - subset image archive
101
+
102
+ If needed, override them with:
103
+ - `GDRIVE_INDEX_URL`
104
+ - `GDRIVE_IMAGES_URL`
105
+
106
+ Optional environment variables:
107
+ - `DATA_DIR`
108
+ - `HF_HOME`
109
+ - `PREFETCH_MODEL`
110
+
111
+ ## Limitations
112
+
113
+ - The app is a diagnostic aid, not a clinical decision system.
114
+ - Free-tier hosting will have slow cold starts.
115
+ - The generated assessment is rule-based synthesis from model scores and retrieval support, not a physician-grade interpretation.
116
+ - The original project plan referenced a larger multi-agent/LLM flow; the current deployed app implements the retrieval + classifier + crosscheck path from the notebook.
app.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import shutil
4
+ from collections import Counter
5
+ from pathlib import Path
6
+
7
+ import streamlit as st
8
+ import torch
9
+ from PIL import Image
10
+
11
+ from visual_search import VisualSearchEngine
12
+
13
+
14
+ APP_TITLE = "Multimodal Medical RAG Diagnostic Assistant"
15
+ MODEL_ID = "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
16
+ DISEASE_PROMPTS = {
17
+ "No Finding": "Chest X-ray with no abnormality, normal findings",
18
+ "Enlarged Cardiomediastinum": "Chest X-ray showing enlarged cardiomediastinum",
19
+ "Cardiomegaly": "Chest X-ray showing cardiomegaly, enlarged heart",
20
+ "Lung Opacity": "Chest X-ray showing lung opacity",
21
+ "Lung Lesion": "Chest X-ray showing lung lesion or mass",
22
+ "Edema": "Chest X-ray showing pulmonary edema, fluid in lungs",
23
+ "Consolidation": "Chest X-ray showing consolidation in lung",
24
+ "Pneumonia": "Chest X-ray showing pneumonia, lung infection",
25
+ "Atelectasis": "Chest X-ray showing atelectasis, collapsed lung",
26
+ "Pneumothorax": "Chest X-ray showing pneumothorax, air in pleural space",
27
+ "Pleural Effusion": "Chest X-ray showing pleural effusion, fluid around lung",
28
+ "Pleural Other": "Chest X-ray showing pleural abnormality",
29
+ "Fracture": "Chest X-ray showing rib fracture or bone fracture",
30
+ "Support Devices": "Chest X-ray showing support devices, tubes or lines",
31
+ }
32
+ INPUT_GUARDRAIL_PROMPTS = {
33
+ "Chest X-ray": "A diagnostic chest X-ray radiograph showing the thorax and lungs",
34
+ "Portrait Photo": "A portrait photograph of a person or celebrity",
35
+ "Animal Photo": "A natural photograph of an animal or pet",
36
+ "Document Screenshot": "A screenshot of a document, website, or computer interface",
37
+ "Natural Image": "A normal everyday color photograph of a scene or object",
38
+ }
39
+ SYNONYMS = {
40
+ "Pleural Effusion": ["pleural fluid", "fluid around lung", "effusion"],
41
+ "Cardiomegaly": ["enlarged heart", "cardiac enlargement"],
42
+ "Pneumonia": ["lung infection", "consolidation"],
43
+ "Edema": ["fluid in lungs", "pulmonary edema"],
44
+ "Atelectasis": ["collapsed lung", "lung collapse"],
45
+ "Lung Opacity": ["opacity", "haziness", "infiltrate"],
46
+ "No Finding": ["normal", "no abnormality", "clear"],
47
+ }
48
+
49
+
50
+ def _get_paths() -> tuple[Path, Path]:
51
+ repo_index = Path("index").resolve()
52
+ data_dir = Path(os.getenv("DATA_DIR", "/tmp/medrag_data")).resolve()
53
+ index_dir = Path(os.getenv("INDEX_DIR", data_dir / "index")).resolve()
54
+ return repo_index, index_dir
55
+
56
+
57
+ def _ensure_index_available() -> Path:
58
+ repo_index, index_dir = _get_paths()
59
+ if index_dir.exists():
60
+ return index_dir
61
+ if repo_index.exists():
62
+ index_dir.parent.mkdir(parents=True, exist_ok=True)
63
+ shutil.copytree(repo_index, index_dir)
64
+ return index_dir
65
+ raise FileNotFoundError("FAISS index not found. Expected at DATA_DIR/index or ./index")
66
+
67
+
68
+ @st.cache_resource(show_spinner=True)
69
+ def _load_engine() -> VisualSearchEngine:
70
+ index_dir = _ensure_index_available()
71
+ return VisualSearchEngine(index_dir=index_dir, device="auto", top_k=5)
72
+
73
+
74
+ @st.cache_resource(show_spinner=False)
75
+ def _load_text_features() -> tuple[list[str], torch.Tensor]:
76
+ engine = _load_engine()
77
+ tokenizer = __import__("open_clip").get_tokenizer(MODEL_ID)
78
+ with torch.no_grad():
79
+ tokens = tokenizer(list(DISEASE_PROMPTS.values())).to(engine.device)
80
+ text_features = engine._model.encode_text(tokens)
81
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
82
+ return list(DISEASE_PROMPTS.keys()), text_features
83
+
84
+
85
+ @st.cache_resource(show_spinner=False)
86
+ def _load_guardrail_features() -> tuple[list[str], torch.Tensor]:
87
+ engine = _load_engine()
88
+ tokenizer = __import__("open_clip").get_tokenizer(MODEL_ID)
89
+ with torch.no_grad():
90
+ tokens = tokenizer(list(INPUT_GUARDRAIL_PROMPTS.values())).to(engine.device)
91
+ text_features = engine._model.encode_text(tokens)
92
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
93
+ return list(INPUT_GUARDRAIL_PROMPTS.keys()), text_features
94
+
95
+
96
+ def _pick_sample_image(data_dir: Path) -> Path | None:
97
+ images_dir = data_dir / "images"
98
+ if not images_dir.exists():
99
+ return None
100
+ candidates = list(images_dir.glob("*.jpg")) + list(images_dir.glob("*.png")) + list(images_dir.glob("*.jpeg"))
101
+ if not candidates:
102
+ return None
103
+ return random.choice(candidates)
104
+
105
+
106
+ @torch.no_grad()
107
+ def _predict_diseases(image: Image.Image) -> dict[str, float]:
108
+ engine = _load_engine()
109
+ disease_names, text_features = _load_text_features()
110
+ tensor = engine._transform(image.convert("RGB")).unsqueeze(0).to(engine.device)
111
+ image_features = engine._model.encode_image(tensor)
112
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
113
+ similarities = (image_features @ text_features.T).squeeze(0)
114
+ probs = torch.softmax(similarities * 100, dim=0).detach().cpu().tolist()
115
+ results = {
116
+ disease_names[i]: round(float(probs[i]) * 100, 2)
117
+ for i in range(len(disease_names))
118
+ }
119
+ return dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
120
+
121
+
122
+ @torch.no_grad()
123
+ def _validate_input_image(image: Image.Image) -> tuple[bool, dict[str, float]]:
124
+ engine = _load_engine()
125
+ labels, text_features = _load_guardrail_features()
126
+ tensor = engine._transform(image.convert("RGB")).unsqueeze(0).to(engine.device)
127
+ image_features = engine._model.encode_image(tensor)
128
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
129
+ similarities = (image_features @ text_features.T).squeeze(0)
130
+ probs = torch.softmax(similarities * 100, dim=0).detach().cpu().tolist()
131
+ scores = {labels[i]: round(float(probs[i]) * 100, 2) for i in range(len(labels))}
132
+ chest_score = scores["Chest X-ray"]
133
+ next_best = max(score for label, score in scores.items() if label != "Chest X-ray")
134
+ is_valid = chest_score >= 55 and chest_score > next_best
135
+ return is_valid, dict(sorted(scores.items(), key=lambda item: item[1], reverse=True))
136
+
137
+
138
+ def _labels_match(disease: str, label_str: str) -> bool:
139
+ label_lower = label_str.lower()
140
+ if disease.lower() in label_lower:
141
+ return True
142
+ return any(syn.lower() in label_lower for syn in SYNONYMS.get(disease, []))
143
+
144
+
145
+ def _crosscheck(similar_cases, disease_probs: dict[str, float]) -> list[dict]:
146
+ top_diseases = list(disease_probs.keys())[:5]
147
+ diagnosis = []
148
+ total_cases = max(len(similar_cases), 1)
149
+
150
+ for disease in top_diseases:
151
+ llm_prob = disease_probs[disease]
152
+ matching_cases = sum(1 for case in similar_cases if _labels_match(disease, case.labels))
153
+ gallery_support = matching_cases / total_cases
154
+ confidence = (llm_prob / 100 * 0.5) + (gallery_support * 0.5)
155
+ if gallery_support >= 0.6 and llm_prob >= 20:
156
+ status = "HIGH"
157
+ elif gallery_support >= 0.3 or llm_prob >= 15:
158
+ status = "MEDIUM"
159
+ else:
160
+ status = "LOW"
161
+ diagnosis.append({
162
+ "disease": disease,
163
+ "llm_probability": llm_prob,
164
+ "matching_cases": matching_cases,
165
+ "total_cases": total_cases,
166
+ "gallery_support": f"{matching_cases}/{total_cases} cases",
167
+ "confidence": round(confidence * 100, 1),
168
+ "status": status,
169
+ })
170
+ return sorted(diagnosis, key=lambda item: item["confidence"], reverse=True)
171
+
172
+
173
+ def _positive_labels(label_str: str) -> list[str]:
174
+ positives = []
175
+ for part in label_str.split(" | "):
176
+ if ": Positive" in part:
177
+ positives.append(part.split(":")[0])
178
+ return positives
179
+
180
+
181
+ def _generate_assessment(diagnosis: list[dict], similar_cases) -> str:
182
+ primary = diagnosis[0]
183
+ top_positive_labels = Counter()
184
+ for case in similar_cases:
185
+ top_positive_labels.update(_positive_labels(case.labels))
186
+
187
+ supporting_findings = ", ".join(label for label, _ in top_positive_labels.most_common(3)) or "no repeated positive findings"
188
+ differential = ", ".join(item["disease"] for item in diagnosis[1:4])
189
+
190
+ return f"""
191
+ ## Primary Clinical Impression
192
+
193
+ Based on visual similarity retrieval and zero-shot disease classification, the leading impression is **{primary["disease"]}** with a combined confidence of **{primary["confidence"]}%**.
194
+
195
+ ## Evidence Summary
196
+
197
+ - The classifier estimated **{primary["llm_probability"]}%** probability for {primary["disease"]}.
198
+ - The retrieval engine found **{primary["gallery_support"]}** similar cases supporting this diagnosis.
199
+ - The most repeated positive findings among retrieved cases were: **{supporting_findings}**.
200
+
201
+ ## Differential Diagnosis
202
+
203
+ Alternative conditions to consider are **{differential}**. These remain relevant because visually similar cases include overlapping thoracic findings common across chest X-ray pathology.
204
+
205
+ ## Clinical Note
206
+
207
+ This is a retrieval-supported decision aid, not a definitive medical diagnosis. Final interpretation should be confirmed by a radiologist or clinician.
208
+ """.strip()
209
+
210
+
211
+ def _run_analysis(image: Image.Image, top_k: int):
212
+ engine = _load_engine()
213
+ similar_cases = engine.search(image, top_k=top_k, load_images=False)
214
+ disease_probs = _predict_diseases(image)
215
+ diagnosis = _crosscheck(similar_cases, disease_probs)
216
+ assessment = _generate_assessment(diagnosis, similar_cases)
217
+ return similar_cases, disease_probs, diagnosis, assessment
218
+
219
+
220
+ def _render_similar_cases(similar_cases):
221
+ st.markdown("### Similar Historical Cases")
222
+ for idx, case in enumerate(similar_cases, start=1):
223
+ cols = st.columns([1, 3])
224
+ with cols[0]:
225
+ if case.filepath and Path(case.filepath).exists():
226
+ try:
227
+ st.image(Image.open(case.filepath).convert("RGB"), use_container_width=True)
228
+ except Exception:
229
+ st.caption("Preview unavailable")
230
+ with cols[1]:
231
+ st.markdown(f"**#{idx} {case.filename}**")
232
+ st.write(f"Similarity: {case.similarity:.3f}")
233
+ positives = _positive_labels(case.labels)
234
+ st.write(f"Confirmed findings: {', '.join(positives) if positives else 'None'}")
235
+
236
+
237
+ def main():
238
+ st.set_page_config(page_title=APP_TITLE, layout="wide")
239
+ st.title(APP_TITLE)
240
+ st.caption(
241
+ "Upload a chest X-ray. The system retrieves similar historical cases and generates a retrieval-supported differential diagnosis."
242
+ )
243
+
244
+ with st.sidebar:
245
+ st.markdown("**Index Status**")
246
+ try:
247
+ index_dir = _ensure_index_available()
248
+ st.write(f"Index dir: `{index_dir}`")
249
+ data_dir = index_dir.parent
250
+ except FileNotFoundError as exc:
251
+ st.error(str(exc))
252
+ return
253
+
254
+ top_k = st.slider("Retrieved Cases", min_value=3, max_value=20, value=5, step=1)
255
+ if st.button("Use Sample Image"):
256
+ st.session_state["sample_path"] = str(_pick_sample_image(data_dir) or "")
257
+ if st.button("Clear"):
258
+ st.session_state.pop("sample_path", None)
259
+ st.session_state.pop("analysis_ready", None)
260
+ st.rerun()
261
+ st.caption("First analysis can still be slow on Render free tier.")
262
+
263
+ uploaded = st.file_uploader("Upload Patient Chest X-Ray", type=["png", "jpg", "jpeg"])
264
+ sample_path = st.session_state.get("sample_path")
265
+
266
+ query_image = None
267
+ if uploaded is not None:
268
+ query_image = Image.open(uploaded).convert("RGB")
269
+ st.session_state["analysis_ready"] = True
270
+ elif sample_path:
271
+ query_image = Image.open(sample_path).convert("RGB")
272
+ st.session_state["analysis_ready"] = True
273
+
274
+ left, right = st.columns([1.05, 1.25])
275
+
276
+ with left:
277
+ st.markdown("### Input X-Ray")
278
+ if query_image is not None:
279
+ st.image(query_image, use_container_width=True)
280
+ else:
281
+ st.info("Upload an image or use the sample button.")
282
+
283
+ with right:
284
+ st.markdown("### Generated Clinical Assessment")
285
+ if query_image is None:
286
+ st.info("Run an analysis to generate the assessment.")
287
+ return
288
+
289
+ if st.button("Submit", type="primary") or st.session_state.get("analysis_ready"):
290
+ with st.spinner("Running retrieval, classification, and crosscheck..."):
291
+ is_valid_xray, input_scores = _validate_input_image(query_image)
292
+ if not is_valid_xray:
293
+ st.error("This tool only supports chest X-ray images. Please upload a chest radiograph.")
294
+ st.markdown("### Input Validation")
295
+ for label, score in list(input_scores.items())[:3]:
296
+ st.write(f"{label}: {score}%")
297
+ st.session_state["analysis_ready"] = False
298
+ return
299
+ similar_cases, disease_probs, diagnosis, assessment = _run_analysis(query_image, top_k)
300
+
301
+ st.markdown(assessment)
302
+ st.markdown("### Ranked Diagnoses")
303
+ for item in diagnosis:
304
+ st.write(
305
+ f"**{item['disease']}** | classifier {item['llm_probability']}% | "
306
+ f"gallery {item['gallery_support']} | confidence {item['confidence']}% [{item['status']}]"
307
+ )
308
+ st.markdown("### Top Disease Probabilities")
309
+ for disease, prob in list(disease_probs.items())[:5]:
310
+ st.write(f"{disease}: {prob}%")
311
+ _render_similar_cases(similar_cases)
312
+ st.session_state["analysis_ready"] = False
313
+
314
+
315
+ if __name__ == "__main__":
316
+ main()
data_downloader.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ data_downloader.py
3
+ ──────────────────
4
+ Downloads the NIH ChestX-ray14 dataset sample (5,606 images, ~1.2 GB).
5
+ This is the public domain dataset used to build the visual_db.index.
6
+
7
+ The NIH dataset contains 14 disease labels per image in the CSV metadata:
8
+ Atelectasis, Cardiomegaly, Effusion, Infiltration, Mass, Nodule,
9
+ Pneumonia, Pneumothorax, Consolidation, Edema, Emphysema, Fibrosis,
10
+ Pleural_Thickening, Hernia (plus "No Finding")
11
+
12
+ Usage:
13
+ python data_downloader.py --output_dir ./data
14
+ """
15
+
16
+ import os
17
+ import sys
18
+ import time
19
+ import zipfile
20
+ import argparse
21
+ import requests
22
+ import pandas as pd
23
+ from pathlib import Path
24
+ from tqdm import tqdm
25
+
26
+ # ── NIH ChestX-ray14 public download URLs ─────────────────────────────────────
27
+ # Source: https://nihcc.app.box.com/v/ChestXray-NIHCC
28
+ # The NIH provides 12 batch ZIPs + 1 metadata CSV.
29
+ # We use only the FIRST batch (images_001.tar.gz β†’ ~1.1 GB, 4,999 images)
30
+ # for a fast bootstrap. Add more batches for larger gallery.
31
+
32
+ NIH_METADATA_URL = (
33
+ "https://raw.githubusercontent.com/ieee8023/covid-chestxray-dataset/"
34
+ "master/metadata.csv" # placeholder – real URL below
35
+ )
36
+
37
+ # Real NIH metadata (hosted on Kaggle mirror for convenience)
38
+ NIH_KAGGLE_METADATA = "https://raw.githubusercontent.com/mlmed/torchxrayvision/master/torchxrayvision/data_dicts/nih_chest_xray_dict.json"
39
+
40
+ # ── Open-I (Indiana University) – ALWAYS freely available, no login ───────────
41
+ # 7,470 frontal X-rays ~900 MB
42
+ OPENI_BASE = "https://openi.nlm.nih.gov/imgs/collections/"
43
+ OPENI_ARCHIVE = "NLMCXR_png.tgz" # full archive
44
+ OPENI_METADATA_URL = "https://openi.nlm.nih.gov/api/search?q=&it=x&m=1&n=500"
45
+
46
+ # ── Lightweight fallback: Kaggle chest-xray-pneumonia (1.15 GB) ───────────────
47
+ # https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia
48
+ # Requires kaggle CLI auth token.
49
+
50
+ SUPPORTED_SOURCES = ["openi", "nih_sample", "local"]
51
+
52
+
53
+ def download_with_progress(url: str, dest_path: Path, chunk_size: int = 8192) -> bool:
54
+ """Stream-download a file with a tqdm progress bar."""
55
+ try:
56
+ resp = requests.get(url, stream=True, timeout=60)
57
+ resp.raise_for_status()
58
+ total = int(resp.headers.get("content-length", 0))
59
+ dest_path.parent.mkdir(parents=True, exist_ok=True)
60
+ with open(dest_path, "wb") as f, tqdm(
61
+ total=total, unit="B", unit_scale=True,
62
+ desc=dest_path.name, ncols=80
63
+ ) as bar:
64
+ for chunk in resp.iter_content(chunk_size=chunk_size):
65
+ f.write(chunk)
66
+ bar.update(len(chunk))
67
+ return True
68
+ except Exception as e:
69
+ print(f"[ERROR] Download failed: {e}")
70
+ return False
71
+
72
+
73
+ def download_openi(output_dir: Path) -> Path:
74
+ """
75
+ Download Open-I Indiana University chest X-ray PNG collection.
76
+ Returns the directory containing .png images.
77
+ """
78
+ import tarfile
79
+
80
+ output_dir.mkdir(parents=True, exist_ok=True)
81
+ archive_path = output_dir / OPENI_ARCHIVE
82
+ images_dir = output_dir / "openi_images"
83
+
84
+ if images_dir.exists() and any(images_dir.glob("*.png")):
85
+ print(f"[SKIP] Open-I images already present at {images_dir}")
86
+ return images_dir
87
+
88
+ print("=" * 60)
89
+ print("Downloading Open-I Indiana X-ray dataset (~900 MB)...")
90
+ print("Source: National Library of Medicine (public domain)")
91
+ print("=" * 60)
92
+
93
+ url = OPENI_BASE + OPENI_ARCHIVE
94
+ if not download_with_progress(url, archive_path):
95
+ raise RuntimeError("Failed to download Open-I archive.")
96
+
97
+ print(f"Extracting to {images_dir}...")
98
+ images_dir.mkdir(exist_ok=True)
99
+ with tarfile.open(archive_path, "r:gz") as tar:
100
+ tar.extractall(path=images_dir)
101
+
102
+ archive_path.unlink() # free disk space
103
+ print(f"[OK] Open-I images extracted β†’ {images_dir}")
104
+ return images_dir
105
+
106
+
107
+ def download_nih_sample(output_dir: Path, max_images: int = 5000) -> Path:
108
+ """
109
+ Download NIH ChestX-ray14 batch_01 (~4,999 images, ~1.1 GB).
110
+ Uses direct Box.com links published by NIH.
111
+ """
112
+ import tarfile
113
+
114
+ NIH_BATCH1_URL = (
115
+ "https://nihcc.box.com/shared/static/"
116
+ "vfk49d74nhbxq3nqjg0900w5nvkorp5c.gz"
117
+ )
118
+
119
+ output_dir.mkdir(parents=True, exist_ok=True)
120
+ archive_path = output_dir / "nih_images_001.tar.gz"
121
+ images_dir = output_dir / "nih_images"
122
+
123
+ if images_dir.exists() and any(images_dir.glob("*.png")):
124
+ print(f"[SKIP] NIH images already present at {images_dir}")
125
+ return images_dir
126
+
127
+ print("=" * 60)
128
+ print("Downloading NIH ChestX-ray14 Batch 1 (~1.1 GB)...")
129
+ print("Source: NIH Clinical Center (CC0 license)")
130
+ print("=" * 60)
131
+
132
+ if not download_with_progress(NIH_BATCH1_URL, archive_path):
133
+ raise RuntimeError(
134
+ "Failed to download NIH batch. "
135
+ "Try manual download from: https://nihcc.app.box.com/v/ChestXray-NIHCC"
136
+ )
137
+
138
+ print(f"Extracting to {images_dir}...")
139
+ images_dir.mkdir(exist_ok=True)
140
+ with tarfile.open(archive_path, "r:gz") as tar:
141
+ members = tar.getmembers()[:max_images]
142
+ tar.extractall(path=images_dir, members=members)
143
+
144
+ archive_path.unlink()
145
+ print(f"[OK] NIH images extracted β†’ {images_dir}")
146
+ return images_dir
147
+
148
+
149
+ def download_nih_metadata(output_dir: Path) -> Path:
150
+ """Download the NIH ChestX-ray14 labels CSV."""
151
+ META_URL = (
152
+ "https://raw.githubusercontent.com/mlmed/torchxrayvision/"
153
+ "master/tests/test_data/nih_data_entry_small.csv"
154
+ )
155
+ # Full metadata (108,948 rows):
156
+ FULL_META_URL = (
157
+ "https://raw.githubusercontent.com/ieee8023/chexnet-dataset/"
158
+ "master/Data_Entry_2017.csv"
159
+ )
160
+ dest = output_dir / "nih_metadata.csv"
161
+ if dest.exists():
162
+ return dest
163
+ print("Downloading NIH metadata CSV...")
164
+ download_with_progress(FULL_META_URL, dest)
165
+ return dest
166
+
167
+
168
+ def scan_local_images(image_dir: Path) -> list[Path]:
169
+ """Return all PNG/JPG images in a directory (recursive)."""
170
+ extensions = {".png", ".jpg", ".jpeg"}
171
+ images = [
172
+ p for p in image_dir.rglob("*")
173
+ if p.suffix.lower() in extensions
174
+ ]
175
+ print(f"[SCAN] Found {len(images):,} images in {image_dir}")
176
+ return images
177
+
178
+
179
+ def build_metadata_csv(
180
+ image_dir: Path,
181
+ nih_csv_path: Path | None,
182
+ output_path: Path
183
+ ) -> pd.DataFrame:
184
+ """
185
+ Build a unified metadata CSV:
186
+ filename | filepath | labels | source
187
+ Works whether NIH labels CSV is available or not.
188
+ """
189
+ images = scan_local_images(image_dir)
190
+
191
+ rows = []
192
+ label_lookup = {}
193
+
194
+ if nih_csv_path and nih_csv_path.exists():
195
+ df_nih = pd.read_csv(nih_csv_path)
196
+ # NIH CSV cols: Image Index, Finding Labels, Patient ID, ...
197
+ for _, row in df_nih.iterrows():
198
+ label_lookup[row["Image Index"]] = row["Finding Labels"]
199
+
200
+ for img_path in images:
201
+ fname = img_path.name
202
+ labels = label_lookup.get(fname, "Unknown")
203
+ rows.append({
204
+ "filename": fname,
205
+ "filepath": str(img_path.resolve()),
206
+ "labels": labels,
207
+ "source": "NIH" if label_lookup else "Unknown",
208
+ })
209
+
210
+ df = pd.DataFrame(rows)
211
+ df.to_csv(output_path, index=False)
212
+ print(f"[OK] Metadata saved β†’ {output_path} ({len(df):,} rows)")
213
+ return df
214
+
215
+
216
+ def main():
217
+ parser = argparse.ArgumentParser(
218
+ description="Download chest X-ray dataset for gallery builder"
219
+ )
220
+ parser.add_argument(
221
+ "--source", choices=SUPPORTED_SOURCES, default="openi",
222
+ help="Dataset source (default: openi – no login required)"
223
+ )
224
+ parser.add_argument(
225
+ "--output_dir", type=Path, default=Path("./data"),
226
+ help="Directory to save images and metadata"
227
+ )
228
+ parser.add_argument(
229
+ "--local_dir", type=Path, default=None,
230
+ help="Path to existing local image folder (use with --source local)"
231
+ )
232
+ args = parser.parse_args()
233
+
234
+ output_dir: Path = args.output_dir.resolve()
235
+ output_dir.mkdir(parents=True, exist_ok=True)
236
+
237
+ if args.source == "openi":
238
+ images_dir = download_openi(output_dir)
239
+ elif args.source == "nih_sample":
240
+ images_dir = download_nih_sample(output_dir)
241
+ nih_meta = download_nih_metadata(output_dir)
242
+ build_metadata_csv(images_dir, nih_meta, output_dir / "metadata.csv")
243
+ return
244
+ elif args.source == "local":
245
+ if not args.local_dir:
246
+ print("[ERROR] --local_dir is required when --source=local")
247
+ sys.exit(1)
248
+ images_dir = args.local_dir.resolve()
249
+ else:
250
+ print(f"[ERROR] Unknown source: {args.source}")
251
+ sys.exit(1)
252
+
253
+ build_metadata_csv(images_dir, None, output_dir / "metadata.csv")
254
+ print("\nβœ… Dataset ready. Next step:")
255
+ print(f" python gallery_builder.py --image_dir {images_dir} --output_dir ./index")
256
+
257
+
258
+ if __name__ == "__main__":
259
+ main()
download_assets.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ download_assets.py
3
+ ------------------
4
+ Downloads index/ and image assets from Google Drive into /var/data.
5
+
6
+ Env vars:
7
+ GDRIVE_INDEX_URL - share link or direct download url for a zip/tar of index/
8
+ GDRIVE_IMAGES_URL - share link or direct download url for a zip/tar of images/
9
+ DATA_DIR - base path (default: /var/data)
10
+ """
11
+
12
+ import os
13
+ import shutil
14
+ import tarfile
15
+ import zipfile
16
+ from pathlib import Path
17
+
18
+ import gdown
19
+ from huggingface_hub import snapshot_download
20
+
21
+
22
+ BIOMEDCLIP_REPO = "microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
23
+ DEFAULT_INDEX_URL = "https://drive.google.com/uc?id=1NwEac0s_qah8L27RO-aFIz2PRfvXC9j0"
24
+ DEFAULT_IMAGES_URL = "https://drive.google.com/uc?id=1LAMNffnw3kFHZXvY9ySR62VxlaChRXyv"
25
+
26
+
27
+ def _download(url: str, dest: Path) -> Path:
28
+ dest.parent.mkdir(parents=True, exist_ok=True)
29
+ if dest.exists():
30
+ return dest
31
+ gdown.download(url, str(dest), quiet=False)
32
+ return dest
33
+
34
+
35
+ def _extract(archive: Path, target_dir: Path) -> None:
36
+ target_dir.mkdir(parents=True, exist_ok=True)
37
+ if zipfile.is_zipfile(archive):
38
+ with zipfile.ZipFile(archive, "r") as zf:
39
+ zf.extractall(target_dir)
40
+ elif tarfile.is_tarfile(archive) or archive.name.endswith((".tgz", ".tar.gz", ".gz")):
41
+ with tarfile.open(archive, "r:*") as tf:
42
+ tf.extractall(target_dir)
43
+ else:
44
+ raise ValueError(f"Unsupported archive: {archive}")
45
+
46
+
47
+ def _ensure_dir(path: Path) -> None:
48
+ path.mkdir(parents=True, exist_ok=True)
49
+
50
+
51
+ def _pick_data_dir() -> Path:
52
+ env_dir = os.getenv("DATA_DIR")
53
+ if env_dir:
54
+ return Path(env_dir).resolve()
55
+ for candidate in (Path("/var/data"), Path("/tmp/medrag_data")):
56
+ try:
57
+ candidate.mkdir(parents=True, exist_ok=True)
58
+ return candidate
59
+ except Exception:
60
+ continue
61
+ return Path("/tmp/medrag_data").resolve()
62
+
63
+
64
+ def _prefetch_biomedclip() -> None:
65
+ cache_dir = Path(os.getenv("HF_HOME", "/tmp/hf_cache")).resolve()
66
+ cache_dir.mkdir(parents=True, exist_ok=True)
67
+ snapshot_download(
68
+ repo_id=BIOMEDCLIP_REPO,
69
+ cache_dir=str(cache_dir),
70
+ local_dir_use_symlinks=False,
71
+ )
72
+ print(f"BiomedCLIP cached in {cache_dir}")
73
+
74
+
75
+ def main():
76
+ data_dir = _pick_data_dir()
77
+ index_dir = data_dir / "index"
78
+ images_dir = data_dir / "images"
79
+
80
+ index_url = os.getenv("GDRIVE_INDEX_URL", DEFAULT_INDEX_URL)
81
+ images_url = os.getenv("GDRIVE_IMAGES_URL", DEFAULT_IMAGES_URL)
82
+
83
+ _ensure_dir(data_dir)
84
+
85
+ if index_dir.exists() and any(index_dir.iterdir()):
86
+ print(f"Index already present at {index_dir}")
87
+ elif index_url:
88
+ archive = data_dir / "index_archive.zip"
89
+ archive = _download(index_url, archive)
90
+ _extract(archive, index_dir)
91
+ print(f"Index extracted to {index_dir}")
92
+ else:
93
+ print("GDRIVE_INDEX_URL not set; index not downloaded.")
94
+
95
+ if images_dir.exists() and any(images_dir.iterdir()):
96
+ print(f"Images already present at {images_dir}")
97
+ elif images_url:
98
+ archive = data_dir / "images_archive.zip"
99
+ archive = _download(images_url, archive)
100
+ _extract(archive, images_dir)
101
+ print(f"Images extracted to {images_dir}")
102
+ else:
103
+ print("GDRIVE_IMAGES_URL not set; images not downloaded.")
104
+
105
+ # cleanup
106
+ for f in [data_dir / "index_archive.zip", data_dir / "images_archive.zip"]:
107
+ if f.exists():
108
+ try:
109
+ if f.is_file():
110
+ f.unlink()
111
+ else:
112
+ shutil.rmtree(f)
113
+ except Exception:
114
+ pass
115
+
116
+ if os.getenv("PREFETCH_MODEL", "1") == "1":
117
+ _prefetch_biomedclip()
118
+
119
+
120
+ if __name__ == "__main__":
121
+ main()
gallery_builder.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ gallery_builder.py
3
+ ──────────────────
4
+ Builds the visual search database for Medical X-ray RAG.
5
+
6
+ Pipeline:
7
+ 1. Load all X-ray images from --image_dir
8
+ 2. Encode each image β†’ 512-dim vector via BiomedCLIP
9
+ 3. Normalize + store in FAISS IndexFlatIP (cosine similarity via dot product)
10
+ 4. Save: visual_db.index (FAISS binary)
11
+ metadata.json (filename β†’ {path, labels, idx})
12
+ embeddings.npy (raw numpy array, optional backup)
13
+
14
+ BiomedCLIP:
15
+ microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224
16
+ Trained on 15M biomedical image-caption pairs from PubMed Central.
17
+ Zero-shot performance on CheXpert = 0.85+ AUC (no fine-tuning needed).
18
+
19
+ Usage:
20
+ python gallery_builder.py \
21
+ --image_dir ./data/openi_images \
22
+ --output_dir ./index \
23
+ --batch_size 64 \
24
+ --device cpu
25
+
26
+ # Resume interrupted build:
27
+ python gallery_builder.py --image_dir ./data/openi_images --resume
28
+
29
+ Output files:
30
+ ./index/visual_db.index ← FAISS binary index
31
+ ./index/metadata.json ← id β†’ {filename, filepath, labels}
32
+ ./index/embeddings.npy ← (N, 512) float32 array
33
+ ./index/build_stats.json ← timing + counts
34
+ """
35
+
36
+ import os
37
+ import sys
38
+ import json
39
+ import time
40
+ import argparse
41
+ import logging
42
+ import numpy as np
43
+ from pathlib import Path
44
+ from typing import Optional
45
+
46
+ import torch
47
+ from torch.utils.data import Dataset, DataLoader
48
+ from PIL import Image, UnidentifiedImageError
49
+ import faiss
50
+ import open_clip
51
+ from tqdm import tqdm
52
+
53
+ logging.basicConfig(
54
+ level=logging.INFO,
55
+ format="%(asctime)s %(levelname)-7s %(message)s",
56
+ datefmt="%H:%M:%S",
57
+ )
58
+ log = logging.getLogger(__name__)
59
+
60
+ # ── Constants ──────────────────────────────────────────────────────────────────
61
+ BIOMEDCLIP_MODEL = "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
62
+ EMBED_DIM = 512
63
+ SUPPORTED_EXTS = {".png", ".jpg", ".jpeg", ".dcm"}
64
+ INDEX_FILE = "visual_db.index"
65
+ METADATA_FILE = "metadata.json"
66
+ EMBEDDINGS_FILE = "embeddings.npy"
67
+ STATS_FILE = "build_stats.json"
68
+
69
+
70
+ # ── Dataset ────────────────────────────────────────────────────────────────────
71
+ class XRayDataset(Dataset):
72
+ """
73
+ Lazy-loading dataset for chest X-ray images.
74
+ Applies BiomedCLIP preprocessing (resize 224, normalize).
75
+ Skips corrupt/unreadable files gracefully.
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ image_paths: list[Path],
81
+ transform,
82
+ metadata_csv_path: Optional[Path] = None,
83
+ ):
84
+ self.paths = image_paths
85
+ self.transform = transform
86
+ self.label_map: dict[str, str] = {}
87
+
88
+ # Optional: load NIH/CheXpert labels CSV
89
+ if metadata_csv_path and metadata_csv_path.exists():
90
+ import pandas as pd
91
+ df = pd.read_csv(metadata_csv_path)
92
+ if "filename" in df.columns and "labels" in df.columns:
93
+ self.label_map = dict(zip(df["filename"], df["labels"].fillna("Unknown")))
94
+
95
+ def __len__(self):
96
+ return len(self.paths)
97
+
98
+ def __getitem__(self, idx: int):
99
+ path = self.paths[idx]
100
+ try:
101
+ img = Image.open(path).convert("RGB")
102
+ tensor = self.transform(img)
103
+ label = self.label_map.get(path.name, "Unknown")
104
+ return tensor, str(path), label, True # (tensor, path, label, valid)
105
+ except (UnidentifiedImageError, OSError, Exception) as e:
106
+ log.warning(f"Skipping corrupt image: {path.name} ({e})")
107
+ # Return a zero tensor so DataLoader batch stays uniform
108
+ dummy = torch.zeros(3, 224, 224)
109
+ return dummy, str(path), "CORRUPT", False
110
+
111
+
112
+ def collate_skip_corrupt(batch):
113
+ """Custom collate: filter out corrupt images before batching."""
114
+ valid = [(t, p, l) for t, p, l, ok in batch if ok]
115
+ if not valid:
116
+ return None
117
+ tensors, paths, labels = zip(*valid)
118
+ return torch.stack(tensors), list(paths), list(labels)
119
+
120
+
121
+ # ── Model loader ───────────────────────────────────────────────────────────────
122
+ def load_biomedclip(device: str):
123
+ """
124
+ Load BiomedCLIP vision encoder from HuggingFace hub.
125
+ Returns (model, transform) where model outputs 512-dim image embeddings.
126
+ """
127
+ log.info("Loading BiomedCLIP from HuggingFace hub (first run downloads ~350 MB)...")
128
+ try:
129
+ model, _, transform = open_clip.create_model_and_transforms(
130
+ BIOMEDCLIP_MODEL
131
+ )
132
+ model = model.to(device).eval()
133
+ log.info(f"BiomedCLIP loaded βœ“ device={device}")
134
+ return model, transform
135
+ except Exception as e:
136
+ log.error(f"Failed to load BiomedCLIP: {e}")
137
+ log.error("Ensure open-clip-torch is installed: pip install open-clip-torch")
138
+ raise
139
+
140
+
141
+ # ── Embedding engine ───────────────────────────────────────────────────────────
142
+ @torch.no_grad()
143
+ def encode_batch(model, image_tensors: torch.Tensor, device: str) -> np.ndarray:
144
+ """Encode a batch of image tensors β†’ L2-normalized embeddings (N, 512)."""
145
+ image_tensors = image_tensors.to(device)
146
+ features = model.encode_image(image_tensors)
147
+ # L2 normalize β†’ cosine similarity = dot product
148
+ features = features / features.norm(dim=-1, keepdim=True)
149
+ return features.cpu().numpy().astype(np.float32)
150
+
151
+
152
+ # ── FAISS index builder ────────────────────────────────────────────────────────
153
+ def build_faiss_index(embeddings: np.ndarray) -> faiss.IndexFlatIP:
154
+ """
155
+ Build FAISS IndexFlatIP (inner product = cosine similarity after L2-norm).
156
+ For galleries > 100K images, swap to IndexIVFFlat for 10x faster search.
157
+ """
158
+ n, d = embeddings.shape
159
+ log.info(f"Building FAISS index ({n:,} vectors Γ— {d} dims)")
160
+
161
+ if n < 10_000:
162
+ # Exact search β€” best for < 10K images
163
+ index = faiss.IndexFlatIP(d)
164
+ else:
165
+ # Approximate search β€” needed for large galleries
166
+ nlist = min(256, n // 39) # IVF rule: nlist β‰ˆ sqrt(N)
167
+ quantizer = faiss.IndexFlatIP(d)
168
+ index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT)
169
+ log.info(f"Training IVF index with nlist={nlist}...")
170
+ index.train(embeddings)
171
+ index.nprobe = 16 # search 16 cells at query time (accuracy vs speed)
172
+
173
+ index.add(embeddings)
174
+ log.info(f"FAISS index built βœ“ total vectors: {index.ntotal:,}")
175
+ return index
176
+
177
+
178
+ # ── Resume support ─────────────────────────────────────────────────────────────
179
+ def load_checkpoint(output_dir: Path) -> tuple[np.ndarray | None, dict | None, int]:
180
+ """Load partial embeddings + metadata if build was interrupted."""
181
+ emb_ckpt = output_dir / "embeddings_checkpoint.npy"
182
+ meta_ckpt = output_dir / "metadata_checkpoint.json"
183
+
184
+ if emb_ckpt.exists() and meta_ckpt.exists():
185
+ embeddings = np.load(emb_ckpt)
186
+ with open(meta_ckpt) as f:
187
+ metadata = json.load(f)
188
+ start_idx = len(metadata)
189
+ log.info(f"[RESUME] Found checkpoint with {start_idx:,} images. Continuing...")
190
+ return embeddings, metadata, start_idx
191
+
192
+ return None, None, 0
193
+
194
+
195
+ def save_checkpoint(output_dir: Path, embeddings: np.ndarray, metadata: dict):
196
+ """Save incremental checkpoint every N batches."""
197
+ np.save(output_dir / "embeddings_checkpoint.npy", embeddings)
198
+ with open(output_dir / "metadata_checkpoint.json", "w") as f:
199
+ json.dump(metadata, f)
200
+
201
+
202
+ # ── Main pipeline ──────────────────────────────────────────────────────────────
203
+ def build_gallery(
204
+ image_dir: Path,
205
+ output_dir: Path,
206
+ batch_size: int = 64,
207
+ device: str = "auto",
208
+ metadata_csv: Optional[Path] = None,
209
+ resume: bool = False,
210
+ checkpoint_every: int = 500,
211
+ ):
212
+ """
213
+ Full pipeline: images β†’ BiomedCLIP embeddings β†’ FAISS index.
214
+
215
+ Args:
216
+ image_dir: Directory containing X-ray images (scanned recursively)
217
+ output_dir: Where to save visual_db.index + metadata.json
218
+ batch_size: Images per GPU/CPU batch (lower if OOM)
219
+ device: "cuda", "cpu", or "auto"
220
+ metadata_csv: Optional CSV with columns: filename, labels
221
+ resume: Resume from last checkpoint if available
222
+ checkpoint_every: Save checkpoint every N images
223
+ """
224
+ t_start = time.time()
225
+ output_dir.mkdir(parents=True, exist_ok=True)
226
+
227
+ # ── Resolve device ─────────────────────────────────────────────────────────
228
+ if device == "auto":
229
+ device = "cuda" if torch.cuda.is_available() else (
230
+ "mps" if torch.backends.mps.is_available() else "cpu"
231
+ )
232
+ log.info(f"Device: {device}")
233
+
234
+ # ── Collect image paths ────────────────────────────────────────────────────
235
+ all_images = sorted([
236
+ p for p in image_dir.rglob("*")
237
+ if p.suffix.lower() in SUPPORTED_EXTS
238
+ ])
239
+ if not all_images:
240
+ raise FileNotFoundError(f"No images found in {image_dir}")
241
+ log.info(f"Found {len(all_images):,} images in {image_dir}")
242
+
243
+ # ── Resume checkpoint ──────────────────────────────────────────────────────
244
+ existing_emb, existing_meta, start_idx = (None, None, 0)
245
+ if resume:
246
+ existing_emb, existing_meta, start_idx = load_checkpoint(output_dir)
247
+
248
+ images_to_process = all_images[start_idx:]
249
+ log.info(f"Images to process: {len(images_to_process):,}")
250
+
251
+ # ── Load BiomedCLIP ────────────────────────────────────────────────────────
252
+ model, transform = load_biomedclip(device)
253
+
254
+ # ── Dataset + DataLoader ───────────────────────────────────────────────────
255
+ dataset = XRayDataset(images_to_process, transform, metadata_csv)
256
+ loader = DataLoader(
257
+ dataset,
258
+ batch_size=batch_size,
259
+ num_workers=min(4, os.cpu_count() or 1),
260
+ pin_memory=(device == "cuda"),
261
+ collate_fn=collate_skip_corrupt,
262
+ prefetch_factor=2 if device == "cuda" else None,
263
+ )
264
+
265
+ # ── Accumulate embeddings ──────────────────────────────────────────────────
266
+ all_embeddings: list[np.ndarray] = []
267
+ all_metadata: dict = existing_meta or {} # id (int) β†’ {filename, filepath, labels}
268
+ global_idx = start_idx
269
+ skipped = 0
270
+
271
+ log.info("Encoding images with BiomedCLIP...")
272
+ for batch in tqdm(loader, desc="Encoding", unit="batch", ncols=80):
273
+ if batch is None:
274
+ continue
275
+ tensors, paths, labels = batch
276
+ batch_emb = encode_batch(model, tensors, device)
277
+
278
+ for i, (path, label) in enumerate(zip(paths, labels)):
279
+ all_embeddings.append(batch_emb[i])
280
+ all_metadata[str(global_idx)] = {
281
+ "filename": Path(path).name,
282
+ "filepath": path,
283
+ "labels": label,
284
+ "idx": global_idx,
285
+ }
286
+ global_idx += 1
287
+
288
+ # Periodic checkpoint
289
+ if global_idx % checkpoint_every < batch_size:
290
+ combined_emb = np.vstack(
291
+ [existing_emb] + all_embeddings
292
+ if existing_emb is not None else all_embeddings
293
+ )
294
+ save_checkpoint(output_dir, combined_emb, all_metadata)
295
+ log.info(f" Checkpoint saved at {global_idx:,} images")
296
+
297
+ if not all_embeddings:
298
+ raise RuntimeError("No valid images were encoded. Check image directory.")
299
+
300
+ # ── Stack all embeddings ───────────────────────────────────────────────────
301
+ new_embeddings = np.vstack(all_embeddings)
302
+ if existing_emb is not None:
303
+ final_embeddings = np.vstack([existing_emb, new_embeddings])
304
+ else:
305
+ final_embeddings = new_embeddings
306
+
307
+ log.info(f"Embeddings shape: {final_embeddings.shape}")
308
+
309
+ # ── Build + save FAISS index ───────────────────────────────────────────────
310
+ index = build_faiss_index(final_embeddings)
311
+ index_path = output_dir / INDEX_FILE
312
+ faiss.write_index(index, str(index_path))
313
+ log.info(f"FAISS index saved β†’ {index_path} ({index_path.stat().st_size / 1e6:.1f} MB)")
314
+
315
+ # ── Save metadata ──────────────────────────────────────────────────────────
316
+ meta_path = output_dir / METADATA_FILE
317
+ with open(meta_path, "w") as f:
318
+ json.dump(all_metadata, f, indent=2)
319
+ log.info(f"Metadata saved β†’ {meta_path}")
320
+
321
+ # ── Save raw embeddings (optional, useful for offline analysis) ────────────
322
+ emb_path = output_dir / EMBEDDINGS_FILE
323
+ np.save(emb_path, final_embeddings)
324
+ log.info(f"Embeddings saved β†’ {emb_path} ({emb_path.stat().st_size / 1e6:.1f} MB)")
325
+
326
+ # ── Clean up checkpoints ───────────────────────────────────────────────────
327
+ for ckpt in ["embeddings_checkpoint.npy", "metadata_checkpoint.json"]:
328
+ ckpt_path = output_dir / ckpt
329
+ if ckpt_path.exists():
330
+ ckpt_path.unlink()
331
+
332
+ # ── Build stats ────────────────────────────────────────────────────────────
333
+ elapsed = time.time() - t_start
334
+ stats = {
335
+ "total_images": index.ntotal,
336
+ "skipped": skipped,
337
+ "embed_dim": EMBED_DIM,
338
+ "model": BIOMEDCLIP_MODEL,
339
+ "index_type": type(index).__name__,
340
+ "build_time_sec": round(elapsed, 1),
341
+ "throughput_img_per_sec": round(index.ntotal / elapsed, 1),
342
+ "index_size_mb": round(index_path.stat().st_size / 1e6, 2),
343
+ "device": device,
344
+ }
345
+ with open(output_dir / STATS_FILE, "w") as f:
346
+ json.dump(stats, f, indent=2)
347
+
348
+ log.info("=" * 55)
349
+ log.info(f"βœ… Gallery build complete!")
350
+ log.info(f" Images indexed : {index.ntotal:,}")
351
+ log.info(f" Build time : {elapsed:.0f}s ({stats['throughput_img_per_sec']} img/s)")
352
+ log.info(f" Index size : {stats['index_size_mb']} MB")
353
+ log.info(f" Output dir : {output_dir.resolve()}")
354
+ log.info("=" * 55)
355
+
356
+ return index, all_metadata
357
+
358
+
359
+ # ── CLI ────────────────────────────────────────────────────────────────────────
360
+ def main():
361
+ parser = argparse.ArgumentParser(
362
+ description="Build FAISS visual search index from chest X-ray images"
363
+ )
364
+ parser.add_argument(
365
+ "--image_dir", type=Path, required=True,
366
+ help="Root directory containing X-ray images (searched recursively)"
367
+ )
368
+ parser.add_argument(
369
+ "--output_dir", type=Path, default=Path("./index"),
370
+ help="Where to save visual_db.index + metadata.json (default: ./index)"
371
+ )
372
+ parser.add_argument(
373
+ "--batch_size", type=int, default=64,
374
+ help="Batch size for encoding. Reduce to 16 if CPU RAM < 8 GB (default: 64)"
375
+ )
376
+ parser.add_argument(
377
+ "--device", choices=["auto", "cuda", "cpu", "mps"], default="auto",
378
+ help="Compute device (default: auto-detect)"
379
+ )
380
+ parser.add_argument(
381
+ "--metadata_csv", type=Path, default=None,
382
+ help="Optional CSV with columns: filename, labels"
383
+ )
384
+ parser.add_argument(
385
+ "--resume", action="store_true",
386
+ help="Resume from last checkpoint if build was interrupted"
387
+ )
388
+ parser.add_argument(
389
+ "--checkpoint_every", type=int, default=500,
390
+ help="Save checkpoint every N images (default: 500)"
391
+ )
392
+ args = parser.parse_args()
393
+
394
+ build_gallery(
395
+ image_dir=args.image_dir.resolve(),
396
+ output_dir=args.output_dir.resolve(),
397
+ batch_size=args.batch_size,
398
+ device=args.device,
399
+ metadata_csv=args.metadata_csv,
400
+ resume=args.resume,
401
+ checkpoint_every=args.checkpoint_every,
402
+ )
403
+
404
+
405
+ if __name__ == "__main__":
406
+ main()
render.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ - type: web
3
+ name: medrag-app
4
+ env: python
5
+ plan: free
6
+ buildCommand: pip install --index-url https://download.pytorch.org/whl/cpu torch torchvision && pip install -r requirements.txt
7
+ startCommand: python download_assets.py && streamlit run app.py --server.port $PORT --server.address 0.0.0.0
8
+ envVars:
9
+ - key: PYTHONUNBUFFERED
10
+ value: "1"
11
+ - key: DATA_DIR
12
+ value: /tmp/medrag_data
13
+ - key: HF_HOME
14
+ value: /tmp/hf_cache
15
+ - key: PREFETCH_MODEL
16
+ value: "1"
17
+ - key: GDRIVE_INDEX_URL
18
+ value: ""
19
+ - key: GDRIVE_IMAGES_URL
20
+ value: ""
requirements-space.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ open-clip-torch>=2.24.0
2
+ faiss-cpu>=1.7.4
3
+ Pillow>=10.0.0
4
+ numpy>=1.24.0
5
+ tqdm>=4.66.0
6
+ requests>=2.31.0
7
+ pandas>=2.0.0
8
+ streamlit>=1.31.0
9
+ gdown>=5.1.0
10
+ huggingface-hub>=0.28.0
11
+ transformers>=4.30.0,<5
requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gallery Builder – Python Dependencies
2
+ # Install with: pip install -r requirements.txt
3
+ #
4
+ # GPU support (recommended for faster encoding):
5
+ # pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
6
+ #
7
+ # CPU only:
8
+ # pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
9
+
10
+ # Core ML
11
+ torch>=2.1.0
12
+ torchvision>=0.16.0
13
+ open-clip-torch>=2.24.0 # BiomedCLIP lives here
14
+
15
+ # Vector database
16
+ faiss-cpu>=1.7.4 # swap for faiss-gpu if CUDA available
17
+
18
+ # Image processing
19
+ Pillow>=10.0.0
20
+ numpy>=1.24.0
21
+
22
+ # Utilities
23
+ tqdm>=4.66.0
24
+ requests>=2.31.0
25
+ pandas>=2.0.0
26
+ streamlit>=1.31.0
27
+ gdown>=5.1.0
28
+
29
+ # Testing
30
+ pytest>=7.4.0
31
+ pytest-cov>=4.1.0
rewrite_metadata.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ rewrite_metadata.py
3
+ -------------------
4
+ Utility to rewrite metadata.json filepaths for deployment.
5
+
6
+ Example:
7
+ python rewrite_metadata.py \
8
+ --index_dir ./index \
9
+ --from_prefix "/Users/you/MedRAG/data/train" \
10
+ --to_prefix "/var/data/images"
11
+ """
12
+
13
+ import argparse
14
+ import json
15
+ from pathlib import Path
16
+
17
+
18
+ def main():
19
+ parser = argparse.ArgumentParser(description="Rewrite metadata.json filepaths")
20
+ parser.add_argument("--index_dir", type=Path, default=Path("./index"))
21
+ parser.add_argument("--from_prefix", required=True)
22
+ parser.add_argument("--to_prefix", required=True)
23
+ args = parser.parse_args()
24
+
25
+ meta_path = args.index_dir / "metadata.json"
26
+ if not meta_path.exists():
27
+ raise FileNotFoundError(f"metadata.json not found: {meta_path}")
28
+
29
+ data = json.loads(meta_path.read_text())
30
+ updated = 0
31
+
32
+ for _, entry in data.items():
33
+ fp = entry.get("filepath", "")
34
+ if fp.startswith(args.from_prefix):
35
+ entry["filepath"] = fp.replace(args.from_prefix, args.to_prefix, 1)
36
+ updated += 1
37
+
38
+ meta_path.write_text(json.dumps(data, indent=2))
39
+ print(f"Rewrote {updated} filepaths in {meta_path}")
40
+
41
+
42
+ if __name__ == "__main__":
43
+ main()
start.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ export DATA_DIR="${DATA_DIR:-/tmp/medrag_data}"
5
+ export HF_HOME="${HF_HOME:-/tmp/hf_cache}"
6
+ export PREFETCH_MODEL="${PREFETCH_MODEL:-1}"
7
+
8
+ python download_assets.py
9
+ exec streamlit run app.py --server.port 7860 --server.address 0.0.0.0
test_visual_search.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ test_visual_search.py
3
+ ─────────────────────
4
+ Unit + integration tests for the gallery builder pipeline.
5
+
6
+ Run:
7
+ # Fast unit tests (no model needed):
8
+ pytest test_visual_search.py -v -m "not integration"
9
+
10
+ # Full integration test (requires built index):
11
+ pytest test_visual_search.py -v --index_dir ./index --image_dir ./data
12
+ """
13
+
14
+ import json
15
+ import tempfile
16
+ import numpy as np
17
+ import pytest
18
+ from pathlib import Path
19
+ from unittest.mock import patch, MagicMock
20
+ from PIL import Image
21
+
22
+
23
+ # ── Fixtures ───────────────────────────────────────────────────────────────────
24
+ @pytest.fixture
25
+ def dummy_index_dir(tmp_path):
26
+ """Create a minimal fake FAISS index + metadata for unit tests."""
27
+ import faiss
28
+
29
+ d = 512
30
+ n = 20
31
+ embeddings = np.random.randn(n, d).astype(np.float32)
32
+ # L2 normalize
33
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
34
+ embeddings /= norms
35
+
36
+ index = faiss.IndexFlatIP(d)
37
+ index.add(embeddings)
38
+ faiss.write_index(index, str(tmp_path / "visual_db.index"))
39
+
40
+ metadata = {
41
+ str(i): {
42
+ "filename": f"image_{i:04d}.png",
43
+ "filepath": str(tmp_path / f"image_{i:04d}.png"),
44
+ "labels": "Pneumonia" if i % 3 == 0 else "No Finding",
45
+ "idx": i,
46
+ }
47
+ for i in range(n)
48
+ }
49
+ with open(tmp_path / "metadata.json", "w") as f:
50
+ json.dump(metadata, f)
51
+
52
+ # Create dummy PNG files
53
+ for i in range(n):
54
+ img = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8))
55
+ img.save(tmp_path / f"image_{i:04d}.png")
56
+
57
+ return tmp_path, embeddings
58
+
59
+
60
+ @pytest.fixture
61
+ def dummy_xray_image(tmp_path) -> Path:
62
+ """Create a fake grayscale X-ray image."""
63
+ img_array = np.random.randint(0, 255, (224, 224), dtype=np.uint8)
64
+ img = Image.fromarray(img_array, mode="L").convert("RGB")
65
+ path = tmp_path / "test_xray.png"
66
+ img.save(path)
67
+ return path
68
+
69
+
70
+ # ── Unit tests ─────────────────────────────────────────────────────────────────
71
+ class TestSearchResult:
72
+ def test_to_dict(self):
73
+ from visual_search import SearchResult
74
+ r = SearchResult(rank=1, idx=5, filename="img.png",
75
+ filepath="/data/img.png", labels="Pneumonia",
76
+ similarity=0.87654)
77
+ d = r.to_dict()
78
+ assert d["rank"] == 1
79
+ assert d["similarity"] == 0.8765 # rounded to 4 decimal places
80
+ assert d["labels"] == "Pneumonia"
81
+ assert "image" not in d # PIL image not serialized
82
+
83
+
84
+ class TestFAISSIndex:
85
+ """Test FAISS index properties independent of BiomedCLIP."""
86
+
87
+ def test_build_flat_index(self):
88
+ import faiss
89
+ d, n = 512, 100
90
+ emb = np.random.randn(n, d).astype(np.float32)
91
+ emb /= np.linalg.norm(emb, axis=1, keepdims=True)
92
+
93
+ index = faiss.IndexFlatIP(d)
94
+ index.add(emb)
95
+ assert index.ntotal == n
96
+
97
+ def test_search_returns_correct_k(self):
98
+ import faiss
99
+ d, n = 512, 50
100
+ emb = np.random.randn(n, d).astype(np.float32)
101
+ emb /= np.linalg.norm(emb, axis=1, keepdims=True)
102
+
103
+ index = faiss.IndexFlatIP(d)
104
+ index.add(emb)
105
+
106
+ query = emb[0:1] # use first vector as query
107
+ sims, idxs = index.search(query, k=5)
108
+ assert sims.shape == (1, 5)
109
+ assert idxs.shape == (1, 5)
110
+ # Self-match should be first with similarity β‰ˆ 1.0
111
+ assert abs(sims[0][0] - 1.0) < 1e-5
112
+ assert idxs[0][0] == 0
113
+
114
+ def test_cosine_similarity_via_dot_product(self):
115
+ """L2-normalized dot product = cosine similarity."""
116
+ import faiss
117
+ d = 512
118
+ # Two identical vectors should have similarity 1.0
119
+ v = np.random.randn(1, d).astype(np.float32)
120
+ v /= np.linalg.norm(v)
121
+
122
+ index = faiss.IndexFlatIP(d)
123
+ index.add(v)
124
+
125
+ sims, _ = index.search(v, k=1)
126
+ assert abs(sims[0][0] - 1.0) < 1e-5
127
+
128
+ def test_ivf_index_for_large_gallery(self):
129
+ """IVF index works for large galleries (>10K vectors)."""
130
+ import faiss
131
+ d, n = 512, 10_000
132
+ emb = np.random.randn(n, d).astype(np.float32)
133
+ emb /= np.linalg.norm(emb, axis=1, keepdims=True)
134
+
135
+ nlist = 64
136
+ quantizer = faiss.IndexFlatIP(d)
137
+ index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT)
138
+ index.train(emb)
139
+ index.add(emb)
140
+ index.nprobe = 8
141
+
142
+ assert index.ntotal == n
143
+ # Check that search still works
144
+ sims, idxs = index.search(emb[0:1], k=5)
145
+ assert idxs[0][0] == 0 # self should be top result
146
+
147
+
148
+ class TestMetadataBuilding:
149
+ def test_metadata_keys(self, dummy_index_dir):
150
+ _, embeddings = dummy_index_dir
151
+ meta_path = dummy_index_dir[0] / "metadata.json"
152
+ with open(meta_path) as f:
153
+ meta = json.load(f)
154
+ assert "0" in meta
155
+ entry = meta["0"]
156
+ assert "filename" in entry
157
+ assert "filepath" in entry
158
+ assert "labels" in entry
159
+ assert "idx" in entry
160
+
161
+ def test_metadata_count_matches_index(self, dummy_index_dir):
162
+ import faiss
163
+ index_dir = dummy_index_dir[0]
164
+ index = faiss.read_index(str(index_dir / "visual_db.index"))
165
+ with open(index_dir / "metadata.json") as f:
166
+ meta = json.load(f)
167
+ assert index.ntotal == len(meta)
168
+
169
+
170
+ class TestVisualSearchEngine:
171
+ """Tests using mocked BiomedCLIP to avoid model download."""
172
+
173
+ def _get_engine_with_mock_model(self, index_dir):
174
+ """Create engine with BiomedCLIP mocked out."""
175
+ from visual_search import VisualSearchEngine
176
+ import faiss
177
+
178
+ with patch("visual_search.open_clip.create_model_and_transforms") as mock_create:
179
+ mock_model = MagicMock()
180
+ mock_transform = MagicMock(return_value=MagicMock(
181
+ unsqueeze=lambda _: MagicMock(to=lambda _: MagicMock())
182
+ ))
183
+ mock_create.return_value = (mock_model, None, mock_transform)
184
+
185
+ engine = VisualSearchEngine(index_dir=index_dir, device="cpu")
186
+
187
+ # Mock the embed function to return a random normalized vector
188
+ def fake_embed(img):
189
+ v = np.random.randn(1, 512).astype(np.float32)
190
+ v /= np.linalg.norm(v, axis=1, keepdims=True)
191
+ return v
192
+
193
+ engine._embed_image = fake_embed
194
+ return engine
195
+
196
+ def test_search_returns_k_results(self, dummy_index_dir):
197
+ index_dir = dummy_index_dir[0]
198
+ engine = self._get_engine_with_mock_model(index_dir)
199
+
200
+ dummy_img = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
201
+ results = engine.search(dummy_img, top_k=5)
202
+ assert len(results) == 5
203
+
204
+ def test_results_sorted_by_similarity(self, dummy_index_dir):
205
+ index_dir = dummy_index_dir[0]
206
+ engine = self._get_engine_with_mock_model(index_dir)
207
+
208
+ dummy_img = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
209
+ results = engine.search(dummy_img, top_k=5)
210
+ sims = [r.similarity for r in results]
211
+ assert sims == sorted(sims, reverse=True)
212
+
213
+ def test_results_have_required_fields(self, dummy_index_dir):
214
+ index_dir = dummy_index_dir[0]
215
+ engine = self._get_engine_with_mock_model(index_dir)
216
+
217
+ dummy_img = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
218
+ results = engine.search(dummy_img, top_k=3)
219
+ for r in results:
220
+ assert hasattr(r, "rank")
221
+ assert hasattr(r, "filename")
222
+ assert hasattr(r, "filepath")
223
+ assert hasattr(r, "labels")
224
+ assert hasattr(r, "similarity")
225
+ assert 0.0 <= r.similarity <= 1.0
226
+
227
+ def test_ranks_are_sequential(self, dummy_index_dir):
228
+ index_dir = dummy_index_dir[0]
229
+ engine = self._get_engine_with_mock_model(index_dir)
230
+
231
+ dummy_img = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
232
+ results = engine.search(dummy_img, top_k=5)
233
+ for i, r in enumerate(results, start=1):
234
+ assert r.rank == i
235
+
236
+ def test_file_not_found_raises(self, dummy_index_dir):
237
+ index_dir = dummy_index_dir[0]
238
+ engine = self._get_engine_with_mock_model(index_dir)
239
+ with pytest.raises(FileNotFoundError):
240
+ engine.search("/nonexistent/image.png")
241
+
242
+ def test_batch_search(self, dummy_index_dir):
243
+ index_dir = dummy_index_dir[0]
244
+ engine = self._get_engine_with_mock_model(index_dir)
245
+
246
+ imgs = [
247
+ Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
248
+ for _ in range(3)
249
+ ]
250
+ batch_results = engine.search_batch(imgs, top_k=5)
251
+ assert len(batch_results) == 3
252
+ assert all(len(r) == 5 for r in batch_results)
253
+
254
+ def test_get_stats(self, dummy_index_dir):
255
+ index_dir = dummy_index_dir[0]
256
+ engine = self._get_engine_with_mock_model(index_dir)
257
+ stats = engine.get_stats()
258
+ assert "total_images" in stats
259
+ assert stats["total_images"] == 20
260
+ assert stats["embed_dim"] == 512
261
+
262
+ def test_to_dict_serializable(self, dummy_index_dir):
263
+ """Search results must be JSON serializable for API responses."""
264
+ index_dir = dummy_index_dir[0]
265
+ engine = self._get_engine_with_mock_model(index_dir)
266
+
267
+ dummy_img = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
268
+ results = engine.search(dummy_img, top_k=3)
269
+ payload = [r.to_dict() for r in results]
270
+ assert json.dumps(payload) # raises if not serializable
271
+
272
+
273
+ # ── Integration tests (require real index) ─────────────────────────────────────
274
+ @pytest.mark.integration
275
+ class TestIntegration:
276
+ """Run with: pytest -m integration --index_dir ./index --image_dir ./data"""
277
+
278
+ @pytest.fixture(autouse=True)
279
+ def setup(self, request):
280
+ self.index_dir = Path(request.config.getoption("--index_dir", default="./index"))
281
+ self.image_dir = Path(request.config.getoption("--image_dir", default="./data"))
282
+
283
+ def test_real_search(self):
284
+ from visual_search import VisualSearchEngine
285
+ engine = VisualSearchEngine(self.index_dir, device="cpu")
286
+ stats = engine.get_stats()
287
+ assert stats["total_images"] > 0
288
+ print(f"\nIndex contains {stats['total_images']:,} images")
289
+
290
+ def test_search_with_real_image(self):
291
+ from visual_search import VisualSearchEngine
292
+ engine = VisualSearchEngine(self.index_dir, device="cpu")
293
+
294
+ # Find first image in data dir
295
+ images = list(self.image_dir.rglob("*.png"))[:1]
296
+ if not images:
297
+ pytest.skip("No test images found")
298
+
299
+ results = engine.search(images[0], top_k=5, exclude_perfect_match=True)
300
+ assert len(results) > 0
301
+ assert results[0].similarity <= 1.0
302
+ print(f"\nTop result: {results[0].filename} sim={results[0].similarity:.3f}")
303
+
304
+
305
+ # ── Pytest config ──────────────────────────────────────────────────────────────
306
+ def pytest_addoption(parser):
307
+ parser.addoption("--index_dir", action="store", default="./index")
308
+ parser.addoption("--image_dir", action="store", default="./data")
visual_search.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ visual_search.py
3
+ ────────────────
4
+ Search function for the Medical X-ray RAG system.
5
+
6
+ Input: A chest X-ray image (file path or PIL Image or numpy array)
7
+ Output: Top-K most similar cases from the gallery database
8
+
9
+ This is the module imported by your web app and RAG pipeline.
10
+
11
+ Usage:
12
+ from visual_search import VisualSearchEngine
13
+
14
+ engine = VisualSearchEngine(
15
+ index_dir="./index",
16
+ device="auto"
17
+ )
18
+
19
+ results = engine.search("./query_xray.png", top_k=5)
20
+ # returns List[SearchResult]
21
+ for r in results:
22
+ print(f"{r.rank}. {r.filename} sim={r.similarity:.3f} labels={r.labels}")
23
+ """
24
+
25
+ import json
26
+ import time
27
+ import logging
28
+ import numpy as np
29
+ from pathlib import Path
30
+ from dataclasses import dataclass, field
31
+ from typing import Union, Optional
32
+
33
+ import faiss
34
+ import torch
35
+ import open_clip
36
+ from PIL import Image, UnidentifiedImageError
37
+
38
+ log = logging.getLogger(__name__)
39
+
40
+ # ── Constants (must match gallery_builder.py) ──────────────────────────────────
41
+ BIOMEDCLIP_MODEL = "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
42
+ INDEX_FILE = "visual_db.index"
43
+ METADATA_FILE = "metadata.json"
44
+
45
+
46
+ # ── Result dataclass ───────────────────────────────────────────────────────────
47
+ @dataclass
48
+ class SearchResult:
49
+ """One similar case returned by the search engine."""
50
+ rank: int # 1 = most similar
51
+ idx: int # Internal FAISS index ID
52
+ filename: str # Image filename
53
+ filepath: str # Absolute path to the image
54
+ labels: str # Diagnosis labels (from metadata)
55
+ similarity: float # Cosine similarity [0, 1]
56
+ image: Optional[object] = field(default=None, repr=False)
57
+ # ↑ Optionally loaded PIL Image (set load_images=True in search())
58
+
59
+ def to_dict(self) -> dict:
60
+ return {
61
+ "rank": self.rank,
62
+ "idx": self.idx,
63
+ "filename": self.filename,
64
+ "filepath": self.filepath,
65
+ "labels": self.labels,
66
+ "similarity": round(float(self.similarity), 4),
67
+ }
68
+
69
+
70
+ # ── Search Engine ──────────────────────────────────────────────────────────────
71
+ class VisualSearchEngine:
72
+ """
73
+ Thread-safe visual search engine for chest X-ray similarity retrieval.
74
+
75
+ Architecture:
76
+ Query image
77
+ β”‚
78
+ β–Ό
79
+ BiomedCLIP vision encoder β†’ 512-dim embedding (L2 normalized)
80
+ β”‚
81
+ β–Ό
82
+ FAISS IndexFlatIP β†’ cosine similarity search
83
+ β”‚
84
+ β–Ό
85
+ Top-K results + metadata
86
+
87
+ Attributes:
88
+ index_dir (Path): Directory containing visual_db.index + metadata.json
89
+ device (str): Compute device for BiomedCLIP
90
+ top_k (int): Default number of results to return
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ index_dir: Union[str, Path],
96
+ device: str = "auto",
97
+ top_k: int = 5,
98
+ ):
99
+ self.index_dir = Path(index_dir).resolve()
100
+ self.top_k = top_k
101
+ self._model = None
102
+ self._transform = None
103
+ self._index = None
104
+ self._metadata: dict = {}
105
+
106
+ # Resolve device
107
+ if device == "auto":
108
+ if torch.cuda.is_available():
109
+ self.device = "cuda"
110
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
111
+ self.device = "mps"
112
+ else:
113
+ self.device = "cpu"
114
+ else:
115
+ self.device = device
116
+
117
+ # Eager load
118
+ self._load_index()
119
+ self._load_model()
120
+ log.info(f"VisualSearchEngine ready (index={self._index.ntotal:,} images, device={self.device})")
121
+
122
+ # ── Private loaders ────────────────────────────────────────────────────────
123
+ def _load_index(self):
124
+ """Load FAISS index + metadata from disk."""
125
+ index_path = self.index_dir / INDEX_FILE
126
+ meta_path = self.index_dir / METADATA_FILE
127
+
128
+ if not index_path.exists():
129
+ raise FileNotFoundError(
130
+ f"FAISS index not found: {index_path}\n"
131
+ "Run: python gallery_builder.py --image_dir ./data --output_dir ./index"
132
+ )
133
+ if not meta_path.exists():
134
+ raise FileNotFoundError(f"Metadata file not found: {meta_path}")
135
+
136
+ log.info(f"Loading FAISS index from {index_path}...")
137
+ self._index = faiss.read_index(str(index_path))
138
+
139
+ # For IVF indexes, set nprobe for recall/speed tradeoff
140
+ if hasattr(self._index, "nprobe"):
141
+ self._index.nprobe = 16
142
+
143
+ log.info(f"Index loaded ({self._index.ntotal:,} vectors, dim={self._index.d})")
144
+
145
+ with open(meta_path) as f:
146
+ self._metadata = json.load(f)
147
+
148
+ def _load_model(self):
149
+ """Load BiomedCLIP vision encoder."""
150
+ log.info("Loading BiomedCLIP encoder...")
151
+ model, _, transform = open_clip.create_model_and_transforms(BIOMEDCLIP_MODEL)
152
+ self._model = model.to(self.device).eval()
153
+ self._transform = transform
154
+ log.info("BiomedCLIP loaded βœ“")
155
+
156
+ # ── Embedding ──────────────────────────────────────────────────────────────
157
+ @torch.no_grad()
158
+ def _embed_image(self, image: Image.Image) -> np.ndarray:
159
+ """
160
+ Encode a single PIL image β†’ L2-normalized 512-dim embedding.
161
+ Returns shape (1, 512) float32 numpy array.
162
+ """
163
+ tensor = self._transform(image).unsqueeze(0).to(self.device)
164
+ features = self._model.encode_image(tensor)
165
+ features = features / features.norm(dim=-1, keepdim=True)
166
+ return features.cpu().numpy().astype(np.float32)
167
+
168
+ # ── Public API ─────────────────────────────────────────────────────────────
169
+ def search(
170
+ self,
171
+ query: Union[str, Path, Image.Image, np.ndarray],
172
+ top_k: Optional[int] = None,
173
+ load_images: bool = False,
174
+ exclude_perfect_match: bool = False,
175
+ ) -> list[SearchResult]:
176
+ """
177
+ Find the top-K most similar X-ray images to a query.
178
+
179
+ Args:
180
+ query: File path, PIL Image, or RGB numpy array
181
+ top_k: Number of results (overrides default)
182
+ load_images: Load PIL Images into SearchResult.image
183
+ exclude_perfect_match: Skip results with similarity β‰₯ 0.9999
184
+ (use when query is in the gallery itself)
185
+
186
+ Returns:
187
+ List[SearchResult] ordered by descending similarity
188
+ """
189
+ t0 = time.perf_counter()
190
+ k = top_k or self.top_k
191
+
192
+ # ── Load query image ───────────────────────────────────────────────────
193
+ if isinstance(query, (str, Path)):
194
+ query_path = Path(query)
195
+ if not query_path.exists():
196
+ raise FileNotFoundError(f"Query image not found: {query_path}")
197
+ try:
198
+ img = Image.open(query_path).convert("RGB")
199
+ except (UnidentifiedImageError, OSError) as e:
200
+ raise ValueError(f"Cannot open image: {query_path} ({e})")
201
+
202
+ elif isinstance(query, np.ndarray):
203
+ img = Image.fromarray(query.astype(np.uint8))
204
+
205
+ elif isinstance(query, Image.Image):
206
+ img = query.convert("RGB")
207
+
208
+ else:
209
+ raise TypeError(f"Unsupported query type: {type(query)}")
210
+
211
+ # ── Encode ─────────────────────────────────────────────────────────────
212
+ query_emb = self._embed_image(img) # (1, 512)
213
+
214
+ # ── FAISS search ───────────────────────────────────────────────────────
215
+ search_k = k + 1 if exclude_perfect_match else k
216
+ similarities, indices = self._index.search(query_emb, search_k)
217
+ similarities = similarities[0] # (k,)
218
+ indices = indices[0] # (k,)
219
+
220
+ # ── Build results ──────────────────────────────────────────────────────
221
+ results: list[SearchResult] = []
222
+ rank = 1
223
+ for sim, idx in zip(similarities, indices):
224
+ if idx < 0: # FAISS returns -1 for empty slots
225
+ continue
226
+ if exclude_perfect_match and float(sim) >= 0.9999:
227
+ continue # skip exact self-match
228
+
229
+ meta = self._metadata.get(str(idx), {})
230
+ filepath = meta.get("filepath", "")
231
+
232
+ result = SearchResult(
233
+ rank=rank,
234
+ idx=int(idx),
235
+ filename=meta.get("filename", f"image_{idx}"),
236
+ filepath=filepath,
237
+ labels=meta.get("labels", "Unknown"),
238
+ similarity=float(sim),
239
+ )
240
+
241
+ if load_images and filepath and Path(filepath).exists():
242
+ try:
243
+ result.image = Image.open(filepath).convert("RGB")
244
+ except Exception:
245
+ pass # image loading is best-effort
246
+
247
+ results.append(result)
248
+ rank += 1
249
+ if len(results) >= k:
250
+ break
251
+
252
+ elapsed_ms = (time.perf_counter() - t0) * 1000
253
+ log.debug(f"Search completed in {elapsed_ms:.1f} ms β†’ {len(results)} results")
254
+ return results
255
+
256
+ def search_batch(
257
+ self,
258
+ queries: list[Union[str, Path, Image.Image]],
259
+ top_k: Optional[int] = None,
260
+ ) -> list[list[SearchResult]]:
261
+ """
262
+ Batch search for multiple query images.
263
+ More efficient than calling search() in a loop.
264
+ """
265
+ k = top_k or self.top_k
266
+ embeddings = []
267
+
268
+ for q in queries:
269
+ if isinstance(q, (str, Path)):
270
+ img = Image.open(q).convert("RGB")
271
+ elif isinstance(q, np.ndarray):
272
+ img = Image.fromarray(q.astype(np.uint8))
273
+ else:
274
+ img = q.convert("RGB")
275
+ embeddings.append(self._embed_image(img)[0])
276
+
277
+ batch_emb = np.stack(embeddings) # (N, 512)
278
+ sims_batch, idxs_batch = self._index.search(batch_emb, k)
279
+
280
+ all_results = []
281
+ for sims, idxs in zip(sims_batch, idxs_batch):
282
+ results = []
283
+ for rank, (sim, idx) in enumerate(zip(sims, idxs), start=1):
284
+ if idx < 0:
285
+ continue
286
+ meta = self._metadata.get(str(idx), {})
287
+ results.append(SearchResult(
288
+ rank=rank,
289
+ idx=int(idx),
290
+ filename=meta.get("filename", f"image_{idx}"),
291
+ filepath=meta.get("filepath", ""),
292
+ labels=meta.get("labels", "Unknown"),
293
+ similarity=float(sim),
294
+ ))
295
+ all_results.append(results)
296
+
297
+ return all_results
298
+
299
+ def get_stats(self) -> dict:
300
+ """Return index statistics."""
301
+ return {
302
+ "total_images": self._index.ntotal,
303
+ "embed_dim": self._index.d,
304
+ "index_type": type(self._index).__name__,
305
+ "device": self.device,
306
+ "index_dir": str(self.index_dir),
307
+ }
308
+
309
+ def __repr__(self) -> str:
310
+ return (
311
+ f"VisualSearchEngine("
312
+ f"images={self._index.ntotal:,}, "
313
+ f"device={self.device}, "
314
+ f"index_dir={self.index_dir})"
315
+ )
316
+
317
+
318
+ # ── Standalone CLI ─────────────────────────────────────────────────────────────
319
+ def main():
320
+ import argparse
321
+ from pprint import pprint
322
+
323
+ parser = argparse.ArgumentParser(
324
+ description="Search for similar X-ray images"
325
+ )
326
+ parser.add_argument("query_image", type=Path, help="Path to query X-ray image")
327
+ parser.add_argument(
328
+ "--index_dir", type=Path, default=Path("./index"),
329
+ help="Directory with visual_db.index (default: ./index)"
330
+ )
331
+ parser.add_argument("--top_k", type=int, default=5)
332
+ parser.add_argument("--device", default="auto")
333
+ args = parser.parse_args()
334
+
335
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
336
+
337
+ engine = VisualSearchEngine(
338
+ index_dir=args.index_dir,
339
+ device=args.device,
340
+ top_k=args.top_k,
341
+ )
342
+
343
+ print(f"\nπŸ” Query: {args.query_image}")
344
+ print("=" * 60)
345
+ results = engine.search(args.query_image, exclude_perfect_match=True)
346
+
347
+ for r in results:
348
+ bar = "β–ˆ" * int(r.similarity * 30)
349
+ print(f" #{r.rank} {r.similarity:.3f} {bar}")
350
+ print(f" {r.filename}")
351
+ print(f" Labels: {r.labels}")
352
+ print()
353
+
354
+ print(f"Index stats: {engine.get_stats()}")
355
+
356
+
357
+ if __name__ == "__main__":
358
+ main()