0xZohar commited on
Commit
7ca8123
·
verified ·
1 Parent(s): 43f661a

Add code/cube3d/model/autoencoder/spherical_vq.py

Browse files
code/cube3d/model/autoencoder/spherical_vq.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from typing import Literal, Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from cube3d.model.transformers.norm import RMSNorm
9
+
10
+
11
+ class SphericalVectorQuantizer(nn.Module):
12
+ def __init__(
13
+ self,
14
+ embed_dim: int,
15
+ num_codes: int,
16
+ width: Optional[int] = None,
17
+ codebook_regularization: Literal["batch_norm", "kl"] = "batch_norm",
18
+ ):
19
+ """
20
+ Initializes the SphericalVQ module.
21
+ Args:
22
+ embed_dim (int): The dimensionality of the embeddings.
23
+ num_codes (int): The number of codes in the codebook.
24
+ width (Optional[int], optional): The width of the input. Defaults to None.
25
+ Raises:
26
+ ValueError: If beta is not in the range [0, 1].
27
+ """
28
+ super().__init__()
29
+
30
+ self.num_codes = num_codes
31
+
32
+ self.codebook = nn.Embedding(num_codes, embed_dim)
33
+ self.codebook.weight.data.uniform_(-1.0 / num_codes, 1.0 / num_codes)
34
+
35
+ width = width or embed_dim
36
+ if width != embed_dim:
37
+ self.c_in = nn.Linear(width, embed_dim)
38
+ self.c_x = nn.Linear(width, embed_dim) # shortcut
39
+ self.c_out = nn.Linear(embed_dim, width)
40
+ else:
41
+ self.c_in = self.c_out = self.c_x = nn.Identity()
42
+
43
+ self.norm = RMSNorm(embed_dim, elementwise_affine=False)
44
+ self.cb_reg = codebook_regularization
45
+ if self.cb_reg == "batch_norm":
46
+ self.cb_norm = nn.BatchNorm1d(embed_dim, track_running_stats=False)
47
+ else:
48
+ self.cb_weight = nn.Parameter(torch.ones([embed_dim]))
49
+ self.cb_bias = nn.Parameter(torch.zeros([embed_dim]))
50
+ self.cb_norm = lambda x: x.mul(self.cb_weight).add_(self.cb_bias)
51
+
52
+ def get_codebook(self):
53
+ """
54
+ Retrieves the normalized codebook weights.
55
+ This method applies a series of normalization operations to the
56
+ codebook weights, ensuring they are properly scaled and normalized
57
+ before being returned.
58
+ Returns:
59
+ torch.Tensor: The normalized weights of the codebook.
60
+ """
61
+
62
+ return self.norm(self.cb_norm(self.codebook.weight))
63
+
64
+ @torch.no_grad()
65
+
66
+ def lookup_codebook(self, q: torch.Tensor):
67
+ """
68
+ Perform a lookup in the codebook and process the result.
69
+ This method takes an input tensor of indices, retrieves the corresponding
70
+ embeddings from the codebook, and applies a transformation to the retrieved
71
+ embeddings.
72
+ Args:
73
+ q (torch.Tensor): A tensor containing indices to look up in the codebook.
74
+ Returns:
75
+ torch.Tensor: The transformed embeddings retrieved from the codebook.
76
+ """
77
+
78
+ # normalize codebook
79
+ z_q = F.embedding(q, self.get_codebook())
80
+ z_q = self.c_out(z_q)
81
+ return z_q
82
+
83
+ @torch.no_grad()
84
+ def lookup_codebook_latents(self, q: torch.Tensor):
85
+ """
86
+ Retrieves the latent representations from the codebook corresponding to the given indices.
87
+ Args:
88
+ q (torch.Tensor): A tensor containing the indices of the codebook entries to retrieve.
89
+ The indices should be integers and correspond to the rows in the codebook.
90
+ Returns:
91
+ torch.Tensor: A tensor containing the latent representations retrieved from the codebook.
92
+ The shape of the returned tensor depends on the shape of the input indices
93
+ and the dimensionality of the codebook entries.
94
+ """
95
+
96
+ # normalize codebook
97
+ z_q = F.embedding(q, self.get_codebook())
98
+ return z_q
99
+
100
+ def quantize(self, z: torch.Tensor):
101
+ """
102
+ Quantizes the latent codes z with the codebook
103
+
104
+ Args:
105
+ z (Tensor): B x ... x F
106
+ """
107
+
108
+ # normalize codebook
109
+ codebook = self.get_codebook()
110
+ # the process of finding quantized codes is non differentiable
111
+ with torch.no_grad():
112
+ # flatten z
113
+ z_flat = z.view(-1, z.shape[-1])
114
+
115
+ # calculate distance and find the closest code
116
+ d = torch.cdist(z_flat, codebook)
117
+ q = torch.argmin(d, dim=1) # num_ele
118
+
119
+ z_q = codebook[q, :].reshape(*z.shape[:-1], -1)
120
+ q = q.view(*z.shape[:-1])
121
+
122
+ return z_q, {"z": z.detach(), "q": q}
123
+
124
+ def straight_through_approximation(self, z, z_q):
125
+ """passed gradient from z_q to z"""
126
+ z_q = z + (z_q - z).detach()
127
+ return z_q
128
+
129
+ def forward(self, z: torch.Tensor):
130
+ """
131
+ Forward pass of the spherical vector quantization autoencoder.
132
+ Args:
133
+ z (torch.Tensor): Input tensor of shape (batch_size, ..., feature_dim).
134
+ Returns:
135
+ Tuple[torch.Tensor, Dict[str, Any]]:
136
+ - z_q (torch.Tensor): The quantized output tensor after applying the
137
+ straight-through approximation and output projection.
138
+ - ret_dict (Dict[str, Any]): A dictionary containing additional
139
+ information:
140
+ - "z_q" (torch.Tensor): Detached quantized tensor.
141
+ - "q" (torch.Tensor): Indices of the quantized vectors.
142
+ - "perplexity" (torch.Tensor): The perplexity of the quantization,
143
+ calculated as the exponential of the negative sum of the
144
+ probabilities' log values.
145
+ """
146
+
147
+ with torch.autocast(device_type=z.device.type, enabled=False):
148
+ # work in full precision
149
+ z = z.float()
150
+
151
+ # project and normalize
152
+ z_e = self.norm(self.c_in(z))
153
+ z_q, ret_dict = self.quantize(z_e)
154
+
155
+ ret_dict["z_q"] = z_q.detach()
156
+ z_q = self.straight_through_approximation(z_e, z_q)
157
+ z_q = self.c_out(z_q)
158
+
159
+ return z_q, ret_dict