R-RAY-Ultra-5.1 / app.py
iraqigold's picture
Upload 3 files
9e3cf15 verified
import os
import json
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoProcessor, AutoModelForImageTextToText
from fastapi.middleware.cors import CORSMiddleware
# Define the model ID
# MedGemma 1.5 4B fits in ~8GB RAM using bfloat16, perfect for HF CPU Spaces
MODEL_ID = "google/medgemma-1.5-4b-it"
# Get huggingface token for gated models
HF_TOKEN = os.environ.get("HF_TOKEN")
app = FastAPI(
title="MedGemma Radiology API",
description="FastAPI service for analyzing radiology reports using MedGemma.",
version="1.0.0"
)
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
processor = None
model = None
@app.on_event("startup")
def load_model():
global processor, model
print(f"Loading processor and model {MODEL_ID}...")
try:
# Check deployment environment device
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16, # Optimized for reasonable RAM usage
device_map=device,
low_cpu_mem_usage=True,
token=HF_TOKEN
)
model.eval()
print(f"Model loaded successfully on {device}.")
except Exception as e:
print(f"Error loading model: {e}")
print("Make sure you have set the HF_TOKEN environment variable correctly and accepted the model license.")
class RadiologyCase(BaseModel):
case_description: str
class AnalysisResult(BaseModel):
diagnosis: str
recommendations: str
urgency_level: str
raw_response: str = None # Included internally for debugging
# The "dماغ" or System Prompt
SYSTEM_PROMPT = """أنت الآن "مساعد تشخيص إشعاعي ذكي" متطور. مهمتك هي تحليل النصوص الواردة إليك والتي تصف نتائج صور الأشعة (X-ray, CT, MRI).
قواعد العمل:
1. التخصص: ركز فقط على المصطلحات الطبية الإشعاعية (مثل Opacity, Radiolucency, Fracture, Lesion).
2. الهيكلية: يجب أن يكون ردك منظماً (النتائج الأساسية، التشخيص المحتمل، التوصيات).
3. الدقة: إذا كانت الحالة طارئة (مثل نزيف أو كسر مضاعف)، ابدأ بردك واجعل مستوى الحالة "حالة طارئة - Urgent".
4. التحذير: أضف دائماً في التوصيات أن هذا التحليل هو "رأي استشاري ذكي" ويجب مراجعته من قبل طبيب أشعة مختص.
5. اللغة: أجب باللغة العربية الطبية الرصينة.
مهم جداً: قم بالرد باستخدام صيغة JSON صحيحة تحتوي على المفاتيح التالية فقط:
{
"diagnosis": "التشخيص المحتمل والنتائج الأساسية",
"recommendations": "التوصيات والتحذير",
"urgency_level": "مستوى الحالة (مثلاً: حالة طارئة - Urgent أو عادية - Normal)"
}"""
@app.post("/analyze-radiology", response_model=AnalysisResult)
async def analyze_report(case: RadiologyCase):
if not model or not processor:
raise HTTPException(status_code=503, detail="The AI model is currently loading or failed to load. Please try again later.")
try:
# Combine System prompt with user case
user_text = f"{SYSTEM_PROMPT}\n\nنص التقرير أو الحالة:\n{case.case_description}"
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": user_text}
]
}
]
# Format the prompt
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
# Generate with optimized settings
with torch.inference_mode():
generation = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=True,
temperature=0.2, # Conservative temp for medical accuracy
top_p=0.9
)
# Exclude the input prompt from generation output
generation_output = generation[0][input_len:]
decoded = processor.decode(generation_output, skip_special_tokens=True)
raw_output = decoded.strip()
# Helper: Clean out markdown block delimiters if model generated them
clean_json = raw_output
if clean_json.startswith("```json"):
clean_json = clean_json.replace("```json", "", 1)
if clean_json.endswith("```"):
clean_json = clean_json[:-3]
clean_json = clean_json.strip()
# Parse JSON
try:
parsed_data = json.loads(clean_json)
except json.JSONDecodeError:
# Fallback if model doesn't strictly adhere to JSON outline
is_urgent = "Urgent" in raw_output or "طارئة" in raw_output
parsed_data = {
"diagnosis": raw_output[:500] + ("..." if len(raw_output)>500 else ""),
"recommendations": "تنبيه: لم يقم الموديل بإرجاع هيكل JSON صحيح. هذا التحليل هو رأي استشاري ذكي ويجب مراجعته من قبل طبيب أشعة مختص.",
"urgency_level": "حالة طارئة - Urgent" if is_urgent else "عادية - Normal"
}
return AnalysisResult(
diagnosis=parsed_data.get("diagnosis", "غير محدد"),
recommendations=parsed_data.get("recommendations", "غير محدد"),
urgency_level=parsed_data.get("urgency_level", "غير محدد"),
raw_response=raw_output
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
@app.get("/")
def health_check():
return {
"status": "Online",
"model": MODEL_ID,
"message": "Welcome to MedGemma Radiology API"
}