import logging logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%Y/%m/%d %H:%M:%S', level=logging.INFO, ) logger = logging.getLogger("Main") import os,random import numpy as np import torch from processing import convert_examples_to_features, read_squad_examples from processing import ChineseFullTokenizer from pytorch_pretrained_bert.my_modeling import BertConfig from optimization import BERTAdam import config from utils import read_and_convert, divide_parameters from modeling import BertForQASimple, BertForQASimpleAdaptorTraining from textbrewer import DistillationConfig, TrainingConfig, BasicTrainer from torch.utils.data import TensorDataset, DataLoader, RandomSampler from functools import partial from train_eval import predict def args_check(args): if os.path.exists(args.output_dir) and os.listdir(args.output_dir): logger.warning("Output directory () already exists and is not empty.") if args.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( args.gradient_accumulation_steps)) if not args.do_train and not args.do_predict: raise ValueError("At least one of `do_train` or `do_predict` must be True.") if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() if not args.no_cuda else 0 else: device = torch.device("cuda", args.local_rank) n_gpu = 1 torch.distributed.init_process_group(backend='nccl') logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1)) args.n_gpu = n_gpu args.device = device return device, n_gpu def main(): #parse arguments config.parse() args = config.args for k,v in vars(args).items(): logger.info(f"{k}:{v}") #set seeds torch.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) #arguments check device, n_gpu = args_check(args) os.makedirs(args.output_dir, exist_ok=True) forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.forward_batch_size = forward_batch_size #load bert config bert_config_S = BertConfig.from_json_file(args.bert_config_file_S) assert args.max_seq_length <= bert_config_S.max_position_embeddings #read data train_examples = None train_features = None eval_examples = None eval_features = None num_train_steps = None tokenizer = ChineseFullTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) convert_fn = partial(convert_examples_to_features, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length) if args.do_train: train_examples,train_features = read_and_convert(args.train_file,is_training=True, do_lower_case=args.do_lower_case, read_fn=read_squad_examples,convert_fn=convert_fn) if args.fake_file_1: fake_examples1,fake_features1 = read_and_convert(args.fake_file_1,is_training=True, do_lower_case=args.do_lower_case, read_fn=read_squad_examples,convert_fn=convert_fn) train_examples += fake_examples1 train_features += fake_features1 if args.fake_file_2: fake_examples2, fake_features2 = read_and_convert(args.fake_file_2,is_training=True, do_lower_case=args.do_lower_case, read_fn=read_squad_examples,convert_fn=convert_fn) train_examples += fake_examples2 train_features += fake_features2 num_train_steps = int(len(train_features)/args.train_batch_size) * args.num_train_epochs if args.do_predict: eval_examples,eval_features = read_and_convert(args.predict_file,is_training=False, do_lower_case=args.do_lower_case, read_fn=read_squad_examples,convert_fn=convert_fn) #Build Model and load checkpoint model_S = BertForQASimple(bert_config_S,args) #Load student if args.load_model_type=='bert': assert args.init_checkpoint_S is not None state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu') state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')} missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False) assert len(missing_keys)==0 elif args.load_model_type=='all': assert args.tuned_checkpoint_S is not None state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu') model_S.load_state_dict(state_dict_S) else: logger.info("Model is randomly initialized.") model_S.to(device) if args.local_rank != -1 or n_gpu > 1: if args.local_rank != -1: raise NotImplementedError elif n_gpu > 1: model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1) if args.do_train: #parameters params = list(model_S.named_parameters()) all_trainable_params = divide_parameters(params, lr=args.learning_rate) logger.info("Length of all_trainable_params: %d", len(all_trainable_params)) optimizer = BERTAdam(all_trainable_params,lr=args.learning_rate, warmup=args.warmup_proportion,t_total=num_train_steps,schedule=args.schedule, s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3) logger.info("***** Running training *****") logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num split examples = %d", len(train_features)) logger.info(" Forward batch size = %d", forward_batch_size) logger.info(" Num backward steps = %d", num_train_steps) ########### DISTILLATION ########### train_config = TrainingConfig( gradient_accumulation_steps = args.gradient_accumulation_steps, ckpt_frequency = args.ckpt_frequency, log_dir = args.output_dir, output_dir = args.output_dir, device = args.device) distiller = BasicTrainer(train_config = train_config, model = model_S, adaptor = BertForQASimpleAdaptorTraining) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_doc_mask = torch.tensor([f.doc_mask for f in train_features], dtype=torch.float) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) train_dataset = TensorDataset(all_input_ids, all_segment_ids, all_input_mask, all_doc_mask, all_start_positions, all_end_positions) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: raise NotImplementedError train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True) callback_func = partial(predict, eval_examples=eval_examples, eval_features=eval_features, args=args) with distiller: distiller.train(optimizer, scheduler=None, dataloader=train_dataloader, num_epochs=args.num_train_epochs, callback=callback_func) if not args.do_train and args.do_predict: res = predict(model_S,eval_examples,eval_features,step=0,args=args) print (res) if __name__ == "__main__": main()