Eaz123 commited on
Commit
47e5004
·
verified ·
1 Parent(s): 55d951f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -16
app.py CHANGED
@@ -1,23 +1,79 @@
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
 
3
 
4
- model_name = "ramsrigouthamg/t5_paraphraser"
5
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
6
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
7
 
8
- def paraphrase_text(text):
9
- input_text = "paraphrase: " + text + " </s>"
10
- encoding = tokenizer.encode_plus(
11
- input_text, padding="max_length", return_tensors="pt", max_length=256, truncation=True
12
- )
13
- outputs = model.generate(
14
- input_ids=encoding["input_ids"],
15
- attention_mask=encoding["attention_mask"],
 
 
 
 
 
 
 
 
 
16
  max_length=256,
17
- num_return_sequences=1,
18
- temperature=1.5,
 
 
 
19
  )
20
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- demo = gr.Interface(fn=paraphrase_text, inputs="text", outputs="text", title="Free Paraphraser")
23
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+ import nltk
5
+ import torch
6
+ from textblob import TextBlob
7
 
8
+ nltk.download("punkt")
 
 
9
 
10
+ model = AutoModelForSeq2SeqLM.from_pretrained("ramsrigouthamg/t5_paraphraser")
11
+ tokenizer = AutoTokenizer.from_pretrained("ramsrigouthamg/t5_paraphraser", use_fast=False)
12
+
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model = model.to(device)
15
+
16
+ def split_into_sentences(text):
17
+ return nltk.tokenize.sent_tokenize(text)
18
+
19
+ def paraphrase_sentence(sentence, creativity):
20
+ text = "paraphrase: " + sentence + " </s>"
21
+ encoding = tokenizer.encode_plus(text, padding="max_length", return_tensors="pt", max_length=256, truncation=True)
22
+ input_ids, attention_mask = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)
23
+
24
+ output = model.generate(
25
+ input_ids=input_ids,
26
+ attention_mask=attention_mask,
27
  max_length=256,
28
+ do_sample=True,
29
+ top_k=120,
30
+ top_p=creativity,
31
+ early_stopping=True,
32
+ num_return_sequences=1
33
  )
34
+ paraphrased = tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
35
+ return paraphrased
36
+
37
+ def correct_grammar(text):
38
+ blob = TextBlob(text)
39
+ return str(blob.correct())
40
+
41
+ def paraphrase_text(text, creativity=0.9, tone="neutral", improve_grammar=True, batch_size=5):
42
+ sentences = split_into_sentences(text)
43
+ results = []
44
+
45
+ for sentence in sentences:
46
+ para = paraphrase_sentence(sentence, creativity)
47
+ if improve_grammar:
48
+ para = correct_grammar(para)
49
+ results.append(para)
50
 
51
+ return " ".join(results)
52
+
53
+ with gr.Blocks() as demo:
54
+ gr.Markdown("<h1>Paraphrasing Tool</h1><p>AI-powered rewriting with grammar improvement and creativity controls.</p>")
55
+
56
+ with gr.Row():
57
+ input_text = gr.Textbox(lines=8, label="Enter Text")
58
+ output_text = gr.Textbox(lines=8, label="Paraphrased Output")
59
+
60
+ creativity = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Creativity (top_p)")
61
+ improve_grammar = gr.Checkbox(value=True, label="Improve Grammar")
62
+ tone = gr.Radio(["neutral", "formal", "casual"], label="Tone (placeholder)", value="neutral")
63
+ batch_size = gr.Slider(1, 10, value=5, step=1, label="Batch Size")
64
+
65
+ run_button = gr.Button("Paraphrase")
66
+
67
+ run_button.click(
68
+ paraphrase_text,
69
+ inputs=[input_text, creativity, tone, improve_grammar, batch_size],
70
+ outputs=output_text
71
+ )
72
+
73
+ if __name__ == "__main__":
74
+ demo.launch(
75
+ server_name="0.0.0.0",
76
+ server_port=7860,
77
+ show_api=True,
78
+ favicon_path="favicon.ico"
79
+ )