File size: 1,680 Bytes
f7009b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import sys, os
root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1])
sys.path.append(root)
os.chdir(root)
# torch
import torch
from torch import nn
# father
import importlib
item = importlib.import_module(f"{sys.argv[1]}")
Dataset = item.Dataset
train_loader = item.train_loader
optimizer = item.optimizer
train_set = item.train_set
config = item.config
model = item.model
assert config.get("tag") is not None, "Remember to set a tag."
test_config = {
"device": "cuda",
"checkpoint": f"./checkpoint/{config['tag']}.pth",
}
config.update(test_config)
# Model
print('==> Building model..')
diction = torch.load(config["checkpoint"])
permutation_shape = diction["to_permutation_state.weight"].shape
model.to_permutation_state = nn.Embedding(*permutation_shape)
model.load_state_dict(diction)
model = model.to(config["device"])
# test
print('==> Defining training..')
def memory_test():
print("==> start training..")
model.train()
for batch_idx, (param, permutation_state) in enumerate(train_loader):
optimizer.zero_grad()
# noinspection PyArgumentList
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
loss = model(output_shape=param.shape,
x_0=param.to(model.device),
permutation_state=permutation_state.to(model.device))
loss.backward()
optimizer.step()
if batch_idx >= 10:
break
os.system("nvidia-smi")
input(f"This program running on GPU:{os.environ['CUDA_VISIBLE_DEVICES']}")
if __name__ == "__main__":
memory_test() |