gal-lardo's picture
Upload folder using huggingface_hub
76efa8b verified
---
language: en
license: mit
datasets:
- glue/rte
tags:
- text-classification
- glue
- bert
- recognizing textual entailment
- assignment
- mean-pooling
metrics:
- accuracy
---
# BERT + Mean Pooling + MLP for RTE (EEE 486/586 Assignment - Part 2)
This model is a fine-tuned version of `bert-base-uncased` on the RTE (Recognizing Textual Entailment) task from the GLUE benchmark. It was developed as part of the EEE 486/586 Statistical Foundations of Natural Language Processing course assignment (Part 2).
## Model Architecture
This model explores an alternative to the standard `BertForSequenceClassification` architecture:
- Uses the standard `bert-base-uncased` model to obtain token embeddings (`last_hidden_state`).
- **Mean Pooling:** Instead of using the [CLS] token's pooler output, it calculates the mean of the `last_hidden_state` across all non-padding tokens (using the attention mask) to get a single sequence representation vector.
- **MLP Classifier Head:** The mean-pooled representation is passed through dropout and then a multi-layer perceptron (MLP) head for classification. The MLP structure was determined by the hyperparameter search (`hidden_size_multiplier=4`).
- The final layer outputs logits for the 2 classes (entailment/not\_entailment).
**Note:** Because this uses a custom architecture (`BertMeanPoolClassifier`), it cannot be loaded directly using `AutoModelForSequenceClassification.from_pretrained()`. You need the model's class definition (provided in the assignment code/report) and then load the `state_dict` (`pytorch_model.bin`) into an instance of that class.
## Performance
The model was trained using hyperparameters found via Optuna. The final training run (5 epochs with early stopping based on validation accuracy) achieved the following:
- **Best Validation Accuracy:** **0.6931** (achieved at Epoch 3)
- Final Validation Accuracy (Epoch 5): 0.6823
- Final Validation Loss (Epoch 5): 1.4258
- Final Training Loss (Epoch 5): 0.0797
The model showed strong fitting capabilities but exhibited signs of overfitting after epoch 3, as indicated by the rising validation loss. The best checkpoint based on accuracy was saved.
## Best Hyperparameters (from Optuna)
| Hyperparameter | Value |
|--------------------------|-----------------------|
| Learning Rate | 3.518e-05 |
| Max Sequence Length | 128 |
| Dropout Rate (Classifier)| 0.4 |
| Batch Size | 16 |
| Hidden Size Multiplier | 4 |
| Epochs (Optuna Best Trial) | 3 |
## Intended Use & Limitations
This model is intended for the RTE task as part of the specific course assignment. Due to its custom architecture, direct loading via `AutoModelForSequenceClassification` is not supported.
```