import torch from transformers import (WEIGHTS_NAME, CONFIG_NAME, AutoConfig) from model import (TestConfig, TestModel) if __name__ == "__main__": config = TestConfig() model = TestModel(config) x = torch.rand(16,4) pred = model(x) print(pred) AutoConfig.register("test-model", TestConfig) config.save_pretrained(CONFIG_NAME) model.save_pretrained(WEIGHTS_NAME)