Sushil Singh commited on
Commit
7a09e94
·
1 Parent(s): 77f799f

changed name from simple-linear-model to simple_linear_model

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. config.json +2 -2
  3. model.py +6 -6
README.md CHANGED
@@ -52,7 +52,7 @@ from simple_linear_model.model import *
52
  from transformers import AutoConfig, AutoModel
53
 
54
  # Load the model
55
- model = AutoModel.from_pretrained("sushilks/simple-linear-model")
56
 
57
  # Forward pass
58
  import torch
 
52
  from transformers import AutoConfig, AutoModel
53
 
54
  # Load the model
55
+ model = AutoModel.from_pretrained("sushilks/simple_linear_model")
56
 
57
  # Forward pass
58
  import torch
config.json CHANGED
@@ -1,10 +1,10 @@
1
  {
2
  "architectures": [
3
- "SimpleLinearPreTrainedModel"
4
  ],
5
  "input_dim": 768,
6
  "input_size": 768,
7
- "model_type": "simple-linear-model",
8
  "output_dim": 512,
9
  "output_size": 512,
10
  "torch_dtype": "float32",
 
1
  {
2
  "architectures": [
3
+ "SimpleLinearModel"
4
  ],
5
  "input_dim": 768,
6
  "input_size": 768,
7
+ "model_type": "simple_linear_model",
8
  "output_dim": 512,
9
  "output_size": 512,
10
  "torch_dtype": "float32",
model.py CHANGED
@@ -4,16 +4,16 @@ from transformers import PreTrainedModel
4
  from transformers import AutoConfig, AutoModel, PretrainedConfig
5
 
6
  class SimpleLinearConfig(PretrainedConfig):
7
- model_type = "simple-linear-model"
8
-
9
  def __init__(self, input_dim=768, output_dim=512, **kwargs):
10
  super().__init__(**kwargs)
11
  self.input_dim = input_dim
12
  self.output_dim = output_dim
13
 
14
- class SimpleLinearPreTrainedModel(PreTrainedModel):
15
  config_class = SimpleLinearConfig
16
-
17
  def __init__(self, config: SimpleLinearConfig):
18
  super().__init__(config)
19
  self.linear = nn.Linear(config.input_dim, config.output_dim)
@@ -32,7 +32,7 @@ class SimpleLinearPreTrainedModel(PreTrainedModel):
32
  nn.init.zeros_(param)
33
 
34
  # Register our config class with AutoConfig
35
- AutoConfig.register("simple-linear-model", SimpleLinearConfig)
36
 
37
  # Register our model class with AutoModel
38
- AutoModel.register(SimpleLinearConfig, SimpleLinearPreTrainedModel)
 
4
  from transformers import AutoConfig, AutoModel, PretrainedConfig
5
 
6
  class SimpleLinearConfig(PretrainedConfig):
7
+ model_type = "simple_linear_model"
8
+ _no_split_modules = ["linear"]
9
  def __init__(self, input_dim=768, output_dim=512, **kwargs):
10
  super().__init__(**kwargs)
11
  self.input_dim = input_dim
12
  self.output_dim = output_dim
13
 
14
+ class SimpleLinearModel(PreTrainedModel):
15
  config_class = SimpleLinearConfig
16
+ _no_split_modules = []
17
  def __init__(self, config: SimpleLinearConfig):
18
  super().__init__(config)
19
  self.linear = nn.Linear(config.input_dim, config.output_dim)
 
32
  nn.init.zeros_(param)
33
 
34
  # Register our config class with AutoConfig
35
+ AutoConfig.register("simple_linear_model", SimpleLinearConfig)
36
 
37
  # Register our model class with AutoModel
38
+ AutoModel.register(SimpleLinearConfig, SimpleLinearModel)