test-model / model.py
copper-light's picture
Upload model
fc0237e verified
from transformers import PreTrainedModel, PretrainedConfig
from torch import nn
class TestConfig(PretrainedConfig):
model_type = "test-model"
def __init__(self, input_dim=4, output_dim=16, **kwargs):
self.input_dim = input_dim
self.output_dim = output_dim
super().__init__(**kwargs)
class TestModel(PreTrainedModel):
config_class = TestConfig
def __init__(self, config):
super().__init__(config)
self.layer = nn.Linear(config.input_dim, config.output_dim)
def forward(self, input):
return self.layer(input)