File size: 2,622 Bytes
cb6ba48
 
 
ebff197
 
cb6ba48
ebff197
cb6ba48
 
 
 
 
ebff197
cb6ba48
 
ebff197
cb6ba48
 
 
 
ebff197
cb6ba48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebff197
cb6ba48
 
 
 
 
 
 
 
ebff197
cb6ba48
 
 
 
ebff197
 
 
 
 
cb6ba48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import json
import torch
import gradio as gr
from PIL import Image
from transformers import SiglipProcessor, SiglipModel

# ==========================================
# MODEL LOADING (happens once at startup)
# ==========================================
MODEL_ID = "google/medsiglip-448"
HF_TOKEN = os.environ.get("HF_TOKEN")

print("Loading MedSigLIP model...")
device = "cuda" if torch.cuda.is_available() else "cpu"

model = SiglipModel.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device)
processor = SiglipProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
model.eval()
print(f"Model loaded on {device}.")


# ==========================================
# INFERENCE FUNCTION
# ==========================================
def classify(image: Image.Image, candidate_labels: str) -> str:
    """
    Zero-shot image classification with MedSigLIP.

    Args:
        image: Input image (PIL).
        candidate_labels: One label per line.

    Returns:
        JSON string with sorted results.
    """
    if image is None:
        return json.dumps({"error": "No image provided"})

    labels = [l.strip() for l in candidate_labels.strip().split("\n") if l.strip()]
    if not labels:
        return json.dumps({"error": "No labels provided"})

    image = image.convert("RGB")

    inputs = processor(
        text=labels,
        images=[image] * len(labels),
        padding="max_length",
        return_tensors="pt",
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits_per_image[0]
    probs = torch.softmax(logits, dim=0)

    results = sorted(
        [{"label": lbl, "score": round(p.item(), 5)} for lbl, p in zip(labels, probs)],
        key=lambda x: x["score"],
        reverse=True,
    )

    return json.dumps(results)


# ==========================================
# GRADIO INTERFACE
# ==========================================
demo = gr.Interface(
    fn=classify,
    inputs=[
        gr.Image(type="pil", label="Medical Image"),
        gr.Textbox(
            label="Candidate Labels (one per line)",
            lines=7,
            value="chest x-ray of a healthy lung\nchest x-ray showing pneumonia\nchest x-ray showing pneumothorax\nchest x-ray showing pleural effusion\nchest x-ray showing cardiomegaly\nchest x-ray showing tuberculosis\nchest x-ray showing a fracture",
        ),
    ],
    outputs=gr.Textbox(label="Results (JSON)"),
    title="MedSigLIP Zero-Shot Classifier",
    description="Upload a medical image and provide candidate labels to classify it using MedSigLIP.",
)

if __name__ == "__main__":
    demo.launch()