TheTrueJard commited on
Commit
d28d295
·
verified ·
1 Parent(s): 8059958

Upload folder using huggingface_hub

Browse files
__pycache__/pfsq.cpython-310.pyc CHANGED
Binary files a/__pycache__/pfsq.cpython-310.pyc and b/__pycache__/pfsq.cpython-310.pyc differ
 
__pycache__/plpq.cpython-310.pyc CHANGED
Binary files a/__pycache__/plpq.cpython-310.pyc and b/__pycache__/plpq.cpython-310.pyc differ
 
pfsq.py CHANGED
@@ -12,9 +12,7 @@ import torch
12
  import torch.nn as nn
13
  from torch.nn import Module
14
  from torch import Tensor, int32
15
- from torch.cuda.amp import autocast
16
-
17
- from einops import rearrange, pack, unpack
18
 
19
  # helper functions
20
 
@@ -35,11 +33,22 @@ def maybe(fn):
35
  return fn(x, *args, **kwargs)
36
  return inner
37
 
38
- def pack_one(t, pattern):
39
- return pack([t], pattern)
40
-
41
- def unpack_one(t, ps, pattern):
42
- return unpack(t, ps, pattern)[0]
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # tensor helpers
45
 
@@ -137,7 +146,7 @@ class PFSQ(Module):
137
 
138
  def indices_to_level_indices(self, indices):
139
  """ Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
140
- indices = rearrange(indices, '... -> ... 1')
141
  codes_non_centered = (indices // self._basis) % self._levels
142
  return codes_non_centered
143
 
@@ -152,7 +161,7 @@ class PFSQ(Module):
152
  codes = self._indices_to_codes(indices)
153
 
154
  if self.keep_num_codebooks_dim:
155
- codes = rearrange(codes, '... c d -> ... (c d)')
156
 
157
  if n_codes == 1:
158
  return codes
@@ -160,11 +169,11 @@ class PFSQ(Module):
160
  codes = self.project_out(codes)
161
 
162
  if is_img_or_video or self.channel_first:
163
- codes = rearrange(codes, 'b ... d -> b d ...')
164
 
165
  return codes
166
 
167
- @autocast(enabled = False)
168
  def forward(self, z):
169
  """
170
  einstein notation
@@ -180,19 +189,19 @@ class PFSQ(Module):
180
  # standardize image or video into (batch, seq, dimension)
181
 
182
  if need_move_channel_last:
183
- z = rearrange(z, 'b d ... -> b ... d')
184
- z, ps = pack_one(z, 'b * d')
185
 
186
  assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
187
 
188
  z = self.project_in(z)
189
 
190
- z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
191
 
192
  # whether to force quantization step to be full precision or not
193
 
194
  force_f32 = self.force_quantization_f32
195
- quantization_context = partial(autocast, enabled = False) if force_f32 else nullcontext
196
 
197
  with quantization_context():
198
  orig_dtype = z.dtype
@@ -210,7 +219,7 @@ class PFSQ(Module):
210
  indices = self.codes_to_indices(codes)
211
 
212
  first_codes = codes[:, :, 0, :] # first codebook
213
- codes = rearrange(codes, 'b n c d -> b n (c d)')
214
 
215
  codes = codes.type(orig_dtype)
216
  first_codes = first_codes.type(orig_dtype)
@@ -221,13 +230,13 @@ class PFSQ(Module):
221
  # reconstitute image or video dimensions
222
 
223
  if need_move_channel_last:
224
- out = unpack_one(out, ps, 'b * d')
225
- out = rearrange(out, 'b ... d -> b d ...')
226
 
227
- indices = maybe(unpack_one)(indices, ps, 'b * c')
228
 
229
  if not self.keep_num_codebooks_dim and self.return_indices:
230
- indices = maybe(rearrange)(indices, '... 1 -> ...')
231
 
232
  # return quantized output and indices
233
 
 
12
  import torch.nn as nn
13
  from torch.nn import Module
14
  from torch import Tensor, int32
15
+ from torch.amp import autocast
 
 
16
 
17
  # helper functions
18
 
 
33
  return fn(x, *args, **kwargs)
34
  return inner
35
 
36
+ # einops version
37
+ #def pack_one(t, pattern):
38
+ # return pack([t], pattern)
39
+ def pack_one(t):
40
+ # pattern "b * d"
41
+ if t.ndim > 2:
42
+ ps = t.shape[1:-1]
43
+ return t.flatten(1,-2), ps
44
+ return t, tuple()
45
+
46
+ # einops version
47
+ #def unpack_one(t, ps, pattern):
48
+ # return unpack(t, ps, pattern)[0]
49
+ def unpack_one(t, ps):
50
+ # pattern "b * d"
51
+ return t.reshape(t.shape[0], ps, t.shape[-1])
52
 
53
  # tensor helpers
54
 
 
146
 
147
  def indices_to_level_indices(self, indices):
148
  """ Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
149
+ indices = indices.unsqueeze(-1)
150
  codes_non_centered = (indices // self._basis) % self._levels
151
  return codes_non_centered
152
 
 
161
  codes = self._indices_to_codes(indices)
162
 
163
  if self.keep_num_codebooks_dim:
164
+ codes = codes.flatten(start_dim=-2) # '... c d -> ... (c d)'
165
 
166
  if n_codes == 1:
167
  return codes
 
169
  codes = self.project_out(codes)
170
 
171
  if is_img_or_video or self.channel_first:
172
+ codes = codes.moveaxis(-1,1) # 'b ... d -> b d ...'
173
 
174
  return codes
175
 
176
+ @autocast('cuda', enabled = False)
177
  def forward(self, z):
178
  """
179
  einstein notation
 
189
  # standardize image or video into (batch, seq, dimension)
190
 
191
  if need_move_channel_last:
192
+ z = z.moveaxis(1,-1) # 'b d ... -> b ... d'
193
+ z, ps = pack_one(z)
194
 
195
  assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
196
 
197
  z = self.project_in(z)
198
 
199
+ z = z.reshape(*z.shape[:2], self.num_codebooks, -1) # 'b n (c d) -> b n c d', c=self.num_codebooks
200
 
201
  # whether to force quantization step to be full precision or not
202
 
203
  force_f32 = self.force_quantization_f32
204
+ quantization_context = partial(autocast, device_type = 'cuda', enabled = False) if force_f32 else nullcontext
205
 
206
  with quantization_context():
207
  orig_dtype = z.dtype
 
219
  indices = self.codes_to_indices(codes)
220
 
221
  first_codes = codes[:, :, 0, :] # first codebook
222
+ codes = codes.flatten(start_dim=-2) # 'b n c d -> b n (c d)'
223
 
224
  codes = codes.type(orig_dtype)
225
  first_codes = first_codes.type(orig_dtype)
 
230
  # reconstitute image or video dimensions
231
 
232
  if need_move_channel_last:
233
+ out = unpack_one(out, ps)
234
+ out = out.moveaxis(-1,1) # 'b ... d -> b d ...'
235
 
236
+ indices = maybe(unpack_one)(indices, ps)
237
 
238
  if not self.keep_num_codebooks_dim and self.return_indices:
239
+ indices = (indices.squeeze(-1)) if indices is not None else None
240
 
241
  # return quantized output and indices
242
 
plpq.py CHANGED
@@ -11,9 +11,7 @@ from .config import PLPQConfig
11
 
12
 
13
  class PLPQ(PreTrainedModel):
14
- """
15
- Pyramidal Local Patch Quantizer
16
- """
17
  config_class = PLPQConfig
18
 
19
  def __init__(self, config):
@@ -58,12 +56,12 @@ class PLPQ(PreTrainedModel):
58
 
59
  # Pyramidal Quantizer
60
  self.quantizer = PFSQ(
61
- levels = config.levels, # number of levels for each codebook
62
- num_codebooks = config.num_quantizers, # number of quantizers
63
- dim = config.encoder_blocks[-1][2], # this is the input feature dimension, defaults to log2(codebook_size) if not defined
64
  )
65
 
66
- # coarse decoder output -> 32x32 supervision
67
  self.coarse_decoder = nn.Conv2d(len(config.levels), config.num_out_channels, kernel_size=1, stride=1)
68
 
69
  self.decoder = nn.Sequential(
@@ -76,9 +74,7 @@ class PLPQ(PreTrainedModel):
76
 
77
 
78
  def get_num_params(self) -> int:
79
- """
80
- Return the number of parameters in the model.
81
- """
82
  return sum(p.numel() for p in self.parameters())
83
 
84
 
@@ -87,19 +83,14 @@ class PLPQ(PreTrainedModel):
87
  """
88
  Quantize the input tensor
89
  Parameters:
90
- x (torch.Tensor): The input tensor. Size b, c, h, w
91
  Returns:
92
  torch.Tensor: The indices tensor. Size b, h, w
93
  """
94
- # encode the input
95
  z = self.encoder(x).permute(0, 2, 3, 1).contiguous()
96
- # reshape the input
97
  b, h, w, c = z.shape
98
  z = z.view(b, h * w, -1)
99
-
100
- # quantize the input
101
  quantized, coarse_quantized, all_codes = self.quantizer(z)
102
-
103
  return all_codes
104
 
105
 
@@ -114,25 +105,21 @@ class PLPQ(PreTrainedModel):
114
 
115
  ncodes = indices.shape[-1]
116
  emb = self.quantizer.indices_to_codes(indices).squeeze(-1)
117
-
118
  # reshape [b t c] -> [b c h w]
119
  b, h, w = emb.size(0), int(math.sqrt(emb.size(1))), int(math.sqrt(emb.size(1)))
120
  emb = emb.permute(0, 2, 1).view(b, -1, h, w).contiguous()
121
 
122
  if ncodes == 1:
123
- pred = self.coarse_decoder(emb)
124
- return pred
125
 
126
  # full decoder: full image prediction
127
- pred = self.decoder(emb)
128
-
129
- return pred
130
 
131
 
132
 
133
  class LayerNorm(nn.Module):
134
- """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
135
-
136
  def __init__(self, ndim, bias):
137
  super().__init__()
138
  self.weight = nn.Parameter(torch.ones(ndim))
 
11
 
12
 
13
  class PLPQ(PreTrainedModel):
14
+ """Pyramidal Local Patch Quantizer"""
 
 
15
  config_class = PLPQConfig
16
 
17
  def __init__(self, config):
 
56
 
57
  # Pyramidal Quantizer
58
  self.quantizer = PFSQ(
59
+ levels = config.levels, # number of levels for each codebook
60
+ num_codebooks = config.num_quantizers, # number of quantizers
61
+ dim = config.encoder_blocks[-1][2], # this is the input feature dimension, defaults to log2(codebook_size) if not defined
62
  )
63
 
64
+ # Coarse decoder output -> 32x32 supervision
65
  self.coarse_decoder = nn.Conv2d(len(config.levels), config.num_out_channels, kernel_size=1, stride=1)
66
 
67
  self.decoder = nn.Sequential(
 
74
 
75
 
76
  def get_num_params(self) -> int:
77
+ """Return the number of parameters in the model."""
 
 
78
  return sum(p.numel() for p in self.parameters())
79
 
80
 
 
83
  """
84
  Quantize the input tensor
85
  Parameters:
86
+ x (Image or torch.Tensor): The input tensor. Size b, c, h, w
87
  Returns:
88
  torch.Tensor: The indices tensor. Size b, h, w
89
  """
 
90
  z = self.encoder(x).permute(0, 2, 3, 1).contiguous()
 
91
  b, h, w, c = z.shape
92
  z = z.view(b, h * w, -1)
 
 
93
  quantized, coarse_quantized, all_codes = self.quantizer(z)
 
94
  return all_codes
95
 
96
 
 
105
 
106
  ncodes = indices.shape[-1]
107
  emb = self.quantizer.indices_to_codes(indices).squeeze(-1)
 
108
  # reshape [b t c] -> [b c h w]
109
  b, h, w = emb.size(0), int(math.sqrt(emb.size(1))), int(math.sqrt(emb.size(1)))
110
  emb = emb.permute(0, 2, 1).view(b, -1, h, w).contiguous()
111
 
112
  if ncodes == 1:
113
+ return self.coarse_decoder(emb)
 
114
 
115
  # full decoder: full image prediction
116
+ return self.decoder(emb)
 
 
117
 
118
 
119
 
120
  class LayerNorm(nn.Module):
121
+ """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
122
+
123
  def __init__(self, ndim, bias):
124
  super().__init__()
125
  self.weight = nn.Parameter(torch.ones(ndim))