File size: 417 Bytes
c580b09
55413d4
c580b09
 
 
 
 
 
 
 
 
 
 
 
 
 
55413d4
 
c580b09
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from transformers import (WEIGHTS_NAME, CONFIG_NAME, AutoConfig)

from model import (TestConfig, TestModel)



if __name__ == "__main__":
    config = TestConfig()
    model = TestModel(config)
    
    x = torch.rand(16,4)
    
    pred = model(x)
    print(pred)
    
    AutoConfig.register("test-model", TestConfig)
    
    config.save_pretrained(CONFIG_NAME)
    model.save_pretrained(WEIGHTS_NAME)