File size: 3,466 Bytes
811e03d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
import os
import sys
from typing import List

import torch
import transformers
from peft import PeftModel
from peft import (
    TaskType,
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig

from utils import *
from collator import Collator

import argparse
from utils import *
from rq_llama import *

parser = argparse.ArgumentParser(description = 'rqllama-finetune')
parser = parse_finetune_args(parser)
args = parser.parse_args()

set_seed(args.seed)
ensure_dir(args.output_dir)

device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
local_rank = int(os.environ.get("LOCAL_RANK") or 0)
if local_rank == 0:
    print(vars(args))

if ddp:
    device_map = {"": local_rank}

train_data, valid_data = load_finetune_datasets(args)

tokenizer = LlamaTokenizer.from_pretrained(args.ckpt_path)
base_model = LlamaForCausalLM.from_pretrained(args.base_model, torch_dtype=torch.float16, low_cpu_mem_usage = True, device_map = device_map)
base_model.resize_token_embeddings(len(tokenizer))
rqllama = PeftModel.from_pretrained(base_model, args.ckpt_path, torch_dtype = torch.float16, device_map = device_map)

if local_rank == 0:
    print("token num:", len(tokenizer))
    print("data num:", len(train_data))

collator = Collator(args, tokenizer)

rqllama.train()

if local_rank == 0:
    rqllama.print_trainable_parameters()

trainer = transformers.Trainer(
    model = rqllama,
    train_dataset = train_data,
    eval_dataset = valid_data,
    args = transformers.TrainingArguments(
        seed = args.seed,
        per_device_train_batch_size = args.per_device_batch_size,
        per_device_eval_batch_size = args.per_device_batch_size,
        gradient_accumulation_steps = args.gradient_accumulation_steps,
        warmup_ratio = args.warmup_ratio,
        num_train_epochs = args.epochs,
        learning_rate = args.learning_rate,
        weight_decay = args.weight_decay,
        lr_scheduler_type = args.lr_scheduler_type,
        fp16 = args.fp16,
        bf16 = args.bf16,
        logging_steps = args.logging_step,
        optim = args.optim,
        gradient_checkpointing = True,
        evaluation_strategy = args.save_and_eval_strategy,
        save_strategy = args.save_and_eval_strategy,
        eval_steps = args.save_and_eval_steps,
        save_steps = args.save_and_eval_steps,
        output_dir = args.output_dir,
        save_total_limit = 5,
        load_best_model_at_end = True,
        deepspeed = args.deepspeed,
        ddp_find_unused_parameters = False if ddp else None,
        report_to = None,
        eval_delay = 1 if args.save_and_eval_strategy=="epoch" else 2000,
        dataloader_num_workers = args.dataloader_num_workers,
        dataloader_prefetch_factor = args.dataloader_prefetch_factor,
        remove_unused_columns = args.remove_unused_columns,
    ),
    tokenizer = tokenizer,
    data_collator = collator,
)
rqllama.config.use_cache = False

if torch.__version__ >= "2" and sys.platform != "win32":
    rqllama = torch.compile(rqllama)

trainer.train(resume_from_checkpoint = args.resume_from_checkpoint)

trainer.save_state()
trainer.save_model(output_dir = args.output_dir)

if local_rank == 0:
    print('rqllama fine-tune finished.')