|
|
"""Train streaming motion generation model (MotionStreamer) with llama blocks, Two-Forward strategy and QK-Norm, using the motion latents encoded by the Causal TAE (trained in the first stage).""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
import random |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
import json |
|
|
from accelerate import Accelerator |
|
|
from models.llama_model import LLaMAHF, LLaMAHFConfig |
|
|
import options.option_transformer as option_trans |
|
|
import utils.utils_model as utils_model |
|
|
import warnings |
|
|
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
|
|
|
args = option_trans.get_args_parser() |
|
|
torch.manual_seed(args.seed) |
|
|
|
|
|
def unwrap(m): |
|
|
return m.module if hasattr(m, 'module') else m |
|
|
|
|
|
|
|
|
class WarmupCosineDecayScheduler: |
|
|
def __init__(self, optimizer, warmup_iters, total_iters, min_lr=0): |
|
|
self.optimizer = optimizer |
|
|
self.warmup_iters = warmup_iters |
|
|
self.total_iters = total_iters |
|
|
self.min_lr = min_lr |
|
|
self.warmup_scheduler = LambdaLR(optimizer, lr_lambda=self.warmup_lambda) |
|
|
self.cosine_scheduler = CosineAnnealingLR(optimizer, T_max=total_iters - warmup_iters, eta_min=min_lr) |
|
|
|
|
|
def warmup_lambda(self, current_iter): |
|
|
if current_iter < self.warmup_iters: |
|
|
return float(current_iter) / float(max(1, self.warmup_iters)) |
|
|
return 1.0 |
|
|
|
|
|
def step(self, current_iter): |
|
|
if current_iter < self.warmup_iters: |
|
|
self.warmup_scheduler.step() |
|
|
else: |
|
|
self.cosine_scheduler.step() |
|
|
|
|
|
def state_dict(self): |
|
|
return {'warmup_iters': self.warmup_iters, 'total_iters': self.total_iters, 'min_lr': self.min_lr} |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
self.warmup_iters = state_dict['warmup_iters'] |
|
|
self.total_iters = state_dict['total_iters'] |
|
|
self.min_lr = state_dict['min_lr'] |
|
|
|
|
|
args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}') |
|
|
os.makedirs(args.out_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
accelerator = Accelerator() |
|
|
comp_device = accelerator.device |
|
|
|
|
|
|
|
|
logger = utils_model.get_logger(args.out_dir) |
|
|
writer = SummaryWriter(args.out_dir) |
|
|
logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) |
|
|
|
|
|
|
|
|
from humanml3d_272 import dataset_TM_train_motionstreamer |
|
|
train_loader = dataset_TM_train_motionstreamer.DATALoader( |
|
|
args.dataname, args.batch_size, unit_length=2**args.down_t, latent_dir=args.latent_dir |
|
|
) |
|
|
|
|
|
|
|
|
from sentence_transformers import SentenceTransformer |
|
|
t5_model = SentenceTransformer("sentence-t5-xl", device=comp_device) |
|
|
t5_model.half() |
|
|
t5_model.eval() |
|
|
for p in t5_model.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
|
|
|
config = LLaMAHFConfig.from_name('Normal_size') |
|
|
|
|
|
|
|
|
|
|
|
trans_encoder = LLaMAHF( |
|
|
config=config, |
|
|
num_diffusion_head_layers=args.num_diffusion_head_layers, |
|
|
input_token_dim=args.latent_dim, |
|
|
device=comp_device, |
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
if args.resume_trans is not None: |
|
|
print('loading transformer checkpoint from {}'.format(args.resume_trans)) |
|
|
ckpt = torch.load(args.resume_trans, map_location='cpu') |
|
|
new_ckpt_trans = {} |
|
|
for key in ckpt['trans'].keys(): |
|
|
new_key = '.'.join(key.split('.')[1:]) if key.split('.')[0]=='module' else key |
|
|
new_ckpt_trans[new_key] = ckpt['trans'][key] |
|
|
trans_encoder.load_state_dict(new_ckpt_trans, strict=True) |
|
|
|
|
|
trans_encoder.train() |
|
|
trans_encoder.to(comp_device) |
|
|
|
|
|
|
|
|
optimizer = utils_model.initial_optim(args.decay_option, args.lr, args.weight_decay, trans_encoder, args.optimizer) |
|
|
scheduler = WarmupCosineDecayScheduler(optimizer, args.total_iter//10, args.total_iter) |
|
|
|
|
|
t5_model, trans_encoder, optimizer, train_loader = accelerator.prepare( |
|
|
t5_model, trans_encoder, optimizer, train_loader |
|
|
) |
|
|
base = accelerator.unwrap_model(trans_encoder) |
|
|
train_loader_iter = dataset_TM_train_motionstreamer.cycle(train_loader) |
|
|
|
|
|
args.dit_window = 2 |
|
|
|
|
|
def lengths_to_mask(lengths, max_len): |
|
|
return torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) |
|
|
|
|
|
import math |
|
|
def cosine_decay(step, total_steps, start_value=1.0, end_value=0.0): |
|
|
step = torch.tensor(step, dtype=torch.float32) |
|
|
total_steps = torch.tensor(total_steps, dtype=torch.float32) |
|
|
cosine_factor = 0.5 * (1 + torch.cos(torch.pi * step / total_steps)) |
|
|
return start_value + (end_value - start_value) * cosine_factor |
|
|
|
|
|
def replace_with_pred(latents, pred_xstart, step, total_steps): |
|
|
decay_factor = cosine_decay(step, total_steps).to(latents.device) |
|
|
b, l, d = latents.shape |
|
|
num_replace = int(l * decay_factor) |
|
|
replace_indices = torch.randperm(l, device=latents.device)[:num_replace] |
|
|
replace_mask = torch.zeros(b, l, dtype=torch.bool, device=latents.device) |
|
|
replace_mask[:, replace_indices] = 1 |
|
|
updated_latents = latents.clone() |
|
|
updated_latents[replace_mask] = pred_xstart[replace_mask] |
|
|
return updated_latents |
|
|
|
|
|
|
|
|
def forward_loss_withmask_2_forward_streaming(latents, trans, m_lens, feat_text, |
|
|
step, total_steps, A_token_length, K=None): |
|
|
""" |
|
|
Two-Forward with a *windowed* Temporal-DiT: |
|
|
- AR sees full sequence. |
|
|
- Diffusion head sees only last K positions (causal). |
|
|
""" |
|
|
K = K or getattr(args, "dit_window", 2) |
|
|
|
|
|
latents = latents.to(comp_device) |
|
|
feat_text = feat_text.to(comp_device) |
|
|
A_token_length = A_token_length.to(comp_device) |
|
|
|
|
|
B, L, D = latents.shape |
|
|
L_eff = L - 1 |
|
|
if L_eff <= 0: |
|
|
raise ValueError("Sequence too short for next-token training.") |
|
|
|
|
|
base.set_prompt(feat_text) |
|
|
|
|
|
|
|
|
conditions = trans(latents, feature=None) |
|
|
|
|
|
z_full = conditions[:, 1:-1, :] |
|
|
target_full = latents[:, 1:, :] |
|
|
|
|
|
|
|
|
eff_lens = (m_lens - 1).clamp(min=0) |
|
|
full_mask = torch.arange(L_eff, device=latents.device).unsqueeze(0).expand(B, L_eff) < eff_lens.unsqueeze(1) |
|
|
|
|
|
for b in range(B): |
|
|
a_excl = max(0, A_token_length[b].item() - 1) |
|
|
if a_excl > 0: |
|
|
full_mask[b, :a_excl] = False |
|
|
|
|
|
|
|
|
W = min(K, L_eff) |
|
|
tail_start = L_eff - W |
|
|
z = z_full[:, tail_start:, :] |
|
|
target = target_full[:, tail_start:, :] |
|
|
mask = full_mask[:, tail_start:] |
|
|
mask_flat = mask.reshape(B * W).float() |
|
|
|
|
|
|
|
|
base.diff_loss.set_sequence_layout(B, W) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
loss0, pred_xstart_full = base.diff_loss( |
|
|
target=target.reshape(B * W, D), |
|
|
z=z.reshape(B * W, -1), |
|
|
mask=None |
|
|
) |
|
|
pred_xstart = pred_xstart_full.view(B, W, D) |
|
|
|
|
|
|
|
|
for b in range(B): |
|
|
a_excl = max(0, A_token_length[b].item() - 1) |
|
|
|
|
|
|
|
|
cut = max(0, min(W, a_excl - tail_start)) |
|
|
if cut > 0: |
|
|
pred_xstart[b, :cut, :] = target[b, :cut, :] |
|
|
|
|
|
|
|
|
decay_ratio = 0.5 * (1.0 + torch.cos( |
|
|
torch.pi * torch.tensor(step, dtype=torch.float32, device=latents.device) |
|
|
/ torch.tensor(total_steps, dtype=torch.float32, device=latents.device) |
|
|
)).item() |
|
|
k = int(W * decay_ratio) |
|
|
|
|
|
updated_latents = latents.clone() |
|
|
if k > 0: |
|
|
replace_idx = torch.randperm(W, device=latents.device)[:k] |
|
|
|
|
|
raw_positions = 1 + tail_start + replace_idx |
|
|
|
|
|
updated_latents[:, raw_positions, :] = pred_xstart[:, replace_idx, :] |
|
|
|
|
|
|
|
|
updated_conditions = trans(updated_latents, feature=None) |
|
|
updated_z_full = updated_conditions[:, 1:-1, :] |
|
|
updated_z = updated_z_full[:, tail_start:, :] |
|
|
|
|
|
updated_loss, _ = base.diff_loss( |
|
|
target=target.reshape(B * W, D), |
|
|
z=updated_z.reshape(B * W, -1), |
|
|
mask=mask_flat |
|
|
) |
|
|
return updated_loss |
|
|
|
|
|
|
|
|
nb_iter, avg_loss_cls = 0, 0.0 |
|
|
|
|
|
while nb_iter <= args.total_iter: |
|
|
batch = next(train_loader_iter) |
|
|
caption, m_tokens, m_tokens_len, A_token_length = batch |
|
|
caption = list(caption) |
|
|
m_tokens, m_tokens_len = m_tokens.to(comp_device), m_tokens_len.to(comp_device) |
|
|
A_token_length = A_token_length.to(comp_device) |
|
|
|
|
|
|
|
|
bs = len(caption) |
|
|
num_masked = int(bs * 0.1) |
|
|
if num_masked > 0: |
|
|
for idx in random.sample(range(bs), num_masked): |
|
|
caption[idx] = '' |
|
|
|
|
|
|
|
|
feat_text = torch.from_numpy(t5_model.encode(caption)).float().to(comp_device) |
|
|
|
|
|
|
|
|
input_latent = m_tokens[:, :-1, :] |
|
|
|
|
|
loss_cls = forward_loss_withmask_2_forward_streaming( |
|
|
latents=input_latent, |
|
|
trans=trans_encoder, |
|
|
m_lens=m_tokens_len, |
|
|
feat_text=feat_text, |
|
|
step=nb_iter, |
|
|
total_steps=args.total_iter, |
|
|
A_token_length=A_token_length, |
|
|
K=args.dit_window, |
|
|
) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
accelerator.backward(loss_cls) |
|
|
optimizer.step() |
|
|
scheduler.step(nb_iter) |
|
|
|
|
|
avg_loss_cls += loss_cls.item() |
|
|
nb_iter += 1 |
|
|
|
|
|
|
|
|
args.print_iter = 100 |
|
|
if nb_iter % args.print_iter == 0: |
|
|
if accelerator.is_main_process: |
|
|
avg_loss_cls = avg_loss_cls / args.print_iter |
|
|
writer.add_scalar('./Loss/train', avg_loss_cls, nb_iter) |
|
|
writer.add_scalar('./LR/train', optimizer.param_groups[0]['lr'], nb_iter) |
|
|
logger.info(f"Train. Iter {nb_iter} : Loss. {avg_loss_cls:.5f}") |
|
|
avg_loss_cls = 0.0 |
|
|
|
|
|
|
|
|
args.save_iter = 10000 |
|
|
if nb_iter % args.save_iter == 0: |
|
|
if accelerator.is_main_process: |
|
|
torch.save({'trans': unwrap(trans_encoder).state_dict()}, |
|
|
os.path.join(args.out_dir, f'latest.pth')) |
|
|
|
|
|
accelerator.wait_for_everyone() |
|
|
|