semasg commited on
Commit
d73d3e5
·
verified ·
1 Parent(s): 7d82d7d

Upload custom_model.py

Browse files
Files changed (1) hide show
  1. custom_model.py +33 -0
custom_model.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, PretrainedConfig
4
+
5
+ # Define the model configuration
6
+ class SimpleNNConfig(PretrainedConfig):
7
+ model_type = "simple_nn"
8
+
9
+ def __init__(self, hidden_size=16, num_labels=1, **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.hidden_size = hidden_size
12
+ self.num_labels = num_labels
13
+
14
+ # Define the model architecture
15
+ class SimpleNN(PreTrainedModel):
16
+ config_class = SimpleNNConfig
17
+
18
+ def __init__(self, config):
19
+ super().__init__(config)
20
+ self.fc1 = nn.Linear(1, config.hidden_size)
21
+ self.fc2 = nn.Linear(config.hidden_size, config.num_labels)
22
+
23
+ def forward(self, x):
24
+ x = torch.relu(self.fc1(x))
25
+ x = self.fc2(x)
26
+ return x
27
+
28
+ @classmethod
29
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
30
+ config = SimpleNNConfig()
31
+ model = cls(config)
32
+ model.load_state_dict(torch.load(pretrained_model_name_or_path, map_location=torch.device("cpu")))
33
+ return model