VictorM-Coder's picture
Update app.py
8fdcc72 verified
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch, gradio as gr
import re
# --- Load Model ---
model_name = "Vamsi/T5_Paraphrase_Paws" # switched to Vamsi model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
# --- Helpers ---
def split_paragraphs(text):
"""Split text into paragraphs based on line breaks."""
paragraphs = [p.strip() for p in text.split("\n") if p.strip()]
return paragraphs
def split_sentences(text):
"""Split paragraph into sentences."""
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
return [s for s in sentences if s]
def clean_sentence(sent):
"""Clean and ensure sentence ends with punctuation."""
sent = re.sub(r'\s+', ' ', sent).strip()
if not sent.endswith(('.', '!', '?')):
sent += "."
return sent
# --- Main function ---
def paraphrase_fn(text, num_return_sequences=1, temperature=1.2, top_p=0.92):
if not text.strip():
return "Enter some text"
num_return_sequences = int(num_return_sequences)
paragraphs = split_paragraphs(text)
paraphrased_paragraphs = []
for para in paragraphs:
sentences = split_sentences(para)
paraphrased_sentences = []
for sent in sentences:
input_text = "paraphrase: " + sent + " </s>"
inputs = tokenizer([input_text], return_tensors="pt", truncation=True, padding=True).to(device)
outputs = model.generate(
**inputs,
max_new_tokens=128,
num_return_sequences=num_return_sequences,
do_sample=True,
top_p=float(top_p),
temperature=float(temperature),
)
decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
seen, unique = set(), []
for d in decoded:
d = clean_sentence(d)
if d not in seen:
unique.append(d)
seen.add(d)
paraphrased_sentences.append(unique[0])
# Join sentences for this paragraph
paraphrased_paragraphs.append(" ".join(paraphrased_sentences))
# Join paragraphs with double line breaks to preserve paragraphing
return "\n\n".join(paraphrased_paragraphs)
# --- Gradio Interface ---
iface = gr.Interface(
fn=paraphrase_fn,
inputs=[
gr.Textbox(lines=12, placeholder="Paste text here..."),
gr.Slider(1, 3, step=1, value=1, label="Variants"),
gr.Slider(0.5, 2.0, step=0.1, value=1.2, label="Temperature"),
gr.Slider(0.6, 1.0, step=0.01, value=0.92, label="Top-p"),
],
outputs=gr.Textbox(label="Output"),
title="📝 Writenix API",
description="This Space provides a UI *and* an API for paraphrasing text while preserving paragraphs."
)
iface.launch()