haydenCho commited on
Commit
633b064
·
1 Parent(s): 7b2286b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -0
app.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import streamlit as st
3
+ from kobart import get_kobart_tokenizer
4
+ from transformers.models.bart import BartForConditionalGeneration
5
+
6
+ @st.cache
7
+ def load_model():
8
+ model = BartForConditionalGeneration.from_pretrained("gogamza/kobart-base-v1")
9
+ # tokenizer = get_kobart_tokenizer()
10
+ return model
11
+
12
+ model = load_model()
13
+ tokenizer = get_kobart_tokenizer()
14
+ st.title("KoBART 요약 Test")
15
+ text = st.text_area("뉴스 입력:")
16
+
17
+ st.markdown("## 뉴스 원문")
18
+ st.write(text)
19
+
20
+ if text:
21
+ text = text.replace('\n', '')
22
+ st.markdown("## KoBART 요약 결과")
23
+ with st.spinner('processing..'):
24
+ input_ids = tokenizer.encode(text)
25
+ input_ids = torch.tensor(input_ids)
26
+ input_ids = input_ids.unsqueeze(0)
27
+ output = model.generate(input_ids, eos_token_id=1, max_length=512, num_beams=5)
28
+ output = tokenizer.decode(output[0], skip_special_tokens=True)
29
+ st.write(output)