File size: 8,281 Bytes
f17ae24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | import torch
from torch import nn
from wm.model.interface import DIT_CLASS_MAP, VAE_CLASS_MAP
class DiffusionForcing_WM(nn.Module):
def __init__(self, model_name, model_config):
super().__init__()
self.model_name = model_name
self.model_config = model_config
# Filter DiT-specific keys to avoid passing VAE/Scheduler config to the DiT constructor
dit_keys = [
'in_channels', 'patch_size', 'dim', 'num_layers', 'num_heads',
'action_dim', 'action_compress_rate', 'max_frames',
'rope_config', 'action_dropout_prob', 'temporal_causal'
]
dit_config = {k: v for k, v in model_config.items() if k in dit_keys}
# Force temporal_causal=True for Diffusion Forcing
dit_config['temporal_causal'] = True
self.model = DIT_CLASS_MAP[model_name](**dit_config)
self.vae = VAE_CLASS_MAP[model_config['vae_name']](*model_config.get('vae_config', []))
# Handle scheduler instantiation
scheduler_config = model_config.get('scheduler')
if isinstance(scheduler_config, str):
if scheduler_config == "FlowMatch":
from wm.model.diffusion.flow_matching import FlowMatchScheduler
self.scheduler = FlowMatchScheduler()
else:
raise ValueError(f"Unknown scheduler type: {scheduler_config}")
else:
self.scheduler = scheduler_config
# init the scheduler
self.scheduler.set_timesteps(model_config['training_timesteps'], training=True)
def encode_obs(self, o):
# o can be [B, T, 3, H, W] or [B, T, H, W, 3], values in [0, 1]
# return: B, T_latent, H', W', D
with torch.no_grad():
# 1. Normalize [0, 1] -> [-1, 1]
o = o * 2.0 - 1.0
# 2. Ensure shape is [B, T, 3, H, W] for WanVAEWrapper
if o.shape[-1] == 3:
o = o.permute(0, 1, 4, 2, 3).contiguous()
elif o.shape[2] == 3:
# Already [B, T, 3, H, W]
pass
latent = self.vae.encode(o) # [B, T_latent, 16, H/8, W/8]
# To [B, T_latent, H/8, W/8, 16] for DiT
latent = latent.permute(0, 1, 3, 4, 2).contiguous()
return latent
def training_loss(self, z, a):
# z: B, T', H', W', D
# a: B, T_pixel, C_a
# return: loss
B, T = z.shape[0], z.shape[1]
# Sample independent timesteps for ALL frames (including first frame)
t_indices = torch.randint(0, self.scheduler.timesteps.shape[0], (B, T), device=z.device)
t_values = self.scheduler.timesteps[t_indices] # [B, T]
# Add independent noise using the helper we added to FlowMatchScheduler
z_t, eps = self.scheduler.add_independent_noise(z, t_values)
v_pred = self.model(z_t, t_values, a)
v_target = self.scheduler.training_target(z, eps, t_values)
# Apply training weights
weights = self.scheduler.training_weight(t_values)
loss = (weights.view(B, T, 1, 1, 1) * (v_pred - v_target)**2).mean()
return loss
def full_train_loss(self, o_t, a):
# o_t: B, T_pixel, H, W, 3
# a: B, T_pixel, C_a
# return: loss
# zero out the last action since it's not used for training
a = a.clone()
a[:, -1, :] = 0
# encode the obs
z = self.encode_obs(o_t) # B, T', H', W', D
# add noise and get training loss
loss = self.training_loss(z, a)
return loss
def generate(self, o_0, a, num_inference_steps=50, noise_level=0.0, mode="autoregressive"):
# o_0: B, H, W, 3
# a: B, T_pixel, A
# return: B, T_pixel, H, W, 3
# Diffusion Forcing (Causal) can be run in two modes:
# 1. "parallel": Denoise the whole sequence at once (fastest, uses causal mask)
# 2. "autoregressive": Denoise one frame at a time (more stable for long horizons)
B = o_0.shape[0]
T_pixel = a.shape[1]
device = o_0.device
# 1. Encode first frame
z_0 = self.encode_obs(o_0.unsqueeze(1)) # [B, 1, H', W', 16]
# 2. Determine latent shape
T_latent = (T_pixel - 1) // 4 + 1
H_prime, W_prime = z_0.shape[2], z_0.shape[3]
D = z_0.shape[4] # 16
# Save old scheduler state
old_training = self.scheduler.training
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, training=False)
from tqdm import tqdm
if mode == "parallel":
# 3. Initialize latent sequence with noise
z = torch.randn(B, T_latent, H_prime, W_prime, D, device=device)
# 4. Handle first frame noise level
if noise_level > 0:
t_val_0 = torch.full((B, 1), noise_level, device=device)
z_0_noisy, _ = self.scheduler.add_independent_noise(z_0, t_val_0)
z[:, 0] = z_0_noisy.squeeze(1)
else:
z[:, 0] = z_0.squeeze(1)
# 5. Denoising loop
for i in tqdm(range(len(self.scheduler.timesteps)), desc="Denoising (Parallel)"):
t_val = self.scheduler.timesteps[i]
t = torch.full((B, T_latent), t_val, device=device)
if noise_level > 0:
t[:, 0] = torch.where(t_val > noise_level, torch.tensor(noise_level, device=device), t_val)
else:
t[:, 0] = 0
with torch.no_grad():
v_pred = self.model(z, t, a)
z = self.scheduler.step(v_pred, t, z)
if noise_level == 0:
z[:, 0] = z_0.squeeze(1)
elif mode == "autoregressive":
# 3. Start with only the first frame
z_all = z_0.clone() # [B, 1, H', W', D]
# 4. Roll the window frame by frame
for t_idx in range(1, T_latent):
# a. Add a new noisy frame to the end
z_next = torch.randn(B, 1, H_prime, W_prime, D, device=device)
z_curr = torch.cat([z_all, z_next], dim=1) # [B, t+1, ...]
# b. Denoise only the LAST frame in the current sequence
for i in range(len(self.scheduler.timesteps)):
t_val = self.scheduler.timesteps[i]
# History is clean (t=0), last frame is noisy (t=t_val)
t_seq = torch.zeros(B, t_idx + 1, device=device)
t_seq[:, -1] = t_val
# Correct actions for current length
L_curr = self.model.action_compress_rate * t_idx + 1
a_curr = a[:, :L_curr]
with torch.no_grad():
v_pred = self.model(z_curr, t_seq, a_curr)
z_curr = self.scheduler.step(v_pred, t_seq, z_curr)
# Fix history frames (optional but recommended for stability)
z_curr[:, :-1] = z_all
# c. Record the clean result for the next iteration
z_all = torch.cat([z_all, z_curr[:, -1:]], dim=1)
z = z_all
else:
raise ValueError(f"Unknown generation mode: {mode}")
# Restore scheduler state
if old_training:
self.scheduler.set_timesteps(self.model_config['training_timesteps'], training=True)
# 6. Decode back to pixels
with torch.no_grad():
z_for_vae = z.permute(0, 1, 4, 2, 3).contiguous()
video_recon = self.vae.decode_to_pixel(z_for_vae)
video_recon = (video_recon + 1.0) / 2.0
video_recon = video_recon.permute(0, 1, 3, 4, 2).contiguous().clamp(0, 1)
return video_recon
|