from typing import Dict, List, Any from transformers import T5Tokenizer, T5ForConditionalGeneration import torch import time from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification, ) import nltk from nltk.tokenize import sent_tokenize nltk.download('punkt') class EndpointHandler: def __init__(self, path=""): self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = T5ForConditionalGeneration.from_pretrained(path, device_map="auto", load_in_8bit=True) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ data args: inputs (:obj: `str`) date (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ data = data.pop("inputs", data) input_text = data.get("input_text", "") lex_diversity = data.get("lex_diversity", 80) order_diversity = data.get("order_diversity", 20) prefix = data.get("prefix", "") prediction = self.paraphrase( input_text, lex_diversity, order_diversity, prefix=prefix, do_sample=True, top_p=0.75, max_length=512 ) prediction = {'prediction': prediction} return prediction def paraphrase(self, input_text, lex_diversity, order_diversity, prefix="", sent_interval=3, **kwargs): """Paraphrase a text using the DIPPER model. Args: input_text (str): The text to paraphrase. Make sure to mark the sentence to be paraphrased between and blocks, keeping space on either side. lex_diversity (int): The lexical diversity of the output, choose multiples of 20 from 0 to 100. 0 means no diversity, 100 means maximum diversity. order_diversity (int): The order diversity of the output, choose multiples of 20 from 0 to 100. 0 means no diversity, 100 means maximum diversity. **kwargs: Additional keyword arguments like top_p, top_k, max_length. """ assert lex_diversity in [0, 20, 40, 60, 80, 100], "Lexical diversity must be one of 0, 20, 40, 60, 80, 100." assert order_diversity in [0, 20, 40, 60, 80, 100], "Order diversity must be one of 0, 20, 40, 60, 80, 100." lex_code = int(100 - lex_diversity) order_code = int(100 - order_diversity) input_text = " ".join(input_text.split()) sentences = sent_tokenize(input_text) prefix = " ".join(prefix.replace("\n", " ").split()) output_text = "" for sent_idx in range(0, len(sentences), sent_interval): curr_sent_window = " ".join(sentences[sent_idx:sent_idx + sent_interval]) final_input_text = f"lexical = {lex_code}, order = {order_code}" if prefix: final_input_text += f" {prefix}" final_input_text += f" {curr_sent_window} " final_input = self.tokenizer([final_input_text], return_tensors="pt") final_input = {k: v.cuda() for k, v in final_input.items()} with torch.inference_mode(): outputs = self.model.generate(**final_input, **kwargs) outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) prefix += " " + outputs[0] output_text += " " + outputs[0] return output_text