msj19's picture
Add files using upload-large-folder tool
c39435c verified
# -*- coding: utf-8 -*-
from datasets import load_from_disk,concatenate_datasets, load_dataset
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
Trainer)
# import fla # noqa
from flame.data import DataCollatorForLanguageModeling
from flame.logging import LogCallback, get_logger
from flame.parser import get_train_args
import sys
sys.path.append('/mnt/jfzn/msj/flash-linear-attention/legacy/training')
sys.path.append('/mnt/jfzn/msj/flash-linear-attention/legacy/training/fla2')
# logger = get_logger(__name__)
# from fla1.models import GLAForCausalLM, GLAConfig, DENDATTNConfig, DENDForCausalLM, MomGLAConfig, MomGLAForCausalLM, GatedDeltaNetForCausalLM, GatedDeltaNetConfig
# from fla1.models import MomGatedDeltaNetForCausalLM,MomGatedDeltaNetConfig,MobGatedDeltaNetForCausalLM,MobGatedDeltaNetConfig
# from fla1.models import Mamba2Config,Mamba2ForCausalLM,MobMamba2Config,MobMamba2ForCausalLM,DeltaNetConfig,DeltaNetForCausalLM,DeltaNetModel
from fla2.models import emlaConfig,emlaForCausalLM,emglaConfig,emglaForCausalLM,mask_deltanetConfig,mask_deltanetForCausalLM
# logger = get_logger(__name__)
# # breakpoint()
print(emlaConfig.model_type)
AutoConfig.register("emla",emlaConfig)
AutoModelForCausalLM.register(emlaConfig,emlaForCausalLM)
print(emglaConfig.model_type)
AutoConfig.register("emgla",emglaConfig)
AutoModelForCausalLM.register(emglaConfig,emglaForCausalLM)
# print(mask_deltanetConfig.model_type)
# AutoConfig.register("mask_deltanet",mask_deltanetConfig)
# AutoModelForCausalLM.register(mask_deltanetConfig,mask_deltanetForCausalLM)
# mask_deltanet
from fla2.models import mask_deltanetConfig,mask_deltanetForCausalLM
logger = get_logger(__name__)
print(mask_deltanetConfig.model_type)
AutoConfig.register("mask_deltanet",mask_deltanetConfig)
AutoModelForCausalLM.register(mask_deltanetConfig,mask_deltanetForCausalLM)
from fla2.models import emlaConfig,emlaForCausalLM,emglaConfig,emglaForCausalLM,mask_deltanetConfig,mask_deltanetForCausalLM
from fla.models import GatedDeltaNetConfig, GatedDeltaNetForCausalLM, GatedDeltaNetModel
logger = get_logger(__name__)
# breakpoint()
print(GatedDeltaNetConfig.model_type)
AutoConfig.register("gated_deltanet",GatedDeltaNetConfig)
AutoModelForCausalLM.register(GatedDeltaNetConfig,GatedDeltaNetForCausalLM)
from fla.models import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetModel
print(DeltaNetConfig.model_type)
AutoConfig.register("delta_net",DeltaNetConfig)
AutoModelForCausalLM.register(DeltaNetConfig,DeltaNetForCausalLM)
# from fla.models import Mamba2Config, Mamba2ForCausalLM, Mamba2Model
# print(Mamba2Config.model_type)
# AutoConfig.register("mamba2",Mamba2Config)
# AutoModelForCausalLM.register(Mamba2Config,Mamba2ForCausalLM)
def main():
args = get_train_args()
print(args)
logger.info(args)
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer,
use_fast=args.use_fast_tokenizer,
trust_remote_code=True,
add_bos_token=True,
add_eos_token=False
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
if args.from_config:
logger.info("All model params are randomly initialized for from-scratch training.")
model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(args.model_name_or_path))
# print(model)
# model = AutoModelForCausalLM.from_pretrained('/mnt/jfzn/msj/delta_net-1.3B-100B_cp')
else:
logger.info(f"Loading pretrained checkpoint {args.model_name_or_path}")
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
model.train()
trainable_params, all_param = model.num_parameters(only_trainable=True), model.num_parameters()
logger.info(f"% of trainable params: {trainable_params:d} / {all_param:d} = {trainable_params / all_param:.2%}")
logger.info(f"{tokenizer}\n{model}\n{model.config}")
print(f"% of trainable params: {trainable_params:d} / {all_param:d} = {trainable_params / all_param:.2%}")
logger.info(f"Loading the `{args.split}` split directly from the cache {args.cache_dir}...")
cache_dir = args.cache_dir.split(',')
if len(cache_dir)>1:
dataset = [load_from_disk(path) for path in cache_dir]
dataset = concatenate_datasets(dataset)
else:
dataset = load_from_disk(cache_dir[0])
logger.info(f"{dataset}")
logger.info(f"Shuffling the dataset with seed {args.seed}")
dataset = dataset.shuffle(seed=args.seed)
logger.info("Creating the data collator")
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, varlen=args.varlen)
logger.info(f"{data_collator}")
if args.lr_scheduler_type == 'cosine_with_min_lr':
args.lr_scheduler_kwargs = {'min_lr_rate': 0.1}
if args.lr_scheduler_type == 'warmup_stable_decay':
args.lr_scheduler_kwargs = {
'num_stable_steps': args.max_steps * 0.9 - args.warmup_steps,
'num_decay_steps': args.max_steps * 0.1
}
trainer = Trainer(
model=model,
args=args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=[LogCallback()],
train_dataset=dataset
)
results = trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
trainer.save_model()
tokenizer.save_pretrained(trainer.args.output_dir)
trainer.log_metrics("train", results.metrics)
trainer.save_metrics("train", results.metrics)
trainer.save_state()
if __name__ == "__main__":
main()