| from transformers import PreTrainedModel, PretrainedConfig | |
| from torch import nn | |
| class TestConfig(PretrainedConfig): | |
| model_type = "test-model" | |
| def __init__(self, input_dim=4, output_dim=16, **kwargs): | |
| self.input_dim = input_dim | |
| self.output_dim = output_dim | |
| super().__init__(**kwargs) | |
| class TestModel(PreTrainedModel): | |
| config_class = TestConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.layer = nn.Linear(config.input_dim, config.output_dim) | |
| def forward(self, input): | |
| return self.layer(input) | |