File size: 477 Bytes
ed52679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from transformers import PretrainedConfig

class SimpleMLPConfig(PretrainedConfig):
    model_type = "simple_mlp"

    def __init__(
            self,
            input_dim=768,
            hidden_dim=256,
            num_classes=2,
            dropout_rate=0.1,
            **kwargs
    ):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.dropout_rate = dropout_rate
        super().__init__(**kwargs)