Update app.py
Browse files
app.py
CHANGED
|
@@ -27,9 +27,10 @@ def load_model():
|
|
| 27 |
encoder = model.get_encoder()
|
| 28 |
decoder = model.model.decoder
|
| 29 |
lm_head = model.lm_head
|
| 30 |
-
|
|
|
|
| 31 |
|
| 32 |
-
encoder, decoder, lm_head, tokenizer = load_model()
|
| 33 |
|
| 34 |
# Title
|
| 35 |
st.title("Czech punctuation and capitalization restoration (BART-Small)")
|
|
@@ -56,7 +57,7 @@ if submit_button:
|
|
| 56 |
max_length = 50
|
| 57 |
|
| 58 |
# Empty input for generated text
|
| 59 |
-
generated_ids = torch.tensor([[
|
| 60 |
|
| 61 |
output_placeholder = st.empty()
|
| 62 |
|
|
|
|
| 27 |
encoder = model.get_encoder()
|
| 28 |
decoder = model.model.decoder
|
| 29 |
lm_head = model.lm_head
|
| 30 |
+
start_token = model.config.decoder_start_token_id
|
| 31 |
+
return encoder, decoder, lm_head, tokenizer, start_token
|
| 32 |
|
| 33 |
+
encoder, decoder, lm_head, tokenizer, start_token = load_model()
|
| 34 |
|
| 35 |
# Title
|
| 36 |
st.title("Czech punctuation and capitalization restoration (BART-Small)")
|
|
|
|
| 57 |
max_length = 50
|
| 58 |
|
| 59 |
# Empty input for generated text
|
| 60 |
+
generated_ids = torch.tensor([[start_token]])
|
| 61 |
|
| 62 |
output_placeholder = st.empty()
|
| 63 |
|