pegasus-test / app.py
varun500's picture
Update app.py
256dc7a
import streamlit as st
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import torch
st.title("Pegasus Generative QA")
# Load the Pegasus model and tokenizer
model_name = 'google/pegasus-large'
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
# Create a text input box for the user to enter their input text
input_text = st.text_area("Enter your input text:")
if st.button("Generate QA"):
# Prepare the input text for the Pegasus model
src_text = [input_text]
batch = tokenizer.prepare_seq2seq_batch(src_text, truncation=True, padding='longest')
batch = {k: torch.tensor(v).to(torch_device) for k, v in batch.items()}
# Generate the output text using the Pegasus model
translated = model.generate(**batch)
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
# Display the generated output text to the user
st.write("Generated QA:")
st.write(tgt_text[0])