File size: 5,653 Bytes
c39435c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | # -*- coding: utf-8 -*-
from datasets import load_from_disk,concatenate_datasets, load_dataset
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
Trainer)
# import fla # noqa
from flame.data import DataCollatorForLanguageModeling
from flame.logging import LogCallback, get_logger
from flame.parser import get_train_args
import sys
sys.path.append('/mnt/jfzn/msj/flash-linear-attention/legacy/training')
sys.path.append('/mnt/jfzn/msj/flash-linear-attention/legacy/training/fla2')
# logger = get_logger(__name__)
# from fla1.models import GLAForCausalLM, GLAConfig, DENDATTNConfig, DENDForCausalLM, MomGLAConfig, MomGLAForCausalLM, GatedDeltaNetForCausalLM, GatedDeltaNetConfig
# from fla1.models import MomGatedDeltaNetForCausalLM,MomGatedDeltaNetConfig,MobGatedDeltaNetForCausalLM,MobGatedDeltaNetConfig
# from fla1.models import Mamba2Config,Mamba2ForCausalLM,MobMamba2Config,MobMamba2ForCausalLM,DeltaNetConfig,DeltaNetForCausalLM,DeltaNetModel
from fla2.models import emlaConfig,emlaForCausalLM,emglaConfig,emglaForCausalLM,mask_deltanetConfig,mask_deltanetForCausalLM
# logger = get_logger(__name__)
# # breakpoint()
print(emlaConfig.model_type)
AutoConfig.register("emla",emlaConfig)
AutoModelForCausalLM.register(emlaConfig,emlaForCausalLM)
print(emglaConfig.model_type)
AutoConfig.register("emgla",emglaConfig)
AutoModelForCausalLM.register(emglaConfig,emglaForCausalLM)
# print(mask_deltanetConfig.model_type)
# AutoConfig.register("mask_deltanet",mask_deltanetConfig)
# AutoModelForCausalLM.register(mask_deltanetConfig,mask_deltanetForCausalLM)
# mask_deltanet
from fla2.models import mask_deltanetConfig,mask_deltanetForCausalLM
logger = get_logger(__name__)
print(mask_deltanetConfig.model_type)
AutoConfig.register("mask_deltanet",mask_deltanetConfig)
AutoModelForCausalLM.register(mask_deltanetConfig,mask_deltanetForCausalLM)
from fla2.models import emlaConfig,emlaForCausalLM,emglaConfig,emglaForCausalLM,mask_deltanetConfig,mask_deltanetForCausalLM
from fla.models import GatedDeltaNetConfig, GatedDeltaNetForCausalLM, GatedDeltaNetModel
logger = get_logger(__name__)
# breakpoint()
print(GatedDeltaNetConfig.model_type)
AutoConfig.register("gated_deltanet",GatedDeltaNetConfig)
AutoModelForCausalLM.register(GatedDeltaNetConfig,GatedDeltaNetForCausalLM)
from fla.models import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetModel
print(DeltaNetConfig.model_type)
AutoConfig.register("delta_net",DeltaNetConfig)
AutoModelForCausalLM.register(DeltaNetConfig,DeltaNetForCausalLM)
# from fla.models import Mamba2Config, Mamba2ForCausalLM, Mamba2Model
# print(Mamba2Config.model_type)
# AutoConfig.register("mamba2",Mamba2Config)
# AutoModelForCausalLM.register(Mamba2Config,Mamba2ForCausalLM)
def main():
args = get_train_args()
print(args)
logger.info(args)
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer,
use_fast=args.use_fast_tokenizer,
trust_remote_code=True,
add_bos_token=True,
add_eos_token=False
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
if args.from_config:
logger.info("All model params are randomly initialized for from-scratch training.")
model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(args.model_name_or_path))
# print(model)
# model = AutoModelForCausalLM.from_pretrained('/mnt/jfzn/msj/delta_net-1.3B-100B_cp')
else:
logger.info(f"Loading pretrained checkpoint {args.model_name_or_path}")
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
model.train()
trainable_params, all_param = model.num_parameters(only_trainable=True), model.num_parameters()
logger.info(f"% of trainable params: {trainable_params:d} / {all_param:d} = {trainable_params / all_param:.2%}")
logger.info(f"{tokenizer}\n{model}\n{model.config}")
print(f"% of trainable params: {trainable_params:d} / {all_param:d} = {trainable_params / all_param:.2%}")
logger.info(f"Loading the `{args.split}` split directly from the cache {args.cache_dir}...")
cache_dir = args.cache_dir.split(',')
if len(cache_dir)>1:
dataset = [load_from_disk(path) for path in cache_dir]
dataset = concatenate_datasets(dataset)
else:
dataset = load_from_disk(cache_dir[0])
logger.info(f"{dataset}")
logger.info(f"Shuffling the dataset with seed {args.seed}")
dataset = dataset.shuffle(seed=args.seed)
logger.info("Creating the data collator")
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, varlen=args.varlen)
logger.info(f"{data_collator}")
if args.lr_scheduler_type == 'cosine_with_min_lr':
args.lr_scheduler_kwargs = {'min_lr_rate': 0.1}
if args.lr_scheduler_type == 'warmup_stable_decay':
args.lr_scheduler_kwargs = {
'num_stable_steps': args.max_steps * 0.9 - args.warmup_steps,
'num_decay_steps': args.max_steps * 0.1
}
trainer = Trainer(
model=model,
args=args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=[LogCallback()],
train_dataset=dataset
)
results = trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
trainer.save_model()
tokenizer.save_pretrained(trainer.args.output_dir)
trainer.log_metrics("train", results.metrics)
trainer.save_metrics("train", results.metrics)
trainer.save_state()
if __name__ == "__main__":
main()
|