File size: 3,122 Bytes
6fbc978
d672951
 
 
8e99010
d672951
f91a057
1739e75
d672951
 
bafe93c
f91a057
8e99010
 
a9e2b5c
d672951
 
 
1739e75
a9e2b5c
d672951
 
f91a057
 
8e99010
d672951
 
f91a057
 
 
 
 
1739e75
d672951
 
6fbc978
1739e75
 
 
f91a057
 
1739e75
8e99010
 
 
 
d672951
 
 
 
 
 
 
 
1739e75
8e99010
 
 
 
 
 
 
 
 
1739e75
8e99010
1739e75
 
d672951
 
 
 
8e99010
f91a057
8e99010
 
d672951
 
8e99010
 
d672951
8e99010
1739e75
a9e2b5c
 
bafe93c
6fbc978
 
 
d672951
a9e2b5c
 
d672951
a9e2b5c
6fbc978
 
 
ea5ac7a
6fbc978
ea5ac7a
 
 
 
 
 
6fbc978
 
ea5ac7a
6fbc978
ea5ac7a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import json
import os
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Tuple

import psutil
import torch
import gradio as gr
from transformers import AutoModelForZeroShotImageClassification, AutoProcessor

from utils.cache_manager import cached_inference, configure_cache
from utils.modality_router import detect_modality


BASE_DIR = Path(__file__).resolve().parent
LABEL_DIR = BASE_DIR / "labels"
MODEL_ID = "google/medsiglip-448"


HF_TOKEN = os.getenv("HF_TOKEN")

physical_cores = psutil.cpu_count(logical=False) or psutil.cpu_count() or 1
torch.set_num_threads(min(physical_cores, 4))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor = AutoProcessor.from_pretrained(
    MODEL_ID,
    token=HF_TOKEN,
    use_fast=True,
)
model = AutoModelForZeroShotImageClassification.from_pretrained(
    MODEL_ID,
    token=HF_TOKEN,
    dtype=torch.float32,
).to(device)
model.eval()

configure_cache(model, processor)


LABEL_OVERRIDES = {
    "xray": "chest_labels.json",
    "mri": "brain_labels.json",
}


@lru_cache(maxsize=None)
def load_labels(file_name: str) -> List[str]:
    label_path = LABEL_DIR / file_name
    with label_path.open("r", encoding="utf-8") as handle:
        return json.load(handle)


def get_candidate_labels(image_path: str) -> Tuple[str, ...]:
    modality = detect_modality(image_path)
    candidate_path = LABEL_DIR / f"{modality}_labels.json"
    if not candidate_path.exists():
        override = LABEL_OVERRIDES.get(modality)
        if override:
            candidate_path = LABEL_DIR / override
    if not candidate_path.exists():
        candidate_path = LABEL_DIR / "general_labels.json"

    return tuple(load_labels(candidate_path.name))


def classify_medical_image(image_path: str) -> Dict[str, float]:
    if not image_path:
        return {}

    candidate_labels = get_candidate_labels(image_path)
    scores = cached_inference(image_path, candidate_labels)

    if not scores:
        return {}

    results = sorted(zip(candidate_labels, scores), key=lambda x: x[1], reverse=True)
    top_results = results[:5]

    return {label: float(score) for label, score in top_results}


demo = gr.Interface(
    fn=classify_medical_image,
    inputs=gr.Image(type="filepath", label="Upload Medical Image"),
    outputs=gr.Label(num_top_classes=5, label="Top Predictions"),
    title="MedSigLIP Smart Medical Classifier",
    description="Zero-shot model with automatic label filtering for different modalities.",
)


if __name__ == "__main__":
    server_name = os.getenv("SERVER_NAME", "0.0.0.0")
    port_env = os.getenv("SERVER_PORT") or os.getenv("PORT") or "7860"
    share_env = os.getenv("GRADIO_SHARE", "false").lower()
    queue_env = os.getenv("GRADIO_QUEUE", "false").lower()

    share_enabled = share_env in {"1", "true", "yes"}
    queue_enabled = queue_env in {"1", "true", "yes"}

    app_to_launch = demo.queue() if queue_enabled else demo

    app_to_launch.launch(
        server_name=server_name,
        server_port=int(port_env),
        share=share_enabled,
        show_api=False,
    )