| import os
|
| import json
|
| import torch
|
| from fastapi import FastAPI, HTTPException
|
| from pydantic import BaseModel
|
| from transformers import AutoProcessor, AutoModelForImageTextToText
|
| from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
| MODEL_ID = "google/medgemma-1.5-4b-it"
|
|
|
|
|
| HF_TOKEN = os.environ.get("HF_TOKEN")
|
|
|
| app = FastAPI(
|
| title="MedGemma Radiology API",
|
| description="FastAPI service for analyzing radiology reports using MedGemma.",
|
| version="1.0.0"
|
| )
|
|
|
|
|
| app.add_middleware(
|
| CORSMiddleware,
|
| allow_origins=["*"],
|
| allow_methods=["*"],
|
| allow_headers=["*"],
|
| )
|
|
|
| processor = None
|
| model = None
|
|
|
| @app.on_event("startup")
|
| def load_model():
|
| global processor, model
|
| print(f"Loading processor and model {MODEL_ID}...")
|
| try:
|
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
| processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
|
| model = AutoModelForImageTextToText.from_pretrained(
|
| MODEL_ID,
|
| torch_dtype=torch.bfloat16,
|
| device_map=device,
|
| low_cpu_mem_usage=True,
|
| token=HF_TOKEN
|
| )
|
| model.eval()
|
| print(f"Model loaded successfully on {device}.")
|
| except Exception as e:
|
| print(f"Error loading model: {e}")
|
| print("Make sure you have set the HF_TOKEN environment variable correctly and accepted the model license.")
|
|
|
| class RadiologyCase(BaseModel):
|
| case_description: str
|
|
|
| class AnalysisResult(BaseModel):
|
| diagnosis: str
|
| recommendations: str
|
| urgency_level: str
|
| raw_response: str = None
|
|
|
|
|
| SYSTEM_PROMPT = """أنت الآن "مساعد تشخيص إشعاعي ذكي" متطور. مهمتك هي تحليل النصوص الواردة إليك والتي تصف نتائج صور الأشعة (X-ray, CT, MRI).
|
|
|
| قواعد العمل:
|
| 1. التخصص: ركز فقط على المصطلحات الطبية الإشعاعية (مثل Opacity, Radiolucency, Fracture, Lesion).
|
| 2. الهيكلية: يجب أن يكون ردك منظماً (النتائج الأساسية، التشخيص المحتمل، التوصيات).
|
| 3. الدقة: إذا كانت الحالة طارئة (مثل نزيف أو كسر مضاعف)، ابدأ بردك واجعل مستوى الحالة "حالة طارئة - Urgent".
|
| 4. التحذير: أضف دائماً في التوصيات أن هذا التحليل هو "رأي استشاري ذكي" ويجب مراجعته من قبل طبيب أشعة مختص.
|
| 5. اللغة: أجب باللغة العربية الطبية الرصينة.
|
|
|
| مهم جداً: قم بالرد باستخدام صيغة JSON صحيحة تحتوي على المفاتيح التالية فقط:
|
| {
|
| "diagnosis": "التشخيص المحتمل والنتائج الأساسية",
|
| "recommendations": "التوصيات والتحذير",
|
| "urgency_level": "مستوى الحالة (مثلاً: حالة طارئة - Urgent أو عادية - Normal)"
|
| }"""
|
|
|
| @app.post("/analyze-radiology", response_model=AnalysisResult)
|
| async def analyze_report(case: RadiologyCase):
|
| if not model or not processor:
|
| raise HTTPException(status_code=503, detail="The AI model is currently loading or failed to load. Please try again later.")
|
|
|
| try:
|
|
|
| user_text = f"{SYSTEM_PROMPT}\n\nنص التقرير أو الحالة:\n{case.case_description}"
|
| messages = [
|
| {
|
| "role": "user",
|
| "content": [
|
| {"type": "text", "text": user_text}
|
| ]
|
| }
|
| ]
|
|
|
|
|
| inputs = processor.apply_chat_template(
|
| messages, add_generation_prompt=True, tokenize=True,
|
| return_dict=True, return_tensors="pt"
|
| ).to(model.device, dtype=torch.bfloat16)
|
|
|
| input_len = inputs["input_ids"].shape[-1]
|
|
|
|
|
| with torch.inference_mode():
|
| generation = model.generate(
|
| **inputs,
|
| max_new_tokens=1024,
|
| do_sample=True,
|
| temperature=0.2,
|
| top_p=0.9
|
| )
|
|
|
| generation_output = generation[0][input_len:]
|
|
|
| decoded = processor.decode(generation_output, skip_special_tokens=True)
|
| raw_output = decoded.strip()
|
|
|
|
|
| clean_json = raw_output
|
| if clean_json.startswith("```json"):
|
| clean_json = clean_json.replace("```json", "", 1)
|
| if clean_json.endswith("```"):
|
| clean_json = clean_json[:-3]
|
| clean_json = clean_json.strip()
|
|
|
|
|
| try:
|
| parsed_data = json.loads(clean_json)
|
| except json.JSONDecodeError:
|
|
|
| is_urgent = "Urgent" in raw_output or "طارئة" in raw_output
|
| parsed_data = {
|
| "diagnosis": raw_output[:500] + ("..." if len(raw_output)>500 else ""),
|
| "recommendations": "تنبيه: لم يقم الموديل بإرجاع هيكل JSON صحيح. هذا التحليل هو رأي استشاري ذكي ويجب مراجعته من قبل طبيب أشعة مختص.",
|
| "urgency_level": "حالة طارئة - Urgent" if is_urgent else "عادية - Normal"
|
| }
|
|
|
| return AnalysisResult(
|
| diagnosis=parsed_data.get("diagnosis", "غير محدد"),
|
| recommendations=parsed_data.get("recommendations", "غير محدد"),
|
| urgency_level=parsed_data.get("urgency_level", "غير محدد"),
|
| raw_response=raw_output
|
| )
|
|
|
| except Exception as e:
|
| raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
|
|
|
| @app.get("/")
|
| def health_check():
|
| return {
|
| "status": "Online",
|
| "model": MODEL_ID,
|
| "message": "Welcome to MedGemma Radiology API"
|
| }
|
|
|