MayaKitzis's picture
Update app.py (#1)
28ef677 verified
raw
history blame
2 kB
import numpy as np
import gradio as gr
import torch
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
MODEL_ID = "openai/clip-vit-base-patch32"
EMB_PATH = "mri_clip_embeddings.npz"
TOP_K = 3
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained(MODEL_ID).to(device)
processor = CLIPProcessor.from_pretrained(MODEL_ID)
model.eval()
ds = load_dataset("AIOmarRehan/Brain_Tumor_MRI_Dataset")["train"]
data = np.load(EMB_PATH)
X = data["X"].astype(np.float32)
LABEL_NAMES = {
0: "Glioma",
1: "Meningioma",
2: "No Tumor",
3: "Pituitary"
}
def embed_image(pil_img):
with torch.inference_mode():
inputs = processor(images=pil_img, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
vec = model.get_image_features(**inputs)
vec = vec / vec.norm(dim=-1, keepdim=True)
return vec.squeeze(0).cpu().numpy().astype(np.float32)
def recommend(pil_img):
q = embed_image(pil_img)
sims = X @ q
top_idx = np.argsort(-sims)[:TOP_K].astype(int)
gallery = []
lines = []
for rank, i in enumerate(top_idx, start=1):
item = ds[int(i)]
lbl = int(item["label"])
sim = float(sims[int(i)])
gallery.append(item["image"])
lines.append(f"{rank}) label={LABEL_NAMES.get(lbl, lbl)} | similarity={sim:.3f}")
return gallery, "\n".join(lines)
demo = gr.Interface(
fn=recommend,
inputs=gr.Image(type="pil", label="Upload an MRI image"),
outputs=[
gr.Gallery(label="Top similar images", columns=3, height=200),
gr.Textbox(label="Details", lines=6)
],
title="Brain MRI Similarity Recommender",
description= "Upload an MRI image and retrieve the top-3 most similar images from the dataset "
"using CLIP embeddings and cosine similarity.\n\n"
"Disclaimer: For educational purposes only β€” not for medical diagnosis."
)
if __name__ == "__main__":
demo.launch()