File size: 754 Bytes
5775e05 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | from transformers import PretrainedConfig
class BertFFNNConfig(PretrainedConfig):
model_type = "bert_ffnn"
def __init__(
self,
bert_model_name="microsoft/deberta-v3-base",
hidden_dims=[192, 96],
output_dim=5,
dropout=0.2,
pooling="attention",
freeze_bert=False,
freeze_layers=0,
use_layer_norm=True,
**kwargs
):
super().__init__(**kwargs)
self.bert_model_name = bert_model_name
self.hidden_dims = hidden_dims
self.output_dim = output_dim
self.dropout = dropout
self.pooling = pooling
self.freeze_bert = freeze_bert
self.freeze_layers = freeze_layers
self.use_layer_norm = use_layer_norm
|