Commit
·
b8ddcf4
1
Parent(s):
af52488
Fix an error
Browse files- distill_bert_to_lstm.py +1 -1
distill_bert_to_lstm.py
CHANGED
|
@@ -147,7 +147,7 @@ def main():
|
|
| 147 |
bert_model.load_state_dict(torch.load(args.bert_model_path, map_location=device))
|
| 148 |
logger.info(f"Loaded teacher model from {args.bert_model_path}")
|
| 149 |
|
| 150 |
-
vocab_size =
|
| 151 |
|
| 152 |
logger.info(f"LSTM Vocabulary size: {vocab_size}")
|
| 153 |
print("LSTM Vocabulary size: ", vocab_size)
|
|
|
|
| 147 |
bert_model.load_state_dict(torch.load(args.bert_model_path, map_location=device))
|
| 148 |
logger.info(f"Loaded teacher model from {args.bert_model_path}")
|
| 149 |
|
| 150 |
+
vocab_size = bert_train_dataset.tokenizer.vocab_size
|
| 151 |
|
| 152 |
logger.info(f"LSTM Vocabulary size: {vocab_size}")
|
| 153 |
print("LSTM Vocabulary size: ", vocab_size)
|