| | import torch |
| | from torch import nn |
| | from wm.model.interface import DIT_CLASS_MAP, VAE_CLASS_MAP |
| |
|
| | class DiffusionForcing_WM(nn.Module): |
| | def __init__(self, model_name, model_config): |
| | super().__init__() |
| | self.model_name = model_name |
| | self.model_config = model_config |
| | |
| | |
| | dit_keys = [ |
| | 'in_channels', 'patch_size', 'dim', 'num_layers', 'num_heads', |
| | 'action_dim', 'action_compress_rate', 'max_frames', |
| | 'rope_config', 'action_dropout_prob', 'temporal_causal' |
| | ] |
| | dit_config = {k: v for k, v in model_config.items() if k in dit_keys} |
| | |
| | |
| | dit_config['temporal_causal'] = True |
| | |
| | self.model = DIT_CLASS_MAP[model_name](**dit_config) |
| | self.vae = VAE_CLASS_MAP[model_config['vae_name']](*model_config.get('vae_config', [])) |
| | |
| | |
| | scheduler_config = model_config.get('scheduler') |
| | if isinstance(scheduler_config, str): |
| | if scheduler_config == "FlowMatch": |
| | from wm.model.diffusion.flow_matching import FlowMatchScheduler |
| | self.scheduler = FlowMatchScheduler() |
| | else: |
| | raise ValueError(f"Unknown scheduler type: {scheduler_config}") |
| | else: |
| | self.scheduler = scheduler_config |
| | |
| | |
| | self.scheduler.set_timesteps(model_config['training_timesteps'], training=True) |
| | |
| | |
| | def encode_obs(self, o): |
| | |
| | |
| | |
| | with torch.no_grad(): |
| | |
| | o = o * 2.0 - 1.0 |
| | |
| | |
| | if o.shape[-1] == 3: |
| | o = o.permute(0, 1, 4, 2, 3).contiguous() |
| | elif o.shape[2] == 3: |
| | |
| | pass |
| | |
| | latent = self.vae.encode(o) |
| | |
| | latent = latent.permute(0, 1, 3, 4, 2).contiguous() |
| | return latent |
| | |
| | |
| | def training_loss(self, z, a): |
| | |
| | |
| | |
| | |
| | B, T = z.shape[0], z.shape[1] |
| | |
| | |
| | t_indices = torch.randint(0, self.scheduler.timesteps.shape[0], (B, T), device=z.device) |
| | t_values = self.scheduler.timesteps[t_indices] |
| | |
| | |
| | z_t, eps = self.scheduler.add_independent_noise(z, t_values) |
| | |
| | v_pred = self.model(z_t, t_values, a) |
| | v_target = self.scheduler.training_target(z, eps, t_values) |
| | |
| | |
| | weights = self.scheduler.training_weight(t_values) |
| | loss = (weights.view(B, T, 1, 1, 1) * (v_pred - v_target)**2).mean() |
| | return loss |
| | |
| | |
| | def full_train_loss(self, o_t, a): |
| | |
| | |
| | |
| | |
| | |
| | a = a.clone() |
| | a[:, -1, :] = 0 |
| | |
| | |
| | z = self.encode_obs(o_t) |
| | |
| | |
| | loss = self.training_loss(z, a) |
| | return loss |
| | |
| | |
| | def generate(self, o_0, a, num_inference_steps=50, noise_level=0.0, mode="autoregressive"): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | B = o_0.shape[0] |
| | T_pixel = a.shape[1] |
| | device = o_0.device |
| | |
| | |
| | z_0 = self.encode_obs(o_0.unsqueeze(1)) |
| | |
| | |
| | T_latent = (T_pixel - 1) // 4 + 1 |
| | H_prime, W_prime = z_0.shape[2], z_0.shape[3] |
| | D = z_0.shape[4] |
| | |
| | |
| | old_training = self.scheduler.training |
| | self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, training=False) |
| | |
| | from tqdm import tqdm |
| | |
| | if mode == "parallel": |
| | |
| | z = torch.randn(B, T_latent, H_prime, W_prime, D, device=device) |
| | |
| | |
| | if noise_level > 0: |
| | t_val_0 = torch.full((B, 1), noise_level, device=device) |
| | z_0_noisy, _ = self.scheduler.add_independent_noise(z_0, t_val_0) |
| | z[:, 0] = z_0_noisy.squeeze(1) |
| | else: |
| | z[:, 0] = z_0.squeeze(1) |
| | |
| | |
| | for i in tqdm(range(len(self.scheduler.timesteps)), desc="Denoising (Parallel)"): |
| | t_val = self.scheduler.timesteps[i] |
| | t = torch.full((B, T_latent), t_val, device=device) |
| | |
| | if noise_level > 0: |
| | t[:, 0] = torch.where(t_val > noise_level, torch.tensor(noise_level, device=device), t_val) |
| | else: |
| | t[:, 0] = 0 |
| | |
| | with torch.no_grad(): |
| | v_pred = self.model(z, t, a) |
| | z = self.scheduler.step(v_pred, t, z) |
| | |
| | if noise_level == 0: |
| | z[:, 0] = z_0.squeeze(1) |
| | |
| | elif mode == "autoregressive": |
| | |
| | z_all = z_0.clone() |
| | |
| | |
| | for t_idx in range(1, T_latent): |
| | |
| | z_next = torch.randn(B, 1, H_prime, W_prime, D, device=device) |
| | z_curr = torch.cat([z_all, z_next], dim=1) |
| | |
| | |
| | for i in range(len(self.scheduler.timesteps)): |
| | t_val = self.scheduler.timesteps[i] |
| | |
| | t_seq = torch.zeros(B, t_idx + 1, device=device) |
| | t_seq[:, -1] = t_val |
| | |
| | |
| | L_curr = self.model.action_compress_rate * t_idx + 1 |
| | a_curr = a[:, :L_curr] |
| | |
| | with torch.no_grad(): |
| | v_pred = self.model(z_curr, t_seq, a_curr) |
| | z_curr = self.scheduler.step(v_pred, t_seq, z_curr) |
| | |
| | |
| | z_curr[:, :-1] = z_all |
| | |
| | |
| | z_all = torch.cat([z_all, z_curr[:, -1:]], dim=1) |
| | |
| | z = z_all |
| | else: |
| | raise ValueError(f"Unknown generation mode: {mode}") |
| |
|
| | |
| | if old_training: |
| | self.scheduler.set_timesteps(self.model_config['training_timesteps'], training=True) |
| | |
| | |
| | with torch.no_grad(): |
| | z_for_vae = z.permute(0, 1, 4, 2, 3).contiguous() |
| | video_recon = self.vae.decode_to_pixel(z_for_vae) |
| | video_recon = (video_recon + 1.0) / 2.0 |
| | video_recon = video_recon.permute(0, 1, 3, 4, 2).contiguous().clamp(0, 1) |
| | |
| | return video_recon |
| |
|