| import os |
| import torch |
| from huggingface_hub import login |
| from transformers import AutoProcessor, AutoModelForImageTextToText |
|
|
| |
| login(token=os.environ["HF_TOKEN"]) |
|
|
| MODEL_ID = "google/medgemma-1.5-4b-it" |
|
|
| PROMPT = """You are a senior consultant radiologist reporting a brain MRI study. |
| |
| You have been provided with 5 MRI sequences: T1, T2 axial, T2 FLAIR, DWI, and T1 with contrast. |
| |
| Write a structured report using EXACTLY this format: |
| |
| TECHNIQUE: |
| MRI of the brain was performed using T1, T2, T2 FLAIR, DWI and post-contrast T1 sequences. |
| |
| FINDINGS: |
| - Cerebral parenchyma: [signal intensity, any focal or diffuse changes] |
| - Diffusion: [any restricted diffusion] |
| - Enhancement: [areas of abnormal enhancement on T1+contrast] |
| - Extra/Intra axial collections: [midline shift, fluid collections] |
| - Hippocampi: [signal, volume] |
| - Basal ganglia, thalami, brainstem and cerebellum: [appearance] |
| - Sellar/Parasellar region: [pituitary, cavernous sinuses] |
| - Ventricular system and subarachnoid spaces: [appearance] |
| - White matter: [FLAIR signal, any lesions — location and distribution] |
| - Cranial nerves and cerebellopontine angles: [appearance] |
| - Intracranial vasculature: [flow voids] |
| - Paranasal sinuses and mastoid air cells: [appearance] |
| - Orbits: [appearance] |
| - Calvarium: [marrow signal] |
| |
| CONCLUSION: |
| [Clear summary, e.g. 'No abnormality detected' or specific finding] |
| |
| Rules: |
| - Never invent clinical history |
| - If a finding cannot be confidently assessed, say so explicitly |
| - Be specific about location using standard anatomical terms |
| - Keep language professional and concise""" |
|
|
| print("Loading MedGemma... this may take a few minutes") |
|
|
| processor = AutoProcessor.from_pretrained(MODEL_ID) |
|
|
| use_cuda = torch.cuda.is_available() |
| dtype = torch.bfloat16 if use_cuda else torch.float32 |
| device = "cuda:0" if use_cuda else "cpu" |
|
|
| model = AutoModelForImageTextToText.from_pretrained( |
| MODEL_ID, |
| torch_dtype=dtype, |
| device_map=device |
| ) |
|
|
| model.generation_config.pad_token_id = processor.tokenizer.eos_token_id |
| model.eval() |
|
|
| print("MedGemma loaded successfully!") |
| print(f"MedGemma loaded on: {device}") |
|
|
| def generate_report(images): |
| """ |
| Takes a list of PIL Images (one per MRI sequence), |
| returns a structured radiology report. |
| """ |
|
|
| content = [] |
| for img in images: |
| content.append({"type": "image", "image": img}) |
| |
| content.append({"type": "text", "text": PROMPT}) |
| |
| messages = [{"role": "user", "content": content}] |
|
|
| |
| inputs = processor.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt" |
| ).to(model.device) |
|
|
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=512, |
| do_sample=False |
| ) |
|
|
| |
| input_length = inputs["input_ids"].shape[1] |
| generated_tokens = outputs[0][input_length:] |
| report = processor.decode(generated_tokens, skip_special_tokens=True) |
|
|
| return report |