File size: 4,235 Bytes
811e03d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import sys
from typing import List
import argparse

import wandb
import torch
import transformers
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig

from peft import (
    TaskType,
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
)

from collator import VanillaCollator
from rq_llama import *
from utils import *

parser = argparse.ArgumentParser(description = 'rqllama-pretrain-more')
parser = parse_global_args(parser)
parser = parse_train_args(parser)
parser = parse_dataset_args(parser)
parser = parse_rqvae_args(parser)
parser = parse_pretrain_args(parser)
args = parser.parse_args()
wandb.init(config = args, reinit = True)

set_seed(args.seed)
ensure_dir(args.output_dir)

device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
local_rank = int(os.environ.get("LOCAL_RANK") or 0)
if local_rank == 0:
    print(vars(args))
if ddp:
    device_map = {"": local_rank}

train_data, valid_data = load_datasets(args)

rqllama = LlamaWithRQ.from_pretrained(args.ckpt_path, torch_dtype = torch.float16, low_cpu_mem_usage = True, device_map = device_map)

for i in range(len(args.num_emb_list)):
    rqllama.rqvae.rq.vq_layers[i].initted = True

if local_rank == 0:
    print("token num:", len(rqllama.tokenizer))
    print("data num:", len(train_data))
    rqllama.tokenizer.save_pretrained(args.output_dir)
    rqllama.config.save_pretrained(args.output_dir)

if args.resume_from_checkpoint:
    checkpoint_name = os.path.join(args.resume_from_checkpoint, "adapter_model.bin")
    args.resume_from_checkpoint = False
    if os.path.exists(checkpoint_name):
        if local_rank == 0:
            print(f"Restarting from {checkpoint_name}")
        adapters_weights = torch.load(checkpoint_name)
        rqllama.model = set_peft_model_state_dict(rqllama.model, adapters_weights)
    else:
        if local_rank == 0:
            print(f"Checkpoint {checkpoint_name} not found")

if local_rank == 0:
    rqllama.model.print_trainable_parameters()

if not ddp and torch.cuda.device_count() > 1:
    rqllama.is_parallelizable = True
    rqllama.model_parallel = True

collator = VanillaCollator(args, rqllama.tokenizer)

trainer = transformers.Trainer(
    model = rqllama,
    train_dataset = train_data,
    eval_dataset = valid_data,
    args = transformers.TrainingArguments(
        seed = args.seed,
        per_device_train_batch_size = args.per_device_batch_size,
        per_device_eval_batch_size = args.per_device_batch_size,
        gradient_accumulation_steps = args.gradient_accumulation_steps,
        warmup_ratio = args.warmup_ratio,
        num_train_epochs = args.epochs,
        learning_rate = args.learning_rate,
        weight_decay = args.weight_decay,
        lr_scheduler_type = args.lr_scheduler_type,
        fp16 = args.fp16,
        bf16 = args.bf16,
        logging_steps = args.logging_step,
        optim = args.optim,
        gradient_checkpointing = True,
        evaluation_strategy = args.save_and_eval_strategy,
        save_strategy = args.save_and_eval_strategy,
        eval_steps = args.save_and_eval_steps,
        save_steps = args.save_and_eval_steps,
        output_dir = args.output_dir,
        save_total_limit = 5,
        load_best_model_at_end = True,
        deepspeed = args.deepspeed,
        ddp_find_unused_parameters = False if ddp else None,
        report_to = None,
        eval_delay = 1 if args.save_and_eval_strategy=="epoch" else 2000,
        dataloader_num_workers = args.dataloader_num_workers,
        dataloader_prefetch_factor = args.dataloader_prefetch_factor,
        remove_unused_columns = args.remove_unused_columns,
    ),
    tokenizer = rqllama.tokenizer,
    data_collator = collator,
)
rqllama.config.use_cache = False

if torch.__version__ >= "2" and sys.platform != "win32":
    rqllama = torch.compile(rqllama)

trainer.train(resume_from_checkpoint = args.resume_from_checkpoint)

trainer.save_state()
trainer.save_model(output_dir = args.output_dir)

if local_rank == 0:
    print('rqllama pre-train finished.')