File size: 4,235 Bytes
811e03d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
import sys
from typing import List
import argparse
import wandb
import torch
import transformers
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
from peft import (
TaskType,
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
set_peft_model_state_dict,
)
from collator import VanillaCollator
from rq_llama import *
from utils import *
parser = argparse.ArgumentParser(description = 'rqllama-pretrain-more')
parser = parse_global_args(parser)
parser = parse_train_args(parser)
parser = parse_dataset_args(parser)
parser = parse_rqvae_args(parser)
parser = parse_pretrain_args(parser)
args = parser.parse_args()
wandb.init(config = args, reinit = True)
set_seed(args.seed)
ensure_dir(args.output_dir)
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
local_rank = int(os.environ.get("LOCAL_RANK") or 0)
if local_rank == 0:
print(vars(args))
if ddp:
device_map = {"": local_rank}
train_data, valid_data = load_datasets(args)
rqllama = LlamaWithRQ.from_pretrained(args.ckpt_path, torch_dtype = torch.float16, low_cpu_mem_usage = True, device_map = device_map)
for i in range(len(args.num_emb_list)):
rqllama.rqvae.rq.vq_layers[i].initted = True
if local_rank == 0:
print("token num:", len(rqllama.tokenizer))
print("data num:", len(train_data))
rqllama.tokenizer.save_pretrained(args.output_dir)
rqllama.config.save_pretrained(args.output_dir)
if args.resume_from_checkpoint:
checkpoint_name = os.path.join(args.resume_from_checkpoint, "adapter_model.bin")
args.resume_from_checkpoint = False
if os.path.exists(checkpoint_name):
if local_rank == 0:
print(f"Restarting from {checkpoint_name}")
adapters_weights = torch.load(checkpoint_name)
rqllama.model = set_peft_model_state_dict(rqllama.model, adapters_weights)
else:
if local_rank == 0:
print(f"Checkpoint {checkpoint_name} not found")
if local_rank == 0:
rqllama.model.print_trainable_parameters()
if not ddp and torch.cuda.device_count() > 1:
rqllama.is_parallelizable = True
rqllama.model_parallel = True
collator = VanillaCollator(args, rqllama.tokenizer)
trainer = transformers.Trainer(
model = rqllama,
train_dataset = train_data,
eval_dataset = valid_data,
args = transformers.TrainingArguments(
seed = args.seed,
per_device_train_batch_size = args.per_device_batch_size,
per_device_eval_batch_size = args.per_device_batch_size,
gradient_accumulation_steps = args.gradient_accumulation_steps,
warmup_ratio = args.warmup_ratio,
num_train_epochs = args.epochs,
learning_rate = args.learning_rate,
weight_decay = args.weight_decay,
lr_scheduler_type = args.lr_scheduler_type,
fp16 = args.fp16,
bf16 = args.bf16,
logging_steps = args.logging_step,
optim = args.optim,
gradient_checkpointing = True,
evaluation_strategy = args.save_and_eval_strategy,
save_strategy = args.save_and_eval_strategy,
eval_steps = args.save_and_eval_steps,
save_steps = args.save_and_eval_steps,
output_dir = args.output_dir,
save_total_limit = 5,
load_best_model_at_end = True,
deepspeed = args.deepspeed,
ddp_find_unused_parameters = False if ddp else None,
report_to = None,
eval_delay = 1 if args.save_and_eval_strategy=="epoch" else 2000,
dataloader_num_workers = args.dataloader_num_workers,
dataloader_prefetch_factor = args.dataloader_prefetch_factor,
remove_unused_columns = args.remove_unused_columns,
),
tokenizer = rqllama.tokenizer,
data_collator = collator,
)
rqllama.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
rqllama = torch.compile(rqllama)
trainer.train(resume_from_checkpoint = args.resume_from_checkpoint)
trainer.save_state()
trainer.save_model(output_dir = args.output_dir)
if local_rank == 0:
print('rqllama pre-train finished.') |