| import sys, os, json |
| root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1]) |
| sys.path.append(root) |
| os.chdir(root) |
| with open("./workspace/config.json", "r") as f: |
| additional_config = json.load(f) |
| USE_WANDB = additional_config["use_wandb"] |
|
|
| |
| import random |
| import numpy as np |
| import torch |
| seed = SEED = 999 |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = True |
| np.random.seed(seed) |
| random.seed(seed) |
|
|
| |
| import math |
| import random |
| import warnings |
| from _thread import start_new_thread |
| warnings.filterwarnings("ignore", category=UserWarning) |
| if USE_WANDB: import wandb |
| |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.nn import functional as F |
| from torch.cuda.amp import autocast |
| |
| from model.pdiff import PDiff as Model |
| from model.diffusion import DDPMSampler, DDIMSampler |
| from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR |
| from accelerate.utils import DistributedDataParallelKwargs |
| from accelerate.utils import AutocastKwargs |
| from accelerate import Accelerator |
| |
| from dataset import Cifar10_CNNSmall as Dataset |
| from torch.utils.data import DataLoader |
|
|
|
|
|
|
|
|
| config = { |
| "seed": SEED, |
| |
| "dataset": Dataset, |
| "sequence_length": 'auto', |
| |
| "batch_size": 16, |
| "num_workers": 16, |
| "total_steps": 50000, |
| "learning_rate": 0.0002, |
| "weight_decay": 0.0, |
| "save_every": 50000//2, |
| "print_every": 50, |
| "autocast": lambda i: 5000 < i < 45000, |
| "checkpoint_save_path": "./checkpoint", |
| |
| "test_batch_size": 1, |
| "generated_path": Dataset.generated_path, |
| "test_command": Dataset.test_command, |
| |
| "model_config": { |
| |
| "layer_channels": [1, 64, 128, 256, 512, 256, 128, 64, 1], |
| "model_dim": "auto", |
| "kernel_size": 7, |
| "sample_mode": DDPMSampler, |
| "beta": (0.0001, 0.02), |
| "T": 1000, |
| }, |
| "tag": "compare_pdiff_cnnsmall", |
| } |
|
|
|
|
|
|
|
|
| |
| divide_slice_length = 64 |
| print('==> Preparing data..') |
| train_set = config["dataset"](dim_per_token=divide_slice_length, |
| granularity=0, |
| pe_granularity=0) |
| print("Dataset length:", train_set.real_length) |
| print("input shape:", train_set[0][0].flatten().shape) |
| if config["sequence_length"] == "auto": |
| config["sequence_length"] = train_set.sequence_length * divide_slice_length |
| print(f"sequence length: {config['sequence_length']}") |
| if config["model_config"]["model_dim"] == "auto": |
| config["model_config"]["model_dim"] = config["sequence_length"] |
| train_loader = DataLoader( |
| dataset=train_set, |
| batch_size=config["batch_size"], |
| num_workers=config["num_workers"], |
| persistent_workers=True, |
| drop_last=True, |
| shuffle=True, |
| ) |
|
|
| |
| print('==> Building model..') |
| Model.config = config["model_config"] |
| model = Model(sequence_length=config["sequence_length"]) |
|
|
| |
| print('==> Building optimizer..') |
| optimizer = optim.AdamW( |
| params=model.parameters(), |
| lr=config["learning_rate"], |
| weight_decay=config["weight_decay"], |
| ) |
| scheduler = CosineAnnealingLR( |
| optimizer=optimizer, |
| T_max=config["total_steps"], |
| ) |
|
|
| |
| if __name__ == "__main__": |
| kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) |
| accelerator = Accelerator(kwargs_handlers=[kwargs,]) |
| model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) |
|
|
|
|
| |
| if __name__ == "__main__" and USE_WANDB and accelerator.is_main_process: |
| wandb.login(key=additional_config["wandb_api_key"]) |
| wandb.init(project="Recurrent-Parameter-Generation", name=config['tag'], config=config,) |
|
|
|
|
|
|
|
|
| |
| print('==> Defining training..') |
| def train(): |
| if not USE_WANDB: |
| train_loss = 0 |
| this_steps = 0 |
| print("==> Start training..") |
| model.train() |
| for batch_idx, (param, _) in enumerate(train_loader): |
| optimizer.zero_grad() |
| |
| |
| with accelerator.autocast(autocast_handler=AutocastKwargs(enabled=config["autocast"](batch_idx))): |
| param = param.flatten(start_dim=1) |
| loss = model(x=param) |
| accelerator.backward(loss) |
| optimizer.step() |
| if accelerator.is_main_process: |
| scheduler.step() |
| |
| if USE_WANDB and accelerator.is_main_process: |
| wandb.log({"train_loss": loss.item()}) |
| elif USE_WANDB: |
| pass |
| else: |
| train_loss += loss.item() |
| this_steps += 1 |
| if this_steps % config["print_every"] == 0: |
| print('Loss: %.6f' % (train_loss/this_steps)) |
| this_steps = 0 |
| train_loss = 0 |
| if batch_idx % config["save_every"] == 0 and accelerator.is_main_process: |
| os.makedirs(config["checkpoint_save_path"], exist_ok=True) |
| state = accelerator.unwrap_model(model).state_dict() |
| torch.save(state, os.path.join(config["checkpoint_save_path"], config["tag"]+".pth")) |
| generate(save_path=config["generated_path"], need_test=True) |
| if batch_idx >= config["total_steps"]: |
| break |
|
|
|
|
| def generate(save_path=config["generated_path"], need_test=True): |
| print("\n==> Generating..") |
| model.eval() |
| with torch.no_grad(): |
| prediction = model(sample=True) |
| generated_norm = prediction.abs().mean() |
| print("Generated_norm:", generated_norm.item()) |
| if USE_WANDB: |
| wandb.log({"generated_norm": generated_norm.item()}) |
| prediction = prediction.view(-1, divide_slice_length) |
| train_set.save_params(prediction, save_path=save_path) |
| if need_test: |
| start_new_thread(os.system, (config["test_command"],)) |
| model.train() |
| return prediction |
|
|
|
|
|
|
|
|
| if __name__ == '__main__': |
| train() |
| del train_loader |
| print("Finished Training!") |
| exit(0) |