codewraith / data /source_files /clean /06e16975d4ac.py
slenk's picture
Upload folder using huggingface_hub
eeef81e verified
import logging
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%Y/%m/%d %H:%M:%S',
level=logging.INFO,
)
logger = logging.getLogger("Main")
import os,random
import numpy as np
import torch
from processing import convert_examples_to_features, read_squad_examples
from processing import ChineseFullTokenizer
from pytorch_pretrained_bert.my_modeling import BertConfig
from optimization import BERTAdam
import config
from utils import read_and_convert, divide_parameters
from modeling import BertForQASimple, BertForQASimpleAdaptorTraining
from textbrewer import DistillationConfig, TrainingConfig, BasicTrainer
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from functools import partial
from train_eval import predict
def args_check(args):
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
logger.warning("Output directory () already exists and is not empty.")
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
if not args.do_train and not args.do_predict:
raise ValueError("At least one of `do_train` or `do_predict` must be True.")
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count() if not args.no_cuda else 0
else:
device = torch.device("cuda", args.local_rank)
n_gpu = 1
torch.distributed.init_process_group(backend='nccl')
logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))
args.n_gpu = n_gpu
args.device = device
return device, n_gpu
def main():
#parse arguments
config.parse()
args = config.args
for k,v in vars(args).items():
logger.info(f"{k}:{v}")
#set seeds
torch.manual_seed(args.random_seed)
torch.cuda.manual_seed_all(args.random_seed)
np.random.seed(args.random_seed)
random.seed(args.random_seed)
#arguments check
device, n_gpu = args_check(args)
os.makedirs(args.output_dir, exist_ok=True)
forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
args.forward_batch_size = forward_batch_size
#load bert config
bert_config_S = BertConfig.from_json_file(args.bert_config_file_S)
assert args.max_seq_length <= bert_config_S.max_position_embeddings
#read data
train_examples = None
train_features = None
eval_examples = None
eval_features = None
num_train_steps = None
tokenizer = ChineseFullTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
convert_fn = partial(convert_examples_to_features,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length)
if args.do_train:
train_examples,train_features = read_and_convert(args.train_file,is_training=True, do_lower_case=args.do_lower_case,
read_fn=read_squad_examples,convert_fn=convert_fn)
if args.fake_file_1:
fake_examples1,fake_features1 = read_and_convert(args.fake_file_1,is_training=True, do_lower_case=args.do_lower_case,
read_fn=read_squad_examples,convert_fn=convert_fn)
train_examples += fake_examples1
train_features += fake_features1
if args.fake_file_2:
fake_examples2, fake_features2 = read_and_convert(args.fake_file_2,is_training=True, do_lower_case=args.do_lower_case,
read_fn=read_squad_examples,convert_fn=convert_fn)
train_examples += fake_examples2
train_features += fake_features2
num_train_steps = int(len(train_features)/args.train_batch_size) * args.num_train_epochs
if args.do_predict:
eval_examples,eval_features = read_and_convert(args.predict_file,is_training=False, do_lower_case=args.do_lower_case,
read_fn=read_squad_examples,convert_fn=convert_fn)
#Build Model and load checkpoint
model_S = BertForQASimple(bert_config_S,args)
#Load student
if args.load_model_type=='bert':
assert args.init_checkpoint_S is not None
state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')}
missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
assert len(missing_keys)==0
elif args.load_model_type=='all':
assert args.tuned_checkpoint_S is not None
state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu')
model_S.load_state_dict(state_dict_S)
else:
logger.info("Model is randomly initialized.")
model_S.to(device)
if args.local_rank != -1 or n_gpu > 1:
if args.local_rank != -1:
raise NotImplementedError
elif n_gpu > 1:
model_S = torch.nn.DataParallel(model_S) #,output_device=n_gpu-1)
if args.do_train:
#parameters
params = list(model_S.named_parameters())
all_trainable_params = divide_parameters(params, lr=args.learning_rate)
logger.info("Length of all_trainable_params: %d", len(all_trainable_params))
optimizer = BERTAdam(all_trainable_params,lr=args.learning_rate,
warmup=args.warmup_proportion,t_total=num_train_steps,schedule=args.schedule,
s_opt1=args.s_opt1, s_opt2=args.s_opt2, s_opt3=args.s_opt3)
logger.info("***** Running training *****")
logger.info(" Num orig examples = %d", len(train_examples))
logger.info(" Num split examples = %d", len(train_features))
logger.info(" Forward batch size = %d", forward_batch_size)
logger.info(" Num backward steps = %d", num_train_steps)
########### DISTILLATION ###########
train_config = TrainingConfig(
gradient_accumulation_steps = args.gradient_accumulation_steps,
ckpt_frequency = args.ckpt_frequency,
log_dir = args.output_dir,
output_dir = args.output_dir,
device = args.device)
distiller = BasicTrainer(train_config = train_config,
model = model_S,
adaptor = BertForQASimpleAdaptorTraining)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_doc_mask = torch.tensor([f.doc_mask for f in train_features], dtype=torch.float)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
train_dataset = TensorDataset(all_input_ids, all_segment_ids, all_input_mask, all_doc_mask,
all_start_positions, all_end_positions)
if args.local_rank == -1:
train_sampler = RandomSampler(train_dataset)
else:
raise NotImplementedError
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True)
callback_func = partial(predict,
eval_examples=eval_examples,
eval_features=eval_features,
args=args)
with distiller:
distiller.train(optimizer, scheduler=None, dataloader=train_dataloader,
num_epochs=args.num_train_epochs, callback=callback_func)
if not args.do_train and args.do_predict:
res = predict(model_S,eval_examples,eval_features,step=0,args=args)
print (res)
if __name__ == "__main__":
main()