|
|
import os |
|
|
import torch |
|
|
from peft import LoraConfig, get_peft_model |
|
|
import ast |
|
|
from transformers import AutoProcessor, BitsAndBytesConfig, Qwen2VLForConditionalGeneration, HfArgumentParser, Qwen2_5_VLForConditionalGeneration |
|
|
from training.trainer import QwenTrainer |
|
|
from training.data import make_supervised_data_module |
|
|
from training.params import DataArguments, ModelArguments, TrainingArguments |
|
|
from training.train_utils import get_peft_state_maybe_zero_3, get_peft_state_non_lora_maybe_zero_3, safe_save_model_for_hf_trainer |
|
|
import pathlib |
|
|
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl, apply_liger_kernel_to_qwen2_5_vl |
|
|
from monkey_patch_forward import replace_qwen2_5_with_mixed_modality_forward, replace_qwen_2_with_mixed_modality_forward |
|
|
|
|
|
local_rank = None |
|
|
|
|
|
def rank0_print(*args): |
|
|
if local_rank == 0 or local_rank == '0' or local_rank is None: |
|
|
print(*args) |
|
|
|
|
|
def find_target_linear_names(model, num_lora_modules=-1, lora_namespan_exclude=[], verbose=True): |
|
|
linear_cls = torch.nn.modules.Linear |
|
|
embedding_cls = torch.nn.modules.Embedding |
|
|
lora_module_names = [] |
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if any(ex_keyword in name for ex_keyword in lora_namespan_exclude): |
|
|
continue |
|
|
if isinstance(module, (linear_cls, embedding_cls)): |
|
|
lora_module_names.append(name) |
|
|
|
|
|
if num_lora_modules > 0: |
|
|
lora_module_names = lora_module_names[-num_lora_modules:] |
|
|
if verbose: |
|
|
rank0_print(f"Found {len(lora_module_names)} lora modules: {lora_module_names}") |
|
|
return lora_module_names |
|
|
|
|
|
def set_requires_grad(parameters, requires_grad): |
|
|
for p in parameters: |
|
|
p.requires_grad = requires_grad |
|
|
|
|
|
def configure_vision_tower(model, training_args, compute_dtype, device): |
|
|
vision_tower = model.visual |
|
|
vision_tower.to(dtype=compute_dtype, device=device) |
|
|
|
|
|
vision_model_params = model.visual.parameters() |
|
|
set_requires_grad(vision_model_params, not training_args.freeze_vision_tower) |
|
|
|
|
|
|
|
|
merger_params = model.visual.merger.parameters() |
|
|
set_requires_grad(merger_params, training_args.tune_merger) |
|
|
|
|
|
def configure_llm(model, training_args): |
|
|
lm_head = model.lm_head.parameters() |
|
|
set_requires_grad(lm_head, not training_args.freeze_llm) |
|
|
|
|
|
llm_params = model.model.parameters() |
|
|
set_requires_grad(llm_params, not training_args.freeze_llm) |
|
|
|
|
|
|
|
|
def train(): |
|
|
global local_rank |
|
|
|
|
|
parser = HfArgumentParser( |
|
|
(ModelArguments, DataArguments, TrainingArguments)) |
|
|
|
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
use_liger = training_args.use_liger |
|
|
if "Qwen2.5" in model_args.model_id: |
|
|
|
|
|
replace_qwen2_5_with_mixed_modality_forward(use_liger=use_liger) |
|
|
|
|
|
if use_liger: |
|
|
apply_liger_kernel_to_qwen2_5_vl(fused_linear_cross_entropy=False) |
|
|
else: |
|
|
|
|
|
replace_qwen_2_with_mixed_modality_forward(use_liger=use_liger) |
|
|
|
|
|
if use_liger: |
|
|
apply_liger_kernel_to_qwen2_vl(fused_linear_cross_entropy=False) |
|
|
|
|
|
|
|
|
if training_args.lora_enable and not training_args.freeze_llm: |
|
|
raise ValueError("If `lora_enable` is True, `freeze_llm` must also be True.") |
|
|
|
|
|
if not training_args.lora_enable: |
|
|
assert not training_args.vision_lora, \ |
|
|
"Error: training_args.lora_enable is not enabled, but training_args.vision_lora is enabled." |
|
|
|
|
|
if training_args.vision_lora and not training_args.freeze_vision_tower: |
|
|
raise ValueError("If `vision_lora` is True, `freeze_vision_tower` must also be True.") |
|
|
|
|
|
else: |
|
|
if training_args.lora_namespan_exclude is not None: |
|
|
training_args.lora_namespan_exclude = ast.literal_eval(training_args.lora_namespan_exclude) |
|
|
else: |
|
|
training_args.lora_namespan_exclude = [] |
|
|
|
|
|
if not training_args.vision_lora: |
|
|
training_args.lora_namespan_exclude += ["visual"] |
|
|
|
|
|
local_rank = training_args.local_rank |
|
|
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) |
|
|
|
|
|
bnb_model_from_pretrained_args = {} |
|
|
if training_args.bits in [4,8]: |
|
|
bnb_model_from_pretrained_args.update(dict( |
|
|
device_map={"":training_args.device}, |
|
|
quantization_config = BitsAndBytesConfig( |
|
|
load_in_4bit=training_args.bits==4, |
|
|
load_in_8bit=training_args.bits==8, |
|
|
llm_int8_skip_modules=["visual"], |
|
|
llm_int8_threshold=6.0, |
|
|
llm_int8_has_fp16_weight=False, |
|
|
bnb_4bit_compute_dtype=compute_dtype, |
|
|
bnb_4bit_use_double_quant=training_args.double_quant, |
|
|
bnb_4bit_quant_type=training_args.quant_type, |
|
|
) |
|
|
)) |
|
|
|
|
|
if "Qwen2.5" in model_args.model_id: |
|
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
|
model_args.model_id, |
|
|
torch_dtype=compute_dtype, |
|
|
attn_implementation="flash_attention_2" if not training_args.disable_flash_attn2 else "sdpa", |
|
|
**bnb_model_from_pretrained_args |
|
|
) |
|
|
else: |
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
|
model_args.model_id, |
|
|
torch_dtype=compute_dtype, |
|
|
attn_implementation="flash_attention_2" if not training_args.disable_flash_attn2 else "sdpa", |
|
|
**bnb_model_from_pretrained_args |
|
|
) |
|
|
|
|
|
model.config.use_cache = False |
|
|
model_to_configure = model |
|
|
configure_llm(model_to_configure, training_args) |
|
|
configure_vision_tower(model_to_configure, training_args, compute_dtype, training_args.device) |
|
|
|
|
|
if training_args.bits in [4,8]: |
|
|
model.config.torch_dtype = (torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) |
|
|
from peft import prepare_model_for_kbit_training |
|
|
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing, gradient_checkpointing_kwargs={"use_reentrant": True}) |
|
|
|
|
|
if training_args.gradient_checkpointing: |
|
|
model.enable_input_require_grads() |
|
|
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} |
|
|
|
|
|
if training_args.lora_enable: |
|
|
lora_namespan_exclude = training_args.lora_namespan_exclude |
|
|
peft_config = LoraConfig( |
|
|
r=training_args.lora_rank, |
|
|
lora_alpha=training_args.lora_alpha, |
|
|
target_modules=find_target_linear_names(model, lora_namespan_exclude=lora_namespan_exclude, num_lora_modules=training_args.num_lora_modules), |
|
|
lora_dropout=training_args.lora_dropout, |
|
|
bias=training_args.lora_bias |
|
|
) |
|
|
if training_args.bits == 16: |
|
|
if training_args.bf16: |
|
|
model.to(torch.bfloat16) |
|
|
if training_args.fp16: |
|
|
model.to(torch.float16) |
|
|
rank0_print("Adding LoRA to the model...") |
|
|
model = get_peft_model(model, peft_config) |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(model_args.model_id, |
|
|
|
|
|
|
|
|
padding_side="right") |
|
|
|
|
|
|
|
|
model.config.tokenizer_padding_side = processor.tokenizer.padding_side |
|
|
model.config.vision_lr = training_args.vision_lr |
|
|
|
|
|
if training_args.bits in [4, 8]: |
|
|
from peft.tuners.lora import LoraLayer |
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, LoraLayer): |
|
|
if training_args.bf16: |
|
|
module = module.to(torch.bfloat16) |
|
|
if 'norm' in name: |
|
|
module = module.to(torch.float32) |
|
|
|
|
|
if 'lm_head' in name or 'embed_token' in name: |
|
|
if hasattr(module, 'weight'): |
|
|
if training_args.bf16 and module.weight.dtype == torch.float32: |
|
|
module = module.to(torch.bfloat16) |
|
|
|
|
|
data_module = make_supervised_data_module(model_id=model_args.model_id, |
|
|
processor=processor, |
|
|
data_args=data_args) |
|
|
|
|
|
trainer = QwenTrainer( |
|
|
model=model, |
|
|
processor=processor, |
|
|
args=training_args, |
|
|
**data_module |
|
|
) |
|
|
|
|
|
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): |
|
|
trainer.train(resume_from_checkpoint=True) |
|
|
else: |
|
|
trainer.train() |
|
|
|
|
|
trainer.save_state() |
|
|
|
|
|
model.config.use_cache = True |
|
|
|
|
|
if training_args.lora_enable: |
|
|
state_dict = get_peft_state_maybe_zero_3( |
|
|
model.named_parameters(), training_args.lora_bias |
|
|
) |
|
|
|
|
|
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( |
|
|
model.named_parameters(), require_grad_only=False |
|
|
) |
|
|
|
|
|
if local_rank == 0 or local_rank == -1: |
|
|
model.config.save_pretrained(training_args.output_dir) |
|
|
model.save_pretrained(training_args.output_dir, state_dict=state_dict) |
|
|
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, "non_lora_state_dict.bin")) |
|
|
else: |
|
|
safe_save_model_for_hf_trainer(trainer, output_dir=training_args.output_dir) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |