import streamlit as st from transformers import PegasusForConditionalGeneration, PegasusTokenizer import torch st.title("Pegasus Generative QA") # Load the Pegasus model and tokenizer model_name = 'google/pegasus-large' torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' tokenizer = PegasusTokenizer.from_pretrained(model_name) model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device) # Create a text input box for the user to enter their input text input_text = st.text_area("Enter your input text:") if st.button("Generate QA"): # Prepare the input text for the Pegasus model src_text = [input_text] batch = tokenizer.prepare_seq2seq_batch(src_text, truncation=True, padding='longest') batch = {k: torch.tensor(v).to(torch_device) for k, v in batch.items()} # Generate the output text using the Pegasus model translated = model.generate(**batch) tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True) # Display the generated output text to the user st.write("Generated QA:") st.write(tgt_text[0])