xray / app.py
fokan's picture
Upload 3 files
ea5ac7a verified
import json
import os
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Tuple
import psutil
import torch
import gradio as gr
from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
from utils.cache_manager import cached_inference, configure_cache
from utils.modality_router import detect_modality
BASE_DIR = Path(__file__).resolve().parent
LABEL_DIR = BASE_DIR / "labels"
MODEL_ID = "google/medsiglip-448"
HF_TOKEN = os.getenv("HF_TOKEN")
physical_cores = psutil.cpu_count(logical=False) or psutil.cpu_count() or 1
torch.set_num_threads(min(physical_cores, 4))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = AutoProcessor.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
use_fast=True,
)
model = AutoModelForZeroShotImageClassification.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
dtype=torch.float32,
).to(device)
model.eval()
configure_cache(model, processor)
LABEL_OVERRIDES = {
"xray": "chest_labels.json",
"mri": "brain_labels.json",
}
@lru_cache(maxsize=None)
def load_labels(file_name: str) -> List[str]:
label_path = LABEL_DIR / file_name
with label_path.open("r", encoding="utf-8") as handle:
return json.load(handle)
def get_candidate_labels(image_path: str) -> Tuple[str, ...]:
modality = detect_modality(image_path)
candidate_path = LABEL_DIR / f"{modality}_labels.json"
if not candidate_path.exists():
override = LABEL_OVERRIDES.get(modality)
if override:
candidate_path = LABEL_DIR / override
if not candidate_path.exists():
candidate_path = LABEL_DIR / "general_labels.json"
return tuple(load_labels(candidate_path.name))
def classify_medical_image(image_path: str) -> Dict[str, float]:
if not image_path:
return {}
candidate_labels = get_candidate_labels(image_path)
scores = cached_inference(image_path, candidate_labels)
if not scores:
return {}
results = sorted(zip(candidate_labels, scores), key=lambda x: x[1], reverse=True)
top_results = results[:5]
return {label: float(score) for label, score in top_results}
demo = gr.Interface(
fn=classify_medical_image,
inputs=gr.Image(type="filepath", label="Upload Medical Image"),
outputs=gr.Label(num_top_classes=5, label="Top Predictions"),
title="MedSigLIP Smart Medical Classifier",
description="Zero-shot model with automatic label filtering for different modalities.",
)
if __name__ == "__main__":
server_name = os.getenv("SERVER_NAME", "0.0.0.0")
port_env = os.getenv("SERVER_PORT") or os.getenv("PORT") or "7860"
share_env = os.getenv("GRADIO_SHARE", "false").lower()
queue_env = os.getenv("GRADIO_QUEUE", "false").lower()
share_enabled = share_env in {"1", "true", "yes"}
queue_enabled = queue_env in {"1", "true", "yes"}
app_to_launch = demo.queue() if queue_enabled else demo
app_to_launch.launch(
server_name=server_name,
server_port=int(port_env),
share=share_enabled,
show_api=False,
)