File size: 417 Bytes
c580b09 55413d4 c580b09 55413d4 c580b09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
from transformers import (WEIGHTS_NAME, CONFIG_NAME, AutoConfig)
from model import (TestConfig, TestModel)
if __name__ == "__main__":
config = TestConfig()
model = TestModel(config)
x = torch.rand(16,4)
pred = model(x)
print(pred)
AutoConfig.register("test-model", TestConfig)
config.save_pretrained(CONFIG_NAME)
model.save_pretrained(WEIGHTS_NAME) |