Commit
·
0bd2cee
1
Parent(s):
ae0c2b6
new version
Browse files- config.json +1 -1
- config.py +1 -1
- model.py +2 -3
config.json
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
"MNIST_Classifier"
|
| 4 |
],
|
| 5 |
"auto_map": {
|
| 6 |
-
"AutoConfig": "config.
|
| 7 |
"AutoModel": "model.MNIST_Classifier"
|
| 8 |
},
|
| 9 |
"hidden_size1": 128,
|
|
|
|
| 3 |
"MNIST_Classifier"
|
| 4 |
],
|
| 5 |
"auto_map": {
|
| 6 |
+
"AutoConfig": "config.MNIST_config",
|
| 7 |
"AutoModel": "model.MNIST_Classifier"
|
| 8 |
},
|
| 9 |
"hidden_size1": 128,
|
config.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from transformers import PretrainedConfig
|
| 2 |
-
class
|
| 3 |
model_type = "MNIST_Classifier"
|
| 4 |
def __init__(self, **kwargs):
|
| 5 |
super().__init__(**kwargs)
|
|
|
|
| 1 |
from transformers import PretrainedConfig
|
| 2 |
+
class MNIST_config(PretrainedConfig):
|
| 3 |
model_type = "MNIST_Classifier"
|
| 4 |
def __init__(self, **kwargs):
|
| 5 |
super().__init__(**kwargs)
|
model.py
CHANGED
|
@@ -1,13 +1,12 @@
|
|
| 1 |
###import os,sys
|
| 2 |
###sys.path.insert(1,os.path.join(sys.path[0],".."))
|
| 3 |
from network import Net
|
| 4 |
-
from config import
|
| 5 |
from transformers import PreTrainedModel
|
| 6 |
# utils not used but importing it forces the upload to huggingface hub to include it
|
| 7 |
|
| 8 |
-
|
| 9 |
class MNIST_Classifier(PreTrainedModel):
|
| 10 |
-
config_class =
|
| 11 |
def __init__(self, config):
|
| 12 |
super().__init__(config)
|
| 13 |
self.classifier=Net(config.input_size,config.hidden_size1,config.hidden_size2,
|
|
|
|
| 1 |
###import os,sys
|
| 2 |
###sys.path.insert(1,os.path.join(sys.path[0],".."))
|
| 3 |
from network import Net
|
| 4 |
+
from config import MNIST_config
|
| 5 |
from transformers import PreTrainedModel
|
| 6 |
# utils not used but importing it forces the upload to huggingface hub to include it
|
| 7 |
|
|
|
|
| 8 |
class MNIST_Classifier(PreTrainedModel):
|
| 9 |
+
config_class = MNIST_config
|
| 10 |
def __init__(self, config):
|
| 11 |
super().__init__(config)
|
| 12 |
self.classifier=Net(config.input_size,config.hidden_size1,config.hidden_size2,
|