File size: 4,583 Bytes
909aaea
44e5fb9
 
 
 
909aaea
44e5fb9
 
 
 
 
 
 
 
 
909aaea
44e5fb9
 
 
909aaea
44e5fb9
 
909aaea
44e5fb9
 
 
 
909aaea
44e5fb9
 
 
 
 
 
909aaea
44e5fb9
 
909aaea
44e5fb9
 
 
 
 
909aaea
44e5fb9
909aaea
44e5fb9
909aaea
44e5fb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
909aaea
 
44e5fb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
909aaea
44e5fb9
 
 
 
 
909aaea
 
44e5fb9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
import torch
import time
from transformers import T5Tokenizer, T5ForConditionalGeneration
from nltk.tokenize import sent_tokenize

class DipperParaphraser(object):
    def __init__(self, model="kalpeshk2011/dipper-paraphraser-xxl", verbose=True):
        time1 = time.time()
        self.tokenizer = T5Tokenizer.from_pretrained('google/t5-v1_1-xxl')
        self.model = T5ForConditionalGeneration.from_pretrained(model)
        if verbose:
            print(f"{model} model loaded in {time.time() - time1}")
        self.model.cuda()
        self.model.eval()

    def paraphrase(self, input_text, lex_diversity, order_diversity, prefix="", sent_interval=3, **kwargs):
        assert lex_diversity in [0, 20, 40, 60, 80, 100], "Lexical diversity must be one of 0, 20, 40, 60, 80, 100."
        assert order_diversity in [0, 20, 40, 60, 80, 100], "Order diversity must be one of 0, 20, 40, 60, 80, 100."

        lex_code = int(100 - lex_diversity)
        order_code = int(100 - order_diversity)

        input_text = " ".join(input_text.split())
        sentences = sent_tokenize(input_text)
        prefix = " ".join(prefix.replace("\n", " ").split())
        output_text = ""

        for sent_idx in range(0, len(sentences), sent_interval):
            curr_sent_window = " ".join(sentences[sent_idx:sent_idx + sent_interval])
            final_input_text = f"lexical = {lex_code}, order = {order_code}"
            if prefix:
                final_input_text += f" {prefix}"
            final_input_text += f" <sent> {curr_sent_window} </sent>"

            final_input = self.tokenizer([final_input_text], return_tensors="pt")
            final_input = {k: v.cuda() for k, v in final_input.items()}

            with torch.inference_mode():
                outputs = self.model.generate(**final_input, **kwargs)
            outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
            prefix += " " + outputs[0]
            output_text += " " + outputs[0]

        return output_text.strip()

dp = None

def paraphrase_interface(prompt, input_text, lex_diversity, order_diversity, sent_interval, top_p, top_k, max_length, do_sample):
    global dp
    if dp is None:
        dp = DipperParaphraser(verbose=False)
    kwargs = {
        "do_sample": do_sample,
        "top_p": top_p,
        "top_k": top_k if top_k else None,
        "max_length": max_length,
    }
    return dp.paraphrase(
        input_text,
        lex_diversity=lex_diversity,
        order_diversity=order_diversity,
        prefix=prompt,
        sent_interval=sent_interval,
        **kwargs
    )

with gr.Blocks() as demo:
    gr.Markdown("# DIPPER Paraphraser XXL")
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt (Optional)", value="In a shocking finding, scientist discovered a herd of unicorns living in a remote valley.")
            input_text = gr.Textbox(label="Text to Paraphrase", lines=8, value="They have never been known to mingle with humans. Today, it is believed these unicorns live in an unspoilt environment which is surrounded by mountains. Its edge is protected by a thick wattle of wattle trees, giving it a majestic appearance. Along with their so-called miracle of multicolored coat, their golden coloured feather makes them look like mirages. Some of them are rumored to be capable of speaking a large amount of different languages. They feed on elk and goats as they were selected from those animals that possess a fierceness to them, and can \"eat\" them with their long horns.")
            lex_diversity = gr.Dropdown(label="Lexical Diversity", choices=[0, 20, 40, 60, 80, 100], value=60)
            order_diversity = gr.Dropdown(label="Order Diversity", choices=[0, 20, 40, 60, 80, 100], value=0)
            sent_interval = gr.Number(label="Sentence Interval", value=3, precision=0)
            top_p = gr.Number(label="Top P (sampling)", value=0.75)
            top_k = gr.Number(label="Top K (sampling, None for default)", value=None, precision=0)
            max_length = gr.Number(label="Max Length", value=512, precision=0)
            do_sample = gr.Checkbox(label="Enable Sampling", value=True)
            btn = gr.Button("Paraphrase")
        with gr.Column():
            output = gr.Textbox(label="Paraphrased Output", lines=8)

    btn.click(
        paraphrase_interface,
        inputs=[prompt, input_text, lex_diversity, order_diversity, sent_interval, top_p, top_k, max_length, do_sample],
        outputs=output
    )

if __name__ == "__main__":
    demo.launch()