import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig # Define the model configuration class SimpleNNConfig(PretrainedConfig): model_type = "simple_nn" def __init__(self, hidden_size=16, num_labels=1, **kwargs): super().__init__(**kwargs) self.hidden_size = hidden_size self.num_labels = num_labels # Define the model architecture class SimpleNN(PreTrainedModel): config_class = SimpleNNConfig def __init__(self, config): super().__init__(config) self.fc1 = nn.Linear(1, config.hidden_size) self.fc2 = nn.Linear(config.hidden_size, config.num_labels) def forward(self, x): x = torch.relu(self.fc1(x)) x = self.fc2(x) return x @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): config = SimpleNNConfig() model = cls(config) model.load_state_dict(torch.load(pretrained_model_name_or_path, map_location=torch.device("cpu"))) return model