litwell's picture
Upload models/src/training/train.py with huggingface_hub
6f287f0 verified
import os
import torch
from peft import LoraConfig, get_peft_model
import ast
from transformers import AutoProcessor, BitsAndBytesConfig, Qwen2VLForConditionalGeneration, HfArgumentParser, Qwen2_5_VLForConditionalGeneration
from training.trainer import QwenTrainer
from training.data import make_supervised_data_module
from training.params import DataArguments, ModelArguments, TrainingArguments
from training.train_utils import get_peft_state_maybe_zero_3, get_peft_state_non_lora_maybe_zero_3, safe_save_model_for_hf_trainer
import pathlib
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl, apply_liger_kernel_to_qwen2_5_vl
from monkey_patch_forward import replace_qwen2_5_with_mixed_modality_forward, replace_qwen_2_with_mixed_modality_forward
local_rank = None
def rank0_print(*args):
if local_rank == 0 or local_rank == '0' or local_rank is None:
print(*args)
def find_target_linear_names(model, num_lora_modules=-1, lora_namespan_exclude=[], verbose=True):
linear_cls = torch.nn.modules.Linear
embedding_cls = torch.nn.modules.Embedding
lora_module_names = []
for name, module in model.named_modules():
if any(ex_keyword in name for ex_keyword in lora_namespan_exclude):
continue
if isinstance(module, (linear_cls, embedding_cls)):
lora_module_names.append(name)
if num_lora_modules > 0:
lora_module_names = lora_module_names[-num_lora_modules:]
if verbose:
rank0_print(f"Found {len(lora_module_names)} lora modules: {lora_module_names}")
return lora_module_names
def set_requires_grad(parameters, requires_grad):
for p in parameters:
p.requires_grad = requires_grad
def configure_vision_tower(model, training_args, compute_dtype, device):
vision_tower = model.visual
vision_tower.to(dtype=compute_dtype, device=device)
vision_model_params = model.visual.parameters()
set_requires_grad(vision_model_params, not training_args.freeze_vision_tower)
# Handle merger specifically
merger_params = model.visual.merger.parameters()
set_requires_grad(merger_params, training_args.tune_merger)
def configure_llm(model, training_args):
lm_head = model.lm_head.parameters()
set_requires_grad(lm_head, not training_args.freeze_llm)
llm_params = model.model.parameters()
set_requires_grad(llm_params, not training_args.freeze_llm)
def train():
global local_rank
parser = HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
use_liger = training_args.use_liger
if "Qwen2.5" in model_args.model_id:
# It monkey patches the forward to handle mixed modality inputs.
replace_qwen2_5_with_mixed_modality_forward(use_liger=use_liger)
# This is becuase mixed-modality training monkey-patches the model forward method.
if use_liger:
apply_liger_kernel_to_qwen2_5_vl(fused_linear_cross_entropy=False)
else:
# It monkey patches the forward to handle mixed modality inputs.
replace_qwen_2_with_mixed_modality_forward(use_liger=use_liger)
# This is becuase mixed-modality training monkey-patches the model forward method.
if use_liger:
apply_liger_kernel_to_qwen2_vl(fused_linear_cross_entropy=False)
if training_args.lora_enable and not training_args.freeze_llm:
raise ValueError("If `lora_enable` is True, `freeze_llm` must also be True.")
if not training_args.lora_enable:
assert not training_args.vision_lora, \
"Error: training_args.lora_enable is not enabled, but training_args.vision_lora is enabled."
if training_args.vision_lora and not training_args.freeze_vision_tower:
raise ValueError("If `vision_lora` is True, `freeze_vision_tower` must also be True.")
else:
if training_args.lora_namespan_exclude is not None:
training_args.lora_namespan_exclude = ast.literal_eval(training_args.lora_namespan_exclude)
else:
training_args.lora_namespan_exclude = []
if not training_args.vision_lora:
training_args.lora_namespan_exclude += ["visual"]
local_rank = training_args.local_rank
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
bnb_model_from_pretrained_args = {}
if training_args.bits in [4,8]:
bnb_model_from_pretrained_args.update(dict(
device_map={"":training_args.device},
quantization_config = BitsAndBytesConfig(
load_in_4bit=training_args.bits==4,
load_in_8bit=training_args.bits==8,
llm_int8_skip_modules=["visual"],
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type,
)
))
if "Qwen2.5" in model_args.model_id:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_args.model_id,
torch_dtype=compute_dtype,
attn_implementation="flash_attention_2" if not training_args.disable_flash_attn2 else "sdpa",
**bnb_model_from_pretrained_args
)
else:
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_args.model_id,
torch_dtype=compute_dtype,
attn_implementation="flash_attention_2" if not training_args.disable_flash_attn2 else "sdpa",
**bnb_model_from_pretrained_args
)
model.config.use_cache = False
model_to_configure = model
configure_llm(model_to_configure, training_args)
configure_vision_tower(model_to_configure, training_args, compute_dtype, training_args.device)
if training_args.bits in [4,8]:
model.config.torch_dtype = (torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
from peft import prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing, gradient_checkpointing_kwargs={"use_reentrant": True})
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
if training_args.lora_enable:
lora_namespan_exclude = training_args.lora_namespan_exclude
peft_config = LoraConfig(
r=training_args.lora_rank,
lora_alpha=training_args.lora_alpha,
target_modules=find_target_linear_names(model, lora_namespan_exclude=lora_namespan_exclude, num_lora_modules=training_args.num_lora_modules),
lora_dropout=training_args.lora_dropout,
bias=training_args.lora_bias
)
if training_args.bits == 16:
if training_args.bf16:
model.to(torch.bfloat16)
if training_args.fp16:
model.to(torch.float16)
rank0_print("Adding LoRA to the model...")
model = get_peft_model(model, peft_config)
processor = AutoProcessor.from_pretrained(model_args.model_id,
# The default setting is padding_side="left"
# When training using the right-side padding is more efficient.
padding_side="right")
# model.config.tokenizer_model_max_length = processor.tokenizer.model_max_length
model.config.tokenizer_padding_side = processor.tokenizer.padding_side
model.config.vision_lr = training_args.vision_lr
if training_args.bits in [4, 8]:
from peft.tuners.lora import LoraLayer
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if training_args.bf16:
module = module.to(torch.bfloat16)
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_token' in name:
if hasattr(module, 'weight'):
if training_args.bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
data_module = make_supervised_data_module(model_id=model_args.model_id,
processor=processor,
data_args=data_args)
trainer = QwenTrainer(
model=model,
processor=processor,
args=training_args,
**data_module
)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
model.config.use_cache = True
if training_args.lora_enable:
state_dict = get_peft_state_maybe_zero_3(
model.named_parameters(), training_args.lora_bias
)
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
model.named_parameters(), require_grad_only=False
)
if local_rank == 0 or local_rank == -1:
model.config.save_pretrained(training_args.output_dir)
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, "non_lora_state_dict.bin"))
else:
safe_save_model_for_hf_trainer(trainer, output_dir=training_args.output_dir)
if __name__ == "__main__":
train()