File size: 7,289 Bytes
e86746e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import numpy as np
import torch
import torch.nn as nn

from .tools.wan_vae_1d import WanVAE_


class VAEWanModel(nn.Module):
    def __init__(
        self,
        input_dim,
        mean_path=None,
        std_path=None,
        z_dim=256,
        dim=160,
        dec_dim=512,
        num_res_blocks=1,
        dropout=0.0,
        dim_mult=[1, 1, 1],
        temperal_downsample=[True, True],
        vel_window=[0, 0],
        **kwargs,
    ):
        super().__init__()

        self.mean_path = mean_path
        self.std_path = std_path
        self.input_dim = input_dim
        self.z_dim = z_dim
        self.dim = dim
        self.dec_dim = dec_dim
        self.num_res_blocks = num_res_blocks
        self.dropout = dropout
        self.dim_mult = dim_mult
        self.temperal_downsample = temperal_downsample
        self.vel_window = vel_window
        self.RECONS_LOSS = nn.SmoothL1Loss()
        self.LAMBDA_FEATURE = kwargs.get("LAMBDA_FEATURE", 1.0)
        self.LAMBDA_VELOCITY = kwargs.get("LAMBDA_VELOCITY", 0.5)
        self.LAMBDA_KL = kwargs.get("LAMBDA_KL", 10e-6)

        if self.mean_path is not None:
            self.register_buffer(
                "mean", torch.from_numpy(np.load(self.mean_path)).float()
            )
        else:
            self.register_buffer("mean", torch.zeros(input_dim))

        if self.std_path is not None:
            self.register_buffer(
                "std", torch.from_numpy(np.load(self.std_path)).float()
            )
        else:
            self.register_buffer("std", torch.ones(input_dim))

        self.model = WanVAE_(
            input_dim=self.input_dim,
            dim=self.dim,
            dec_dim=self.dec_dim,
            z_dim=self.z_dim,
            dim_mult=self.dim_mult,
            num_res_blocks=self.num_res_blocks,
            temperal_downsample=self.temperal_downsample,
            dropout=self.dropout,
        )

        downsample_factor = 1
        for flag in self.temperal_downsample:
            if flag:
                downsample_factor *= 2
        self.downsample_factor = downsample_factor

    def preprocess(self, x):
        # (bs, T, C) -> (bs, C, T)
        x = x.permute(0, 2, 1)
        return x

    def postprocess(self, x):
        # (bs, C, T) ->  (bs, T, C)
        x = x.permute(0, 2, 1)
        return x

    def forward(self, x):
        features = x["feature"]
        feature_length = x["feature_length"]
        features = (features - self.mean) / self.std
        # create mask based on feature_length
        batch_size, seq_len = features.shape[:2]
        mask = torch.zeros(
            batch_size, seq_len, dtype=torch.bool, device=features.device
        )
        for i in range(batch_size):
            mask[i, : feature_length[i]] = True

        x_in = self.preprocess(features)  # (bs, input_dim, T)
        mu, log_var = self.model.encode(
            x_in, scale=[0, 1], return_dist=True
        )  # (bs, z_dim, T)
        z = self.model.reparameterize(mu, log_var)
        x_decoder = self.model.decode(z, scale=[0, 1])  # (bs, input_dim, T)
        x_out = self.postprocess(x_decoder)  # (bs, T, input_dim)

        if x_out.size(1) != features.size(1):
            min_len = min(x_out.size(1), features.size(1))
            x_out = x_out[:, :min_len, :]
            features = features[:, :min_len, :]
            mask = mask[:, :min_len]

        mask_expanded = mask.unsqueeze(-1)
        x_out_masked = x_out * mask_expanded
        features_masked = features * mask_expanded
        loss_recons = self.RECONS_LOSS(x_out_masked, features_masked)
        vel_start = self.vel_window[0]
        vel_end = self.vel_window[1]
        loss_vel = self.RECONS_LOSS(
            x_out_masked[..., vel_start:vel_end],
            features_masked[..., vel_start:vel_end],
        )

        # Compute KL divergence loss
        # KL(N(mu, sigma) || N(0, 1)) = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        # log_var = log(sigma^2), so we can use it directly

        # Build mask for latent space
        T_latent = mu.size(2)
        mask_downsampled = torch.zeros(
            batch_size, T_latent, dtype=torch.bool, device=features.device
        )
        for i in range(batch_size):
            latent_length = (
                feature_length[i] + self.downsample_factor - 1
            ) // self.downsample_factor
            mask_downsampled[i, :latent_length] = True
        mask_latent = mask_downsampled.unsqueeze(1)  # (B, 1, T_latent)

        # Compute KL loss per element
        kl_per_element = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp())
        # Apply mask: only compute KL loss for valid timesteps
        kl_masked = kl_per_element * mask_latent
        # Sum over all dimensions and normalize by the number of valid elements
        kl_loss = torch.sum(kl_masked) / (
            torch.sum(mask_downsampled) * mu.size(1)
        )  # normalize by valid timesteps * latent_dim

        # Total loss
        total_loss = (
            self.LAMBDA_FEATURE * loss_recons
            + self.LAMBDA_VELOCITY * loss_vel
            + self.LAMBDA_KL * kl_loss
        )

        loss_dict = {}
        loss_dict["total"] = total_loss
        loss_dict["recons"] = loss_recons
        loss_dict["velocity"] = loss_vel
        loss_dict["kl"] = kl_loss

        return loss_dict

    def encode(self, x):
        x = (x - self.mean) / self.std
        x_in = self.preprocess(x)  # (bs, T, input_dim) -> (bs, input_dim, T)
        mu = self.model.encode(x_in, scale=[0, 1])  # (bs, z_dim, T)
        mu = self.postprocess(mu)  # (bs, T, z_dim)
        return mu

    def decode(self, mu):
        mu_in = self.preprocess(mu)  # (bs, T, z_dim) -> (bs, z_dim, T)
        x_decoder = self.model.decode(mu_in, scale=[0, 1])  # (bs, z_dim, T)
        x_out = self.postprocess(x_decoder)  # (bs, T, input_dim)
        x_out = x_out * self.std + self.mean
        return x_out

    @torch.no_grad()
    def stream_encode(self, x, first_chunk=True):
        x = (x - self.mean) / self.std
        x_in = self.preprocess(x)  # (bs, input_dim, T)
        mu = self.model.stream_encode(x_in, first_chunk=first_chunk, scale=[0, 1])
        mu = self.postprocess(mu)  # (bs, T, z_dim)
        return mu

    @torch.no_grad()
    def stream_decode(self, mu, first_chunk=True):
        mu_in = self.preprocess(mu)  # (bs, z_dim, T)
        x_decoder = self.model.stream_decode(
            mu_in, first_chunk=first_chunk, scale=[0, 1]
        )
        x_out = self.postprocess(x_decoder)  # (bs, T, input_dim)
        x_out = x_out * self.std + self.mean
        return x_out

    def clear_cache(self):
        self.model.clear_cache()

    def generate(self, x):
        features = x["feature"]
        feature_length = x["feature_length"]
        y_hat = self.decode(self.encode(features))

        y_hat_out = []

        for i in range(y_hat.shape[0]):
            # cut off the padding and align lengths
            valid_len = (
                feature_length[i] - 1
            ) // self.downsample_factor * self.downsample_factor + 1
            # Make sure both have the same length (take minimum)
            y_hat_out.append(y_hat[i, :valid_len, :])

        out = {}
        out["generated"] = y_hat_out
        return out