Fahad-S commited on
Commit
f2ba706
·
verified ·
1 Parent(s): 951a8f6

Upload noqueries_code/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. noqueries_code/train.py +754 -0
noqueries_code/train.py ADDED
@@ -0,0 +1,754 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ from dataclasses import dataclass, field
4
+ import logging
5
+ import pathlib
6
+ from typing import Dict, Optional, Sequence
7
+ import torch
8
+ import glob
9
+ import transformers
10
+ import tokenizers
11
+ from blip3o.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX
12
+ from torch.utils.data import Dataset
13
+ from blip3o.train.blip3o_trainer import blip3oTrainer
14
+ from blip3o import conversation as conversation_lib
15
+ from blip3o.model import *
16
+ from blip3o.mm_utils import tokenizer_image_token
17
+ from PIL import Image, ImageFile
18
+ from datasets import load_dataset, concatenate_datasets
19
+ from pathlib import Path
20
+ from datasets.utils.logging import set_verbosity_info
21
+ from transformers import logging as tf_logging
22
+ import torchvision.transforms as T
23
+ from torchvision.transforms.functional import InterpolationMode
24
+ from transformers import AutoProcessor
25
+ import random
26
+ from blip3o.model.multimodal_encoder.eva_clip.eva_clip_processors import EvaClipImageTrainProcessor
27
+
28
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
29
+ transform_und_images = T.Compose([T.Resize(448, interpolation=InterpolationMode.BICUBIC, antialias=True), T.CenterCrop(448)])
30
+
31
+ set_verbosity_info()
32
+ tf_logging.set_verbosity_info()
33
+
34
+ local_rank = None
35
+ from transformers import TrainerCallback
36
+
37
+ class GradCheckCallback(TrainerCallback):
38
+ def on_step_end(self, args, state, control, **kwargs):
39
+ model = kwargs["model"]
40
+ for name, param in model.named_parameters():
41
+ if "caption_embed" in name or "diffusion_connector" in name:
42
+ if param.grad is None:
43
+ print(f"{name} has NO gradient!")
44
+ else:
45
+ print(f"{name} grad mean: {param.grad.abs().mean().item():.6f}")
46
+
47
+ def rank0_print(*args):
48
+ if local_rank == 0:
49
+ print(*args)
50
+
51
+
52
+ from packaging import version
53
+
54
+
55
+ @dataclass
56
+ class ModelArguments:
57
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
58
+ version: Optional[str] = field(default="v0")
59
+ freeze_backbone: bool = field(default=True)
60
+ tune_mm_mlp_adapter: bool = field(default=False)
61
+ vision_tower: Optional[str] = field(default=None)
62
+ gen_vision_tower: Optional[str] = field(default=None)
63
+ mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
64
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
65
+ pretrain_gen_mlp_adapter: Optional[str] = field(default=None)
66
+ vision_tower_pretrained: Optional[str] = field(default=None)
67
+ mm_projector_type: Optional[str] = field(default="linear")
68
+ gen_projector_type: Optional[str] = field(default="linear")
69
+ mm_use_im_start_end: bool = field(default=False)
70
+ mm_use_im_patch_token: bool = field(default=True)
71
+ mm_patch_merge_type: Optional[str] = field(default="flat")
72
+ mm_vision_select_feature: Optional[str] = field(default="patch")
73
+ n_query: Optional[int] = field(default=729) # clip 576, siglip 729
74
+ n_und_query: Optional[int] = field(default=729) # clip 576, siglip 729
75
+ gen_pooling: Optional[str] = field(default="all") # options are: pool2d_3, pool2d_9, seq_3, seq_9, seq_27
76
+ diffusion_name_or_path: Optional[str] = field(default="Efficient-Large-Model/Sana_600M_1024px_diffusers")
77
+
78
+
79
+ @dataclass
80
+ class DataArguments:
81
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
82
+ lazy_preprocess: bool = False
83
+ is_multimodal: bool = False
84
+ image_folder: Optional[str] = field(default=None)
85
+ journeyDB_folder: Optional[str] = field(default=None)
86
+ shortcaption_image_folder: Optional[str] = field(default=None)
87
+ data_type: Optional[str] = field(default="mix")
88
+ image_aspect_ratio: str = "square"
89
+
90
+
91
+ @dataclass
92
+ class TrainingArguments(transformers.TrainingArguments):
93
+ cache_dir: Optional[str] = field(default=None)
94
+ optim: str = field(default="adamw_torch")
95
+ remove_unused_columns: bool = field(default=False)
96
+ freeze_mm_mlp_adapter: bool = field(default=False)
97
+ mpt_attn_impl: Optional[str] = field(default="triton")
98
+ model_max_length: int = field(
99
+ default=512,
100
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
101
+ )
102
+ double_quant: bool = field(
103
+ default=True,
104
+ metadata={"help": "Compress the quantization statistics through double quantization."},
105
+ )
106
+ quant_type: str = field(
107
+ default="nf4",
108
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."},
109
+ )
110
+ bits: int = field(default=16, metadata={"help": "How many bits to use."})
111
+ lora_enable: bool = False
112
+ lora_r: int = 64
113
+ lora_alpha: int = 16
114
+ lora_dropout: float = 0.05
115
+ lora_weight_path: str = ""
116
+ lora_bias: str = "none"
117
+ mm_projector_lr: Optional[float] = None
118
+ group_by_modality_length: bool = field(default=False)
119
+ ddp_find_unused_parameters: bool =True
120
+
121
+ ASPECT_RATIO_512 = {
122
+ "0.25": [256.0, 1024.0],
123
+ "0.26": [256.0, 992.0],
124
+ "0.27": [256.0, 960.0],
125
+ "0.28": [256.0, 928.0],
126
+ "0.32": [288.0, 896.0],
127
+ "0.33": [288.0, 864.0],
128
+ "0.35": [288.0, 832.0],
129
+ "0.4": [320.0, 800.0],
130
+ "0.42": [320.0, 768.0],
131
+ "0.48": [352.0, 736.0],
132
+ "0.5": [352.0, 704.0],
133
+ "0.52": [352.0, 672.0],
134
+ "0.57": [384.0, 672.0],
135
+ "0.6": [384.0, 640.0],
136
+ "0.68": [416.0, 608.0],
137
+ "0.72": [416.0, 576.0],
138
+ "0.78": [448.0, 576.0],
139
+ "0.82": [448.0, 544.0],
140
+ "0.88": [480.0, 544.0],
141
+ "0.94": [480.0, 512.0],
142
+ "1.0": [1024.0, 1024.0],
143
+ "1.07": [512.0, 480.0],
144
+ "1.13": [544.0, 480.0],
145
+ "1.21": [544.0, 448.0],
146
+ "1.29": [576.0, 448.0],
147
+ "1.38": [576.0, 416.0],
148
+ "1.46": [608.0, 416.0],
149
+ "1.67": [640.0, 384.0],
150
+ "1.75": [672.0, 384.0],
151
+ "2.0": [704.0, 352.0],
152
+ "2.09": [736.0, 352.0],
153
+ "2.4": [768.0, 320.0],
154
+ "2.5": [800.0, 320.0],
155
+ "2.89": [832.0, 288.0],
156
+ "3.0": [864.0, 288.0],
157
+ "3.11": [896.0, 288.0],
158
+ "3.62": [928.0, 256.0],
159
+ "3.75": [960.0, 256.0],
160
+ "3.88": [992.0, 256.0],
161
+ "4.0": [1024.0, 256.0],
162
+ }
163
+ print("Input size: ", ASPECT_RATIO_512["1.0"])
164
+
165
+ def maybe_zero_3(param, ignore_status=False, name=None):
166
+ from deepspeed import zero
167
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
168
+
169
+ if hasattr(param, "ds_id"):
170
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
171
+ if not ignore_status:
172
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
173
+ with zero.GatheredParameters([param]):
174
+ param = param.data.detach().cpu().clone()
175
+ else:
176
+ param = param.detach().cpu().clone()
177
+ return param
178
+
179
+
180
+
181
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
182
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
183
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
184
+ return to_return
185
+
186
+
187
+
188
+
189
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, vision_tower: str):
190
+ if trainer.deepspeed:
191
+ torch.cuda.synchronize()
192
+ keys_to_match = ["mm_projector"]
193
+ if getattr(trainer.args, "use_im_start_end", False):
194
+ keys_to_match.extend(["embed_tokens", "embed_in"])
195
+
196
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
197
+ trainer.model.config.save_pretrained(output_dir)
198
+
199
+ current_folder = output_dir.split("/")[-1]
200
+ parent_folder = os.path.dirname(output_dir)
201
+ if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
202
+ if current_folder.startswith("checkpoint-"):
203
+ mm_projector_folder = os.path.join(parent_folder, "mm_projector")
204
+ os.makedirs(mm_projector_folder, exist_ok=True)
205
+ torch.save(
206
+ weight_to_save,
207
+ os.path.join(mm_projector_folder, f"{current_folder}.bin"),
208
+ )
209
+ else:
210
+ torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
211
+
212
+ keys_to_match = ["gen_projector"]
213
+ if getattr(trainer.args, "use_im_start_end", False):
214
+ keys_to_match.extend(["embed_tokens", "embed_in"])
215
+
216
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
217
+ trainer.model.config.save_pretrained(output_dir)
218
+
219
+ current_folder = output_dir.split("/")[-1]
220
+ parent_folder = os.path.dirname(output_dir)
221
+ if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
222
+ if current_folder.startswith("checkpoint-"):
223
+ mm_projector_folder = os.path.join(parent_folder, "gen_projector")
224
+ os.makedirs(mm_projector_folder, exist_ok=True)
225
+ torch.save(
226
+ weight_to_save,
227
+ os.path.join(mm_projector_folder, f"{current_folder}.bin"),
228
+ )
229
+ else:
230
+ torch.save(weight_to_save, os.path.join(output_dir, f"gen_projector.bin"))
231
+
232
+ if trainer.deepspeed:
233
+ torch.cuda.synchronize()
234
+ trainer.save_model(output_dir)
235
+ return
236
+
237
+ state_dict = trainer.model.state_dict()
238
+ if trainer.args.should_save:
239
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
240
+ del state_dict
241
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
242
+
243
+
244
+ def smart_tokenizer_and_embedding_resize(
245
+ special_tokens_dict: Dict,
246
+ tokenizer: transformers.PreTrainedTokenizer,
247
+ model: transformers.PreTrainedModel,
248
+ ):
249
+
250
+
251
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
252
+ model.resize_token_embeddings(len(tokenizer))
253
+
254
+ if num_new_tokens > 0:
255
+ input_embeddings = model.get_input_embeddings().weight.data
256
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
257
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
258
+
259
+
260
+
261
+
262
+ def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
263
+ is_multimodal = data_args.is_multimodal
264
+ if not is_multimodal: return sources
265
+ und_placeholder = "<|vision_start|>" + "<|image_pad|>" * data_args.n_und_query + "<|vision_end|>"
266
+ gen_placeholder = ""
267
+ # "[IMG]" + "<image>" * data_args.n_query + "[/IMG]"
268
+ inst_type = None
269
+ for source in sources: # [instance]
270
+ for sentence in source:
271
+ if sentence["from"] == "human" and "<image>" in sentence["value"]:
272
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, und_placeholder).strip()
273
+ inst_type = "und"
274
+ elif sentence["from"] == "gpt" and "<image>" in sentence["value"]:
275
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, gen_placeholder).strip()
276
+ inst_type = "gen"
277
+ return sources, inst_type
278
+
279
+
280
+
281
+
282
+ def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
283
+ roles = {"human": "user", "gpt": "assistant"}
284
+
285
+ tokenizer = copy.deepcopy(tokenizer)
286
+ chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
287
+ tokenizer.chat_template = chat_template
288
+
289
+ # Apply prompt templates
290
+ input_ids, targets = [], []
291
+ for i, source in enumerate(sources):
292
+ if roles[source[0]["from"]] != roles["human"]:
293
+ source = source[1:]
294
+
295
+ input_id, target = [], []
296
+
297
+ # New version, use apply chat template
298
+ # Build system message for each sentence
299
+ input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}])
300
+ target += [IGNORE_INDEX] * len(input_id)
301
+
302
+ for conv in source:
303
+ try:
304
+ role = conv["role"]
305
+ content = conv["content"]
306
+ except:
307
+ role = conv["from"]
308
+ content = conv["value"]
309
+
310
+ role = roles.get(role, role)
311
+
312
+ conv = [{"role" : role, "content" : content}]
313
+ encode_id = tokenizer.apply_chat_template(conv)
314
+ input_id += encode_id
315
+ if role in ["user", "system"]:
316
+ target += [IGNORE_INDEX] * len(encode_id)
317
+ else:
318
+ target += encode_id
319
+
320
+
321
+
322
+ assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
323
+
324
+ input_ids.append(input_id)
325
+ targets.append(target)
326
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
327
+ targets = torch.tensor(targets, dtype=torch.long)
328
+
329
+ return dict(
330
+ input_ids=input_ids, # tensor(bs x seq_len)
331
+ labels=targets, # tensor(bs x seq_len)
332
+ )
333
+
334
+ def get_closest_ratio(height: float, width: float, ratios: dict):
335
+ aspect_ratio = height / width
336
+ closest_ratio = "1.0" #min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
337
+ return ratios[closest_ratio], float(closest_ratio)
338
+
339
+
340
+
341
+ class LazySupervisedMixDataset(Dataset):
342
+ def __init__(
343
+ self,
344
+ data_path: str,
345
+ tokenizer: transformers.PreTrainedTokenizer,
346
+ data_args: DataArguments,
347
+ ):
348
+ super(LazySupervisedMixDataset, self).__init__()
349
+
350
+ self.data_args = data_args
351
+ list_data_dict = []
352
+
353
+ ###################################### text to image #######################################
354
+ data_files = glob.glob(os.path.join(self.data_args.image_folder, "*.tar"))
355
+ #data_files = glob.glob(os.path.join('/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-Pretrain-Long-Caption', "*.tar")) + glob.glob(os.path.join('/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-Pretrain-Short-Caption', "*.tar")) + glob.glob(os.path.join('/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-Pretrain-JourneyDB', "*.tar"))
356
+ train_dataset = load_dataset("webdataset", data_files=data_files, split="train", num_proc=32)
357
+ train_dataset = train_dataset.rename_column("jpg", "image")
358
+ train_dataset = train_dataset.add_column('type', len(train_dataset) * ['T2I'])
359
+ train_dataset = train_dataset.add_column('image_path', len(train_dataset) * [None])
360
+ train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if not col in (
361
+ ["image", "txt", "type", "image_path"])])
362
+ print(f"finish loading image {len(train_dataset)}")
363
+ list_data_dict.append(train_dataset)
364
+
365
+
366
+ if len(list_data_dict) > 1:
367
+ list_data_dict = concatenate_datasets(list_data_dict)
368
+ else:
369
+ list_data_dict = list_data_dict[0]
370
+ list_data_dict = list_data_dict.shuffle(seed=42)
371
+
372
+ rank0_print(f"Total number of training instance: {len(list_data_dict)}")
373
+ self.tokenizer = tokenizer
374
+ self.list_data_dict = list_data_dict
375
+
376
+ def __len__(self):
377
+ return len(self.list_data_dict)
378
+
379
+ @property
380
+ def lengths(self):
381
+ length_list = []
382
+ for sample in self.list_data_dict:
383
+ img_tokens = 128 if "image" in sample else 0
384
+ length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
385
+ return length_list
386
+
387
+ @property
388
+ def modality_lengths(self):
389
+ length_list = []
390
+ for sample in self.list_data_dict:
391
+ cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
392
+ cur_len = cur_len if "image" in sample else -cur_len
393
+ length_list.append(cur_len)
394
+ return length_list
395
+
396
+ def _safe_img_process(self, imgs):
397
+ try:
398
+ out = []
399
+ for img in imgs:
400
+ ori_h, ori_w = img.height, img.width
401
+ closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, ASPECT_RATIO_512)
402
+ closest_size = [int(x) for x in closest_size]
403
+ if closest_size[0] / ori_h > closest_size[1] / ori_w:
404
+ resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h)
405
+ else:
406
+ resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1]
407
+ transform = T.Compose([
408
+ T.Lambda(lambda img: img.convert("RGB")),
409
+ T.Resize(resize_size, interpolation=InterpolationMode.BICUBIC), # Image.BICUBIC
410
+ T.CenterCrop(closest_size),
411
+ T.ToTensor(),
412
+ T.Normalize([0.5], [0.5]),
413
+ ])
414
+ out.append(transform(img))
415
+ return out
416
+ except Exception as e:
417
+ print(f"Corrupted image during processing: {e}")
418
+ return None
419
+
420
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
421
+
422
+ while True:
423
+ try:
424
+ sources = self.list_data_dict[i]
425
+ sources["conversations"] = [
426
+ {"from": "human", "value": f"Please generate image based on the following caption: {sources['txt']}"},
427
+ {"from": "gpt", "value": "<image>"},
428
+ ]
429
+ image_files = self.list_data_dict[i]["image"]
430
+ if not isinstance(image_files, list):
431
+ image_files = [image_files]
432
+
433
+ is_corrupt = False
434
+ images = []
435
+ for img in image_files:
436
+ img = img.convert("RGB")
437
+ images.append(img)
438
+
439
+ processed_images = self._safe_img_process(images)
440
+ if processed_images is None:
441
+ print("Corrupted image during transform, picking new sample.")
442
+ i = random.randint(0, len(self.list_data_dict) - 1)
443
+ continue
444
+ # just replace <image> with "" in generation tasks
445
+ sources, inst_type = preprocess_multimodal(copy.deepcopy([sources["conversations"]]), self.data_args)
446
+ data_dict = preprocess_qwen(sources, self.tokenizer, has_image=("image" in self.list_data_dict[i]))
447
+ if isinstance(i, int):
448
+ data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
449
+
450
+ data_dict["gen_image"] = processed_images[0]
451
+ data_dict["ids"] = self.list_data_dict[i]["id"] if "id" in self.list_data_dict[i] else "unk"
452
+ return data_dict
453
+ except Exception as e:
454
+ print(f"[WARN] Skipping corrupted sample {i}: {e}")
455
+ i = random.randint(0, len(self.list_data_dict) - 1)
456
+ continue
457
+
458
+ @dataclass
459
+ class DataCollatorForSupervisedDataset(object):
460
+ """Collate examples for supervised fine-tuning."""
461
+
462
+ tokenizer: transformers.PreTrainedTokenizer
463
+
464
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
465
+ input_ids, labels, ids = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "ids"))
466
+ multi_input_ids = []
467
+ multi_labels = []
468
+ i_s_pos = []
469
+ for input_id, label in zip(input_ids, labels):
470
+ input_id = input_id[: self.tokenizer.model_max_length - 17]
471
+ label = label[: self.tokenizer.model_max_length - 17]
472
+ i_s_pos.append(input_id.shape[0]+1)
473
+ img_id = torch.full((17,), IMAGE_TOKEN_IDX, dtype=input_id.dtype, device=input_id.device)
474
+ img_id[0] = DEFAULT_IM_START_TOKEN_IDX
475
+ # input_id = torch.cat([input_id, img_id])
476
+ img_label = torch.full((17,), IMAGE_TOKEN_IDX, dtype=label.dtype, device=label.device)
477
+ img_label[0] = DEFAULT_IM_START_TOKEN_IDX
478
+ # label = torch.cat([label, img_label])
479
+ multi_input_ids.append(input_id)
480
+ multi_labels.append(label)
481
+
482
+ input_ids = multi_input_ids
483
+ labels = multi_labels
484
+
485
+ input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
486
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
487
+ if input_ids.shape[1] > self.tokenizer.model_max_length:
488
+ print(f"Warning input with length {input_ids.shape[1]} is longer than max length {self.tokenizer.model_max_length}")
489
+ input_ids = input_ids[:, : self.tokenizer.model_max_length]
490
+ labels = labels[:, : self.tokenizer.model_max_length]
491
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
492
+ batch = dict(
493
+ input_ids=input_ids,
494
+ labels=labels,
495
+ attention_mask=attention_mask,
496
+ )
497
+
498
+ batch_gen_images = []
499
+ batch_und_images = []
500
+ batch_grid_thw = []
501
+
502
+ for instance in instances:
503
+ if "gen_image" in instance:
504
+ batch_gen_images.append(instance["gen_image"])
505
+
506
+ if len(batch_gen_images) > 0:
507
+ if all(x is not None and y.shape == batch_gen_images[0][0].shape for x in batch_gen_images for y in x):
508
+ batch["gen_image"] = torch.cat([images.unsqueeze(0) for images in batch_gen_images], dim=0)
509
+ else:
510
+ batch["gen_image"] = batch_gen_images
511
+ else:
512
+ batch["gen_image"] = None
513
+
514
+
515
+ for instance in instances:
516
+ if "und_image" in instance:
517
+ batch_und_images.append(instance["und_image"].unsqueeze(0)) ## 1*1024*1176
518
+ batch_grid_thw.append(instance["grid_thw"]) ## 1*3
519
+
520
+
521
+ # print(f"batch_und_images {batch_und_images}")
522
+ if len(batch_und_images) > 0:
523
+ batch["und_image"] = torch.cat([images for images in batch_und_images], dim=0)
524
+ batch["grid_thw"] = torch.cat([images for images in batch_grid_thw], dim=0)
525
+ else:
526
+ batch["und_image"] = None
527
+ batch["grid_thw"] = None
528
+
529
+ batch["ids"] = ids
530
+ batch["i_s_pos"] = i_s_pos
531
+ return batch
532
+
533
+
534
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
535
+ train_dataset = LazySupervisedMixDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args)
536
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
537
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
538
+
539
+ def train(attn_implementation=None):
540
+ global local_rank
541
+
542
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
543
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
544
+ print(model_args, data_args, training_args)
545
+ local_rank = training_args.local_rank
546
+ compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)
547
+
548
+ bnb_model_from_pretrained_args = {}
549
+ if training_args.bits in [4, 8]:
550
+ from transformers import BitsAndBytesConfig
551
+
552
+ bnb_model_from_pretrained_args.update(
553
+ dict(
554
+ device_map={"": training_args.device},
555
+ load_in_4bit=training_args.bits == 4,
556
+ load_in_8bit=training_args.bits == 8,
557
+ quantization_config=BitsAndBytesConfig(
558
+ load_in_4bit=training_args.bits == 4,
559
+ load_in_8bit=training_args.bits == 8,
560
+ llm_int8_skip_modules=["mm_projector"],
561
+ llm_int8_threshold=6.0,
562
+ llm_int8_has_fp16_weight=False,
563
+ bnb_4bit_compute_dtype=compute_dtype,
564
+ bnb_4bit_use_double_quant=training_args.double_quant,
565
+ bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'}
566
+ ),
567
+ )
568
+ )
569
+
570
+ model = blip3oFastForCausalLM.from_pretrained(
571
+ model_args.model_name_or_path,
572
+ cache_dir=training_args.cache_dir,
573
+ # attn_implementation=attn_implementation,
574
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
575
+ **bnb_model_from_pretrained_args,
576
+ )
577
+
578
+
579
+ model.config.use_cache = False
580
+
581
+ if model_args.freeze_backbone:
582
+ for (n, p) in model.get_model().named_parameters():
583
+ p.requires_grad = False
584
+ for (n, p) in model.get_vision_tower().named_parameters():
585
+ p.requires_grad = False
586
+ for (n, p) in model.lm_head.named_parameters():
587
+ p.requires_grad = False
588
+
589
+ #for (n, p) in model.get_model().named_parameters():
590
+ # p.requires_grad = True
591
+ #for (n, p) in model.get_vision_tower().named_parameters():
592
+ # p.requires_grad = False
593
+
594
+ #for (n, p) in model.get_model().embed_tokens.named_parameters():
595
+ # p.requires_grad=True
596
+
597
+ if training_args.gradient_checkpointing:
598
+ if hasattr(model, "enable_input_require_grads"):
599
+ model.enable_input_require_grads()
600
+ else:
601
+
602
+ def make_inputs_require_grad(module, input, output):
603
+ output.requires_grad_(True)
604
+
605
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
606
+
607
+ try:
608
+ tokenizer = AutoProcessor.from_pretrained(model_args.model_name_or_path).tokenizer
609
+ except Exception as e:
610
+ tokenizer = AutoProcessor.from_pretrained(model_args.model_name_or_path)
611
+
612
+ tokenizer.model_max_length = training_args.model_max_length
613
+
614
+ # tokenizer.pad_token = tokenizer.unk_token
615
+ if tokenizer.pad_token is None:
616
+ smart_tokenizer_and_embedding_resize(
617
+ special_tokens_dict=dict(
618
+ pad_token="<pad>",
619
+ additional_special_tokens=["[IMG]", "[/IMG]", "<image>"],
620
+ ),
621
+ tokenizer=tokenizer,
622
+ model=model,
623
+ )
624
+ elif not "<image>" in tokenizer.get_added_vocab():
625
+ smart_tokenizer_and_embedding_resize(
626
+ special_tokens_dict=dict(additional_special_tokens=["[IMG]", "[/IMG]", "<image>"]),
627
+ tokenizer=tokenizer,
628
+ model=model,
629
+ )
630
+ if model_args.version in conversation_lib.conv_templates:
631
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
632
+ else:
633
+ conversation_lib.default_conversation = conversation_lib.conv_templates["llama3"]
634
+ rank0_print(f"Using conversation format: {conversation_lib.default_conversation.version}")
635
+
636
+
637
+
638
+ # if model_args.vision_tower is not None:
639
+ model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp)
640
+ image_processor = model.get_model().get_vision_tower().image_processor
641
+ data_args.gen_image_processor = image_processor
642
+ data_args.image_processor = image_processor
643
+
644
+ data_args.is_multimodal = True
645
+ data_args.n_query = model_args.n_query
646
+ data_args.n_und_query = model_args.n_und_query
647
+
648
+ model.config.image_aspect_ratio = data_args.image_aspect_ratio
649
+ model.config.tokenizer_padding_side = tokenizer.padding_side
650
+ model.config.tokenizer_model_max_length = tokenizer.model_max_length
651
+
652
+ model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
653
+
654
+ model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
655
+
656
+ # Calculate total parameters and trainable parameters
657
+ total_params = sum(p.numel() for p in model.get_model().parameters())
658
+ trainable_params = sum(p.numel() for p in model.get_model().parameters() if p.requires_grad)
659
+
660
+ print(f"Total parameters: {total_params}")
661
+ print(f"Trainable parameters: {trainable_params}")
662
+
663
+
664
+ model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
665
+ model.config.mm_projector_lr = training_args.mm_projector_lr
666
+ training_args.use_im_start_end = model_args.mm_use_im_start_end
667
+ model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
668
+ # TODO: what is this?
669
+ model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
670
+ model.config.pad_token_id = tokenizer.pad_token_id
671
+
672
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
673
+
674
+ trainer = blip3oTrainer(
675
+ model=model,
676
+ tokenizer=tokenizer,
677
+ args=training_args,
678
+ #callbacks=[GradCheckCallback],
679
+ **data_module,
680
+ )
681
+ from tabulate import tabulate
682
+
683
+ if trainer.is_world_process_zero():
684
+ stat = []
685
+ for i, (n, p) in enumerate(trainer.model.named_parameters()):
686
+ stat.append([i, n, p.shape, p.requires_grad])
687
+ print(tabulate(stat, headers=["idx", "name", "shape", "trainable"]))
688
+
689
+ '''
690
+ from safetensors.torch import load_file
691
+ import json
692
+ import pathlib
693
+
694
+ # ---- Load model.safetensors if it exists ----
695
+ checkpoint_dir = pathlib.Path(training_args.output_dir)
696
+ safetensor_path = checkpoint_dir / "model.safetensors"
697
+ trainer_state_path = checkpoint_dir / "trainer_state.json"
698
+
699
+ if safetensor_path.exists():
700
+ print(f"Loading weights from {safetensor_path}")
701
+ state_dict = load_file(safetensor_path)
702
+ new_state_dict = {}
703
+ for k, v in state_dict.items():
704
+ new_key = k.replace("model.", "", 1) if k.startswith("model.") else k
705
+ new_state_dict[new_key] = v
706
+
707
+ # print all keys
708
+ #print("🔑 Keys in checkpoint:")
709
+ #for k in state_dict.keys():
710
+ # print(k, state_dict[k].shape)
711
+
712
+ missing, unexpected = model.get_model().load_state_dict(new_state_dict, strict=False)
713
+ print("✅ Loaded parameters:")
714
+ for k in new_state_dict.keys():
715
+ if k not in missing:
716
+ print(f" {k} {tuple(new_state_dict[k].shape)}")
717
+
718
+
719
+ # Restore last global step
720
+ if trainer_state_path.exists():
721
+ with open(trainer_state_path, "r") as f:
722
+ trainer_state = json.load(f)
723
+ last_global_step = trainer_state.get("global_step", 0)
724
+ last_lr = trainer_state.get("learning_rate", trainer.args.learning_rate)
725
+ trainer.state.global_step = last_global_step
726
+ # Reset optimizer with last learning rate
727
+ trainer.create_optimizer_and_scheduler(num_training_steps=trainer.args.max_steps)
728
+ optimizer = trainer.optimizer
729
+ #lr_scheduler = trainer.lr_scheduler
730
+
731
+ for param_group in optimizer.param_groups:
732
+ param_group['lr'] = last_lr
733
+ trainer.optimizer = optimizer
734
+ print(f"✅ Restored global step: {last_global_step}, learning rate: {last_lr}")
735
+
736
+
737
+ '''
738
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
739
+ trainer.train(resume_from_checkpoint=True)
740
+ else:
741
+ trainer.train()
742
+ trainer.save_state()
743
+
744
+ model.config.use_cache = True
745
+ safe_save_model_for_hf_trainer(
746
+ trainer=trainer,
747
+ output_dir=training_args.output_dir,
748
+ vision_tower=model_args.vision_tower,
749
+ )
750
+
751
+
752
+ if __name__ == "__main__":
753
+ train()
754
+