File size: 3,618 Bytes
9cf390a
 
 
 
 
38e5619
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cf390a
 
38e5619
9cf390a
 
38e5619
 
 
 
 
 
 
 
 
9cf390a
 
 
38e5619
9cf390a
38e5619
 
 
 
9cf390a
38e5619
9cf390a
 
 
 
 
 
 
 
 
 
 
 
38e5619
9cf390a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.nn as nn
import torch.nn.functional as F
from basicsr.utils.registry import ARCH_REGISTRY

class VectorQuantizer(nn.Module):
    def __init__(self, n_e, e_dim):
        super().__init__()
        self.n_e = n_e
        self.e_dim = e_dim
        self.embedding = nn.Parameter(torch.randn(n_e, e_dim))

    def get_codebook_feat(self, indices, shape):
        feat = self.embedding[indices]
        feat = feat.view(shape[0], shape[1], shape[2], shape[3]).permute(0, 3, 1, 2)
        return feat

    def forward(self, z):
        z_q = self.embedding[z]
        loss = torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
        z_q = z + (z_q - z).detach()
        return z_q, loss, None

class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, downsample=False, downsample_method='nearest'):
        super().__init__()
        self.downsample = downsample
        self.downsample_method = downsample_method
        self.conv = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(out_channel, out_channel, 3, 1, 1)
        )
        # Add a projection layer for the identity path if channels or spatial dimensions change
        self.proj = nn.Conv2d(in_channel, out_channel, 1, 1, 0) if in_channel != out_channel else nn.Identity()
        self.downsample_identity = nn.AvgPool2d(kernel_size=2, stride=2) if downsample else nn.Identity()

    def forward(self, x):
        identity = x
        out = self.conv(x)
        if self.downsample:
            out = F.interpolate(out, scale_factor=0.5, mode=self.downsample_method)
        # Adjust the identity path to match out's dimensions
        identity = self.proj(identity)  # Match channel dimensions
        identity = self.downsample_identity(identity)  # Match spatial dimensions if downsampling
        out += identity
        out = F.relu(out)
        return out

@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
    def __init__(self, in_channel, channel, down_factor, downsample_method, downsample_steps, z_channels, codebook_size):
        super().__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        curr_channel = in_channel
        for i in range(downsample_steps):
            next_channel = channel * down_factor[i]
            down = i < len(down_factor) - 1
            self.encoder.append(ResBlock(curr_channel, next_channel, downsample=down, downsample_method=downsample_method))
            curr_channel = next_channel
        self.encoder.append(nn.Conv2d(curr_channel, z_channels, 3, 1, 1))
        self.quantize = VectorQuantizer(codebook_size, z_channels)
        self.decoder.append(nn.Conv2d(z_channels, curr_channel, 3, 1, 1))
        for i in range(downsample_steps - 1, -1, -1):
            next_channel = channel * down_factor[i]
            up = i > 0
            self.decoder.append(ResBlock(curr_channel, next_channel, downsample=False))
            if up:
                self.decoder.append(nn.Upsample(scale_factor=down_factor[i], mode=downsample_method))
            curr_channel = next_channel
        self.decoder.append(nn.Conv2d(curr_channel, in_channel, 3, 1, 1))

    def encode(self, x):
        for module in self.encoder:
            x = module(x)
        return x

    def decode(self, z):
        for module in self.decoder:
            z = module(z)
        return z

    def forward(self, x):
        z = self.encode(x)
        z_q, quant_loss, _ = self.quantize(z)
        out = self.decode(z_q)
        return out, quant_loss