import gradio as gr import torch from transformers import CLIPProcessor, CLIPModel from datasets import load_dataset from PIL import Image import numpy as np import os MODEL_ID = "spicy03/CLIP-ROCO-v1" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f" Loading Model: {MODEL_ID}...") try: model = CLIPModel.from_pretrained(MODEL_ID).to(DEVICE) processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") model.eval() print(" Model loaded successfully!") except Exception as e: print(f" Error: {e}") LABEL_PRESETS = { "Imaging Modalities": ["chest x-ray", "brain MRI scan","spine MRI scan", "abdominal CT scan", "ultrasound", "mammography","knee x-ray","dental x-ray","hand x-ray",], "Anatomical Regions": ["chest", "brain", "abdomen", "spine", "pelvis", "knee","dental","hand","leg"], "Pathologies": ["normal", "pneumonia", "fracture", "tumor", "edema"] } def classify_image(image, label_text, preset_choice): if image is None: return None, " Please upload an image." if preset_choice != "Custom": labels = LABEL_PRESETS[preset_choice] else: labels = [l.strip() for l in label_text.split("\n") if l.strip()] if not labels: return None, " Enter at least one label." try: inputs = processor(text=labels, images=image, return_tensors="pt", padding=True).to(DEVICE) with torch.no_grad(): outputs = model(**inputs) probs = outputs.logits_per_image.softmax(dim=1)[0].cpu().numpy() results = {label: float(prob) for label, prob in zip(labels, probs)} top_lbl = max(results, key=results.get) interpretation = f"**Top Prediction:** {top_lbl}\n**Confidence:** {results[top_lbl]:.1%}" return results, interpretation except Exception as e: return None, f" Error: {str(e)}" with gr.Blocks(title="MedCLIP AI", theme=gr.themes.Soft()) as demo: gr.Markdown("# ROCO-Radiology AI Assistant") gr.Markdown(f"**Model:** `{MODEL_ID}` | **Status:** Live on {DEVICE.upper()}") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload Scan") preset_radio = gr.Radio( choices=["Custom"] + list(LABEL_PRESETS.keys()), value="Imaging Modalities", label="Select Candidates" ) custom_labels = gr.Textbox( label="Custom Labels (One per line)", placeholder="pneumonia\nnormal", visible=False ) classify_btn = gr.Button(" Analyze Image", variant="primary") with gr.Column(scale=1): output_label = gr.Label(num_top_classes=5, label="Confidence Scores") interpretation = gr.Markdown(label="Interpretation") def update_vis(choice): return gr.update(visible=(choice == "Custom")) preset_radio.change(fn=update_vis, inputs=[preset_radio], outputs=[custom_labels]) classify_btn.click( fn=classify_image, inputs=[image_input, custom_labels, preset_radio], outputs=[output_label, interpretation] ) gr.Markdown("### Try an Example (Click one to run)") gr.Examples( examples=[ ["example_0.jpg", "", "Imaging Modalities"], ["example_1.jpg", "", "Anatomical Regions"], ["example_2.jpg", "chest x-ray\nbrain MRI\nknee scan", "Custom"] ], inputs=[image_input, custom_labels, preset_radio], outputs=[output_label, interpretation], fn=classify_image, cache_examples=False, ) gr.Markdown("---") gr.Markdown(" **Disclaimer:** For research/demo purposes only. Not for clinical use.") print(" Launching App...") demo.launch(share=True, debug=True)