|
|
import os |
|
|
from typing import List, Tuple |
|
|
|
|
|
import faiss |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from PIL import Image |
|
|
|
|
|
import torch |
|
|
import gradio as gr |
|
|
from huggingface_hub import hf_hub_download |
|
|
from transformers import ( |
|
|
CLIPModel, |
|
|
CLIPProcessor, |
|
|
AutoProcessor, |
|
|
BlipForConditionalGeneration, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DATASET_REPO = "saad003/Dataset_final" |
|
|
IMAGES_REPO = "saad003/images" |
|
|
|
|
|
CLIP_MODEL_ID = "openai/clip-vit-base-patch32" |
|
|
CAPTION_MODEL_ID = "WafaaFraih/blip-roco-radiology-captioning" |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading CLIP model...") |
|
|
clip_model = CLIPModel.from_pretrained(CLIP_MODEL_ID).to(DEVICE) |
|
|
clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID) |
|
|
clip_model.eval() |
|
|
|
|
|
print("Loading caption model...") |
|
|
caption_processor = AutoProcessor.from_pretrained(CAPTION_MODEL_ID) |
|
|
caption_model = BlipForConditionalGeneration.from_pretrained( |
|
|
CAPTION_MODEL_ID |
|
|
).to(DEVICE) |
|
|
caption_model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading FAISS index + embeddings + metadata...") |
|
|
|
|
|
embeddings_path = hf_hub_download(DATASET_REPO, "embeddings.npy") |
|
|
index_path = hf_hub_download(DATASET_REPO, "image_index.faiss") |
|
|
|
|
|
EMBEDDINGS = np.load(embeddings_path).astype("float32") |
|
|
INDEX = faiss.read_index(index_path) |
|
|
|
|
|
|
|
|
try: |
|
|
meta_path = hf_hub_download(DATASET_REPO, "metadata.parquet") |
|
|
METADATA = pd.read_parquet(meta_path) |
|
|
print("Loaded metadata.parquet") |
|
|
except Exception: |
|
|
meta_path = hf_hub_download(DATASET_REPO, "metadata.csv") |
|
|
METADATA = pd.read_csv(meta_path) |
|
|
print("Loaded metadata.csv") |
|
|
|
|
|
print("Metadata columns:", list(METADATA.columns)) |
|
|
|
|
|
|
|
|
def pick_column(candidates: List[str]) -> str: |
|
|
"""Pick first existing column name from candidates.""" |
|
|
for c in candidates: |
|
|
if c in METADATA.columns: |
|
|
return c |
|
|
raise RuntimeError( |
|
|
f"None of {candidates} found in metadata columns: {list(METADATA.columns)}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
IMAGE_COL = pick_column( |
|
|
["image_path", "img_path", "filepath", "image", "image_file", "path"] |
|
|
) |
|
|
CAPTION_COL = pick_column(["caption", "report", "text", "caption_text"]) |
|
|
|
|
|
print("Using IMAGE_COL =", IMAGE_COL) |
|
|
print("Using CAPTION_COL =", CAPTION_COL) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_image_for_row(row: pd.Series) -> Image.Image: |
|
|
""" |
|
|
Load one image given a metadata row. |
|
|
Assumes metadata[IMAGE_COL] is a relative path inside saad003/images repo. |
|
|
""" |
|
|
rel_path = str(row[IMAGE_COL]) |
|
|
local_path = hf_hub_download(IMAGES_REPO, rel_path) |
|
|
img = Image.open(local_path).convert("RGB") |
|
|
return img |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def embed_query_image(image: Image.Image) -> np.ndarray: |
|
|
"""Embed query image with the same CLIP model used during indexing.""" |
|
|
inputs = clip_processor(images=image, return_tensors="pt").to(DEVICE) |
|
|
features = clip_model.get_image_features(**inputs) |
|
|
|
|
|
features = features / features.norm(dim=-1, keepdim=True) |
|
|
return features.cpu().numpy().astype("float32") |
|
|
|
|
|
|
|
|
def retrieve_similar(image: Image.Image, k: int = 5) -> pd.DataFrame: |
|
|
"""Return top-k similar rows from METADATA.""" |
|
|
query_emb = embed_query_image(image) |
|
|
D, I = INDEX.search(query_emb, k) |
|
|
rows = METADATA.iloc[I[0]].copy() |
|
|
rows["distance"] = D[0] |
|
|
return rows |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_caption(image: Image.Image, neighbors: pd.DataFrame) -> str: |
|
|
"""Generate caption for query image, using neighbors' captions as context.""" |
|
|
neighbor_captions = neighbors[CAPTION_COL].astype(str).tolist() |
|
|
context = " | ".join(neighbor_captions[:3]) |
|
|
|
|
|
prompt = ( |
|
|
"Radiology image. Similar case descriptions: " |
|
|
f"{context}. Generate a concise radiology-style caption for this new image." |
|
|
) |
|
|
|
|
|
inputs = caption_processor( |
|
|
images=image, |
|
|
text=prompt, |
|
|
return_tensors="pt", |
|
|
).to(DEVICE) |
|
|
|
|
|
out = caption_model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=64, |
|
|
num_beams=3, |
|
|
do_sample=False, |
|
|
) |
|
|
|
|
|
caption = caption_processor.decode(out[0], skip_special_tokens=True).strip() |
|
|
return caption |
|
|
|
|
|
|
|
|
def detect_modality(text: str) -> str: |
|
|
t = text.lower() |
|
|
modalities = { |
|
|
"CT": ["ct", "computed tomography"], |
|
|
"X-ray": ["x-ray", "xray", "radiograph", "chest x-ray", "cxr"], |
|
|
"MRI": ["mri", "magnetic resonance"], |
|
|
"Ultrasound": ["ultrasound", "sonography", "usg"], |
|
|
"PET": ["pet scan", "pet-ct", "positron emission tomography"], |
|
|
"Mammography": ["mammogram", "mammography"], |
|
|
} |
|
|
|
|
|
for name, kws in modalities.items(): |
|
|
if any(kw in t for kw in kws): |
|
|
return name |
|
|
return "Unknown" |
|
|
|
|
|
|
|
|
def run_pipeline( |
|
|
query_image: Image.Image, k: int = 5 |
|
|
) -> Tuple[List[Tuple[Image.Image, str]], str, str]: |
|
|
""" |
|
|
Full pipeline: |
|
|
- retrieve neighbors |
|
|
- load their images |
|
|
- generate caption for query |
|
|
- detect modality |
|
|
""" |
|
|
neighbors = retrieve_similar(query_image, k=k) |
|
|
|
|
|
neighbor_images = [load_image_for_row(row) for _, row in neighbors.iterrows()] |
|
|
neighbor_captions = neighbors[CAPTION_COL].astype(str).tolist() |
|
|
|
|
|
gallery = [(img, cap) for img, cap in zip(neighbor_images, neighbor_captions)] |
|
|
|
|
|
generated_caption = generate_caption(query_image, neighbors) |
|
|
|
|
|
modality = detect_modality( |
|
|
generated_caption + " " + " ".join(neighbor_captions) |
|
|
) |
|
|
|
|
|
return gallery, generated_caption, modality |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_infer(image, k): |
|
|
if image is None: |
|
|
return [], "No image provided", "" |
|
|
|
|
|
k = int(k) |
|
|
gallery, caption, modality = run_pipeline(image, k=k) |
|
|
return gallery, caption, modality |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=gradio_infer, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="Query radiology image"), |
|
|
gr.Slider(1, 12, value=5, step=1, label="Number of similar images"), |
|
|
], |
|
|
outputs=[ |
|
|
gr.Gallery(label="Similar images (with captions)").style(preview=True), |
|
|
gr.Textbox(label="Generated caption for query image"), |
|
|
gr.Textbox(label="Detected modality"), |
|
|
], |
|
|
title="Radiology Image Retrieval + Captioning", |
|
|
description="Research demo. Not for clinical use.", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|