File size: 3,493 Bytes
32f0fd5 1bea27b 32f0fd5 1bea27b c577cbf 1bea27b 32f0fd5 1bea27b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | from typing import Dict, List, Any
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
import time
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
RobertaTokenizer,
RobertaForSequenceClassification,
)
import nltk
from nltk.tokenize import sent_tokenize
nltk.download('punkt')
class EndpointHandler:
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = T5ForConditionalGeneration.from_pretrained(path, device_map="auto", load_in_8bit=True)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
data = data.pop("inputs", data)
input_text = data.get("input_text", "")
lex_diversity = data.get("lex_diversity", 80)
order_diversity = data.get("order_diversity", 20)
prefix = data.get("prefix", "")
prediction = self.paraphrase(
input_text,
lex_diversity,
order_diversity,
prefix=prefix,
do_sample=True,
top_p=0.75,
max_length=512
)
prediction = {'prediction': prediction}
return prediction
def paraphrase(self, input_text, lex_diversity, order_diversity, prefix="", sent_interval=3, **kwargs):
"""Paraphrase a text using the DIPPER model.
Args:
input_text (str): The text to paraphrase. Make sure to mark the sentence to be paraphrased between <sent> and </sent> blocks, keeping space on either side.
lex_diversity (int): The lexical diversity of the output, choose multiples of 20 from 0 to 100. 0 means no diversity, 100 means maximum diversity.
order_diversity (int): The order diversity of the output, choose multiples of 20 from 0 to 100. 0 means no diversity, 100 means maximum diversity.
**kwargs: Additional keyword arguments like top_p, top_k, max_length.
"""
assert lex_diversity in [0, 20, 40, 60, 80, 100], "Lexical diversity must be one of 0, 20, 40, 60, 80, 100."
assert order_diversity in [0, 20, 40, 60, 80, 100], "Order diversity must be one of 0, 20, 40, 60, 80, 100."
lex_code = int(100 - lex_diversity)
order_code = int(100 - order_diversity)
input_text = " ".join(input_text.split())
sentences = sent_tokenize(input_text)
prefix = " ".join(prefix.replace("\n", " ").split())
output_text = ""
for sent_idx in range(0, len(sentences), sent_interval):
curr_sent_window = " ".join(sentences[sent_idx:sent_idx + sent_interval])
final_input_text = f"lexical = {lex_code}, order = {order_code}"
if prefix:
final_input_text += f" {prefix}"
final_input_text += f" <sent> {curr_sent_window} </sent>"
final_input = self.tokenizer([final_input_text], return_tensors="pt")
final_input = {k: v.cuda() for k, v in final_input.items()}
with torch.inference_mode():
outputs = self.model.generate(**final_input, **kwargs)
outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
prefix += " " + outputs[0]
output_text += " " + outputs[0]
return output_text |