Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import time | |
| from transformers import T5Tokenizer, T5ForConditionalGeneration | |
| from nltk.tokenize import sent_tokenize | |
| class DipperParaphraser(object): | |
| def __init__(self, model="kalpeshk2011/dipper-paraphraser-xxl", verbose=True): | |
| time1 = time.time() | |
| self.tokenizer = T5Tokenizer.from_pretrained('google/t5-v1_1-xxl') | |
| self.model = T5ForConditionalGeneration.from_pretrained(model) | |
| if verbose: | |
| print(f"{model} model loaded in {time.time() - time1}") | |
| self.model.cuda() | |
| self.model.eval() | |
| def paraphrase(self, input_text, lex_diversity, order_diversity, prefix="", sent_interval=3, **kwargs): | |
| assert lex_diversity in [0, 20, 40, 60, 80, 100], "Lexical diversity must be one of 0, 20, 40, 60, 80, 100." | |
| assert order_diversity in [0, 20, 40, 60, 80, 100], "Order diversity must be one of 0, 20, 40, 60, 80, 100." | |
| lex_code = int(100 - lex_diversity) | |
| order_code = int(100 - order_diversity) | |
| input_text = " ".join(input_text.split()) | |
| sentences = sent_tokenize(input_text) | |
| prefix = " ".join(prefix.replace("\n", " ").split()) | |
| output_text = "" | |
| for sent_idx in range(0, len(sentences), sent_interval): | |
| curr_sent_window = " ".join(sentences[sent_idx:sent_idx + sent_interval]) | |
| final_input_text = f"lexical = {lex_code}, order = {order_code}" | |
| if prefix: | |
| final_input_text += f" {prefix}" | |
| final_input_text += f" <sent> {curr_sent_window} </sent>" | |
| final_input = self.tokenizer([final_input_text], return_tensors="pt") | |
| final_input = {k: v.cuda() for k, v in final_input.items()} | |
| with torch.inference_mode(): | |
| outputs = self.model.generate(**final_input, **kwargs) | |
| outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| prefix += " " + outputs[0] | |
| output_text += " " + outputs[0] | |
| return output_text.strip() | |
| dp = None | |
| def paraphrase_interface(prompt, input_text, lex_diversity, order_diversity, sent_interval, top_p, top_k, max_length, do_sample): | |
| global dp | |
| if dp is None: | |
| dp = DipperParaphraser(verbose=False) | |
| kwargs = { | |
| "do_sample": do_sample, | |
| "top_p": top_p, | |
| "top_k": top_k if top_k else None, | |
| "max_length": max_length, | |
| } | |
| return dp.paraphrase( | |
| input_text, | |
| lex_diversity=lex_diversity, | |
| order_diversity=order_diversity, | |
| prefix=prompt, | |
| sent_interval=sent_interval, | |
| **kwargs | |
| ) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# DIPPER Paraphraser XXL") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Prompt (Optional)", value="In a shocking finding, scientist discovered a herd of unicorns living in a remote valley.") | |
| input_text = gr.Textbox(label="Text to Paraphrase", lines=8, value="They have never been known to mingle with humans. Today, it is believed these unicorns live in an unspoilt environment which is surrounded by mountains. Its edge is protected by a thick wattle of wattle trees, giving it a majestic appearance. Along with their so-called miracle of multicolored coat, their golden coloured feather makes them look like mirages. Some of them are rumored to be capable of speaking a large amount of different languages. They feed on elk and goats as they were selected from those animals that possess a fierceness to them, and can \"eat\" them with their long horns.") | |
| lex_diversity = gr.Dropdown(label="Lexical Diversity", choices=[0, 20, 40, 60, 80, 100], value=60) | |
| order_diversity = gr.Dropdown(label="Order Diversity", choices=[0, 20, 40, 60, 80, 100], value=0) | |
| sent_interval = gr.Number(label="Sentence Interval", value=3, precision=0) | |
| top_p = gr.Number(label="Top P (sampling)", value=0.75) | |
| top_k = gr.Number(label="Top K (sampling, None for default)", value=None, precision=0) | |
| max_length = gr.Number(label="Max Length", value=512, precision=0) | |
| do_sample = gr.Checkbox(label="Enable Sampling", value=True) | |
| btn = gr.Button("Paraphrase") | |
| with gr.Column(): | |
| output = gr.Textbox(label="Paraphrased Output", lines=8) | |
| btn.click( | |
| paraphrase_interface, | |
| inputs=[prompt, input_text, lex_diversity, order_diversity, sent_interval, top_p, top_k, max_length, do_sample], | |
| outputs=output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |