File size: 4,122 Bytes
22cfe7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class TransformerVQVAE(nn.Module):
    def __init__(self, input_shape=(1, 128, 216), num_codebook_vectors=512, 
                 codebook_dim=64, num_layers=4, num_heads=8, hidden_dim=256):
        super(TransformerVQVAE, self).__init__()
        
        self.input_shape = input_shape
        self.num_codebook_vectors = num_codebook_vectors
        self.codebook_dim = codebook_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, codebook_dim, kernel_size=3, stride=2, padding=1)
        )
        
        # Transformer layers
        encoder_layer = nn.TransformerEncoderLayer(d_model=codebook_dim, nhead=num_heads, dim_feedforward=hidden_dim)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Vector Quantizer
        self.vq = VectorQuantizer(num_codebook_vectors, codebook_dim)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(codebook_dim, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Tanh()
        )

    def encode(self, x):
        # Encode the input
        z = self.encoder(x)
        
        # Reshape for transformer
        z = z.permute(2, 3, 0, 1).contiguous()
        z = z.view(-1, z.shape[2], z.shape[3])
        
        # Apply transformer
        z = self.transformer(z)
        
        # Reshape back
        z = z.view(x.shape[2]//8, x.shape[3]//8, x.shape[0], self.codebook_dim)
        z = z.permute(2, 3, 0, 1).contiguous()
        
        return z

    def decode(self, z):
        # Decode the latent representation
        return self.decoder(z)

    def forward(self, x):
        z = self.encode(x)
        z_q, indices, vq_loss = self.vq(z)
        x_recon = self.decode(z_q)
        return x_recon, indices, vq_loss

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost = 0.25):
        super(VectorQuantizer, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)

    def forward(self, z):
        # Reshape z -> (batch, height, width, channel) and flatten
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.embedding_dim)

        # Distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2ze
        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight**2, dim=1) - \
            2 * torch.matmul(z_flattened, self.embedding.weight.t())

        # Find closest encodings
        min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
        min_encodings = torch.zeros(min_encoding_indices.shape[0], self.num_embeddings).to(z.device)
        min_encodings.scatter_(1, min_encoding_indices, 1)

        # Quantize and unflatten
        z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)

        # Compute loss for embedding
        e_latent_loss = F.mse_loss(z_q.detach(), z)
        q_latent_loss = F.mse_loss(z_q, z.detach())
        vq_loss = q_latent_loss + self.commitment_cost * e_latent_loss

        # Straight Through Estimator
        z_q = z + (z_q - z).detach()

        # Reshape back to match original input shape
        z_q = z_q.permute(0, 3, 1, 2).contiguous()

        return z_q, min_encoding_indices, vq_loss