Sushil Singh commited on
Commit
0ab8f3a
·
1 Parent(s): b2fc459

added model file

Browse files
Files changed (1) hide show
  1. model.py +38 -0
model.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ from transformers import AutoConfig, AutoModel, PretrainedConfig
5
+
6
+ class SimpleLinearConfig(PretrainedConfig):
7
+ model_type = "simple-linear-model"
8
+
9
+ def __init__(self, input_dim=768, output_dim=512, **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.input_dim = input_dim
12
+ self.output_dim = output_dim
13
+
14
+ class SimpleLinearPreTrainedModel(PreTrainedModel):
15
+ config_class = SimpleLinearConfig
16
+
17
+ def __init__(self, config: SimpleLinearConfig):
18
+ super().__init__(config)
19
+ self.linear = nn.Linear(config.input_dim, config.output_dim)
20
+ self.post_init() # This calls init_weights internally
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ return self.linear(x)
24
+
25
+ def init_weights(self):
26
+ # Standard weight init
27
+ for name, param in self.named_parameters():
28
+ if param.requires_grad:
29
+ if "weight" in name:
30
+ nn.init.xavier_uniform_(param)
31
+ elif "bias" in name:
32
+ nn.init.zeros_(param)
33
+
34
+ # Register our config class with AutoConfig
35
+ AutoConfig.register("simple-linear-model", SimpleLinearConfig)
36
+
37
+ # Register our model class with AutoModel
38
+ AutoModel.register(SimpleLinearConfig, SimpleLinearPreTrainedModel)