| import torch | |
| from transformers import (WEIGHTS_NAME, CONFIG_NAME, AutoConfig) | |
| from model import (TestConfig, TestModel) | |
| if __name__ == "__main__": | |
| config = TestConfig() | |
| model = TestModel(config) | |
| x = torch.rand(16,4) | |
| pred = model(x) | |
| print(pred) | |
| AutoConfig.register("test-model", TestConfig) | |
| config.save_pretrained(CONFIG_NAME) | |
| model.save_pretrained(WEIGHTS_NAME) |