|
|
import streamlit as st |
|
|
from kobart import get_kobart_tokenizer |
|
|
from transformers.models.bart import BartForConditionalGeneration |
|
|
|
|
|
@st.cache |
|
|
def load_model(): |
|
|
model = BartForConditionalGeneration.from_pretrained("gogamza/kobart-base-v1") |
|
|
|
|
|
return model |
|
|
|
|
|
model = load_model() |
|
|
tokenizer = get_kobart_tokenizer() |
|
|
st.title("KoBART ์์ฝ Test") |
|
|
text = st.text_area("๋ด์ค ์
๋ ฅ:") |
|
|
|
|
|
st.markdown("## ๋ด์ค ์๋ฌธ") |
|
|
st.write(text) |
|
|
|
|
|
if text: |
|
|
text = text.replace('\n', '') |
|
|
st.markdown("## KoBART ์์ฝ ๊ฒฐ๊ณผ") |
|
|
with st.spinner('processing..'): |
|
|
input_ids = tokenizer.encode(text) |
|
|
input_ids = torch.tensor(input_ids) |
|
|
input_ids = input_ids.unsqueeze(0) |
|
|
output = model.generate(input_ids, eos_token_id=1, max_length=512, num_beams=5) |
|
|
output = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
st.write(output) |
|
|
|