berkamphoon commited on
Commit
a75a576
·
verified ·
1 Parent(s): 79a35b5

Training in progress, epoch 6

Browse files
adapter_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:35e16b795535298f0a3ed8b3ed4a1367d9eca1b89855759fafd6526288bb86ef
3
  size 6127553104
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b47a69a2a5f6aae43a6f6092b082ca0afcb41231d53e7ffd363c9e163d04746
3
  size 6127553104
log.txt CHANGED
@@ -1,29 +1,29 @@
1
- AUC: 0.8361
2
- Sensitivity at Specificity 80%: 0.6733
3
- Sensitivity at Specificity 85%: 0.5966
4
- Sensitivity at Specificity 90%: 0.5227
5
- Sensitivity at Specificity 95%: 0.4119
6
  ##############################
7
  Sex Group AUC
8
- Sex 0: 0.8212
9
- Sex 1: 0.8581
10
- ES-AUC Sex: 0.8358
11
  ##############################
12
  Race Group AUC
13
- Race Asian: 0.6105
14
- Race Black or African American: 0.9798
15
- Race White: 0.8419
16
- Race Other or Unknown: 0.7765
17
- ES-AUC Race: 0.8325
18
  ##############################
19
  Ethnic Group AUC
20
- Ethnic 0: 0.8374
21
- Ethnic 1: 0.7869
22
- Ethnic Unknown or Not Reported: 0.8220
23
- ES-AUC Ethnic: 0.8355
24
  ##############################
25
  Language Group AUC
26
- Language English: 0.8354
27
- Language Spanish: 0.8283
28
- Language Other or Unknown: 0.8521
29
- ES-AUC Language: 0.8359
 
1
+ AUC: 0.8415
2
+ Sensitivity at Specificity 80%: 0.7188
3
+ Sensitivity at Specificity 85%: 0.6278
4
+ Sensitivity at Specificity 90%: 0.5256
5
+ Sensitivity at Specificity 95%: 0.4148
6
  ##############################
7
  Sex Group AUC
8
+ Sex 0: 0.8426
9
+ Sex 1: 0.8472
10
+ ES-AUC Sex: 0.8414
11
  ##############################
12
  Race Group AUC
13
+ Race Asian: 0.7143
14
+ Race Black or African American: 1.0000
15
+ Race White: 0.8407
16
+ Race Other or Unknown: 0.8420
17
+ ES-AUC Race: 0.8391
18
  ##############################
19
  Ethnic Group AUC
20
+ Ethnic 0: 0.8457
21
+ Ethnic 1: 0.8248
22
+ Ethnic Unknown or Not Reported: 0.7278
23
+ ES-AUC Ethnic: 0.8404
24
  ##############################
25
  Language Group AUC
26
+ Language English: 0.8373
27
+ Language Spanish: 0.9066
28
+ Language Other or Unknown: 0.8520
29
+ ES-AUC Language: 0.8408
runs/Aug21_10-48-46_meedgxh100a/events.out.tfevents.1755787728.meedgxh100a.2323190.0 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c3fc5cb272d60e081d8ae134b3416e75e6c9a30b17b56ce9982ac33e8fc81b69
3
- size 19441
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c93ad5672908e684ae1945d49192862d333e26d69c9806267dec22484de2f2de
3
+ size 21751
train_medgemma_focalft_final_amd_copy.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 특정 체크포인트로 inference
2
+ # python train.py --task amd --name my_exp --checkpoint checkpoint-938
3
+
4
+ # 평가만 실행 (최신 체크포인트 자동 선택)
5
+ # python train.py --task amd --name my_exp --eval_only
6
+
7
+ # 특정 체크포인트로 평가만 실행
8
+ # python train.py --task amd --name my_exp --checkpoint checkpoint-500 --eval_only
9
+
10
+ # 기존 방식 (훈련 후 최신 체크포인트로 평가)
11
+ # python train.py --task amd --name my_exp
12
+
13
+ from __future__ import division, print_function
14
+
15
+ # Standard library imports
16
+ import os
17
+ import os.path as osp
18
+ import random
19
+ import argparse
20
+ import logging
21
+ import shutil
22
+
23
+ # Third-party imports
24
+ from tqdm import tqdm
25
+ from PIL import Image
26
+ import numpy as np
27
+ import torch
28
+ import torch.backends.cudnn as cudnn
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ from sklearn.metrics import roc_auc_score
32
+ from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig
33
+ from peft import LoraConfig, get_peft_model, PeftModel
34
+ from trl import SFTConfig, SFTTrainer
35
+ from torch.utils.data import Subset
36
+ import wandb
37
+
38
+ # Local imports
39
+ from utils import compute_es_auc, compute_group_auc, compute_es_auc_multi
40
+
41
+ # ==================== CONSTANTS ====================
42
+ SEED = 42
43
+
44
+ # Group categories for bias analysis
45
+ GROUPS = [
46
+ ['0', '1'], # Sex
47
+ ["Asian", "Black or African American", "White", "Other or Unknown"], # Race
48
+ ["0", "1", "Unknown or Not Reported"], # Ethnicity
49
+ ["English", "Spanish", "Other or Unknown"] # Language
50
+ ]
51
+
52
+ # Mapping dictionaries for demographic data
53
+ RACEMAP = {
54
+ "Asian": 1,
55
+ "White": 2,
56
+ "Other or Unknown": 3,
57
+ "Black or African American": 4
58
+ }
59
+
60
+ ETHNICMAP = {
61
+ "0": 0,
62
+ "1": 1,
63
+ "Unknown or Not Reported": 2
64
+ }
65
+
66
+ LANGUAGEMAP = {
67
+ "English": 0,
68
+ "Spanish": 1,
69
+ "Other or Unknown": 2,
70
+ }
71
+
72
+ # ==================== TASK-SPECIFIC CONFIGURATIONS ====================
73
+ TASK_CONFIGS = {
74
+ 'dr': {
75
+ 'task_idx': -3,
76
+ 'disease_name': 'Diabetic Retinopathy',
77
+ 'num_epochs': 15,
78
+ 'learning_rate': 5e-4,
79
+ 'pos_weight': 0.75,
80
+ 'neg_weight': 0.25,
81
+ 'batch_size': 8,
82
+ 'lr_scheduler': 'linear'
83
+ },
84
+ 'amd': {
85
+ 'task_idx': -2,
86
+ 'disease_name': 'Aged Macular Degeneration',
87
+ 'num_epochs': 8,
88
+ 'learning_rate': 5e-4,
89
+ 'pos_weight': 0.75,
90
+ 'neg_weight': 0.25,
91
+ 'batch_size': 8,
92
+ 'lr_scheduler': 'linear'
93
+ },
94
+ 'glaucoma': {
95
+ 'task_idx': -1,
96
+ 'disease_name': 'Glaucoma',
97
+ 'num_epochs': 12,
98
+ 'learning_rate': 7e-4,
99
+ 'pos_weight': 0.8,
100
+ 'neg_weight': 0.2,
101
+ 'batch_size': 6,
102
+ 'lr_scheduler': 'cosine'
103
+ }
104
+ }
105
+
106
+ # ==================== SETUP FUNCTIONS ====================
107
+ def setup_reproducibility():
108
+ """Set up random seeds for reproducible results."""
109
+ random.seed(SEED)
110
+ np.random.seed(SEED)
111
+ torch.manual_seed(SEED)
112
+ torch.cuda.manual_seed_all(SEED)
113
+
114
+ # CUDNN settings for complete reproducibility
115
+ torch.backends.cudnn.deterministic = True
116
+ torch.backends.cudnn.benchmark = False
117
+
118
+
119
+ def setup_logging(exp_name):
120
+ """Set up logging configuration."""
121
+ log_path = os.path.join(exp_name, "log.txt")
122
+
123
+ if osp.isfile(log_path):
124
+ os.remove(log_path)
125
+
126
+ logger = logging.getLogger(__name__)
127
+ logger.setLevel(logging.INFO)
128
+
129
+ # Console handler
130
+ console_handler = logging.StreamHandler()
131
+ logger.addHandler(console_handler)
132
+
133
+ # File handler
134
+ file_handler = logging.FileHandler(log_path)
135
+ logger.addHandler(file_handler)
136
+
137
+ return logger
138
+
139
+
140
+ # ==================== DATA PROCESSING ====================
141
+ def mask_until_after_assistant(labels: torch.Tensor, tokenizer, assistant_token_ids: list):
142
+ """Mask tokens until after the assistant token for proper loss computation."""
143
+ for i in range(labels.size(0)):
144
+ for j in range(labels.size(1) - len(assistant_token_ids) + 1):
145
+ if torch.equal(labels[i, j:j+len(assistant_token_ids)],
146
+ torch.tensor(assistant_token_ids, device=labels.device)):
147
+ labels[i, :j + len(assistant_token_ids)] = -100 # Mask until ASSISTANT:
148
+ break
149
+ return labels
150
+
151
+
152
+ def collate_fn(examples):
153
+ """Custom collate function for processing batches of image-text data."""
154
+ texts = []
155
+ images = []
156
+
157
+ # Process each example
158
+ for example in examples:
159
+ # Process image
160
+ image = example["image"].convert("RGB")
161
+ image = image.resize((IM_SIZE, IM_SIZE))
162
+ images.append([image])
163
+
164
+ # Process text
165
+ texts.append(processor.apply_chat_template(
166
+ example["messages"], add_generation_prompt=False, tokenize=False
167
+ ).strip())
168
+
169
+ # Tokenize and process
170
+ batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
171
+
172
+ # Create labels for loss computation
173
+ labels = batch["input_ids"].clone()
174
+
175
+ # Mask special tokens
176
+ image_token_id = [
177
+ processor.tokenizer.convert_tokens_to_ids(
178
+ processor.tokenizer.special_tokens_map["boi_token"]
179
+ )
180
+ ]
181
+
182
+ # Apply masks
183
+ labels[labels == processor.tokenizer.pad_token_id] = -100
184
+ labels[labels == image_token_id] = -100
185
+ labels[labels == 262144] = -100
186
+
187
+ # Mask until assistant token
188
+ labels = mask_until_after_assistant(labels, processor.tokenizer, ASST_ID)
189
+ labels[:, -1] = -100
190
+
191
+ batch["labels"] = labels
192
+ return batch
193
+
194
+
195
+ def format_data(sample, task_idx, disease_name, system_message, img_root_path):
196
+ """Format training data sample into the required structure."""
197
+ label = 'negative' if sample[task_idx] == '0.0' else 'positive'
198
+ prompt = f"Please diagnose whether the {disease_name} exist or not based on the given image.\n"
199
+
200
+ example = {
201
+ "image": Image.open(os.path.join(img_root_path, sample[1])),
202
+ "label": 0 if sample[task_idx] == '0.0' else 1,
203
+ "messages": [
204
+ {"role": "system", "content": [{"type": "text", "text": system_message}]},
205
+ {"role": "user", "content": [
206
+ {"type": "image"},
207
+ {"type": "text", "text": prompt},
208
+ ]},
209
+ {"role": "assistant", "content": [{"type": "text", "text": str(label)}]}
210
+ ]
211
+ }
212
+ return example
213
+
214
+
215
+ def format_data_for_inference(sample, task_idx, disease_name, system_message, img_root_path):
216
+ """Format validation data sample for inference."""
217
+ prompt = f"Please diagnose whether the {disease_name} exist or not based on the given image."
218
+
219
+ example = {
220
+ "image": Image.open(os.path.join(img_root_path, sample[1])),
221
+ "messages": [
222
+ {"role": "system", "content": [{"type": "text", "text": system_message}]},
223
+ {"role": "user", "content": [
224
+ {"type": "image"},
225
+ {"type": "text", "text": prompt + "\n"},
226
+ ]},
227
+ ],
228
+ "groups": sample[2:]
229
+ }
230
+ return example
231
+
232
+
233
+ def create_subset(data, task_idx, train=True):
234
+ """Create balanced subset of data for training/validation."""
235
+ if task_idx == -1: # Glaucoma
236
+ neg = [s for s in data if s[task_idx] == '0.0']
237
+ pos = [s for s in data if s[task_idx] != '0.0']
238
+ num_sample = len(pos)
239
+
240
+ if train:
241
+ return random.sample(neg, 10*num_sample), pos
242
+ else:
243
+ return random.sample(neg, 5*num_sample), pos
244
+ ###########################################################
245
+ elif task_idx == -2: # AMD
246
+ neg = []
247
+ pos = []
248
+ for s in data:
249
+ if s[task_idx] in ['3.0']:
250
+ s[task_idx] = '1.0'
251
+ pos.append(s)
252
+ else:
253
+ s[task_idx] = '0.0'
254
+ neg.append(s)
255
+
256
+ num_sample = len(pos)
257
+
258
+ if train:
259
+ print(f"AMD - Number of positive samples: {num_sample}")
260
+ return random.sample(neg, 10*num_sample), pos
261
+ else:
262
+ return random.sample(neg, 1*num_sample), pos
263
+ # return neg, pos
264
+
265
+ ###########################################################
266
+ elif task_idx == -3: # DR
267
+ neg = [s for s in data if s[task_idx] == '0.0']
268
+ pos = [s for s in data if s[task_idx] != '0.0']
269
+ num_sample = len(pos)
270
+
271
+ if train:
272
+ return random.sample(neg, 5*num_sample), pos
273
+ else:
274
+ return random.sample(neg, 10*num_sample), pos
275
+
276
+ else:
277
+ raise ValueError(f"Unsupported task_idx: {task_idx}")
278
+
279
+
280
+ # ==================== MODEL COMPONENTS ====================
281
+ class WeightedCELossFromCausalLM(nn.Module):
282
+ """Custom weighted cross-entropy loss for handling class imbalance."""
283
+
284
+ def __init__(self, pos_weight=1.5, neg_weight=0.5, ignore_index=-100):
285
+ super().__init__()
286
+ self.pos_weight = pos_weight
287
+ self.neg_weight = neg_weight
288
+ self.ignore_index = ignore_index
289
+
290
+ def forward(self, logits, labels):
291
+ """
292
+ Compute weighted cross-entropy loss.
293
+
294
+ Args:
295
+ logits: (B, L, V) - model logits
296
+ labels: (B, L) - target labels
297
+ """
298
+ shift_logits = logits[..., :-1, :].contiguous() # (B, L-1, V)
299
+ shift_labels = labels[..., 1:].contiguous() # (B, L-1)
300
+
301
+ # Flatten for CE loss
302
+ B, L1, V = shift_logits.shape
303
+ shift_logits = shift_logits.view(-1, V) # (B*L-1, V)
304
+ shift_labels = shift_labels.view(-1) # (B*L-1,)
305
+
306
+ # Compute CE loss without reduction
307
+ ce_loss = F.cross_entropy(
308
+ shift_logits, shift_labels,
309
+ ignore_index=self.ignore_index, reduction='none'
310
+ )
311
+
312
+ # Apply token-based weights
313
+ weights = torch.ones_like(ce_loss)
314
+ weights[shift_labels == POS_ID[0]] = self.pos_weight
315
+ weights[shift_labels == NEG_ID[0]] = self.neg_weight
316
+
317
+ # Apply valid mask and compute weighted loss
318
+ valid_mask = shift_labels != self.ignore_index
319
+ ce_loss = ce_loss[valid_mask]
320
+ weights = weights[valid_mask]
321
+
322
+ weighted_loss = (ce_loss * weights).mean()
323
+ return weighted_loss
324
+
325
+
326
+ class CustomSFTTrainer(SFTTrainer):
327
+ """Custom trainer with weighted loss and token accuracy logging."""
328
+
329
+ def __init__(self, task_config, *args, **kwargs):
330
+ self.task_config = task_config
331
+ super().__init__(*args, **kwargs)
332
+
333
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
334
+ """Compute training loss with custom weighted loss and metrics logging."""
335
+ mode = "train" if self.model.training else "eval"
336
+ outputs = model(**inputs)
337
+ logits = outputs.logits
338
+ labels = inputs["labels"]
339
+
340
+ # Apply task-specific weighted loss
341
+ loss_fn = WeightedCELossFromCausalLM(
342
+ pos_weight=self.task_config['pos_weight'],
343
+ neg_weight=self.task_config['neg_weight']
344
+ )
345
+ loss = loss_fn(logits, labels)
346
+
347
+ # Count training tokens
348
+ if mode == "train":
349
+ if "attention_mask" in inputs:
350
+ num_tokens_in_batch = self.accelerator.gather_for_metrics(
351
+ inputs["attention_mask"].sum()
352
+ ).sum().item()
353
+ elif "position_ids" in inputs:
354
+ local_num_tokens = torch.tensor(
355
+ inputs["position_ids"].size(1),
356
+ device=inputs["position_ids"].device
357
+ )
358
+ num_tokens_in_batch = self.accelerator.gather_for_metrics(
359
+ local_num_tokens
360
+ ).sum().item()
361
+ else:
362
+ raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.")
363
+
364
+ self._total_train_tokens += num_tokens_in_batch
365
+
366
+ self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
367
+
368
+ # Calculate token-level accuracy
369
+ if "labels" in inputs and not self.args.use_liger_kernel:
370
+ shift_logits = logits[..., :-1, :].contiguous()
371
+ shift_labels = inputs["labels"][..., 1:].contiguous()
372
+
373
+ predictions = shift_logits.argmax(dim=-1)
374
+ mask = shift_labels != -100
375
+ correct_predictions = (predictions == shift_labels) & mask
376
+
377
+ correct_tokens = self.accelerator.gather_for_metrics(correct_predictions.sum())
378
+ total_tokens = self.accelerator.gather_for_metrics(mask.sum())
379
+
380
+ total_sum = total_tokens.sum()
381
+ accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0
382
+ self._metrics[mode]["mean_token_accuracy"].append(accuracy)
383
+
384
+ return (loss, outputs) if return_outputs else loss
385
+
386
+
387
+ # ==================== MAIN EXECUTION ====================
388
+ def setup_model_and_processor(model_id):
389
+ """Initialize and configure the model and processor."""
390
+ model_kwargs = dict(
391
+ attn_implementation="eager",
392
+ torch_dtype=torch.bfloat16,
393
+ device_map="auto"
394
+ )
395
+
396
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
397
+ load_in_4bit=True,
398
+ bnb_4bit_use_double_quant=True,
399
+ bnb_4bit_quant_type="nf4",
400
+ bnb_4bit_compute_dtype=torch.bfloat16,
401
+ bnb_4bit_quant_storage=torch.bfloat16,
402
+ )
403
+
404
+ model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
405
+ processor = AutoProcessor.from_pretrained(model_id)
406
+ processor.tokenizer.padding_side = "right"
407
+
408
+ return model, processor
409
+
410
+
411
+ def run_inference(model, processor, val_dataset, task_idx, logger):
412
+ """Run inference on validation dataset and compute metrics."""
413
+ batch_size = 1
414
+ model.eval()
415
+
416
+ preds, targets, infos = [], [], {
417
+ 'sex': [], 'race': [], 'ethnic': [], 'language': []
418
+ }
419
+
420
+ for i in tqdm(range(0, len(val_dataset), batch_size), desc="Running inference"):
421
+ batch = val_dataset[i:i + batch_size]
422
+
423
+ # Prepare inputs
424
+ texts, images = [], []
425
+ for example in batch:
426
+ text = processor.apply_chat_template(
427
+ example["messages"], add_generation_prompt=True, tokenize=False
428
+ ).strip()
429
+ texts.append(text)
430
+
431
+ image = example["image"].convert("RGB").resize((IM_SIZE, IM_SIZE))
432
+ images.append([image])
433
+
434
+ # Run inference
435
+ with torch.no_grad():
436
+ texts[0] += "\n"
437
+ inputs = processor(
438
+ text=texts, images=images,
439
+ return_tensors="pt", padding=True
440
+ ).to(model.device)
441
+
442
+ outputs = model(**inputs, output_hidden_states=False, return_dict=True)
443
+ logits = outputs.logits
444
+
445
+ # Calculate probability
446
+ probs = torch.sigmoid(logits[0, -1, POS_ID] - logits[0, -1, NEG_ID])
447
+ predicted_token = processor.tokenizer.decode(outputs.logits[0].argmax(-1)[-1])
448
+ # print(f"==> {predicted_token} | {probs}")
449
+
450
+ # Process targets and demographic info with task-specific logic
451
+ target_value = batch[0]['groups'][task_idx]
452
+
453
+ # if task_idx == -2: # AMD - only '3.0' is positive
454
+ # target = 1.0 if target_value == '3.0' else 0.0
455
+ # else: # DR and Glaucoma - anything != '0.0' is positive
456
+ target = 0.0 if target_value == '0.0' else 1.0
457
+
458
+ info = {
459
+ 'sex': batch[0]['groups'][0].item(),
460
+ 'race': batch[0]['groups'][1].item(),
461
+ 'ethnic': batch[0]['groups'][2].item(),
462
+ 'language': batch[0]['groups'][3].item()
463
+ }
464
+
465
+ preds.append(probs.detach().cpu().item())
466
+ targets.append(target)
467
+
468
+ for key in infos.keys():
469
+ infos[key].append(info[key])
470
+
471
+ # Compute and log metrics
472
+ targets, preds = np.array(targets), np.array(preds)
473
+ auc_score = roc_auc_score(targets, preds)
474
+ logger.info(f"AUC: {auc_score:.4f}")
475
+
476
+ compute_es_auc(targets, preds, logger)
477
+
478
+ # Compute group-wise AUC scores
479
+ group_labels = [
480
+ ['0', '1'], # Sex
481
+ ["Asian", "Black or African American", "White", "Other or Unknown"], # Race
482
+ ["0", "1", "Unknown or Not Reported"], # Ethnicity
483
+ ["English", "Spanish", "Other or Unknown"] # Language
484
+ ]
485
+
486
+ for group, labels in zip(['Sex', 'Race', 'Ethnic', 'Language'], group_labels):
487
+ compute_group_auc(
488
+ targets, preds, infos[group.lower()], labels,
489
+ group, logger, auc_score, False
490
+ )
491
+
492
+
493
+ if __name__ == '__main__':
494
+ # ==================== ARGUMENT PARSING ====================
495
+ parser = argparse.ArgumentParser(description='Medical Image Classification Training')
496
+ parser.add_argument("--task", required=True, choices=['amd', 'dr', 'glaucoma'],
497
+ help='Medical task: amd, dr, or glaucoma')
498
+ parser.add_argument("--name", required=True, help='Experiment name')
499
+ parser.add_argument("--use_subset", action='store_true',
500
+ help='Use balanced subset of data')
501
+ parser.add_argument("--checkpoint", type=str, default=None,
502
+ help='Specific checkpoint to use for inference (e.g., checkpoint-938)')
503
+ parser.add_argument("--eval_only", action='store_true',
504
+ help='Only run evaluation without training')
505
+ args = parser.parse_args()
506
+
507
+ # ==================== SETUP ====================
508
+ setup_reproducibility()
509
+
510
+ # Get task-specific configuration
511
+ task_config = TASK_CONFIGS[args.task]
512
+ task_idx = task_config['task_idx']
513
+ disease_name = task_config['disease_name']
514
+
515
+ print(f"Task: {args.task.upper()}")
516
+ print(f"Disease: {disease_name}")
517
+ print(f"Epochs: {task_config['num_epochs']}")
518
+ print(f"Learning Rate: {task_config['learning_rate']}")
519
+ print(f"Batch Size: {task_config['batch_size']}")
520
+ print(f"LR Scheduler: {task_config['lr_scheduler']}")
521
+ print("=" * 50)
522
+
523
+ # System message for the model
524
+ system_message = f"""You are an expert AI in ophthalmology.
525
+ Your primary role is to provide accurate, reliable, and up-to-date medical knowledge based on credible sources.
526
+ You must follow these guidelines:
527
+ 1. Be accurate, concise, and clinically relevant.
528
+ 2. Use proper medical terms.
529
+ 3. Avoid overexplaining unless requested.
530
+ 4. Tone: confident, professional, precise.
531
+ Do not include any explanation or thought.
532
+ If {disease_name} is present, answer exactly 'positive'. Otherwise answer 'negative'."""
533
+
534
+ # ==================== DATA LOADING ====================
535
+ img_root_path = '/PHShome/sy1081/exeye/data'
536
+ train_dataset_raw = np.load('/PHShome/sy1081/exeye/data/train_final.npy')
537
+ val_dataset_raw = np.load('/PHShome/sy1081/exeye/data/val_final.npy')
538
+
539
+ # Create subsets
540
+ train_dataset_raw = sum(create_subset(train_dataset_raw, task_idx, train=True), [])
541
+ val_dataset_raw = sum(create_subset(val_dataset_raw, task_idx, train=False), [])
542
+
543
+ # Format datasets
544
+ train_dataset = [
545
+ format_data(s, task_idx, disease_name, system_message, img_root_path)
546
+ for s in tqdm(train_dataset_raw, desc="Formatting training data")
547
+ ]
548
+ random.shuffle(train_dataset)
549
+
550
+ val_dataset = [
551
+ format_data_for_inference(s, task_idx, disease_name, system_message, img_root_path)
552
+ for s in tqdm(val_dataset_raw, desc="Formatting validation data")
553
+ ]
554
+
555
+ print("=" * 50)
556
+ print(f"Dataset sizes | Train: {len(train_dataset)} | Val: {len(val_dataset)}")
557
+ print("=" * 50)
558
+
559
+ # ==================== MODEL SETUP ====================
560
+ model_id = "google/medgemma-27b-it"
561
+ model, processor = setup_model_and_processor(model_id)
562
+
563
+ # Get token IDs
564
+ POS_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("positive"))
565
+ NEG_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("negative"))
566
+ ASST_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("model\n"))
567
+
568
+ IM_SIZE = 512
569
+
570
+ # LoRA configuration
571
+ peft_config = LoraConfig(
572
+ lora_alpha=8,
573
+ lora_dropout=0.05,
574
+ r=16,
575
+ bias="none",
576
+ target_modules="all-linear",
577
+ task_type="CAUSAL_LM",
578
+ modules_to_save=["lm_head", "embed_tokens"],
579
+ )
580
+
581
+ # ==================== EXPERIMENT SETUP ====================
582
+ exp_name = f"{model_id.split('/')[-1]}-{args.name}"
583
+
584
+ # Determine phase and load model if exists
585
+ if args.eval_only or args.checkpoint:
586
+ # Evaluation mode or specific checkpoint specified
587
+ if args.checkpoint:
588
+ checkpoint_path = os.path.join(exp_name, args.checkpoint)
589
+ if not os.path.exists(checkpoint_path):
590
+ raise ValueError(f"Specified checkpoint {checkpoint_path} does not exist")
591
+ print(f"Loading specified checkpoint: {args.checkpoint}")
592
+ else:
593
+ # Find the latest checkpoint automatically for eval_only
594
+ if not os.path.exists(exp_name):
595
+ raise ValueError(f"Experiment directory {exp_name} does not exist")
596
+ checkpoints = [d for d in os.listdir(exp_name) if d.startswith("checkpoint-")]
597
+ if not checkpoints:
598
+ print("No checkpoint found, loading base experiment...")
599
+ checkpoint_path = exp_name
600
+ else:
601
+ # Sort by checkpoint number
602
+ latest_checkpoint = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))[-1]
603
+ checkpoint_path = os.path.join(exp_name, latest_checkpoint)
604
+ print(f"Loading latest checkpoint: {latest_checkpoint}")
605
+
606
+ model = PeftModel.from_pretrained(model, checkpoint_path)
607
+ phase = "eval"
608
+ logger = setup_logging(exp_name)
609
+
610
+ elif os.path.exists(exp_name):
611
+ print("Loading trained PEFT weights...")
612
+ # Find the latest checkpoint automatically
613
+ checkpoints = [d for d in os.listdir(exp_name) if d.startswith("checkpoint-")]
614
+ if checkpoints:
615
+ # Sort by checkpoint number
616
+ latest_checkpoint = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))[-1]
617
+ print(f"Loading from {latest_checkpoint}")
618
+ model = PeftModel.from_pretrained(model, exp_name + f"/{latest_checkpoint}")
619
+ else:
620
+ print("No checkpoint found, loading base experiment...")
621
+ model = PeftModel.from_pretrained(model, exp_name)
622
+ phase = "val"
623
+ logger = setup_logging(exp_name)
624
+ else:
625
+ print("Initializing new LoRA model...")
626
+ model = get_peft_model(model, peft_config)
627
+ model.print_trainable_parameters()
628
+ phase = "train"
629
+ os.makedirs(exp_name, exist_ok=True)
630
+
631
+ # Task-specific training configuration
632
+ training_args = SFTConfig(
633
+ output_dir=exp_name,
634
+ num_train_epochs=task_config['num_epochs'],
635
+ per_device_train_batch_size=task_config['batch_size'],
636
+ per_device_eval_batch_size=4,
637
+ gradient_accumulation_steps=8,
638
+ gradient_checkpointing=True,
639
+ optim="adamw_torch_fused",
640
+ logging_steps=10,
641
+ save_strategy="epoch",
642
+ eval_strategy="steps",
643
+ eval_steps=10000,
644
+ learning_rate=task_config['learning_rate'],
645
+ bf16=True,
646
+ max_grad_norm=1.0,
647
+ warmup_ratio=0.03,
648
+ lr_scheduler_type=task_config['lr_scheduler'],
649
+ push_to_hub=True,
650
+ report_to="tensorboard",
651
+ gradient_checkpointing_kwargs={"use_reentrant": False},
652
+ dataset_kwargs={"skip_prepare_dataset": True},
653
+ remove_unused_columns=False,
654
+ label_names=["labels"],
655
+ )
656
+
657
+ # Initialize wandb with task-specific project name
658
+ wandb.init(
659
+ project=f"{exp_name}-{args.task.upper()}-Project",
660
+ name=f"{exp_name}-{args.task}",
661
+ config=dict(training_args.to_dict(), **task_config)
662
+ )
663
+
664
+ # ==================== TRAINER SETUP ====================
665
+ trainer = CustomSFTTrainer(
666
+ task_config=task_config,
667
+ model=model,
668
+ args=training_args,
669
+ train_dataset=train_dataset,
670
+ eval_dataset=val_dataset,
671
+ data_collator=collate_fn,
672
+ peft_config=peft_config,
673
+ processing_class=processor.tokenizer,
674
+ )
675
+
676
+ # Copy source code for reproducibility
677
+ if phase in ["val", "eval"]:
678
+ shutil.copy(
679
+ "/PHShome/sy1081/exeye/train_medgemma_focalft_final.py",
680
+ os.path.join(exp_name, f"train_medgemma_focalft_final_{args.task}_copy.py")
681
+ )
682
+
683
+ # ==================== TRAINING ====================
684
+ if phase == 'train':
685
+ print(f"Starting {args.task.upper()} training with task-specific configuration...")
686
+ trainer.train()
687
+ trainer.save_model(training_args.output_dir)
688
+ logger = setup_logging(exp_name)
689
+
690
+ # ==================== EVALUATION ====================
691
+ if phase in ["val", "eval"]:
692
+ print(f"Starting {args.task.upper()} evaluation...")
693
+ if args.checkpoint:
694
+ print(f"Using checkpoint: {args.checkpoint}")
695
+ run_inference(model, processor, val_dataset, task_idx, logger)