zhihanyang commited on
Commit
31e551d
·
verified ·
1 Parent(s): 53756d1

Upload MLP

Browse files
Files changed (4) hide show
  1. config.json +16 -0
  2. config.py +16 -0
  3. model.py +32 -0
  4. model.safetensors +3 -0
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "zhihanyang/mlp",
3
+ "architectures": [
4
+ "MLP"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "config.MLPConfig",
8
+ "AutoModelForMaskedLM": "model.MLP"
9
+ },
10
+ "hidden_dim": 32,
11
+ "input_dim": 16,
12
+ "model_type": "MLP",
13
+ "output_dim": 16,
14
+ "torch_dtype": "float32",
15
+ "transformers_version": "4.49.0"
16
+ }
config.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+
3
+
4
+ class MLPConfig(transformers.PretrainedConfig):
5
+ model_type = 'MLP'
6
+
7
+ def __init__(
8
+ self,
9
+ input_dim: int = 16,
10
+ hidden_dim: int = 32,
11
+ output_dim: int = 16,
12
+ **kwargs):
13
+ super().__init__(**kwargs)
14
+ self.input_dim = input_dim
15
+ self.hidden_dim = hidden_dim
16
+ self.output_dim = output_dim
model.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import typing
4
+ import transformers
5
+
6
+ from .config import MLPConfig
7
+
8
+
9
+ class Backbone(nn.Module):
10
+ def __init__(self, config):
11
+ super().__init__()
12
+ self.model = nn.Sequential(
13
+ nn.Linear(config.input_dim, config.hidden_dim),
14
+ nn.ReLU(),
15
+ nn.Linear(config.hidden_dim, config.hidden_dim),
16
+ nn.ReLU(),
17
+ nn.Linear(config.hidden_dim, config.output_dim)
18
+ )
19
+
20
+ def forward(self, x):
21
+ return self.model(x)
22
+
23
+
24
+ class MLP(transformers.PreTrainedModel):
25
+ """HF-compatible model."""
26
+ config_class = MLPConfig
27
+ base_model_prefix = 'mlp'
28
+
29
+ def __init__(self, config: MLPConfig):
30
+ super().__init__(config)
31
+ self.config = config
32
+ self.backbone = Backbone(config)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aba70e689e01bd5aec5926f6dbf032ae12b984380fa47684108b607047a9ef44
3
+ size 9048