from transformers import PreTrainedModel, PretrainedConfig from torch import nn class TestConfig(PretrainedConfig): model_type = "test-model" def __init__(self, input_dim=4, output_dim=16, **kwargs): self.input_dim = input_dim self.output_dim = output_dim super().__init__(**kwargs) class TestModel(PreTrainedModel): config_class = TestConfig def __init__(self, config): super().__init__(config) self.layer = nn.Linear(config.input_dim, config.output_dim) def forward(self, input): return self.layer(input)