import os import torch from huggingface_hub import login from transformers import AutoProcessor, AutoModelForImageTextToText # Login with your secret token login(token=os.environ["HF_TOKEN"]) MODEL_ID = "google/medgemma-1.5-4b-it" PROMPT = """You are a senior consultant radiologist reporting a brain MRI study. You have been provided with 5 MRI sequences: T1, T2 axial, T2 FLAIR, DWI, and T1 with contrast. Write a structured report using EXACTLY this format: TECHNIQUE: MRI of the brain was performed using T1, T2, T2 FLAIR, DWI and post-contrast T1 sequences. FINDINGS: - Cerebral parenchyma: [signal intensity, any focal or diffuse changes] - Diffusion: [any restricted diffusion] - Enhancement: [areas of abnormal enhancement on T1+contrast] - Extra/Intra axial collections: [midline shift, fluid collections] - Hippocampi: [signal, volume] - Basal ganglia, thalami, brainstem and cerebellum: [appearance] - Sellar/Parasellar region: [pituitary, cavernous sinuses] - Ventricular system and subarachnoid spaces: [appearance] - White matter: [FLAIR signal, any lesions — location and distribution] - Cranial nerves and cerebellopontine angles: [appearance] - Intracranial vasculature: [flow voids] - Paranasal sinuses and mastoid air cells: [appearance] - Orbits: [appearance] - Calvarium: [marrow signal] CONCLUSION: [Clear summary, e.g. 'No abnormality detected' or specific finding] Rules: - Never invent clinical history - If a finding cannot be confidently assessed, say so explicitly - Be specific about location using standard anatomical terms - Keep language professional and concise""" print("Loading MedGemma... this may take a few minutes") processor = AutoProcessor.from_pretrained(MODEL_ID) use_cuda = torch.cuda.is_available() dtype = torch.bfloat16 if use_cuda else torch.float32 device = "cuda:0" if use_cuda else "cpu" model = AutoModelForImageTextToText.from_pretrained( MODEL_ID, torch_dtype=dtype, device_map=device ) model.generation_config.pad_token_id = processor.tokenizer.eos_token_id model.eval() print("MedGemma loaded successfully!") print(f"MedGemma loaded on: {device}") def generate_report(images): """ Takes a list of PIL Images (one per MRI sequence), returns a structured radiology report. """ content = [] for img in images: content.append({"type": "image", "image": img}) content.append({"type": "text", "text": PROMPT}) messages = [{"role": "user", "content": content}] # Prepare inputs inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, # ← returns a dict, not a raw Tensor return_tensors="pt" ).to(model.device) # Generate the report with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=512, do_sample=False ) # Decode only the newly generated tokens input_length = inputs["input_ids"].shape[1] generated_tokens = outputs[0][input_length:] report = processor.decode(generated_tokens, skip_special_tokens=True) return report