Spaces:
Paused
Paused
File size: 9,362 Bytes
34c9227 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
import os
import glob
from dataclasses import dataclass, field
from typing import List, Literal, Optional
import safetensors
import torch
from transformers import TrainingArguments
########## DataClass For Configure ##########
@dataclass
class TrainingConfig(TrainingArguments):
max_length: Optional[int] = None
dataset_num_proc: Optional[int] = None
center_rewards_coefficient: Optional[float] = None
disable_flash_attn2: bool = field(default=False)
vision_lr: Optional[float] = None
merger_lr: Optional[float] = None
special_token_lr: Optional[float] = None
conduct_eval: Optional[bool] = True
load_from_pretrained: str = None
load_from_pretrained_step: int = None
logging_epochs: Optional[float] = None
eval_epochs: Optional[float] = None
save_epochs: Optional[float] = None
remove_unused_columns: Optional[bool] = False
save_full_model: Optional[bool] = False
@dataclass
class PEFTLoraConfig:
lora_enable: bool = False
vision_lora: bool = False
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.05
lora_target_modules: Optional[List[str]] = None
lora_namespan_exclude: Optional[List[str]] = None
lora_modules_to_save: Optional[List[str]] = None
lora_task_type: str = "CAUSAL_LM"
use_rslora: bool = False
num_lora_modules: int = -1
def __post_init__(self):
if isinstance(self.lora_target_modules, list) and len(self.lora_target_modules) == 1:
self.lora_target_modules = self.lora_target_modules[0]
if isinstance(self.lora_namespan_exclude, list) and len(self.lora_namespan_exclude) == 1:
self.lora_namespan_exclude = self.lora_namespan_exclude[0]
@dataclass
class ModelConfig:
model_name_or_path: Optional[str] = None
model_revision: str = "main"
output_dim: int = 1
use_special_tokens: bool = False
freeze_vision_tower: bool = field(default=False)
freeze_llm: bool = field(default=False)
tune_merger: bool = field(default=False)
torch_dtype: Optional[Literal["auto", "bfloat16", "float16", "float32"]] = None
trust_remote_code: bool = False
attn_implementation: Optional[str] = None
load_in_8bit: bool = False
load_in_4bit: bool = False
bnb_4bit_quant_type: Literal["fp4", "nf4"] = "nf4"
use_bnb_nested_quant: bool = False
reward_token: Literal["last", "mean", "special"] = "last"
loss_type: Literal["bt", "reg", "btt", "margin", "constant_margin", "scaled"] = "regular"
def __post_init__(self):
if self.load_in_8bit and self.load_in_4bit:
raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
# if isinstance(self.lora_target_modules, list) and len(self.lora_target_modules) == 1:
# self.lora_target_modules = self.lora_target_modules[0]
# if isinstance(self.lora_namespan_exclude, list) and len(self.lora_namespan_exclude) == 1:
# self.lora_namespan_exclude = self.lora_namespan_exclude[0]
########## Functions for get trainable modules' parameters ##########
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
to_return = {k: t for k, t in named_params if "lora_" not in k}
if require_grad_only:
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
########## Load Models From Folder ##########
def _insert_adapter_name_into_state_dict(
state_dict: dict[str, torch.Tensor], adapter_name: str, parameter_prefix: str
) -> dict[str, torch.Tensor]:
"""Utility function to remap the state_dict keys to fit the PEFT model by inserting the adapter name."""
peft_model_state_dict = {}
for key, val in state_dict.items():
if parameter_prefix in key:
suffix = key.split(parameter_prefix)[1]
if "." in suffix:
suffix_to_replace = ".".join(suffix.split(".")[1:])
key = key.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}")
else:
key = f"{key}.{adapter_name}"
peft_model_state_dict[key] = val
else:
peft_model_state_dict[key] = val
return peft_model_state_dict
def save_video(tensor, path):
from torchvision.io import write_video
tensor = tensor * 255.0
tensor = tensor.permute(0, 2, 3, 1)
tensor = tensor.clamp(0, 255).byte()
write_video(path, tensor, 4, video_codec='h264')
def load_model_from_checkpoint(
model, checkpoint_dir, checkpoint_step
):
checkpoint_paths = glob.glob(os.path.join(checkpoint_dir, "checkpoint-*"))
checkpoint_paths.sort(key=lambda x: int(x.split("-")[-1]), reverse=True)
if checkpoint_step is None or checkpoint_step == -1:
# get the latest checkpoint
checkpoint_path = checkpoint_paths[0]
print(f"===> Checkpoint step is not provided, using the latest checkpoint: {checkpoint_path}")
else:
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint-{checkpoint_step}")
if checkpoint_path not in checkpoint_paths:
checkpoint_path = checkpoint_paths[0]
print(f"===> Checkpoint step {checkpoint_step} not found, using the latest checkpoint: {checkpoint_path}")
else:
print(f"===> Checkpoint step {checkpoint_step} found, using the specified checkpoint: {checkpoint_path}")
checkpoint_step = checkpoint_path.split("checkpoint-")[-1].split("/")[0]
full_ckpt = os.path.join(checkpoint_path, "model.pth")
lora_ckpt = os.path.join(checkpoint_path, "adapter_model.safetensors")
non_lora_ckpt = os.path.join(checkpoint_path, "non_lora_state_dict.pth")
if os.path.exists(full_ckpt):
model_state_dict = torch.load(full_ckpt, map_location="cpu", weights_only=True)
# Create a new state_dict to store the modified key-value pairs
new_state_dict = {}
# for key, value in model_state_dict.items():
# if key.startswith("base_model.model.model"):
# new_key = "base_model.model.model.language_model" + key[len("base_model.model.model"):]
# new_state_dict[new_key] = value
# elif key.startswith("base_model.model.visual"):
# new_key = "base_model.model.model.visual" + key[len("base_model.model.visual"):]
# new_state_dict[new_key] = value
# else:
# new_state_dict[key] = value
for key, value in model_state_dict.items():
if key.startswith("base_model.model.model"):
new_key = "base_model.model.model.language_model" + key[len("base_model.model.model"):]
new_state_dict[new_key] = value
elif key.startswith("base_model.model.visual"):
new_key = "base_model.model.model.visual" + key[len("base_model.model.visual"):]
new_state_dict[new_key] = value
else:
new_state_dict[key] = value
# Load the modified state_dict into the model
model.load_state_dict(new_state_dict)
# model_state_dict = torch.load(full_ckpt, map_location="cpu")
# model.load_state_dict(model_state_dict)
else:
lora_state_dict = safetensors.torch.load_file(lora_ckpt)
non_lora_state_dict = torch.load(non_lora_ckpt, map_location="cpu")
lora_state_dict = _insert_adapter_name_into_state_dict(lora_state_dict, adapter_name="default", parameter_prefix="lora_")
model_state_dict = model.state_dict()
model_state_dict.update(non_lora_state_dict)
model_state_dict.update(lora_state_dict)
model.load_state_dict(model_state_dict)
return model, checkpoint_step |