fffiloni's picture
Migrated from GitHub
34c9227 verified
raw
history blame
9.36 kB
import os
import glob
from dataclasses import dataclass, field
from typing import List, Literal, Optional
import safetensors
import torch
from transformers import TrainingArguments
########## DataClass For Configure ##########
@dataclass
class TrainingConfig(TrainingArguments):
max_length: Optional[int] = None
dataset_num_proc: Optional[int] = None
center_rewards_coefficient: Optional[float] = None
disable_flash_attn2: bool = field(default=False)
vision_lr: Optional[float] = None
merger_lr: Optional[float] = None
special_token_lr: Optional[float] = None
conduct_eval: Optional[bool] = True
load_from_pretrained: str = None
load_from_pretrained_step: int = None
logging_epochs: Optional[float] = None
eval_epochs: Optional[float] = None
save_epochs: Optional[float] = None
remove_unused_columns: Optional[bool] = False
save_full_model: Optional[bool] = False
@dataclass
class PEFTLoraConfig:
lora_enable: bool = False
vision_lora: bool = False
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.05
lora_target_modules: Optional[List[str]] = None
lora_namespan_exclude: Optional[List[str]] = None
lora_modules_to_save: Optional[List[str]] = None
lora_task_type: str = "CAUSAL_LM"
use_rslora: bool = False
num_lora_modules: int = -1
def __post_init__(self):
if isinstance(self.lora_target_modules, list) and len(self.lora_target_modules) == 1:
self.lora_target_modules = self.lora_target_modules[0]
if isinstance(self.lora_namespan_exclude, list) and len(self.lora_namespan_exclude) == 1:
self.lora_namespan_exclude = self.lora_namespan_exclude[0]
@dataclass
class ModelConfig:
model_name_or_path: Optional[str] = None
model_revision: str = "main"
output_dim: int = 1
use_special_tokens: bool = False
freeze_vision_tower: bool = field(default=False)
freeze_llm: bool = field(default=False)
tune_merger: bool = field(default=False)
torch_dtype: Optional[Literal["auto", "bfloat16", "float16", "float32"]] = None
trust_remote_code: bool = False
attn_implementation: Optional[str] = None
load_in_8bit: bool = False
load_in_4bit: bool = False
bnb_4bit_quant_type: Literal["fp4", "nf4"] = "nf4"
use_bnb_nested_quant: bool = False
reward_token: Literal["last", "mean", "special"] = "last"
loss_type: Literal["bt", "reg", "btt", "margin", "constant_margin", "scaled"] = "regular"
def __post_init__(self):
if self.load_in_8bit and self.load_in_4bit:
raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
# if isinstance(self.lora_target_modules, list) and len(self.lora_target_modules) == 1:
# self.lora_target_modules = self.lora_target_modules[0]
# if isinstance(self.lora_namespan_exclude, list) and len(self.lora_namespan_exclude) == 1:
# self.lora_namespan_exclude = self.lora_namespan_exclude[0]
########## Functions for get trainable modules' parameters ##########
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
to_return = {k: t for k, t in named_params if "lora_" not in k}
if require_grad_only:
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
########## Load Models From Folder ##########
def _insert_adapter_name_into_state_dict(
state_dict: dict[str, torch.Tensor], adapter_name: str, parameter_prefix: str
) -> dict[str, torch.Tensor]:
"""Utility function to remap the state_dict keys to fit the PEFT model by inserting the adapter name."""
peft_model_state_dict = {}
for key, val in state_dict.items():
if parameter_prefix in key:
suffix = key.split(parameter_prefix)[1]
if "." in suffix:
suffix_to_replace = ".".join(suffix.split(".")[1:])
key = key.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}")
else:
key = f"{key}.{adapter_name}"
peft_model_state_dict[key] = val
else:
peft_model_state_dict[key] = val
return peft_model_state_dict
def save_video(tensor, path):
from torchvision.io import write_video
tensor = tensor * 255.0
tensor = tensor.permute(0, 2, 3, 1)
tensor = tensor.clamp(0, 255).byte()
write_video(path, tensor, 4, video_codec='h264')
def load_model_from_checkpoint(
model, checkpoint_dir, checkpoint_step
):
checkpoint_paths = glob.glob(os.path.join(checkpoint_dir, "checkpoint-*"))
checkpoint_paths.sort(key=lambda x: int(x.split("-")[-1]), reverse=True)
if checkpoint_step is None or checkpoint_step == -1:
# get the latest checkpoint
checkpoint_path = checkpoint_paths[0]
print(f"===> Checkpoint step is not provided, using the latest checkpoint: {checkpoint_path}")
else:
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint-{checkpoint_step}")
if checkpoint_path not in checkpoint_paths:
checkpoint_path = checkpoint_paths[0]
print(f"===> Checkpoint step {checkpoint_step} not found, using the latest checkpoint: {checkpoint_path}")
else:
print(f"===> Checkpoint step {checkpoint_step} found, using the specified checkpoint: {checkpoint_path}")
checkpoint_step = checkpoint_path.split("checkpoint-")[-1].split("/")[0]
full_ckpt = os.path.join(checkpoint_path, "model.pth")
lora_ckpt = os.path.join(checkpoint_path, "adapter_model.safetensors")
non_lora_ckpt = os.path.join(checkpoint_path, "non_lora_state_dict.pth")
if os.path.exists(full_ckpt):
model_state_dict = torch.load(full_ckpt, map_location="cpu", weights_only=True)
# Create a new state_dict to store the modified key-value pairs
new_state_dict = {}
# for key, value in model_state_dict.items():
# if key.startswith("base_model.model.model"):
# new_key = "base_model.model.model.language_model" + key[len("base_model.model.model"):]
# new_state_dict[new_key] = value
# elif key.startswith("base_model.model.visual"):
# new_key = "base_model.model.model.visual" + key[len("base_model.model.visual"):]
# new_state_dict[new_key] = value
# else:
# new_state_dict[key] = value
for key, value in model_state_dict.items():
if key.startswith("base_model.model.model"):
new_key = "base_model.model.model.language_model" + key[len("base_model.model.model"):]
new_state_dict[new_key] = value
elif key.startswith("base_model.model.visual"):
new_key = "base_model.model.model.visual" + key[len("base_model.model.visual"):]
new_state_dict[new_key] = value
else:
new_state_dict[key] = value
# Load the modified state_dict into the model
model.load_state_dict(new_state_dict)
# model_state_dict = torch.load(full_ckpt, map_location="cpu")
# model.load_state_dict(model_state_dict)
else:
lora_state_dict = safetensors.torch.load_file(lora_ckpt)
non_lora_state_dict = torch.load(non_lora_ckpt, map_location="cpu")
lora_state_dict = _insert_adapter_name_into_state_dict(lora_state_dict, adapter_name="default", parameter_prefix="lora_")
model_state_dict = model.state_dict()
model_state_dict.update(non_lora_state_dict)
model_state_dict.update(lora_state_dict)
model.load_state_dict(model_state_dict)
return model, checkpoint_step