|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
|
|
|
|
|
|
class SimpleNNConfig(PretrainedConfig): |
|
|
model_type = "simple_nn" |
|
|
|
|
|
def __init__(self, hidden_size=16, num_labels=1, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.hidden_size = hidden_size |
|
|
self.num_labels = num_labels |
|
|
|
|
|
|
|
|
class SimpleNN(PreTrainedModel): |
|
|
config_class = SimpleNNConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.fc1 = nn.Linear(1, config.hidden_size) |
|
|
self.fc2 = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
def forward(self, x): |
|
|
x = torch.relu(self.fc1(x)) |
|
|
x = self.fc2(x) |
|
|
return x |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
|
|
config = SimpleNNConfig() |
|
|
model = cls(config) |
|
|
model.load_state_dict(torch.load(pretrained_model_name_or_path, map_location=torch.device("cpu"))) |
|
|
return model |
|
|
|