VictorLJZ commited on
Commit
b298259
·
1 Parent(s): 0502778

medgemma script

Browse files
Files changed (1) hide show
  1. medgemma_example.py +365 -0
medgemma_example.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MedGemma VQA Inference Script
3
+ This script performs Visual Question Answering on medical images using Google's MedGemma model.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import torch
9
+ from tqdm import tqdm
10
+ from PIL import Image
11
+ from pathlib import Path
12
+ from transformers import AutoProcessor, AutoModelForImageTextToText
13
+ from transformers import __version__ as transformers_version
14
+
15
+ # Suppress torch dynamo errors to fall back to eager execution
16
+ import torch._dynamo
17
+ torch._dynamo.config.suppress_errors = True
18
+
19
+ print(f"Transformers version: {transformers_version}")
20
+
21
+
22
+ def apply_transformers_workarounds():
23
+ """Apply various workarounds for transformers compatibility issues"""
24
+
25
+ # Workaround 1: ALL_PARALLEL_STYLES issue
26
+ try:
27
+ from transformers import modeling_utils
28
+ if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
29
+ modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
30
+ print("Applied ALL_PARALLEL_STYLES workaround")
31
+ except ImportError:
32
+ pass
33
+
34
+ # Workaround 2: Attention implementation mapping issue
35
+ try:
36
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
37
+ from transformers.models.gemma3.modeling_gemma3 import (
38
+ Gemma3Attention,
39
+ Gemma3SdpaAttention,
40
+ Gemma3FlashAttention2,
41
+ )
42
+
43
+ # Ensure all attention implementations are properly mapped
44
+ attention_mapping = {
45
+ "eager": Gemma3Attention,
46
+ "sdpa": Gemma3SdpaAttention,
47
+ "flash_attention_2": Gemma3FlashAttention2,
48
+ }
49
+
50
+ for key, value in attention_mapping.items():
51
+ if key not in ALL_ATTENTION_FUNCTIONS._global_mapping:
52
+ ALL_ATTENTION_FUNCTIONS._global_mapping[key] = value
53
+
54
+ print("Applied attention functions workaround")
55
+
56
+ except (ImportError, AttributeError) as e:
57
+ print(f"Could not apply attention workaround: {e}")
58
+
59
+ # Workaround 3: Force specific attention implementation
60
+ os.environ["TRANSFORMERS_ATTENTION_TYPE"] = "eager"
61
+
62
+
63
+ # Apply all workarounds before loading the model
64
+ apply_transformers_workarounds()
65
+
66
+
67
+ class MedGemmaVQAInference:
68
+ """
69
+ MedGemma Visual Question Answering Inference Engine
70
+
71
+ This class handles loading the MedGemma model and processing medical VQA tasks.
72
+ """
73
+
74
+ def __init__(self, model_name="google/medgemma-4b-it", device="auto"):
75
+ """
76
+ Initialize the MedGemma model and processor for VQA tasks
77
+
78
+ Args:
79
+ model_name (str): Name or path of the model
80
+ device (str): Device to run inference on ("auto", "cuda", or "cpu")
81
+ """
82
+ self.model_name = model_name
83
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device == "auto" else device
84
+ print(f"Using device: {self.device}")
85
+
86
+ # Load model and processor
87
+ self._load_model()
88
+
89
+ def _load_model(self):
90
+ """Load the model and processor with fallback options"""
91
+ print('Loading Model and Processor...')
92
+
93
+ try:
94
+ self.model = AutoModelForImageTextToText.from_pretrained(
95
+ self.model_name,
96
+ torch_dtype=torch.bfloat16,
97
+ device_map="auto",
98
+ attn_implementation="eager", # Force eager attention to avoid compatibility issues
99
+ trust_remote_code=True,
100
+ )
101
+ self.processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=True)
102
+ print("Model and processor loaded successfully")
103
+
104
+ except Exception as e:
105
+ print(f"Error loading model with eager attention: {e}")
106
+ print("Trying alternative loading method...")
107
+
108
+ # Fallback: try loading without specific attention implementation
109
+ self.model = AutoModelForImageTextToText.from_pretrained(
110
+ self.model_name,
111
+ torch_dtype=torch.bfloat16,
112
+ device_map="auto",
113
+ trust_remote_code=True,
114
+ )
115
+ self.processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=True)
116
+ print("Model loaded with fallback method")
117
+
118
+ def load_images(self, image_paths, base_path=""):
119
+ """
120
+ Load images from paths
121
+
122
+ Args:
123
+ image_paths (list): List of image paths
124
+ base_path (str): Base path to prepend to image paths
125
+
126
+ Returns:
127
+ list: List of loaded PIL images (limited to 2 images)
128
+ """
129
+ images = []
130
+ for img_path in image_paths:
131
+ full_path = Path(base_path) / img_path.lstrip('/')
132
+
133
+ # Handle both .dcm and .png formats
134
+ if full_path.suffix == '.dcm':
135
+ full_path = full_path.with_suffix('.png')
136
+
137
+ try:
138
+ img = Image.open(str(full_path)).convert('RGB')
139
+ images.append(img)
140
+ except Exception as e:
141
+ print(f"Error loading image {full_path}: {e}")
142
+
143
+ # Limit to 2 images for optimal performance
144
+ return images[:2]
145
+
146
+ def generate_prompt(self, question: str, options: list, context: str = "") -> str:
147
+ """
148
+ Generate a prompt for the medical VQA model
149
+
150
+ Args:
151
+ question (str): The medical question to be answered
152
+ options (list): List of option strings
153
+ context (str, optional): Additional context or patient information
154
+
155
+ Returns:
156
+ str: Formatted prompt string ready for model input
157
+ """
158
+ # Format the question and options as a dictionary
159
+ try:
160
+ formatted_options = {
161
+ opt.strip().split('.')[0].strip(): opt.strip().split('.', 1)[1].strip()
162
+ for opt in options if '.' in opt
163
+ }
164
+ except Exception:
165
+ # Fallback if options don't follow expected format
166
+ formatted_options = {chr(65 + i): opt.strip() for i, opt in enumerate(options)}
167
+
168
+ question_data = {
169
+ "Question": question.strip(),
170
+ "Options": formatted_options
171
+ }
172
+
173
+ context_text = f"Patient information: {context}. " if context else ""
174
+
175
+ prompt = f"""{context_text}You are a medical expert assistant. Please analyze the provided medical image(s) and answer the multiple-choice question.
176
+
177
+ Please provide your response in JSON format as follows: {{"answer": "letter_of_correct_option", "explanation": "brief explanation of your choice"}}
178
+
179
+ Question and Options:
180
+ {json.dumps(question_data, indent=2)}
181
+
182
+ Please select the most appropriate answer and provide a brief medical explanation."""
183
+
184
+ return prompt
185
+
186
+ def process_single_case(self, case_data, base_path=""):
187
+ """
188
+ Process a single VQA case
189
+
190
+ Args:
191
+ case_data (dict): Case data including images, question, and options
192
+ base_path (str): Base path for image loading
193
+
194
+ Returns:
195
+ dict: Results including model's answer and explanation
196
+ """
197
+ # Load images
198
+ images = self.load_images(case_data['image_path_list'], base_path)
199
+ if not images:
200
+ return {"error": "No images could be loaded"}
201
+
202
+ # Generate prompt
203
+ prompt = self.generate_prompt(
204
+ case_data['question'],
205
+ case_data['options'],
206
+ case_data.get('context', '')
207
+ )
208
+
209
+ # Create messages for MedGemma chat template
210
+ messages = [
211
+ {
212
+ "role": "system",
213
+ "content": [{"type": "text", "text": "You are an expert medical AI assistant specializing in medical image analysis and diagnosis."}]
214
+ },
215
+ {
216
+ "role": "user",
217
+ "content": [
218
+ {"type": "text", "text": prompt}
219
+ ] + [{"type": "image", "image": img} for img in images]
220
+ }
221
+ ]
222
+
223
+ try:
224
+ # Process inputs using chat template
225
+ inputs = self.processor.apply_chat_template(
226
+ messages,
227
+ add_generation_prompt=True,
228
+ tokenize=True,
229
+ return_dict=True,
230
+ return_tensors="pt"
231
+ ).to(self.model.device, dtype=torch.bfloat16)
232
+
233
+ input_len = inputs["input_ids"].shape[-1]
234
+
235
+ # Generate response
236
+ with torch.inference_mode():
237
+ generation = self.model.generate(
238
+ **inputs,
239
+ max_new_tokens=300,
240
+ do_sample=False,
241
+ pad_token_id=self.processor.tokenizer.eos_token_id,
242
+ temperature=0.0,
243
+ )
244
+ generation = generation[0][input_len:]
245
+
246
+ # Decode the generated text
247
+ response = self.processor.decode(generation, skip_special_tokens=True).strip()
248
+
249
+ # Clean up to free GPU memory
250
+ del inputs, generation
251
+ if torch.cuda.is_available():
252
+ torch.cuda.empty_cache()
253
+
254
+ return {
255
+ "model_response": response,
256
+ "question_id": case_data.get('study_id', '') + '_' + case_data.get('task_name', ''),
257
+ "question": case_data.get('question', ''),
258
+ "options": case_data.get('options', []),
259
+ "correct_answer": case_data.get('correct_answer', ''),
260
+ "category": case_data.get('category', ''),
261
+ "subcategory": case_data.get('subcategory', ''),
262
+ "context": case_data.get('context', '')
263
+ }
264
+
265
+ except Exception as e:
266
+ return {"error": f"Processing error: {str(e)}"}
267
+
268
+ def process_batch(self, json_data, base_path="", output_file="results.json"):
269
+ """
270
+ Process multiple cases with progress bar and checkpointing
271
+
272
+ Args:
273
+ json_data (dict): Dictionary containing multiple cases
274
+ base_path (str): Base path for image loading
275
+ output_file (str): Path to save results
276
+
277
+ Returns:
278
+ dict: Results for all cases
279
+ """
280
+ # Load existing results if available
281
+ results = {}
282
+ if os.path.exists(output_file):
283
+ try:
284
+ with open(output_file, 'r') as f:
285
+ results = json.load(f)
286
+ # Remove all items with errors to retry them
287
+ results = {k: v for k, v in results.items() if 'error' not in v}
288
+ print(f"Loaded {len(results)} existing results from {output_file}")
289
+ except json.JSONDecodeError:
290
+ print(f"Error loading existing results from {output_file}, starting fresh")
291
+
292
+ # Create progress bar
293
+ pbar = tqdm(total=len(json_data), desc="Processing VQA cases")
294
+ pbar.update(len(results))
295
+
296
+ # Process remaining cases
297
+ for case_id, case_data in json_data.items():
298
+ # Skip if already processed
299
+ if case_id in results:
300
+ continue
301
+
302
+ try:
303
+ results[case_id] = self.process_single_case(case_data, base_path)
304
+
305
+ # Save results after each successful case (checkpointing)
306
+ with open(output_file, 'w') as f:
307
+ json.dump(results, f, indent=2)
308
+
309
+ # Print errors for debugging
310
+ if "error" in results[case_id]:
311
+ print(f"\nError processing {case_id}: {results[case_id]['error']}")
312
+
313
+ except Exception as e:
314
+ results[case_id] = {"error": str(e)}
315
+ # Also save on error
316
+ with open(output_file, 'w') as f:
317
+ json.dump(results, f, indent=2)
318
+ print(f"\nException processing {case_id}: {str(e)}")
319
+
320
+ pbar.update(1)
321
+
322
+ pbar.close()
323
+ return results
324
+
325
+
326
+ def main():
327
+ """Main function to run MedGemma VQA inference"""
328
+ import argparse
329
+
330
+ parser = argparse.ArgumentParser(description='MedGemma VQA Inference')
331
+ parser.add_argument('--input_file', type=str, required=True,
332
+ help='Input JSON file with VQA cases')
333
+ parser.add_argument('--output_file', type=str, required=True,
334
+ help='Output JSON file for results')
335
+ parser.add_argument('--base_path', type=str, default="",
336
+ help='Base path for image loading')
337
+ parser.add_argument('--model_name', type=str,
338
+ default='google/medgemma-4b-it',
339
+ help='Model name or path')
340
+ args = parser.parse_args()
341
+
342
+ # Create output directory if it doesn't exist
343
+ os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
344
+
345
+ # Initialize model
346
+ print("Initializing MedGemma VQA model...")
347
+ inferencer = MedGemmaVQAInference(model_name=args.model_name)
348
+
349
+ # Load JSON data
350
+ print(f"Loading VQA cases from {args.input_file}")
351
+ with open(args.input_file, 'r') as f:
352
+ cases = json.load(f)
353
+
354
+ print(f"Found {len(cases)} VQA cases to process")
355
+
356
+ # Process all cases with progress bar and checkpointing
357
+ results = inferencer.process_batch(cases, args.base_path, args.output_file)
358
+
359
+ print(f"\nProcessing complete. Results saved to {args.output_file}")
360
+ print(f"Successfully processed {len([r for r in results.values() if 'error' not in r])} cases")
361
+ print(f"Errors in {len([r for r in results.values() if 'error' in r])} cases")
362
+
363
+
364
+ if __name__ == "__main__":
365
+ main()