Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| ''' | |
| @Author : Jiangjie Chen | |
| @Time : 2020/8/12 14:44 | |
| @Contact : jjchen19@fudan.edu.cn | |
| @Description: | |
| ''' | |
| import re | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, List | |
| import torch | |
| from logging import getLogger | |
| from tqdm import tqdm | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| import ujson as json | |
| import random | |
| try: | |
| from .seq2seq.seq2seq_utils import ( | |
| use_task_specific_params, | |
| calculate_rouge, | |
| chunks, | |
| Seq2SeqDataset, | |
| lmap, | |
| load_json, | |
| save_json, | |
| ) | |
| except ImportError: | |
| import cjjpy as cjj | |
| import sys | |
| sys.path.append(cjj.AbsParentDir(__file__, '.')) | |
| from seq2seq.seq2seq_utils import ( | |
| use_task_specific_params, | |
| calculate_rouge, | |
| chunks, | |
| Seq2SeqDataset, | |
| lmap, | |
| load_json, | |
| save_json, | |
| ) | |
| logger = getLogger(__name__) | |
| DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| random.seed(1111) | |
| def assemble_answers_to_one(js, k=5, mask_token='<mask>', mask_rate=0.): | |
| if isinstance(js, str): | |
| js = json.loads(js) | |
| should_keep = random.random() > mask_rate | |
| js.pop('evidential_assembled') | |
| for q, answers in zip(js['cloze_qs'], js['evidential']): | |
| if mask_token in q: | |
| s = q.find(mask_token) | |
| e = s + len(mask_token) | |
| nq_list = [] | |
| if should_keep: | |
| for i in range(k): | |
| answer_span = answers[i] | |
| nq = q[:s] + answer_span + q[e:] | |
| nq_list.append(nq) | |
| else: | |
| for i in range(k): | |
| answer_span = mask_token | |
| nq = q[:s] + answer_span + q[e:] | |
| nq_list.append(nq) | |
| ev_nqs = ' '.join(nq_list) | |
| if js.get('evidential_assembled') is None: | |
| js['evidential_assembled'] = [ev_nqs] | |
| else: | |
| js['evidential_assembled'].append(ev_nqs) | |
| assert len(js['evidential_assembled']) == len(js['answers']) | |
| return js | |
| class AnswerGenerator(): | |
| def __init__(self, model_name, device=DEFAULT_DEVICE): | |
| self.model_name = str(model_name) | |
| self.device = device | |
| self.model = None | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| def init_model(self): | |
| if self.model is None: | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name).to(self.device) | |
| def assemble(self, question, context): | |
| sep = '\n' if 'unifiedqa' in self.tokenizer.name_or_path else self.tokenizer.sep_token | |
| return f'{question} {sep} {context}' | |
| def generate(self, examples, out_file=None, batch_size=16, verbose=True, | |
| max_length=20, min_length=1, num_beams=4, num_return_sequences=4, | |
| prefix=None, fp16=False, task='summarization', **generate_kwargs): | |
| ''' | |
| :param examples: [N] | |
| :return: [N x num_return_seq] | |
| ''' | |
| self.init_model() | |
| if fp16: | |
| self.model = self.model.half() | |
| # update config with summarization specific params | |
| use_task_specific_params(self.model, task) | |
| fout = None if out_file is None else Path(out_file).open("w", encoding="utf-8") | |
| generated = [] | |
| if verbose: | |
| iter = tqdm(list(chunks(examples, batch_size)), desc="MRC") | |
| else: | |
| iter = list(chunks(examples, batch_size)) | |
| if prefix is None: | |
| prefix = prefix or getattr(self.model.config, "prefix", "") or "" | |
| for examples_chunk in iter: | |
| examples_chunk = [prefix + text for text in examples_chunk] | |
| batch = self.tokenizer(examples_chunk, return_tensors="pt", truncation=True, | |
| padding="longest").to(self.device) | |
| summaries = self.model.generate( | |
| input_ids=batch.input_ids, | |
| attention_mask=batch.attention_mask, | |
| max_length=max_length, | |
| min_length=min_length, | |
| num_beams=num_beams, | |
| num_return_sequences=num_return_sequences, | |
| length_penalty=1.2, | |
| repetition_penalty=1.2, | |
| **generate_kwargs, | |
| ) | |
| dec = self.tokenizer.batch_decode(summaries, skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False) | |
| if fout is not None: | |
| for hypothesis in dec: | |
| fout.write(hypothesis.strip() + "\n") | |
| fout.flush() | |
| else: | |
| generated += dec | |
| if fout is not None: | |
| fout.close() | |
| generated = list(map(lambda x: x.strip(), generated)) | |
| generated = list(chunks(generated, num_return_sequences)) | |
| return generated | |