File size: 7,988 Bytes
9ad5b1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.lightning.base_modules import BaseModule
from omegaconf import DictConfig
from typing import Any, Dict, Tuple
from utils import instantiate
import cv2
from PIL import Image
import numpy as np


class VQAutoEncoder(BaseModule):
    """ VQ-VAE model """
    def __init__(
        self,
        config: DictConfig,
    ) -> None:
        super().__init__(config)

        self.config = config
        
        self.l_w_recon = config.loss.l_w_recon
        self.l_w_embedding = config.loss.l_w_embedding
        self.l_w_commitment = config.loss.l_w_commitment
        self.mse_loss = nn.MSELoss()
       

    def _get_scheduler(self) -> Any:
        # this function is for diffusion model
        pass
    
    def configure_model(self):
        config = self.config
        self.encoder = instantiate(config.model.encoder)

        self.decoder = instantiate(config.model.decoder)
        # self.quantizer = instantiate(config.model.quantizer)
        
        # VQ Embedding (Vector Quantization) layer
        self.vq_embedding = nn.Embedding(config.model.n_embedding, config.model.latent_dim)
        self.vq_embedding.weight.data.uniform_(-1.0 / config.model.latent_dim, 1.0 / config.model.latent_dim)  # Random initialization
    
    def configure_optimizers(self) -> Dict[str, Any]:
        params_to_update = [p for p in self.parameters() if p.requires_grad]
        optimizer = torch.optim.AdamW(
            params_to_update,
            lr=self.config.optimizer.lr,
            weight_decay=self.config.optimizer.weight_decay,
            betas=(self.config.optimizer.adam_beta1, self.config.optimizer.adam_beta2),
            eps=self.config.optimizer.adam_epsilon,
        )

        return {"optimizer": optimizer}
    
    def encode(self, image):
        ze = self.encoder(image)

        # Vector Quantization
        embedding = self.vq_embedding.weight.data
        B, C, H, W = ze.shape
        K, _ = embedding.shape
        embedding_broadcast = embedding.reshape(1, K, C, 1, 1)
        ze_broadcast = ze.reshape(B, 1, C, H, W)
        distance = torch.sum((embedding_broadcast - ze_broadcast) ** 2, 2)
        nearest_neighbor = torch.argmin(distance, 1)

        # Quantized features
        zq = self.vq_embedding(nearest_neighbor).permute(0, 3, 1, 2)  # [B, H, W, C] -> [B, C, H, W]

        return ze, zq
    
    def decode(self, quantized_fea):
        x_hat = self.decoder(quantized_fea)
        return x_hat
    
    def _step(self, batch, return_loss=True):
        pixel_values_vid = batch['pixel_values_vid'] # this is a video batch: [B, T, C, H, W]
        pixel_values_vid = pixel_values_vid.view(-1, 3,  pixel_values_vid.size(-2),  pixel_values_vid.size(-1)) # [B, T, C, H, W] -> [B*T, C, H, W]
        
        # import cv2
        # cv2.imwrite('debug_img.png', 255*pixel_values_vid[-1].permute(1,2,0).cpu().numpy()[:,:,::-1])
        # import pdb; pdb.set_trace()

        # test on single image
        # pixel_values_vid = Image.open('debug_img.png')
        # pixel_values_vid = np.array(pixel_values_vid) / 255.0
        # pixel_values_vid = torch.from_numpy(pixel_values_vid).float().to(self.device)[None].permute(0, 3, 1, 2)

        # Encoding
        hidden_fea, quantized_fea = self.encode(self, pixel_values_vid)
        
        # Stop gradient
        decoder_input = hidden_fea + (quantized_fea - hidden_fea).detach()
        
        # Decoding
        x_hat = self.decode(decoder_input)
        
        if return_loss:
            # Reconstruction Loss
            l_reconstruct = self.mse_loss(x_hat, pixel_values_vid)
            
            # Embedding Loss
            l_embedding = self.mse_loss(hidden_fea.detach(), quantized_fea)
            
            # Commitment Loss
            l_commitment = self.mse_loss(hidden_fea, quantized_fea.detach())
            
            # Total Loss
            total_loss = l_reconstruct + self.l_w_embedding * l_embedding + self.l_w_commitment * l_commitment
            
            self.log('recon_loss', l_reconstruct, on_step=True, on_epoch=True, prog_bar=True)
            self.log('emb_loss', l_embedding, on_step=True, on_epoch=True, prog_bar=True)
            self.log('commit_loss', l_commitment, on_step=True, on_epoch=True, prog_bar=True)

            return total_loss
        else:
            return x_hat, pixel_values_vid
        
    def training_step(self, batch):
        total_loss = self._step(batch)
        return total_loss

    def validation_step(self, batch):
        total_loss = self._step(batch)
        return total_loss
    
    def forward(self, batch):
        x_pred, x_gt = self._step(batch, return_loss=False)

        return x_pred, x_gt
    
    
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, bn=False):
        super(ResBlock, self).__init__()

        if mid_channels is None:
            mid_channels = out_channels

        layers = [
            nn.ReLU(),
            nn.Conv2d(in_channels, mid_channels,
                      kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(mid_channels, out_channels,
                      kernel_size=1, stride=1, padding=0)
        ]
        if bn:
            layers.insert(2, nn.BatchNorm2d(out_channels))
        self.convs = nn.Sequential(*layers)

    def forward(self, x):
        return x + self.convs(x)


class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1)
        self.conv2 = nn.Conv2d(dim, dim, 1)

    def forward(self, x):
        tmp = self.relu(x)
        tmp = self.conv1(tmp)
        tmp = self.relu(tmp)
        tmp = self.conv2(tmp)
        return x + tmp


class Encoder(nn.Module):
    def __init__(self, output_channels=512):
        super(Encoder, self).__init__()
    
        self.block = nn.Sequential(
            nn.Conv2d(3, output_channels, 4, 2, 1),  # Convolutional layer
            nn.ReLU(),
            nn.Conv2d(output_channels, output_channels, 4, 2, 1),  # Another Convolutional layer
            nn.ReLU(),
            nn.Conv2d(output_channels, output_channels, 4, 2, 1),  # Convolutional layer
            nn.ReLU(),
            nn.Conv2d(output_channels, output_channels, 4, 2, 1),  # Another Convolutional layer
            nn.ReLU(),
            nn.Conv2d(output_channels, output_channels, 4, 2, 1),  # Another Convolutional layer
            nn.ReLU(),
            nn.Conv2d(output_channels, output_channels, 3, 1, 1),  # Final Convolutional layer before residuals
            ResidualBlock(output_channels),  # Residual block 1
            ResidualBlock(output_channels),   # Residual block 2
        )
        
    def forward(self, x):
        x = self.block(x)
        return x
    

class Decoder(nn.Module):
    def __init__(self, input_dim=512):
        super(Decoder, self).__init__()
        
        self.fea_map_size=16
    
        self.block = nn.Sequential(
            nn.Conv2d(input_dim, input_dim, 3, 1, 1),  # Initial convolution in the decoder
            ResidualBlock(input_dim),  # Residual block 1
            ResidualBlock(input_dim),  # Residual block 2
            nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1),  # Transposed convolution (upsampling)
            nn.ReLU(),
            nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1),  # Transposed convolution (upsampling)
            nn.ReLU(),
            nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1),  # Transposed convolution (upsampling)
            nn.ReLU(),
            nn.ConvTranspose2d(input_dim, input_dim, 4, 2, 1),  # Transposed convolution (upsampling)
            nn.ReLU(),
            nn.ConvTranspose2d(input_dim, 3, 4, 2, 1)  # Final transposed convolution (output layer)
        )

    def forward(self, x):
        x_hat = self.block(x)

        return x_hat