jaimin commited on
Commit
d75114e
·
1 Parent(s): 39afde2

Create new file

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio.mix import Parallel
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+ import os
5
+ from transformers import T5TokenizerFast, T5ForConditionalGeneration
6
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer
7
+ import pytorch_lightning as pl
8
+ import torch
9
+ import itertools
10
+ import random
11
+ import nltk
12
+ from nltk.tokenize import sent_tokenize
13
+ import requests
14
+ import json
15
+ nltk.download('punkt')
16
+ from fastT5 import export_and_get_onnx_model
17
+
18
+
19
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
+
21
+ T5_tokenizer = AutoTokenizer.from_pretrained("jaimin/T5-Large")
22
+ T5_model = export_and_get_onnx_model('jaimin/T5-large')
23
+
24
+
25
+
26
+ def get_paraphrases(text, n_predictions=3, top_k=50, max_length=256, device="cpu"):
27
+ para = []
28
+ r = requests.post(
29
+ url="https://hf.space/embed/jaimin/CWI/+/api/predict",
30
+ json={"data": [text]},
31
+ )
32
+ response = r.json()
33
+ sentence = response["data"][0]
34
+ for sent in sent_tokenize(sentence):
35
+ text = "paraphrase: "+sent + " </s>"
36
+ encoding = T5_tokenizer.encode_plus(text, padding=True, return_tensors="pt", truncation=True)
37
+ input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)
38
+ model_output = T5_model.generate(
39
+ input_ids=input_ids,attention_mask=attention_masks,
40
+ max_length = 512,
41
+ early_stopping=True,
42
+ num_beams=15,
43
+ num_beam_groups = 3,
44
+ num_return_sequences=n_predictions,
45
+ diversity_penalty = 0.70,
46
+ temperature=0.7)
47
+ outputs = []
48
+ for output in model_output:
49
+ generated_sent = T5_tokenizer.decode(
50
+ output, skip_special_tokens=True, clean_up_tokenization_spaces=True
51
+ )
52
+ if (
53
+ generated_sent.lower() != sentence.lower()
54
+ and generated_sent not in outputs
55
+ ):
56
+ outputs.append(generated_sent.replace('paraphrasedoutput:', ""))
57
+ para.append(outputs)
58
+ print(para)
59
+ a = list(itertools.product(*para))
60
+ random.shuffle(a)
61
+
62
+ l=[]
63
+ for i in range(len(a)):
64
+ l.append(" ".join(a[i]))
65
+
66
+ return l
67
+
68
+ iface = gr.Interface(fn=get_paraphrases, inputs=[gr.inputs.Textbox(lines=5)],outputs=["text"])
69
+ iface.launch()
70
+