File size: 627 Bytes
fc0237e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import PreTrainedModel, PretrainedConfig
from torch import nn

class TestConfig(PretrainedConfig):
    model_type = "test-model" 
    
    def __init__(self, input_dim=4, output_dim=16, **kwargs):
        self.input_dim = input_dim
        self.output_dim = output_dim
        super().__init__(**kwargs)
    
    
class TestModel(PreTrainedModel):
    config_class = TestConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.layer = nn.Linear(config.input_dim, config.output_dim)
        
    def forward(self, input):
        return self.layer(input)