VictorM-Coder commited on
Commit
fedaceb
·
verified ·
1 Parent(s): c417d0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -44
app.py CHANGED
@@ -1,49 +1,88 @@
1
- import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
3
  import re
4
 
5
- # Load model
6
- tokenizer = AutoTokenizer.from_pretrained("Vamsi/T5_Paraphrase_Paws")
7
- model = AutoModelForSeq2SeqLM.from_pretrained("Vamsi/T5_Paraphrase_Paws")
8
-
9
- # Function to paraphrase a single chunk
10
- def paraphrase_text(text):
11
- input_text = f"paraphrase: {text} </s>"
12
- input_ids = tokenizer.encode(input_text, return_tensors="pt", truncation=True)
13
- output_ids = model.generate(
14
- input_ids,
15
- max_length=256,
16
- do_sample=True,
17
- top_k=120,
18
- top_p=0.95,
19
- temperature=1.3
20
- )
21
- return tokenizer.decode(output_ids[0], skip_special_tokens=True)
22
-
23
- # Split text into chunks (4 sentences each)
24
- def chunk_text(text, max_sentences=4):
25
- sentences = re.split(r'(?<=[.!?]) +', text.strip())
26
- return [' '.join(sentences[i:i+max_sentences]) for i in range(0, len(sentences), max_sentences)]
27
-
28
- # Paraphrase the full text
29
- def full_article_paraphrase(text):
30
- chunks = chunk_text(text)
31
- return "\n\n".join(paraphrase_text(chunk.strip()) for chunk in chunks if chunk.strip())
32
-
33
- # Gradio pipeline
34
- def paraphrase_pipeline(input_text):
35
- if not input_text or len(input_text.strip()) < 10:
36
- return "Please enter valid text."
37
- return full_article_paraphrase(input_text)
38
-
39
- # Gradio interface
40
- demo = gr.Interface(
41
- fn=paraphrase_pipeline,
42
- inputs=gr.Textbox(label="Paste Text Here", lines=20, placeholder="Enter your text..."),
43
- outputs=gr.Textbox(label="Paraphrased Text"),
44
- title="Smart Paraphraser",
45
- description="Paste your text and get paraphrased output instantly."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  )
47
 
48
- if __name__ == "__main__":
49
- demo.launch()
 
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+ import torch, gradio as gr
3
  import re
4
 
5
+ # --- Load Model ---
6
+ model_name = "prithivida/parrot_paraphraser_on_T5"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model = model.to(device)
12
+ model.eval()
13
+
14
+ # --- Helpers ---
15
+ def split_paragraphs(text):
16
+ """Split text into paragraphs based on line breaks."""
17
+ paragraphs = [p.strip() for p in text.split("\n") if p.strip()]
18
+ return paragraphs
19
+
20
+ def split_sentences(text):
21
+ """Split paragraph into sentences."""
22
+ sentences = re.split(r'(?<=[.!?])\s+', text.strip())
23
+ return [s for s in sentences if s]
24
+
25
+ def clean_sentence(sent):
26
+ """Clean and ensure sentence ends with punctuation."""
27
+ sent = re.sub(r'\s+', ' ', sent).strip()
28
+ if not sent.endswith(('.', '!', '?')):
29
+ sent += "."
30
+ return sent
31
+
32
+ # --- Main function ---
33
+ def paraphrase_fn(text, num_return_sequences=1, temperature=1.2, top_p=0.92):
34
+ if not text.strip():
35
+ return "Enter some text"
36
+
37
+ num_return_sequences = int(num_return_sequences)
38
+ paragraphs = split_paragraphs(text)
39
+ paraphrased_paragraphs = []
40
+
41
+ for para in paragraphs:
42
+ sentences = split_sentences(para)
43
+ paraphrased_sentences = []
44
+
45
+ for sent in sentences:
46
+ input_text = "paraphrase: " + sent + " </s>"
47
+ inputs = tokenizer([input_text], return_tensors="pt", truncation=True, padding=True).to(device)
48
+
49
+ outputs = model.generate(
50
+ **inputs,
51
+ max_new_tokens=128,
52
+ num_return_sequences=num_return_sequences,
53
+ do_sample=True,
54
+ top_p=float(top_p),
55
+ temperature=float(temperature),
56
+ )
57
+ decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
58
+
59
+ seen, unique = set(), []
60
+ for d in decoded:
61
+ d = clean_sentence(d)
62
+ if d not in seen:
63
+ unique.append(d)
64
+ seen.add(d)
65
+
66
+ paraphrased_sentences.append(unique[0])
67
+
68
+ # Join sentences for this paragraph
69
+ paraphrased_paragraphs.append(" ".join(paraphrased_sentences))
70
+
71
+ # Join paragraphs with double line breaks to preserve paragraphing
72
+ return "\n\n".join(paraphrased_paragraphs)
73
+
74
+ # --- Gradio Interface ---
75
+ iface = gr.Interface(
76
+ fn=paraphrase_fn,
77
+ inputs=[
78
+ gr.Textbox(lines=12, placeholder="Paste text here..."),
79
+ gr.Slider(1, 3, step=1, value=1, label="Variants"),
80
+ gr.Slider(0.5, 2.0, step=0.1, value=1.2, label="Temperature"),
81
+ gr.Slider(0.6, 1.0, step=0.01, value=0.92, label="Top-p"),
82
+ ],
83
+ outputs=gr.Textbox(label="Output"),
84
+ title="📝 Writenix API",
85
+ description="This Space provides a UI *and* an API for paraphrasing text while preserving paragraphs."
86
  )
87
 
88
+ iface.launch()