| import argparse |
| import os |
| import sys |
| from typing import List |
|
|
| import torch |
| import transformers |
|
|
|
|
| from peft import ( |
| TaskType, |
| LoraConfig, |
| get_peft_model, |
| get_peft_model_state_dict, |
| set_peft_model_state_dict, |
| ) |
| from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig |
|
|
| from utils import * |
| from collator import Collator |
|
|
| def train(args): |
|
|
| set_seed(args.seed) |
| ensure_dir(args.output_dir) |
|
|
| device_map = "auto" |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) |
| ddp = world_size != 1 |
| local_rank = int(os.environ.get("LOCAL_RANK") or 0) |
| if local_rank == 0: |
| print(vars(args)) |
|
|
| if ddp: |
| device_map = {"": local_rank} |
|
|
| config = LlamaConfig.from_pretrained(args.base_model) |
| tokenizer = LlamaTokenizer.from_pretrained( |
| args.base_model, |
| model_max_length=args.model_max_length, |
| padding_side="right", |
| ) |
| tokenizer.pad_token_id = 0 |
|
|
| train_data, valid_data = load_datasets(args) |
| add_num = tokenizer.add_tokens(train_data.datasets[0].get_new_tokens()) |
| config.vocab_size = len(tokenizer) |
| if local_rank == 0: |
| print("add {} new token.".format(add_num)) |
| print("data num:", len(train_data)) |
| tokenizer.save_pretrained(args.output_dir) |
| config.save_pretrained(args.output_dir) |
|
|
| collator = Collator(args, tokenizer) |
|
|
| model = LlamaForCausalLM.from_pretrained( |
| args.base_model, |
| |
| device_map=device_map, |
| ) |
| model.resize_token_embeddings(len(tokenizer)) |
|
|
| config = LoraConfig( |
| r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| target_modules=args.lora_target_modules.split(","), |
| modules_to_save=args.lora_modules_to_save.split(","), |
| lora_dropout=args.lora_dropout, |
| bias="none", |
| inference_mode=False, |
| task_type=TaskType.CAUSAL_LM, |
| ) |
| model = get_peft_model(model, config) |
|
|
| if args.resume_from_checkpoint: |
| checkpoint_name = os.path.join( |
| args.resume_from_checkpoint, "adapter_model.bin" |
| ) |
| args.resume_from_checkpoint = False |
| |
| if os.path.exists(checkpoint_name): |
| if local_rank == 0: |
| print(f"Restarting from {checkpoint_name}") |
| adapters_weights = torch.load(checkpoint_name) |
| model = set_peft_model_state_dict(model, adapters_weights) |
| else: |
| if local_rank == 0: |
| print(f"Checkpoint {checkpoint_name} not found") |
|
|
| for n, p in model.named_parameters(): |
| if "original_module" in n and any(module_name in n for module_name in config.modules_to_save): |
| p.requires_grad = False |
|
|
| if local_rank == 0: |
| model.print_trainable_parameters() |
|
|
|
|
| if not ddp and torch.cuda.device_count() > 1: |
| model.is_parallelizable = True |
| model.model_parallel = True |
|
|
| trainer = transformers.Trainer( |
| model=model, |
| train_dataset=train_data, |
| eval_dataset=valid_data, |
| args=transformers.TrainingArguments( |
| seed=args.seed, |
| per_device_train_batch_size=args.per_device_batch_size, |
| per_device_eval_batch_size=args.per_device_batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| warmup_ratio=args.warmup_ratio, |
| num_train_epochs=args.epochs, |
| learning_rate=args.learning_rate, |
| weight_decay=args.weight_decay, |
| lr_scheduler_type=args.lr_scheduler_type, |
| fp16=args.fp16, |
| bf16=args.bf16, |
| logging_steps=args.logging_step, |
| optim=args.optim, |
| gradient_checkpointing=True, |
| evaluation_strategy=args.save_and_eval_strategy, |
| save_strategy=args.save_and_eval_strategy, |
| eval_steps=args.save_and_eval_steps, |
| save_steps=args.save_and_eval_steps, |
| output_dir=args.output_dir, |
| save_total_limit=5, |
| load_best_model_at_end=True, |
| deepspeed=args.deepspeed, |
| ddp_find_unused_parameters=False if ddp else None, |
| report_to=None, |
| eval_delay=1 if args.save_and_eval_strategy=="epoch" else 2000, |
| ), |
| tokenizer=tokenizer, |
| data_collator=collator, |
| ) |
| model.config.use_cache = False |
|
|
| |
| |
| |
| |
|
|
| if torch.__version__ >= "2" and sys.platform != "win32": |
| model = torch.compile(model) |
|
|
| trainer.train( |
| resume_from_checkpoint=args.resume_from_checkpoint, |
| ) |
|
|
| trainer.save_state() |
| trainer.save_model(output_dir=args.output_dir) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description='LLMRec') |
| parser = parse_global_args(parser) |
| parser = parse_train_args(parser) |
| parser = parse_dataset_args(parser) |
|
|
| args = parser.parse_args() |
|
|
| train(args) |
|
|