import os import copy from dataclasses import dataclass, field import logging import pathlib from typing import Dict, Optional, Sequence import torch import glob import transformers import tokenizers from blip3o.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX from torch.utils.data import Dataset from blip3o.train.blip3o_trainer import blip3oTrainer from blip3o import conversation as conversation_lib from blip3o.model import * from blip3o.mm_utils import tokenizer_image_token from PIL import Image, ImageFile from datasets import load_dataset, concatenate_datasets from pathlib import Path from datasets.utils.logging import set_verbosity_info from transformers import logging as tf_logging import torchvision.transforms as T from torchvision.transforms.functional import InterpolationMode from transformers import AutoProcessor import random from blip3o.model.multimodal_encoder.eva_clip.eva_clip_processors import EvaClipImageTrainProcessor ImageFile.LOAD_TRUNCATED_IMAGES = True transform_und_images = T.Compose([T.Resize(448, interpolation=InterpolationMode.BICUBIC, antialias=True), T.CenterCrop(448)]) set_verbosity_info() tf_logging.set_verbosity_info() local_rank = None from transformers import TrainerCallback class GradCheckCallback(TrainerCallback): def on_step_end(self, args, state, control, **kwargs): model = kwargs["model"] for name, param in model.named_parameters(): if "caption_embed" in name or "diffusion_connector" in name: if param.grad is None: print(f"{name} has NO gradient!") else: print(f"{name} grad mean: {param.grad.abs().mean().item():.6f}") def rank0_print(*args): if local_rank == 0: print(*args) from packaging import version @dataclass class ModelArguments: model_name_or_path: Optional[str] = field(default="facebook/opt-125m") version: Optional[str] = field(default="v0") freeze_backbone: bool = field(default=True) tune_mm_mlp_adapter: bool = field(default=False) vision_tower: Optional[str] = field(default=None) gen_vision_tower: Optional[str] = field(default=None) mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer pretrain_mm_mlp_adapter: Optional[str] = field(default=None) pretrain_gen_mlp_adapter: Optional[str] = field(default=None) vision_tower_pretrained: Optional[str] = field(default=None) mm_projector_type: Optional[str] = field(default="linear") gen_projector_type: Optional[str] = field(default="linear") mm_use_im_start_end: bool = field(default=False) mm_use_im_patch_token: bool = field(default=True) mm_patch_merge_type: Optional[str] = field(default="flat") mm_vision_select_feature: Optional[str] = field(default="patch") n_query: Optional[int] = field(default=729) # clip 576, siglip 729 n_und_query: Optional[int] = field(default=729) # clip 576, siglip 729 gen_pooling: Optional[str] = field(default="all") # options are: pool2d_3, pool2d_9, seq_3, seq_9, seq_27 diffusion_name_or_path: Optional[str] = field(default="Efficient-Large-Model/Sana_600M_1024px_diffusers") @dataclass class DataArguments: data_path: str = field(default=None, metadata={"help": "Path to the training data."}) lazy_preprocess: bool = False is_multimodal: bool = False image_folder: Optional[str] = field(default=None) journeyDB_folder: Optional[str] = field(default=None) shortcaption_image_folder: Optional[str] = field(default=None) data_type: Optional[str] = field(default="mix") image_aspect_ratio: str = "square" @dataclass class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) optim: str = field(default="adamw_torch") remove_unused_columns: bool = field(default=False) freeze_mm_mlp_adapter: bool = field(default=False) mpt_attn_impl: Optional[str] = field(default="triton") model_max_length: int = field( default=512, metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, ) double_quant: bool = field( default=True, metadata={"help": "Compress the quantization statistics through double quantization."}, ) quant_type: str = field( default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}, ) bits: int = field(default=16, metadata={"help": "How many bits to use."}) lora_enable: bool = False lora_r: int = 64 lora_alpha: int = 16 lora_dropout: float = 0.05 lora_weight_path: str = "" lora_bias: str = "none" mm_projector_lr: Optional[float] = None group_by_modality_length: bool = field(default=False) ddp_find_unused_parameters: bool =True ASPECT_RATIO_512 = { "0.25": [256.0, 1024.0], "0.26": [256.0, 992.0], "0.27": [256.0, 960.0], "0.28": [256.0, 928.0], "0.32": [288.0, 896.0], "0.33": [288.0, 864.0], "0.35": [288.0, 832.0], "0.4": [320.0, 800.0], "0.42": [320.0, 768.0], "0.48": [352.0, 736.0], "0.5": [352.0, 704.0], "0.52": [352.0, 672.0], "0.57": [384.0, 672.0], "0.6": [384.0, 640.0], "0.68": [416.0, 608.0], "0.72": [416.0, 576.0], "0.78": [448.0, 576.0], "0.82": [448.0, 544.0], "0.88": [480.0, 544.0], "0.94": [480.0, 512.0], "1.0": [1024.0, 1024.0], "1.07": [512.0, 480.0], "1.13": [544.0, 480.0], "1.21": [544.0, 448.0], "1.29": [576.0, 448.0], "1.38": [576.0, 416.0], "1.46": [608.0, 416.0], "1.67": [640.0, 384.0], "1.75": [672.0, 384.0], "2.0": [704.0, 352.0], "2.09": [736.0, 352.0], "2.4": [768.0, 320.0], "2.5": [800.0, 320.0], "2.89": [832.0, 288.0], "3.0": [864.0, 288.0], "3.11": [896.0, 288.0], "3.62": [928.0, 256.0], "3.75": [960.0, 256.0], "3.88": [992.0, 256.0], "4.0": [1024.0, 256.0], } print("Input size: ", ASPECT_RATIO_512["1.0"]) def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} return to_return def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, vision_tower: str): if trainer.deepspeed: torch.cuda.synchronize() keys_to_match = ["mm_projector"] if getattr(trainer.args, "use_im_start_end", False): keys_to_match.extend(["embed_tokens", "embed_in"]) weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) trainer.model.config.save_pretrained(output_dir) current_folder = output_dir.split("/")[-1] parent_folder = os.path.dirname(output_dir) if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: if current_folder.startswith("checkpoint-"): mm_projector_folder = os.path.join(parent_folder, "mm_projector") os.makedirs(mm_projector_folder, exist_ok=True) torch.save( weight_to_save, os.path.join(mm_projector_folder, f"{current_folder}.bin"), ) else: torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin")) keys_to_match = ["gen_projector"] if getattr(trainer.args, "use_im_start_end", False): keys_to_match.extend(["embed_tokens", "embed_in"]) weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) trainer.model.config.save_pretrained(output_dir) current_folder = output_dir.split("/")[-1] parent_folder = os.path.dirname(output_dir) if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: if current_folder.startswith("checkpoint-"): mm_projector_folder = os.path.join(parent_folder, "gen_projector") os.makedirs(mm_projector_folder, exist_ok=True) torch.save( weight_to_save, os.path.join(mm_projector_folder, f"{current_folder}.bin"), ) else: torch.save(weight_to_save, os.path.join(output_dir, f"gen_projector.bin")) if trainer.deepspeed: torch.cuda.synchronize() trainer.save_model(output_dir) return state_dict = trainer.model.state_dict() if trainer.args.should_save: cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa def smart_tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ): num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) model.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) input_embeddings[-num_new_tokens:] = input_embeddings_avg def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict: is_multimodal = data_args.is_multimodal if not is_multimodal: return sources und_placeholder = "<|vision_start|>" + "<|image_pad|>" * data_args.n_und_query + "<|vision_end|>" gen_placeholder = "" # "[IMG]" + "" * data_args.n_query + "[/IMG]" inst_type = None for source in sources: # [instance] for sentence in source: if sentence["from"] == "human" and "" in sentence["value"]: sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, und_placeholder).strip() inst_type = "und" elif sentence["from"] == "gpt" and "" in sentence["value"]: sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, gen_placeholder).strip() inst_type = "gen" return sources, inst_type def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict: roles = {"human": "user", "gpt": "assistant"} tokenizer = copy.deepcopy(tokenizer) chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" tokenizer.chat_template = chat_template # Apply prompt templates input_ids, targets = [], [] for i, source in enumerate(sources): if roles[source[0]["from"]] != roles["human"]: source = source[1:] input_id, target = [], [] # New version, use apply chat template # Build system message for each sentence input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}]) target += [IGNORE_INDEX] * len(input_id) for conv in source: try: role = conv["role"] content = conv["content"] except: role = conv["from"] content = conv["value"] role = roles.get(role, role) conv = [{"role" : role, "content" : content}] encode_id = tokenizer.apply_chat_template(conv) input_id += encode_id if role in ["user", "system"]: target += [IGNORE_INDEX] * len(encode_id) else: target += encode_id assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" input_ids.append(input_id) targets.append(target) input_ids = torch.tensor(input_ids, dtype=torch.long) targets = torch.tensor(targets, dtype=torch.long) return dict( input_ids=input_ids, # tensor(bs x seq_len) labels=targets, # tensor(bs x seq_len) ) def get_closest_ratio(height: float, width: float, ratios: dict): aspect_ratio = height / width closest_ratio = "1.0" #min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) return ratios[closest_ratio], float(closest_ratio) class LazySupervisedMixDataset(Dataset): def __init__( self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments, ): super(LazySupervisedMixDataset, self).__init__() self.data_args = data_args list_data_dict = [] ###################################### text to image ####################################### data_files = glob.glob(os.path.join(self.data_args.image_folder, "*.tar")) #data_files = glob.glob(os.path.join('/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-Pretrain-Long-Caption', "*.tar")) + glob.glob(os.path.join('/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-Pretrain-Short-Caption', "*.tar")) + glob.glob(os.path.join('/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-Pretrain-JourneyDB', "*.tar")) train_dataset = load_dataset("webdataset", data_files=data_files, split="train", num_proc=32) train_dataset = train_dataset.rename_column("jpg", "image") train_dataset = train_dataset.add_column('type', len(train_dataset) * ['T2I']) train_dataset = train_dataset.add_column('image_path', len(train_dataset) * [None]) train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if not col in ( ["image", "txt", "type", "image_path"])]) print(f"finish loading image {len(train_dataset)}") list_data_dict.append(train_dataset) if len(list_data_dict) > 1: list_data_dict = concatenate_datasets(list_data_dict) else: list_data_dict = list_data_dict[0] list_data_dict = list_data_dict.shuffle(seed=42) rank0_print(f"Total number of training instance: {len(list_data_dict)}") self.tokenizer = tokenizer self.list_data_dict = list_data_dict def __len__(self): return len(self.list_data_dict) @property def lengths(self): length_list = [] for sample in self.list_data_dict: img_tokens = 128 if "image" in sample else 0 length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens) return length_list @property def modality_lengths(self): length_list = [] for sample in self.list_data_dict: cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"]) cur_len = cur_len if "image" in sample else -cur_len length_list.append(cur_len) return length_list def _safe_img_process(self, imgs): try: out = [] for img in imgs: ori_h, ori_w = img.height, img.width closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, ASPECT_RATIO_512) closest_size = [int(x) for x in closest_size] if closest_size[0] / ori_h > closest_size[1] / ori_w: resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h) else: resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1] transform = T.Compose([ T.Lambda(lambda img: img.convert("RGB")), T.Resize(resize_size, interpolation=InterpolationMode.BICUBIC), # Image.BICUBIC T.CenterCrop(closest_size), T.ToTensor(), T.Normalize([0.5], [0.5]), ]) out.append(transform(img)) return out except Exception as e: print(f"Corrupted image during processing: {e}") return None def __getitem__(self, i) -> Dict[str, torch.Tensor]: while True: try: sources = self.list_data_dict[i] sources["conversations"] = [ {"from": "human", "value": f"Please generate image based on the following caption: {sources['txt']}"}, {"from": "gpt", "value": ""}, ] image_files = self.list_data_dict[i]["image"] if not isinstance(image_files, list): image_files = [image_files] is_corrupt = False images = [] for img in image_files: img = img.convert("RGB") images.append(img) processed_images = self._safe_img_process(images) if processed_images is None: print("Corrupted image during transform, picking new sample.") i = random.randint(0, len(self.list_data_dict) - 1) continue # just replace with "" in generation tasks sources, inst_type = preprocess_multimodal(copy.deepcopy([sources["conversations"]]), self.data_args) data_dict = preprocess_qwen(sources, self.tokenizer, has_image=("image" in self.list_data_dict[i])) if isinstance(i, int): data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) data_dict["gen_image"] = processed_images[0] data_dict["ids"] = self.list_data_dict[i]["id"] if "id" in self.list_data_dict[i] else "unk" return data_dict except Exception as e: print(f"[WARN] Skipping corrupted sample {i}: {e}") i = random.randint(0, len(self.list_data_dict) - 1) continue @dataclass class DataCollatorForSupervisedDataset(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels, ids = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "ids")) multi_input_ids = [] multi_labels = [] i_s_pos = [] for input_id, label in zip(input_ids, labels): input_id = input_id[: self.tokenizer.model_max_length - 17] label = label[: self.tokenizer.model_max_length - 17] i_s_pos.append(input_id.shape[0]+1) img_id = torch.full((17,), IMAGE_TOKEN_IDX, dtype=input_id.dtype, device=input_id.device) img_id[0] = DEFAULT_IM_START_TOKEN_IDX # input_id = torch.cat([input_id, img_id]) img_label = torch.full((17,), IMAGE_TOKEN_IDX, dtype=label.dtype, device=label.device) img_label[0] = DEFAULT_IM_START_TOKEN_IDX # label = torch.cat([label, img_label]) multi_input_ids.append(input_id) multi_labels.append(label) input_ids = multi_input_ids labels = multi_labels input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) if input_ids.shape[1] > self.tokenizer.model_max_length: print(f"Warning input with length {input_ids.shape[1]} is longer than max length {self.tokenizer.model_max_length}") input_ids = input_ids[:, : self.tokenizer.model_max_length] labels = labels[:, : self.tokenizer.model_max_length] attention_mask = input_ids.ne(self.tokenizer.pad_token_id) batch = dict( input_ids=input_ids, labels=labels, attention_mask=attention_mask, ) batch_gen_images = [] batch_und_images = [] batch_grid_thw = [] for instance in instances: if "gen_image" in instance: batch_gen_images.append(instance["gen_image"]) if len(batch_gen_images) > 0: if all(x is not None and y.shape == batch_gen_images[0][0].shape for x in batch_gen_images for y in x): batch["gen_image"] = torch.cat([images.unsqueeze(0) for images in batch_gen_images], dim=0) else: batch["gen_image"] = batch_gen_images else: batch["gen_image"] = None for instance in instances: if "und_image" in instance: batch_und_images.append(instance["und_image"].unsqueeze(0)) ## 1*1024*1176 batch_grid_thw.append(instance["grid_thw"]) ## 1*3 # print(f"batch_und_images {batch_und_images}") if len(batch_und_images) > 0: batch["und_image"] = torch.cat([images for images in batch_und_images], dim=0) batch["grid_thw"] = torch.cat([images for images in batch_grid_thw], dim=0) else: batch["und_image"] = None batch["grid_thw"] = None batch["ids"] = ids batch["i_s_pos"] = i_s_pos return batch def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: train_dataset = LazySupervisedMixDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) def train(attn_implementation=None): global local_rank parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() print(model_args, data_args, training_args) local_rank = training_args.local_rank compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32) bnb_model_from_pretrained_args = {} if training_args.bits in [4, 8]: from transformers import BitsAndBytesConfig bnb_model_from_pretrained_args.update( dict( device_map={"": training_args.device}, load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, quantization_config=BitsAndBytesConfig( load_in_4bit=training_args.bits == 4, load_in_8bit=training_args.bits == 8, llm_int8_skip_modules=["mm_projector"], llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=training_args.double_quant, bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'} ), ) ) model = blip3oFastForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, # attn_implementation=attn_implementation, torch_dtype=(torch.bfloat16 if training_args.bf16 else None), **bnb_model_from_pretrained_args, ) model.config.use_cache = False if model_args.freeze_backbone: for (n, p) in model.get_model().named_parameters(): p.requires_grad = False for (n, p) in model.get_vision_tower().named_parameters(): p.requires_grad = False for (n, p) in model.lm_head.named_parameters(): p.requires_grad = False #for (n, p) in model.get_model().named_parameters(): # p.requires_grad = True #for (n, p) in model.get_vision_tower().named_parameters(): # p.requires_grad = False #for (n, p) in model.get_model().embed_tokens.named_parameters(): # p.requires_grad=True if training_args.gradient_checkpointing: if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) try: tokenizer = AutoProcessor.from_pretrained(model_args.model_name_or_path).tokenizer except Exception as e: tokenizer = AutoProcessor.from_pretrained(model_args.model_name_or_path) tokenizer.model_max_length = training_args.model_max_length # tokenizer.pad_token = tokenizer.unk_token if tokenizer.pad_token is None: smart_tokenizer_and_embedding_resize( special_tokens_dict=dict( pad_token="", additional_special_tokens=["[IMG]", "[/IMG]", ""], ), tokenizer=tokenizer, model=model, ) elif not "" in tokenizer.get_added_vocab(): smart_tokenizer_and_embedding_resize( special_tokens_dict=dict(additional_special_tokens=["[IMG]", "[/IMG]", ""]), tokenizer=tokenizer, model=model, ) if model_args.version in conversation_lib.conv_templates: conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] else: conversation_lib.default_conversation = conversation_lib.conv_templates["llama3"] rank0_print(f"Using conversation format: {conversation_lib.default_conversation.version}") # if model_args.vision_tower is not None: model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp) image_processor = model.get_model().get_vision_tower().image_processor data_args.gen_image_processor = image_processor data_args.image_processor = image_processor data_args.is_multimodal = True data_args.n_query = model_args.n_query data_args.n_und_query = model_args.n_und_query model.config.image_aspect_ratio = data_args.image_aspect_ratio model.config.tokenizer_padding_side = tokenizer.padding_side model.config.tokenizer_model_max_length = tokenizer.model_max_length model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter # Calculate total parameters and trainable parameters total_params = sum(p.numel() for p in model.get_model().parameters()) trainable_params = sum(p.numel() for p in model.get_model().parameters() if p.requires_grad) print(f"Total parameters: {total_params}") print(f"Trainable parameters: {trainable_params}") model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end model.config.mm_projector_lr = training_args.mm_projector_lr training_args.use_im_start_end = model_args.mm_use_im_start_end model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token # TODO: what is this? model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) model.config.pad_token_id = tokenizer.pad_token_id data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) trainer = blip3oTrainer( model=model, tokenizer=tokenizer, args=training_args, #callbacks=[GradCheckCallback], **data_module, ) from tabulate import tabulate if trainer.is_world_process_zero(): stat = [] for i, (n, p) in enumerate(trainer.model.named_parameters()): stat.append([i, n, p.shape, p.requires_grad]) print(tabulate(stat, headers=["idx", "name", "shape", "trainable"])) ''' from safetensors.torch import load_file import json import pathlib # ---- Load model.safetensors if it exists ---- checkpoint_dir = pathlib.Path(training_args.output_dir) safetensor_path = checkpoint_dir / "model.safetensors" trainer_state_path = checkpoint_dir / "trainer_state.json" if safetensor_path.exists(): print(f"Loading weights from {safetensor_path}") state_dict = load_file(safetensor_path) new_state_dict = {} for k, v in state_dict.items(): new_key = k.replace("model.", "", 1) if k.startswith("model.") else k new_state_dict[new_key] = v # print all keys #print("🔑 Keys in checkpoint:") #for k in state_dict.keys(): # print(k, state_dict[k].shape) missing, unexpected = model.get_model().load_state_dict(new_state_dict, strict=False) print("✅ Loaded parameters:") for k in new_state_dict.keys(): if k not in missing: print(f" {k} {tuple(new_state_dict[k].shape)}") # Restore last global step if trainer_state_path.exists(): with open(trainer_state_path, "r") as f: trainer_state = json.load(f) last_global_step = trainer_state.get("global_step", 0) last_lr = trainer_state.get("learning_rate", trainer.args.learning_rate) trainer.state.global_step = last_global_step # Reset optimizer with last learning rate trainer.create_optimizer_and_scheduler(num_training_steps=trainer.args.max_steps) optimizer = trainer.optimizer #lr_scheduler = trainer.lr_scheduler for param_group in optimizer.param_groups: param_group['lr'] = last_lr trainer.optimizer = optimizer print(f"✅ Restored global step: {last_global_step}, learning rate: {last_lr}") ''' if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() model.config.use_cache = True safe_save_model_for_hf_trainer( trainer=trainer, output_dir=training_args.output_dir, vision_tower=model_args.vision_tower, ) if __name__ == "__main__": train()