Spaces:
Runtime error
Runtime error
File size: 5,025 Bytes
831f30b 3c22ee7 94e3c71 3c22ee7 6531e0c 3c22ee7 823a4de 94e3c71 e9f8bf5 6531e0c e9f8bf5 94e3c71 3c22ee7 6531e0c 3c22ee7 e9f8bf5 3c22ee7 831f30b 6531e0c 94e3c71 831f30b 94e3c71 3c22ee7 94e3c71 6531e0c e9f8bf5 6531e0c 94e3c71 3c22ee7 6531e0c 94e3c71 3c22ee7 94e3c71 3c22ee7 6531e0c 3c22ee7 823a4de 3c22ee7 6531e0c 3c22ee7 6531e0c 3c22ee7 6531e0c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | import sys
import os
import json
import base64
import io
import torch
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from PIL import Image
from huggingface_hub import snapshot_download, login
from transformers import AutoProcessor, AutoModelForImageTextToText
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
login(token=HF_TOKEN)
print("Authenticated with HF token.", flush=True)
else:
print("WARNING: HF_TOKEN not set — gated models will fail.", flush=True)
# ---------------------------------------------------------------------------
# MedImageInsight — CLIP-style encoder for zero-shot label scoring
# ---------------------------------------------------------------------------
print("Downloading MedImageInsights repo...", flush=True)
repo_path = snapshot_download("lion-ai/MedImageInsights")
print(f"Downloaded to: {repo_path}", flush=True)
sys.path.insert(0, repo_path)
from medimageinsightmodel import MedImageInsight # noqa: E402
model_dir = os.path.join(repo_path, "2024.09.27")
print("Loading MedImageInsight...", flush=True)
classifier = MedImageInsight(
model_dir=model_dir,
vision_model_name="medimageinsigt-v1.0.0.pt",
language_model_name="language_model.pth",
)
classifier.load_model()
print("MedImageInsight ready.", flush=True)
# ---------------------------------------------------------------------------
# MedGemma — generative VLM for free-text image description
# ---------------------------------------------------------------------------
MEDGEMMA_ID = "google/medgemma-1.5-4b-it"
print("Loading MedGemma processor...", flush=True)
gemma_processor = AutoProcessor.from_pretrained(MEDGEMMA_ID)
print("Loading MedGemma model (bfloat16)...", flush=True)
gemma_model = AutoModelForImageTextToText.from_pretrained(
MEDGEMMA_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
gemma_model.eval()
print("MedGemma ready.", flush=True)
# ---------------------------------------------------------------------------
# FastAPI app
# ---------------------------------------------------------------------------
app = FastAPI(title="Medical Image Analysis API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
def _encode_image(data: bytes) -> str:
"""Convert raw image bytes → base64 PNG string for MedImageInsight."""
img = Image.open(io.BytesIO(data)).convert("RGB")
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.encodebytes(buf.getvalue()).decode("utf-8")
def _scores_to_list(scores: dict) -> list:
return [{"label": k, "score": round(float(v), 6)} for k, v in scores.items()]
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/")
def root():
return RedirectResponse(url="/health")
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/classify")
async def classify(
image: UploadFile = File(...),
labels: str = Form(...),
):
"""Zero-shot classification via MedImageInsight. Scores sum to ~1 (softmax)."""
labels_list = json.loads(labels)
img_b64 = _encode_image(await image.read())
results = classifier.predict([img_b64], labels_list, multilabel=False)
return {"results": _scores_to_list(results[0])}
@app.post("/multilabel")
async def multilabel(
image: UploadFile = File(...),
labels: str = Form(...),
):
"""Multi-label classification via MedImageInsight. Each score is independent (sigmoid)."""
labels_list = json.loads(labels)
img_b64 = _encode_image(await image.read())
results = classifier.predict([img_b64], labels_list, multilabel=True)
return {"results": _scores_to_list(results[0])}
@app.post("/describe")
async def describe(
image: UploadFile = File(...),
prompt: str = Form(default="Describe the medical findings visible in this image."),
):
"""Free-text image description via MedGemma 1.5-4B."""
img = Image.open(io.BytesIO(await image.read())).convert("RGB")
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": img},
{"type": "text", "text": prompt},
],
}
]
inputs = gemma_processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(gemma_model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = gemma_model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
)
generation = generation[0][input_len:]
description = gemma_processor.decode(generation, skip_special_tokens=True)
return {"description": description}
|