File size: 477 Bytes
ed52679 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | from transformers import PretrainedConfig
class SimpleMLPConfig(PretrainedConfig):
model_type = "simple_mlp"
def __init__(
self,
input_dim=768,
hidden_dim=256,
num_classes=2,
dropout_rate=0.1,
**kwargs
):
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_classes = num_classes
self.dropout_rate = dropout_rate
super().__init__(**kwargs) |