# -*- coding: utf-8 -*- from datasets import load_from_disk,concatenate_datasets, load_dataset from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, Trainer) # import fla # noqa from flame.data import DataCollatorForLanguageModeling from flame.logging import LogCallback, get_logger from flame.parser import get_train_args import sys sys.path.append('/mnt/jfzn/msj/flash-linear-attention/legacy/training') sys.path.append('/mnt/jfzn/msj/flash-linear-attention/legacy/training/fla2') # logger = get_logger(__name__) # from fla1.models import GLAForCausalLM, GLAConfig, DENDATTNConfig, DENDForCausalLM, MomGLAConfig, MomGLAForCausalLM, GatedDeltaNetForCausalLM, GatedDeltaNetConfig # from fla1.models import MomGatedDeltaNetForCausalLM,MomGatedDeltaNetConfig,MobGatedDeltaNetForCausalLM,MobGatedDeltaNetConfig # from fla1.models import Mamba2Config,Mamba2ForCausalLM,MobMamba2Config,MobMamba2ForCausalLM,DeltaNetConfig,DeltaNetForCausalLM,DeltaNetModel from fla2.models import emlaConfig,emlaForCausalLM,emglaConfig,emglaForCausalLM,mask_deltanetConfig,mask_deltanetForCausalLM # logger = get_logger(__name__) # # breakpoint() print(emlaConfig.model_type) AutoConfig.register("emla",emlaConfig) AutoModelForCausalLM.register(emlaConfig,emlaForCausalLM) print(emglaConfig.model_type) AutoConfig.register("emgla",emglaConfig) AutoModelForCausalLM.register(emglaConfig,emglaForCausalLM) # print(mask_deltanetConfig.model_type) # AutoConfig.register("mask_deltanet",mask_deltanetConfig) # AutoModelForCausalLM.register(mask_deltanetConfig,mask_deltanetForCausalLM) # mask_deltanet from fla2.models import mask_deltanetConfig,mask_deltanetForCausalLM logger = get_logger(__name__) print(mask_deltanetConfig.model_type) AutoConfig.register("mask_deltanet",mask_deltanetConfig) AutoModelForCausalLM.register(mask_deltanetConfig,mask_deltanetForCausalLM) from fla2.models import emlaConfig,emlaForCausalLM,emglaConfig,emglaForCausalLM,mask_deltanetConfig,mask_deltanetForCausalLM from fla.models import GatedDeltaNetConfig, GatedDeltaNetForCausalLM, GatedDeltaNetModel logger = get_logger(__name__) # breakpoint() print(GatedDeltaNetConfig.model_type) AutoConfig.register("gated_deltanet",GatedDeltaNetConfig) AutoModelForCausalLM.register(GatedDeltaNetConfig,GatedDeltaNetForCausalLM) from fla.models import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetModel print(DeltaNetConfig.model_type) AutoConfig.register("delta_net",DeltaNetConfig) AutoModelForCausalLM.register(DeltaNetConfig,DeltaNetForCausalLM) # from fla.models import Mamba2Config, Mamba2ForCausalLM, Mamba2Model # print(Mamba2Config.model_type) # AutoConfig.register("mamba2",Mamba2Config) # AutoModelForCausalLM.register(Mamba2Config,Mamba2ForCausalLM) def main(): args = get_train_args() print(args) logger.info(args) tokenizer = AutoTokenizer.from_pretrained( args.tokenizer, use_fast=args.use_fast_tokenizer, trust_remote_code=True, add_bos_token=True, add_eos_token=False ) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token logger.info("Add pad token: {}".format(tokenizer.pad_token)) if args.from_config: logger.info("All model params are randomly initialized for from-scratch training.") model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(args.model_name_or_path)) # print(model) # model = AutoModelForCausalLM.from_pretrained('/mnt/jfzn/msj/delta_net-1.3B-100B_cp') else: logger.info(f"Loading pretrained checkpoint {args.model_name_or_path}") model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path) model.train() trainable_params, all_param = model.num_parameters(only_trainable=True), model.num_parameters() logger.info(f"% of trainable params: {trainable_params:d} / {all_param:d} = {trainable_params / all_param:.2%}") logger.info(f"{tokenizer}\n{model}\n{model.config}") print(f"% of trainable params: {trainable_params:d} / {all_param:d} = {trainable_params / all_param:.2%}") logger.info(f"Loading the `{args.split}` split directly from the cache {args.cache_dir}...") cache_dir = args.cache_dir.split(',') if len(cache_dir)>1: dataset = [load_from_disk(path) for path in cache_dir] dataset = concatenate_datasets(dataset) else: dataset = load_from_disk(cache_dir[0]) logger.info(f"{dataset}") logger.info(f"Shuffling the dataset with seed {args.seed}") dataset = dataset.shuffle(seed=args.seed) logger.info("Creating the data collator") data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, varlen=args.varlen) logger.info(f"{data_collator}") if args.lr_scheduler_type == 'cosine_with_min_lr': args.lr_scheduler_kwargs = {'min_lr_rate': 0.1} if args.lr_scheduler_type == 'warmup_stable_decay': args.lr_scheduler_kwargs = { 'num_stable_steps': args.max_steps * 0.9 - args.warmup_steps, 'num_decay_steps': args.max_steps * 0.1 } trainer = Trainer( model=model, args=args, tokenizer=tokenizer, data_collator=data_collator, callbacks=[LogCallback()], train_dataset=dataset ) results = trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) trainer.save_model() tokenizer.save_pretrained(trainer.args.output_dir) trainer.log_metrics("train", results.metrics) trainer.save_metrics("train", results.metrics) trainer.save_state() if __name__ == "__main__": main()