jesse-tong commited on
Commit
8e3d6fe
·
1 Parent(s): 8a33e9c

Reduce max sequence length to 250 as PhoBERT allows maximum 256 tokens

Browse files
distill_bert_to_lstm.py CHANGED
@@ -44,7 +44,7 @@ def main():
44
  # BERT model arguments
45
  parser.add_argument("--bert_model", type=str, default="bert-base-uncased", help="BERT model to use")
46
  parser.add_argument("--bert_model_path", type=str, required=True, help="Path to saved BERT model weights")
47
- parser.add_argument("--max_seq_length", type=int, default=512, help="Maximum sequence length")
48
 
49
  # LSTM model arguments
50
  parser.add_argument("--embedding_dim", type=int, default=300, help="Dimension of word embeddings in LSTM")
 
44
  # BERT model arguments
45
  parser.add_argument("--bert_model", type=str, default="bert-base-uncased", help="BERT model to use")
46
  parser.add_argument("--bert_model_path", type=str, required=True, help="Path to saved BERT model weights")
47
+ parser.add_argument("--max_seq_length", type=int, default=250, help="Maximum sequence length (e.g., 250 for PhoBERT as PhoBERT allows max_position_embeddings=258)")
48
 
49
  # LSTM model arguments
50
  parser.add_argument("--embedding_dim", type=int, default=300, help="Dimension of word embeddings in LSTM")
inference_example.py CHANGED
@@ -11,7 +11,7 @@ if __name__ == "__main__":
11
  parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset")
12
  parser.add_argument("--bert_model", type=str, default="bert-base-uncased", help="Pre-trained BERT model name")
13
  parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model")
14
- parser.add_argument("--max_seq_length", type=int, default=512, help="Maximum sequence length for BERT")
15
  parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training and evaluation")
16
  parser.add_argument("--num_classes", type=int, required=True, help="Number of classes for classification")
17
  parser.add_argument("--text_column", type=str, default="text", help="Column name for text data")
 
11
  parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset")
12
  parser.add_argument("--bert_model", type=str, default="bert-base-uncased", help="Pre-trained BERT model name")
13
  parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model")
14
+ parser.add_argument("--max_seq_length", type=int, default=250, help="Maximum sequence length for BERT (e.g., 250 for PhoBERT as PhoBERT allows max_position_embeddings=258)")
15
  parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training and evaluation")
16
  parser.add_argument("--num_classes", type=int, required=True, help="Number of classes for classification")
17
  parser.add_argument("--text_column", type=str, default="text", help="Column name for text data")
inference_lstm.py CHANGED
@@ -16,7 +16,7 @@ if __name__ == "__main__":
16
  parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset")
17
  parser.add_argument("--bert_model", type=str, default="bert-base-uncased", help="BERT model name or path used for distillation (as we'll use its tokenizer)")
18
  parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model")
19
- parser.add_argument("--max_seq_length", type=int, default=512, help="Maximum sequence length for LSTM")
20
  parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training and evaluation")
21
  parser.add_argument("--num_classes", type=int, required=True, help="Number of classes for classification")
22
  parser.add_argument("--text_column", type=str, default="text", help="Column name for text data")
 
16
  parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset")
17
  parser.add_argument("--bert_model", type=str, default="bert-base-uncased", help="BERT model name or path used for distillation (as we'll use its tokenizer)")
18
  parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model")
19
+ parser.add_argument("--max_seq_length", type=int, default=250, help="Maximum sequence length for LSTM")
20
  parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training and evaluation")
21
  parser.add_argument("--num_classes", type=int, required=True, help="Number of classes for classification")
22
  parser.add_argument("--text_column", type=str, default="text", help="Column name for text data")
train.py CHANGED
@@ -40,7 +40,7 @@ def main():
40
  parser.add_argument("--bert_model", type=str, default="bert-base-uncased",
41
  help="BERT model to use (e.g., bert-base-uncased, bert-large-uncased)")
42
  parser.add_argument("--num_classes", type=int, required=True, help="Number of classes to predict")
43
- parser.add_argument("--max_length", type=int, default=512, help="Maximum sequence length")
44
  parser.add_argument("--dropout", type=float, default=0.1, help="Dropout probability")
45
 
46
  # Training arguments
 
40
  parser.add_argument("--bert_model", type=str, default="bert-base-uncased",
41
  help="BERT model to use (e.g., bert-base-uncased, bert-large-uncased)")
42
  parser.add_argument("--num_classes", type=int, required=True, help="Number of classes to predict")
43
+ parser.add_argument("--max_length", type=int, default=250, help="Maximum sequence length (PhoBERT has 258 max_position_embeddings so we choose 250)")
44
  parser.add_argument("--dropout", type=float, default=0.1, help="Dropout probability")
45
 
46
  # Training arguments