|
|
import multiprocessing |
|
|
import os |
|
|
from zlib import Z_FULL_FLUSH |
|
|
|
|
|
|
|
|
|
|
|
import pickle |
|
|
from functools import partial |
|
|
from pathlib import Path |
|
|
from args import FineDance_parse_train_opt, save_arguments_to_yaml |
|
|
import sys |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import wandb |
|
|
from accelerate import Accelerator, DistributedDataParallelKwargs |
|
|
from accelerate.state import AcceleratorState |
|
|
from torch.utils.data import DataLoader |
|
|
from tqdm import tqdm |
|
|
|
|
|
from dataset.FineDance_dataset import FineDance_Smpl |
|
|
from dataset.preprocess import increment_path |
|
|
from dataset.preprocess import My_Normalizer as Normalizer |
|
|
from model.adan import Adan |
|
|
from model.diffusion import GaussianDiffusion |
|
|
from model.model import DanceDecoder, SeqModel |
|
|
from vis import SMPLX_Skeleton, SMPLSkeleton |
|
|
|
|
|
|
|
|
def wrap(x): |
|
|
return {f"module.{key}": value for key, value in x.items()} |
|
|
|
|
|
|
|
|
def maybe_wrap(x, num): |
|
|
return x if num == 1 else wrap(x) |
|
|
|
|
|
|
|
|
class EDGE: |
|
|
def __init__( |
|
|
self, |
|
|
opt, |
|
|
feature_type, |
|
|
checkpoint_path="", |
|
|
normalizer=None, |
|
|
EMA=True, |
|
|
learning_rate=4e-4, |
|
|
weight_decay=0.02, |
|
|
): |
|
|
self.opt = opt |
|
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) |
|
|
self.accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) |
|
|
state = AcceleratorState() |
|
|
num_processes = state.num_processes |
|
|
|
|
|
self.repr_dim = repr_dim = opt.nfeats |
|
|
feature_dim = 35 |
|
|
|
|
|
self.horizon = horizon = opt.full_seq_len |
|
|
|
|
|
self.accelerator.wait_for_everyone() |
|
|
|
|
|
self.resume_num = 0 |
|
|
checkpoint = None |
|
|
self.normalizer = None |
|
|
if checkpoint_path != "": |
|
|
checkpoint = torch.load( |
|
|
checkpoint_path, map_location=self.accelerator.device |
|
|
) |
|
|
self.resume_num = int(os.path.basename(checkpoint_path).split("-")[1].split(".")[0]) |
|
|
|
|
|
model = SeqModel( |
|
|
nfeats=repr_dim, |
|
|
seq_len=horizon, |
|
|
latent_dim=512, |
|
|
ff_size=1024, |
|
|
num_layers=8, |
|
|
num_heads=8, |
|
|
dropout=0.1, |
|
|
cond_feature_dim=feature_dim, |
|
|
activation=F.gelu, |
|
|
) |
|
|
if opt.nfeats == 139 or opt.nfeats == 135: |
|
|
smplx_fk = SMPLSkeleton(device=self.accelerator.device) |
|
|
else: |
|
|
smplx_fk = SMPLX_Skeleton(device=self.accelerator.device, batch=512000) |
|
|
diffusion = GaussianDiffusion( |
|
|
model, |
|
|
opt, |
|
|
horizon, |
|
|
repr_dim, |
|
|
smplx_model = smplx_fk, |
|
|
schedule="cosine", |
|
|
n_timestep=1000, |
|
|
predict_epsilon=False, |
|
|
loss_type="l2", |
|
|
use_p2=False, |
|
|
cond_drop_prob=0.25, |
|
|
guidance_weight=2, |
|
|
do_normalize = opt.do_normalize |
|
|
) |
|
|
|
|
|
print( |
|
|
"Model has {} parameters".format(sum(y.numel() for y in model.parameters())) |
|
|
) |
|
|
|
|
|
self.model = self.accelerator.prepare(model) |
|
|
self.diffusion = diffusion.to(self.accelerator.device) |
|
|
self.smplx_fk = smplx_fk |
|
|
optim = Adan(model.parameters(), lr=learning_rate, weight_decay=weight_decay) |
|
|
self.optim = self.accelerator.prepare(optim) |
|
|
|
|
|
if checkpoint_path != "": |
|
|
self.model.load_state_dict( |
|
|
maybe_wrap( |
|
|
checkpoint["ema_state_dict" if EMA else "model_state_dict"], |
|
|
num_processes, |
|
|
) |
|
|
) |
|
|
|
|
|
def eval(self): |
|
|
self.diffusion.eval() |
|
|
|
|
|
def train(self): |
|
|
self.diffusion.train() |
|
|
|
|
|
def prepare(self, objects): |
|
|
return self.accelerator.prepare(*objects) |
|
|
|
|
|
def train_loop(self, opt): |
|
|
print("train_dataset = FineDance_Dataset ") |
|
|
train_dataset = FineDance_Smpl( |
|
|
args=opt, |
|
|
istrain=True, |
|
|
) |
|
|
test_dataset = FineDance_Smpl( |
|
|
args=opt, |
|
|
istrain=False, |
|
|
) |
|
|
|
|
|
num_cpus = multiprocessing.cpu_count() |
|
|
print("batchsize=:", opt.batch_size) |
|
|
train_data_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=opt.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=min(int(num_cpus * 0.5), 40), |
|
|
pin_memory=True, |
|
|
drop_last=True, |
|
|
) |
|
|
test_data_loader = DataLoader( |
|
|
test_dataset, |
|
|
batch_size=opt.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=2, |
|
|
pin_memory=True, |
|
|
drop_last=True, |
|
|
) |
|
|
|
|
|
train_data_loader = self.accelerator.prepare(train_data_loader) |
|
|
|
|
|
load_loop = ( |
|
|
partial(tqdm, position=1, desc="Batch") |
|
|
if self.accelerator.is_main_process |
|
|
else lambda x: x |
|
|
) |
|
|
if self.accelerator.is_main_process: |
|
|
save_dir = str(increment_path(Path(opt.project) / opt.exp_name)) |
|
|
opt.exp_name = save_dir.split("/")[-1] |
|
|
wandb.init(project=opt.wandb_pj_name, name=opt.exp_name) |
|
|
save_dir = Path(save_dir) |
|
|
wdir = save_dir / "weights" |
|
|
wdir.mkdir(parents=True, exist_ok=True) |
|
|
wandb.save("params.yaml") |
|
|
yaml_path = os.path.join(wdir, 'parameters.yaml') |
|
|
save_arguments_to_yaml(opt, yaml_path) |
|
|
|
|
|
|
|
|
self.accelerator.wait_for_everyone() |
|
|
for epoch in range(1, opt.epochs + 1): |
|
|
print("epoch:", epoch+self.resume_num) |
|
|
avg_loss = 0 |
|
|
avg_vloss = 0 |
|
|
avg_fkloss = 0 |
|
|
avg_footloss = 0 |
|
|
|
|
|
|
|
|
self.train() |
|
|
for step, (x, cond, filename) in enumerate( |
|
|
load_loop(train_data_loader) |
|
|
): |
|
|
if opt.nfeats == 139 or opt.nfeats==135: |
|
|
x = x[:, :, :139] |
|
|
|
|
|
total_loss, (loss, v_loss, fk_loss, foot_loss) = self.diffusion( |
|
|
x, cond, t_override=None |
|
|
) |
|
|
|
|
|
self.optim.zero_grad() |
|
|
self.accelerator.backward(total_loss) |
|
|
self.optim.step() |
|
|
|
|
|
|
|
|
if self.accelerator.is_main_process: |
|
|
avg_loss += loss.detach().cpu().numpy() |
|
|
avg_vloss += v_loss.detach().cpu().numpy() |
|
|
avg_fkloss += fk_loss.detach().cpu().numpy() |
|
|
avg_footloss += foot_loss.detach().cpu().numpy() |
|
|
if step % opt.ema_interval == 0: |
|
|
self.diffusion.ema.update_model_average( |
|
|
self.diffusion.master_model, self.diffusion.model |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ((epoch+self.resume_num) % opt.save_interval) == 0 or epoch<=1: |
|
|
|
|
|
self.accelerator.wait_for_everyone() |
|
|
self.eval() |
|
|
|
|
|
if self.accelerator.is_main_process: |
|
|
|
|
|
|
|
|
avg_loss /= len(train_data_loader) |
|
|
avg_vloss /= len(train_data_loader) |
|
|
avg_fkloss /= len(train_data_loader) |
|
|
avg_footloss /= len(train_data_loader) |
|
|
log_dict = { |
|
|
"Train Loss": avg_loss, |
|
|
"V Loss": avg_vloss, |
|
|
"FK Loss": avg_fkloss, |
|
|
"Foot Loss": avg_footloss, |
|
|
} |
|
|
|
|
|
wandb.log(log_dict) |
|
|
|
|
|
ckpt = { |
|
|
"ema_state_dict": self.diffusion.master_model.state_dict(), |
|
|
"model_state_dict": self.accelerator.unwrap_model( |
|
|
self.model |
|
|
).state_dict(), |
|
|
"optimizer_state_dict": self.optim.state_dict(), |
|
|
"normalizer": self.normalizer, |
|
|
} |
|
|
|
|
|
torch.save(ckpt, os.path.join(wdir, f"train-{epoch+self.resume_num}.pt")) |
|
|
print(f"[MODEL SAVED at Epoch {epoch+self.resume_num}]") |
|
|
|
|
|
|
|
|
render_count = 2 |
|
|
shape = (render_count, self.horizon, self.opt.nfeats) |
|
|
print("Generating Sample") |
|
|
|
|
|
(x, cond, filename) = next(iter(test_data_loader)) |
|
|
|
|
|
|
|
|
|
|
|
if opt.nfeats == 139 or opt.nfeats==135: |
|
|
x = x[:, :, :139] |
|
|
|
|
|
cond = cond.to(self.accelerator.device) |
|
|
|
|
|
self.diffusion.render_sample( |
|
|
shape, |
|
|
cond[:render_count], |
|
|
self.normalizer, |
|
|
epoch+self.resume_num, |
|
|
render_out = os.path.join(opt.render_dir, "train_" + opt.exp_name), |
|
|
fk_out = os.path.join(opt.render_dir, "train_" + opt.exp_name), |
|
|
name=filename[:render_count], |
|
|
|
|
|
sound=True, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if self.accelerator.is_main_process: |
|
|
wandb.run.finish() |
|
|
|
|
|
def render_sample( |
|
|
self, data_tuple, label, render_dir, render_count=-1, mode='normal', fk_out=None, render=True, |
|
|
): |
|
|
_, cond, wavname = data_tuple |
|
|
assert len(cond.shape) == 3 |
|
|
if render_count < 0: |
|
|
render_count = len(cond) |
|
|
shape = (render_count, self.horizon, self.repr_dim) |
|
|
cond = cond.to(self.accelerator.device).float() |
|
|
self.diffusion.render_sample( |
|
|
shape, |
|
|
cond[:render_count], |
|
|
self.normalizer, |
|
|
label, |
|
|
render_dir, |
|
|
name=wavname[:render_count], |
|
|
sound=True, |
|
|
mode=mode, |
|
|
fk_out=fk_out, |
|
|
render=render |
|
|
) |
|
|
|
|
|
def train(opt): |
|
|
model = EDGE(opt, opt.feature_type) |
|
|
model.train_loop(opt) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
opt = FineDance_parse_train_opt() |
|
|
command = ' '.join(sys.argv) |
|
|
if not os.path.exists(os.path.join(opt.project, opt.exp_name)): |
|
|
os.makedirs(os.path.join(opt.project, opt.exp_name), exist_ok=False) |
|
|
with open(os.path.join(opt.project, opt.exp_name, 'command.txt'), 'w') as f: |
|
|
f.write(command) |
|
|
|
|
|
yaml_path = os.path.join(opt.project, opt.exp_name, 'parameters.yaml') |
|
|
save_arguments_to_yaml(opt, yaml_path) |
|
|
|
|
|
train(opt) |
|
|
|