import os import copy from dataclasses import dataclass, field import logging import pathlib from typing import Dict, Optional, Sequence import torch import glob import transformers from blip3o.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX from torch.utils.data import Dataset from blip3o.train.blip3o_trainer import blip3oTrainer from blip3o import conversation as conversation_lib from blip3o.model import * from PIL import ImageFile from datasets import load_dataset, concatenate_datasets from pathlib import Path from datasets.utils.logging import set_verbosity_info from transformers import logging as tf_logging import torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode from transformers import AutoProcessor, AutoTokenizer import random from datasets import load_dataset, concatenate_datasets import wandb ImageFile.LOAD_TRUNCATED_IMAGES = True transform_und_images = T.Compose([T.Resize(448, interpolation=InterpolationMode.BICUBIC, antialias=True), T.CenterCrop(448)]) set_verbosity_info() tf_logging.set_verbosity_info() local_rank = None from transformers import TrainerCallback class GradCheckCallback(TrainerCallback): def on_step_end(self, args, state, control, **kwargs): model = kwargs["model"] for name, param in model.named_parameters(): if "caption_embed" in name or "diffusion_connector" in name: if param.grad is None: print(f"{name} has NO gradient!") else: print(f"{name} grad mean: {param.grad.abs().mean().item():.6f}") def rank0_print(*args): if local_rank == 0: print(*args) from packaging import version @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") version: Optional[str] = field(default="v0") freeze_backbone: bool = field(default=True) tune_mm_mlp_adapter: bool = field(default=False) vision_tower: Optional[str] = field(default="mobileclip_l_1024") gen_vision_tower: Optional[str] = field(default=None) mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer pretrain_mm_mlp_adapter: Optional[str] = field(default=None) pretrain_gen_mlp_adapter: Optional[str] = field(default=None) vision_tower_pretrained: Optional[str] = field(default=None) mm_projector_type: Optional[str] = field(default="linear") gen_projector_type: Optional[str] = field(default="linear") mm_use_im_start_end: bool = field(default=False) mm_use_im_patch_token: bool = field(default=True) mm_patch_merge_type: Optional[str] = field(default="flat") mm_vision_select_feature: Optional[str] = field(default="patch") n_query: Optional[int] = field(default=729) # clip 576, siglip 729 n_und_query: Optional[int] = field(default=729) # clip 576, siglip 729 gen_pooling: Optional[str] = field(default="all") # options are: pool2d_3, pool2d_9, seq_3, seq_9, seq_27 diffusion_name_or_path: Optional[str] = field(default="Efficient-Large-Model/Sana_600M_512px_diffusers") teacher_model_name: Optional[str] = field(default="Efficient-Large-Model/SANA1.5_4.8B_1024px_diffusers") @dataclass class DataArguments: data_path: str = field(default=None, metadata={"help": "Path to the training data."}) lazy_preprocess: bool = False is_multimodal: bool = False image_folder: Optional[str] = field(default=None) journeyDB_folder: Optional[str] = field(default=None) shortcaption_image_folder: Optional[str] = field(default=None) data_type: Optional[str] = field(default="mix") image_aspect_ratio: str = "square" @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") remove_unused_columns: bool = field(default=False) freeze_mm_mlp_adapter: bool = field(default=False) mpt_attn_impl: Optional[str] = field(default="triton") model_max_length: int = field( default=512, metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, ) double_quant: bool = field( default=True, metadata={"help": "Compress the quantization statistics through double quantization."}, ) quant_type: str = field( default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}, ) bits: int = field(default=16, metadata={"help": "How many bits to use."}) lora_enable: bool = False lora_r: int = 64 lora_alpha: int = 16 lora_dropout: float = 0.05 lora_weight_path: str = "" lora_bias: str = "none" mm_projector_lr: Optional[float] = None group_by_modality_length: bool = field(default=False) ddp_find_unused_parameters: bool =True ASPECT_RATIO_512 = { "0.25": [256.0, 1024.0], "0.26": [256.0, 992.0], "0.27": [256.0, 960.0], "0.28": [256.0, 928.0], "0.32": [288.0, 896.0], "0.33": [288.0, 864.0], "0.35": [288.0, 832.0], "0.4": [320.0, 800.0], "0.42": [320.0, 768.0], "0.48": [352.0, 736.0], "0.5": [352.0, 704.0], "0.52": [352.0, 672.0], "0.57": [384.0, 672.0], "0.6": [384.0, 640.0], "0.68": [416.0, 608.0], "0.72": [416.0, 576.0], "0.78": [448.0, 576.0], "0.82": [448.0, 544.0], "0.88": [480.0, 544.0], "0.94": [480.0, 512.0], "1.0": [512.0, 512.0], "1.07": [512.0, 480.0], "1.13": [544.0, 480.0], "1.21": [544.0, 448.0], "1.29": [576.0, 448.0], "1.38": [576.0, 416.0], "1.46": [608.0, 416.0], "1.67": [640.0, 384.0], "1.75": [672.0, 384.0], "2.0": [704.0, 352.0], "2.09": [736.0, 352.0], "2.4": [768.0, 320.0], "2.5": [800.0, 320.0], "2.89": [832.0, 288.0], "3.0": [864.0, 288.0], "3.11": [896.0, 288.0], "3.62": [928.0, 256.0], "3.75": [960.0, 256.0], "3.88": [992.0, 256.0], "4.0": [1024.0, 256.0], } def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, vision_tower: str): if trainer.deepspeed: torch.cuda.synchronize() keys_to_match = ["mm_projector"] if getattr(trainer.args, "use_im_start_end", False): keys_to_match.extend(["embed_tokens", "embed_in"]) weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) trainer.model.config.save_pretrained(output_dir) current_folder = output_dir.split("/")[-1] parent_folder = os.path.dirname(output_dir) if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: if current_folder.startswith("checkpoint-"): mm_projector_folder = os.path.join(parent_folder, "mm_projector") os.makedirs(mm_projector_folder, exist_ok=True) torch.save( weight_to_save, os.path.join(mm_projector_folder, f"{current_folder}.bin"), ) else: torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin")) keys_to_match = ["gen_projector"] if getattr(trainer.args, "use_im_start_end", False): keys_to_match.extend(["embed_tokens", "embed_in"]) weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) trainer.model.config.save_pretrained(output_dir) current_folder = output_dir.split("/")[-1] parent_folder = os.path.dirname(output_dir) if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: if current_folder.startswith("checkpoint-"): mm_projector_folder = os.path.join(parent_folder, "gen_projector") os.makedirs(mm_projector_folder, exist_ok=True) torch.save( weight_to_save, os.path.join(mm_projector_folder, f"{current_folder}.bin"), ) else: torch.save(weight_to_save, os.path.join(output_dir, f"gen_projector.bin")) if trainer.deepspeed: torch.cuda.synchronize() trainer.save_model(output_dir) return state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ): num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict: is_multimodal = data_args.is_multimodal if not is_multimodal: return sources und_placeholder = "<|vision_start|>" + "<|image_pad|>" * data_args.n_und_query + "<|vision_end|>" gen_placeholder = "" # "[IMG]" + "" * data_args.n_query + "[/IMG]" inst_type = None for source in sources: # [instance] for sentence in source: if sentence["from"] == "human" and "" in sentence["value"]: sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, und_placeholder).strip() inst_type = "und" elif sentence["from"] == "gpt" and "" in sentence["value"]: sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, gen_placeholder).strip() inst_type = "gen" return sources, inst_type def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict: roles = {"human": "user", "gpt": "assistant"} tokenizer = copy.deepcopy(tokenizer) chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" tokenizer.chat_template = chat_template # Apply prompt templates input_ids, targets = [], [] for i, source in enumerate(sources): if roles[source[0]["from"]] != roles["human"]: source = source[1:] input_id, target = [], [] # New version, use apply chat template # Build system message for each sentence input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}]) target += [IGNORE_INDEX] * len(input_id) for conv in source: try: role = conv["role"] content = conv["content"] except: role = conv["from"] content = conv["value"] role = roles.get(role, role) conv = [{"role" : role, "content" : content}] encode_id = tokenizer.apply_chat_template(conv) input_id += encode_id if role in ["user", "system"]: target += [IGNORE_INDEX] * len(encode_id) else: target += encode_id assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" input_ids.append(input_id) targets.append(target) input_ids = torch.tensor(input_ids, dtype=torch.long) targets = torch.tensor(targets, dtype=torch.long) return dict( input_ids=input_ids, # tensor(bs x seq_len) labels=targets, # tensor(bs x seq_len) ) def get_closest_ratio(height: float, width: float, ratios: dict): aspect_ratio = height / width closest_ratio = "1.0" #min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) return ratios[closest_ratio], float(closest_ratio) class LazySupervisedMixDataset(Dataset): def __init__( self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, teacher_tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, ): super(LazySupervisedMixDataset, self).__init__() self.data_args = data_args list_data_dict = [] ###################################### text to image ####################################### ''' data_files = glob.glob(os.path.join(self.data_args.image_folder, "*.tar")) #data_files = glob.glob(os.path.join('/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-Pretrain-Long-Caption', "*.tar")) + glob.glob(os.path.join('/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-Pretrain-Short-Caption', "*.tar")) + glob.glob(os.path.join('/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-Pretrain-JourneyDB', "*.tar")) train_dataset = load_dataset("webdataset", data_files=data_files, split="train", num_proc=32) train_dataset = train_dataset.rename_column("jpg", "image") train_dataset = train_dataset.add_column('type', len(train_dataset) * ['T2I']) train_dataset = train_dataset.add_column('image_path', len(train_dataset) * [None]) train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if not col in ( ["image", "txt", "type", "image_path"])]) print(f"finish loading image {len(train_dataset)}") list_data_dict.append(train_dataset) ''' # sharegpt_files = sorted(glob.glob("/proj/cvl/users/x_fahkh2/BLIP3o/dataset/ShareGPT-4o-Image/flat/*.tar")) # blip3o_files = sorted(glob.glob("/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-60k/*.tar")) # datasets_to_combine = [] # sharegpt_dataset = load_dataset("webdataset", data_files=sharegpt_files, split="train", num_proc=16) # blip3o_dataset = load_dataset("webdataset", data_files=blip3o_files, split="train", num_proc=16) # for possible_col in ["jpg", "jpeg", "png"]: # if possible_col in sharegpt_dataset.column_names: # sharegpt_dataset = sharegpt_dataset.rename_column(possible_col, "image") # break # datasets_to_combine.append(sharegpt_dataset) # for possible_col in ["jpg", "jpeg", "png"]: # if possible_col in blip3o_dataset.column_names: # blip3o_dataset = blip3o_dataset.rename_column(possible_col, "image") # break # datasets_to_combine.append(blip3o_dataset) # train_dataset = concatenate_datasets(datasets_to_combine) # train_dataset = train_dataset.add_column('type', len(train_dataset) * ['T2I']) # train_dataset = train_dataset.add_column('image_path', len(train_dataset) * [None]) # train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if not col in (["image", "txt", "type", "image_path"])]) # print(f"finish loading image {len(train_dataset)}") # list_data_dict.append(train_dataset) data_files = glob.glob(os.path.join(self.data_args.image_folder, "*.tar")) # data_files = glob.glob(os.path.join('/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-Pretrain-Long-Caption', "*.tar")) + glob.glob(os.path.join('/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-Pretrain-Short-Caption', "*.tar")) + glob.glob(os.path.join('/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-Pretrain-JourneyDB', "*.tar")) train_dataset = load_dataset("webdataset", data_files=data_files, split="train", num_proc=32) train_dataset = train_dataset.rename_column("jpg", "image") train_dataset = train_dataset.add_column('type', len(train_dataset) * ['T2I']) train_dataset = train_dataset.add_column('image_path', len(train_dataset) * [None]) train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if not col in ( ["image", "txt", "type", "image_path"])]) print(f"finish loading image {len(train_dataset)}") list_data_dict.append(train_dataset) if len(list_data_dict) > 1: list_data_dict = concatenate_datasets(list_data_dict) else: list_data_dict = list_data_dict[0] list_data_dict = list_data_dict.shuffle(seed=42) rank0_print(f"Total number of training instance: {len(list_data_dict)}") #self.tokenizer = tokenizer #self.list_data_dict = list_data_dict list_data_dict = list_data_dict.filter(lambda x: x["image"] is not None, num_proc=8) self.tokenizer = tokenizer self.teacher_tokenizer = teacher_tokenizer self.list_data_dict = list_data_dict def __len__(self): return len(self.list_data_dict) @property def lengths(self): length_list = [] for sample in self.list_data_dict: img_tokens = 128 if "image" in sample else 0 length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens) return length_list @property def modality_lengths(self): length_list = [] for sample in self.list_data_dict: cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"]) cur_len = cur_len if "image" in sample else -cur_len length_list.append(cur_len) return length_list def _safe_img_process(self, imgs): try: out = [] for img in imgs: ori_h, ori_w = img.height, img.width closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, ASPECT_RATIO_512) closest_size = [int(x) for x in closest_size] if closest_size[0] / ori_h > closest_size[1] / ori_w: resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h) else: resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1] transform = T.Compose([ T.Lambda(lambda img: img.convert("RGB")), T.Resize(resize_size, interpolation=InterpolationMode.BICUBIC), # Image.BICUBIC T.CenterCrop(closest_size), T.ToTensor(), T.Normalize([0.5], [0.5]), ]) out.append(transform(img)) return out except Exception as e: print(f"Corrupted image during processing: {e}") return None def __getitem__(self, i) -> Dict[str, torch.Tensor]: while True: try: sources = self.list_data_dict[i] teacher_prompt = sources['txt'].lower().strip() teacher_prompt = f"Please generate image based on the following caption: {teacher_prompt}" sources["conversations"] = [ {"from": "human", "value": f"Please generate image based on the following caption: {sources['txt']}"}, {"from": "gpt", "value": ""}, ] image_files = self.list_data_dict[i]["image"] if not isinstance(image_files, list): image_files = [image_files] is_corrupt = False images = [] for img in image_files: img = img.convert("RGB") images.append(img) processed_images = self._safe_img_process(images) if processed_images is None: print("Corrupted image during transform, picking new sample.") i = random.randint(0, len(self.list_data_dict) - 1) continue # just replace with "" in generation tasks sources, inst_type = preprocess_multimodal(copy.deepcopy([sources["conversations"]]), self.data_args) data_dict = preprocess_qwen(sources, self.tokenizer, has_image=("image" in self.list_data_dict[i])) if isinstance(i, int): data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) data_dict["gen_image"] = processed_images[0] data_dict["ids"] = self.list_data_dict[i]["id"] if "id" in self.list_data_dict[i] else "unk" # input to teacher tokenizer is sources['txt'] teacher_inputs = self.teacher_tokenizer( teacher_prompt, padding="max_length", max_length=300, truncation=True, add_special_tokens=True, return_tensors="pt", ) data_dict['teacher_token_ids'] = teacher_inputs.input_ids data_dict['teacher_attention_mask'] = teacher_inputs.attention_mask data_dict['teacher_prompt'] = teacher_prompt return data_dict except Exception as e: print(f"[WARN] Skipping corrupted sample {i}: {e}") i = random.randint(0, len(self.list_data_dict) - 1) continue @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels, ids = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "ids")) multi_input_ids = [] multi_labels = [] i_s_pos = [] for input_id, label in zip(input_ids, labels): input_id = input_id[: self.tokenizer.model_max_length - 9] label = label[: self.tokenizer.model_max_length - 9] i_s_pos.append(input_id.shape[0]+1) img_id = torch.full((9,), IMAGE_TOKEN_IDX, dtype=input_id.dtype, device=input_id.device) img_id[0] = DEFAULT_IM_START_TOKEN_IDX input_id = torch.cat([input_id, img_id]) img_label = torch.full((9,), IMAGE_TOKEN_IDX, dtype=label.dtype, device=label.device) img_label[0] = DEFAULT_IM_START_TOKEN_IDX label = torch.cat([label, img_label]) multi_input_ids.append(input_id) multi_labels.append(label) input_ids = multi_input_ids labels = multi_labels input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) if input_ids.shape[1] > self.tokenizer.model_max_length: print(f"Warning input with length {input_ids.shape[1]} is longer than max length {self.tokenizer.model_max_length}") input_ids = input_ids[:, : self.tokenizer.model_max_length] labels = labels[:, : self.tokenizer.model_max_length] batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), teacher_prompts = [instance['teacher_prompt'] for instance in instances], teacher_input_ids=torch.cat([instance['teacher_token_ids'] for instance in instances], dim=0), teacher_attention_mask=torch.cat([instance['teacher_attention_mask'] for instance in instances], dim=0) ) batch_gen_images = [] batch_und_images = [] batch_grid_thw = [] for instance in instances: if "gen_image" in instance: batch_gen_images.append(instance["gen_image"]) if len(batch_gen_images) > 0: if all(x is not None and y.shape == batch_gen_images[0][0].shape for x in batch_gen_images for y in x): batch["gen_image"] = torch.cat([images.unsqueeze(0) for images in batch_gen_images], dim=0) else: batch["gen_image"] = batch_gen_images else: batch["gen_image"] = None for instance in instances: if "und_image" in instance: batch_und_images.append(instance["und_image"].unsqueeze(0)) ## 1*1024*1176 batch_grid_thw.append(instance["grid_thw"]) ## 1*3 # print(f"batch_und_images {batch_und_images}") if len(batch_und_images) > 0: batch["und_image"] = torch.cat([images for images in batch_und_images], dim=0) batch["grid_thw"] = torch.cat([images for images in batch_grid_thw], dim=0) else: batch["und_image"] = None batch["grid_thw"] = None batch["ids"] = ids batch["i_s_pos"] = i_s_pos return batch def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, teacher_tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: train_dataset = LazySupervisedMixDataset(tokenizer=tokenizer, teacher_tokenizer=teacher_tokenizer, data_path=data_args.data_path, data_args=data_args) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) def train(attn_implementation=None): global local_rank parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() # print(model_args, data_args, training_args) local_rank = training_args.local_rank compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32) bnb_model_from_pretrained_args = {} if training_args.bits in [4, 8]: from transformers import BitsAndBytesConfig bnb_model_from_pretrained_args.update( dict( device_map={"": training_args.device}, load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, quantization_config=BitsAndBytesConfig( load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, llm_int8_skip_modules=["mm_projector"], llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=training_args.double_quant, bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'} ), ) ) model = blip3oFastForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, # attn_implementation=attn_implementation, torch_dtype=(torch.bfloat16 if training_args.bf16 else None), **bnb_model_from_pretrained_args, ) model.config.use_cache = False if model_args.freeze_backbone: for (n, p) in model.get_model().named_parameters(): p.requires_grad = False for (n, p) in model.get_vision_tower().named_parameters(): p.requires_grad = False for (n, p) in model.lm_head.named_parameters(): p.requires_grad = False #for (n, p) in model.get_model().named_parameters(): # p.requires_grad = True #for (n, p) in model.get_vision_tower().named_parameters(): # p.requires_grad = False #for (n, p) in model.get_model().embed_tokens.named_parameters(): # p.requires_grad=True if training_args.gradient_checkpointing: if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) try: tokenizer = AutoProcessor.from_pretrained(model_args.model_name_or_path).tokenizer except Exception as e: tokenizer = AutoProcessor.from_pretrained(model_args.model_name_or_path) teacher_tokenizer = AutoTokenizer.from_pretrained(model_args.teacher_model_name, subfolder="tokenizer") tokenizer.model_max_length = training_args.model_max_length teacher_tokenizer.padding_side = "right" # tokenizer.pad_token = tokenizer.unk_token if tokenizer.pad_token is None: smart_tokenizer_and_embedding_resize( special_tokens_dict=dict( pad_token="", additional_special_tokens=["[IMG]", "[/IMG]", ""], ), tokenizer=tokenizer, model=model, ) elif not "" in tokenizer.get_added_vocab(): smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(additional_special_tokens=["[IMG]", "[/IMG]", ""]), tokenizer=tokenizer, model=model, ) if model_args.version in conversation_lib.conv_templates: conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] else: conversation_lib.default_conversation = conversation_lib.conv_templates["llama3"] rank0_print(f"Using conversation format: {conversation_lib.default_conversation.version}") # if model_args.vision_tower is not None: model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp) image_processor = model.get_model().get_vision_tower().image_processor data_args.gen_image_processor = image_processor data_args.image_processor = image_processor data_args.is_multimodal = True data_args.n_query = model_args.n_query data_args.n_und_query = model_args.n_und_query model.config.image_aspect_ratio = data_args.image_aspect_ratio model.config.tokenizer_padding_side = tokenizer.padding_side model.config.tokenizer_model_max_length = tokenizer.model_max_length model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter # Calculate total parameters and trainable parameters total_params = sum(p.numel() for p in model.get_model().parameters()) trainable_params = sum(p.numel() for p in model.get_model().parameters() if p.requires_grad) print(f"Total parameters: {total_params}") print(f"Trainable parameters: {trainable_params}") model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end model.config.mm_projector_lr = training_args.mm_projector_lr training_args.use_im_start_end = model_args.mm_use_im_start_end model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token # TODO: what is this? model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) model.config.pad_token_id = tokenizer.pad_token_id data_module = make_supervised_data_module(tokenizer=tokenizer, teacher_tokenizer=teacher_tokenizer, data_args=data_args) class WandbLossLogger(TrainerCallback): def on_step_end(self, args, state, control, **kwargs): if not state.is_world_process_zero: return model = kwargs["model"] logs = {} if hasattr(model, "_last_diff_loss"): logs["diff_loss"] = float(model._last_diff_loss) if hasattr(model, "_last_kd_loss"): logs["kd_loss"] = float(model._last_kd_loss) if hasattr(model, "_last_pred_divergence"): logs["pred_divergence"] = float(model._last_pred_divergence) if hasattr(model, "kd_weight"): logs["kd_weight"] = float(model.kd_weight) if logs: wandb.log(logs, step=state.global_step) class ProgressiveKDCallback(TrainerCallback): def on_step_begin(self, args, state, control, **kwargs): model = kwargs["model"] current_step = state.global_step total_steps = args.max_steps if args.max_steps > 0 else state.max_steps # Update KD weight progressively if hasattr(model, 'kd_weight'): model.kd_weight = max(0.1, 10 * (1 - current_step / total_steps)) trainer = blip3oTrainer( model=model, tokenizer=tokenizer, args=training_args, callbacks=[WandbLossLogger, ProgressiveKDCallback], **data_module, ) from tabulate import tabulate # if trainer.is_world_process_zero(): # stat = [] # for i, (n, p) in enumerate(trainer.model.named_parameters()): # stat.append([i, n, p.shape, p.requires_grad]) # print(tabulate(stat, headers=["idx", "name", "shape", "trainable"])) ''' from safetensors.torch import load_file import json import pathlib # ---- Load model.safetensors if it exists ---- checkpoint_dir = pathlib.Path(training_args.output_dir) safetensor_path = checkpoint_dir / "model.safetensors" trainer_state_path = checkpoint_dir / "trainer_state.json" if safetensor_path.exists(): print(f"Loading weights from {safetensor_path}") state_dict = load_file(safetensor_path) new_state_dict = {} for k, v in state_dict.items(): new_key = k.replace("model.", "", 1) if k.startswith("model.") else k new_state_dict[new_key] = v # print all keys #print("🔑 Keys in checkpoint:") #for k in state_dict.keys(): # print(k, state_dict[k].shape) missing, unexpected = model.get_model().load_state_dict(new_state_dict, strict=False) print("✅ Loaded parameters:") for k in new_state_dict.keys(): if k not in missing: print(f" {k} {tuple(new_state_dict[k].shape)}") # Restore last global step if trainer_state_path.exists(): with open(trainer_state_path, "r") as f: trainer_state = json.load(f) last_global_step = trainer_state.get("global_step", 0) last_lr = trainer_state.get("learning_rate", trainer.args.learning_rate) trainer.state.global_step = last_global_step # Reset optimizer with last learning rate trainer.create_optimizer_and_scheduler(num_training_steps=trainer.args.max_steps) optimizer = trainer.optimizer #lr_scheduler = trainer.lr_scheduler for param_group in optimizer.param_groups: param_group['lr'] = last_lr trainer.optimizer = optimizer print(f"✅ Restored global step: {last_global_step}, learning rate: {last_lr}") ''' if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() model.config.use_cache = True safe_save_model_for_hf_trainer( trainer=trainer, output_dir=training_args.output_dir, vision_tower=model_args.vision_tower, ) if __name__ == "__main__": train()