Spaces:
Sleeping
Sleeping
| # app.py | |
| import io | |
| import os | |
| import faiss | |
| import torch | |
| import pandas as pd | |
| from PIL import Image | |
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from huggingface_hub import hf_hub_download | |
| from transformers import CLIPProcessor, CLIPModel | |
| from transformers import BlipForConditionalGeneration, AutoProcessor | |
| # ---------- FastAPI app ---------- | |
| app = FastAPI() | |
| # Allow your React app to call this API | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # You can later restrict to your domain | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ---------- Config ---------- | |
| # Dataset with FAISS index + radiology_metadata.csv | |
| EMBED_REPO_ID = "saad003/Red01" | |
| # Dataset with all radiology images (you uploaded here) | |
| IMAGE_REPO_ID = "saad003/images02" | |
| BASE_IMAGE_URL = f"https://huggingface.co/datasets/{IMAGE_REPO_ID}/resolve/main" | |
| # Optional: token if Red01 is private (set HF_TOKEN secret in Space) | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # ---------- Download index + metadata ---------- | |
| print("Downloading FAISS index & metadata from Hugging Face...") | |
| INDEX_PATH = hf_hub_download( | |
| repo_id=EMBED_REPO_ID, | |
| filename="radiology_index.faiss", | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| ) | |
| META_PATH = hf_hub_download( | |
| repo_id=EMBED_REPO_ID, | |
| filename="radiology_metadata.csv", | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| ) | |
| print("Loading FAISS index...") | |
| index = faiss.read_index(INDEX_PATH) | |
| print("Loading metadata CSV...") | |
| metadata = pd.read_csv(META_PATH) | |
| # Sanity check | |
| assert index.ntotal == len(metadata), "Index size and metadata rows mismatch!" | |
| # ---------- Load CLIP (retrieval) ---------- | |
| # IMPORTANT: must match the model you used to build the index. | |
| print("Loading PubMedCLIP model for retrieval...") | |
| CLIP_MODEL_NAME = "flaviagiammarino/pubmed-clip-vit-base-patch32" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("Using device:", device) | |
| clip_model = CLIPModel.from_pretrained(CLIP_MODEL_NAME).to(device) | |
| clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME) | |
| clip_model.eval() | |
| # ---------- Load BLIP (captioning) ---------- | |
| print("Loading BLIP radiology captioning model...") | |
| CAPTION_MODEL_ID = "WafaaFraih/blip-roco-radiology-captioning" | |
| caption_processor = AutoProcessor.from_pretrained(CAPTION_MODEL_ID) | |
| caption_model = BlipForConditionalGeneration.from_pretrained(CAPTION_MODEL_ID).to(device) | |
| caption_model.eval() | |
| print("Backend ready ✅") | |
| # ---------- Helper: build image URL from img_path ---------- | |
| def img_path_to_image_url(img_path: str) -> str: | |
| """ | |
| Use the original img_path from Kaggle and map it to your HF dataset. | |
| Example img_path in CSV: | |
| /kaggle/input/radiology/8333645/train_images/train/ROCOv2_2023_train_000001.jpg | |
| If you uploaded folders train_images/..., test_images/..., valid_images/... into | |
| saad003/images02, the relative path after '8333645/' is what we want. | |
| So URL becomes: | |
| https://huggingface.co/datasets/saad003/images02/resolve/main/train_images/train/ROCOv2_2023_train_000001.jpg | |
| """ | |
| if not isinstance(img_path, str): | |
| return None | |
| # Try to cut everything up to the Kaggle dataset root | |
| marker = "8333645/" | |
| if marker in img_path: | |
| rel = img_path.split(marker, 1)[1] | |
| else: | |
| # Fallback: just take the filename | |
| rel = os.path.basename(img_path) | |
| rel = rel.lstrip("/") # safety | |
| return f"{BASE_IMAGE_URL}/{rel}" | |
| # ---------- Helper: search by image ---------- | |
| def search_similar_by_image(image: Image.Image, k: int = 5) -> pd.DataFrame: | |
| """ | |
| Encode query image with CLIP, search FAISS, return top-k rows | |
| containing ID, split, caption, concepts, score, image_url. | |
| """ | |
| # Encode image | |
| inputs = clip_processor(images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| feats = clip_model.get_image_features(**inputs) | |
| # Normalize (very important, must match index construction) | |
| feats = feats / feats.norm(p=2, dim=-1, keepdim=True) | |
| feats = feats.cpu().numpy().astype("float32") | |
| # Search FAISS | |
| D, I = index.search(feats, k) # D: distances/similarity, I: indices | |
| # Get metadata rows for top-k indices | |
| rows = metadata.iloc[I[0]].copy() | |
| rows["score"] = D[0] | |
| # Add image_url using original img_path column | |
| rows["image_url"] = rows["img_path"].apply(img_path_to_image_url) | |
| return rows[["ID", "split", "caption", "concepts_manual", "score", "image_url"]] | |
| # ---------- Helper: generate caption for query image ---------- | |
| def generate_query_caption(image: Image.Image) -> str: | |
| """ | |
| Generate a medical radiology caption for the query image using BLIP | |
| fine-tuned on ROCO. | |
| """ | |
| inputs = caption_processor(images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| out = caption_model.generate(**inputs, max_new_tokens=64) | |
| caption = caption_processor.batch_decode(out, skip_special_tokens=True)[0] | |
| return caption.strip() | |
| # ---------- Helper: infer modality from caption ---------- | |
| def infer_modality_from_caption(caption: str) -> str: | |
| """ | |
| Simple heuristic to map a caption to imaging modality. | |
| """ | |
| if not caption: | |
| return "Unknown" | |
| text = caption.lower() | |
| # CT | |
| if any(word in text for word in ["ct scan", "computed tomography", "ct of", "ct image", "ct of the"]): | |
| return "CT" | |
| # MRI | |
| if any(word in text for word in ["mri", "magnetic resonance"]): | |
| return "MRI" | |
| # X-ray / radiograph | |
| if any(word in text for word in ["x-ray", "x ray", "radiograph", "chest xray", "chest x-ray"]): | |
| return "X-ray" | |
| # Ultrasound | |
| if any(word in text for word in ["ultrasound", "sonography", "sonogram"]): | |
| return "Ultrasound" | |
| # PET / PET-CT | |
| if any(word in text for word in ["pet-ct", "pet ct", "pet scan", "positron emission tomography"]): | |
| return "PET/CT" | |
| return "Unknown" | |
| # ---------- Routes ---------- | |
| def root(): | |
| return {"status": "ok", "message": "Radiology retrieval + captioning API"} | |
| async def search_by_image(file: UploadFile = File(...), k: int = 5): | |
| """ | |
| Upload a radiology image. | |
| Returns: | |
| - query_caption: generated caption for the query image (BLIP) | |
| - modality: inferred imaging modality from the caption | |
| - results: list of similar images with their captions, concepts, score, image_url | |
| """ | |
| content = await file.read() | |
| image = Image.open(io.BytesIO(content)).convert("RGB") | |
| # 1) Retrieval | |
| results_df = search_similar_by_image(image, k=k) | |
| results = results_df.to_dict(orient="records") | |
| # 2) Captioning for the query image | |
| try: | |
| query_caption = generate_query_caption(image) | |
| except Exception as e: | |
| print("Error generating caption:", e) | |
| query_caption = None | |
| # 3) Infer modality | |
| modality = infer_modality_from_caption(query_caption or "") | |
| return JSONResponse( | |
| { | |
| "query_caption": query_caption, | |
| "modality": modality, | |
| "results": results, | |
| } | |
| ) | |