import gradio as gr import torch from transformers import T5ForConditionalGeneration, T5Tokenizer # Set seed def set_seed(seed): torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) set_seed(42) # Load model and tokenizer model_name = 'ramsrigouthamg/t5_paraphraser' tokenizer = T5Tokenizer.from_pretrained(model_name) model = T5ForConditionalGeneration.from_pretrained(model_name) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # Paraphrasing function using top-k/top-p sampling def paraphrase(text, num_return_sequences=5, top_k=120, top_p=0.98): sentence = text.strip() if not sentence: return "Please enter valid text." input_text = "paraphrase: " + sentence + " " encoding = tokenizer.encode_plus(input_text, padding="longest", return_tensors="pt") input_ids = encoding["input_ids"].to(device) attention_masks = encoding["attention_mask"].to(device) outputs = model.generate( input_ids=input_ids, attention_mask=attention_masks, do_sample=True, max_length=256, top_k=top_k, top_p=top_p, early_stopping=True, num_return_sequences=num_return_sequences ) final_outputs = [] for beam_output in outputs: decoded = tokenizer.decode(beam_output, skip_special_tokens=True, clean_up_tokenization_spaces=True) if decoded.lower() != sentence.lower() and decoded not in final_outputs: final_outputs.append(decoded) return "\n\n".join(final_outputs) if final_outputs else "No paraphrases generated." # Gradio UI with gr.Blocks() as demo: gr.Markdown("## 📝 T5 Paraphraser with Sampling (Top-k & Top-p)") input_text = gr.Textbox(label="Enter Text to Paraphrase", lines=3, placeholder="Type here...") num_return = gr.Slider(1, 500, step=10, value=5, label="Number of Paraphrases") top_k_slider = gr.Slider(10, 200, step=10, value=120, label="Top-k Sampling") top_p_slider = gr.Slider(0.5, 1.0, step=0.01, value=0.98, label="Top-p Nucleus Sampling") output = gr.Textbox(label="Paraphrased Outputs", lines=10) btn = gr.Button("Generate Paraphrases") btn.click(fn=paraphrase, inputs=[input_text, num_return, top_k_slider, top_p_slider], outputs=output) demo.launch()