File size: 8,795 Bytes
f17ae24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import torch
import math

class FlowMatchScheduler(torch.nn.Module):
    """
    A simplified Flow Matching scheduler specifically for the Wan template.
    Supports scalars, [B], [B, T], and higher-dimensional timesteps.
    """

    def __init__(self):
        super().__init__()
        self.num_train_timesteps = 1000
        self.register_buffer("sigmas", None, persistent=False)
        self.register_buffer("timesteps", None, persistent=False)
        self.register_buffer("linear_timesteps_weights", None, persistent=False)
        self.training = False # Renamed from self.training as nn.Module has a training attribute

    @property
    def device(self):
        if self.timesteps is not None:
            return self.timesteps.device
        return torch.device('cpu')

    def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, shift=5.0, training=False):
        """
        Sets the timesteps and sigmas for the Wan template.
        """
        sigma_min = 0.0
        sigma_max = 1.0
        sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
        
        # Sigmas for Wan template: ensure we include 0.0 for clean samples
        sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
        
        # Apply shift (default is 5 for Wan)
        sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
        
        # Move to the current device of the module
        device = self.device
        sigmas = sigmas.to(device)
        timesteps = (sigmas * self.num_train_timesteps).to(device)
        
        self.register_buffer("sigmas", sigmas, persistent=False)
        self.register_buffer("timesteps", timesteps, persistent=False)
        
        if training:
            self.set_training_weight()
            self.training = True
        else:
            self.training = False

    def set_training_weight(self):
        steps = 1000
        x = self.timesteps
        y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
        y_shifted = y - y.min()
        bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
        if len(self.timesteps) != 1000:
            # This is an empirical formula.
            bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
            bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
        
        # Move to the current device of the module
        self.register_buffer("linear_timesteps_weights", bsmntw_weighing.to(self.device), persistent=False)

    def _get_timestep_indices(self, timestep: torch.Tensor):
        """
        Efficiently find the nearest indices in self.timesteps for input timesteps.
        Supports any input shape by flattening, computing, and reshapping.
        """
        if not isinstance(timestep, torch.Tensor):
            timestep = torch.tensor(timestep, device=self.device)
        
        t_input = timestep.to(self.device)
        orig_shape = t_input.shape
        
        # Flatten input to handle any shape (B, T, ...)
        t_flat = t_input.reshape(-1, 1)
        
        # Broadcast against self.timesteps [N] -> [len(t_flat), N]
        diff = (t_flat - self.timesteps.unsqueeze(0)).abs()
        indices = torch.argmin(diff, dim=-1)
        
        return indices.view(orig_shape)

    def step(self, model_output, timestep, sample, to_final=False):
        indices = self._get_timestep_indices(timestep)
        sigma = self.sigmas[indices]
        
        if to_final:
            sigma_next = torch.zeros_like(sigma)
        else:
            # Get next sigma, clamping to avoid out of bounds
            next_indices = (indices + 1).clamp(max=len(self.sigmas) - 1)
            sigma_next = self.sigmas[next_indices]
            # If we were already at the last step, next sigma is 0
            sigma_next = torch.where(indices + 1 >= len(self.sigmas), torch.zeros_like(sigma), sigma_next)

        # Broadcast sigma diff to match sample shape (e.g. [B, T, C, H, W] or [B, C, H, W])
        sigma_diff = (sigma_next - sigma).view(*sigma.shape, *([1] * (sample.ndim - sigma.ndim)))
        sigma_diff = sigma_diff.to(sample.device)
        return sample + model_output * sigma_diff
    
    def return_to_timestep(self, timestep, sample, sample_stablized):
        indices = self._get_timestep_indices(timestep)
        sigma = self.sigmas[indices]
        sigma_view = sigma.view(*sigma.shape, *([1] * (sample.ndim - sigma.ndim)))
        sigma_view = sigma_view.to(sample.device)
        model_output = (sample - sample_stablized) / sigma_view
        return model_output
    
    def add_noise(self, original_samples, noise, timestep):
        indices = self._get_timestep_indices(timestep)
        sigma = self.sigmas[indices]
        
        # Broadcast sigma to match sample shape (e.g. [B, T, 1, 1, 1])
        sigma_view = sigma.view(*sigma.shape, *([1] * (original_samples.ndim - sigma.ndim)))
        sigma_view = sigma_view.to(original_samples.device)
        
        return (1 - sigma_view) * original_samples + sigma_view * noise
    
    def add_independent_noise(self, original_samples, timestep):
        """
        Helper that samples noise independently for each element in original_samples
        and applies it based on the provided timestep (which should match the leading dims).
        """
        noise = torch.randn_like(original_samples)
        return self.add_noise(original_samples, noise, timestep), noise

    def training_target(self, sample, noise, timestep):
        return noise - sample
    
    def training_weight(self, timestep):
        indices = self._get_timestep_indices(timestep)
        return self.linear_timesteps_weights[indices]


if __name__ == "__main__":
    import matplotlib.pyplot as plt
    import numpy as np
    import os

    # Create results directory
    os.makedirs("results/test_flow_matching", exist_ok=True)

    # 1. Initialize scheduler
    scheduler = FlowMatchScheduler()
    num_steps = 50
    scheduler.set_timesteps(num_inference_steps=num_steps, training=True)
    
    # 2. Test with (B, T) shape
    B, T = 2, 4
    indices_bt = torch.randint(0, num_steps, (B, T))
    timesteps_bt = scheduler.timesteps[indices_bt]
    print(f"Testing with (B, T) shape: {timesteps_bt.shape}")
    
    # Test add_noise with (B, T, C, H, W)
    x0 = torch.randn(B, T, 3, 64, 64)
    noise = torch.randn_like(x0)
    xt = scheduler.add_noise(x0, noise, timesteps_bt)
    print(f"xt shape: {xt.shape}")
    assert xt.shape == x0.shape

    # 3. Visualize Timestep Mapping and Training Weights
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Left: Timestep Mapping Curve
    axes[0].plot(range(len(scheduler.timesteps)), scheduler.timesteps.numpy(), marker='.', color='blue', label='Timesteps')
    axes[0].set_title("Timestep Mapping (Wan Shift=5)")
    axes[0].set_xlabel("Inference Step Index")
    axes[0].set_ylabel("Training Timestep (0-1000)")
    axes[0].grid(True)
    axes[0].legend()

    # Right: Training Weights Curve
    axes[1].plot(scheduler.timesteps.numpy(), scheduler.linear_timesteps_weights.numpy(), marker='.', color='red', label='Weights')
    axes[1].set_title("Training Weights vs Timestep")
    axes[1].set_xlabel("Training Timestep")
    axes[1].set_ylabel("Weight Value")
    axes[1].grid(True)
    axes[1].legend()

    plt.tight_layout()
    plt.savefig("results/test_flow_matching/scheduler_curves.png")
    print("Saved curves to results/test_flow_matching/scheduler_curves.png")

    # 4. Visualize x_t interpolation (add_noise)
    # Create a simple grid pattern as original image
    size = 256
    grid = np.zeros((size, size, 3), dtype=np.float32)
    grid[::32, :] = 1.0
    grid[:, ::32] = 1.0
    original_image = torch.from_numpy(grid).permute(2, 0, 1).unsqueeze(0) # [1, 3, 256, 256]
    
    # Random noise
    noise = torch.randn_like(original_image)
    
    # Pick a few steps to visualize
    vis_indices = [0, num_steps//4, num_steps//2, 3*num_steps//4, num_steps-1]
    num_vis = len(vis_indices)
    
    fig_xt, axes_xt = plt.subplots(1, num_vis, figsize=(15, 3))
    for i, idx in enumerate(vis_indices):
        t = scheduler.timesteps[idx]
        xt_img = scheduler.add_noise(original_image, noise, t)
        
        # Denormalize for visualization (clip and permute)
        vis_img = xt_img.squeeze(0).permute(1, 2, 0).numpy()
        vis_img = np.clip(vis_img, 0, 1)
        
        axes_xt[i].imshow(vis_img)
        axes_xt[i].set_title(f"t={t:.1f}")
        axes_xt[i].axis('off')
    
    plt.suptitle("Flow Matching Interpolation (x_t) from Data (left) to Noise (right)")
    plt.tight_layout()
    plt.savefig("results/test_flow_matching/xt_interpolation.png")
    print("Saved x_t interpolation to results/test_flow_matching/xt_interpolation.png")