File size: 9,549 Bytes
a62f296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .tools.wan_vae import WanVAE_


class VAEWanModel(nn.Module):
    def __init__(
        self,
        input_dim,
        mean_path=None,
        std_path=None,
        z_dim=256,
        dim=160,
        dec_dim=512,
        num_res_blocks=1,
        dropout=0.0,
        dim_mult=[1, 1, 1],
        temperal_downsample=[True, True],
        spatial_downsample=[False, False],
        spatial_dim=0,
        input_keys={
            "feature": "feature",
            "feature_length": "feature_length",
        },
        **kwargs,
    ):
        super().__init__()
        self.input_keys = input_keys

        self.mean_path = mean_path
        self.std_path = std_path
        self.input_dim = input_dim
        self.z_dim = z_dim
        self.dim = dim
        self.dec_dim = dec_dim
        self.num_res_blocks = num_res_blocks
        self.dropout = dropout
        self.dim_mult = dim_mult
        self.temperal_downsample = temperal_downsample
        self.spatial_downsample = spatial_downsample
        self.spatial_dim = spatial_dim
        self.RECONS_LOSS = nn.SmoothL1Loss(reduction="none")
        self.LAMBDA_FEATURE = kwargs.get("LAMBDA_FEATURE", 1.0)
        self.LAMBDA_KL = kwargs.get("LAMBDA_KL", 10e-6)

        # Per-dimension reconstruction weights (default: all ones)
        # If shorter than input_dim, pad with 1s at the end.
        recons_weights = kwargs.get("recons_weights", None)
        if recons_weights is not None:
            w = torch.tensor(recons_weights, dtype=torch.float32)
            if w.numel() < input_dim:
                w = torch.cat([w, torch.ones(input_dim - w.numel())])
            self.register_buffer("recons_weights", w[:input_dim], persistent=False)
        else:
            self.register_buffer(
                "recons_weights", torch.ones(input_dim, dtype=torch.float32), persistent=False
            )

        if self.mean_path is not None:
            self.register_buffer(
                "mean", torch.from_numpy(np.load(self.mean_path)).float()
            )
        else:
            self.register_buffer("mean", torch.zeros(input_dim))

        if self.std_path is not None:
            self.register_buffer(
                "std", torch.from_numpy(np.load(self.std_path)).float()
            )
        else:
            self.register_buffer("std", torch.ones(input_dim))

        self.model = WanVAE_(
            input_dim=self.input_dim,
            dim=self.dim,
            dec_dim=self.dec_dim,
            z_dim=self.z_dim,
            dim_mult=self.dim_mult,
            num_res_blocks=self.num_res_blocks,
            temperal_downsample=self.temperal_downsample,
            spatial_downsample=self.spatial_downsample,
            spatial_dim=self.spatial_dim,
            dropout=self.dropout,
        )

        downsample_factor = 1
        for flag in self.temperal_downsample:
            if flag:
                downsample_factor *= 2
        self.downsample_factor = downsample_factor

    def _extract_inputs(self, x):
        inputs = {}
        for internal_key, external_key in self.input_keys.items():
            if external_key in x:
                inputs[internal_key] = x[external_key]
        return inputs

    def preprocess(self, x):
        """Convert last-channel batched format to channel-first, padding to 5D (B, C, T, H, W).
        (B, T, C) -> (B, C, T, 1, 1)
        (B, T, H, C) -> (B, C, T, H, 1)
        (B, T, H, W, C) -> (B, C, T, H, W)
        """
        ndim = x.ndim
        if ndim == 3:  # (B, T, C)
            x = x.permute(0, 2, 1)[:, :, :, None, None]
        elif ndim == 4:  # (B, T, H, C)
            x = x.permute(0, 3, 1, 2)[:, :, :, :, None]
        elif ndim == 5:  # (B, T, H, W, C)
            x = x.permute(0, 4, 1, 2, 3)
        return x

    def postprocess(self, x):
        """Reverse of preprocess: channel-first 5D back to last-channel, stripping padding dims.
        (B, C, T, 1, 1) -> (B, T, C)
        (B, C, T, H, 1) -> (B, T, H, C)
        (B, C, T, H, W) -> (B, T, H, W, C)
        """
        shape = x.shape  # (B, C, T, H, W)
        if shape[3] == 1 and shape[4] == 1:  # (B, C, T, 1, 1) -> (B, T, C)
            x = x[:, :, :, 0, 0].permute(0, 2, 1)
        elif shape[4] == 1:  # (B, C, T, H, 1) -> (B, T, H, C)
            x = x[:, :, :, :, 0].permute(0, 2, 3, 1)
        else:  # (B, C, T, H, W) -> (B, T, H, W, C)
            x = x.permute(0, 2, 3, 4, 1)
        return x

    def forward(self, x):
        x = self._extract_inputs(x)
        features = x["feature"]
        feature_length = x["feature_length"]
        features = (features - self.mean) / self.std
        # create mask based on feature_length
        batch_size, seq_len = features.shape[:2]
        mask = torch.zeros(
            batch_size, seq_len, dtype=torch.bool, device=features.device
        )
        for i in range(batch_size):
            mask[i, : feature_length[i]] = True

        x_in = self.preprocess(features)  # (bs, input_dim, T, 1, 1)
        mu, log_var = self.model.encode(
            x_in, scale=[0, 1], return_dist=True
        )  # (bs, z_dim, T, 1, 1)
        z = self.model.reparameterize(mu, log_var)
        x_decoder = self.model.decode(z, scale=[0, 1])  # (bs, input_dim, T, 1, 1)
        x_out = self.postprocess(x_decoder)  # (bs, T, input_dim)

        if x_out.size(1) != features.size(1):
            min_len = min(x_out.size(1), features.size(1))
            x_out = x_out[:, :min_len]
            features = features[:, :min_len]
            mask = mask[:, :min_len]

        mask_expanded = mask
        for _ in range(features.ndim - 2):
            mask_expanded = mask_expanded.unsqueeze(-1)
        loss_per_element = self.RECONS_LOSS(x_out, features)
        loss_recons = (loss_per_element * mask_expanded * self.recons_weights).sum() / mask_expanded.sum() / self.recons_weights.sum()

        # Compute KL divergence loss
        # KL(N(mu, sigma) || N(0, 1)) = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        # log_var = log(sigma^2), so we can use it directly

        # Build mask for latent space
        T_latent = mu.size(2)
        mask_downsampled = torch.zeros(
            batch_size, T_latent, dtype=torch.bool, device=features.device
        )
        for i in range(batch_size):
            latent_length = (
                feature_length[i] + self.downsample_factor - 1
            ) // self.downsample_factor
            mask_downsampled[i, :latent_length] = True
        mask_latent = (
            mask_downsampled.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
        )  # (B, 1, T_latent, 1, 1)

        # Compute KL loss per element
        kl_per_element = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp())
        # Apply mask: only compute KL loss for valid timesteps
        kl_masked = kl_per_element * mask_latent
        # Sum over all dimensions and normalize by the number of valid elements
        num_latent_elements = mu.size(1) * mu.size(3) * mu.size(4)  # C * H * W
        kl_loss = torch.sum(kl_masked) / (
            torch.sum(mask_downsampled) * num_latent_elements
        )  # normalize by valid timesteps * (C * H * W)

        # Total loss
        total_loss = (
            self.LAMBDA_FEATURE * loss_recons
            + self.LAMBDA_KL * kl_loss
        )

        loss_dict = {}
        loss_dict["total"] = total_loss
        loss_dict["recons"] = loss_recons
        loss_dict["kl"] = kl_loss

        return loss_dict

    def encode(self, x):
        x = (x - self.mean) / self.std
        x_in = self.preprocess(x)  # (bs, T, input_dim) -> (bs, input_dim, T, 1, 1)
        mu = self.model.encode(x_in, scale=[0, 1])  # (bs, z_dim, T, 1, 1)
        mu = self.postprocess(mu)  # (bs, T, z_dim)
        return mu

    def decode(self, mu):
        mu_in = self.preprocess(mu)  # (bs, T, z_dim) -> (bs, z_dim, T, 1, 1)
        x_decoder = self.model.decode(mu_in, scale=[0, 1])  # (bs, z_dim, T, 1, 1)
        x_out = self.postprocess(x_decoder)  # (bs, T, input_dim)
        x_out = x_out * self.std + self.mean
        return x_out

    @torch.no_grad()
    def stream_encode(self, x, first_chunk=True):
        x = (x - self.mean) / self.std
        x_in = self.preprocess(x)  # (bs, input_dim, T, 1, 1)
        mu = self.model.stream_encode(x_in, first_chunk=first_chunk, scale=[0, 1])
        mu = self.postprocess(mu)  # (bs, T, z_dim)
        return mu

    @torch.no_grad()
    def stream_decode(self, mu, first_chunk=True):
        mu_in = self.preprocess(mu)  # (bs, z_dim, T, 1, 1)
        x_decoder = self.model.stream_decode(
            mu_in, first_chunk=first_chunk, scale=[0, 1]
        )
        x_out = self.postprocess(x_decoder)  # (bs, T, input_dim)
        x_out = x_out * self.std + self.mean
        return x_out

    def clear_cache(self):
        self.model.clear_cache()

    def generate(self, x):
        x = self._extract_inputs(x)
        features = x["feature"]
        feature_length = x["feature_length"]
        y_hat = self.decode(self.encode(features))

        y_hat_out = []

        for i in range(y_hat.shape[0]):
            # cut off the padding and align lengths
            valid_len = (
                feature_length[i] - 1
            ) // self.downsample_factor * self.downsample_factor + 1
            # Make sure both have the same length (take minimum)
            y_hat_out.append(y_hat[i, :valid_len])

        out = {}
        out["generated"] = y_hat_out
        return out