BFS / app.py
VladyslavSh's picture
Update app.py
cb6ba48 verified
import os
import json
import torch
import gradio as gr
from PIL import Image
from transformers import SiglipProcessor, SiglipModel
# ==========================================
# MODEL LOADING (happens once at startup)
# ==========================================
MODEL_ID = "google/medsiglip-448"
HF_TOKEN = os.environ.get("HF_TOKEN")
print("Loading MedSigLIP model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SiglipModel.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device)
processor = SiglipProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
model.eval()
print(f"Model loaded on {device}.")
# ==========================================
# INFERENCE FUNCTION
# ==========================================
def classify(image: Image.Image, candidate_labels: str) -> str:
"""
Zero-shot image classification with MedSigLIP.
Args:
image: Input image (PIL).
candidate_labels: One label per line.
Returns:
JSON string with sorted results.
"""
if image is None:
return json.dumps({"error": "No image provided"})
labels = [l.strip() for l in candidate_labels.strip().split("\n") if l.strip()]
if not labels:
return json.dumps({"error": "No labels provided"})
image = image.convert("RGB")
inputs = processor(
text=labels,
images=[image] * len(labels),
padding="max_length",
return_tensors="pt",
).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits_per_image[0]
probs = torch.softmax(logits, dim=0)
results = sorted(
[{"label": lbl, "score": round(p.item(), 5)} for lbl, p in zip(labels, probs)],
key=lambda x: x["score"],
reverse=True,
)
return json.dumps(results)
# ==========================================
# GRADIO INTERFACE
# ==========================================
demo = gr.Interface(
fn=classify,
inputs=[
gr.Image(type="pil", label="Medical Image"),
gr.Textbox(
label="Candidate Labels (one per line)",
lines=7,
value="chest x-ray of a healthy lung\nchest x-ray showing pneumonia\nchest x-ray showing pneumothorax\nchest x-ray showing pleural effusion\nchest x-ray showing cardiomegaly\nchest x-ray showing tuberculosis\nchest x-ray showing a fracture",
),
],
outputs=gr.Textbox(label="Results (JSON)"),
title="MedSigLIP Zero-Shot Classifier",
description="Upload a medical image and provide candidate labels to classify it using MedSigLIP.",
)
if __name__ == "__main__":
demo.launch()