File size: 627 Bytes
fc0237e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
from transformers import PreTrainedModel, PretrainedConfig
from torch import nn
class TestConfig(PretrainedConfig):
model_type = "test-model"
def __init__(self, input_dim=4, output_dim=16, **kwargs):
self.input_dim = input_dim
self.output_dim = output_dim
super().__init__(**kwargs)
class TestModel(PreTrainedModel):
config_class = TestConfig
def __init__(self, config):
super().__init__(config)
self.layer = nn.Linear(config.input_dim, config.output_dim)
def forward(self, input):
return self.layer(input)
|