Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import CLIPProcessor, CLIPModel | |
| MODEL_ID = "openai/clip-vit-base-patch32" | |
| EMB_PATH = "mri_clip_embeddings.npz" | |
| TOP_K = 3 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = CLIPModel.from_pretrained(MODEL_ID).to(device) | |
| processor = CLIPProcessor.from_pretrained(MODEL_ID) | |
| model.eval() | |
| ds = load_dataset("AIOmarRehan/Brain_Tumor_MRI_Dataset")["train"] | |
| data = np.load(EMB_PATH) | |
| X = data["X"].astype(np.float32) | |
| LABEL_NAMES = { | |
| 0: "Glioma", | |
| 1: "Meningioma", | |
| 2: "No Tumor", | |
| 3: "Pituitary" | |
| } | |
| def embed_image(pil_img): | |
| with torch.inference_mode(): | |
| inputs = processor(images=pil_img, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| vec = model.get_image_features(**inputs) | |
| vec = vec / vec.norm(dim=-1, keepdim=True) | |
| return vec.squeeze(0).cpu().numpy().astype(np.float32) | |
| def recommend(pil_img): | |
| q = embed_image(pil_img) | |
| sims = X @ q | |
| top_idx = np.argsort(-sims)[:TOP_K].astype(int) | |
| gallery = [] | |
| lines = [] | |
| for rank, i in enumerate(top_idx, start=1): | |
| item = ds[int(i)] | |
| lbl = int(item["label"]) | |
| sim = float(sims[int(i)]) | |
| gallery.append(item["image"]) | |
| lines.append(f"{rank}) label={LABEL_NAMES.get(lbl, lbl)} | similarity={sim:.3f}") | |
| return gallery, "\n".join(lines) | |
| demo = gr.Interface( | |
| fn=recommend, | |
| inputs=gr.Image(type="pil", label="Upload an MRI image"), | |
| outputs=[ | |
| gr.Gallery(label="Top similar images", columns=3, height=200), | |
| gr.Textbox(label="Details", lines=6) | |
| ], | |
| title="Brain MRI Similarity Recommender", | |
| description= "Upload an MRI image and retrieve the top-3 most similar images from the dataset " | |
| "using CLIP embeddings and cosine similarity.\n\n" | |
| "Disclaimer: For educational purposes only β not for medical diagnosis." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |