Spaces:
Sleeping
Sleeping
File size: 8,392 Bytes
eeef81e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | import logging
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%Y/%m/%d %H:%M:%S',
level=logging.INFO,
)
logger = logging.getLogger("Main")
import os,random
import numpy as np
import torch
from processing import convert_examples_to_features, read_squad_examples
from processing import ChineseFullTokenizer
from pytorch_pretrained_bert.my_modeling import BertConfig
from optimization import BERTAdam
import config
from utils import read_and_convert, divide_parameters
from modeling import BertForQASimple, BertForQASimpleAdaptorTraining
from textbrewer import DistillationConfig, TrainingConfig, BasicTrainer
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from functools import partial
from train_eval import predict
def args_check(args):
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
logger.warning("Output directory () already exists and is not empty.")
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
if not args.do_train and not args.do_predict:
raise ValueError("At least one of `do_train` or `do_predict` must be True.")
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count() if not args.no_cuda else 0
else:
device = torch.device("cuda", args.local_rank)
n_gpu = 1
torch.distributed.init_process_group(backend='nccl')
logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))
args.n_gpu = n_gpu
args.device = device
return device, n_gpu
def main():
#parse arguments
config.parse()
args = config.args
for k,v in vars(args).items():
logger.info(f"{k}:{v}")
#set seeds
torch.manual_seed(args.random_seed)
torch.cuda.manual_seed_all(args.random_seed)
np.random.seed(args.random_seed)
random.seed(args.random_seed)
#arguments check
device, n_gpu = args_check(args)
os.makedirs(args.output_dir, exist_ok=True)
forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
args.forward_batch_size = forward_batch_size
#load bert config
bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
assert args.max_seq_length <= bert_config_S.max_position_embeddings
#read data
train_examples = None
train_features = None
eval_examples = None
eval_features = None
num_train_steps = None
tokenizer = ChineseFullTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
convert_fn = partial(convert_examples_to_features,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length)
if args.do_train:
train_examples,train_features = read_and_convert(args.train_file,is_training=True, do_lower_case=args.do_lower_case,
read_fn=read_squad_examples,convert_fn=convert_fn)
if args.fake_file_1:
fake_examples1,fake_features1 = read_and_convert(args.fake_file_1,is_training=True, do_lower_case=args.do_lower_case,
read_fn=read_squad_examples,convert_fn=convert_fn)
train_examples += fake_examples1
train_features += fake_features1
if args.fake_file_2:
fake_examples2, fake_features2 = read_and_convert(args.fake_file_2,is_training=True, do_lower_case=args.do_lower_case,
read_fn=read_squad_examples,convert_fn=convert_fn)
train_examples += fake_examples2
train_features += fake_features2
num_train_steps = int(len(train_features)/args.train_batch_size) * args.num_train_epochs
if args.do_predict:
eval_examples,eval_features = read_and_convert(args.predict_file,is_training=False, do_lower_case=args.do_lower_case,
read_fn=read_squad_examples,convert_fn=convert_fn)
#Build Model and load checkpoint
model_S = BertForQASimple(bert_config_S,args)
#Load student
if args.load_model_type=='bert':
assert args.init_checkpoint_S is not None
state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')}
missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
assert len(missing_keys)==0
elif args.load_model_type=='all':
assert args.tuned_checkpoint_S is not None
state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu')
model_S.load_state_dict(state_dict_S)
else:
logger.info("Model is randomly initialized.")
model_S.to(device)
if args.local_rank != -1 or n_gpu > 1:
if args.local_rank != -1:
raise NotImplementedError
elif n_gpu > 1:
model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1)
if args.do_train:
#parameters
params = list(model_S.named_parameters())
all_trainable_params = divide_parameters(params, lr=args.learning_rate)
logger.info("Length of all_trainable_params: %d", len(all_trainable_params))
optimizer = BERTAdam(all_trainable_params,lr=args.learning_rate,
warmup=args.warmup_proportion,t_total=num_train_steps,schedule=args.schedule,
s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3)
logger.info("***** Running training *****")
logger.info(" Num orig examples = %d", len(train_examples))
logger.info(" Num split examples = %d", len(train_features))
logger.info(" Forward batch size = %d", forward_batch_size)
logger.info(" Num backward steps = %d", num_train_steps)
########### DISTILLATION ###########
train_config = TrainingConfig(
gradient_accumulation_steps = args.gradient_accumulation_steps,
ckpt_frequency = args.ckpt_frequency,
log_dir = args.output_dir,
output_dir = args.output_dir,
device = args.device)
distiller = BasicTrainer(train_config = train_config,
model = model_S,
adaptor = BertForQASimpleAdaptorTraining)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_doc_mask = torch.tensor([f.doc_mask for f in train_features], dtype=torch.float)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
train_dataset = TensorDataset(all_input_ids, all_segment_ids, all_input_mask, all_doc_mask,
all_start_positions, all_end_positions)
if args.local_rank == -1:
train_sampler = RandomSampler(train_dataset)
else:
raise NotImplementedError
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True)
callback_func = partial(predict,
eval_examples=eval_examples,
eval_features=eval_features,
args=args)
with distiller:
distiller.train(optimizer, scheduler=None, dataloader=train_dataloader,
num_epochs=args.num_train_epochs, callback=callback_func)
if not args.do_train and args.do_predict:
res = predict(model_S,eval_examples,eval_features,step=0,args=args)
print (res)
if __name__ == "__main__":
main()
|