test-model / train.py
copper-light's picture
update
55413d4
raw
history blame contribute delete
417 Bytes
import torch
from transformers import (WEIGHTS_NAME, CONFIG_NAME, AutoConfig)
from model import (TestConfig, TestModel)
if __name__ == "__main__":
config = TestConfig()
model = TestModel(config)
x = torch.rand(16,4)
pred = model(x)
print(pred)
AutoConfig.register("test-model", TestConfig)
config.save_pretrained(CONFIG_NAME)
model.save_pretrained(WEIGHTS_NAME)