File size: 8,281 Bytes
f17ae24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import torch
from torch import nn
from wm.model.interface import DIT_CLASS_MAP, VAE_CLASS_MAP

class DiffusionForcing_WM(nn.Module):
    def __init__(self, model_name, model_config):
        super().__init__()
        self.model_name = model_name
        self.model_config = model_config
    
        # Filter DiT-specific keys to avoid passing VAE/Scheduler config to the DiT constructor
        dit_keys = [
            'in_channels', 'patch_size', 'dim', 'num_layers', 'num_heads', 
            'action_dim', 'action_compress_rate', 'max_frames', 
            'rope_config', 'action_dropout_prob', 'temporal_causal'
        ]
        dit_config = {k: v for k, v in model_config.items() if k in dit_keys}
        
        # Force temporal_causal=True for Diffusion Forcing
        dit_config['temporal_causal'] = True
        
        self.model = DIT_CLASS_MAP[model_name](**dit_config)
        self.vae = VAE_CLASS_MAP[model_config['vae_name']](*model_config.get('vae_config', []))
        
        # Handle scheduler instantiation
        scheduler_config = model_config.get('scheduler')
        if isinstance(scheduler_config, str):
            if scheduler_config == "FlowMatch":
                from wm.model.diffusion.flow_matching import FlowMatchScheduler
                self.scheduler = FlowMatchScheduler()
            else:
                raise ValueError(f"Unknown scheduler type: {scheduler_config}")
        else:
            self.scheduler = scheduler_config
        
        # init the scheduler
        self.scheduler.set_timesteps(model_config['training_timesteps'], training=True)
        
    
    def encode_obs(self, o):
        # o can be [B, T, 3, H, W] or [B, T, H, W, 3], values in [0, 1]
        # return: B, T_latent, H', W', D
        
        with torch.no_grad():
            # 1. Normalize [0, 1] -> [-1, 1]
            o = o * 2.0 - 1.0
            
            # 2. Ensure shape is [B, T, 3, H, W] for WanVAEWrapper
            if o.shape[-1] == 3:
                o = o.permute(0, 1, 4, 2, 3).contiguous()
            elif o.shape[2] == 3:
                # Already [B, T, 3, H, W]
                pass
                
            latent = self.vae.encode(o) # [B, T_latent, 16, H/8, W/8]
            # To [B, T_latent, H/8, W/8, 16] for DiT
            latent = latent.permute(0, 1, 3, 4, 2).contiguous()
        return latent
    
    
    def training_loss(self, z, a):
        # z: B, T', H', W', D
        # a: B, T_pixel, C_a
        # return: loss
        
        B, T = z.shape[0], z.shape[1]
        
        # Sample independent timesteps for ALL frames (including first frame)
        t_indices = torch.randint(0, self.scheduler.timesteps.shape[0], (B, T), device=z.device)
        t_values = self.scheduler.timesteps[t_indices] # [B, T]
        
        # Add independent noise using the helper we added to FlowMatchScheduler
        z_t, eps = self.scheduler.add_independent_noise(z, t_values)
        
        v_pred = self.model(z_t, t_values, a)
        v_target = self.scheduler.training_target(z, eps, t_values)
        
        # Apply training weights
        weights = self.scheduler.training_weight(t_values)
        loss = (weights.view(B, T, 1, 1, 1) * (v_pred - v_target)**2).mean()
        return loss
        
    
    def full_train_loss(self, o_t, a):
        # o_t: B, T_pixel, H, W, 3
        # a: B, T_pixel, C_a
        # return: loss
        
        # zero out the last action since it's not used for training
        a = a.clone()
        a[:, -1, :] = 0 
        
        # encode the obs
        z = self.encode_obs(o_t) # B, T', H', W', D
        
        # add noise and get training loss
        loss = self.training_loss(z, a)
        return loss
    
    
    def generate(self, o_0, a, num_inference_steps=50, noise_level=0.0, mode="autoregressive"):
        # o_0: B, H, W, 3
        # a: B, T_pixel, A
        # return: B, T_pixel, H, W, 3
        
        # Diffusion Forcing (Causal) can be run in two modes:
        # 1. "parallel": Denoise the whole sequence at once (fastest, uses causal mask)
        # 2. "autoregressive": Denoise one frame at a time (more stable for long horizons)
        
        B = o_0.shape[0]
        T_pixel = a.shape[1]
        device = o_0.device
        
        # 1. Encode first frame
        z_0 = self.encode_obs(o_0.unsqueeze(1)) # [B, 1, H', W', 16]
        
        # 2. Determine latent shape
        T_latent = (T_pixel - 1) // 4 + 1
        H_prime, W_prime = z_0.shape[2], z_0.shape[3]
        D = z_0.shape[4] # 16
        
        # Save old scheduler state
        old_training = self.scheduler.training
        self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, training=False)
        
        from tqdm import tqdm
        
        if mode == "parallel":
            # 3. Initialize latent sequence with noise
            z = torch.randn(B, T_latent, H_prime, W_prime, D, device=device)
            
            # 4. Handle first frame noise level
            if noise_level > 0:
                t_val_0 = torch.full((B, 1), noise_level, device=device)
                z_0_noisy, _ = self.scheduler.add_independent_noise(z_0, t_val_0)
                z[:, 0] = z_0_noisy.squeeze(1)
            else:
                z[:, 0] = z_0.squeeze(1)
            
            # 5. Denoising loop
            for i in tqdm(range(len(self.scheduler.timesteps)), desc="Denoising (Parallel)"):
                t_val = self.scheduler.timesteps[i]
                t = torch.full((B, T_latent), t_val, device=device)
                
                if noise_level > 0:
                    t[:, 0] = torch.where(t_val > noise_level, torch.tensor(noise_level, device=device), t_val)
                else:
                    t[:, 0] = 0
                
                with torch.no_grad():
                    v_pred = self.model(z, t, a)
                    z = self.scheduler.step(v_pred, t, z)
                    
                    if noise_level == 0:
                        z[:, 0] = z_0.squeeze(1)
                        
        elif mode == "autoregressive":
            # 3. Start with only the first frame
            z_all = z_0.clone() # [B, 1, H', W', D]
            
            # 4. Roll the window frame by frame
            for t_idx in range(1, T_latent):
                # a. Add a new noisy frame to the end
                z_next = torch.randn(B, 1, H_prime, W_prime, D, device=device)
                z_curr = torch.cat([z_all, z_next], dim=1) # [B, t+1, ...]
                
                # b. Denoise only the LAST frame in the current sequence
                for i in range(len(self.scheduler.timesteps)):
                    t_val = self.scheduler.timesteps[i]
                    # History is clean (t=0), last frame is noisy (t=t_val)
                    t_seq = torch.zeros(B, t_idx + 1, device=device)
                    t_seq[:, -1] = t_val
                    
                    # Correct actions for current length
                    L_curr = self.model.action_compress_rate * t_idx + 1
                    a_curr = a[:, :L_curr]
                    
                    with torch.no_grad():
                        v_pred = self.model(z_curr, t_seq, a_curr)
                        z_curr = self.scheduler.step(v_pred, t_seq, z_curr)
                        
                        # Fix history frames (optional but recommended for stability)
                        z_curr[:, :-1] = z_all
                
                # c. Record the clean result for the next iteration
                z_all = torch.cat([z_all, z_curr[:, -1:]], dim=1)
            
            z = z_all
        else:
            raise ValueError(f"Unknown generation mode: {mode}")

        # Restore scheduler state
        if old_training:
            self.scheduler.set_timesteps(self.model_config['training_timesteps'], training=True)
            
        # 6. Decode back to pixels
        with torch.no_grad():
            z_for_vae = z.permute(0, 1, 4, 2, 3).contiguous()
            video_recon = self.vae.decode_to_pixel(z_for_vae)
            video_recon = (video_recon + 1.0) / 2.0
            video_recon = video_recon.permute(0, 1, 3, 4, 2).contiguous().clamp(0, 1)
        
        return video_recon