testmodel / register.py
Darius-H
First model version
d761eb1
raw
history blame contribute delete
418 Bytes
import torch
from test_model import MyModelConfig,MyModel
MyModelConfig.register_for_auto_class()
MyModel.register_for_auto_class("AutoModel")
save_config = MyModelConfig(input_dim=10,layers_num=3)
save_config.save_pretrained("custom-mymodel")
mymodel = MyModel(save_config)
mymodel.load_state_dict(torch.load('pytorch_model.bin'))
mymodel.save_pretrained("custom-mymodel")
mymodel.push_to_hub("custom-mymodel_v1")