mpolacek commited on
Commit
0d38908
·
verified ·
1 Parent(s): a81a35c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -27,9 +27,10 @@ def load_model():
27
  encoder = model.get_encoder()
28
  decoder = model.model.decoder
29
  lm_head = model.lm_head
30
- return encoder, decoder, lm_head, tokenizer
 
31
 
32
- encoder, decoder, lm_head, tokenizer = load_model()
33
 
34
  # Title
35
  st.title("Czech punctuation and capitalization restoration (BART-Small)")
@@ -56,7 +57,7 @@ if submit_button:
56
  max_length = 50
57
 
58
  # Empty input for generated text
59
- generated_ids = torch.tensor([[model.config.decoder_start_token_id]])
60
 
61
  output_placeholder = st.empty()
62
 
 
27
  encoder = model.get_encoder()
28
  decoder = model.model.decoder
29
  lm_head = model.lm_head
30
+ start_token = model.config.decoder_start_token_id
31
+ return encoder, decoder, lm_head, tokenizer, start_token
32
 
33
+ encoder, decoder, lm_head, tokenizer, start_token = load_model()
34
 
35
  # Title
36
  st.title("Czech punctuation and capitalization restoration (BART-Small)")
 
57
  max_length = 50
58
 
59
  # Empty input for generated text
60
+ generated_ids = torch.tensor([[start_token]])
61
 
62
  output_placeholder = st.empty()
63