File size: 1,680 Bytes
f7009b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import sys, os
root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1])
sys.path.append(root)
os.chdir(root)

# torch
import torch
from torch import nn
# father
import importlib
item = importlib.import_module(f"{sys.argv[1]}")
Dataset = item.Dataset
train_loader = item.train_loader
optimizer = item.optimizer
train_set = item.train_set
config = item.config
model = item.model
assert config.get("tag") is not None, "Remember to set a tag."




test_config = {
    "device": "cuda",
    "checkpoint": f"./checkpoint/{config['tag']}.pth",
}
config.update(test_config)




# Model
print('==> Building model..')
diction = torch.load(config["checkpoint"])
permutation_shape = diction["to_permutation_state.weight"].shape
model.to_permutation_state = nn.Embedding(*permutation_shape)
model.load_state_dict(diction)
model = model.to(config["device"])


# test
print('==> Defining training..')
def memory_test():
    print("==> start training..")
    model.train()
    for batch_idx, (param, permutation_state) in enumerate(train_loader):
        optimizer.zero_grad()
        # noinspection PyArgumentList
        with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
            loss = model(output_shape=param.shape,
                         x_0=param.to(model.device),
                         permutation_state=permutation_state.to(model.device))
        loss.backward()
        optimizer.step()
        if batch_idx >= 10:
            break
    os.system("nvidia-smi")
    input(f"This program running on GPU:{os.environ['CUDA_VISIBLE_DEVICES']}")




if __name__ == "__main__":
    memory_test()