berkamphoon commited on
Commit
e35b676
·
verified ·
1 Parent(s): d31464d

Training in progress, epoch 0

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: google/medgemma-4b-it
3
+ library_name: transformers
4
+ model_name: medgemma-4b-it-dr5
5
+ tags:
6
+ - generated_from_trainer
7
+ - sft
8
+ - trl
9
+ licence: license
10
+ ---
11
+
12
+ # Model Card for medgemma-4b-it-dr5
13
+
14
+ This model is a fine-tuned version of [google/medgemma-4b-it](https://huggingface.co/google/medgemma-4b-it).
15
+ It has been trained using [TRL](https://github.com/huggingface/trl).
16
+
17
+ ## Quick start
18
+
19
+ ```python
20
+ from transformers import pipeline
21
+
22
+ question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?"
23
+ generator = pipeline("text-generation", model="berkamphoon/medgemma-4b-it-dr5", device="cuda")
24
+ output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0]
25
+ print(output["generated_text"])
26
+ ```
27
+
28
+ ## Training procedure
29
+
30
+ [<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>](https://wandb.ai/yoon307-kaist/medgemma-4b-it-dr5-Project/runs/gtg6ozbb)
31
+
32
+
33
+ This model was trained with SFT.
34
+
35
+ ### Framework versions
36
+
37
+ - TRL: 0.19.0
38
+ - Transformers: 4.51.3
39
+ - Pytorch: 2.5.0
40
+ - Datasets: 3.6.0
41
+ - Tokenizers: 0.21.1
42
+
43
+ ## Citations
44
+
45
+
46
+
47
+ Cite TRL as:
48
+
49
+ ```bibtex
50
+ @misc{vonwerra2022trl,
51
+ title = {{TRL: Transformer Reinforcement Learning}},
52
+ author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec},
53
+ year = 2020,
54
+ journal = {GitHub repository},
55
+ publisher = {GitHub},
56
+ howpublished = {\url{https://github.com/huggingface/trl}}
57
+ }
58
+ ```
adapter_config.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "google/medgemma-4b-it",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 16,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0.05,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": [
22
+ "lm_head",
23
+ "embed_tokens"
24
+ ],
25
+ "peft_type": "LORA",
26
+ "qalora_group_size": 16,
27
+ "r": 16,
28
+ "rank_pattern": {},
29
+ "revision": null,
30
+ "target_modules": [
31
+ "q_proj",
32
+ "k_proj",
33
+ "fc1",
34
+ "out_proj",
35
+ "up_proj",
36
+ "down_proj",
37
+ "gate_proj",
38
+ "v_proj",
39
+ "fc2",
40
+ "o_proj"
41
+ ],
42
+ "task_type": "CAUSAL_LM",
43
+ "trainable_token_indices": null,
44
+ "use_dora": false,
45
+ "use_qalora": false,
46
+ "use_rslora": false
47
+ }
adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d42fc8e20b081bdc403e82f0abfbfc247cfa3b40dc005dad28b8ca245f84feb9
3
+ size 2839124552
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<image_soft_token>": 262144
3
+ }
runs/Jul15_11-02-42_seribizon/events.out.tfevents.1752591767.seribizon.4176318.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba5dfd3fd51e2c00b8b34e24788fd50aae2bfeea01ec7fdbc9b0e9191445b793
3
+ size 8559
special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "boi_token": "<start_of_image>",
3
+ "bos_token": {
4
+ "content": "<bos>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "eoi_token": "<end_of_image>",
11
+ "eos_token": {
12
+ "content": "<eos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "image_token": "<image_soft_token>",
19
+ "pad_token": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ebf1915455f8237564395182c49e3c685cfe3533b3d50ec6d49ce65ec43c32e
3
+ size 33384723
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
3
+ size 4689074
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
train_medgemma_ft_copy.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division, print_function
2
+
3
+ # === Base ===
4
+ import os
5
+ import os.path as osp
6
+ import random
7
+ import argparse
8
+ import logging
9
+ from tqdm import tqdm
10
+ from matplotlib import pyplot as plt
11
+ import pdb
12
+ from PIL import Image
13
+ import shutil
14
+ import os
15
+
16
+ # === DL ===
17
+ import numpy as np
18
+ import torch
19
+ import torch.backends.cudnn as cudnn
20
+ from torch.utils.data import DataLoader
21
+ from torch.utils.tensorboard import SummaryWriter
22
+
23
+ # === Custom ===
24
+ import tools.imutils as imutils
25
+ import tools.utils as utils
26
+ import tools.pyutils as pyutils
27
+ from tools.utils import compute_es_auc, compute_group_auc, ImprovedBalancedBatchSampler, compute_es_auc_multi
28
+
29
+ # === Evaluation ===
30
+ from sklearn.metrics import roc_curve, accuracy_score, roc_auc_score
31
+
32
+ # === Transformers ===
33
+ from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig, pipeline
34
+ from peft import LoraConfig, get_peft_model
35
+ from trl import SFTTrainer, SFTConfig
36
+ import wandb
37
+
38
+ # === Label Masking Function ===
39
+ def mask_until_after_assistant(labels: torch.Tensor, tokenizer, assistant_token_ids: list):
40
+ for i in range(labels.size(0)):
41
+ for j in range(labels.size(1) - len(assistant_token_ids) + 1):
42
+ if torch.equal(labels[i, j:j+len(assistant_token_ids)], torch.tensor(assistant_token_ids, device=labels.device)):
43
+ labels[i, :j + len(assistant_token_ids)] = -100 # ASSISTANT: 까지 마스킹
44
+ break
45
+ return labels
46
+
47
+
48
+ # === Collate Function ===
49
+ def collate_fn(examples):
50
+ texts = []
51
+ images = []
52
+ for example in examples:
53
+ image = example["image"].convert("RGB")
54
+ image = image.resize((512,512))
55
+ images.append([image])
56
+ texts.append(processor.apply_chat_template(
57
+ example["messages"], add_generation_prompt=False, tokenize=False
58
+ ).strip())
59
+
60
+ # Tokenize the texts and process the images
61
+ batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
62
+
63
+ # The labels are the input_ids, with the padding and image tokens masked in
64
+ # the loss computation
65
+ labels = batch["input_ids"].clone()
66
+
67
+ # Mask image tokens
68
+ image_token_id = [
69
+ processor.tokenizer.convert_tokens_to_ids(
70
+ processor.tokenizer.special_tokens_map["boi_token"]
71
+ )
72
+ ]
73
+ # Mask tokens that are not used in the loss computation
74
+ labels[labels == processor.tokenizer.pad_token_id] = -100
75
+ labels[labels == image_token_id] = -100
76
+ labels[labels == 262144] = -100
77
+
78
+ labels = mask_until_after_assistant(labels, processor.tokenizer, ASST_ID)
79
+ labels[:,-1] = -100
80
+
81
+ batch["labels"] = labels
82
+ # pdb.set_trace()
83
+ return batch
84
+
85
+ def format_data(sample):
86
+ label = 'negative' if sample[task_idx] == '0.0' else 'positive'
87
+ prompt = f"Please diagnose whether the {disease_name} exist or not based on the given image.\n"
88
+
89
+ # pdb.set_trace()
90
+ example = {}
91
+ example["image"] = Image.open(os.path.join(img_root_path, sample[1]))
92
+ example["label"] = 0 if sample[task_idx]== '0,0' else 1
93
+ example["messages"] = [
94
+ {"role": "system", "content": [{"type": "text", "text": system_message}]},
95
+ {"role": "user", "content": [
96
+ # {"type": "image", "image": os.path.join(img_root_path, sample[1])},
97
+ {"type": "image"},
98
+ {"type": "text", "text": prompt},
99
+ ]},
100
+ {"role": "assistant", "content": [{"type": "text", "text": str(label)}]}
101
+ ]
102
+
103
+ return example
104
+
105
+ def format_data_for_inference(sample):
106
+ prompt = f"Please diagnose whether the {disease_name} exist or not based on the given image.\n"
107
+
108
+ # pdb.set_trace()
109
+ example = {}
110
+ example["image"] = Image.open(os.path.join(img_root_path, sample[1]))
111
+ # example["label"] = 0 if sample[task_idx]== '0,0' else 1
112
+ example["messages"] = [
113
+ {"role": "system", "content": [{"type": "text", "text": system_message}]},
114
+ {"role": "user", "content": [
115
+ # {"type": "image", "image": os.path.join(img_root_path, sample[1])},
116
+ {"type": "image"},
117
+ {"type": "text", "text": prompt+"\n"},
118
+ ]},
119
+ # {"role": "assistant", "content": [{"type": "text", "text": str(label)}]}
120
+ ]
121
+ # prompt = f"Please diagnose whether the {disease_name} exist or not based on the given image."
122
+ # return [
123
+ # {"role": "system", "content": [{"type": "text", "text": system_message}]},
124
+ # {"role": "user", "content": [
125
+ # {"type": "image", "image": os.path.join(img_root_path, sample[1])},
126
+ # {"type": "text", "text": prompt}
127
+ # ]}
128
+ # ]
129
+ return example
130
+
131
+ # === Logit Preprocessing ===
132
+ def slice_logits(logits, labels):
133
+ if isinstance(logits, (tuple, list)):
134
+ logits = logits[0]
135
+ return logits.detach().cpu()
136
+
137
+ def compute_metrics(eval_pred):
138
+ logits = torch.tensor(eval_pred.predictions)
139
+
140
+ token_ids = logits.argmax(dim=-1) # (B, L): predicted token at each position
141
+
142
+ batch_logits = []
143
+ for b in range(logits.size(0)):
144
+ seq = token_ids[b] # (L,)
145
+ idxs = torch.where((seq == POS_ID[0]) | (seq == NEG_ID[0]))[0]
146
+ if len(idxs) == 0:
147
+ raise ValueError(f"Neither pos_id nor neg_id found in sequence {b}")
148
+ t = idxs[0].item() # first position where pos or neg appears
149
+ tok_id = seq[t].item() # should be either pos_id or neg_id
150
+ batch_logits.append(logits[b, t, tok_id]) # scalar
151
+
152
+ batch_logits = torch.stack(batch_logits) # shape: [B]
153
+ pred_texts = processor.tokenizer.batch_decode(token_ids[:,-1], skip_special_tokens=True)
154
+
155
+ # print(pred_texts)
156
+ # pdb.set_trace()
157
+ probs = torch.sigmoid(logits[:,-1, POS_ID[0]] - logits[:,-1, NEG_ID[0]]).numpy()
158
+
159
+ # probs = torch.sigmoid(batch_logits).numpy()
160
+ labels = torch.tensor(eval_pred.label_ids)
161
+ gt_ids = labels[labels != -100].view(logits.size(0), -1)[:, 0]
162
+ y_true = (gt_ids == POS_ID[0]).int().cpu().numpy()
163
+ auc_val = roc_auc_score(y_true, probs)
164
+ fpr, tpr, thr = roc_curve(y_true, probs)
165
+ best = thr[np.argmax(tpr - fpr)]
166
+ acc = accuracy_score(y_true, probs >= best)
167
+ return {"roc_auc": auc_val, "accuracy": acc}
168
+
169
+ def run_custom_evaluation(trainer, val_dataset, val_labels):
170
+ outputs = trainer.predict(val_dataset)
171
+ logits = torch.from_numpy(outputs.predictions) # (B, S, L)
172
+ # pdb.set_trace()
173
+ probs = torch.sigmoid(logits[:,-1, POS_ID[0]] - logits[:,-1, NEG_ID[0]]).numpy()
174
+
175
+ # decoded = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
176
+ # y_pred = [1 if "positive" in t.lower() else 0 for t in decoded]
177
+
178
+ auc_val = roc_auc_score(val_labels, probs)
179
+ # acc = accuracy_score(val_labels, y_pred)
180
+ print(f"[Custom Eval] AUC: {auc_val:.4f}")
181
+ # print(f"[Custom Eval] AUC: {auc_val:.4f}, ACC: {acc:.4f}")
182
+ return {"auc": auc_val}
183
+
184
+ # === Main ===
185
+ if __name__ == '__main__':
186
+ parser = argparse.ArgumentParser()
187
+ parser.add_argument("--task", required=True, help='amd, dr, glaucoma')
188
+ parser.add_argument("--name", required=True)
189
+ parser.add_argument("--use_subset", action='store_true')
190
+ args = parser.parse_args()
191
+
192
+ pyutils.same_seeds(0)
193
+
194
+ task_map = {'dr': (-3, 'Diabetic Retinopathy'), 'amd': (-2, 'Aged Macular Degeneration'), 'glaucoma': (-1, 'Glaucoma')}
195
+ task_idx, disease_name = task_map[args.task]
196
+ system_message = f"""You are an expert AI in ophthalmology.\n
197
+ Your primary role is to provide accurate, reliable, and up-to-date medical knowledge based on credible sources.\n
198
+ "You must follow these guidelines:\n"
199
+ "1. Be accurate, concise, and clinically relevant.\n"
200
+ "2. Use proper medical terms.\n"
201
+ "3. Avoid overexplaining unless requested.\n"
202
+ "4. Tone: confident, professional, precise.\n"
203
+ "Do not include any explanation or thought."
204
+ If {disease_name} is present, answer exactly 'positive'. Otherwise answer 'negative'."""
205
+
206
+ cudnn.benchmark = True
207
+ img_root_path = '/shared/ssd_30T/yoon/exEYE/Eyeproject/data'
208
+ train_dataset = np.load('/shared/ssd_30T/yoon/exEYE/datasplit/train_final.npy')
209
+ val_dataset_raw = np.load('/shared/ssd_30T/yoon/exEYE/datasplit/val_final.npy')
210
+
211
+ if args.use_subset:
212
+ def subset(data,train=True):
213
+ neg = [s for s in data if s[task_idx] == '0.0']
214
+ pos = [s for s in data if s[task_idx] != '0.0']
215
+ num_sample = len(pos)
216
+ if train:
217
+ return random.sample(neg, 5*num_sample), random.sample(pos, num_sample)
218
+ else:
219
+ return random.sample(neg, num_sample), random.sample(pos, num_sample)
220
+ # return random.sample(neg, 15), random.sample(pos, 15)
221
+ # return neg, random.sample(pos, num_sample)
222
+ train_dataset = sum(subset(train_dataset,train=True), [])
223
+ val_dataset_raw = sum(subset(val_dataset_raw,train=False), [])
224
+
225
+ train_dataset = [format_data(s) for s in tqdm(train_dataset)]
226
+ random.shuffle(train_dataset)
227
+ val_dataset = [format_data_for_inference(s) for s in tqdm(val_dataset_raw)]
228
+ val_labels = [1 if s[task_idx] != '0.0' else 0 for s in val_dataset_raw]
229
+ # val_dataset = [format_data(s) for s in tqdm(val_dataset)]
230
+ print("="*50)
231
+ print(f"Total number of Data| Train: {len(train_dataset)} | Val : {len(val_dataset)}")
232
+ print("="*50)
233
+
234
+ model_id = "google/medgemma-4b-it"
235
+ model_kwargs = dict(
236
+ attn_implementation="eager",
237
+ torch_dtype=torch.bfloat16,
238
+ device_map="auto",
239
+ )
240
+
241
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
242
+ load_in_4bit=True,
243
+ bnb_4bit_use_double_quant=True,
244
+ bnb_4bit_quant_type="nf4",
245
+ bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
246
+ bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
247
+ )
248
+
249
+ model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
250
+ processor = AutoProcessor.from_pretrained(model_id)
251
+
252
+ # Use right padding to avoid issues during training
253
+ processor.tokenizer.padding_side = "right"
254
+ # processor.image_processor.size = {"height": 512, "width": 512}
255
+ # processor.image_processor.crop_size = {"height": 512, "width": 512}
256
+
257
+ POS_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("positive")) #30558
258
+ NEG_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("negative")) #27851
259
+ ASST_ID = processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.tokenize("model\n"))
260
+
261
+
262
+ peft_config = LoraConfig(
263
+ lora_alpha=16,
264
+ lora_dropout=0.05,
265
+ r=16,
266
+ bias="none",
267
+ target_modules="all-linear",
268
+ task_type="CAUSAL_LM",
269
+ modules_to_save=[
270
+ "lm_head",
271
+ "embed_tokens",
272
+ ],
273
+ )
274
+
275
+
276
+ exp_name = f"{model_id.split('/')[-1]}-{args.name}"
277
+
278
+ if os.path.exists(exp_name):
279
+ from peft import PeftModel
280
+ print("🔁 Loading trained PEFT weights...")
281
+ model = PeftModel.from_pretrained(model, exp_name)
282
+ # model = PeftModel.from_pretrained(model, "llava-1.5-7b-hf-dr-all/checkpoint-80")
283
+ phase= "val"
284
+ else:
285
+ print("🚀 Initializing new LoRA model...")
286
+ model = get_peft_model(model, peft_config)
287
+ model.print_trainable_parameters()
288
+ phase= "train"
289
+
290
+
291
+ training_args = SFTConfig(
292
+ output_dir=exp_name,
293
+ num_train_epochs=15, # Number of training epochs
294
+ per_device_train_batch_size=4, # Batch size per device during training
295
+ per_device_eval_batch_size=4, # Batch size per device during evaluation
296
+ gradient_accumulation_steps=8, # Number of steps before performing a backward/update pass
297
+ gradient_checkpointing=True, # Enable gradient checkpointing to reduce memory usage
298
+ optim="adamw_torch_fused", # Use fused AdamW optimizer for better performance
299
+ logging_steps=10, # Number of steps between logs
300
+ save_strategy="epoch", # Save checkpoint every epoch
301
+ eval_strategy="steps", # Evaluate every `eval_steps`
302
+ eval_steps=10000, # Number of steps between evaluations
303
+ learning_rate=3e-4, # Learning rate based on QLoRA paper
304
+ bf16=True, # Use bfloat16 precision
305
+ max_grad_norm=0.3, # Max gradient norm based on QLoRA paper
306
+ warmup_ratio=0.03, # Warmup ratio based on QLoRA paper
307
+ lr_scheduler_type="linear", # Use linear learning rate scheduler
308
+ push_to_hub=True, # Push model to Hub
309
+ report_to="tensorboard", # Report metrics to tensorboard
310
+ gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues
311
+ dataset_kwargs={"skip_prepare_dataset": True}, # Skip default dataset preparation to preprocess manually
312
+ remove_unused_columns = False, # Columns are unused for training but needed for data collator
313
+ label_names=["labels"],
314
+ )
315
+ # training_args.remove_unused_columns = False
316
+
317
+ wandb.init(project=f"{exp_name}-Project", name=exp_name, config=training_args)
318
+
319
+ trainer = SFTTrainer(
320
+ model=model,
321
+ args=training_args,
322
+ train_dataset=train_dataset,
323
+ eval_dataset=val_dataset,
324
+ data_collator=collate_fn,
325
+ peft_config=peft_config,
326
+ processing_class=processor.tokenizer,
327
+ # compute_metrics=compute_metrics,
328
+ # preprocess_logits_for_metrics=slice_logits,
329
+ )
330
+
331
+ shutil.copy("/shared/ssd_30T/yoon/exEYE/Eyeproject/train_medgemma_ft.py",os.path.join(".",exp_name,"train_medgemma_ft_copy.py"))
332
+
333
+ if phase == 'train':
334
+ trainer.train()
335
+ trainer.save_model(training_args.output_dir)
336
+
337
+ # custom_eval_metrics = run_custom_evaluation(trainer, val_dataset, val_labels)
338
+ # else:
339
+ # ft_pipe = pipeline(
340
+ # "image-text-to-text",
341
+ # model=exp_name,
342
+ # processor=processor,
343
+ # torch_dtype=torch.bfloat16,
344
+ # )
345
+
346
+ # # Set `do_sample = False` for deterministic responses
347
+ # ft_pipe.model.generation_config.do_sample = False
348
+ # ft_pipe.model.generation_config.pad_token_id = processor.tokenizer.eos_token_id
349
+ # # Use left padding during inference
350
+ # processor.tokenizer.padding_side = "left"
351
+
352
+ # texts = []
353
+ # images = []
354
+
355
+ # for example in val_dataset:
356
+ # text = processor.apply_chat_template(
357
+ # example["messages"], add_generation_prompt=True, tokenize=False
358
+ # ).strip()
359
+ # texts.append(text)
360
+ # image = example["image"].convert("RGB").resize((512, 512))
361
+ # images.append([image]) # 리스트로 감싸야 MedGEMMA가 기대하는 batched format
362
+
363
+ # # pdb.set_trace()
364
+ # ft_outputs = ft_pipe(
365
+ # text=texts,
366
+ # images=images,
367
+ # max_new_tokens=5,
368
+ # batch_size=1,
369
+ # return_full_text=False,
370
+ # )
371
+
372
+ batch_size = 1
373
+ model.eval()
374
+ all_logits = []
375
+
376
+ for i in tqdm(range(0, len(val_dataset), batch_size), desc="Running inference with logits"):
377
+ batch = val_dataset[i:i + batch_size]
378
+
379
+ # prepare inputs
380
+ texts = []
381
+ images = []
382
+ for example in batch:
383
+ text = processor.apply_chat_template(
384
+ example["messages"], add_generation_prompt=True, tokenize=False
385
+ ).strip()
386
+ texts.append(text)
387
+ image = example["image"].convert("RGB").resize((512, 512))
388
+ images.append([image])
389
+
390
+ # tokenizer & image processor
391
+ with torch.no_grad():
392
+ texts[0] += "\n"
393
+ inputs = processor(
394
+ text=texts,
395
+ images=images,
396
+ return_tensors="pt",
397
+ padding=True
398
+ ).to(model.device)
399
+
400
+ outputs = model(**inputs, output_hidden_states=False, return_dict=True)
401
+
402
+ # pdb.set_trace()
403
+ print(processor.tokenizer.decode(outputs.logits[0].argmax(-1)[-1]))
404
+
405
+ # logits: (B, L, V)
406
+ all_logits.append(outputs.logits.to(torch.float32).detach().cpu().numpy())
407
+
408
+ # pdb.set_trace()
409
+
410
+ logits= torch.from_numpy(np.stack(all_logits,axis=0)).squeeze(1)
411
+
412
+ probs = torch.sigmoid(logits[:,-1, POS_ID] - logits[:,-1, NEG_ID])
413
+
414
+ # decoded = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
415
+ # y_pred = [1 if "positive" in t.lower() else 0 for t in decoded]
416
+
417
+ auc_val = roc_auc_score(val_labels, probs)
418
+ print(auc_val)
419
+
420
+ # print(trainer.evaluate())
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:917843d0db4b87699709b89729f6dbbf7627e023b2ea7d95950d17712c751c5e
3
+ size 5752