| | import sys, os, json |
| | root = os.sep + os.sep.join(__file__.split(os.sep)[1:__file__.split(os.sep).index("Recurrent-Parameter-Generation")+1]) |
| | sys.path.append(root) |
| | os.chdir(root) |
| | with open("./workspace/config.json", "r") as f: |
| | additional_config = json.load(f) |
| | USE_WANDB = additional_config["use_wandb"] |
| |
|
| | |
| | import random |
| | import numpy as np |
| | import torch |
| | seed = SEED = 999 |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = True |
| | np.random.seed(seed) |
| | random.seed(seed) |
| |
|
| | |
| | import math |
| | import random |
| | import warnings |
| | from _thread import start_new_thread |
| | warnings.filterwarnings("ignore", category=UserWarning) |
| | if USE_WANDB: import wandb |
| | |
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.nn import functional as F |
| | from torch.cuda.amp import autocast |
| | |
| | from model import MambaDiffusion as Model |
| | from model.diffusion import DDPMSampler, DDIMSampler |
| | from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR |
| | from accelerate.utils import DistributedDataParallelKwargs |
| | from accelerate.utils import AutocastKwargs |
| | from accelerate import Accelerator |
| | |
| | from dataset import Cifar10_ResNet18 as Dataset |
| | from torch.utils.data import DataLoader |
| |
|
| |
|
| |
|
| |
|
| | config = { |
| | "seed": SEED, |
| | |
| | "dataset": Dataset, |
| | "dim_per_token": 8192, |
| | "sequence_length": 'auto', |
| | |
| | "batch_size": 8, |
| | "num_workers": 16, |
| | "total_steps": 80000, |
| | "learning_rate": 0.00004, |
| | "weight_decay": 0.0, |
| | "save_every": 80000//30, |
| | "print_every": 50, |
| | "autocast": lambda i: 5000 < i < 45000, |
| | "checkpoint_save_path": "./checkpoint", |
| | |
| | "test_batch_size": 1, |
| | "generated_path": Dataset.generated_path, |
| | "test_command": Dataset.test_command, |
| | |
| | "model_config": { |
| | "num_permutation": 'auto', |
| | |
| | "d_condition": 1, |
| | "d_model": 8192, |
| | "d_state": 128, |
| | "d_conv": 4, |
| | "expand": 2, |
| | "num_layers": 2, |
| | |
| | "diffusion_batch": 1024, |
| | "layer_channels": [1, 32, 64, 128, 64, 32, 1], |
| | "model_dim": "auto", |
| | "condition_dim": "auto", |
| | "kernel_size": 7, |
| | "sample_mode": DDPMSampler, |
| | "beta": (0.0001, 0.02), |
| | "T": 1000, |
| | "forward_once": True, |
| | }, |
| | "tag": "compare_ours_resnet18", |
| | } |
| |
|
| |
|
| |
|
| |
|
| | |
| | print('==> Preparing data..') |
| | train_set = config["dataset"](dim_per_token=config["dim_per_token"]) |
| | print("Dataset length:", train_set.real_length) |
| | print("input shape:", train_set[0][0].shape) |
| | if config["model_config"]["num_permutation"] == "auto": |
| | config["model_config"]["num_permutation"] = train_set.max_permutation_state |
| | if config["model_config"]["condition_dim"] == "auto": |
| | config["model_config"]["condition_dim"] = config["model_config"]["d_model"] |
| | if config["model_config"]["model_dim"] == "auto": |
| | config["model_config"]["model_dim"] = config["dim_per_token"] |
| | if config["sequence_length"] == "auto": |
| | config["sequence_length"] = train_set.sequence_length |
| | print(f"sequence length: {config['sequence_length']}") |
| | else: |
| | assert train_set.sequence_length == config["sequence_length"], f"sequence_length={train_set.sequence_length}" |
| | train_loader = DataLoader( |
| | dataset=train_set, |
| | batch_size=config["batch_size"], |
| | num_workers=config["num_workers"], |
| | persistent_workers=True, |
| | drop_last=True, |
| | shuffle=True, |
| | ) |
| |
|
| | |
| | print('==> Building model..') |
| | Model.config = config["model_config"] |
| | model = Model( |
| | sequence_length=config["sequence_length"], |
| | positional_embedding=train_set.get_position_embedding( |
| | positional_embedding_dim=config["model_config"]["d_model"] |
| | ) |
| | ) |
| |
|
| | |
| | print('==> Building optimizer..') |
| | optimizer = optim.AdamW( |
| | params=model.parameters(), |
| | lr=config["learning_rate"], |
| | weight_decay=config["weight_decay"], |
| | ) |
| | scheduler = CosineAnnealingLR( |
| | optimizer=optimizer, |
| | T_max=config["total_steps"], |
| | ) |
| |
|
| | |
| | if __name__ == "__main__": |
| | kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) |
| | accelerator = Accelerator(kwargs_handlers=[kwargs,]) |
| | if config["dim_per_token"] > 12288 and accelerator.state.num_processes == 1: |
| | print(f"\033[91mWARNING: With token size {config['dim_per_token']}, we suggest to train on multiple GPUs.\033[0m") |
| | model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) |
| |
|
| |
|
| | |
| | if __name__ == "__main__" and USE_WANDB and accelerator.is_main_process: |
| | wandb.login(key=additional_config["wandb_api_key"]) |
| | wandb.init(project="Recurrent-Parameter-Generation", name=config['tag'], config=config,) |
| |
|
| |
|
| |
|
| |
|
| | |
| | print('==> Defining training..') |
| | def train(): |
| | if not USE_WANDB: |
| | train_loss = 0 |
| | this_steps = 0 |
| | print("==> Start training..") |
| | model.train() |
| | for batch_idx, (param, permutation_state) in enumerate(train_loader): |
| | optimizer.zero_grad() |
| | |
| | |
| | with accelerator.autocast(autocast_handler=AutocastKwargs(enabled=config["autocast"](batch_idx))): |
| | loss = model(output_shape=param.shape, x_0=param, permutation_state=permutation_state) |
| | accelerator.backward(loss) |
| | optimizer.step() |
| | if accelerator.is_main_process: |
| | scheduler.step() |
| | |
| | if USE_WANDB and accelerator.is_main_process: |
| | wandb.log({"train_loss": loss.item()}) |
| | elif USE_WANDB: |
| | pass |
| | else: |
| | train_loss += loss.item() |
| | this_steps += 1 |
| | if this_steps % config["print_every"] == 0: |
| | print('Loss: %.6f' % (train_loss/this_steps)) |
| | this_steps = 0 |
| | train_loss = 0 |
| | if batch_idx % config["save_every"] == 0 and accelerator.is_main_process: |
| | os.makedirs(config["checkpoint_save_path"], exist_ok=True) |
| | state = accelerator.unwrap_model(model).state_dict() |
| | torch.save(state, os.path.join(config["checkpoint_save_path"], config["tag"]+".pth")) |
| | generate(save_path=config["generated_path"], need_test=True) |
| | if batch_idx >= config["total_steps"]: |
| | break |
| |
|
| |
|
| | def generate(save_path=config["generated_path"], need_test=True): |
| | print("\n==> Generating..") |
| | model.eval() |
| | with torch.no_grad(): |
| | prediction = model(sample=True) |
| | generated_norm = prediction.abs().mean() |
| | print("Generated_norm:", generated_norm.item()) |
| | if USE_WANDB: |
| | wandb.log({"generated_norm": generated_norm.item()}) |
| | train_set.save_params(prediction, save_path=save_path) |
| | if need_test: |
| | start_new_thread(os.system, (config["test_command"],)) |
| | model.train() |
| | return prediction |
| |
|
| |
|
| |
|
| |
|
| | if __name__ == '__main__': |
| | train() |
| | del train_loader |
| | print("Finished Training!") |
| | exit(0) |