zirobtc commited on
Commit
a811761
·
verified ·
1 Parent(s): 7b6b374

Upload train_motionstreamer.py

Browse files
Files changed (1) hide show
  1. train_motionstreamer.py +293 -0
train_motionstreamer.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Train streaming motion generation model (MotionStreamer) with llama blocks, Two-Forward strategy and QK-Norm, using the motion latents encoded by the Causal TAE (trained in the first stage)."""
2
+
3
+ import os
4
+ import torch
5
+ import numpy as np
6
+ import random
7
+ from torch.utils.tensorboard import SummaryWriter
8
+ import json
9
+ from accelerate import Accelerator
10
+ from models.llama_model import LLaMAHF, LLaMAHFConfig
11
+ import options.option_transformer as option_trans
12
+ import utils.utils_model as utils_model
13
+ import warnings
14
+ from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
15
+ warnings.filterwarnings('ignore')
16
+
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+
19
+ ##### ---- Args / Exp dirs ---- #####
20
+ args = option_trans.get_args_parser()
21
+ torch.manual_seed(args.seed)
22
+
23
+ def unwrap(m):
24
+ return m.module if hasattr(m, 'module') else m
25
+
26
+ # ---- warm-up + cosine decay scheduler ----
27
+ class WarmupCosineDecayScheduler:
28
+ def __init__(self, optimizer, warmup_iters, total_iters, min_lr=0):
29
+ self.optimizer = optimizer
30
+ self.warmup_iters = warmup_iters
31
+ self.total_iters = total_iters
32
+ self.min_lr = min_lr
33
+ self.warmup_scheduler = LambdaLR(optimizer, lr_lambda=self.warmup_lambda)
34
+ self.cosine_scheduler = CosineAnnealingLR(optimizer, T_max=total_iters - warmup_iters, eta_min=min_lr)
35
+
36
+ def warmup_lambda(self, current_iter):
37
+ if current_iter < self.warmup_iters:
38
+ return float(current_iter) / float(max(1, self.warmup_iters))
39
+ return 1.0
40
+
41
+ def step(self, current_iter):
42
+ if current_iter < self.warmup_iters:
43
+ self.warmup_scheduler.step()
44
+ else:
45
+ self.cosine_scheduler.step()
46
+
47
+ def state_dict(self):
48
+ return {'warmup_iters': self.warmup_iters, 'total_iters': self.total_iters, 'min_lr': self.min_lr}
49
+
50
+ def load_state_dict(self, state_dict):
51
+ self.warmup_iters = state_dict['warmup_iters']
52
+ self.total_iters = state_dict['total_iters']
53
+ self.min_lr = state_dict['min_lr']
54
+
55
+ args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}')
56
+ os.makedirs(args.out_dir, exist_ok=True)
57
+
58
+ ##### ---- Accelerator ---- #####
59
+ accelerator = Accelerator()
60
+ comp_device = accelerator.device
61
+
62
+ ##### ---- Logger ---- #####
63
+ logger = utils_model.get_logger(args.out_dir)
64
+ writer = SummaryWriter(args.out_dir)
65
+ logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
66
+
67
+ ##### ---- Dataloader ---- #####
68
+ from humanml3d_272 import dataset_TM_train_motionstreamer
69
+ train_loader = dataset_TM_train_motionstreamer.DATALoader(
70
+ args.dataname, args.batch_size, unit_length=2**args.down_t, latent_dir=args.latent_dir
71
+ )
72
+
73
+ ##### ---- Text encoder (frozen) ---- #####
74
+ from sentence_transformers import SentenceTransformer
75
+ t5_model = SentenceTransformer("sentence-t5-xl", device=comp_device)
76
+ t5_model.half() # if GPU supports fp16/bf16
77
+ t5_model.eval()
78
+ for p in t5_model.parameters():
79
+ p.requires_grad = False
80
+
81
+ ##### ---- Network ---- #####
82
+ config = LLaMAHFConfig.from_name('Normal_size')
83
+ # Optional: set a tighter block size if you know max tokens per seq; otherwise leave default.
84
+ # config.block_size = 78
85
+
86
+ trans_encoder = LLaMAHF(
87
+ config=config,
88
+ num_diffusion_head_layers=args.num_diffusion_head_layers,
89
+ input_token_dim=args.latent_dim,
90
+ device=comp_device,
91
+ # width defaults to 1792; override via args if you want:
92
+ # width=args.diff_width
93
+ )
94
+
95
+ if args.resume_trans is not None:
96
+ print('loading transformer checkpoint from {}'.format(args.resume_trans))
97
+ ckpt = torch.load(args.resume_trans, map_location='cpu')
98
+ new_ckpt_trans = {}
99
+ for key in ckpt['trans'].keys():
100
+ new_key = '.'.join(key.split('.')[1:]) if key.split('.')[0]=='module' else key
101
+ new_ckpt_trans[new_key] = ckpt['trans'][key]
102
+ trans_encoder.load_state_dict(new_ckpt_trans, strict=True)
103
+
104
+ trans_encoder.train()
105
+ trans_encoder.to(comp_device)
106
+
107
+ ##### ---- Optimizer & Scheduler ---- #####
108
+ optimizer = utils_model.initial_optim(args.decay_option, args.lr, args.weight_decay, trans_encoder, args.optimizer)
109
+ scheduler = WarmupCosineDecayScheduler(optimizer, args.total_iter//10, args.total_iter)
110
+
111
+ t5_model, trans_encoder, optimizer, train_loader = accelerator.prepare(
112
+ t5_model, trans_encoder, optimizer, train_loader
113
+ )
114
+ base = accelerator.unwrap_model(trans_encoder)
115
+ train_loader_iter = dataset_TM_train_motionstreamer.cycle(train_loader)
116
+
117
+ args.dit_window = 2
118
+
119
+ def lengths_to_mask(lengths, max_len):
120
+ return torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
121
+
122
+ import math
123
+ def cosine_decay(step, total_steps, start_value=1.0, end_value=0.0):
124
+ step = torch.tensor(step, dtype=torch.float32)
125
+ total_steps = torch.tensor(total_steps, dtype=torch.float32)
126
+ cosine_factor = 0.5 * (1 + torch.cos(torch.pi * step / total_steps))
127
+ return start_value + (end_value - start_value) * cosine_factor
128
+
129
+ def replace_with_pred(latents, pred_xstart, step, total_steps):
130
+ decay_factor = cosine_decay(step, total_steps).to(latents.device)
131
+ b, l, d = latents.shape
132
+ num_replace = int(l * decay_factor)
133
+ replace_indices = torch.randperm(l, device=latents.device)[:num_replace]
134
+ replace_mask = torch.zeros(b, l, dtype=torch.bool, device=latents.device)
135
+ replace_mask[:, replace_indices] = 1
136
+ updated_latents = latents.clone()
137
+ updated_latents[replace_mask] = pred_xstart[replace_mask]
138
+ return updated_latents
139
+
140
+ # ---- Two-Forward with cached prompt + temporal DiT head ----
141
+ def forward_loss_withmask_2_forward_streaming(latents, trans, m_lens, feat_text,
142
+ step, total_steps, A_token_length, K=None):
143
+ """
144
+ Two-Forward with a *windowed* Temporal-DiT:
145
+ - AR sees full sequence.
146
+ - Diffusion head sees only last K positions (causal).
147
+ """
148
+ K = K or getattr(args, "dit_window", 2) # default to 2 if not provided
149
+
150
+ latents = latents.to(comp_device) # [B, L, D]
151
+ feat_text = feat_text.to(comp_device) # [B, Dtxt]
152
+ A_token_length = A_token_length.to(comp_device)
153
+
154
+ B, L, D = latents.shape
155
+ L_eff = L - 1
156
+ if L_eff <= 0:
157
+ raise ValueError("Sequence too short for next-token training.")
158
+
159
+ base.set_prompt(feat_text) # cache text once
160
+
161
+ # --- AR forward (full) ---
162
+ conditions = trans(latents, feature=None) # [B, L, C] (BOS already added inside)
163
+ # shift for next-token training (BOS-aware):
164
+ z_full = conditions[:, 1:-1, :] # [B, L-1, C]
165
+ target_full = latents[:, 1:, :] # [B, L-1, D]
166
+
167
+ # --- build full mask on shifted axis, then tail-slice to K ---
168
+ eff_lens = (m_lens - 1).clamp(min=0) # lengths in shifted space
169
+ full_mask = torch.arange(L_eff, device=latents.device).unsqueeze(0).expand(B, L_eff) < eff_lens.unsqueeze(1)
170
+ # exclude A-motion in shifted space: [0 .. A_len-2]
171
+ for b in range(B):
172
+ a_excl = max(0, A_token_length[b].item() - 1)
173
+ if a_excl > 0:
174
+ full_mask[b, :a_excl] = False
175
+
176
+ # --- restrict to last K positions for diffusion ---
177
+ W = min(K, L_eff)
178
+ tail_start = L_eff - W
179
+ z = z_full[:, tail_start:, :] # [B, W, C]
180
+ target = target_full[:, tail_start:, :] # [B, W, D]
181
+ mask = full_mask[:, tail_start:] # [B, W]
182
+ mask_flat = mask.reshape(B * W).float()
183
+
184
+ # Tell DiT we are a (B, W) sequence
185
+ base.diff_loss.set_sequence_layout(B, W)
186
+
187
+ # ================= First pass (teacher) =================
188
+ with torch.no_grad():
189
+ # flatten for diffusion loss API
190
+ loss0, pred_xstart_full = base.diff_loss(
191
+ target=target.reshape(B * W, D),
192
+ z=z.reshape(B * W, -1),
193
+ mask=None # teacher doesn't need a mask
194
+ )
195
+ pred_xstart = pred_xstart_full.view(B, W, D)
196
+
197
+ # keep GT for A-motion region if the tail overlaps A
198
+ for b in range(B):
199
+ a_excl = max(0, A_token_length[b].item() - 1)
200
+ # in shifted axis, A spans [:a_excl]; in tail window that corresponds to indices < a_excl - tail_start
201
+ # so clamp to [0, W)
202
+ cut = max(0, min(W, a_excl - tail_start))
203
+ if cut > 0:
204
+ pred_xstart[b, :cut, :] = target[b, :cut, :]
205
+
206
+ # cosine-decayed teacher mixing, but only inside the tail window
207
+ decay_ratio = 0.5 * (1.0 + torch.cos(
208
+ torch.pi * torch.tensor(step, dtype=torch.float32, device=latents.device)
209
+ / torch.tensor(total_steps, dtype=torch.float32, device=latents.device)
210
+ )).item()
211
+ k = int(W * decay_ratio)
212
+
213
+ updated_latents = latents.clone()
214
+ if k > 0:
215
+ replace_idx = torch.randperm(W, device=latents.device)[:k] # local indices in [0..W-1]
216
+ # map tail-window indices (shifted space) back to raw latents positions (+1 for next-token position)
217
+ raw_positions = 1 + tail_start + replace_idx
218
+ # write teacher predictions into raw stream at those positions
219
+ updated_latents[:, raw_positions, :] = pred_xstart[:, replace_idx, :]
220
+
221
+ # ================= Second pass (refined) =================
222
+ updated_conditions = trans(updated_latents, feature=None) # [B, L, C]
223
+ updated_z_full = updated_conditions[:, 1:-1, :] # [B, L-1, C]
224
+ updated_z = updated_z_full[:, tail_start:, :] # [B, W, C]
225
+
226
+ updated_loss, _ = base.diff_loss(
227
+ target=target.reshape(B * W, D),
228
+ z=updated_z.reshape(B * W, -1),
229
+ mask=mask_flat
230
+ )
231
+ return updated_loss
232
+
233
+ ##### ---- Training Loop ---- #####
234
+ nb_iter, avg_loss_cls = 0, 0.0
235
+
236
+ while nb_iter <= args.total_iter:
237
+ batch = next(train_loader_iter)
238
+ caption, m_tokens, m_tokens_len, A_token_length = batch
239
+ caption = list(caption)
240
+ m_tokens, m_tokens_len = m_tokens.to(comp_device), m_tokens_len.to(comp_device)
241
+ A_token_length = A_token_length.to(comp_device)
242
+
243
+ # 10% empty captions for CFG-style robustness
244
+ bs = len(caption)
245
+ num_masked = int(bs * 0.1)
246
+ if num_masked > 0:
247
+ for idx in random.sample(range(bs), num_masked):
248
+ caption[idx] = ''
249
+
250
+ # Text features (T5-xxl sentence embeddings)
251
+ feat_text = torch.from_numpy(t5_model.encode(caption)).float().to(comp_device)
252
+
253
+ # Ground truth latents (AR next-token: we predict t+1 from up to t)
254
+ input_latent = m_tokens[:, :-1, :] # [B, L, D]
255
+
256
+ loss_cls = forward_loss_withmask_2_forward_streaming(
257
+ latents=input_latent,
258
+ trans=trans_encoder,
259
+ m_lens=m_tokens_len,
260
+ feat_text=feat_text,
261
+ step=nb_iter,
262
+ total_steps=args.total_iter,
263
+ A_token_length=A_token_length,
264
+ K=args.dit_window,
265
+ )
266
+
267
+ # backward & step
268
+ optimizer.zero_grad()
269
+ accelerator.backward(loss_cls)
270
+ optimizer.step()
271
+ scheduler.step(nb_iter)
272
+
273
+ avg_loss_cls += loss_cls.item()
274
+ nb_iter += 1
275
+
276
+ # Logs
277
+ args.print_iter = 100
278
+ if nb_iter % args.print_iter == 0:
279
+ if accelerator.is_main_process:
280
+ avg_loss_cls = avg_loss_cls / args.print_iter
281
+ writer.add_scalar('./Loss/train', avg_loss_cls, nb_iter)
282
+ writer.add_scalar('./LR/train', optimizer.param_groups[0]['lr'], nb_iter)
283
+ logger.info(f"Train. Iter {nb_iter} : Loss. {avg_loss_cls:.5f}")
284
+ avg_loss_cls = 0.0
285
+
286
+ # Checkpoint
287
+ args.save_iter = 10000
288
+ if nb_iter % args.save_iter == 0:
289
+ if accelerator.is_main_process:
290
+ torch.save({'trans': unwrap(trans_encoder).state_dict()},
291
+ os.path.join(args.out_dir, f'latest.pth'))
292
+
293
+ accelerator.wait_for_everyone()