| | import os |
| | import copy |
| | from dataclasses import dataclass, field |
| | import logging |
| | import pathlib |
| | from typing import Dict, Optional, Sequence |
| | import torch |
| | import glob |
| | import transformers |
| | from blip3o.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX |
| | from torch.utils.data import Dataset |
| | from blip3o.train.blip3o_trainer import blip3oTrainer |
| | from blip3o import conversation as conversation_lib |
| | from blip3o.model import * |
| | from PIL import ImageFile |
| | from datasets import load_dataset, concatenate_datasets |
| | from pathlib import Path |
| | from datasets.utils.logging import set_verbosity_info |
| | from transformers import logging as tf_logging |
| | import torchvision.transforms as T |
| | from torchvision.transforms.functional import InterpolationMode |
| | from transformers import AutoProcessor, AutoTokenizer |
| | import random |
| | from datasets import load_dataset, concatenate_datasets |
| | import wandb |
| |
|
| | ImageFile.LOAD_TRUNCATED_IMAGES = True |
| | transform_und_images = T.Compose([T.Resize(448, interpolation=InterpolationMode.BICUBIC, antialias=True), T.CenterCrop(448)]) |
| |
|
| | set_verbosity_info() |
| | tf_logging.set_verbosity_info() |
| |
|
| | local_rank = None |
| | from transformers import TrainerCallback |
| |
|
| | class GradCheckCallback(TrainerCallback): |
| | def on_step_end(self, args, state, control, **kwargs): |
| | model = kwargs["model"] |
| | for name, param in model.named_parameters(): |
| | if "caption_embed" in name or "diffusion_connector" in name: |
| | if param.grad is None: |
| | print(f"{name} has NO gradient!") |
| | else: |
| | print(f"{name} grad mean: {param.grad.abs().mean().item():.6f}") |
| |
|
| | def rank0_print(*args): |
| | if local_rank == 0: |
| | print(*args) |
| |
|
| |
|
| | from packaging import version |
| |
|
| |
|
| | @dataclass |
| | class ModelArguments: |
| | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") |
| | version: Optional[str] = field(default="v0") |
| | freeze_backbone: bool = field(default=True) |
| | tune_mm_mlp_adapter: bool = field(default=False) |
| | vision_tower: Optional[str] = field(default="mobileclip_l_1024") |
| | gen_vision_tower: Optional[str] = field(default=None) |
| | mm_vision_select_layer: Optional[int] = field(default=-1) |
| | pretrain_mm_mlp_adapter: Optional[str] = field(default=None) |
| | pretrain_gen_mlp_adapter: Optional[str] = field(default=None) |
| | vision_tower_pretrained: Optional[str] = field(default=None) |
| | mm_projector_type: Optional[str] = field(default="linear") |
| | gen_projector_type: Optional[str] = field(default="linear") |
| | mm_use_im_start_end: bool = field(default=False) |
| | mm_use_im_patch_token: bool = field(default=True) |
| | mm_patch_merge_type: Optional[str] = field(default="flat") |
| | mm_vision_select_feature: Optional[str] = field(default="patch") |
| | n_query: Optional[int] = field(default=729) |
| | n_und_query: Optional[int] = field(default=729) |
| | gen_pooling: Optional[str] = field(default="all") |
| | diffusion_name_or_path: Optional[str] = field(default="Efficient-Large-Model/Sana_600M_512px_diffusers") |
| | teacher_model_name: Optional[str] = field(default="Efficient-Large-Model/SANA1.5_4.8B_1024px_diffusers") |
| |
|
| |
|
| | @dataclass |
| | class DataArguments: |
| | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) |
| | lazy_preprocess: bool = False |
| | is_multimodal: bool = False |
| | image_folder: Optional[str] = field(default=None) |
| | journeyDB_folder: Optional[str] = field(default=None) |
| | shortcaption_image_folder: Optional[str] = field(default=None) |
| | data_type: Optional[str] = field(default="mix") |
| | image_aspect_ratio: str = "square" |
| |
|
| |
|
| | @dataclass |
| | class TrainingArguments(transformers.TrainingArguments): |
| | cache_dir: Optional[str] = field(default=None) |
| | optim: str = field(default="adamw_torch") |
| | remove_unused_columns: bool = field(default=False) |
| | freeze_mm_mlp_adapter: bool = field(default=False) |
| | mpt_attn_impl: Optional[str] = field(default="triton") |
| | model_max_length: int = field( |
| | default=512, |
| | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, |
| | ) |
| | double_quant: bool = field( |
| | default=True, |
| | metadata={"help": "Compress the quantization statistics through double quantization."}, |
| | ) |
| | quant_type: str = field( |
| | default="nf4", |
| | metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}, |
| | ) |
| | bits: int = field(default=16, metadata={"help": "How many bits to use."}) |
| | lora_enable: bool = False |
| | lora_r: int = 64 |
| | lora_alpha: int = 16 |
| | lora_dropout: float = 0.05 |
| | lora_weight_path: str = "" |
| | lora_bias: str = "none" |
| | mm_projector_lr: Optional[float] = None |
| | group_by_modality_length: bool = field(default=False) |
| | ddp_find_unused_parameters: bool =True |
| |
|
| | ASPECT_RATIO_512 = { |
| | "0.25": [256.0, 1024.0], |
| | "0.26": [256.0, 992.0], |
| | "0.27": [256.0, 960.0], |
| | "0.28": [256.0, 928.0], |
| | "0.32": [288.0, 896.0], |
| | "0.33": [288.0, 864.0], |
| | "0.35": [288.0, 832.0], |
| | "0.4": [320.0, 800.0], |
| | "0.42": [320.0, 768.0], |
| | "0.48": [352.0, 736.0], |
| | "0.5": [352.0, 704.0], |
| | "0.52": [352.0, 672.0], |
| | "0.57": [384.0, 672.0], |
| | "0.6": [384.0, 640.0], |
| | "0.68": [416.0, 608.0], |
| | "0.72": [416.0, 576.0], |
| | "0.78": [448.0, 576.0], |
| | "0.82": [448.0, 544.0], |
| | "0.88": [480.0, 544.0], |
| | "0.94": [480.0, 512.0], |
| | "1.0": [512.0, 512.0], |
| | "1.07": [512.0, 480.0], |
| | "1.13": [544.0, 480.0], |
| | "1.21": [544.0, 448.0], |
| | "1.29": [576.0, 448.0], |
| | "1.38": [576.0, 416.0], |
| | "1.46": [608.0, 416.0], |
| | "1.67": [640.0, 384.0], |
| | "1.75": [672.0, 384.0], |
| | "2.0": [704.0, 352.0], |
| | "2.09": [736.0, 352.0], |
| | "2.4": [768.0, 320.0], |
| | "2.5": [800.0, 320.0], |
| | "2.89": [832.0, 288.0], |
| | "3.0": [864.0, 288.0], |
| | "3.11": [896.0, 288.0], |
| | "3.62": [928.0, 256.0], |
| | "3.75": [960.0, 256.0], |
| | "3.88": [992.0, 256.0], |
| | "4.0": [1024.0, 256.0], |
| | } |
| |
|
| |
|
| | def maybe_zero_3(param, ignore_status=False, name=None): |
| | from deepspeed import zero |
| | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus |
| |
|
| | if hasattr(param, "ds_id"): |
| | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: |
| | if not ignore_status: |
| | logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") |
| | with zero.GatheredParameters([param]): |
| | param = param.data.detach().cpu().clone() |
| | else: |
| | param = param.detach().cpu().clone() |
| | return param |
| |
|
| |
|
| |
|
| | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): |
| | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} |
| | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} |
| | return to_return |
| |
|
| |
|
| |
|
| |
|
| | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, vision_tower: str): |
| | if trainer.deepspeed: |
| | torch.cuda.synchronize() |
| | keys_to_match = ["mm_projector"] |
| | if getattr(trainer.args, "use_im_start_end", False): |
| | keys_to_match.extend(["embed_tokens", "embed_in"]) |
| |
|
| | weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) |
| | trainer.model.config.save_pretrained(output_dir) |
| |
|
| | current_folder = output_dir.split("/")[-1] |
| | parent_folder = os.path.dirname(output_dir) |
| | if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: |
| | if current_folder.startswith("checkpoint-"): |
| | mm_projector_folder = os.path.join(parent_folder, "mm_projector") |
| | os.makedirs(mm_projector_folder, exist_ok=True) |
| | torch.save( |
| | weight_to_save, |
| | os.path.join(mm_projector_folder, f"{current_folder}.bin"), |
| | ) |
| | else: |
| | torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin")) |
| |
|
| | keys_to_match = ["gen_projector"] |
| | if getattr(trainer.args, "use_im_start_end", False): |
| | keys_to_match.extend(["embed_tokens", "embed_in"]) |
| |
|
| | weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) |
| | trainer.model.config.save_pretrained(output_dir) |
| |
|
| | current_folder = output_dir.split("/")[-1] |
| | parent_folder = os.path.dirname(output_dir) |
| | if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: |
| | if current_folder.startswith("checkpoint-"): |
| | mm_projector_folder = os.path.join(parent_folder, "gen_projector") |
| | os.makedirs(mm_projector_folder, exist_ok=True) |
| | torch.save( |
| | weight_to_save, |
| | os.path.join(mm_projector_folder, f"{current_folder}.bin"), |
| | ) |
| | else: |
| | torch.save(weight_to_save, os.path.join(output_dir, f"gen_projector.bin")) |
| |
|
| | if trainer.deepspeed: |
| | torch.cuda.synchronize() |
| | trainer.save_model(output_dir) |
| | return |
| |
|
| | state_dict = trainer.model.state_dict() |
| | if trainer.args.should_save: |
| | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} |
| | del state_dict |
| | trainer._save(output_dir, state_dict=cpu_state_dict) |
| |
|
| |
|
| | def smart_tokenizer_and_embedding_resize( |
| | special_tokens_dict: Dict, |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | model: transformers.PreTrainedModel, |
| | ): |
| |
|
| |
|
| | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) |
| | model.resize_token_embeddings(len(tokenizer)) |
| |
|
| | if num_new_tokens > 0: |
| | input_embeddings = model.get_input_embeddings().weight.data |
| | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
| | input_embeddings[-num_new_tokens:] = input_embeddings_avg |
| |
|
| |
|
| |
|
| |
|
| | def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict: |
| | is_multimodal = data_args.is_multimodal |
| | if not is_multimodal: return sources |
| | und_placeholder = "<|vision_start|>" + "<|image_pad|>" * data_args.n_und_query + "<|vision_end|>" |
| | gen_placeholder = "" |
| | |
| | inst_type = None |
| | for source in sources: |
| | for sentence in source: |
| | if sentence["from"] == "human" and "<image>" in sentence["value"]: |
| | sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, und_placeholder).strip() |
| | inst_type = "und" |
| | elif sentence["from"] == "gpt" and "<image>" in sentence["value"]: |
| | sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, gen_placeholder).strip() |
| | inst_type = "gen" |
| | return sources, inst_type |
| |
|
| |
|
| |
|
| |
|
| | def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict: |
| | roles = {"human": "user", "gpt": "assistant"} |
| |
|
| | tokenizer = copy.deepcopy(tokenizer) |
| | chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" |
| | tokenizer.chat_template = chat_template |
| |
|
| | |
| | input_ids, targets = [], [] |
| | for i, source in enumerate(sources): |
| | if roles[source[0]["from"]] != roles["human"]: |
| | source = source[1:] |
| |
|
| | input_id, target = [], [] |
| |
|
| | |
| | |
| | input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}]) |
| | target += [IGNORE_INDEX] * len(input_id) |
| |
|
| | for conv in source: |
| | try: |
| | role = conv["role"] |
| | content = conv["content"] |
| | except: |
| | role = conv["from"] |
| | content = conv["value"] |
| |
|
| | role = roles.get(role, role) |
| | |
| | conv = [{"role" : role, "content" : content}] |
| | encode_id = tokenizer.apply_chat_template(conv) |
| | input_id += encode_id |
| | if role in ["user", "system"]: |
| | target += [IGNORE_INDEX] * len(encode_id) |
| | else: |
| | target += encode_id |
| | |
| |
|
| | |
| | assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" |
| |
|
| | input_ids.append(input_id) |
| | targets.append(target) |
| | input_ids = torch.tensor(input_ids, dtype=torch.long) |
| | targets = torch.tensor(targets, dtype=torch.long) |
| |
|
| | return dict( |
| | input_ids=input_ids, |
| | labels=targets, |
| | ) |
| |
|
| | def get_closest_ratio(height: float, width: float, ratios: dict): |
| | aspect_ratio = height / width |
| | closest_ratio = "1.0" |
| | return ratios[closest_ratio], float(closest_ratio) |
| |
|
| |
|
| |
|
| | class LazySupervisedMixDataset(Dataset): |
| | def __init__( |
| | self, |
| | data_path: str, |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | teacher_tokenizer: transformers.PreTrainedTokenizer, |
| | data_args: DataArguments, |
| | ): |
| | super(LazySupervisedMixDataset, self).__init__() |
| |
|
| | self.data_args = data_args |
| | list_data_dict = [] |
| |
|
| | |
| | ''' |
| | data_files = glob.glob(os.path.join(self.data_args.image_folder, "*.tar")) |
| | #data_files = glob.glob(os.path.join('/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-Pretrain-Long-Caption', "*.tar")) + glob.glob(os.path.join('/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-Pretrain-Short-Caption', "*.tar")) + glob.glob(os.path.join('/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-Pretrain-JourneyDB', "*.tar")) |
| | train_dataset = load_dataset("webdataset", data_files=data_files, split="train", num_proc=32) |
| | train_dataset = train_dataset.rename_column("jpg", "image") |
| | train_dataset = train_dataset.add_column('type', len(train_dataset) * ['T2I']) |
| | train_dataset = train_dataset.add_column('image_path', len(train_dataset) * [None]) |
| | train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if not col in ( |
| | ["image", "txt", "type", "image_path"])]) |
| | print(f"finish loading image {len(train_dataset)}") |
| | list_data_dict.append(train_dataset) |
| | ''' |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | data_files = glob.glob(os.path.join(self.data_args.image_folder, "*.tar")) |
| | |
| | train_dataset = load_dataset("webdataset", data_files=data_files, split="train", num_proc=32) |
| | train_dataset = train_dataset.rename_column("jpg", "image") |
| | train_dataset = train_dataset.add_column('type', len(train_dataset) * ['T2I']) |
| | train_dataset = train_dataset.add_column('image_path', len(train_dataset) * [None]) |
| | train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if not col in ( |
| | ["image", "txt", "type", "image_path"])]) |
| | print(f"finish loading image {len(train_dataset)}") |
| | list_data_dict.append(train_dataset) |
| |
|
| |
|
| | if len(list_data_dict) > 1: |
| | list_data_dict = concatenate_datasets(list_data_dict) |
| | else: |
| | list_data_dict = list_data_dict[0] |
| | list_data_dict = list_data_dict.shuffle(seed=42) |
| |
|
| | rank0_print(f"Total number of training instance: {len(list_data_dict)}") |
| | |
| | |
| |
|
| | list_data_dict = list_data_dict.filter(lambda x: x["image"] is not None, num_proc=8) |
| | self.tokenizer = tokenizer |
| | self.teacher_tokenizer = teacher_tokenizer |
| | self.list_data_dict = list_data_dict |
| |
|
| | def __len__(self): |
| | return len(self.list_data_dict) |
| |
|
| | @property |
| | def lengths(self): |
| | length_list = [] |
| | for sample in self.list_data_dict: |
| | img_tokens = 128 if "image" in sample else 0 |
| | length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens) |
| | return length_list |
| |
|
| | @property |
| | def modality_lengths(self): |
| | length_list = [] |
| | for sample in self.list_data_dict: |
| | cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"]) |
| | cur_len = cur_len if "image" in sample else -cur_len |
| | length_list.append(cur_len) |
| | return length_list |
| | |
| | def _safe_img_process(self, imgs): |
| | try: |
| | out = [] |
| | for img in imgs: |
| | ori_h, ori_w = img.height, img.width |
| | closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, ASPECT_RATIO_512) |
| | closest_size = [int(x) for x in closest_size] |
| | if closest_size[0] / ori_h > closest_size[1] / ori_w: |
| | resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h) |
| | else: |
| | resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1] |
| | transform = T.Compose([ |
| | T.Lambda(lambda img: img.convert("RGB")), |
| | T.Resize(resize_size, interpolation=InterpolationMode.BICUBIC), |
| | T.CenterCrop(closest_size), |
| | T.ToTensor(), |
| | T.Normalize([0.5], [0.5]), |
| | ]) |
| | out.append(transform(img)) |
| | return out |
| | except Exception as e: |
| | print(f"Corrupted image during processing: {e}") |
| | return None |
| |
|
| | def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
| |
|
| | while True: |
| | try: |
| | sources = self.list_data_dict[i] |
| | teacher_prompt = sources['txt'].lower().strip() |
| | teacher_prompt = f"Please generate image based on the following caption: {teacher_prompt}" |
| | |
| | sources["conversations"] = [ |
| | {"from": "human", "value": f"Please generate image based on the following caption: {sources['txt']}"}, |
| | {"from": "gpt", "value": "<image>"}, |
| | ] |
| | image_files = self.list_data_dict[i]["image"] |
| | if not isinstance(image_files, list): |
| | image_files = [image_files] |
| |
|
| | is_corrupt = False |
| | images = [] |
| | for img in image_files: |
| | img = img.convert("RGB") |
| | images.append(img) |
| | |
| | processed_images = self._safe_img_process(images) |
| | if processed_images is None: |
| | print("Corrupted image during transform, picking new sample.") |
| | i = random.randint(0, len(self.list_data_dict) - 1) |
| | continue |
| | |
| | sources, inst_type = preprocess_multimodal(copy.deepcopy([sources["conversations"]]), self.data_args) |
| | data_dict = preprocess_qwen(sources, self.tokenizer, has_image=("image" in self.list_data_dict[i])) |
| | if isinstance(i, int): |
| | data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) |
| |
|
| | data_dict["gen_image"] = processed_images[0] |
| | data_dict["ids"] = self.list_data_dict[i]["id"] if "id" in self.list_data_dict[i] else "unk" |
| | |
| | teacher_inputs = self.teacher_tokenizer( |
| | teacher_prompt, |
| | padding="max_length", |
| | max_length=300, |
| | truncation=True, |
| | add_special_tokens=True, |
| | return_tensors="pt", |
| | ) |
| | data_dict['teacher_token_ids'] = teacher_inputs.input_ids |
| | data_dict['teacher_attention_mask'] = teacher_inputs.attention_mask |
| | data_dict['teacher_prompt'] = teacher_prompt |
| | return data_dict |
| | except Exception as e: |
| | print(f"[WARN] Skipping corrupted sample {i}: {e}") |
| | i = random.randint(0, len(self.list_data_dict) - 1) |
| | continue |
| |
|
| | @dataclass |
| | class DataCollatorForSupervisedDataset(object): |
| | """Collate examples for supervised fine-tuning.""" |
| |
|
| | tokenizer: transformers.PreTrainedTokenizer |
| |
|
| | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
| | input_ids, labels, ids = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "ids")) |
| | multi_input_ids = [] |
| | multi_labels = [] |
| | i_s_pos = [] |
| | for input_id, label in zip(input_ids, labels): |
| | input_id = input_id[: self.tokenizer.model_max_length - 9] |
| | label = label[: self.tokenizer.model_max_length - 9] |
| | i_s_pos.append(input_id.shape[0]+1) |
| | img_id = torch.full((9,), IMAGE_TOKEN_IDX, dtype=input_id.dtype, device=input_id.device) |
| | img_id[0] = DEFAULT_IM_START_TOKEN_IDX |
| | input_id = torch.cat([input_id, img_id]) |
| | img_label = torch.full((9,), IMAGE_TOKEN_IDX, dtype=label.dtype, device=label.device) |
| | img_label[0] = DEFAULT_IM_START_TOKEN_IDX |
| | label = torch.cat([label, img_label]) |
| | multi_input_ids.append(input_id) |
| | multi_labels.append(label) |
| |
|
| | input_ids = multi_input_ids |
| | labels = multi_labels |
| |
|
| | input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) |
| | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) |
| | if input_ids.shape[1] > self.tokenizer.model_max_length: |
| | print(f"Warning input with length {input_ids.shape[1]} is longer than max length {self.tokenizer.model_max_length}") |
| | input_ids = input_ids[:, : self.tokenizer.model_max_length] |
| | labels = labels[:, : self.tokenizer.model_max_length] |
| | batch = dict( |
| | input_ids=input_ids, |
| | labels=labels, |
| | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), |
| | teacher_prompts = [instance['teacher_prompt'] for instance in instances], |
| | teacher_input_ids=torch.cat([instance['teacher_token_ids'] for instance in instances], dim=0), |
| | teacher_attention_mask=torch.cat([instance['teacher_attention_mask'] for instance in instances], dim=0) |
| | ) |
| |
|
| | batch_gen_images = [] |
| | batch_und_images = [] |
| | batch_grid_thw = [] |
| |
|
| | for instance in instances: |
| | if "gen_image" in instance: |
| | batch_gen_images.append(instance["gen_image"]) |
| |
|
| | if len(batch_gen_images) > 0: |
| | if all(x is not None and y.shape == batch_gen_images[0][0].shape for x in batch_gen_images for y in x): |
| | batch["gen_image"] = torch.cat([images.unsqueeze(0) for images in batch_gen_images], dim=0) |
| | else: |
| | batch["gen_image"] = batch_gen_images |
| | else: |
| | batch["gen_image"] = None |
| |
|
| |
|
| | for instance in instances: |
| | if "und_image" in instance: |
| | batch_und_images.append(instance["und_image"].unsqueeze(0)) |
| | batch_grid_thw.append(instance["grid_thw"]) |
| |
|
| |
|
| | |
| | if len(batch_und_images) > 0: |
| | batch["und_image"] = torch.cat([images for images in batch_und_images], dim=0) |
| | batch["grid_thw"] = torch.cat([images for images in batch_grid_thw], dim=0) |
| | else: |
| | batch["und_image"] = None |
| | batch["grid_thw"] = None |
| |
|
| | batch["ids"] = ids |
| |
|
| | batch["i_s_pos"] = i_s_pos |
| |
|
| | return batch |
| |
|
| |
|
| | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, teacher_tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: |
| | train_dataset = LazySupervisedMixDataset(tokenizer=tokenizer, teacher_tokenizer=teacher_tokenizer, data_path=data_args.data_path, data_args=data_args) |
| | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) |
| | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) |
| |
|
| | def train(attn_implementation=None): |
| | global local_rank |
| |
|
| | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
| | model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| | |
| | local_rank = training_args.local_rank |
| | compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32) |
| |
|
| | bnb_model_from_pretrained_args = {} |
| | if training_args.bits in [4, 8]: |
| | from transformers import BitsAndBytesConfig |
| |
|
| | bnb_model_from_pretrained_args.update( |
| | dict( |
| | device_map={"": training_args.device}, |
| | load_in_4bit=training_args.bits == 4, |
| | load_in_8bit=training_args.bits == 8, |
| | quantization_config=BitsAndBytesConfig( |
| | load_in_4bit=training_args.bits == 4, |
| | load_in_8bit=training_args.bits == 8, |
| | llm_int8_skip_modules=["mm_projector"], |
| | llm_int8_threshold=6.0, |
| | llm_int8_has_fp16_weight=False, |
| | bnb_4bit_compute_dtype=compute_dtype, |
| | bnb_4bit_use_double_quant=training_args.double_quant, |
| | bnb_4bit_quant_type=training_args.quant_type, |
| | ), |
| | ) |
| | ) |
| | |
| | model = blip3oFastForCausalLM.from_pretrained( |
| | model_args.model_name_or_path, |
| | cache_dir=training_args.cache_dir, |
| | |
| | torch_dtype=(torch.bfloat16 if training_args.bf16 else None), |
| | **bnb_model_from_pretrained_args, |
| | ) |
| | |
| | |
| | model.config.use_cache = False |
| | |
| | if model_args.freeze_backbone: |
| | for (n, p) in model.get_model().named_parameters(): |
| | p.requires_grad = False |
| | for (n, p) in model.get_vision_tower().named_parameters(): |
| | p.requires_grad = False |
| | for (n, p) in model.lm_head.named_parameters(): |
| | p.requires_grad = False |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | if training_args.gradient_checkpointing: |
| | if hasattr(model, "enable_input_require_grads"): |
| | model.enable_input_require_grads() |
| | else: |
| |
|
| | def make_inputs_require_grad(module, input, output): |
| | output.requires_grad_(True) |
| |
|
| | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
| | |
| | try: |
| | tokenizer = AutoProcessor.from_pretrained(model_args.model_name_or_path).tokenizer |
| | except Exception as e: |
| | tokenizer = AutoProcessor.from_pretrained(model_args.model_name_or_path) |
| |
|
| | teacher_tokenizer = AutoTokenizer.from_pretrained(model_args.teacher_model_name, subfolder="tokenizer") |
| |
|
| | tokenizer.model_max_length = training_args.model_max_length |
| | teacher_tokenizer.padding_side = "right" |
| |
|
| | |
| | if tokenizer.pad_token is None: |
| | smart_tokenizer_and_embedding_resize( |
| | special_tokens_dict=dict( |
| | pad_token="<pad>", |
| | additional_special_tokens=["[IMG]", "[/IMG]", "<image>"], |
| | ), |
| | tokenizer=tokenizer, |
| | model=model, |
| | ) |
| | elif not "<image>" in tokenizer.get_added_vocab(): |
| | smart_tokenizer_and_embedding_resize( |
| | special_tokens_dict=dict(additional_special_tokens=["[IMG]", "[/IMG]", "<image>"]), |
| | tokenizer=tokenizer, |
| | model=model, |
| | ) |
| | if model_args.version in conversation_lib.conv_templates: |
| | conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] |
| | else: |
| | conversation_lib.default_conversation = conversation_lib.conv_templates["llama3"] |
| | rank0_print(f"Using conversation format: {conversation_lib.default_conversation.version}") |
| |
|
| |
|
| |
|
| | |
| | model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp) |
| | image_processor = model.get_model().get_vision_tower().image_processor |
| | data_args.gen_image_processor = image_processor |
| | data_args.image_processor = image_processor |
| |
|
| | data_args.is_multimodal = True |
| | data_args.n_query = model_args.n_query |
| | data_args.n_und_query = model_args.n_und_query |
| |
|
| | model.config.image_aspect_ratio = data_args.image_aspect_ratio |
| | model.config.tokenizer_padding_side = tokenizer.padding_side |
| | model.config.tokenizer_model_max_length = tokenizer.model_max_length |
| |
|
| | model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter |
| |
|
| | model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter |
| |
|
| | |
| | total_params = sum(p.numel() for p in model.get_model().parameters()) |
| | trainable_params = sum(p.numel() for p in model.get_model().parameters() if p.requires_grad) |
| |
|
| | print(f"Total parameters: {total_params}") |
| | print(f"Trainable parameters: {trainable_params}") |
| |
|
| |
|
| | model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end |
| | model.config.mm_projector_lr = training_args.mm_projector_lr |
| | training_args.use_im_start_end = model_args.mm_use_im_start_end |
| | model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token |
| | |
| | model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) |
| | model.config.pad_token_id = tokenizer.pad_token_id |
| |
|
| | data_module = make_supervised_data_module(tokenizer=tokenizer, teacher_tokenizer=teacher_tokenizer, data_args=data_args) |
| |
|
| |
|
| | class WandbLossLogger(TrainerCallback): |
| | def on_step_end(self, args, state, control, **kwargs): |
| | if not state.is_world_process_zero: |
| | return |
| | |
| | model = kwargs["model"] |
| | logs = {} |
| | |
| | if hasattr(model, "_last_diff_loss"): |
| | logs["diff_loss"] = float(model._last_diff_loss) |
| | if hasattr(model, "_last_kd_loss"): |
| | logs["kd_loss"] = float(model._last_kd_loss) |
| | if hasattr(model, "_last_pred_divergence"): |
| | logs["pred_divergence"] = float(model._last_pred_divergence) |
| | if hasattr(model, "kd_weight"): |
| | logs["kd_weight"] = float(model.kd_weight) |
| | if logs: |
| | wandb.log(logs, step=state.global_step) |
| | |
| | class ProgressiveKDCallback(TrainerCallback): |
| | def on_step_begin(self, args, state, control, **kwargs): |
| | model = kwargs["model"] |
| | current_step = state.global_step |
| | total_steps = args.max_steps if args.max_steps > 0 else state.max_steps |
| | |
| | |
| | if hasattr(model, 'kd_weight'): |
| | model.kd_weight = max(0.1, 10 * (1 - current_step / total_steps)) |
| | |
| | trainer = blip3oTrainer( |
| | model=model, |
| | tokenizer=tokenizer, |
| | args=training_args, |
| | callbacks=[WandbLossLogger, ProgressiveKDCallback], |
| | **data_module, |
| | ) |
| | from tabulate import tabulate |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | ''' |
| | from safetensors.torch import load_file |
| | import json |
| | import pathlib |
| | |
| | # ---- Load model.safetensors if it exists ---- |
| | checkpoint_dir = pathlib.Path(training_args.output_dir) |
| | safetensor_path = checkpoint_dir / "model.safetensors" |
| | trainer_state_path = checkpoint_dir / "trainer_state.json" |
| | |
| | if safetensor_path.exists(): |
| | print(f"Loading weights from {safetensor_path}") |
| | state_dict = load_file(safetensor_path) |
| | new_state_dict = {} |
| | for k, v in state_dict.items(): |
| | new_key = k.replace("model.", "", 1) if k.startswith("model.") else k |
| | new_state_dict[new_key] = v |
| | |
| | # print all keys |
| | #print("🔑 Keys in checkpoint:") |
| | #for k in state_dict.keys(): |
| | # print(k, state_dict[k].shape) |
| | |
| | missing, unexpected = model.get_model().load_state_dict(new_state_dict, strict=False) |
| | print("✅ Loaded parameters:") |
| | for k in new_state_dict.keys(): |
| | if k not in missing: |
| | print(f" {k} {tuple(new_state_dict[k].shape)}") |
| | |
| | |
| | # Restore last global step |
| | if trainer_state_path.exists(): |
| | with open(trainer_state_path, "r") as f: |
| | trainer_state = json.load(f) |
| | last_global_step = trainer_state.get("global_step", 0) |
| | last_lr = trainer_state.get("learning_rate", trainer.args.learning_rate) |
| | trainer.state.global_step = last_global_step |
| | # Reset optimizer with last learning rate |
| | trainer.create_optimizer_and_scheduler(num_training_steps=trainer.args.max_steps) |
| | optimizer = trainer.optimizer |
| | #lr_scheduler = trainer.lr_scheduler |
| | |
| | for param_group in optimizer.param_groups: |
| | param_group['lr'] = last_lr |
| | trainer.optimizer = optimizer |
| | print(f"✅ Restored global step: {last_global_step}, learning rate: {last_lr}") |
| | |
| | |
| | ''' |
| | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): |
| | trainer.train(resume_from_checkpoint=True) |
| | else: |
| | trainer.train() |
| | trainer.save_state() |
| |
|
| | model.config.use_cache = True |
| | safe_save_model_for_hf_trainer( |
| | trainer=trainer, |
| | output_dir=training_args.output_dir, |
| | vision_tower=model_args.vision_tower, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | train() |
| |
|
| |
|