leoyinn commited on
Commit
db00ba8
·
verified ·
1 Parent(s): f1ed8e5

Upload fine-tuned MedGemma model for FLARE 2025 medical image analysis

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model: google/medgemma-4b-it
4
+ tags:
5
+ - vision-language
6
+ - medical-imaging
7
+ - radiology
8
+ - medgemma
9
+ - gemma
10
+ - flare2025
11
+ - peft
12
+ - lora
13
+ - multimodal
14
+ - medical-ai
15
+ datasets:
16
+ - FLARE2025
17
+ pipeline_tag: image-text-to-text
18
+ library_name: transformers
19
+ ---
20
+
21
+ # MedGemma Fine-tuned for FLARE 2025 Medical Image Analysis
22
+
23
+ This model is a fine-tuned version of [google/medgemma-4b-it](https://huggingface.co/google/medgemma-4b-it)
24
+ specifically optimized for medical image analysis tasks in the FLARE 2025 2D Medical Multimodal Dataset challenge.
25
+
26
+ ## Model Description
27
+
28
+ - **Base Model**: MedGemma-4B-IT (Google's medical-specialized Gemma model)
29
+ - **Fine-tuning Method**: QLoRA (Low-Rank Adaptation)
30
+ - **Target Domain**: Medical imaging across 7 modalities (CT, MRI, X-ray, Ultrasound, Fundus, Pathology, Endoscopy)
31
+ - **Tasks**: Medical image captioning, visual question answering, report generation, diagnosis support
32
+ - **Training Data**: 19 FLARE 2025 datasets with comprehensive medical annotations
33
+
34
+ ## Training Details
35
+
36
+ ### Training Data
37
+ The model was fine-tuned on 19 diverse medical imaging datasets from FLARE 2025, including:
38
+ - **Classification**: Disease diagnosis with balanced accuracy optimization
39
+ - **Multi-label Classification**: Multi-pathology identification
40
+ - **Detection**: Anatomical structure and pathology detection
41
+ - **Instance Detection**: Identity-aware detection (e.g., chromosome analysis)
42
+ - **Counting**: Cell counting and quantitative analysis
43
+ - **Regression**: Continuous medical measurements
44
+ - **Report Generation**: Comprehensive medical report writing
45
+
46
+ Details available at: https://huggingface.co/datasets/FLARE-MedFM/FLARE-Task5-MLLM-2D
47
+
48
+ ### Training Configuration
49
+ ```yaml\n# LoRA Configuration\nlora_r: 16\nlora_alpha: 32\nlora_dropout: 0.1\ntarget_modules: ['gate_proj', 'up_proj', 'o_proj', 'down_proj', 'v_proj', 'q_proj', 'k_proj']\ntask_type: CAUSAL_LM\nbias: none\n\n```
50
+
51
+ ### Training Procedure
52
+ - **Base Architecture**: MedGemma-4B with medical domain pre-training
53
+ - **Optimization**: 4-bit quantization with BitsAndBytesConfig
54
+ - **LoRA Configuration**:
55
+ - r=64, alpha=16, dropout=0.1
56
+ - Target modules: All attention and MLP layers
57
+ - **Memory Optimization**: Gradient checkpointing, flash attention
58
+ - **Batch Size**: Dynamic based on image resolution and GPU memory
59
+ - **Learning Rate**: 1e-4 with cosine scheduling
60
+ - **Training Steps**: 4000 steps with evaluation every 500 steps
61
+ - **Chat Template**: Gemma-style chat formatting for medical conversations
62
+
63
+ ## Model Performance
64
+
65
+ This model has been evaluated across multiple medical imaging tasks using FLARE 2025 evaluation metrics:
66
+
67
+ ### Evaluation Metrics by Task Type
68
+
69
+ **Classification Tasks (Disease Diagnosis):**
70
+ - **Balanced Accuracy** (PRIMARY): Handles class imbalance in medical diagnosis
71
+ - **Accuracy**: Standard classification accuracy
72
+ - **F1 Score**: Weighted F1 for multi-class scenarios
73
+
74
+ **Multi-label Classification (Multi-pathology):**
75
+ - **F1 Score** (PRIMARY): Sample-wise F1 across multiple medical conditions
76
+ - **Precision**: Label prediction precision
77
+ - **Recall**: Medical condition coverage recall
78
+
79
+ **Detection Tasks (Anatomical/Pathological):**
80
+ - **F1 Score @ IoU > 0.5** (PRIMARY): Standard computer vision detection metric
81
+ - **Precision**: Detection precision at IoU threshold
82
+ - **Recall**: Detection recall at IoU threshold
83
+
84
+ **Instance Detection (Identity-Aware Detection):**
85
+ - **F1 Score @ IoU > 0.3** (PRIMARY): Medical imaging standard (e.g., chromosome detection)
86
+ - **F1 Score @ IoU > 0.5**: Computer vision standard
87
+ - **Average F1**: COCO-style average across IoU thresholds (0.3-0.7)
88
+ - **Per-instance metrics**: Detailed breakdown by object identity
89
+
90
+ **Counting Tasks (Cell/Structure Counting):**
91
+ - **Mean Absolute Error** (PRIMARY): Cell counting accuracy
92
+ - **Root Mean Squared Error**: Additional counting precision metric
93
+
94
+ **Regression Tasks (Medical Measurements):**
95
+ - **Mean Absolute Error** (PRIMARY): Continuous value prediction accuracy
96
+ - **Root Mean Squared Error**: Regression precision metric
97
+
98
+ **Report Generation (Medical Reports):**
99
+ - **GREEN Score** (PRIMARY): Comprehensive medical report evaluation with 7 components:
100
+ - Entity matching with severity assessment (30%)
101
+ - Location accuracy with laterality (20%)
102
+ - Negation and uncertainty handling (15%)
103
+ - Temporal accuracy (10%)
104
+ - Size/measurement accuracy (10%)
105
+ - Clinical significance weighting (10%)
106
+ - Report structure completeness (5%)
107
+ - **BLEU Score**: Text generation quality
108
+ - **Clinical Efficacy**: Medical relevance scoring
109
+
110
+ ## Usage
111
+
112
+ ### Installation
113
+ ```bash
114
+ pip install transformers torch peft accelerate bitsandbytes
115
+ ```
116
+
117
+ ### Basic Usage
118
+ ```python
119
+ import torch
120
+ from transformers import AutoTokenizer, AutoProcessor, AutoModelForImageTextToText
121
+ from peft import PeftModel
122
+ from PIL import Image
123
+
124
+ # Load the fine-tuned model
125
+ base_model_name = "google/medgemma-4b-it"
126
+ adapter_model_name = "leoyinn/flare25-medgemma"
127
+
128
+ # Load tokenizer and processor
129
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
130
+ processor = AutoProcessor.from_pretrained(base_model_name, trust_remote_code=True)
131
+
132
+ # Load base model
133
+ base_model = AutoModelForImageTextToText.from_pretrained(
134
+ base_model_name,
135
+ torch_dtype=torch.bfloat16,
136
+ device_map="auto",
137
+ trust_remote_code=True,
138
+ attn_implementation="eager"
139
+ )
140
+
141
+ # Load the fine-tuned adapter
142
+ model = PeftModel.from_pretrained(base_model, adapter_model_name)
143
+
144
+ # Prepare input with MedGemma chat format
145
+ image = Image.open("medical_image.jpg").convert("RGB")
146
+ image = image.resize((448, 448)) # MedGemma standard size
147
+
148
+ # Create proper message format
149
+ messages = [
150
+ {
151
+ "role": "system",
152
+ "content": [{
153
+ "type": "text",
154
+ "text": "You are an expert medical AI assistant specialized in analyzing medical images and providing accurate diagnostic insights."
155
+ }]
156
+ },
157
+ {
158
+ "role": "user",
159
+ "content": [
160
+ {"type": "image"},
161
+ {"type": "text", "text": "Describe the medical findings in this image and provide a diagnostic assessment."}
162
+ ]
163
+ }
164
+ ]
165
+
166
+ # Apply chat template
167
+ full_text = tokenizer.apply_chat_template(
168
+ messages,
169
+ tokenize=False,
170
+ add_generation_prompt=True
171
+ )
172
+
173
+ # Process and generate
174
+ inputs = processor(
175
+ images=[image],
176
+ text=full_text,
177
+ return_tensors="pt",
178
+ padding=True,
179
+ truncation=False
180
+ ).to(model.device, dtype=torch.bfloat16)
181
+
182
+ # Generate medical response
183
+ with torch.inference_mode():
184
+ outputs = model.generate(
185
+ **inputs,
186
+ max_new_tokens=300,
187
+ do_sample=False, # Deterministic for medical applications
188
+ use_cache=True,
189
+ cache_implementation="dynamic"
190
+ )
191
+
192
+ # Decode response
193
+ input_len = inputs["input_ids"].shape[-1]
194
+ response = processor.decode(outputs[0][input_len:], skip_special_tokens=True)
195
+ print(response)
196
+ ```
197
+
198
+ ### Advanced Usage for Specific Medical Tasks
199
+
200
+ ```python
201
+ # For medical report generation
202
+ def generate_medical_report(image_path, model, processor, tokenizer):
203
+ image = Image.open(image_path).convert("RGB").resize((448, 448))
204
+
205
+ messages = [
206
+ {
207
+ "role": "system",
208
+ "content": [{
209
+ "type": "text",
210
+ "text": "You are an expert medical AI assistant specialized in analyzing medical images and providing accurate diagnostic insights."
211
+ }]
212
+ },
213
+ {
214
+ "role": "user",
215
+ "content": [
216
+ {"type": "image"},
217
+ {"type": "text", "text": "Generate a comprehensive medical report for this image, including findings, impressions, and recommendations."}
218
+ ]
219
+ }
220
+ ]
221
+
222
+ full_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
223
+ inputs = processor(images=[image], text=full_text, return_tensors="pt").to(model.device)
224
+
225
+ with torch.inference_mode():
226
+ outputs = model.generate(**inputs, max_new_tokens=400, temperature=0.1)
227
+
228
+ return processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
229
+
230
+ # For medical VQA
231
+ def medical_vqa(image_path, question, model, processor, tokenizer):
232
+ image = Image.open(image_path).convert("RGB").resize((448, 448))
233
+
234
+ instruction = "Look at the image carefully and answer the medical question accurately based on what you observe."
235
+ full_question = f"{instruction}\n\n{question}"
236
+
237
+ messages = [
238
+ {
239
+ "role": "system",
240
+ "content": [{
241
+ "type": "text",
242
+ "text": "You are an expert medical AI assistant specialized in analyzing medical images and providing accurate diagnostic insights."
243
+ }]
244
+ },
245
+ {
246
+ "role": "user",
247
+ "content": [
248
+ {"type": "image"},
249
+ {"type": "text", "text": full_question}
250
+ ]
251
+ }
252
+ ]
253
+
254
+ full_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
255
+ inputs = processor(images=[image], text=full_text, return_tensors="pt").to(model.device)
256
+
257
+ with torch.inference_mode():
258
+ outputs = model.generate(**inputs, max_new_tokens=200, do_sample=False)
259
+
260
+ return processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
261
+ ```
262
+
263
+ ## Limitations and Ethical Considerations
264
+
265
+ ### Limitations
266
+ - Model outputs may contain inaccuracies and should be verified by medical professionals
267
+ - Performance may vary across different medical imaging modalities and populations
268
+ - Training data may contain biases present in medical literature and datasets
269
+ - Model has not been validated in clinical settings
270
+ - Designed for research and educational purposes, not clinical decision-making
271
+
272
+ ### Intended Use
273
+ - Medical education and training
274
+ - Research in medical AI and computer vision
275
+ - Development of clinical decision support tools (with proper validation)
276
+ - Academic research in multimodal medical AI
277
+ - Medical image analysis prototyping
278
+
279
+ ### Out-of-Scope Use
280
+ - Direct clinical diagnosis without physician oversight
281
+ - Treatment recommendations without medical professional validation
282
+ - Use in emergency medical situations
283
+ - Deployment in production clinical systems without extensive validation
284
+ - Patient-facing applications without proper medical supervision
285
+
286
+ ## Citation
287
+
288
+ If you use this model in your research, please cite:
289
+
290
+ ```bibtex
291
+ @misc{medgemma-flare2025,
292
+ title={MedGemma Fine-tuned for FLARE 2025 Medical Image Analysis},
293
+ author={Your Name},
294
+ year={2025},
295
+ publisher={Hugging Face},
296
+ url={https://huggingface.co/leoyinn/flare25-medgemma}
297
+ }
298
+
299
+ @misc{medgemma-base,
300
+ title={MedGemma: Medical Gemma Models for Healthcare},
301
+ author={Google Research},
302
+ year={2024},
303
+ publisher={Hugging Face},
304
+ url={https://huggingface.co/google/medgemma-4b-it}
305
+ }
306
+
307
+ @misc{flare2025,
308
+ title={FLARE 2025: A Multi-Modal Foundation Model Challenge for Medical AI},
309
+ year={2025},
310
+ url={https://huggingface.co/datasets/FLARE-MedFM/FLARE-Task5-MLLM-2D}
311
+ }
312
+ ```
313
+
314
+ ## Model Details
315
+ - **Model Type**: Vision-Language Model (VLM) specialized for medical applications
316
+ - **Architecture**: MedGemma (Gemma-based) with LoRA adapters
317
+ - **Parameters**: ~4B base parameters + LoRA adapters
318
+ - **Precision**: bfloat16 base model + full precision adapters
319
+ - **Framework**: PyTorch, Transformers, PEFT
320
+ - **Input Resolution**: 448x448 pixels (standard for MedGemma)
321
+ - **Context Length**: Supports long medical reports and conversations
322
+
323
+ ## Technical Specifications
324
+ - **Base Model**: google/medgemma-4b-it
325
+ - **Adapter Type**: LoRA (Low-Rank Adaptation)
326
+ - **Target Modules**: All attention projection layers and MLP layers
327
+ - **Chat Template**: Gemma-style with medical system prompts
328
+ - **Attention Implementation**: Eager attention for stability
329
+ - **Cache Implementation**: Dynamic caching for efficient inference
330
+
331
+ ## Contact
332
+ For questions or issues, please open an issue in the model repository or contact the authors.
333
+
334
+ ---
335
+ **Disclaimer**: This model is for research and educational purposes only. Always consult qualified medical professionals for clinical decisions.
adapter_config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "google/medgemma-4b-it",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 32,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.1,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": null,
22
+ "peft_type": "LORA",
23
+ "r": 16,
24
+ "rank_pattern": {},
25
+ "revision": null,
26
+ "target_modules": [
27
+ "gate_proj",
28
+ "up_proj",
29
+ "o_proj",
30
+ "down_proj",
31
+ "v_proj",
32
+ "q_proj",
33
+ "k_proj"
34
+ ],
35
+ "task_type": "CAUSAL_LM",
36
+ "trainable_token_indices": null,
37
+ "use_dora": false,
38
+ "use_rslora": false
39
+ }
adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f57b648f2342820a11e8f7043c67e8dd7de460a66ab20b79a6e88299a508c6c6
3
+ size 131252288
chat_template.jinja ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {{ bos_token }}
2
+ {%- if messages[0]['role'] == 'system' -%}
3
+ {%- if messages[0]['content'] is string -%}
4
+ {%- set first_user_prefix = messages[0]['content'] + '
5
+
6
+ ' -%}
7
+ {%- else -%}
8
+ {%- set first_user_prefix = messages[0]['content'][0]['text'] + '
9
+
10
+ ' -%}
11
+ {%- endif -%}
12
+ {%- set loop_messages = messages[1:] -%}
13
+ {%- else -%}
14
+ {%- set first_user_prefix = "" -%}
15
+ {%- set loop_messages = messages -%}
16
+ {%- endif -%}
17
+ {%- for message in loop_messages -%}
18
+ {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
19
+ {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
20
+ {%- endif -%}
21
+ {%- if (message['role'] == 'assistant') -%}
22
+ {%- set role = "model" -%}
23
+ {%- else -%}
24
+ {%- set role = message['role'] -%}
25
+ {%- endif -%}
26
+ {{ '<start_of_turn>' + role + '
27
+ ' + (first_user_prefix if loop.first else "") }}
28
+ {%- if message['content'] is string -%}
29
+ {{ message['content'] | trim }}
30
+ {%- elif message['content'] is iterable -%}
31
+ {%- for item in message['content'] -%}
32
+ {%- if item['type'] == 'image' -%}
33
+ {{ '<start_of_image>' }}
34
+ {%- elif item['type'] == 'text' -%}
35
+ {{ item['text'] | trim }}
36
+ {%- endif -%}
37
+ {%- endfor -%}
38
+ {%- else -%}
39
+ {{ raise_exception("Invalid content type") }}
40
+ {%- endif -%}
41
+ {{ '<end_of_turn>
42
+ ' }}
43
+ {%- endfor -%}
44
+ {%- if add_generation_prompt -%}
45
+ {{'<start_of_turn>model
46
+ '}}
47
+ {%- endif -%}
config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Gemma3ForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "google/medgemma-4b-it--configuration_gemma.GemmaConfig",
7
+ "AutoModel": "google/medgemma-4b-it--modeling_gemma.GemmaForCausalLM",
8
+ "AutoModelForImageTextToText": "google/medgemma-4b-it--modeling_gemma.Gemma3ForConditionalGeneration"
9
+ },
10
+ "model_type": "gemma",
11
+ "transformers_version": "4.45.0",
12
+ "base_model_name_or_path": "google/medgemma-4b-it",
13
+ "peft_type": "LORA"
14
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "boi_token": "<start_of_image>",
3
+ "bos_token": {
4
+ "content": "<bos>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "eoi_token": "<end_of_image>",
11
+ "eos_token": {
12
+ "content": "<eos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "image_token": "<image_soft_token>",
19
+ "pad_token": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4667f2089529e8e7657cfb6d1c19910ae71ff5f28aa7ab2ff2763330affad795
3
+ size 33384568
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40e151d89e2b2298440e00f62d75307e3fc674849f15739bfe36ee186454ef1e
3
+ size 6161