Eaz123 commited on
Commit
7e38706
·
verified ·
1 Parent(s): 47e5004

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -71
app.py CHANGED
@@ -1,79 +1,43 @@
1
-
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import nltk
5
- import torch
6
- from textblob import TextBlob
7
-
8
- nltk.download("punkt")
9
-
10
- model = AutoModelForSeq2SeqLM.from_pretrained("ramsrigouthamg/t5_paraphraser")
11
- tokenizer = AutoTokenizer.from_pretrained("ramsrigouthamg/t5_paraphraser", use_fast=False)
12
-
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- model = model.to(device)
15
-
16
- def split_into_sentences(text):
17
- return nltk.tokenize.sent_tokenize(text)
18
 
19
- def paraphrase_sentence(sentence, creativity):
20
- text = "paraphrase: " + sentence + " </s>"
21
- encoding = tokenizer.encode_plus(text, padding="max_length", return_tensors="pt", max_length=256, truncation=True)
22
- input_ids, attention_mask = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)
23
 
24
- output = model.generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  input_ids=input_ids,
26
  attention_mask=attention_mask,
27
  max_length=256,
28
- do_sample=True,
29
- top_k=120,
30
- top_p=creativity,
31
- early_stopping=True,
32
- num_return_sequences=1
33
- )
34
- paraphrased = tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
35
- return paraphrased
36
-
37
- def correct_grammar(text):
38
- blob = TextBlob(text)
39
- return str(blob.correct())
40
-
41
- def paraphrase_text(text, creativity=0.9, tone="neutral", improve_grammar=True, batch_size=5):
42
- sentences = split_into_sentences(text)
43
- results = []
44
-
45
- for sentence in sentences:
46
- para = paraphrase_sentence(sentence, creativity)
47
- if improve_grammar:
48
- para = correct_grammar(para)
49
- results.append(para)
50
-
51
- return " ".join(results)
52
-
53
- with gr.Blocks() as demo:
54
- gr.Markdown("<h1>Paraphrasing Tool</h1><p>AI-powered rewriting with grammar improvement and creativity controls.</p>")
55
-
56
- with gr.Row():
57
- input_text = gr.Textbox(lines=8, label="Enter Text")
58
- output_text = gr.Textbox(lines=8, label="Paraphrased Output")
59
-
60
- creativity = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Creativity (top_p)")
61
- improve_grammar = gr.Checkbox(value=True, label="Improve Grammar")
62
- tone = gr.Radio(["neutral", "formal", "casual"], label="Tone (placeholder)", value="neutral")
63
- batch_size = gr.Slider(1, 10, value=5, step=1, label="Batch Size")
64
-
65
- run_button = gr.Button("Paraphrase")
66
-
67
- run_button.click(
68
- paraphrase_text,
69
- inputs=[input_text, creativity, tone, improve_grammar, batch_size],
70
- outputs=output_text
71
- )
72
-
73
- if __name__ == "__main__":
74
- demo.launch(
75
- server_name="0.0.0.0",
76
- server_port=7860,
77
- show_api=True,
78
- favicon_path="favicon.ico"
79
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import nltk
3
+ nltk.download('punkt') # Download necessary data
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
 
6
 
7
+ # Load model and tokenizer
8
+ model_name = "Vamsi/T5_Paraphrase_Paws"
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
11
+
12
+ def paraphrase(text):
13
+ if not text.strip():
14
+ return "Please enter some text to paraphrase."
15
+ input_text = f"paraphrase: {text} </s>"
16
+ encoding = tokenizer.encode_plus(
17
+ input_text,
18
+ max_length=256,
19
+ padding="max_length",
20
+ return_tensors="pt",
21
+ truncation=True
22
+ )
23
+ input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
24
+ outputs = model.generate(
25
  input_ids=input_ids,
26
  attention_mask=attention_mask,
27
  max_length=256,
28
+ num_beams=5,
29
+ num_return_sequences=1,
30
+ temperature=1.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  )
32
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
33
+
34
+ # Gradio Interface
35
+ demo = gr.Interface(
36
+ fn=paraphrase,
37
+ inputs=gr.Textbox(label="Enter your text to paraphrase"),
38
+ outputs=gr.Textbox(label="Paraphrased text"),
39
+ title="AI Paraphrasing Tool",
40
+ description="Enter your sentence or paragraph, and the model will return a paraphrased version."
41
+ )
42
+
43
+ demo.launch()