Spaces:
Runtime error
Runtime error
| from transformers import PreTrainedTokenizerFast | |
| from tokenizers import SentencePieceBPETokenizer | |
| from transformers import BartForConditionalGeneration | |
| import streamlit as st | |
| import torch | |
| def tokenizer(): | |
| tokenizer = PreTrainedTokenizerFast.from_pretrained('dnrso/koBART_Sum_Review_finetuning') | |
| return tokenizer | |
| def get_model(): | |
| model = BartForConditionalGeneration.from_pretrained('dnrso/koBART_Sum_Review_finetuning') | |
| model.eval() | |
| return model | |
| default_text = '''게임을 하면서 사용하기 좋아요 음질도 괜찮고 착용감도 좋고 이어컵측면에 불빛도 이뻐요 가성비 정말 좋은 제품입니다 | |
| ''' | |
| model = get_model() | |
| tokenizer = tokenizer() | |
| st.title("Review Summarization Test") | |
| text = st.text_area("Input:", value=default_text) | |
| st.markdown("Review Data") | |
| st.write(text) | |
| if text: | |
| st.markdown("## Predict Summary") | |
| with st.spinner('processing..'): | |
| raw_input_ids = tokenizer.encode(text) | |
| input_ids = [tokenizer.bos_token_id] + \ | |
| raw_input_ids + [tokenizer.eos_token_id] | |
| summary_ids = model.generate(torch.tensor([input_ids]), | |
| max_length=256, | |
| early_stopping=True, | |
| repetition_penalty=2.0) | |
| summ = tokenizer.decode(summary_ids.squeeze().tolist(), skip_special_tokens=True) | |
| st.write(summ) | |