Spaces:
Runtime error
Runtime error
| import numpy as np | |
| from math import exp | |
| import torch | |
| from torch import nn | |
| from transformers import BertTokenizer, BertForNextSentencePrediction | |
| import utils | |
| from maddog import Extractor | |
| import spacy | |
| import constant | |
| nlp = spacy.load("en_core_web_sm") | |
| ruleExtractor = Extractor() | |
| kb = utils.load_acronym_kb('acronym_kb.json') | |
| model_path='acrobert.pt' | |
| class AcronymBERT(nn.Module): | |
| def __init__(self, model_name="bert-base-uncased", device='cpu'): | |
| super().__init__() | |
| self.device = device | |
| self.model = BertForNextSentencePrediction.from_pretrained(model_name) | |
| self.tokenizer = BertTokenizer.from_pretrained(model_name) | |
| def forward(self, sentence): | |
| samples = self.tokenizer(sentence, padding=True, return_tensors='pt', truncation=True)["input_ids"] | |
| samples = samples.to(self.device) | |
| outputs = self.model(samples).logits | |
| scores = nn.Softmax(dim=1)(outputs)[:, 0] | |
| return scores | |
| model = AcronymBERT(device='cpu') | |
| model.load_state_dict(torch.load(model_path, map_location='cpu')) | |
| def softmax(elements): | |
| total = sum([exp(e) for e in elements]) | |
| return exp(elements[0]) / total | |
| def predict(topk, model, short_form, context, batch_size, acronym_kb, device): | |
| ori_candidate = utils.get_candidate(acronym_kb, short_form, can_num=20) | |
| long_terms = [str.lower(can) for can in ori_candidate] | |
| scores = cal_score(model.model, model.tokenizer, long_terms, context, batch_size, device) | |
| #indexes = [np.argmax(scores)] | |
| topk = min(len(scores), topk) | |
| indexes = np.array(scores).argsort()[::-1][:topk] | |
| names = [ori_candidate[i] for i in indexes] | |
| confidences = [round(scores[i], 3) for i in indexes] | |
| return names, confidences | |
| def cal_score(model, tokenizer, long_forms, contexts, batch_size, device): | |
| ps = list() | |
| for index in range(0, len(long_forms), batch_size): | |
| batch_lf = long_forms[index:index + batch_size] | |
| batch_ctx = [contexts] * len(batch_lf) | |
| encoding = tokenizer(batch_lf, batch_ctx, return_tensors="pt", padding=True, truncation=True, max_length=400).to(device) | |
| outputs = model(**encoding) | |
| logits = outputs.logits.cpu().detach().numpy() | |
| p = [softmax(lg) for lg in logits] | |
| ps.extend(p) | |
| return ps | |
| def dog_extract(sentence): | |
| tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0] | |
| rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES) | |
| return rulebased_pairs | |
| def acrobert(sentence, model, device): | |
| model.to(device) | |
| #params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| #print(params) | |
| tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0] | |
| rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES) | |
| results = dict() | |
| for acronym in rulebased_pairs.keys(): | |
| if rulebased_pairs[acronym][0] != '': | |
| results[acronym] = rulebased_pairs[acronym][0] | |
| else: | |
| pred, scores = predict(5, model, acronym, sentence, batch_size=10, acronym_kb=kb, device=device) | |
| output = list(zip(pred, scores)) | |
| #print(output) | |
| results[acronym] = output | |
| #results.append((acronym, pred[0], scores[0])) | |
| return results | |
| def popularity(sentence): | |
| tokens = [t.text for t in nlp(sentence) if len(t.text.strip()) > 0] | |
| rulebased_pairs = ruleExtractor.extract(tokens, constant.RULES) | |
| results = list() | |
| for acronym in rulebased_pairs.keys(): | |
| if rulebased_pairs[acronym][0] != '': | |
| results.append((acronym, rulebased_pairs[acronym][0])) | |
| else: | |
| pred = utils.get_candidate(kb, acronym, can_num=1) | |
| results.append((acronym, pred[0])) | |
| return results | |
| def acronym_linker(sentence, mode='acrobert', model=model, device='cpu'): | |
| if mode == 'acrobert': | |
| return acrobert(sentence, model, device) | |
| if mode == 'pop': | |
| return popularity(sentence) | |
| raise Exception('mode name should in this list [acrobert, pop]') | |
| if __name__ == '__main__': | |
| #sentence = \ | |
| #"This new genome assembly and the annotation are tagged as a RefSeq genome by NCBI and thus provide substantially enhanced genomic resources for future research involving S. scovelli." | |
| #sentence = """ There have been initiated several projects to modernize the network of ECB | |
| #corridors, financed from ispa funds and state-guaranteed loans from international | |
| #financial institutions.""" | |
| # sentence = """A whistleblower like monologist Mike Daisey gets targeted as a scapegoat who must | |
| # be discredited and diminished in the public ’s eye. More often than not, PR is | |
| # a preemptive process. Celebrity publicists are paid lots of money to keep certain | |
| # stories out of the news.""" | |
| sentence = """ | |
| AI is a wide-ranging branch of computer science concerned with building smart machines capable of performing tasks that typically require human intelligence. | |
| """ | |
| results = acronym_linker(sentence) | |
| print(results) |