Spaces:
Sleeping
Sleeping
File size: 7,487 Bytes
50e19d2 8a35df3 50e19d2 8a35df3 50e19d2 8a35df3 50e19d2 8a35df3 50e19d2 8a35df3 50e19d2 8a35df3 50e19d2 8a35df3 50e19d2 8a35df3 50e19d2 8a35df3 50e19d2 8a35df3 50e19d2 8a35df3 50e19d2 8a35df3 50e19d2 8a35df3 50e19d2 8a35df3 50e19d2 8a35df3 50e19d2 8a35df3 50e19d2 8a35df3 50e19d2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | import os
import io
import json
import torch
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from pydantic import BaseModel
from transformers import AutoProcessor, AutoModelForImageTextToText
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
# Define the model ID
MODEL_ID = "google/medgemma-1.5-4b-it"
# Get huggingface token for gated models
HF_TOKEN = os.environ.get("HF_TOKEN")
app = FastAPI(
title="MedGemma Radiology API",
description="FastAPI service for analyzing multimodal radiology cases (Image + Text) using MedGemma.",
version="1.0.0"
)
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
processor = None
model = None
@app.on_event("startup")
def load_model():
global processor, model
print(f"Loading processor and model {MODEL_ID}...")
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map=device,
low_cpu_mem_usage=True,
token=HF_TOKEN
)
model.eval()
print(f"Model loaded successfully on {device}.")
except Exception as e:
print(f"Error loading model: {e}")
print("Make sure you have set the HF_TOKEN environment variable correctly and accepted the model license.")
class AnalysisResult(BaseModel):
diagnosis: str
recommendations: str
urgency_level: str
raw_response: str = None
# The "dماغ" or System Prompt
SYSTEM_PROMPT = """أنت الآن "مساعد تشخيص إشعاعي ذكي" متطور. مهمتك هي تحليل الصور والفحوصات الطبية المرفقة بالإضافة إلى النصوص الواردة والتي تصف حالة المريض.
قواعد العمل:
1. التخصص: ركز فقط على المصطلحات الطبية الإشعاعية (مثل Opacity, Radiolucency, Fracture, Lesion) عند وصف الصورة.
2. الهيكلية: يجب أن يكون ردك منظماً (النتائج الأساسية للصورة، التشخيص المحتمل، التوصيات).
3. الدقة: إذا كانت الحالة طارئة بناءً على الصورة (مثل كسر مضاعف أو استرواح الصدر)، اجعل مستوى الحالة "حالة طارئة - Urgent".
4. التحذير: أضف دائماً في التوصيات أن هذا التحليل هو "رأي استشاري ذكي" ويجب مراجعته من قبل طبيب أشعة مختص.
5. اللغة: أجب باللغة العربية الطبية الرصينة.
مهم جداً: قم بالرد باستخدام صيغة JSON صحيحة تحتوي على المفاتيح التالية فقط:
{
"diagnosis": "نتائج تحليل الصورة والتشخيص المحتمل",
"recommendations": "التوصيات والتحذير",
"urgency_level": "مستوى الحالة (مثلاً: حالة طارئة - Urgent أو عادية - Normal)"
}"""
@app.post("/analyze-radiology", response_model=AnalysisResult)
async def analyze_report(
case_description: str = Form(""),
image: UploadFile = File(None)
):
"""
Analyzes a radiology case. Accepts an optional text description and an optional image (X-Ray, MRI, etc).
At least one of them must be provided.
"""
if not model or not processor:
raise HTTPException(status_code=503, detail="The AI model is currently loading or failed to load. Please try again later.")
if not case_description and not image:
raise HTTPException(status_code=400, detail="يجب إرفاق صورة أو كتابة وصف للحالة على الأقل.")
try:
content = []
# 1. Process Image if provided
if image:
image_data = await image.read()
pil_image = Image.open(io.BytesIO(image_data)).convert("RGB")
content.append({"type": "image", "image": pil_image})
# 2. Process Text
user_text = SYSTEM_PROMPT + "\n\n"
if case_description:
user_text += f"وصف الحالة السريرية أو الأعراض:\n{case_description}\n\n"
if image:
user_text += "الرجاء تحليل الصورة الطبية المرفقة بناءً على القواعد أعلاه."
else:
user_text += "الرجاء تحليل الوصف الطبي أعلاه بناءً على القواعد أعلاه."
content.append({"type": "text", "text": user_text})
# 3. Create messages format
messages = [
{
"role": "user",
"content": content
}
]
# Format the prompt
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
# Generate
with torch.inference_mode():
generation = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=True,
temperature=0.2,
top_p=0.9
)
generation_output = generation[0][input_len:]
decoded = processor.decode(generation_output, skip_special_tokens=True)
raw_output = decoded.strip()
# Clean JSON markdown blocks
clean_json = raw_output
if clean_json.startswith("```json"):
clean_json = clean_json.replace("```json", "", 1)
if clean_json.endswith("```"):
clean_json = clean_json[:-3]
clean_json = clean_json.strip()
# Parse JSON
try:
parsed_data = json.loads(clean_json)
except json.JSONDecodeError:
is_urgent = "Urgent" in raw_output or "طارئة" in raw_output
parsed_data = {
"diagnosis": raw_output[:500] + ("..." if len(raw_output)>500 else ""),
"recommendations": "تنبيه: لم يقم الموديل بإرجاع هيكل JSON صحيح. هذا التحليل هو رأي استشاري ذكي ويجب مراجعته من قبل طبيب أشعة مختص.",
"urgency_level": "حالة طارئة - Urgent" if is_urgent else "عادية - Normal"
}
return AnalysisResult(
diagnosis=parsed_data.get("diagnosis", "غير محدد"),
recommendations=parsed_data.get("recommendations", "غير محدد"),
urgency_level=parsed_data.get("urgency_level", "غير محدد"),
raw_response=raw_output
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
@app.get("/")
def health_check():
return {
"status": "Online",
"model": MODEL_ID,
"vision_enabled": True,
"message": "Welcome to Multimodal MedGemma Radiology API"
}
|