add model
Browse files- config.json +1 -1
- model.py +1 -16
config.json
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
],
|
| 5 |
"auto_map": {
|
| 6 |
"AutoConfig": "config.CNNConfig",
|
| 7 |
-
"AutoModel": "
|
| 8 |
},
|
| 9 |
"num_classes": 10,
|
| 10 |
"torch_dtype": "float32",
|
|
|
|
| 4 |
],
|
| 5 |
"auto_map": {
|
| 6 |
"AutoConfig": "config.CNNConfig",
|
| 7 |
+
"AutoModel": "model.CNNModel"
|
| 8 |
},
|
| 9 |
"num_classes": 10,
|
| 10 |
"torch_dtype": "float32",
|
model.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from config import CNNConfig
|
| 2 |
from transformers import PreTrainedModel
|
| 3 |
|
| 4 |
|
|
@@ -71,20 +71,5 @@ class CNNModel(PreTrainedModel):
|
|
| 71 |
def forward(self, tensor):
|
| 72 |
return self.model(tensor)
|
| 73 |
|
| 74 |
-
if __name__ == "__main__":
|
| 75 |
-
weights = "k_cifar10/2wye5gyi/checkpoints/epoch=599-step=187800.ckpt"
|
| 76 |
-
|
| 77 |
-
CNNConfig.register_for_auto_class()
|
| 78 |
-
CNNModel.register_for_auto_class("AutoModel")
|
| 79 |
-
|
| 80 |
-
pretrained = CNN.load_from_checkpoint(weights)
|
| 81 |
-
|
| 82 |
-
config = CNNConfig.from_pretrained("custom_cnn_10")
|
| 83 |
-
my_cnn = CNNModel(config)
|
| 84 |
-
my_cnn.model.load_state_dict(pretrained.state_dict())
|
| 85 |
-
|
| 86 |
-
my_cnn.push_to_hub("k_cnn_cifar10")
|
| 87 |
-
|
| 88 |
-
|
| 89 |
|
| 90 |
|
|
|
|
| 1 |
+
from .config import CNNConfig
|
| 2 |
from transformers import PreTrainedModel
|
| 3 |
|
| 4 |
|
|
|
|
| 71 |
def forward(self, tensor):
|
| 72 |
return self.model(tensor)
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
|