File size: 3,164 Bytes
16e8aa3 462cc5a 8ba4a66 462cc5a 8ba4a66 462cc5a 8ba4a66 462cc5a 8ba4a66 462cc5a 8ba4a66 462cc5a 16e8aa3 dae598b 08aa07f dae598b 16e8aa3 08aa07f 16e8aa3 08aa07f 16e8aa3 7773eb5 08aa07f 7773eb5 b97d616 7773eb5 462cc5a 7773eb5 462cc5a 7773eb5 904cdd6 7773eb5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | import os
import torch
from huggingface_hub import login
from transformers import AutoProcessor, AutoModelForImageTextToText
# Login with your secret token
login(token=os.environ["HF_TOKEN"])
MODEL_ID = "google/medgemma-1.5-4b-it"
PROMPT = """You are a senior consultant radiologist reporting a brain MRI study.
You have been provided with 5 MRI sequences: T1, T2 axial, T2 FLAIR, DWI, and T1 with contrast.
Write a structured report using EXACTLY this format:
TECHNIQUE:
MRI of the brain was performed using T1, T2, T2 FLAIR, DWI and post-contrast T1 sequences.
FINDINGS:
- Cerebral parenchyma: [signal intensity, any focal or diffuse changes]
- Diffusion: [any restricted diffusion]
- Enhancement: [areas of abnormal enhancement on T1+contrast]
- Extra/Intra axial collections: [midline shift, fluid collections]
- Hippocampi: [signal, volume]
- Basal ganglia, thalami, brainstem and cerebellum: [appearance]
- Sellar/Parasellar region: [pituitary, cavernous sinuses]
- Ventricular system and subarachnoid spaces: [appearance]
- White matter: [FLAIR signal, any lesions — location and distribution]
- Cranial nerves and cerebellopontine angles: [appearance]
- Intracranial vasculature: [flow voids]
- Paranasal sinuses and mastoid air cells: [appearance]
- Orbits: [appearance]
- Calvarium: [marrow signal]
CONCLUSION:
[Clear summary, e.g. 'No abnormality detected' or specific finding]
Rules:
- Never invent clinical history
- If a finding cannot be confidently assessed, say so explicitly
- Be specific about location using standard anatomical terms
- Keep language professional and concise"""
print("Loading MedGemma... this may take a few minutes")
processor = AutoProcessor.from_pretrained(MODEL_ID)
use_cuda = torch.cuda.is_available()
dtype = torch.bfloat16 if use_cuda else torch.float32
device = "cuda:0" if use_cuda else "cpu"
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
device_map=device
)
model.generation_config.pad_token_id = processor.tokenizer.eos_token_id
model.eval()
print("MedGemma loaded successfully!")
print(f"MedGemma loaded on: {device}")
def generate_report(images):
"""
Takes a list of PIL Images (one per MRI sequence),
returns a structured radiology report.
"""
content = []
for img in images:
content.append({"type": "image", "image": img})
content.append({"type": "text", "text": PROMPT})
messages = [{"role": "user", "content": content}]
# Prepare inputs
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True, # ← returns a dict, not a raw Tensor
return_tensors="pt"
).to(model.device)
# Generate the report
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False
)
# Decode only the newly generated tokens
input_length = inputs["input_ids"].shape[1]
generated_tokens = outputs[0][input_length:]
report = processor.decode(generated_tokens, skip_special_tokens=True)
return report |