import os from typing import List, Tuple import faiss import numpy as np import pandas as pd from PIL import Image import torch import gradio as gr from huggingface_hub import hf_hub_download from transformers import ( CLIPModel, CLIPProcessor, AutoProcessor, BlipForConditionalGeneration, ) # ========================= # CONFIG # ========================= DATASET_REPO = "saad003/Dataset_final" # where embeddings + faiss + metadata live IMAGES_REPO = "saad003/images" # where the radiology images live CLIP_MODEL_ID = "openai/clip-vit-base-patch32" CAPTION_MODEL_ID = "WafaaFraih/blip-roco-radiology-captioning" # BLIP radiology DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ========================= # LOAD MODELS # ========================= print("Loading CLIP model...") clip_model = CLIPModel.from_pretrained(CLIP_MODEL_ID).to(DEVICE) clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID) clip_model.eval() print("Loading caption model...") caption_processor = AutoProcessor.from_pretrained(CAPTION_MODEL_ID) caption_model = BlipForConditionalGeneration.from_pretrained( CAPTION_MODEL_ID ).to(DEVICE) caption_model.eval() # ========================= # LOAD INDEX + METADATA # ========================= print("Loading FAISS index + embeddings + metadata...") embeddings_path = hf_hub_download(DATASET_REPO, "embeddings.npy") index_path = hf_hub_download(DATASET_REPO, "image_index.faiss") EMBEDDINGS = np.load(embeddings_path).astype("float32") INDEX = faiss.read_index(index_path) # metadata: parquet preferred, else csv try: meta_path = hf_hub_download(DATASET_REPO, "metadata.parquet") METADATA = pd.read_parquet(meta_path) print("Loaded metadata.parquet") except Exception: meta_path = hf_hub_download(DATASET_REPO, "metadata.csv") METADATA = pd.read_csv(meta_path) print("Loaded metadata.csv") print("Metadata columns:", list(METADATA.columns)) def pick_column(candidates: List[str]) -> str: """Pick first existing column name from candidates.""" for c in candidates: if c in METADATA.columns: return c raise RuntimeError( f"None of {candidates} found in metadata columns: {list(METADATA.columns)}" ) # Adjust these if my guesses are wrong; check your metadata file on HF IMAGE_COL = pick_column( ["image_path", "img_path", "filepath", "image", "image_file", "path"] ) CAPTION_COL = pick_column(["caption", "report", "text", "caption_text"]) print("Using IMAGE_COL =", IMAGE_COL) print("Using CAPTION_COL =", CAPTION_COL) # ========================= # HELPER FUNCTIONS # ========================= def load_image_for_row(row: pd.Series) -> Image.Image: """ Load one image given a metadata row. Assumes metadata[IMAGE_COL] is a relative path inside saad003/images repo. """ rel_path = str(row[IMAGE_COL]) local_path = hf_hub_download(IMAGES_REPO, rel_path) img = Image.open(local_path).convert("RGB") return img @torch.no_grad() def embed_query_image(image: Image.Image) -> np.ndarray: """Embed query image with the same CLIP model used during indexing.""" inputs = clip_processor(images=image, return_tensors="pt").to(DEVICE) features = clip_model.get_image_features(**inputs) # normalize for cosine similarity features = features / features.norm(dim=-1, keepdim=True) return features.cpu().numpy().astype("float32") def retrieve_similar(image: Image.Image, k: int = 5) -> pd.DataFrame: """Return top-k similar rows from METADATA.""" query_emb = embed_query_image(image) D, I = INDEX.search(query_emb, k) rows = METADATA.iloc[I[0]].copy() rows["distance"] = D[0] return rows @torch.no_grad() def generate_caption(image: Image.Image, neighbors: pd.DataFrame) -> str: """Generate caption for query image, using neighbors' captions as context.""" neighbor_captions = neighbors[CAPTION_COL].astype(str).tolist() context = " | ".join(neighbor_captions[:3]) prompt = ( "Radiology image. Similar case descriptions: " f"{context}. Generate a concise radiology-style caption for this new image." ) inputs = caption_processor( images=image, text=prompt, return_tensors="pt", ).to(DEVICE) out = caption_model.generate( **inputs, max_new_tokens=64, num_beams=3, do_sample=False, ) caption = caption_processor.decode(out[0], skip_special_tokens=True).strip() return caption def detect_modality(text: str) -> str: t = text.lower() modalities = { "CT": ["ct", "computed tomography"], "X-ray": ["x-ray", "xray", "radiograph", "chest x-ray", "cxr"], "MRI": ["mri", "magnetic resonance"], "Ultrasound": ["ultrasound", "sonography", "usg"], "PET": ["pet scan", "pet-ct", "positron emission tomography"], "Mammography": ["mammogram", "mammography"], } for name, kws in modalities.items(): if any(kw in t for kw in kws): return name return "Unknown" def run_pipeline( query_image: Image.Image, k: int = 5 ) -> Tuple[List[Tuple[Image.Image, str]], str, str]: """ Full pipeline: - retrieve neighbors - load their images - generate caption for query - detect modality """ neighbors = retrieve_similar(query_image, k=k) neighbor_images = [load_image_for_row(row) for _, row in neighbors.iterrows()] neighbor_captions = neighbors[CAPTION_COL].astype(str).tolist() gallery = [(img, cap) for img, cap in zip(neighbor_images, neighbor_captions)] generated_caption = generate_caption(query_image, neighbors) modality = detect_modality( generated_caption + " " + " ".join(neighbor_captions) ) return gallery, generated_caption, modality # ========================= # GRADIO APP # ========================= def gradio_infer(image, k): if image is None: return [], "No image provided", "" k = int(k) gallery, caption, modality = run_pipeline(image, k=k) return gallery, caption, modality demo = gr.Interface( fn=gradio_infer, inputs=[ gr.Image(type="pil", label="Query radiology image"), gr.Slider(1, 12, value=5, step=1, label="Number of similar images"), ], outputs=[ gr.Gallery(label="Similar images (with captions)").style(preview=True), gr.Textbox(label="Generated caption for query image"), gr.Textbox(label="Detected modality"), ], title="Radiology Image Retrieval + Captioning", description="Research demo. Not for clinical use.", ) if __name__ == "__main__": demo.launch()