sedtha commited on
Commit
13abb53
·
verified ·
1 Parent(s): edd7bc2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +1 -1
main.py CHANGED
@@ -97,7 +97,7 @@ def predict(text: str) -> str:
97
  pred = torch.argmax(out, dim=-1)[0]
98
 
99
  # Keep the prediction same length as input
100
- pred = pred[:input_len]
101
 
102
  return "".join(idx_to_char[i.item()] for i in pred)
103
 
 
97
  pred = torch.argmax(out, dim=-1)[0]
98
 
99
  # Keep the prediction same length as input
100
+ pred = pred[:input_len+1]
101
 
102
  return "".join(idx_to_char[i.item()] for i in pred)
103