| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from dataclasses import dataclass, field |
| import logging |
| import pathlib |
| import typing |
|
|
| from deepspeed import zero |
| from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus |
| from peft import LoraConfig, get_peft_model |
| import transformers |
| from transformers import Trainer |
|
|
| from fastchat.train.train import ( |
| DataArguments, |
| ModelArguments, |
| TrainingArguments, |
| make_supervised_data_module, |
| ) |
|
|
| from fastchat.train.llama_flash_attn_monkey_patch import ( |
| replace_llama_attn_with_flash_attn, |
| ) |
|
|
| replace_llama_attn_with_flash_attn() |
|
|
|
|
| @dataclass |
| class LoraArguments: |
| lora_r: int = 8 |
| lora_alpha: int = 16 |
| lora_dropout: float = 0.05 |
| lora_target_modules: typing.List[str] = field( |
| default_factory=lambda: ["q_proj", "v_proj"] |
| ) |
| lora_weight_path: str = "" |
| bias: str = "none" |
|
|
|
|
| def maybe_zero_3(param): |
| if hasattr(param, "ds_id"): |
| assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE |
| with zero.GatheredParameters([param]): |
| param = param.data.cpu().clone().detach() |
| return param |
|
|
|
|
| |
| def get_peft_state_maybe_zero_3(state_dict, bias): |
| if bias == "none": |
| to_return = { |
| k: state_dict[k].cpu().clone().detach() for k in state_dict if "lora_" in k |
| } |
| elif bias == "all": |
| to_return = { |
| k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k |
| } |
| elif bias == "lora_only": |
| to_return = {} |
| for k in state_dict: |
| if "lora_" in k: |
| to_return[k] = state_dict[k] |
| bias_name = k.split("lora_")[0] + "bias" |
| if bias_name in state_dict: |
| to_return[bias_name] = state_dict[bias_name] |
| else: |
| raise NotImplementedError |
| to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} |
| return to_return |
|
|
|
|
| def train(): |
| parser = transformers.HfArgumentParser( |
| (ModelArguments, DataArguments, TrainingArguments, LoraArguments) |
| ) |
| ( |
| model_args, |
| data_args, |
| training_args, |
| lora_args, |
| ) = parser.parse_args_into_dataclasses() |
|
|
| model = transformers.AutoModelForCausalLM.from_pretrained( |
| model_args.model_name_or_path, |
| cache_dir=training_args.cache_dir, |
| ) |
| lora_config = LoraConfig( |
| r=lora_args.lora_r, |
| lora_alpha=lora_args.lora_alpha, |
| target_modules=lora_args.lora_target_modules, |
| lora_dropout=lora_args.lora_dropout, |
| bias=lora_args.bias, |
| task_type="CAUSAL_LM", |
| ) |
| model = get_peft_model(model, lora_config) |
| if training_args.deepspeed is not None and training_args.local_rank == 0: |
| model.print_trainable_parameters() |
|
|
| if training_args.gradient_checkpointing: |
| logging.warning( |
| "gradient checkpointing with lora makes requires_grad " |
| "incorrect and needs a monkey patch in Trainer or the " |
| "wrapped model's forward. ref: " |
| "https://github.com/lm-sys/FastChat/pull/138#issuecomment-1509172198" |
| ) |
|
|
| tokenizer = transformers.AutoTokenizer.from_pretrained( |
| model_args.model_name_or_path, |
| cache_dir=training_args.cache_dir, |
| model_max_length=training_args.model_max_length, |
| padding_side="right", |
| use_fast=False, |
| ) |
| tokenizer.pad_token = tokenizer.unk_token |
|
|
| data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) |
| trainer = Trainer( |
| model=model, tokenizer=tokenizer, args=training_args, **data_module |
| ) |
|
|
| model.config.use_cache = False |
|
|
| if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): |
| trainer.train(resume_from_checkpoint=True) |
| else: |
| trainer.train() |
| trainer.save_state() |
|
|
| |
| state_dict = get_peft_state_maybe_zero_3(model.state_dict(), lora_args.bias) |
| if training_args.local_rank == 0: |
| model.save_pretrained(training_args.output_dir, state_dict=state_dict) |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|