Spaces:
Build error
Build error
| import logging | |
| import argparse | |
| import os | |
| import json | |
| import torch | |
| from tqdm import tqdm | |
| from transformers import BertTokenizer | |
| from .models import inference_model | |
| from .data_loader import DataLoaderTest | |
| from .bert_model import BertForSequenceEncoder | |
| logger = logging.getLogger(__name__) | |
| def save_to_file(all_predict, outpath, evi_num): | |
| with open(outpath, "w") as out: | |
| for key, values in all_predict.items(): | |
| sorted_values = sorted(values, key=lambda x:x[-1], reverse=True) | |
| data = json.dumps({"id": key, "evidence": sorted_values[:evi_num]}) | |
| out.write(data + "\n") | |
| def eval_model(model, validset_reader): | |
| model.eval() | |
| all_predict = dict() | |
| for inp_tensor, msk_tensor, seg_tensor, ids, evi_list in tqdm(validset_reader): | |
| probs = model(inp_tensor, msk_tensor, seg_tensor) | |
| probs = probs.tolist() | |
| assert len(probs) == len(evi_list) | |
| for i in range(len(probs)): | |
| if ids[i] not in all_predict: | |
| all_predict[ids[i]] = [] | |
| #if probs[i][1] >= probs[i][0]: | |
| all_predict[ids[i]].append(evi_list[i] + [probs[i]]) | |
| return all_predict | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--test_path', help='train path') | |
| parser.add_argument('--name', help='train path') | |
| parser.add_argument("--batch_size", default=32, type=int, help="Total batch size for training.") | |
| parser.add_argument('--outdir', required=True, help='path to output directory') | |
| parser.add_argument('--bert_pretrain', required=True) | |
| parser.add_argument('--checkpoint', required=True) | |
| parser.add_argument('--dropout', type=float, default=0.6, help='Dropout.') | |
| parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.') | |
| parser.add_argument("--bert_hidden_dim", default=768, type=int, help="Total batch size for training.") | |
| parser.add_argument("--layer", type=int, default=1, help='Graph Layer.') | |
| parser.add_argument("--num_labels", type=int, default=3) | |
| parser.add_argument("--evi_num", type=int, default=5, help='Evidence num.') | |
| parser.add_argument("--threshold", type=float, default=0.0, help='Evidence num.') | |
| parser.add_argument("--max_len", default=120, type=int, | |
| help="The maximum total input sequence length after WordPiece tokenization. Sequences " | |
| "longer than this will be truncated, and sequences shorter than this will be padded.") | |
| args = parser.parse_args() | |
| if not os.path.exists(args.outdir): | |
| os.mkdir(args.outdir) | |
| args.cuda = not args.no_cuda and torch.cuda.is_available() | |
| handlers = [logging.FileHandler(os.path.abspath(args.outdir) + '/train_log.txt'), logging.StreamHandler()] | |
| logging.basicConfig(format='[%(asctime)s] %(levelname)s: %(message)s', level=logging.DEBUG, | |
| datefmt='%d-%m-%Y %H:%M:%S', handlers=handlers) | |
| logger.info(args) | |
| logger.info('Start training!') | |
| tokenizer = BertTokenizer.from_pretrained(args.bert_pretrain, do_lower_case=False) | |
| logger.info("loading training set") | |
| validset_reader = DataLoaderTest(args.test_path, tokenizer, args, batch_size=args.batch_size) | |
| logger.info('initializing estimator model') | |
| bert_model = BertForSequenceEncoder.from_pretrained(args.bert_pretrain) | |
| bert_model = bert_model.cuda() | |
| model = inference_model(bert_model, args) | |
| model.load_state_dict(torch.load(args.checkpoint)['model']) | |
| model = model.cuda() | |
| logger.info('Start eval!') | |
| save_path = args.outdir + "/" + args.name | |
| predict_dict = eval_model(model, validset_reader) | |
| save_to_file(predict_dict, save_path, args.evi_num) |