|
|
import os |
|
|
import copy |
|
|
from dataclasses import dataclass, field |
|
|
import logging |
|
|
import pathlib |
|
|
from typing import Dict, Optional, Sequence |
|
|
import torch |
|
|
import glob |
|
|
import transformers |
|
|
import tokenizers |
|
|
from blip3o.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX |
|
|
from torch.utils.data import Dataset |
|
|
from blip3o.train.blip3o_trainer import blip3oTrainer |
|
|
from blip3o import conversation as conversation_lib |
|
|
from blip3o.model import * |
|
|
from blip3o.mm_utils import tokenizer_image_token |
|
|
from PIL import Image, ImageFile |
|
|
from datasets import load_dataset, concatenate_datasets |
|
|
from pathlib import Path |
|
|
from datasets.utils.logging import set_verbosity_info |
|
|
from transformers import logging as tf_logging |
|
|
import torchvision.transforms as T |
|
|
from torchvision.transforms.functional import InterpolationMode |
|
|
from transformers import AutoProcessor |
|
|
import random |
|
|
from blip3o.model.multimodal_encoder.eva_clip.eva_clip_processors import EvaClipImageTrainProcessor |
|
|
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
transform_und_images = T.Compose([T.Resize(448, interpolation=InterpolationMode.BICUBIC, antialias=True), T.CenterCrop(448)]) |
|
|
|
|
|
set_verbosity_info() |
|
|
tf_logging.set_verbosity_info() |
|
|
|
|
|
local_rank = None |
|
|
from transformers import TrainerCallback |
|
|
|
|
|
class GradCheckCallback(TrainerCallback): |
|
|
def on_step_end(self, args, state, control, **kwargs): |
|
|
model = kwargs["model"] |
|
|
for name, param in model.named_parameters(): |
|
|
if "caption_embed" in name or "diffusion_connector" in name: |
|
|
if param.grad is None: |
|
|
print(f"{name} has NO gradient!") |
|
|
else: |
|
|
print(f"{name} grad mean: {param.grad.abs().mean().item():.6f}") |
|
|
|
|
|
def rank0_print(*args): |
|
|
if local_rank == 0: |
|
|
print(*args) |
|
|
|
|
|
|
|
|
from packaging import version |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelArguments: |
|
|
model_name_or_path: Optional[str] = field(default="facebook/opt-125m") |
|
|
version: Optional[str] = field(default="v0") |
|
|
freeze_backbone: bool = field(default=True) |
|
|
tune_mm_mlp_adapter: bool = field(default=False) |
|
|
vision_tower: Optional[str] = field(default=None) |
|
|
gen_vision_tower: Optional[str] = field(default=None) |
|
|
mm_vision_select_layer: Optional[int] = field(default=-1) |
|
|
pretrain_mm_mlp_adapter: Optional[str] = field(default=None) |
|
|
pretrain_gen_mlp_adapter: Optional[str] = field(default=None) |
|
|
vision_tower_pretrained: Optional[str] = field(default=None) |
|
|
mm_projector_type: Optional[str] = field(default="linear") |
|
|
gen_projector_type: Optional[str] = field(default="linear") |
|
|
mm_use_im_start_end: bool = field(default=False) |
|
|
mm_use_im_patch_token: bool = field(default=True) |
|
|
mm_patch_merge_type: Optional[str] = field(default="flat") |
|
|
mm_vision_select_feature: Optional[str] = field(default="patch") |
|
|
n_query: Optional[int] = field(default=729) |
|
|
n_und_query: Optional[int] = field(default=729) |
|
|
gen_pooling: Optional[str] = field(default="all") |
|
|
diffusion_name_or_path: Optional[str] = field(default="Efficient-Large-Model/Sana_600M_1024px_diffusers") |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataArguments: |
|
|
data_path: str = field(default=None, metadata={"help": "Path to the training data."}) |
|
|
lazy_preprocess: bool = False |
|
|
is_multimodal: bool = False |
|
|
image_folder: Optional[str] = field(default=None) |
|
|
journeyDB_folder: Optional[str] = field(default=None) |
|
|
shortcaption_image_folder: Optional[str] = field(default=None) |
|
|
data_type: Optional[str] = field(default="mix") |
|
|
image_aspect_ratio: str = "square" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainingArguments(transformers.TrainingArguments): |
|
|
cache_dir: Optional[str] = field(default=None) |
|
|
optim: str = field(default="adamw_torch") |
|
|
remove_unused_columns: bool = field(default=False) |
|
|
freeze_mm_mlp_adapter: bool = field(default=False) |
|
|
mpt_attn_impl: Optional[str] = field(default="triton") |
|
|
model_max_length: int = field( |
|
|
default=512, |
|
|
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, |
|
|
) |
|
|
double_quant: bool = field( |
|
|
default=True, |
|
|
metadata={"help": "Compress the quantization statistics through double quantization."}, |
|
|
) |
|
|
quant_type: str = field( |
|
|
default="nf4", |
|
|
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}, |
|
|
) |
|
|
bits: int = field(default=16, metadata={"help": "How many bits to use."}) |
|
|
lora_enable: bool = False |
|
|
lora_r: int = 64 |
|
|
lora_alpha: int = 16 |
|
|
lora_dropout: float = 0.05 |
|
|
lora_weight_path: str = "" |
|
|
lora_bias: str = "none" |
|
|
mm_projector_lr: Optional[float] = None |
|
|
group_by_modality_length: bool = field(default=False) |
|
|
ddp_find_unused_parameters: bool =True |
|
|
|
|
|
ASPECT_RATIO_512 = { |
|
|
"0.25": [256.0, 1024.0], |
|
|
"0.26": [256.0, 992.0], |
|
|
"0.27": [256.0, 960.0], |
|
|
"0.28": [256.0, 928.0], |
|
|
"0.32": [288.0, 896.0], |
|
|
"0.33": [288.0, 864.0], |
|
|
"0.35": [288.0, 832.0], |
|
|
"0.4": [320.0, 800.0], |
|
|
"0.42": [320.0, 768.0], |
|
|
"0.48": [352.0, 736.0], |
|
|
"0.5": [352.0, 704.0], |
|
|
"0.52": [352.0, 672.0], |
|
|
"0.57": [384.0, 672.0], |
|
|
"0.6": [384.0, 640.0], |
|
|
"0.68": [416.0, 608.0], |
|
|
"0.72": [416.0, 576.0], |
|
|
"0.78": [448.0, 576.0], |
|
|
"0.82": [448.0, 544.0], |
|
|
"0.88": [480.0, 544.0], |
|
|
"0.94": [480.0, 512.0], |
|
|
"1.0": [1024.0, 1024.0], |
|
|
"1.07": [512.0, 480.0], |
|
|
"1.13": [544.0, 480.0], |
|
|
"1.21": [544.0, 448.0], |
|
|
"1.29": [576.0, 448.0], |
|
|
"1.38": [576.0, 416.0], |
|
|
"1.46": [608.0, 416.0], |
|
|
"1.67": [640.0, 384.0], |
|
|
"1.75": [672.0, 384.0], |
|
|
"2.0": [704.0, 352.0], |
|
|
"2.09": [736.0, 352.0], |
|
|
"2.4": [768.0, 320.0], |
|
|
"2.5": [800.0, 320.0], |
|
|
"2.89": [832.0, 288.0], |
|
|
"3.0": [864.0, 288.0], |
|
|
"3.11": [896.0, 288.0], |
|
|
"3.62": [928.0, 256.0], |
|
|
"3.75": [960.0, 256.0], |
|
|
"3.88": [992.0, 256.0], |
|
|
"4.0": [1024.0, 256.0], |
|
|
} |
|
|
print("Input size: ", ASPECT_RATIO_512["1.0"]) |
|
|
|
|
|
def maybe_zero_3(param, ignore_status=False, name=None): |
|
|
from deepspeed import zero |
|
|
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus |
|
|
|
|
|
if hasattr(param, "ds_id"): |
|
|
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: |
|
|
if not ignore_status: |
|
|
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") |
|
|
with zero.GatheredParameters([param]): |
|
|
param = param.data.detach().cpu().clone() |
|
|
else: |
|
|
param = param.detach().cpu().clone() |
|
|
return param |
|
|
|
|
|
|
|
|
|
|
|
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): |
|
|
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} |
|
|
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} |
|
|
return to_return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, vision_tower: str): |
|
|
if trainer.deepspeed: |
|
|
torch.cuda.synchronize() |
|
|
keys_to_match = ["mm_projector"] |
|
|
if getattr(trainer.args, "use_im_start_end", False): |
|
|
keys_to_match.extend(["embed_tokens", "embed_in"]) |
|
|
|
|
|
weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) |
|
|
trainer.model.config.save_pretrained(output_dir) |
|
|
|
|
|
current_folder = output_dir.split("/")[-1] |
|
|
parent_folder = os.path.dirname(output_dir) |
|
|
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: |
|
|
if current_folder.startswith("checkpoint-"): |
|
|
mm_projector_folder = os.path.join(parent_folder, "mm_projector") |
|
|
os.makedirs(mm_projector_folder, exist_ok=True) |
|
|
torch.save( |
|
|
weight_to_save, |
|
|
os.path.join(mm_projector_folder, f"{current_folder}.bin"), |
|
|
) |
|
|
else: |
|
|
torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin")) |
|
|
|
|
|
keys_to_match = ["gen_projector"] |
|
|
if getattr(trainer.args, "use_im_start_end", False): |
|
|
keys_to_match.extend(["embed_tokens", "embed_in"]) |
|
|
|
|
|
weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) |
|
|
trainer.model.config.save_pretrained(output_dir) |
|
|
|
|
|
current_folder = output_dir.split("/")[-1] |
|
|
parent_folder = os.path.dirname(output_dir) |
|
|
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: |
|
|
if current_folder.startswith("checkpoint-"): |
|
|
mm_projector_folder = os.path.join(parent_folder, "gen_projector") |
|
|
os.makedirs(mm_projector_folder, exist_ok=True) |
|
|
torch.save( |
|
|
weight_to_save, |
|
|
os.path.join(mm_projector_folder, f"{current_folder}.bin"), |
|
|
) |
|
|
else: |
|
|
torch.save(weight_to_save, os.path.join(output_dir, f"gen_projector.bin")) |
|
|
|
|
|
if trainer.deepspeed: |
|
|
torch.cuda.synchronize() |
|
|
trainer.save_model(output_dir) |
|
|
return |
|
|
|
|
|
state_dict = trainer.model.state_dict() |
|
|
if trainer.args.should_save: |
|
|
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} |
|
|
del state_dict |
|
|
trainer._save(output_dir, state_dict=cpu_state_dict) |
|
|
|
|
|
|
|
|
def smart_tokenizer_and_embedding_resize( |
|
|
special_tokens_dict: Dict, |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
model: transformers.PreTrainedModel, |
|
|
): |
|
|
|
|
|
|
|
|
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) |
|
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
if num_new_tokens > 0: |
|
|
input_embeddings = model.get_input_embeddings().weight.data |
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict: |
|
|
is_multimodal = data_args.is_multimodal |
|
|
if not is_multimodal: return sources |
|
|
und_placeholder = "<|vision_start|>" + "<|image_pad|>" * data_args.n_und_query + "<|vision_end|>" |
|
|
gen_placeholder = "" |
|
|
|
|
|
inst_type = None |
|
|
for source in sources: |
|
|
for sentence in source: |
|
|
if sentence["from"] == "human" and "<image>" in sentence["value"]: |
|
|
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, und_placeholder).strip() |
|
|
inst_type = "und" |
|
|
elif sentence["from"] == "gpt" and "<image>" in sentence["value"]: |
|
|
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, gen_placeholder).strip() |
|
|
inst_type = "gen" |
|
|
return sources, inst_type |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict: |
|
|
roles = {"human": "user", "gpt": "assistant"} |
|
|
|
|
|
tokenizer = copy.deepcopy(tokenizer) |
|
|
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" |
|
|
tokenizer.chat_template = chat_template |
|
|
|
|
|
|
|
|
input_ids, targets = [], [] |
|
|
for i, source in enumerate(sources): |
|
|
if roles[source[0]["from"]] != roles["human"]: |
|
|
source = source[1:] |
|
|
|
|
|
input_id, target = [], [] |
|
|
|
|
|
|
|
|
|
|
|
input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}]) |
|
|
target += [IGNORE_INDEX] * len(input_id) |
|
|
|
|
|
for conv in source: |
|
|
try: |
|
|
role = conv["role"] |
|
|
content = conv["content"] |
|
|
except: |
|
|
role = conv["from"] |
|
|
content = conv["value"] |
|
|
|
|
|
role = roles.get(role, role) |
|
|
|
|
|
conv = [{"role" : role, "content" : content}] |
|
|
encode_id = tokenizer.apply_chat_template(conv) |
|
|
input_id += encode_id |
|
|
if role in ["user", "system"]: |
|
|
target += [IGNORE_INDEX] * len(encode_id) |
|
|
else: |
|
|
target += encode_id |
|
|
|
|
|
|
|
|
|
|
|
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" |
|
|
|
|
|
input_ids.append(input_id) |
|
|
targets.append(target) |
|
|
input_ids = torch.tensor(input_ids, dtype=torch.long) |
|
|
targets = torch.tensor(targets, dtype=torch.long) |
|
|
|
|
|
return dict( |
|
|
input_ids=input_ids, |
|
|
labels=targets, |
|
|
) |
|
|
|
|
|
def get_closest_ratio(height: float, width: float, ratios: dict): |
|
|
aspect_ratio = height / width |
|
|
closest_ratio = "1.0" |
|
|
return ratios[closest_ratio], float(closest_ratio) |
|
|
|
|
|
|
|
|
|
|
|
class LazySupervisedMixDataset(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
data_path: str, |
|
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
data_args: DataArguments, |
|
|
): |
|
|
super(LazySupervisedMixDataset, self).__init__() |
|
|
|
|
|
self.data_args = data_args |
|
|
list_data_dict = [] |
|
|
|
|
|
|
|
|
data_files = glob.glob(os.path.join(self.data_args.image_folder, "*.tar")) |
|
|
|
|
|
train_dataset = load_dataset("webdataset", data_files=data_files, split="train", num_proc=32) |
|
|
train_dataset = train_dataset.rename_column("jpg", "image") |
|
|
train_dataset = train_dataset.add_column('type', len(train_dataset) * ['T2I']) |
|
|
train_dataset = train_dataset.add_column('image_path', len(train_dataset) * [None]) |
|
|
train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if not col in ( |
|
|
["image", "txt", "type", "image_path"])]) |
|
|
print(f"finish loading image {len(train_dataset)}") |
|
|
list_data_dict.append(train_dataset) |
|
|
|
|
|
|
|
|
if len(list_data_dict) > 1: |
|
|
list_data_dict = concatenate_datasets(list_data_dict) |
|
|
else: |
|
|
list_data_dict = list_data_dict[0] |
|
|
list_data_dict = list_data_dict.shuffle(seed=42) |
|
|
|
|
|
rank0_print(f"Total number of training instance: {len(list_data_dict)}") |
|
|
self.tokenizer = tokenizer |
|
|
self.list_data_dict = list_data_dict |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.list_data_dict) |
|
|
|
|
|
@property |
|
|
def lengths(self): |
|
|
length_list = [] |
|
|
for sample in self.list_data_dict: |
|
|
img_tokens = 128 if "image" in sample else 0 |
|
|
length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens) |
|
|
return length_list |
|
|
|
|
|
@property |
|
|
def modality_lengths(self): |
|
|
length_list = [] |
|
|
for sample in self.list_data_dict: |
|
|
cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"]) |
|
|
cur_len = cur_len if "image" in sample else -cur_len |
|
|
length_list.append(cur_len) |
|
|
return length_list |
|
|
|
|
|
def _safe_img_process(self, imgs): |
|
|
try: |
|
|
out = [] |
|
|
for img in imgs: |
|
|
ori_h, ori_w = img.height, img.width |
|
|
closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, ASPECT_RATIO_512) |
|
|
closest_size = [int(x) for x in closest_size] |
|
|
if closest_size[0] / ori_h > closest_size[1] / ori_w: |
|
|
resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h) |
|
|
else: |
|
|
resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1] |
|
|
transform = T.Compose([ |
|
|
T.Lambda(lambda img: img.convert("RGB")), |
|
|
T.Resize(resize_size, interpolation=InterpolationMode.BICUBIC), |
|
|
T.CenterCrop(closest_size), |
|
|
T.ToTensor(), |
|
|
T.Normalize([0.5], [0.5]), |
|
|
]) |
|
|
out.append(transform(img)) |
|
|
return out |
|
|
except Exception as e: |
|
|
print(f"Corrupted image during processing: {e}") |
|
|
return None |
|
|
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
while True: |
|
|
try: |
|
|
sources = self.list_data_dict[i] |
|
|
sources["conversations"] = [ |
|
|
{"from": "human", "value": f"Please generate image based on the following caption: {sources['txt']}"}, |
|
|
{"from": "gpt", "value": "<image>"}, |
|
|
] |
|
|
image_files = self.list_data_dict[i]["image"] |
|
|
if not isinstance(image_files, list): |
|
|
image_files = [image_files] |
|
|
|
|
|
is_corrupt = False |
|
|
images = [] |
|
|
for img in image_files: |
|
|
img = img.convert("RGB") |
|
|
images.append(img) |
|
|
|
|
|
processed_images = self._safe_img_process(images) |
|
|
if processed_images is None: |
|
|
print("Corrupted image during transform, picking new sample.") |
|
|
i = random.randint(0, len(self.list_data_dict) - 1) |
|
|
continue |
|
|
|
|
|
sources, inst_type = preprocess_multimodal(copy.deepcopy([sources["conversations"]]), self.data_args) |
|
|
data_dict = preprocess_qwen(sources, self.tokenizer, has_image=("image" in self.list_data_dict[i])) |
|
|
if isinstance(i, int): |
|
|
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) |
|
|
|
|
|
data_dict["gen_image"] = processed_images[0] |
|
|
data_dict["ids"] = self.list_data_dict[i]["id"] if "id" in self.list_data_dict[i] else "unk" |
|
|
return data_dict |
|
|
except Exception as e: |
|
|
print(f"[WARN] Skipping corrupted sample {i}: {e}") |
|
|
i = random.randint(0, len(self.list_data_dict) - 1) |
|
|
continue |
|
|
|
|
|
@dataclass |
|
|
class DataCollatorForSupervisedDataset(object): |
|
|
"""Collate examples for supervised fine-tuning.""" |
|
|
|
|
|
tokenizer: transformers.PreTrainedTokenizer |
|
|
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
|
input_ids, labels, ids = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "ids")) |
|
|
multi_input_ids = [] |
|
|
multi_labels = [] |
|
|
i_s_pos = [] |
|
|
for input_id, label in zip(input_ids, labels): |
|
|
input_id = input_id[: self.tokenizer.model_max_length - 17] |
|
|
label = label[: self.tokenizer.model_max_length - 17] |
|
|
i_s_pos.append(input_id.shape[0]+1) |
|
|
img_id = torch.full((17,), IMAGE_TOKEN_IDX, dtype=input_id.dtype, device=input_id.device) |
|
|
img_id[0] = DEFAULT_IM_START_TOKEN_IDX |
|
|
|
|
|
img_label = torch.full((17,), IMAGE_TOKEN_IDX, dtype=label.dtype, device=label.device) |
|
|
img_label[0] = DEFAULT_IM_START_TOKEN_IDX |
|
|
|
|
|
multi_input_ids.append(input_id) |
|
|
multi_labels.append(label) |
|
|
|
|
|
input_ids = multi_input_ids |
|
|
labels = multi_labels |
|
|
|
|
|
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) |
|
|
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) |
|
|
if input_ids.shape[1] > self.tokenizer.model_max_length: |
|
|
print(f"Warning input with length {input_ids.shape[1]} is longer than max length {self.tokenizer.model_max_length}") |
|
|
input_ids = input_ids[:, : self.tokenizer.model_max_length] |
|
|
labels = labels[:, : self.tokenizer.model_max_length] |
|
|
attention_mask = input_ids.ne(self.tokenizer.pad_token_id) |
|
|
batch = dict( |
|
|
input_ids=input_ids, |
|
|
labels=labels, |
|
|
attention_mask=attention_mask, |
|
|
) |
|
|
|
|
|
batch_gen_images = [] |
|
|
batch_und_images = [] |
|
|
batch_grid_thw = [] |
|
|
|
|
|
for instance in instances: |
|
|
if "gen_image" in instance: |
|
|
batch_gen_images.append(instance["gen_image"]) |
|
|
|
|
|
if len(batch_gen_images) > 0: |
|
|
if all(x is not None and y.shape == batch_gen_images[0][0].shape for x in batch_gen_images for y in x): |
|
|
batch["gen_image"] = torch.cat([images.unsqueeze(0) for images in batch_gen_images], dim=0) |
|
|
else: |
|
|
batch["gen_image"] = batch_gen_images |
|
|
else: |
|
|
batch["gen_image"] = None |
|
|
|
|
|
|
|
|
for instance in instances: |
|
|
if "und_image" in instance: |
|
|
batch_und_images.append(instance["und_image"].unsqueeze(0)) |
|
|
batch_grid_thw.append(instance["grid_thw"]) |
|
|
|
|
|
|
|
|
|
|
|
if len(batch_und_images) > 0: |
|
|
batch["und_image"] = torch.cat([images for images in batch_und_images], dim=0) |
|
|
batch["grid_thw"] = torch.cat([images for images in batch_grid_thw], dim=0) |
|
|
else: |
|
|
batch["und_image"] = None |
|
|
batch["grid_thw"] = None |
|
|
|
|
|
batch["ids"] = ids |
|
|
batch["i_s_pos"] = i_s_pos |
|
|
return batch |
|
|
|
|
|
|
|
|
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: |
|
|
train_dataset = LazySupervisedMixDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args) |
|
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) |
|
|
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) |
|
|
|
|
|
def train(attn_implementation=None): |
|
|
global local_rank |
|
|
|
|
|
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
print(model_args, data_args, training_args) |
|
|
local_rank = training_args.local_rank |
|
|
compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32) |
|
|
|
|
|
bnb_model_from_pretrained_args = {} |
|
|
if training_args.bits in [4, 8]: |
|
|
from transformers import BitsAndBytesConfig |
|
|
|
|
|
bnb_model_from_pretrained_args.update( |
|
|
dict( |
|
|
device_map={"": training_args.device}, |
|
|
load_in_4bit=training_args.bits == 4, |
|
|
load_in_8bit=training_args.bits == 8, |
|
|
quantization_config=BitsAndBytesConfig( |
|
|
load_in_4bit=training_args.bits == 4, |
|
|
load_in_8bit=training_args.bits == 8, |
|
|
llm_int8_skip_modules=["mm_projector"], |
|
|
llm_int8_threshold=6.0, |
|
|
llm_int8_has_fp16_weight=False, |
|
|
bnb_4bit_compute_dtype=compute_dtype, |
|
|
bnb_4bit_use_double_quant=training_args.double_quant, |
|
|
bnb_4bit_quant_type=training_args.quant_type, |
|
|
), |
|
|
) |
|
|
) |
|
|
|
|
|
model = blip3oFastForCausalLM.from_pretrained( |
|
|
model_args.model_name_or_path, |
|
|
cache_dir=training_args.cache_dir, |
|
|
|
|
|
torch_dtype=(torch.bfloat16 if training_args.bf16 else None), |
|
|
**bnb_model_from_pretrained_args, |
|
|
) |
|
|
|
|
|
|
|
|
model.config.use_cache = False |
|
|
|
|
|
if model_args.freeze_backbone: |
|
|
for (n, p) in model.get_model().named_parameters(): |
|
|
p.requires_grad = False |
|
|
for (n, p) in model.get_vision_tower().named_parameters(): |
|
|
p.requires_grad = False |
|
|
for (n, p) in model.lm_head.named_parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if training_args.gradient_checkpointing: |
|
|
if hasattr(model, "enable_input_require_grads"): |
|
|
model.enable_input_require_grads() |
|
|
else: |
|
|
|
|
|
def make_inputs_require_grad(module, input, output): |
|
|
output.requires_grad_(True) |
|
|
|
|
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
|
|
|
|
|
try: |
|
|
tokenizer = AutoProcessor.from_pretrained(model_args.model_name_or_path).tokenizer |
|
|
except Exception as e: |
|
|
tokenizer = AutoProcessor.from_pretrained(model_args.model_name_or_path) |
|
|
|
|
|
tokenizer.model_max_length = training_args.model_max_length |
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
smart_tokenizer_and_embedding_resize( |
|
|
special_tokens_dict=dict( |
|
|
pad_token="<pad>", |
|
|
additional_special_tokens=["[IMG]", "[/IMG]", "<image>"], |
|
|
), |
|
|
tokenizer=tokenizer, |
|
|
model=model, |
|
|
) |
|
|
elif not "<image>" in tokenizer.get_added_vocab(): |
|
|
smart_tokenizer_and_embedding_resize( |
|
|
special_tokens_dict=dict(additional_special_tokens=["[IMG]", "[/IMG]", "<image>"]), |
|
|
tokenizer=tokenizer, |
|
|
model=model, |
|
|
) |
|
|
if model_args.version in conversation_lib.conv_templates: |
|
|
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] |
|
|
else: |
|
|
conversation_lib.default_conversation = conversation_lib.conv_templates["llama3"] |
|
|
rank0_print(f"Using conversation format: {conversation_lib.default_conversation.version}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp) |
|
|
image_processor = model.get_model().get_vision_tower().image_processor |
|
|
data_args.gen_image_processor = image_processor |
|
|
data_args.image_processor = image_processor |
|
|
|
|
|
data_args.is_multimodal = True |
|
|
data_args.n_query = model_args.n_query |
|
|
data_args.n_und_query = model_args.n_und_query |
|
|
|
|
|
model.config.image_aspect_ratio = data_args.image_aspect_ratio |
|
|
model.config.tokenizer_padding_side = tokenizer.padding_side |
|
|
model.config.tokenizer_model_max_length = tokenizer.model_max_length |
|
|
|
|
|
model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter |
|
|
|
|
|
model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter |
|
|
|
|
|
|
|
|
total_params = sum(p.numel() for p in model.get_model().parameters()) |
|
|
trainable_params = sum(p.numel() for p in model.get_model().parameters() if p.requires_grad) |
|
|
|
|
|
print(f"Total parameters: {total_params}") |
|
|
print(f"Trainable parameters: {trainable_params}") |
|
|
|
|
|
|
|
|
model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end |
|
|
model.config.mm_projector_lr = training_args.mm_projector_lr |
|
|
training_args.use_im_start_end = model_args.mm_use_im_start_end |
|
|
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token |
|
|
|
|
|
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) |
|
|
model.config.pad_token_id = tokenizer.pad_token_id |
|
|
|
|
|
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) |
|
|
|
|
|
trainer = blip3oTrainer( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
args=training_args, |
|
|
|
|
|
**data_module, |
|
|
) |
|
|
from tabulate import tabulate |
|
|
|
|
|
if trainer.is_world_process_zero(): |
|
|
stat = [] |
|
|
for i, (n, p) in enumerate(trainer.model.named_parameters()): |
|
|
stat.append([i, n, p.shape, p.requires_grad]) |
|
|
print(tabulate(stat, headers=["idx", "name", "shape", "trainable"])) |
|
|
|
|
|
''' |
|
|
from safetensors.torch import load_file |
|
|
import json |
|
|
import pathlib |
|
|
|
|
|
# ---- Load model.safetensors if it exists ---- |
|
|
checkpoint_dir = pathlib.Path(training_args.output_dir) |
|
|
safetensor_path = checkpoint_dir / "model.safetensors" |
|
|
trainer_state_path = checkpoint_dir / "trainer_state.json" |
|
|
|
|
|
if safetensor_path.exists(): |
|
|
print(f"Loading weights from {safetensor_path}") |
|
|
state_dict = load_file(safetensor_path) |
|
|
new_state_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
new_key = k.replace("model.", "", 1) if k.startswith("model.") else k |
|
|
new_state_dict[new_key] = v |
|
|
|
|
|
# print all keys |
|
|
#print("🔑 Keys in checkpoint:") |
|
|
#for k in state_dict.keys(): |
|
|
# print(k, state_dict[k].shape) |
|
|
|
|
|
missing, unexpected = model.get_model().load_state_dict(new_state_dict, strict=False) |
|
|
print("✅ Loaded parameters:") |
|
|
for k in new_state_dict.keys(): |
|
|
if k not in missing: |
|
|
print(f" {k} {tuple(new_state_dict[k].shape)}") |
|
|
|
|
|
|
|
|
# Restore last global step |
|
|
if trainer_state_path.exists(): |
|
|
with open(trainer_state_path, "r") as f: |
|
|
trainer_state = json.load(f) |
|
|
last_global_step = trainer_state.get("global_step", 0) |
|
|
last_lr = trainer_state.get("learning_rate", trainer.args.learning_rate) |
|
|
trainer.state.global_step = last_global_step |
|
|
# Reset optimizer with last learning rate |
|
|
trainer.create_optimizer_and_scheduler(num_training_steps=trainer.args.max_steps) |
|
|
optimizer = trainer.optimizer |
|
|
#lr_scheduler = trainer.lr_scheduler |
|
|
|
|
|
for param_group in optimizer.param_groups: |
|
|
param_group['lr'] = last_lr |
|
|
trainer.optimizer = optimizer |
|
|
print(f"✅ Restored global step: {last_global_step}, learning rate: {last_lr}") |
|
|
|
|
|
|
|
|
''' |
|
|
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): |
|
|
trainer.train(resume_from_checkpoint=True) |
|
|
else: |
|
|
trainer.train() |
|
|
trainer.save_state() |
|
|
|
|
|
model.config.use_cache = True |
|
|
safe_save_model_for_hf_trainer( |
|
|
trainer=trainer, |
|
|
output_dir=training_args.output_dir, |
|
|
vision_tower=model_args.vision_tower, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |
|
|
|
|
|
|