| |
| import transformers |
| from transformers import Trainer |
|
|
| from xtuner.apis import DefaultTrainingArguments, build_model |
| from xtuner.apis.datasets import alpaca_data_collator, alpaca_dataset |
|
|
|
|
| def train(): |
| |
| parser = transformers.HfArgumentParser(DefaultTrainingArguments) |
| training_args = parser.parse_args_into_dataclasses()[0] |
|
|
| |
| model, tokenizer = build_model( |
| model_name_or_path=training_args.model_name_or_path, |
| return_tokenizer=True) |
| train_dataset = alpaca_dataset( |
| tokenizer=tokenizer, path=training_args.dataset_name_or_path) |
| data_collator = alpaca_data_collator(return_hf_format=True) |
|
|
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| data_collator=data_collator) |
|
|
| |
| trainer.train() |
|
|
| trainer.save_state() |
| trainer.save_model(output_dir=training_args.output_dir) |
|
|
|
|
| if __name__ == '__main__': |
| train() |
|
|