saad003's picture
Update app.py
bc08e23 verified
import os
from typing import List, Tuple
import faiss
import numpy as np
import pandas as pd
from PIL import Image
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
from transformers import (
CLIPModel,
CLIPProcessor,
AutoProcessor,
BlipForConditionalGeneration,
)
# =========================
# CONFIG
# =========================
DATASET_REPO = "saad003/Dataset_final" # where embeddings + faiss + metadata live
IMAGES_REPO = "saad003/images" # where the radiology images live
CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
CAPTION_MODEL_ID = "WafaaFraih/blip-roco-radiology-captioning" # BLIP radiology
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# =========================
# LOAD MODELS
# =========================
print("Loading CLIP model...")
clip_model = CLIPModel.from_pretrained(CLIP_MODEL_ID).to(DEVICE)
clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID)
clip_model.eval()
print("Loading caption model...")
caption_processor = AutoProcessor.from_pretrained(CAPTION_MODEL_ID)
caption_model = BlipForConditionalGeneration.from_pretrained(
CAPTION_MODEL_ID
).to(DEVICE)
caption_model.eval()
# =========================
# LOAD INDEX + METADATA
# =========================
print("Loading FAISS index + embeddings + metadata...")
embeddings_path = hf_hub_download(DATASET_REPO, "embeddings.npy")
index_path = hf_hub_download(DATASET_REPO, "image_index.faiss")
EMBEDDINGS = np.load(embeddings_path).astype("float32")
INDEX = faiss.read_index(index_path)
# metadata: parquet preferred, else csv
try:
meta_path = hf_hub_download(DATASET_REPO, "metadata.parquet")
METADATA = pd.read_parquet(meta_path)
print("Loaded metadata.parquet")
except Exception:
meta_path = hf_hub_download(DATASET_REPO, "metadata.csv")
METADATA = pd.read_csv(meta_path)
print("Loaded metadata.csv")
print("Metadata columns:", list(METADATA.columns))
def pick_column(candidates: List[str]) -> str:
"""Pick first existing column name from candidates."""
for c in candidates:
if c in METADATA.columns:
return c
raise RuntimeError(
f"None of {candidates} found in metadata columns: {list(METADATA.columns)}"
)
# Adjust these if my guesses are wrong; check your metadata file on HF
IMAGE_COL = pick_column(
["image_path", "img_path", "filepath", "image", "image_file", "path"]
)
CAPTION_COL = pick_column(["caption", "report", "text", "caption_text"])
print("Using IMAGE_COL =", IMAGE_COL)
print("Using CAPTION_COL =", CAPTION_COL)
# =========================
# HELPER FUNCTIONS
# =========================
def load_image_for_row(row: pd.Series) -> Image.Image:
"""
Load one image given a metadata row.
Assumes metadata[IMAGE_COL] is a relative path inside saad003/images repo.
"""
rel_path = str(row[IMAGE_COL])
local_path = hf_hub_download(IMAGES_REPO, rel_path)
img = Image.open(local_path).convert("RGB")
return img
@torch.no_grad()
def embed_query_image(image: Image.Image) -> np.ndarray:
"""Embed query image with the same CLIP model used during indexing."""
inputs = clip_processor(images=image, return_tensors="pt").to(DEVICE)
features = clip_model.get_image_features(**inputs)
# normalize for cosine similarity
features = features / features.norm(dim=-1, keepdim=True)
return features.cpu().numpy().astype("float32")
def retrieve_similar(image: Image.Image, k: int = 5) -> pd.DataFrame:
"""Return top-k similar rows from METADATA."""
query_emb = embed_query_image(image)
D, I = INDEX.search(query_emb, k)
rows = METADATA.iloc[I[0]].copy()
rows["distance"] = D[0]
return rows
@torch.no_grad()
def generate_caption(image: Image.Image, neighbors: pd.DataFrame) -> str:
"""Generate caption for query image, using neighbors' captions as context."""
neighbor_captions = neighbors[CAPTION_COL].astype(str).tolist()
context = " | ".join(neighbor_captions[:3])
prompt = (
"Radiology image. Similar case descriptions: "
f"{context}. Generate a concise radiology-style caption for this new image."
)
inputs = caption_processor(
images=image,
text=prompt,
return_tensors="pt",
).to(DEVICE)
out = caption_model.generate(
**inputs,
max_new_tokens=64,
num_beams=3,
do_sample=False,
)
caption = caption_processor.decode(out[0], skip_special_tokens=True).strip()
return caption
def detect_modality(text: str) -> str:
t = text.lower()
modalities = {
"CT": ["ct", "computed tomography"],
"X-ray": ["x-ray", "xray", "radiograph", "chest x-ray", "cxr"],
"MRI": ["mri", "magnetic resonance"],
"Ultrasound": ["ultrasound", "sonography", "usg"],
"PET": ["pet scan", "pet-ct", "positron emission tomography"],
"Mammography": ["mammogram", "mammography"],
}
for name, kws in modalities.items():
if any(kw in t for kw in kws):
return name
return "Unknown"
def run_pipeline(
query_image: Image.Image, k: int = 5
) -> Tuple[List[Tuple[Image.Image, str]], str, str]:
"""
Full pipeline:
- retrieve neighbors
- load their images
- generate caption for query
- detect modality
"""
neighbors = retrieve_similar(query_image, k=k)
neighbor_images = [load_image_for_row(row) for _, row in neighbors.iterrows()]
neighbor_captions = neighbors[CAPTION_COL].astype(str).tolist()
gallery = [(img, cap) for img, cap in zip(neighbor_images, neighbor_captions)]
generated_caption = generate_caption(query_image, neighbors)
modality = detect_modality(
generated_caption + " " + " ".join(neighbor_captions)
)
return gallery, generated_caption, modality
# =========================
# GRADIO APP
# =========================
def gradio_infer(image, k):
if image is None:
return [], "No image provided", ""
k = int(k)
gallery, caption, modality = run_pipeline(image, k=k)
return gallery, caption, modality
demo = gr.Interface(
fn=gradio_infer,
inputs=[
gr.Image(type="pil", label="Query radiology image"),
gr.Slider(1, 12, value=5, step=1, label="Number of similar images"),
],
outputs=[
gr.Gallery(label="Similar images (with captions)").style(preview=True),
gr.Textbox(label="Generated caption for query image"),
gr.Textbox(label="Detected modality"),
],
title="Radiology Image Retrieval + Captioning",
description="Research demo. Not for clinical use.",
)
if __name__ == "__main__":
demo.launch()