gal-lardo commited on
Commit
91dc934
·
verified ·
1 Parent(s): d9f25b8

Upload BERT-RTE-LinearClassifier for EEE 486/586 Assignment

Browse files
Files changed (2) hide show
  1. README.md +19 -1
  2. config.json +5 -5
README.md CHANGED
@@ -23,11 +23,29 @@ Unlike the standard BERT classification approach, this model implements a custom
23
 
24
  - Uses BERT base model as the encoder for feature extraction
25
  - Replaces the standard single linear classification head with **multiple linear layers**:
26
- - First expansion layer: hidden_size → hidden_size*2
27
  - Intermediate layer with ReLU activation and dropout
28
  - Final classification layer
29
  - Uses label smoothing of 0.1 in the loss function for better generalization
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  ## Usage
33
 
 
23
 
24
  - Uses BERT base model as the encoder for feature extraction
25
  - Replaces the standard single linear classification head with **multiple linear layers**:
26
+ - First expansion layer: hidden_size → hidden_size
27
  - Intermediate layer with ReLU activation and dropout
28
  - Final classification layer
29
  - Uses label smoothing of 0.1 in the loss function for better generalization
30
 
31
+ ## Performance
32
+
33
+ The model achieves **69.31%** accuracy on the RTE validation set, with the following training dynamics:
34
+ - Best validation accuracy: 69.31% (epoch 4)
35
+ - Final validation accuracy: 69.31% (with early stopping)
36
+
37
+ ## Hyperparameters
38
+
39
+ The model was optimized using Optuna hyperparameter search:
40
+
41
+ | Hyperparameter | Value |
42
+ |----------------|-------|
43
+ | Learning rate | 1.304e-05 |
44
+ | Max sequence length | 128 |
45
+ | Dropout rate | 0.1 |
46
+ | Hidden size multiplier | 1 |
47
+ | Batch size | 16 |
48
+ | Training epochs | 4 (early stopping) |
49
 
50
  ## Usage
51
 
config.json CHANGED
@@ -3,13 +3,13 @@
3
  "BertForSequenceClassification"
4
  ],
5
  "attention_probs_dropout_prob": 0.1,
6
- "classifier_dropout": 0.2,
7
  "custom_params": {
8
  "batch_size": 16,
9
- "hidden_size_multiplier": 2,
10
- "learning_rate": 1.7166350301570613e-05,
11
- "max_sequence_length": 128,
12
- "weight_decay": 0.04
13
  },
14
  "gradient_checkpointing": false,
15
  "hidden_act": "gelu",
 
3
  "BertForSequenceClassification"
4
  ],
5
  "attention_probs_dropout_prob": 0.1,
6
+ "classifier_dropout": 0.1,
7
  "custom_params": {
8
  "batch_size": 16,
9
+ "dropout_rate": 0.1,
10
+ "hidden_size_multiplier": 1,
11
+ "learning_rate": 1.304261063040958e-05,
12
+ "max_sequence_length": 128
13
  },
14
  "gradient_checkpointing": false,
15
  "hidden_act": "gelu",