PARAGPT / app.py
Dedeep Vasireddy
Update app.py
74c29d8 verified
import gradio as gr
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
# Set seed
def set_seed(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
set_seed(42)
# Load model and tokenizer
model_name = 'ramsrigouthamg/t5_paraphraser'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Paraphrasing function using top-k/top-p sampling
def paraphrase(text, num_return_sequences=5, top_k=120, top_p=0.98):
sentence = text.strip()
if not sentence:
return "Please enter valid text."
input_text = "paraphrase: " + sentence + " </s>"
encoding = tokenizer.encode_plus(input_text, padding="longest", return_tensors="pt")
input_ids = encoding["input_ids"].to(device)
attention_masks = encoding["attention_mask"].to(device)
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_masks,
do_sample=True,
max_length=256,
top_k=top_k,
top_p=top_p,
early_stopping=True,
num_return_sequences=num_return_sequences
)
final_outputs = []
for beam_output in outputs:
decoded = tokenizer.decode(beam_output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
if decoded.lower() != sentence.lower() and decoded not in final_outputs:
final_outputs.append(decoded)
return "\n\n".join(final_outputs) if final_outputs else "No paraphrases generated."
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## ๐Ÿ“ T5 Paraphraser with Sampling (Top-k & Top-p)")
input_text = gr.Textbox(label="Enter Text to Paraphrase", lines=3, placeholder="Type here...")
num_return = gr.Slider(1, 500, step=10, value=5, label="Number of Paraphrases")
top_k_slider = gr.Slider(10, 200, step=10, value=120, label="Top-k Sampling")
top_p_slider = gr.Slider(0.5, 1.0, step=0.01, value=0.98, label="Top-p Nucleus Sampling")
output = gr.Textbox(label="Paraphrased Outputs", lines=10)
btn = gr.Button("Generate Paraphrases")
btn.click(fn=paraphrase, inputs=[input_text, num_return, top_k_slider, top_p_slider], outputs=output)
demo.launch()