import torch from torch import nn from wm.model.interface import DIT_CLASS_MAP, VAE_CLASS_MAP class DiffusionForcing_WM(nn.Module): def __init__(self, model_name, model_config): super().__init__() self.model_name = model_name self.model_config = model_config # Filter DiT-specific keys to avoid passing VAE/Scheduler config to the DiT constructor dit_keys = [ 'in_channels', 'patch_size', 'dim', 'num_layers', 'num_heads', 'action_dim', 'action_compress_rate', 'max_frames', 'rope_config', 'action_dropout_prob', 'temporal_causal' ] dit_config = {k: v for k, v in model_config.items() if k in dit_keys} # Force temporal_causal=True for Diffusion Forcing dit_config['temporal_causal'] = True self.model = DIT_CLASS_MAP[model_name](**dit_config) self.vae = VAE_CLASS_MAP[model_config['vae_name']](*model_config.get('vae_config', [])) # Handle scheduler instantiation scheduler_config = model_config.get('scheduler') if isinstance(scheduler_config, str): if scheduler_config == "FlowMatch": from wm.model.diffusion.flow_matching import FlowMatchScheduler self.scheduler = FlowMatchScheduler() else: raise ValueError(f"Unknown scheduler type: {scheduler_config}") else: self.scheduler = scheduler_config # init the scheduler self.scheduler.set_timesteps(model_config['training_timesteps'], training=True) def encode_obs(self, o): # o can be [B, T, 3, H, W] or [B, T, H, W, 3], values in [0, 1] # return: B, T_latent, H', W', D with torch.no_grad(): # 1. Normalize [0, 1] -> [-1, 1] o = o * 2.0 - 1.0 # 2. Ensure shape is [B, T, 3, H, W] for WanVAEWrapper if o.shape[-1] == 3: o = o.permute(0, 1, 4, 2, 3).contiguous() elif o.shape[2] == 3: # Already [B, T, 3, H, W] pass latent = self.vae.encode(o) # [B, T_latent, 16, H/8, W/8] # To [B, T_latent, H/8, W/8, 16] for DiT latent = latent.permute(0, 1, 3, 4, 2).contiguous() return latent def training_loss(self, z, a): # z: B, T', H', W', D # a: B, T_pixel, C_a # return: loss B, T = z.shape[0], z.shape[1] # Sample independent timesteps for ALL frames (including first frame) t_indices = torch.randint(0, self.scheduler.timesteps.shape[0], (B, T), device=z.device) t_values = self.scheduler.timesteps[t_indices] # [B, T] # Add independent noise using the helper we added to FlowMatchScheduler z_t, eps = self.scheduler.add_independent_noise(z, t_values) v_pred = self.model(z_t, t_values, a) v_target = self.scheduler.training_target(z, eps, t_values) # Apply training weights weights = self.scheduler.training_weight(t_values) loss = (weights.view(B, T, 1, 1, 1) * (v_pred - v_target)**2).mean() return loss def full_train_loss(self, o_t, a): # o_t: B, T_pixel, H, W, 3 # a: B, T_pixel, C_a # return: loss # zero out the last action since it's not used for training a = a.clone() a[:, -1, :] = 0 # encode the obs z = self.encode_obs(o_t) # B, T', H', W', D # add noise and get training loss loss = self.training_loss(z, a) return loss def generate(self, o_0, a, num_inference_steps=50, noise_level=0.0, mode="autoregressive"): # o_0: B, H, W, 3 # a: B, T_pixel, A # return: B, T_pixel, H, W, 3 # Diffusion Forcing (Causal) can be run in two modes: # 1. "parallel": Denoise the whole sequence at once (fastest, uses causal mask) # 2. "autoregressive": Denoise one frame at a time (more stable for long horizons) B = o_0.shape[0] T_pixel = a.shape[1] device = o_0.device # 1. Encode first frame z_0 = self.encode_obs(o_0.unsqueeze(1)) # [B, 1, H', W', 16] # 2. Determine latent shape T_latent = (T_pixel - 1) // 4 + 1 H_prime, W_prime = z_0.shape[2], z_0.shape[3] D = z_0.shape[4] # 16 # Save old scheduler state old_training = self.scheduler.training self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, training=False) from tqdm import tqdm if mode == "parallel": # 3. Initialize latent sequence with noise z = torch.randn(B, T_latent, H_prime, W_prime, D, device=device) # 4. Handle first frame noise level if noise_level > 0: t_val_0 = torch.full((B, 1), noise_level, device=device) z_0_noisy, _ = self.scheduler.add_independent_noise(z_0, t_val_0) z[:, 0] = z_0_noisy.squeeze(1) else: z[:, 0] = z_0.squeeze(1) # 5. Denoising loop for i in tqdm(range(len(self.scheduler.timesteps)), desc="Denoising (Parallel)"): t_val = self.scheduler.timesteps[i] t = torch.full((B, T_latent), t_val, device=device) if noise_level > 0: t[:, 0] = torch.where(t_val > noise_level, torch.tensor(noise_level, device=device), t_val) else: t[:, 0] = 0 with torch.no_grad(): v_pred = self.model(z, t, a) z = self.scheduler.step(v_pred, t, z) if noise_level == 0: z[:, 0] = z_0.squeeze(1) elif mode == "autoregressive": # 3. Start with only the first frame z_all = z_0.clone() # [B, 1, H', W', D] # 4. Roll the window frame by frame for t_idx in range(1, T_latent): # a. Add a new noisy frame to the end z_next = torch.randn(B, 1, H_prime, W_prime, D, device=device) z_curr = torch.cat([z_all, z_next], dim=1) # [B, t+1, ...] # b. Denoise only the LAST frame in the current sequence for i in range(len(self.scheduler.timesteps)): t_val = self.scheduler.timesteps[i] # History is clean (t=0), last frame is noisy (t=t_val) t_seq = torch.zeros(B, t_idx + 1, device=device) t_seq[:, -1] = t_val # Correct actions for current length L_curr = self.model.action_compress_rate * t_idx + 1 a_curr = a[:, :L_curr] with torch.no_grad(): v_pred = self.model(z_curr, t_seq, a_curr) z_curr = self.scheduler.step(v_pred, t_seq, z_curr) # Fix history frames (optional but recommended for stability) z_curr[:, :-1] = z_all # c. Record the clean result for the next iteration z_all = torch.cat([z_all, z_curr[:, -1:]], dim=1) z = z_all else: raise ValueError(f"Unknown generation mode: {mode}") # Restore scheduler state if old_training: self.scheduler.set_timesteps(self.model_config['training_timesteps'], training=True) # 6. Decode back to pixels with torch.no_grad(): z_for_vae = z.permute(0, 1, 4, 2, 3).contiguous() video_recon = self.vae.decode_to_pixel(z_for_vae) video_recon = (video_recon + 1.0) / 2.0 video_recon = video_recon.permute(0, 1, 3, 4, 2).contiguous().clamp(0, 1) return video_recon