Fahad-S's picture
Upload noqueries_code/train.py with huggingface_hub
f2ba706 verified
import os
import copy
from dataclasses import dataclass, field
import logging
import pathlib
from typing import Dict, Optional, Sequence
import torch
import glob
import transformers
import tokenizers
from blip3o.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX
from torch.utils.data import Dataset
from blip3o.train.blip3o_trainer import blip3oTrainer
from blip3o import conversation as conversation_lib
from blip3o.model import *
from blip3o.mm_utils import tokenizer_image_token
from PIL import Image, ImageFile
from datasets import load_dataset, concatenate_datasets
from pathlib import Path
from datasets.utils.logging import set_verbosity_info
from transformers import logging as tf_logging
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoProcessor
import random
from blip3o.model.multimodal_encoder.eva_clip.eva_clip_processors import EvaClipImageTrainProcessor
ImageFile.LOAD_TRUNCATED_IMAGES = True
transform_und_images = T.Compose([T.Resize(448, interpolation=InterpolationMode.BICUBIC, antialias=True), T.CenterCrop(448)])
set_verbosity_info()
tf_logging.set_verbosity_info()
local_rank = None
from transformers import TrainerCallback
class GradCheckCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
model = kwargs["model"]
for name, param in model.named_parameters():
if "caption_embed" in name or "diffusion_connector" in name:
if param.grad is None:
print(f"{name} has NO gradient!")
else:
print(f"{name} grad mean: {param.grad.abs().mean().item():.6f}")
def rank0_print(*args):
if local_rank == 0:
print(*args)
from packaging import version
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
version: Optional[str] = field(default="v0")
freeze_backbone: bool = field(default=True)
tune_mm_mlp_adapter: bool = field(default=False)
vision_tower: Optional[str] = field(default=None)
gen_vision_tower: Optional[str] = field(default=None)
mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
pretrain_gen_mlp_adapter: Optional[str] = field(default=None)
vision_tower_pretrained: Optional[str] = field(default=None)
mm_projector_type: Optional[str] = field(default="linear")
gen_projector_type: Optional[str] = field(default="linear")
mm_use_im_start_end: bool = field(default=False)
mm_use_im_patch_token: bool = field(default=True)
mm_patch_merge_type: Optional[str] = field(default="flat")
mm_vision_select_feature: Optional[str] = field(default="patch")
n_query: Optional[int] = field(default=729) # clip 576, siglip 729
n_und_query: Optional[int] = field(default=729) # clip 576, siglip 729
gen_pooling: Optional[str] = field(default="all") # options are: pool2d_3, pool2d_9, seq_3, seq_9, seq_27
diffusion_name_or_path: Optional[str] = field(default="Efficient-Large-Model/Sana_600M_1024px_diffusers")
@dataclass
class DataArguments:
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
lazy_preprocess: bool = False
is_multimodal: bool = False
image_folder: Optional[str] = field(default=None)
journeyDB_folder: Optional[str] = field(default=None)
shortcaption_image_folder: Optional[str] = field(default=None)
data_type: Optional[str] = field(default="mix")
image_aspect_ratio: str = "square"
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
remove_unused_columns: bool = field(default=False)
freeze_mm_mlp_adapter: bool = field(default=False)
mpt_attn_impl: Optional[str] = field(default="triton")
model_max_length: int = field(
default=512,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
double_quant: bool = field(
default=True,
metadata={"help": "Compress the quantization statistics through double quantization."},
)
quant_type: str = field(
default="nf4",
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."},
)
bits: int = field(default=16, metadata={"help": "How many bits to use."})
lora_enable: bool = False
lora_r: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_weight_path: str = ""
lora_bias: str = "none"
mm_projector_lr: Optional[float] = None
group_by_modality_length: bool = field(default=False)
ddp_find_unused_parameters: bool =True
ASPECT_RATIO_512 = {
"0.25": [256.0, 1024.0],
"0.26": [256.0, 992.0],
"0.27": [256.0, 960.0],
"0.28": [256.0, 928.0],
"0.32": [288.0, 896.0],
"0.33": [288.0, 864.0],
"0.35": [288.0, 832.0],
"0.4": [320.0, 800.0],
"0.42": [320.0, 768.0],
"0.48": [352.0, 736.0],
"0.5": [352.0, 704.0],
"0.52": [352.0, 672.0],
"0.57": [384.0, 672.0],
"0.6": [384.0, 640.0],
"0.68": [416.0, 608.0],
"0.72": [416.0, 576.0],
"0.78": [448.0, 576.0],
"0.82": [448.0, 544.0],
"0.88": [480.0, 544.0],
"0.94": [480.0, 512.0],
"1.0": [1024.0, 1024.0],
"1.07": [512.0, 480.0],
"1.13": [544.0, 480.0],
"1.21": [544.0, 448.0],
"1.29": [576.0, 448.0],
"1.38": [576.0, 416.0],
"1.46": [608.0, 416.0],
"1.67": [640.0, 384.0],
"1.75": [672.0, 384.0],
"2.0": [704.0, 352.0],
"2.09": [736.0, 352.0],
"2.4": [768.0, 320.0],
"2.5": [800.0, 320.0],
"2.89": [832.0, 288.0],
"3.0": [864.0, 288.0],
"3.11": [896.0, 288.0],
"3.62": [928.0, 256.0],
"3.75": [960.0, 256.0],
"3.88": [992.0, 256.0],
"4.0": [1024.0, 256.0],
}
print("Input size: ", ASPECT_RATIO_512["1.0"])
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, vision_tower: str):
if trainer.deepspeed:
torch.cuda.synchronize()
keys_to_match = ["mm_projector"]
if getattr(trainer.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
trainer.model.config.save_pretrained(output_dir)
current_folder = output_dir.split("/")[-1]
parent_folder = os.path.dirname(output_dir)
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
if current_folder.startswith("checkpoint-"):
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(
weight_to_save,
os.path.join(mm_projector_folder, f"{current_folder}.bin"),
)
else:
torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin"))
keys_to_match = ["gen_projector"]
if getattr(trainer.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
trainer.model.config.save_pretrained(output_dir)
current_folder = output_dir.split("/")[-1]
parent_folder = os.path.dirname(output_dir)
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
if current_folder.startswith("checkpoint-"):
mm_projector_folder = os.path.join(parent_folder, "gen_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(
weight_to_save,
os.path.join(mm_projector_folder, f"{current_folder}.bin"),
)
else:
torch.save(weight_to_save, os.path.join(output_dir, f"gen_projector.bin"))
if trainer.deepspeed:
torch.cuda.synchronize()
trainer.save_model(output_dir)
return
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
is_multimodal = data_args.is_multimodal
if not is_multimodal: return sources
und_placeholder = "<|vision_start|>" + "<|image_pad|>" * data_args.n_und_query + "<|vision_end|>"
gen_placeholder = ""
# "[IMG]" + "<image>" * data_args.n_query + "[/IMG]"
inst_type = None
for source in sources: # [instance]
for sentence in source:
if sentence["from"] == "human" and "<image>" in sentence["value"]:
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, und_placeholder).strip()
inst_type = "und"
elif sentence["from"] == "gpt" and "<image>" in sentence["value"]:
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, gen_placeholder).strip()
inst_type = "gen"
return sources, inst_type
def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
roles = {"human": "user", "gpt": "assistant"}
tokenizer = copy.deepcopy(tokenizer)
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
tokenizer.chat_template = chat_template
# Apply prompt templates
input_ids, targets = [], []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != roles["human"]:
source = source[1:]
input_id, target = [], []
# New version, use apply chat template
# Build system message for each sentence
input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}])
target += [IGNORE_INDEX] * len(input_id)
for conv in source:
try:
role = conv["role"]
content = conv["content"]
except:
role = conv["from"]
content = conv["value"]
role = roles.get(role, role)
conv = [{"role" : role, "content" : content}]
encode_id = tokenizer.apply_chat_template(conv)
input_id += encode_id
if role in ["user", "system"]:
target += [IGNORE_INDEX] * len(encode_id)
else:
target += encode_id
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
input_ids.append(input_id)
targets.append(target)
input_ids = torch.tensor(input_ids, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)
return dict(
input_ids=input_ids, # tensor(bs x seq_len)
labels=targets, # tensor(bs x seq_len)
)
def get_closest_ratio(height: float, width: float, ratios: dict):
aspect_ratio = height / width
closest_ratio = "1.0" #min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
return ratios[closest_ratio], float(closest_ratio)
class LazySupervisedMixDataset(Dataset):
def __init__(
self,
data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
data_args: DataArguments,
):
super(LazySupervisedMixDataset, self).__init__()
self.data_args = data_args
list_data_dict = []
###################################### text to image #######################################
data_files = glob.glob(os.path.join(self.data_args.image_folder, "*.tar"))
#data_files = glob.glob(os.path.join('/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-Pretrain-Long-Caption', "*.tar")) + glob.glob(os.path.join('/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-Pretrain-Short-Caption', "*.tar")) + glob.glob(os.path.join('/proj/cvl/users/x_fahkh2/BLIP3o/dataset/BLIP3o-Pretrain-JourneyDB', "*.tar"))
train_dataset = load_dataset("webdataset", data_files=data_files, split="train", num_proc=32)
train_dataset = train_dataset.rename_column("jpg", "image")
train_dataset = train_dataset.add_column('type', len(train_dataset) * ['T2I'])
train_dataset = train_dataset.add_column('image_path', len(train_dataset) * [None])
train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if not col in (
["image", "txt", "type", "image_path"])])
print(f"finish loading image {len(train_dataset)}")
list_data_dict.append(train_dataset)
if len(list_data_dict) > 1:
list_data_dict = concatenate_datasets(list_data_dict)
else:
list_data_dict = list_data_dict[0]
list_data_dict = list_data_dict.shuffle(seed=42)
rank0_print(f"Total number of training instance: {len(list_data_dict)}")
self.tokenizer = tokenizer
self.list_data_dict = list_data_dict
def __len__(self):
return len(self.list_data_dict)
@property
def lengths(self):
length_list = []
for sample in self.list_data_dict:
img_tokens = 128 if "image" in sample else 0
length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
cur_len = cur_len if "image" in sample else -cur_len
length_list.append(cur_len)
return length_list
def _safe_img_process(self, imgs):
try:
out = []
for img in imgs:
ori_h, ori_w = img.height, img.width
closest_size, closest_ratio = get_closest_ratio(ori_h, ori_w, ASPECT_RATIO_512)
closest_size = [int(x) for x in closest_size]
if closest_size[0] / ori_h > closest_size[1] / ori_w:
resize_size = closest_size[0], int(ori_w * closest_size[0] / ori_h)
else:
resize_size = int(ori_h * closest_size[1] / ori_w), closest_size[1]
transform = T.Compose([
T.Lambda(lambda img: img.convert("RGB")),
T.Resize(resize_size, interpolation=InterpolationMode.BICUBIC), # Image.BICUBIC
T.CenterCrop(closest_size),
T.ToTensor(),
T.Normalize([0.5], [0.5]),
])
out.append(transform(img))
return out
except Exception as e:
print(f"Corrupted image during processing: {e}")
return None
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
while True:
try:
sources = self.list_data_dict[i]
sources["conversations"] = [
{"from": "human", "value": f"Please generate image based on the following caption: {sources['txt']}"},
{"from": "gpt", "value": "<image>"},
]
image_files = self.list_data_dict[i]["image"]
if not isinstance(image_files, list):
image_files = [image_files]
is_corrupt = False
images = []
for img in image_files:
img = img.convert("RGB")
images.append(img)
processed_images = self._safe_img_process(images)
if processed_images is None:
print("Corrupted image during transform, picking new sample.")
i = random.randint(0, len(self.list_data_dict) - 1)
continue
# just replace <image> with "" in generation tasks
sources, inst_type = preprocess_multimodal(copy.deepcopy([sources["conversations"]]), self.data_args)
data_dict = preprocess_qwen(sources, self.tokenizer, has_image=("image" in self.list_data_dict[i]))
if isinstance(i, int):
data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
data_dict["gen_image"] = processed_images[0]
data_dict["ids"] = self.list_data_dict[i]["id"] if "id" in self.list_data_dict[i] else "unk"
return data_dict
except Exception as e:
print(f"[WARN] Skipping corrupted sample {i}: {e}")
i = random.randint(0, len(self.list_data_dict) - 1)
continue
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels, ids = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "ids"))
multi_input_ids = []
multi_labels = []
i_s_pos = []
for input_id, label in zip(input_ids, labels):
input_id = input_id[: self.tokenizer.model_max_length - 17]
label = label[: self.tokenizer.model_max_length - 17]
i_s_pos.append(input_id.shape[0]+1)
img_id = torch.full((17,), IMAGE_TOKEN_IDX, dtype=input_id.dtype, device=input_id.device)
img_id[0] = DEFAULT_IM_START_TOKEN_IDX
# input_id = torch.cat([input_id, img_id])
img_label = torch.full((17,), IMAGE_TOKEN_IDX, dtype=label.dtype, device=label.device)
img_label[0] = DEFAULT_IM_START_TOKEN_IDX
# label = torch.cat([label, img_label])
multi_input_ids.append(input_id)
multi_labels.append(label)
input_ids = multi_input_ids
labels = multi_labels
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
if input_ids.shape[1] > self.tokenizer.model_max_length:
print(f"Warning input with length {input_ids.shape[1]} is longer than max length {self.tokenizer.model_max_length}")
input_ids = input_ids[:, : self.tokenizer.model_max_length]
labels = labels[:, : self.tokenizer.model_max_length]
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask,
)
batch_gen_images = []
batch_und_images = []
batch_grid_thw = []
for instance in instances:
if "gen_image" in instance:
batch_gen_images.append(instance["gen_image"])
if len(batch_gen_images) > 0:
if all(x is not None and y.shape == batch_gen_images[0][0].shape for x in batch_gen_images for y in x):
batch["gen_image"] = torch.cat([images.unsqueeze(0) for images in batch_gen_images], dim=0)
else:
batch["gen_image"] = batch_gen_images
else:
batch["gen_image"] = None
for instance in instances:
if "und_image" in instance:
batch_und_images.append(instance["und_image"].unsqueeze(0)) ## 1*1024*1176
batch_grid_thw.append(instance["grid_thw"]) ## 1*3
# print(f"batch_und_images {batch_und_images}")
if len(batch_und_images) > 0:
batch["und_image"] = torch.cat([images for images in batch_und_images], dim=0)
batch["grid_thw"] = torch.cat([images for images in batch_grid_thw], dim=0)
else:
batch["und_image"] = None
batch["grid_thw"] = None
batch["ids"] = ids
batch["i_s_pos"] = i_s_pos
return batch
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
train_dataset = LazySupervisedMixDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
def train(attn_implementation=None):
global local_rank
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
print(model_args, data_args, training_args)
local_rank = training_args.local_rank
compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)
bnb_model_from_pretrained_args = {}
if training_args.bits in [4, 8]:
from transformers import BitsAndBytesConfig
bnb_model_from_pretrained_args.update(
dict(
device_map={"": training_args.device},
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
quantization_config=BitsAndBytesConfig(
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
llm_int8_skip_modules=["mm_projector"],
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'}
),
)
)
model = blip3oFastForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
# attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args,
)
model.config.use_cache = False
if model_args.freeze_backbone:
for (n, p) in model.get_model().named_parameters():
p.requires_grad = False
for (n, p) in model.get_vision_tower().named_parameters():
p.requires_grad = False
for (n, p) in model.lm_head.named_parameters():
p.requires_grad = False
#for (n, p) in model.get_model().named_parameters():
# p.requires_grad = True
#for (n, p) in model.get_vision_tower().named_parameters():
# p.requires_grad = False
#for (n, p) in model.get_model().embed_tokens.named_parameters():
# p.requires_grad=True
if training_args.gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
try:
tokenizer = AutoProcessor.from_pretrained(model_args.model_name_or_path).tokenizer
except Exception as e:
tokenizer = AutoProcessor.from_pretrained(model_args.model_name_or_path)
tokenizer.model_max_length = training_args.model_max_length
# tokenizer.pad_token = tokenizer.unk_token
if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(
pad_token="<pad>",
additional_special_tokens=["[IMG]", "[/IMG]", "<image>"],
),
tokenizer=tokenizer,
model=model,
)
elif not "<image>" in tokenizer.get_added_vocab():
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(additional_special_tokens=["[IMG]", "[/IMG]", "<image>"]),
tokenizer=tokenizer,
model=model,
)
if model_args.version in conversation_lib.conv_templates:
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
else:
conversation_lib.default_conversation = conversation_lib.conv_templates["llama3"]
rank0_print(f"Using conversation format: {conversation_lib.default_conversation.version}")
# if model_args.vision_tower is not None:
model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp)
image_processor = model.get_model().get_vision_tower().image_processor
data_args.gen_image_processor = image_processor
data_args.image_processor = image_processor
data_args.is_multimodal = True
data_args.n_query = model_args.n_query
data_args.n_und_query = model_args.n_und_query
model.config.image_aspect_ratio = data_args.image_aspect_ratio
model.config.tokenizer_padding_side = tokenizer.padding_side
model.config.tokenizer_model_max_length = tokenizer.model_max_length
model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
# Calculate total parameters and trainable parameters
total_params = sum(p.numel() for p in model.get_model().parameters())
trainable_params = sum(p.numel() for p in model.get_model().parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")
model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_projector_lr = training_args.mm_projector_lr
training_args.use_im_start_end = model_args.mm_use_im_start_end
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
# TODO: what is this?
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
model.config.pad_token_id = tokenizer.pad_token_id
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
trainer = blip3oTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
#callbacks=[GradCheckCallback],
**data_module,
)
from tabulate import tabulate
if trainer.is_world_process_zero():
stat = []
for i, (n, p) in enumerate(trainer.model.named_parameters()):
stat.append([i, n, p.shape, p.requires_grad])
print(tabulate(stat, headers=["idx", "name", "shape", "trainable"]))
'''
from safetensors.torch import load_file
import json
import pathlib
# ---- Load model.safetensors if it exists ----
checkpoint_dir = pathlib.Path(training_args.output_dir)
safetensor_path = checkpoint_dir / "model.safetensors"
trainer_state_path = checkpoint_dir / "trainer_state.json"
if safetensor_path.exists():
print(f"Loading weights from {safetensor_path}")
state_dict = load_file(safetensor_path)
new_state_dict = {}
for k, v in state_dict.items():
new_key = k.replace("model.", "", 1) if k.startswith("model.") else k
new_state_dict[new_key] = v
# print all keys
#print("🔑 Keys in checkpoint:")
#for k in state_dict.keys():
# print(k, state_dict[k].shape)
missing, unexpected = model.get_model().load_state_dict(new_state_dict, strict=False)
print("✅ Loaded parameters:")
for k in new_state_dict.keys():
if k not in missing:
print(f" {k} {tuple(new_state_dict[k].shape)}")
# Restore last global step
if trainer_state_path.exists():
with open(trainer_state_path, "r") as f:
trainer_state = json.load(f)
last_global_step = trainer_state.get("global_step", 0)
last_lr = trainer_state.get("learning_rate", trainer.args.learning_rate)
trainer.state.global_step = last_global_step
# Reset optimizer with last learning rate
trainer.create_optimizer_and_scheduler(num_training_steps=trainer.args.max_steps)
optimizer = trainer.optimizer
#lr_scheduler = trainer.lr_scheduler
for param_group in optimizer.param_groups:
param_group['lr'] = last_lr
trainer.optimizer = optimizer
print(f"✅ Restored global step: {last_global_step}, learning rate: {last_lr}")
'''
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
model.config.use_cache = True
safe_save_model_for_hf_trainer(
trainer=trainer,
output_dir=training_args.output_dir,
vision_tower=model_args.vision_tower,
)
if __name__ == "__main__":
train()