luantber commited on
Commit
71b5fcd
·
1 Parent(s): b69b23b

add model

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. model.py +1 -16
config.json CHANGED
@@ -4,7 +4,7 @@
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "config.CNNConfig",
7
- "AutoModel": "__main__.CNNModel"
8
  },
9
  "num_classes": 10,
10
  "torch_dtype": "float32",
 
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "config.CNNConfig",
7
+ "AutoModel": "model.CNNModel"
8
  },
9
  "num_classes": 10,
10
  "torch_dtype": "float32",
model.py CHANGED
@@ -1,4 +1,4 @@
1
- from config import CNNConfig
2
  from transformers import PreTrainedModel
3
 
4
 
@@ -71,20 +71,5 @@ class CNNModel(PreTrainedModel):
71
  def forward(self, tensor):
72
  return self.model(tensor)
73
 
74
- if __name__ == "__main__":
75
- weights = "k_cifar10/2wye5gyi/checkpoints/epoch=599-step=187800.ckpt"
76
-
77
- CNNConfig.register_for_auto_class()
78
- CNNModel.register_for_auto_class("AutoModel")
79
-
80
- pretrained = CNN.load_from_checkpoint(weights)
81
-
82
- config = CNNConfig.from_pretrained("custom_cnn_10")
83
- my_cnn = CNNModel(config)
84
- my_cnn.model.load_state_dict(pretrained.state_dict())
85
-
86
- my_cnn.push_to_hub("k_cnn_cifar10")
87
-
88
-
89
 
90
 
 
1
+ from .config import CNNConfig
2
  from transformers import PreTrainedModel
3
 
4
 
 
71
  def forward(self, tensor):
72
  return self.model(tensor)
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75