File size: 6,624 Bytes
9e3cf15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import json
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoProcessor, AutoModelForImageTextToText
from fastapi.middleware.cors import CORSMiddleware

# Define the model ID
# MedGemma 1.5 4B fits in ~8GB RAM using bfloat16, perfect for HF CPU Spaces
MODEL_ID = "google/medgemma-1.5-4b-it"

# Get huggingface token for gated models
HF_TOKEN = os.environ.get("HF_TOKEN")

app = FastAPI(
    title="MedGemma Radiology API",
    description="FastAPI service for analyzing radiology reports using MedGemma.",
    version="1.0.0"
)

# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

processor = None
model = None

@app.on_event("startup")
def load_model():
    global processor, model
    print(f"Loading processor and model {MODEL_ID}...")
    try:
        # Check deployment environment device
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
        model = AutoModelForImageTextToText.from_pretrained(
            MODEL_ID,
            torch_dtype=torch.bfloat16, # Optimized for reasonable RAM usage
            device_map=device,
            low_cpu_mem_usage=True,
            token=HF_TOKEN
        )
        model.eval()
        print(f"Model loaded successfully on {device}.")
    except Exception as e:
        print(f"Error loading model: {e}")
        print("Make sure you have set the HF_TOKEN environment variable correctly and accepted the model license.")

class RadiologyCase(BaseModel):
    case_description: str

class AnalysisResult(BaseModel):
    diagnosis: str
    recommendations: str
    urgency_level: str
    raw_response: str = None # Included internally for debugging

# The "dماغ" or System Prompt
SYSTEM_PROMPT = """أنت الآن "مساعد تشخيص إشعاعي ذكي" متطور. مهمتك هي تحليل النصوص الواردة إليك والتي تصف نتائج صور الأشعة (X-ray, CT, MRI).



قواعد العمل:

1. التخصص: ركز فقط على المصطلحات الطبية الإشعاعية (مثل Opacity, Radiolucency, Fracture, Lesion).

2. الهيكلية: يجب أن يكون ردك منظماً (النتائج الأساسية، التشخيص المحتمل، التوصيات).

3. الدقة: إذا كانت الحالة طارئة (مثل نزيف أو كسر مضاعف)، ابدأ بردك واجعل مستوى الحالة "حالة طارئة - Urgent".

4. التحذير: أضف دائماً في التوصيات أن هذا التحليل هو "رأي استشاري ذكي" ويجب مراجعته من قبل طبيب أشعة مختص.

5. اللغة: أجب باللغة العربية الطبية الرصينة.



مهم جداً: قم بالرد باستخدام صيغة JSON صحيحة تحتوي على المفاتيح التالية فقط:

{

  "diagnosis": "التشخيص المحتمل والنتائج الأساسية",

  "recommendations": "التوصيات والتحذير",

  "urgency_level": "مستوى الحالة (مثلاً: حالة طارئة - Urgent أو عادية - Normal)"

}"""

@app.post("/analyze-radiology", response_model=AnalysisResult)
async def analyze_report(case: RadiologyCase):
    if not model or not processor:
        raise HTTPException(status_code=503, detail="The AI model is currently loading or failed to load. Please try again later.")
        
    try:
        # Combine System prompt with user case
        user_text = f"{SYSTEM_PROMPT}\n\nنص التقرير أو الحالة:\n{case.case_description}"
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": user_text}
                ]
            }
        ]
        
        # Format the prompt
        inputs = processor.apply_chat_template(
            messages, add_generation_prompt=True, tokenize=True,
            return_dict=True, return_tensors="pt"
        ).to(model.device, dtype=torch.bfloat16)

        input_len = inputs["input_ids"].shape[-1]
        
        # Generate with optimized settings
        with torch.inference_mode():
            generation = model.generate(
                **inputs, 
                max_new_tokens=1024, 
                do_sample=True, 
                temperature=0.2, # Conservative temp for medical accuracy
                top_p=0.9
            )
            # Exclude the input prompt from generation output
            generation_output = generation[0][input_len:]
            
        decoded = processor.decode(generation_output, skip_special_tokens=True)
        raw_output = decoded.strip()
        
        # Helper: Clean out markdown block delimiters if model generated them
        clean_json = raw_output
        if clean_json.startswith("```json"):
            clean_json = clean_json.replace("```json", "", 1)
        if clean_json.endswith("```"):
            clean_json = clean_json[:-3]
        clean_json = clean_json.strip()

        # Parse JSON
        try:
            parsed_data = json.loads(clean_json)
        except json.JSONDecodeError:
            # Fallback if model doesn't strictly adhere to JSON outline
            is_urgent = "Urgent" in raw_output or "طارئة" in raw_output
            parsed_data = {
                "diagnosis": raw_output[:500] + ("..." if len(raw_output)>500 else ""),
                "recommendations": "تنبيه: لم يقم الموديل بإرجاع هيكل JSON صحيح. هذا التحليل هو رأي استشاري ذكي ويجب مراجعته من قبل طبيب أشعة مختص.",
                "urgency_level": "حالة طارئة - Urgent" if is_urgent else "عادية - Normal"
            }

        return AnalysisResult(
            diagnosis=parsed_data.get("diagnosis", "غير محدد"),
            recommendations=parsed_data.get("recommendations", "غير محدد"),
            urgency_level=parsed_data.get("urgency_level", "غير محدد"),
            raw_response=raw_output
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")

@app.get("/")
def health_check():
    return {
        "status": "Online",
        "model": MODEL_ID,
        "message": "Welcome to MedGemma Radiology API"
    }