Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import SiglipProcessor, SiglipModel | |
| # ========================================== | |
| # MODEL LOADING (happens once at startup) | |
| # ========================================== | |
| MODEL_ID = "google/medsiglip-448" | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| print("Loading MedSigLIP model...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = SiglipModel.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device) | |
| processor = SiglipProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN) | |
| model.eval() | |
| print(f"Model loaded on {device}.") | |
| # ========================================== | |
| # INFERENCE FUNCTION | |
| # ========================================== | |
| def classify(image: Image.Image, candidate_labels: str) -> str: | |
| """ | |
| Zero-shot image classification with MedSigLIP. | |
| Args: | |
| image: Input image (PIL). | |
| candidate_labels: One label per line. | |
| Returns: | |
| JSON string with sorted results. | |
| """ | |
| if image is None: | |
| return json.dumps({"error": "No image provided"}) | |
| labels = [l.strip() for l in candidate_labels.strip().split("\n") if l.strip()] | |
| if not labels: | |
| return json.dumps({"error": "No labels provided"}) | |
| image = image.convert("RGB") | |
| inputs = processor( | |
| text=labels, | |
| images=[image] * len(labels), | |
| padding="max_length", | |
| return_tensors="pt", | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits_per_image[0] | |
| probs = torch.softmax(logits, dim=0) | |
| results = sorted( | |
| [{"label": lbl, "score": round(p.item(), 5)} for lbl, p in zip(labels, probs)], | |
| key=lambda x: x["score"], | |
| reverse=True, | |
| ) | |
| return json.dumps(results) | |
| # ========================================== | |
| # GRADIO INTERFACE | |
| # ========================================== | |
| demo = gr.Interface( | |
| fn=classify, | |
| inputs=[ | |
| gr.Image(type="pil", label="Medical Image"), | |
| gr.Textbox( | |
| label="Candidate Labels (one per line)", | |
| lines=7, | |
| value="chest x-ray of a healthy lung\nchest x-ray showing pneumonia\nchest x-ray showing pneumothorax\nchest x-ray showing pleural effusion\nchest x-ray showing cardiomegaly\nchest x-ray showing tuberculosis\nchest x-ray showing a fracture", | |
| ), | |
| ], | |
| outputs=gr.Textbox(label="Results (JSON)"), | |
| title="MedSigLIP Zero-Shot Classifier", | |
| description="Upload a medical image and provide candidate labels to classify it using MedSigLIP.", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |