Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| ''' | |
| @Author : Jiangjie Chen | |
| @Time : 2020/9/20 11:42 | |
| @Contact : jjchen19@fudan.edu.cn | |
| @Description: | |
| ''' | |
| import torch | |
| from transformers import BertTokenizer | |
| from .retrieval_model.bert_model import BertForSequenceEncoder | |
| from .retrieval_model.models import inference_model | |
| from .retrieval_model.data_loader import DataLoaderTest | |
| class SentSelector: | |
| def __init__(self, pretrained_bert_path, select_model_path, args): | |
| self.args = args | |
| self.use_cuda = self.args.use_cuda and torch.cuda.is_available() | |
| self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased') | |
| self.bert_model = BertForSequenceEncoder.from_pretrained(pretrained_bert_path) | |
| self.rank_model = inference_model(self.bert_model, self.args) | |
| self.rank_model.load_state_dict(torch.load(select_model_path)['model']) | |
| if self.use_cuda: | |
| self.bert_model = self.bert_model.cuda() | |
| self.rank_model.cuda() | |
| def rank_sentences(self, js: list): | |
| ''' | |
| :param js: [{'claim': xxx, 'id': xx, 'evidence': xxx}] | |
| :return: [(ent, num, sent, prob), (ent, num, sent, prob)] | |
| ''' | |
| data_reader = DataLoaderTest(js, self.tokenizer, self.args, self.use_cuda) | |
| self.rank_model.eval() | |
| all_predict = dict() | |
| for inp_tensor, msk_tensor, seg_tensor, ids, evi_list in data_reader: | |
| probs = self.rank_model(inp_tensor, msk_tensor, seg_tensor) | |
| probs = probs.tolist() | |
| assert len(probs) == len(evi_list) | |
| for i in range(len(probs)): | |
| if ids[i] not in all_predict: | |
| all_predict[ids[i]] = [] | |
| # if probs[i][1] >= probs[i][0]: | |
| all_predict[ids[i]].append(tuple(evi_list[i]) + (probs[i],)) | |
| results = {} | |
| for k, v in all_predict.items(): | |
| sorted_v = sorted(v, key=lambda x: x[-1], reverse=True) | |
| results[k] = sorted_v[:self.args.evi_num] | |
| return results | |