lprimeau commited on
Commit
cbabe4d
·
verified ·
1 Parent(s): a13ea26

Upload BasicLinear

Browse files
Files changed (4) hide show
  1. config.json +13 -3
  2. custom_config.py +18 -0
  3. custom_net.py +21 -0
  4. model.safetensors +2 -2
config.json CHANGED
@@ -1,5 +1,15 @@
1
  {
 
 
 
 
 
 
 
2
  "bias": true,
3
- "in_features": 3,
4
- "out_features": 32
5
- }
 
 
 
 
1
  {
2
+ "architectures": [
3
+ "BasicLinear"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "custom_config.LinearConfig",
7
+ "AutoModel": "custom_net.BasicLinear"
8
+ },
9
  "bias": true,
10
+ "dtype": "float32",
11
+ "in_features": 10,
12
+ "model_type": "linear",
13
+ "out_features": 1,
14
+ "transformers_version": "4.57.1"
15
+ }
custom_config.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class LinearConfig(PretrainedConfig):
4
+ model_type = 'linear'
5
+
6
+ def __init__(self,
7
+ in_features=10,
8
+ out_features=1,
9
+ bias=True,
10
+ **kwargs):
11
+
12
+ self.in_features = in_features
13
+ self.out_features = out_features
14
+ self.bias = bias
15
+
16
+ super().__init__(**kwargs)
17
+
18
+
custom_net.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from .custom_config import LinearConfig
3
+ import torch.nn as nn
4
+ import torch
5
+
6
+ class BasicLinear(PreTrainedModel):
7
+ config_class = LinearConfig
8
+
9
+ def __init__(self, config):
10
+ super().__init__(config)
11
+ self.weight = nn.Parameter(torch.randn(config.out_features, config.in_features) * 0.01)
12
+ if config.bias:
13
+ self.bias = nn.Parameter(torch.zeros(config.out_features))
14
+ else:
15
+ self.bias = None
16
+
17
+ def forward(self, x):
18
+ out = x @ self.weight.T
19
+ if self.bias is not None:
20
+ out = out + self.bias
21
+ return out
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fc7f778d43fa531ede820babf86b9d90ed089361ba9992e09b60265741ab0035
3
- size 648
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3638ec69400cb6873d0c92abed84a49010d46d848f98857945787e072c17fc2
3
+ size 204