| import transformers | |
| from transformers import Trainer | |
| from llm_finetune.arguments import ( | |
| ModelArguments, | |
| DataArguments, | |
| TrainingArguments, | |
| ) | |
| from llm_finetune.dataset import make_supervised_data_module | |
| def train(): | |
| parser = transformers.HfArgumentParser( | |
| (ModelArguments, DataArguments, TrainingArguments) | |
| ) | |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
| model = transformers.AutoModelForCausalLM.from_pretrained( | |
| model_args.model_name_or_path, | |
| cache_dir=training_args.cache_dir, | |
| ) | |
| tokenizer = transformers.AutoTokenizer.from_pretrained( | |
| model_args.model_name_or_path, | |
| cache_dir=training_args.cache_dir, | |
| model_max_length=training_args.model_max_length, | |
| padding_side="right", | |
| use_fast=False, | |
| ) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| data_module = make_supervised_data_module( | |
| tokenizer=tokenizer, | |
| data_args=data_args, | |
| ) | |
| trainer = Trainer( | |
| model=model, tokenizer=tokenizer, args=training_args, **data_module | |
| ) | |
| trainer.train(training_args.checkpoint) | |
| trainer.save_state() | |
| trainer.save_model(output_dir=training_args.output_dir) | |
| if __name__ == "__main__": | |
| train() | |