File size: 3,122 Bytes
6fbc978 d672951 8e99010 d672951 f91a057 1739e75 d672951 bafe93c f91a057 8e99010 a9e2b5c d672951 1739e75 a9e2b5c d672951 f91a057 8e99010 d672951 f91a057 1739e75 d672951 6fbc978 1739e75 f91a057 1739e75 8e99010 d672951 1739e75 8e99010 1739e75 8e99010 1739e75 d672951 8e99010 f91a057 8e99010 d672951 8e99010 d672951 8e99010 1739e75 a9e2b5c bafe93c 6fbc978 d672951 a9e2b5c d672951 a9e2b5c 6fbc978 ea5ac7a 6fbc978 ea5ac7a 6fbc978 ea5ac7a 6fbc978 ea5ac7a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | import json
import os
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Tuple
import psutil
import torch
import gradio as gr
from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
from utils.cache_manager import cached_inference, configure_cache
from utils.modality_router import detect_modality
BASE_DIR = Path(__file__).resolve().parent
LABEL_DIR = BASE_DIR / "labels"
MODEL_ID = "google/medsiglip-448"
HF_TOKEN = os.getenv("HF_TOKEN")
physical_cores = psutil.cpu_count(logical=False) or psutil.cpu_count() or 1
torch.set_num_threads(min(physical_cores, 4))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = AutoProcessor.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
use_fast=True,
)
model = AutoModelForZeroShotImageClassification.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
dtype=torch.float32,
).to(device)
model.eval()
configure_cache(model, processor)
LABEL_OVERRIDES = {
"xray": "chest_labels.json",
"mri": "brain_labels.json",
}
@lru_cache(maxsize=None)
def load_labels(file_name: str) -> List[str]:
label_path = LABEL_DIR / file_name
with label_path.open("r", encoding="utf-8") as handle:
return json.load(handle)
def get_candidate_labels(image_path: str) -> Tuple[str, ...]:
modality = detect_modality(image_path)
candidate_path = LABEL_DIR / f"{modality}_labels.json"
if not candidate_path.exists():
override = LABEL_OVERRIDES.get(modality)
if override:
candidate_path = LABEL_DIR / override
if not candidate_path.exists():
candidate_path = LABEL_DIR / "general_labels.json"
return tuple(load_labels(candidate_path.name))
def classify_medical_image(image_path: str) -> Dict[str, float]:
if not image_path:
return {}
candidate_labels = get_candidate_labels(image_path)
scores = cached_inference(image_path, candidate_labels)
if not scores:
return {}
results = sorted(zip(candidate_labels, scores), key=lambda x: x[1], reverse=True)
top_results = results[:5]
return {label: float(score) for label, score in top_results}
demo = gr.Interface(
fn=classify_medical_image,
inputs=gr.Image(type="filepath", label="Upload Medical Image"),
outputs=gr.Label(num_top_classes=5, label="Top Predictions"),
title="MedSigLIP Smart Medical Classifier",
description="Zero-shot model with automatic label filtering for different modalities.",
)
if __name__ == "__main__":
server_name = os.getenv("SERVER_NAME", "0.0.0.0")
port_env = os.getenv("SERVER_PORT") or os.getenv("PORT") or "7860"
share_env = os.getenv("GRADIO_SHARE", "false").lower()
queue_env = os.getenv("GRADIO_QUEUE", "false").lower()
share_enabled = share_env in {"1", "true", "yes"}
queue_enabled = queue_env in {"1", "true", "yes"}
app_to_launch = demo.queue() if queue_enabled else demo
app_to_launch.launch(
server_name=server_name,
server_port=int(port_env),
share=share_enabled,
show_api=False,
)
|