import os import json import torch import gradio as gr from PIL import Image from transformers import SiglipProcessor, SiglipModel # ========================================== # MODEL LOADING (happens once at startup) # ========================================== MODEL_ID = "google/medsiglip-448" HF_TOKEN = os.environ.get("HF_TOKEN") print("Loading MedSigLIP model...") device = "cuda" if torch.cuda.is_available() else "cpu" model = SiglipModel.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device) processor = SiglipProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN) model.eval() print(f"Model loaded on {device}.") # ========================================== # INFERENCE FUNCTION # ========================================== def classify(image: Image.Image, candidate_labels: str) -> str: """ Zero-shot image classification with MedSigLIP. Args: image: Input image (PIL). candidate_labels: One label per line. Returns: JSON string with sorted results. """ if image is None: return json.dumps({"error": "No image provided"}) labels = [l.strip() for l in candidate_labels.strip().split("\n") if l.strip()] if not labels: return json.dumps({"error": "No labels provided"}) image = image.convert("RGB") inputs = processor( text=labels, images=[image] * len(labels), padding="max_length", return_tensors="pt", ).to(device) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits_per_image[0] probs = torch.softmax(logits, dim=0) results = sorted( [{"label": lbl, "score": round(p.item(), 5)} for lbl, p in zip(labels, probs)], key=lambda x: x["score"], reverse=True, ) return json.dumps(results) # ========================================== # GRADIO INTERFACE # ========================================== demo = gr.Interface( fn=classify, inputs=[ gr.Image(type="pil", label="Medical Image"), gr.Textbox( label="Candidate Labels (one per line)", lines=7, value="chest x-ray of a healthy lung\nchest x-ray showing pneumonia\nchest x-ray showing pneumothorax\nchest x-ray showing pleural effusion\nchest x-ray showing cardiomegaly\nchest x-ray showing tuberculosis\nchest x-ray showing a fracture", ), ], outputs=gr.Textbox(label="Results (JSON)"), title="MedSigLIP Zero-Shot Classifier", description="Upload a medical image and provide candidate labels to classify it using MedSigLIP.", ) if __name__ == "__main__": demo.launch()