spicy03's picture
Upload folder using huggingface_hub
4c3000e verified
import gradio as gr
import torch
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
from PIL import Image
import numpy as np
import os
MODEL_ID = "spicy03/CLIP-ROCO-v1"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f" Loading Model: {MODEL_ID}...")
try:
model = CLIPModel.from_pretrained(MODEL_ID).to(DEVICE)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model.eval()
print(" Model loaded successfully!")
except Exception as e:
print(f" Error: {e}")
LABEL_PRESETS = {
"Imaging Modalities": ["chest x-ray", "brain MRI scan","spine MRI scan", "abdominal CT scan", "ultrasound", "mammography","knee x-ray","dental x-ray","hand x-ray",],
"Anatomical Regions": ["chest", "brain", "abdomen", "spine", "pelvis", "knee","dental","hand","leg"],
"Pathologies": ["normal", "pneumonia", "fracture", "tumor", "edema"]
}
def classify_image(image, label_text, preset_choice):
if image is None:
return None, " Please upload an image."
if preset_choice != "Custom":
labels = LABEL_PRESETS[preset_choice]
else:
labels = [l.strip() for l in label_text.split("\n") if l.strip()]
if not labels:
return None, " Enter at least one label."
try:
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True).to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
probs = outputs.logits_per_image.softmax(dim=1)[0].cpu().numpy()
results = {label: float(prob) for label, prob in zip(labels, probs)}
top_lbl = max(results, key=results.get)
interpretation = f"**Top Prediction:** {top_lbl}\n**Confidence:** {results[top_lbl]:.1%}"
return results, interpretation
except Exception as e:
return None, f" Error: {str(e)}"
with gr.Blocks(title="MedCLIP AI", theme=gr.themes.Soft()) as demo:
gr.Markdown("# ROCO-Radiology AI Assistant")
gr.Markdown(f"**Model:** `{MODEL_ID}` | **Status:** Live on {DEVICE.upper()}")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Scan")
preset_radio = gr.Radio(
choices=["Custom"] + list(LABEL_PRESETS.keys()),
value="Imaging Modalities",
label="Select Candidates"
)
custom_labels = gr.Textbox(
label="Custom Labels (One per line)",
placeholder="pneumonia\nnormal",
visible=False
)
classify_btn = gr.Button(" Analyze Image", variant="primary")
with gr.Column(scale=1):
output_label = gr.Label(num_top_classes=5, label="Confidence Scores")
interpretation = gr.Markdown(label="Interpretation")
def update_vis(choice):
return gr.update(visible=(choice == "Custom"))
preset_radio.change(fn=update_vis, inputs=[preset_radio], outputs=[custom_labels])
classify_btn.click(
fn=classify_image,
inputs=[image_input, custom_labels, preset_radio],
outputs=[output_label, interpretation]
)
gr.Markdown("### Try an Example (Click one to run)")
gr.Examples(
examples=[
["example_0.jpg", "", "Imaging Modalities"],
["example_1.jpg", "", "Anatomical Regions"],
["example_2.jpg", "chest x-ray\nbrain MRI\nknee scan", "Custom"]
],
inputs=[image_input, custom_labels, preset_radio],
outputs=[output_label, interpretation],
fn=classify_image,
cache_examples=False,
)
gr.Markdown("---")
gr.Markdown(" **Disclaimer:** For research/demo purposes only. Not for clinical use.")
print(" Launching App...")
demo.launch(share=True, debug=True)