| from typing import Dict, List, Any | |
| from transformers import T5Tokenizer, T5ForConditionalGeneration | |
| import torch | |
| import time | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| RobertaTokenizer, | |
| RobertaForSequenceClassification, | |
| ) | |
| import nltk | |
| from nltk.tokenize import sent_tokenize | |
| nltk.download('punkt') | |
| class EndpointHandler: | |
| def __init__(self, path=""): | |
| self.tokenizer = AutoTokenizer.from_pretrained(path) | |
| self.model = T5ForConditionalGeneration.from_pretrained(path, device_map="auto", load_in_8bit=True) | |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| data args: | |
| inputs (:obj: `str`) | |
| date (:obj: `str`) | |
| Return: | |
| A :obj:`list` | `dict`: will be serialized and returned | |
| """ | |
| data = data.pop("inputs", data) | |
| input_text = data.get("input_text", "") | |
| lex_diversity = data.get("lex_diversity", 80) | |
| order_diversity = data.get("order_diversity", 20) | |
| prefix = data.get("prefix", "") | |
| prediction = self.paraphrase( | |
| input_text, | |
| lex_diversity, | |
| order_diversity, | |
| prefix=prefix, | |
| do_sample=True, | |
| top_p=0.75, | |
| max_length=512 | |
| ) | |
| prediction = {'prediction': prediction} | |
| return prediction | |
| def paraphrase(self, input_text, lex_diversity, order_diversity, prefix="", sent_interval=3, **kwargs): | |
| """Paraphrase a text using the DIPPER model. | |
| Args: | |
| input_text (str): The text to paraphrase. Make sure to mark the sentence to be paraphrased between <sent> and </sent> blocks, keeping space on either side. | |
| lex_diversity (int): The lexical diversity of the output, choose multiples of 20 from 0 to 100. 0 means no diversity, 100 means maximum diversity. | |
| order_diversity (int): The order diversity of the output, choose multiples of 20 from 0 to 100. 0 means no diversity, 100 means maximum diversity. | |
| **kwargs: Additional keyword arguments like top_p, top_k, max_length. | |
| """ | |
| assert lex_diversity in [0, 20, 40, 60, 80, 100], "Lexical diversity must be one of 0, 20, 40, 60, 80, 100." | |
| assert order_diversity in [0, 20, 40, 60, 80, 100], "Order diversity must be one of 0, 20, 40, 60, 80, 100." | |
| lex_code = int(100 - lex_diversity) | |
| order_code = int(100 - order_diversity) | |
| input_text = " ".join(input_text.split()) | |
| sentences = sent_tokenize(input_text) | |
| prefix = " ".join(prefix.replace("\n", " ").split()) | |
| output_text = "" | |
| for sent_idx in range(0, len(sentences), sent_interval): | |
| curr_sent_window = " ".join(sentences[sent_idx:sent_idx + sent_interval]) | |
| final_input_text = f"lexical = {lex_code}, order = {order_code}" | |
| if prefix: | |
| final_input_text += f" {prefix}" | |
| final_input_text += f" <sent> {curr_sent_window} </sent>" | |
| final_input = self.tokenizer([final_input_text], return_tensors="pt") | |
| final_input = {k: v.cuda() for k, v in final_input.items()} | |
| with torch.inference_mode(): | |
| outputs = self.model.generate(**final_input, **kwargs) | |
| outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| prefix += " " + outputs[0] | |
| output_text += " " + outputs[0] | |
| return output_text |