File size: 9,819 Bytes
6f287f0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 | import os
import torch
from peft import LoraConfig, get_peft_model
import ast
from transformers import AutoProcessor, BitsAndBytesConfig, Qwen2VLForConditionalGeneration, HfArgumentParser, Qwen2_5_VLForConditionalGeneration
from training.trainer import QwenTrainer
from training.data import make_supervised_data_module
from training.params import DataArguments, ModelArguments, TrainingArguments
from training.train_utils import get_peft_state_maybe_zero_3, get_peft_state_non_lora_maybe_zero_3, safe_save_model_for_hf_trainer
import pathlib
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl, apply_liger_kernel_to_qwen2_5_vl
from monkey_patch_forward import replace_qwen2_5_with_mixed_modality_forward, replace_qwen_2_with_mixed_modality_forward
local_rank = None
def rank0_print(*args):
if local_rank == 0 or local_rank == '0' or local_rank is None:
print(*args)
def find_target_linear_names(model, num_lora_modules=-1, lora_namespan_exclude=[], verbose=True):
linear_cls = torch.nn.modules.Linear
embedding_cls = torch.nn.modules.Embedding
lora_module_names = []
for name, module in model.named_modules():
if any(ex_keyword in name for ex_keyword in lora_namespan_exclude):
continue
if isinstance(module, (linear_cls, embedding_cls)):
lora_module_names.append(name)
if num_lora_modules > 0:
lora_module_names = lora_module_names[-num_lora_modules:]
if verbose:
rank0_print(f"Found {len(lora_module_names)} lora modules: {lora_module_names}")
return lora_module_names
def set_requires_grad(parameters, requires_grad):
for p in parameters:
p.requires_grad = requires_grad
def configure_vision_tower(model, training_args, compute_dtype, device):
vision_tower = model.visual
vision_tower.to(dtype=compute_dtype, device=device)
vision_model_params = model.visual.parameters()
set_requires_grad(vision_model_params, not training_args.freeze_vision_tower)
# Handle merger specifically
merger_params = model.visual.merger.parameters()
set_requires_grad(merger_params, training_args.tune_merger)
def configure_llm(model, training_args):
lm_head = model.lm_head.parameters()
set_requires_grad(lm_head, not training_args.freeze_llm)
llm_params = model.model.parameters()
set_requires_grad(llm_params, not training_args.freeze_llm)
def train():
global local_rank
parser = HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
use_liger = training_args.use_liger
if "Qwen2.5" in model_args.model_id:
# It monkey patches the forward to handle mixed modality inputs.
replace_qwen2_5_with_mixed_modality_forward(use_liger=use_liger)
# This is becuase mixed-modality training monkey-patches the model forward method.
if use_liger:
apply_liger_kernel_to_qwen2_5_vl(fused_linear_cross_entropy=False)
else:
# It monkey patches the forward to handle mixed modality inputs.
replace_qwen_2_with_mixed_modality_forward(use_liger=use_liger)
# This is becuase mixed-modality training monkey-patches the model forward method.
if use_liger:
apply_liger_kernel_to_qwen2_vl(fused_linear_cross_entropy=False)
if training_args.lora_enable and not training_args.freeze_llm:
raise ValueError("If `lora_enable` is True, `freeze_llm` must also be True.")
if not training_args.lora_enable:
assert not training_args.vision_lora, \
"Error: training_args.lora_enable is not enabled, but training_args.vision_lora is enabled."
if training_args.vision_lora and not training_args.freeze_vision_tower:
raise ValueError("If `vision_lora` is True, `freeze_vision_tower` must also be True.")
else:
if training_args.lora_namespan_exclude is not None:
training_args.lora_namespan_exclude = ast.literal_eval(training_args.lora_namespan_exclude)
else:
training_args.lora_namespan_exclude = []
if not training_args.vision_lora:
training_args.lora_namespan_exclude += ["visual"]
local_rank = training_args.local_rank
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
bnb_model_from_pretrained_args = {}
if training_args.bits in [4,8]:
bnb_model_from_pretrained_args.update(dict(
device_map={"":training_args.device},
quantization_config = BitsAndBytesConfig(
load_in_4bit=training_args.bits==4,
load_in_8bit=training_args.bits==8,
llm_int8_skip_modules=["visual"],
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type,
)
))
if "Qwen2.5" in model_args.model_id:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_args.model_id,
torch_dtype=compute_dtype,
attn_implementation="flash_attention_2" if not training_args.disable_flash_attn2 else "sdpa",
**bnb_model_from_pretrained_args
)
else:
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_args.model_id,
torch_dtype=compute_dtype,
attn_implementation="flash_attention_2" if not training_args.disable_flash_attn2 else "sdpa",
**bnb_model_from_pretrained_args
)
model.config.use_cache = False
model_to_configure = model
configure_llm(model_to_configure, training_args)
configure_vision_tower(model_to_configure, training_args, compute_dtype, training_args.device)
if training_args.bits in [4,8]:
model.config.torch_dtype = (torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
from peft import prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing, gradient_checkpointing_kwargs={"use_reentrant": True})
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
if training_args.lora_enable:
lora_namespan_exclude = training_args.lora_namespan_exclude
peft_config = LoraConfig(
r=training_args.lora_rank,
lora_alpha=training_args.lora_alpha,
target_modules=find_target_linear_names(model, lora_namespan_exclude=lora_namespan_exclude, num_lora_modules=training_args.num_lora_modules),
lora_dropout=training_args.lora_dropout,
bias=training_args.lora_bias
)
if training_args.bits == 16:
if training_args.bf16:
model.to(torch.bfloat16)
if training_args.fp16:
model.to(torch.float16)
rank0_print("Adding LoRA to the model...")
model = get_peft_model(model, peft_config)
processor = AutoProcessor.from_pretrained(model_args.model_id,
# The default setting is padding_side="left"
# When training using the right-side padding is more efficient.
padding_side="right")
# model.config.tokenizer_model_max_length = processor.tokenizer.model_max_length
model.config.tokenizer_padding_side = processor.tokenizer.padding_side
model.config.vision_lr = training_args.vision_lr
if training_args.bits in [4, 8]:
from peft.tuners.lora import LoraLayer
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if training_args.bf16:
module = module.to(torch.bfloat16)
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_token' in name:
if hasattr(module, 'weight'):
if training_args.bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
data_module = make_supervised_data_module(model_id=model_args.model_id,
processor=processor,
data_args=data_args)
trainer = QwenTrainer(
model=model,
processor=processor,
args=training_args,
**data_module
)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
model.config.use_cache = True
if training_args.lora_enable:
state_dict = get_peft_state_maybe_zero_3(
model.named_parameters(), training_args.lora_bias
)
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
model.named_parameters(), require_grad_only=False
)
if local_rank == 0 or local_rank == -1:
model.config.save_pretrained(training_args.output_dir)
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, "non_lora_state_dict.bin"))
else:
safe_save_model_for_hf_trainer(trainer, output_dir=training_args.output_dir)
if __name__ == "__main__":
train() |