import os import json import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoProcessor, AutoModelForImageTextToText from fastapi.middleware.cors import CORSMiddleware # Define the model ID # MedGemma 1.5 4B fits in ~8GB RAM using bfloat16, perfect for HF CPU Spaces MODEL_ID = "google/medgemma-1.5-4b-it" # Get huggingface token for gated models HF_TOKEN = os.environ.get("HF_TOKEN") app = FastAPI( title="MedGemma Radiology API", description="FastAPI service for analyzing radiology reports using MedGemma.", version="1.0.0" ) # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) processor = None model = None @app.on_event("startup") def load_model(): global processor, model print(f"Loading processor and model {MODEL_ID}...") try: # Check deployment environment device device = "cuda" if torch.cuda.is_available() else "cpu" processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN) model = AutoModelForImageTextToText.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, # Optimized for reasonable RAM usage device_map=device, low_cpu_mem_usage=True, token=HF_TOKEN ) model.eval() print(f"Model loaded successfully on {device}.") except Exception as e: print(f"Error loading model: {e}") print("Make sure you have set the HF_TOKEN environment variable correctly and accepted the model license.") class RadiologyCase(BaseModel): case_description: str class AnalysisResult(BaseModel): diagnosis: str recommendations: str urgency_level: str raw_response: str = None # Included internally for debugging # The "dماغ" or System Prompt SYSTEM_PROMPT = """أنت الآن "مساعد تشخيص إشعاعي ذكي" متطور. مهمتك هي تحليل النصوص الواردة إليك والتي تصف نتائج صور الأشعة (X-ray, CT, MRI). قواعد العمل: 1. التخصص: ركز فقط على المصطلحات الطبية الإشعاعية (مثل Opacity, Radiolucency, Fracture, Lesion). 2. الهيكلية: يجب أن يكون ردك منظماً (النتائج الأساسية، التشخيص المحتمل، التوصيات). 3. الدقة: إذا كانت الحالة طارئة (مثل نزيف أو كسر مضاعف)، ابدأ بردك واجعل مستوى الحالة "حالة طارئة - Urgent". 4. التحذير: أضف دائماً في التوصيات أن هذا التحليل هو "رأي استشاري ذكي" ويجب مراجعته من قبل طبيب أشعة مختص. 5. اللغة: أجب باللغة العربية الطبية الرصينة. مهم جداً: قم بالرد باستخدام صيغة JSON صحيحة تحتوي على المفاتيح التالية فقط: { "diagnosis": "التشخيص المحتمل والنتائج الأساسية", "recommendations": "التوصيات والتحذير", "urgency_level": "مستوى الحالة (مثلاً: حالة طارئة - Urgent أو عادية - Normal)" }""" @app.post("/analyze-radiology", response_model=AnalysisResult) async def analyze_report(case: RadiologyCase): if not model or not processor: raise HTTPException(status_code=503, detail="The AI model is currently loading or failed to load. Please try again later.") try: # Combine System prompt with user case user_text = f"{SYSTEM_PROMPT}\n\nنص التقرير أو الحالة:\n{case.case_description}" messages = [ { "role": "user", "content": [ {"type": "text", "text": user_text} ] } ] # Format the prompt inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(model.device, dtype=torch.bfloat16) input_len = inputs["input_ids"].shape[-1] # Generate with optimized settings with torch.inference_mode(): generation = model.generate( **inputs, max_new_tokens=1024, do_sample=True, temperature=0.2, # Conservative temp for medical accuracy top_p=0.9 ) # Exclude the input prompt from generation output generation_output = generation[0][input_len:] decoded = processor.decode(generation_output, skip_special_tokens=True) raw_output = decoded.strip() # Helper: Clean out markdown block delimiters if model generated them clean_json = raw_output if clean_json.startswith("```json"): clean_json = clean_json.replace("```json", "", 1) if clean_json.endswith("```"): clean_json = clean_json[:-3] clean_json = clean_json.strip() # Parse JSON try: parsed_data = json.loads(clean_json) except json.JSONDecodeError: # Fallback if model doesn't strictly adhere to JSON outline is_urgent = "Urgent" in raw_output or "طارئة" in raw_output parsed_data = { "diagnosis": raw_output[:500] + ("..." if len(raw_output)>500 else ""), "recommendations": "تنبيه: لم يقم الموديل بإرجاع هيكل JSON صحيح. هذا التحليل هو رأي استشاري ذكي ويجب مراجعته من قبل طبيب أشعة مختص.", "urgency_level": "حالة طارئة - Urgent" if is_urgent else "عادية - Normal" } return AnalysisResult( diagnosis=parsed_data.get("diagnosis", "غير محدد"), recommendations=parsed_data.get("recommendations", "غير محدد"), urgency_level=parsed_data.get("urgency_level", "غير محدد"), raw_response=raw_output ) except Exception as e: raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}") @app.get("/") def health_check(): return { "status": "Online", "model": MODEL_ID, "message": "Welcome to MedGemma Radiology API" }