Sushil Singh commited on
Commit ·
7a09e94
1
Parent(s): 77f799f
changed name from simple-linear-model to simple_linear_model
Browse files- README.md +1 -1
- config.json +2 -2
- model.py +6 -6
README.md
CHANGED
|
@@ -52,7 +52,7 @@ from simple_linear_model.model import *
|
|
| 52 |
from transformers import AutoConfig, AutoModel
|
| 53 |
|
| 54 |
# Load the model
|
| 55 |
-
model = AutoModel.from_pretrained("sushilks/
|
| 56 |
|
| 57 |
# Forward pass
|
| 58 |
import torch
|
|
|
|
| 52 |
from transformers import AutoConfig, AutoModel
|
| 53 |
|
| 54 |
# Load the model
|
| 55 |
+
model = AutoModel.from_pretrained("sushilks/simple_linear_model")
|
| 56 |
|
| 57 |
# Forward pass
|
| 58 |
import torch
|
config.json
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
{
|
| 2 |
"architectures": [
|
| 3 |
-
"
|
| 4 |
],
|
| 5 |
"input_dim": 768,
|
| 6 |
"input_size": 768,
|
| 7 |
-
"model_type": "
|
| 8 |
"output_dim": 512,
|
| 9 |
"output_size": 512,
|
| 10 |
"torch_dtype": "float32",
|
|
|
|
| 1 |
{
|
| 2 |
"architectures": [
|
| 3 |
+
"SimpleLinearModel"
|
| 4 |
],
|
| 5 |
"input_dim": 768,
|
| 6 |
"input_size": 768,
|
| 7 |
+
"model_type": "simple_linear_model",
|
| 8 |
"output_dim": 512,
|
| 9 |
"output_size": 512,
|
| 10 |
"torch_dtype": "float32",
|
model.py
CHANGED
|
@@ -4,16 +4,16 @@ from transformers import PreTrainedModel
|
|
| 4 |
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
| 5 |
|
| 6 |
class SimpleLinearConfig(PretrainedConfig):
|
| 7 |
-
model_type = "
|
| 8 |
-
|
| 9 |
def __init__(self, input_dim=768, output_dim=512, **kwargs):
|
| 10 |
super().__init__(**kwargs)
|
| 11 |
self.input_dim = input_dim
|
| 12 |
self.output_dim = output_dim
|
| 13 |
|
| 14 |
-
class
|
| 15 |
config_class = SimpleLinearConfig
|
| 16 |
-
|
| 17 |
def __init__(self, config: SimpleLinearConfig):
|
| 18 |
super().__init__(config)
|
| 19 |
self.linear = nn.Linear(config.input_dim, config.output_dim)
|
|
@@ -32,7 +32,7 @@ class SimpleLinearPreTrainedModel(PreTrainedModel):
|
|
| 32 |
nn.init.zeros_(param)
|
| 33 |
|
| 34 |
# Register our config class with AutoConfig
|
| 35 |
-
AutoConfig.register("
|
| 36 |
|
| 37 |
# Register our model class with AutoModel
|
| 38 |
-
AutoModel.register(SimpleLinearConfig,
|
|
|
|
| 4 |
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
| 5 |
|
| 6 |
class SimpleLinearConfig(PretrainedConfig):
|
| 7 |
+
model_type = "simple_linear_model"
|
| 8 |
+
_no_split_modules = ["linear"]
|
| 9 |
def __init__(self, input_dim=768, output_dim=512, **kwargs):
|
| 10 |
super().__init__(**kwargs)
|
| 11 |
self.input_dim = input_dim
|
| 12 |
self.output_dim = output_dim
|
| 13 |
|
| 14 |
+
class SimpleLinearModel(PreTrainedModel):
|
| 15 |
config_class = SimpleLinearConfig
|
| 16 |
+
_no_split_modules = []
|
| 17 |
def __init__(self, config: SimpleLinearConfig):
|
| 18 |
super().__init__(config)
|
| 19 |
self.linear = nn.Linear(config.input_dim, config.output_dim)
|
|
|
|
| 32 |
nn.init.zeros_(param)
|
| 33 |
|
| 34 |
# Register our config class with AutoConfig
|
| 35 |
+
AutoConfig.register("simple_linear_model", SimpleLinearConfig)
|
| 36 |
|
| 37 |
# Register our model class with AutoModel
|
| 38 |
+
AutoModel.register(SimpleLinearConfig, SimpleLinearModel)
|