Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
| import torch | |
| st.title("Pegasus Generative QA") | |
| # Load the Pegasus model and tokenizer | |
| model_name = 'google/pegasus-large' | |
| torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| tokenizer = PegasusTokenizer.from_pretrained(model_name) | |
| model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device) | |
| # Create a text input box for the user to enter their input text | |
| input_text = st.text_area("Enter your input text:") | |
| if st.button("Generate QA"): | |
| # Prepare the input text for the Pegasus model | |
| src_text = [input_text] | |
| batch = tokenizer.prepare_seq2seq_batch(src_text, truncation=True, padding='longest') | |
| batch = {k: torch.tensor(v).to(torch_device) for k, v in batch.items()} | |
| # Generate the output text using the Pegasus model | |
| translated = model.generate(**batch) | |
| tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True) | |
| # Display the generated output text to the user | |
| st.write("Generated QA:") | |
| st.write(tgt_text[0]) | |