|
|
import json |
|
|
import os |
|
|
from functools import lru_cache |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
import psutil |
|
|
import torch |
|
|
import gradio as gr |
|
|
from transformers import AutoModelForZeroShotImageClassification, AutoProcessor |
|
|
|
|
|
from utils.cache_manager import cached_inference, configure_cache |
|
|
from utils.modality_router import detect_modality |
|
|
|
|
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent |
|
|
LABEL_DIR = BASE_DIR / "labels" |
|
|
MODEL_ID = "google/medsiglip-448" |
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
physical_cores = psutil.cpu_count(logical=False) or psutil.cpu_count() or 1 |
|
|
torch.set_num_threads(min(physical_cores, 4)) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
processor = AutoProcessor.from_pretrained( |
|
|
MODEL_ID, |
|
|
token=HF_TOKEN, |
|
|
use_fast=True, |
|
|
) |
|
|
model = AutoModelForZeroShotImageClassification.from_pretrained( |
|
|
MODEL_ID, |
|
|
token=HF_TOKEN, |
|
|
dtype=torch.float32, |
|
|
).to(device) |
|
|
model.eval() |
|
|
|
|
|
configure_cache(model, processor) |
|
|
|
|
|
|
|
|
LABEL_OVERRIDES = { |
|
|
"xray": "chest_labels.json", |
|
|
"mri": "brain_labels.json", |
|
|
} |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=None) |
|
|
def load_labels(file_name: str) -> List[str]: |
|
|
label_path = LABEL_DIR / file_name |
|
|
with label_path.open("r", encoding="utf-8") as handle: |
|
|
return json.load(handle) |
|
|
|
|
|
|
|
|
def get_candidate_labels(image_path: str) -> Tuple[str, ...]: |
|
|
modality = detect_modality(image_path) |
|
|
candidate_path = LABEL_DIR / f"{modality}_labels.json" |
|
|
if not candidate_path.exists(): |
|
|
override = LABEL_OVERRIDES.get(modality) |
|
|
if override: |
|
|
candidate_path = LABEL_DIR / override |
|
|
if not candidate_path.exists(): |
|
|
candidate_path = LABEL_DIR / "general_labels.json" |
|
|
|
|
|
return tuple(load_labels(candidate_path.name)) |
|
|
|
|
|
|
|
|
def classify_medical_image(image_path: str) -> Dict[str, float]: |
|
|
if not image_path: |
|
|
return {} |
|
|
|
|
|
candidate_labels = get_candidate_labels(image_path) |
|
|
scores = cached_inference(image_path, candidate_labels) |
|
|
|
|
|
if not scores: |
|
|
return {} |
|
|
|
|
|
results = sorted(zip(candidate_labels, scores), key=lambda x: x[1], reverse=True) |
|
|
top_results = results[:5] |
|
|
|
|
|
return {label: float(score) for label, score in top_results} |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=classify_medical_image, |
|
|
inputs=gr.Image(type="filepath", label="Upload Medical Image"), |
|
|
outputs=gr.Label(num_top_classes=5, label="Top Predictions"), |
|
|
title="MedSigLIP Smart Medical Classifier", |
|
|
description="Zero-shot model with automatic label filtering for different modalities.", |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
server_name = os.getenv("SERVER_NAME", "0.0.0.0") |
|
|
port_env = os.getenv("SERVER_PORT") or os.getenv("PORT") or "7860" |
|
|
share_env = os.getenv("GRADIO_SHARE", "false").lower() |
|
|
queue_env = os.getenv("GRADIO_QUEUE", "false").lower() |
|
|
|
|
|
share_enabled = share_env in {"1", "true", "yes"} |
|
|
queue_enabled = queue_env in {"1", "true", "yes"} |
|
|
|
|
|
app_to_launch = demo.queue() if queue_enabled else demo |
|
|
|
|
|
app_to_launch.launch( |
|
|
server_name=server_name, |
|
|
server_port=int(port_env), |
|
|
share=share_enabled, |
|
|
show_api=False, |
|
|
) |
|
|
|