File size: 3,164 Bytes
16e8aa3
 
 
 
 
 
 
 
 
 
462cc5a
 
8ba4a66
462cc5a
 
 
 
8ba4a66
462cc5a
 
 
 
8ba4a66
462cc5a
 
 
 
 
8ba4a66
462cc5a
 
 
 
 
 
 
8ba4a66
462cc5a
 
 
 
 
 
 
16e8aa3
 
 
 
dae598b
 
08aa07f
dae598b
16e8aa3
 
08aa07f
 
16e8aa3
 
08aa07f
16e8aa3
 
7773eb5
08aa07f
7773eb5
b97d616
7773eb5
462cc5a
 
7773eb5
462cc5a
 
 
 
 
 
 
 
7773eb5
 
 
904cdd6
 
 
 
 
7773eb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import torch
from huggingface_hub import login
from transformers import AutoProcessor, AutoModelForImageTextToText

# Login with your secret token
login(token=os.environ["HF_TOKEN"])

MODEL_ID = "google/medgemma-1.5-4b-it"

PROMPT = """You are a senior consultant radiologist reporting a brain MRI study.

You have been provided with 5 MRI sequences: T1, T2 axial, T2 FLAIR, DWI, and T1 with contrast.

Write a structured report using EXACTLY this format:

TECHNIQUE:
MRI of the brain was performed using T1, T2, T2 FLAIR, DWI and post-contrast T1 sequences.

FINDINGS:
- Cerebral parenchyma: [signal intensity, any focal or diffuse changes]
- Diffusion: [any restricted diffusion]
- Enhancement: [areas of abnormal enhancement on T1+contrast]
- Extra/Intra axial collections: [midline shift, fluid collections]
- Hippocampi: [signal, volume]
- Basal ganglia, thalami, brainstem and cerebellum: [appearance]
- Sellar/Parasellar region: [pituitary, cavernous sinuses]
- Ventricular system and subarachnoid spaces: [appearance]
- White matter: [FLAIR signal, any lesions — location and distribution]
- Cranial nerves and cerebellopontine angles: [appearance]
- Intracranial vasculature: [flow voids]
- Paranasal sinuses and mastoid air cells: [appearance]
- Orbits: [appearance]
- Calvarium: [marrow signal]

CONCLUSION:
[Clear summary,  e.g. 'No abnormality detected' or specific finding]

Rules:
- Never invent clinical history
- If a finding cannot be confidently assessed, say so explicitly
- Be specific about location using standard anatomical terms
- Keep language professional and concise"""

print("Loading MedGemma... this may take a few minutes")

processor = AutoProcessor.from_pretrained(MODEL_ID)

use_cuda = torch.cuda.is_available()
dtype = torch.bfloat16 if use_cuda else torch.float32
device = "cuda:0" if use_cuda else "cpu"

model = AutoModelForImageTextToText.from_pretrained(
    MODEL_ID,
    torch_dtype=dtype,
    device_map=device
)

model.generation_config.pad_token_id = processor.tokenizer.eos_token_id
model.eval()

print("MedGemma loaded successfully!")
print(f"MedGemma loaded on: {device}")

def generate_report(images):
    """
    Takes a list of PIL Images (one per MRI sequence),
    returns a structured radiology report.
    """

    content = []
    for img in images:
        content.append({"type": "image", "image": img})
        
    content.append({"type": "text", "text": PROMPT})
    
    messages = [{"role": "user", "content": content}]

    # Prepare inputs
    inputs = processor.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,        # ← returns a dict, not a raw Tensor
    return_tensors="pt"
    ).to(model.device)

    # Generate the report
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=False        
        )

    # Decode only the newly generated tokens
    input_length = inputs["input_ids"].shape[1]
    generated_tokens = outputs[0][input_length:]
    report = processor.decode(generated_tokens, skip_special_tokens=True)

    return report