File size: 5,653 Bytes
c39435c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# -*- coding: utf-8 -*-

from datasets import load_from_disk,concatenate_datasets, load_dataset
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
                          Trainer)
# import fla  # noqa
from flame.data import DataCollatorForLanguageModeling
from flame.logging import LogCallback, get_logger
from flame.parser import get_train_args
import sys
sys.path.append('/mnt/jfzn/msj/flash-linear-attention/legacy/training')
sys.path.append('/mnt/jfzn/msj/flash-linear-attention/legacy/training/fla2')
# logger = get_logger(__name__)

# from fla1.models import GLAForCausalLM, GLAConfig, DENDATTNConfig, DENDForCausalLM, MomGLAConfig, MomGLAForCausalLM, GatedDeltaNetForCausalLM, GatedDeltaNetConfig
# from fla1.models import MomGatedDeltaNetForCausalLM,MomGatedDeltaNetConfig,MobGatedDeltaNetForCausalLM,MobGatedDeltaNetConfig
# from fla1.models import Mamba2Config,Mamba2ForCausalLM,MobMamba2Config,MobMamba2ForCausalLM,DeltaNetConfig,DeltaNetForCausalLM,DeltaNetModel
from fla2.models import emlaConfig,emlaForCausalLM,emglaConfig,emglaForCausalLM,mask_deltanetConfig,mask_deltanetForCausalLM
# logger = get_logger(__name__)
# # breakpoint()
print(emlaConfig.model_type)
AutoConfig.register("emla",emlaConfig)
AutoModelForCausalLM.register(emlaConfig,emlaForCausalLM)
print(emglaConfig.model_type)
AutoConfig.register("emgla",emglaConfig)
AutoModelForCausalLM.register(emglaConfig,emglaForCausalLM)

# print(mask_deltanetConfig.model_type)
# AutoConfig.register("mask_deltanet",mask_deltanetConfig)
# AutoModelForCausalLM.register(mask_deltanetConfig,mask_deltanetForCausalLM)
# mask_deltanet
from fla2.models import mask_deltanetConfig,mask_deltanetForCausalLM
logger = get_logger(__name__)
print(mask_deltanetConfig.model_type)
AutoConfig.register("mask_deltanet",mask_deltanetConfig)
AutoModelForCausalLM.register(mask_deltanetConfig,mask_deltanetForCausalLM)


from fla2.models import emlaConfig,emlaForCausalLM,emglaConfig,emglaForCausalLM,mask_deltanetConfig,mask_deltanetForCausalLM

from fla.models import GatedDeltaNetConfig, GatedDeltaNetForCausalLM, GatedDeltaNetModel
logger = get_logger(__name__)
# breakpoint()
print(GatedDeltaNetConfig.model_type)
AutoConfig.register("gated_deltanet",GatedDeltaNetConfig)
AutoModelForCausalLM.register(GatedDeltaNetConfig,GatedDeltaNetForCausalLM)

from fla.models import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetModel
print(DeltaNetConfig.model_type)
AutoConfig.register("delta_net",DeltaNetConfig)
AutoModelForCausalLM.register(DeltaNetConfig,DeltaNetForCausalLM)


# from fla.models import Mamba2Config, Mamba2ForCausalLM, Mamba2Model
# print(Mamba2Config.model_type)
# AutoConfig.register("mamba2",Mamba2Config)
# AutoModelForCausalLM.register(Mamba2Config,Mamba2ForCausalLM)
def main():
    args = get_train_args()
    print(args)
    logger.info(args)

    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer,
        use_fast=args.use_fast_tokenizer,
        trust_remote_code=True,
        add_bos_token=True,
        add_eos_token=False
    )
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
        logger.info("Add pad token: {}".format(tokenizer.pad_token))
    if args.from_config:
        logger.info("All model params are randomly initialized for from-scratch training.")
        model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(args.model_name_or_path))
        # print(model)
        # model = AutoModelForCausalLM.from_pretrained('/mnt/jfzn/msj/delta_net-1.3B-100B_cp')
    else:
        logger.info(f"Loading pretrained checkpoint {args.model_name_or_path}")
        model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
    model.train()

    trainable_params, all_param = model.num_parameters(only_trainable=True), model.num_parameters()
    logger.info(f"% of trainable params: {trainable_params:d} / {all_param:d} = {trainable_params / all_param:.2%}")
    logger.info(f"{tokenizer}\n{model}\n{model.config}")
    

    print(f"% of trainable params: {trainable_params:d} / {all_param:d} = {trainable_params / all_param:.2%}")
    logger.info(f"Loading the `{args.split}` split directly from the cache {args.cache_dir}...")
    
    cache_dir = args.cache_dir.split(',')
    if len(cache_dir)>1:  
        dataset = [load_from_disk(path) for path in cache_dir]
        dataset = concatenate_datasets(dataset)
    else:
        dataset = load_from_disk(cache_dir[0])    
    logger.info(f"{dataset}")
    logger.info(f"Shuffling the dataset with seed {args.seed}")
    dataset = dataset.shuffle(seed=args.seed)
    logger.info("Creating the data collator")
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, varlen=args.varlen)
    logger.info(f"{data_collator}")

    if args.lr_scheduler_type == 'cosine_with_min_lr':
        args.lr_scheduler_kwargs = {'min_lr_rate': 0.1}
    if args.lr_scheduler_type == 'warmup_stable_decay':
        args.lr_scheduler_kwargs = {
            'num_stable_steps': args.max_steps * 0.9 - args.warmup_steps,
            'num_decay_steps': args.max_steps * 0.1
        }

    trainer = Trainer(
        model=model,
        args=args,
        tokenizer=tokenizer,
        data_collator=data_collator,
        callbacks=[LogCallback()],
        train_dataset=dataset
    )

    results = trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
    trainer.save_model()
    tokenizer.save_pretrained(trainer.args.output_dir)

    trainer.log_metrics("train", results.metrics)
    trainer.save_metrics("train", results.metrics)
    trainer.save_state()


if __name__ == "__main__":
    main()