TunaDance / train_seq.py
NikhilMarisetty's picture
Upload folder using huggingface_hub
eb71a72 verified
import multiprocessing
import os
from zlib import Z_FULL_FLUSH
# os.environ["WANDB_API_KEY"] = "your WANDB_API_KEY" #
# os.environ["WANDB_MODE"] = "online"
# os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7"
import pickle
from functools import partial
from pathlib import Path
from args import FineDance_parse_train_opt, save_arguments_to_yaml
import sys
import torch
import torch.nn.functional as F
import wandb
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.state import AcceleratorState
from torch.utils.data import DataLoader
from tqdm import tqdm
from dataset.FineDance_dataset import FineDance_Smpl
from dataset.preprocess import increment_path
from dataset.preprocess import My_Normalizer as Normalizer # do not use Normalizer
from model.adan import Adan
from model.diffusion import GaussianDiffusion
from model.model import DanceDecoder, SeqModel
from vis import SMPLX_Skeleton, SMPLSkeleton
def wrap(x):
return {f"module.{key}": value for key, value in x.items()}
def maybe_wrap(x, num):
return x if num == 1 else wrap(x)
class EDGE:
def __init__(
self,
opt,
feature_type,
checkpoint_path="",
normalizer=None,
EMA=True,
learning_rate=4e-4,
weight_decay=0.02,
):
self.opt = opt
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
self.accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
state = AcceleratorState()
num_processes = state.num_processes
self.repr_dim = repr_dim = opt.nfeats
feature_dim = 35
self.horizon = horizon = opt.full_seq_len
self.accelerator.wait_for_everyone()
self.resume_num = 0
checkpoint = None
self.normalizer = None
if checkpoint_path != "":
checkpoint = torch.load(
checkpoint_path, map_location=self.accelerator.device
)
self.resume_num = int(os.path.basename(checkpoint_path).split("-")[1].split(".")[0]) # int(os.path.basenam
model = SeqModel(
nfeats=repr_dim,
seq_len=horizon,
latent_dim=512,
ff_size=1024,
num_layers=8,
num_heads=8,
dropout=0.1,
cond_feature_dim=feature_dim,
activation=F.gelu,
)
if opt.nfeats == 139 or opt.nfeats == 135:
smplx_fk = SMPLSkeleton(device=self.accelerator.device)
else:
smplx_fk = SMPLX_Skeleton(device=self.accelerator.device, batch=512000)
diffusion = GaussianDiffusion(
model,
opt,
horizon,
repr_dim,
smplx_model = smplx_fk,
schedule="cosine",
n_timestep=1000,
predict_epsilon=False,
loss_type="l2",
use_p2=False,
cond_drop_prob=0.25,
guidance_weight=2,
do_normalize = opt.do_normalize
)
print(
"Model has {} parameters".format(sum(y.numel() for y in model.parameters()))
)
self.model = self.accelerator.prepare(model)
self.diffusion = diffusion.to(self.accelerator.device) # 为什么这里不需要prepare
self.smplx_fk = smplx_fk # to(self.accelerator.device)
optim = Adan(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
self.optim = self.accelerator.prepare(optim)
if checkpoint_path != "":
self.model.load_state_dict(
maybe_wrap(
checkpoint["ema_state_dict" if EMA else "model_state_dict"],
num_processes,
)
)
def eval(self):
self.diffusion.eval()
def train(self):
self.diffusion.train()
def prepare(self, objects):
return self.accelerator.prepare(*objects)
def train_loop(self, opt):
print("train_dataset = FineDance_Dataset ")
train_dataset = FineDance_Smpl(
args=opt, # data/
istrain=True,
)
test_dataset = FineDance_Smpl(
args=opt,
istrain=False,
)
num_cpus = multiprocessing.cpu_count()
print("batchsize=:", opt.batch_size)
train_data_loader = DataLoader(
train_dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=min(int(num_cpus * 0.5), 40), # num_workers=min(int(num_cpus * 0.75), 32),
pin_memory=True,
drop_last=True,
)
test_data_loader = DataLoader(
test_dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=2,
pin_memory=True,
drop_last=True,
)
train_data_loader = self.accelerator.prepare(train_data_loader)
# boot up multi-gpu training. test dataloader is only on main process
load_loop = (
partial(tqdm, position=1, desc="Batch")
if self.accelerator.is_main_process
else lambda x: x
)
if self.accelerator.is_main_process:
save_dir = str(increment_path(Path(opt.project) / opt.exp_name))
opt.exp_name = save_dir.split("/")[-1]
wandb.init(project=opt.wandb_pj_name, name=opt.exp_name)
save_dir = Path(save_dir)
wdir = save_dir / "weights"
wdir.mkdir(parents=True, exist_ok=True)
wandb.save("params.yaml") # 保存wandb配置到文件
yaml_path = os.path.join(wdir, 'parameters.yaml')
save_arguments_to_yaml(opt, yaml_path)
self.accelerator.wait_for_everyone()
for epoch in range(1, opt.epochs + 1):
print("epoch:", epoch+self.resume_num)
avg_loss = 0
avg_vloss = 0
avg_fkloss = 0
avg_footloss = 0
# train
self.train()
for step, (x, cond, filename) in enumerate(
load_loop(train_data_loader)
):
if opt.nfeats == 139 or opt.nfeats==135:
x = x[:, :, :139]
total_loss, (loss, v_loss, fk_loss, foot_loss) = self.diffusion(
x, cond, t_override=None
)
# print("3")
self.optim.zero_grad()
self.accelerator.backward(total_loss)
self.optim.step()
# ema update and train loss update only on main
if self.accelerator.is_main_process:
avg_loss += loss.detach().cpu().numpy()
avg_vloss += v_loss.detach().cpu().numpy()
avg_fkloss += fk_loss.detach().cpu().numpy()
avg_footloss += foot_loss.detach().cpu().numpy()
if step % opt.ema_interval == 0:
self.diffusion.ema.update_model_average(
self.diffusion.master_model, self.diffusion.model
)
#-----------------------------------------------------------------------------------------------------------
# test
# Save model
if ((epoch+self.resume_num) % opt.save_interval) == 0 or epoch<=1:
# everyone waits here for the val loop to finish ( don't start next train epoch early)
self.accelerator.wait_for_everyone()
self.eval() # debug!
# save only if on main thread
if self.accelerator.is_main_process:
# self.eval()
# log
avg_loss /= len(train_data_loader)
avg_vloss /= len(train_data_loader)
avg_fkloss /= len(train_data_loader)
avg_footloss /= len(train_data_loader)
log_dict = {
"Train Loss": avg_loss,
"V Loss": avg_vloss,
"FK Loss": avg_fkloss,
"Foot Loss": avg_footloss,
}
wandb.log(log_dict)
ckpt = {
"ema_state_dict": self.diffusion.master_model.state_dict(), # 经过accelerate prepare的模型,在保存时需要unwrap,反之不需要
"model_state_dict": self.accelerator.unwrap_model(
self.model
).state_dict(),
"optimizer_state_dict": self.optim.state_dict(),
"normalizer": self.normalizer,
}
torch.save(ckpt, os.path.join(wdir, f"train-{epoch+self.resume_num}.pt"))
print(f"[MODEL SAVED at Epoch {epoch+self.resume_num}]")
# generate a sample
render_count = 2
shape = (render_count, self.horizon, self.opt.nfeats)
print("Generating Sample")
# draw a music from the test dataset
(x, cond, filename) = next(iter(test_data_loader))
# if opt.do_normalize:
# x = self.normalizer.normalize(x)
if opt.nfeats == 139 or opt.nfeats==135:
x = x[:, :, :139]
cond = cond.to(self.accelerator.device)
# name_iter = name_iter+1
self.diffusion.render_sample(
shape,
cond[:render_count],
self.normalizer,
epoch+self.resume_num,
render_out = os.path.join(opt.render_dir, "train_" + opt.exp_name), # render out
fk_out = os.path.join(opt.render_dir, "train_" + opt.exp_name),
name=filename[:render_count],
# name = str(epoch) + str(name_iter).zfill(3)
sound=True,
)
#-----------------------------------------------------------------------------------------------------------
if self.accelerator.is_main_process:
wandb.run.finish()
def render_sample(
self, data_tuple, label, render_dir, render_count=-1, mode='normal', fk_out=None, render=True,
):
_, cond, wavname = data_tuple
assert len(cond.shape) == 3
if render_count < 0:
render_count = len(cond)
shape = (render_count, self.horizon, self.repr_dim)
cond = cond.to(self.accelerator.device).float()
self.diffusion.render_sample(
shape,
cond[:render_count],
self.normalizer,
label,
render_dir,
name=wavname[:render_count],
sound=True,
mode=mode,
fk_out=fk_out,
render=render
)
def train(opt):
model = EDGE(opt, opt.feature_type)
model.train_loop(opt)
if __name__ == "__main__":
opt = FineDance_parse_train_opt()
command = ' '.join(sys.argv)
if not os.path.exists(os.path.join(opt.project, opt.exp_name)):
os.makedirs(os.path.join(opt.project, opt.exp_name), exist_ok=False)
with open(os.path.join(opt.project, opt.exp_name, 'command.txt'), 'w') as f:
f.write(command)
yaml_path = os.path.join(opt.project, opt.exp_name, 'parameters.yaml')
save_arguments_to_yaml(opt, yaml_path)
train(opt)