| | import sys, os |
| | root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1]) |
| | sys.path.append(root) |
| | os.chdir(root) |
| |
|
| | |
| | import torch |
| | from torch import nn |
| | |
| | import importlib |
| | item = importlib.import_module(f"{sys.argv[1]}") |
| | Dataset = item.Dataset |
| | train_loader = item.train_loader |
| | optimizer = item.optimizer |
| | train_set = item.train_set |
| | config = item.config |
| | model = item.model |
| | assert config.get("tag") is not None, "Remember to set a tag." |
| |
|
| |
|
| |
|
| |
|
| | test_config = { |
| | "device": "cuda", |
| | "checkpoint": f"./checkpoint/{config['tag']}.pth", |
| | } |
| | config.update(test_config) |
| |
|
| |
|
| |
|
| |
|
| | |
| | print('==> Building model..') |
| | diction = torch.load(config["checkpoint"]) |
| | permutation_shape = diction["to_permutation_state.weight"].shape |
| | model.to_permutation_state = nn.Embedding(*permutation_shape) |
| | model.load_state_dict(diction) |
| | model = model.to(config["device"]) |
| |
|
| |
|
| | |
| | print('==> Defining training..') |
| | def memory_test(): |
| | print("==> start training..") |
| | model.train() |
| | for batch_idx, (param, permutation_state) in enumerate(train_loader): |
| | optimizer.zero_grad() |
| | |
| | with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): |
| | loss = model(output_shape=param.shape, |
| | x_0=param.to(model.device), |
| | permutation_state=permutation_state.to(model.device)) |
| | loss.backward() |
| | optimizer.step() |
| | if batch_idx >= 10: |
| | break |
| | os.system("nvidia-smi") |
| | input(f"This program running on GPU:{os.environ['CUDA_VISIBLE_DEVICES']}") |
| |
|
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | memory_test() |