File size: 418 Bytes
d761eb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import torch
from test_model import MyModelConfig,MyModel
MyModelConfig.register_for_auto_class()
MyModel.register_for_auto_class("AutoModel")
save_config = MyModelConfig(input_dim=10,layers_num=3)
save_config.save_pretrained("custom-mymodel")
mymodel = MyModel(save_config)
mymodel.load_state_dict(torch.load('pytorch_model.bin'))
mymodel.save_pretrained("custom-mymodel")
mymodel.push_to_hub("custom-mymodel_v1") |