bigshanedogg commited on
Commit
296b2af
·
verified ·
1 Parent(s): 4afba87

Delete modeling_tatok.py

Browse files
Files changed (1) hide show
  1. modeling_tatok.py +0 -468
modeling_tatok.py DELETED
@@ -1,468 +0,0 @@
1
- # from . import models
2
- # from .utils import ScalingLayer
3
- import copy
4
- import inspect
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from einops import rearrange
10
- from torchvision.transforms import Resize
11
- from transformers import (
12
- AutoConfig,
13
- AutoModel,
14
- PreTrainedModel,
15
- Siglip2VisionConfig,
16
- Siglip2VisionModel,
17
- )
18
-
19
- from configuration_tatok import TATokConfig
20
-
21
-
22
- # def make(model_spec, args=None, load_sd=False) -> torch.nn.Module:
23
- def models_make(model_spec, args=None, load_sd=False) -> torch.nn.Module:
24
- if args is not None:
25
- model_args = copy.deepcopy(model_spec["args"])
26
- model_args.update(args)
27
- else:
28
- model_args = model_spec["args"]
29
- model_params = inspect.signature(models[model_spec["name"]]).parameters
30
- if "kwargs" not in model_params:
31
- model_args = {k: v for k, v in model_args.items() if k in model_params}
32
- model = models[model_spec["name"]](**model_args)
33
- if load_sd:
34
- if (
35
- ("abs_pe" in model_spec["sd"])
36
- and hasattr(model, "abs_pe")
37
- and model_spec["sd"]["abs_pe"].shape != model.abs_pe.shape
38
- ):
39
- del model_spec["sd"]["abs_pe"]
40
- msg = model.load_state_dict(model_spec["sd"], strict=False)
41
- print(msg)
42
- return model
43
-
44
-
45
- class Bottleneck(nn.Module):
46
- def __init__(
47
- self, bottleneck_dim: int, input_dim: int, output_dim: int, token_nums: int, regularizer=None, **kwargs
48
- ):
49
- super().__init__()
50
- self.token_nums = token_nums
51
- self.input_dim = input_dim
52
- self.output_dim = output_dim
53
- if bottleneck_dim > 0:
54
- self.bottleneck_dim = bottleneck_dim
55
- else:
56
- assert (
57
- self.input_dim == self.output_dim
58
- ), "input_dim and output_dim must be the same when bottleneck_dim is not specified"
59
- self.bottleneck_dim = self.input_dim
60
-
61
- self.project_dim = self.bottleneck_dim
62
-
63
- if self.bottleneck_dim > 0:
64
- self.in_linear = nn.Linear(self.input_dim, self.project_dim)
65
- self.out_linear = nn.Linear(self.bottleneck_dim, self.output_dim)
66
- else:
67
- self.in_linear = self.out_linear = lambda x: x
68
-
69
- regularizer["args"]["dim"] = self.bottleneck_dim
70
- regularizer["args"]["token_nums"] = self.token_nums
71
- self.regularizer = models_make(regularizer)
72
-
73
- def project_in(self, x):
74
- assert len(x.shape) == 3, "Input shape must be (batch, n_tokens, e_dim)"
75
- z = self.in_linear(x)
76
- return z
77
-
78
- def project_out(self, z_cat):
79
- z = self.out_linear(z_cat)
80
- return z
81
-
82
- def decode(self, bottleneck_rep):
83
- regularized_z = self.regularizer.decode(bottleneck_rep)
84
- return self.project_out(regularized_z)
85
-
86
- def forward(self, x):
87
- z = self.project_in(x)
88
- projected_z = z
89
- regularized_output = self.regularizer(z)
90
- x_hat = self.project_out(regularized_output["regularized_z"])
91
- bottleneck_rep = regularized_output.pop("bottleneck_rep")
92
- return {
93
- "output": x_hat,
94
- "bottleneck_rep": bottleneck_rep,
95
- "projected_z": projected_z,
96
- **regularized_output,
97
- }
98
-
99
-
100
- class SimVectorQuantizer(nn.Module):
101
- def __init__(
102
- self,
103
- dim,
104
- codebook_size,
105
- l2_normalized=False,
106
- same_index_shape=True,
107
- stochastic=False,
108
- stochastic_temperature=1.0,
109
- **kwargs,
110
- ):
111
- super().__init__()
112
- self.codebook_size = codebook_size
113
- self.dim = dim
114
- assert isinstance(l2_normalized, bool)
115
- self.l2_normalized = l2_normalized
116
- self.stochastic = stochastic
117
- self.eval_deterministic = False
118
- self.default_stochastic_temperature = stochastic_temperature
119
-
120
- if self.stochastic:
121
- if stochastic_temperature > 0: # fixed temperature
122
- self.stochastic_temperature_inv = 1 / stochastic_temperature
123
- else: # set stochastic_temperature < 0 to use learnable temperature
124
- self.stochastic_temperature_inv = nn.Parameter(torch.tensor(10.0))
125
-
126
- # for clear inference code, we remove the codebook init from LLM's embedding
127
- self.embedding = nn.Embedding(self.codebook_size, self.dim)
128
- self.embedding_proj = nn.Linear(self.dim, self.dim)
129
-
130
- self.same_index_shape = same_index_shape
131
-
132
- def set_eval_deterministic(self, deterministic=True):
133
- self.eval_deterministic = deterministic
134
-
135
- def set_stochastic_temperature(self, temperature):
136
- self.stochastic_temperature_inv = 1 / temperature
137
-
138
- @torch.autocast(device_type="cuda", enabled=False)
139
- def get_emb(self):
140
- emb = self.embedding_proj(self.embedding.weight)
141
- if self.l2_normalized:
142
- emb = F.normalize(emb, p=2, dim=-1)
143
- # assert emb.dtype == torch.float32, f"Embedding weight dtype is {emb.dtype}, expected float32"
144
- return emb
145
-
146
- @torch.autocast(device_type="cuda", enabled=False)
147
- def forward(self, z):
148
- emb = self.get_emb()
149
- z = z.to(emb)
150
- # z = z.float()
151
- assert len(z.shape) == 3, "Input shape must be (batch, n_tokens, e_dim)"
152
- if self.l2_normalized:
153
- z = F.normalize(z, p=2, dim=-1)
154
-
155
- z_flattened = rearrange(z, "b n d -> (b n) d")
156
-
157
- if self.stochastic:
158
- # sample the softmaxed cosine similarity
159
- assert self.l2_normalized, "Stochastic sampling requires l2 normalization"
160
- cos_sim = torch.einsum("bd,nd->bn", z_flattened, emb)
161
- probs = F.softmax(cos_sim * self.stochastic_temperature_inv, dim=-1)
162
- if self.eval_deterministic and not self.training:
163
- q_indices = torch.argmax(probs, dim=-1)
164
- else:
165
- q_indices = torch.multinomial(probs, 1).squeeze(-1)
166
- else:
167
- d = (
168
- torch.sum(z_flattened**2, dim=1, keepdim=True)
169
- + torch.sum(emb**2, dim=1)
170
- - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(emb, "n d -> d n"))
171
- )
172
- q_indices = torch.argmin(d, dim=1)
173
-
174
- quantized = F.embedding(
175
- q_indices,
176
- emb,
177
- self.embedding.padding_idx,
178
- self.embedding.max_norm,
179
- self.embedding.norm_type,
180
- self.embedding.scale_grad_by_freq,
181
- self.embedding.sparse,
182
- ).view(
183
- z.shape
184
- ) # (b, n, d)
185
-
186
- # preserve gradients
187
- quantized = z + (quantized - z).detach()
188
-
189
- if self.same_index_shape:
190
- q_indices = q_indices.reshape(quantized.shape[0], quantized.shape[1])
191
-
192
- return_dict = {
193
- "unregularized_z": z, # but l2 normalized if l2_normalized=True
194
- "emb": emb, # but l2 normalized if l2_normalized=True
195
- "regularized_z": quantized,
196
- "bottleneck_rep": q_indices,
197
- }
198
- return return_dict
199
-
200
- def get_codebook_entry(self, indices, shape=None):
201
- # shape specifying (batch, height, width, channel)
202
- indices_shape = indices.shape
203
- indices_flatten = rearrange(indices, "... -> (...)")
204
-
205
- # get quantized latent vectors
206
- emb = self.get_emb()
207
- z_q = F.embedding(indices_flatten, emb)
208
- # z_q = self.embedding(indices_flatten)
209
- if self.l2_normalized:
210
- z_q = F.normalize(z_q, p=2, dim=-1)
211
-
212
- if shape is not None:
213
- z_q = z_q.reshape(shape)
214
- else:
215
- z_q = z_q.reshape([*indices_shape, self.dim])
216
- return z_q
217
-
218
- def decode(self, indices):
219
- return self.get_codebook_entry(indices)
220
-
221
-
222
- models = {"simvq": SimVectorQuantizer, "bottleneck": Bottleneck}
223
-
224
-
225
- class ScalingLayer(nn.Module):
226
- def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
227
- super().__init__()
228
- self.register_buffer("shift", torch.Tensor(mean)[None, :, None, None])
229
- self.register_buffer("scale", torch.Tensor(std)[None, :, None, None])
230
-
231
- def forward(self, inp):
232
- return (inp - self.shift) / self.scale
233
-
234
- def inv(self, inp):
235
- return inp * self.scale + self.shift
236
-
237
-
238
- # Register discrete vision models with AutoModel
239
- # TATok (TextAlignedTokenizer) Model
240
- class TATokModel(PreTrainedModel):
241
- config_class = TATokConfig
242
- base_model_prefix = "discrete_vision_encoder"
243
-
244
- def _init_weights(self, module):
245
- """
246
- Initialize the weights of the module.
247
- """
248
- if (
249
- isinstance(module, nn.Conv2d) # noqa: SIM101
250
- or isinstance(module, nn.Embedding)
251
- or isinstance(module, nn.Linear)
252
- ):
253
- module.weight.data.normal_(mean=0.0, std=0.02)
254
- if hasattr(module, "bias") and module.bias is not None:
255
- module.bias.data.zero_()
256
-
257
- elif isinstance(module, nn.LayerNorm):
258
- module.bias.data.zero_()
259
- module.weight.data.fill_(1.0)
260
- elif isinstance(module, nn.Parameter):
261
- embed_std = 1 / torch.sqrt(torch.tensor(module.size(0), dtype=torch.float)).to(module.dtype)
262
- module.data.normal_(mean=0.0, std=embed_std)
263
-
264
- def get_input_embeddings(self):
265
- """
266
- Returns the input embeddings. For vision models, this typically returns None.
267
- """
268
- return None
269
-
270
- def set_input_embeddings(self, value):
271
- # Vision models don't have input embeddings in the traditional sense
272
- pass
273
-
274
- def get_output_embeddings(self):
275
- """
276
- Returns the output embeddings. For vision models, this typically returns None.
277
- """
278
- return None
279
-
280
- def set_output_embeddings(self, new_embeddings):
281
- # Vision models don't have output embeddings in the traditional sense
282
- pass
283
-
284
- def __init__(
285
- self,
286
- config: TATokConfig,
287
- bottleneck=None,
288
- bottleneck_token_num=None,
289
- input_size=None,
290
- teacher=None,
291
- input_type=None, # choose from ['quant', 'rec', 'indices']
292
- pool_scale=None, # choose from [1, 2, 3]
293
- decoder_depth=None,
294
- select_layer_id=None,
295
- **kwargs,
296
- ):
297
- super().__init__(config)
298
- # Get values from config if not provided directly
299
- bottleneck = bottleneck if bottleneck is not None else config.bottleneck
300
- bottleneck_token_num = (
301
- bottleneck_token_num
302
- if bottleneck_token_num is not None
303
- else (
304
- config.bottleneck_token_num
305
- if hasattr(config, "bottleneck_token_num") and config.bottleneck_token_num is not None
306
- else 256
307
- )
308
- )
309
- input_size = (
310
- input_size
311
- if input_size is not None
312
- else (config.input_size if hasattr(config, "input_size") and config.input_size is not None else 384)
313
- )
314
- teacher = (
315
- teacher
316
- if teacher is not None
317
- else (
318
- config.teacher
319
- if hasattr(config, "teacher") and config.teacher is not None
320
- else "google/siglip2-so400m-patch14-384"
321
- )
322
- )
323
- input_type = (
324
- input_type
325
- if input_type is not None
326
- else (config.input_type if hasattr(config, "input_type") and config.input_type is not None else "quant")
327
- )
328
- pool_scale = (
329
- pool_scale
330
- if pool_scale is not None
331
- else (config.pool_scale if hasattr(config, "pool_scale") and config.pool_scale is not None else 1)
332
- )
333
- decoder_depth = (
334
- decoder_depth
335
- if decoder_depth is not None
336
- else (config.decoder_depth if hasattr(config, "decoder_depth") and config.decoder_depth is not None else 3)
337
- )
338
- select_layer_id = (
339
- select_layer_id
340
- if select_layer_id is not None
341
- else (
342
- config.select_layer_id
343
- if hasattr(config, "select_layer_id") and config.select_layer_id is not None
344
- else -2
345
- )
346
- )
347
-
348
- self.input_size = input_size
349
- self.bottleneck_token_num = bottleneck_token_num
350
- self.teacher = teacher
351
- self.input_type = input_type
352
- self.pool_scale = pool_scale
353
- self.decoder_depth = decoder_depth
354
- self.select_layer_id = select_layer_id
355
-
356
- self.bottleneck_dim = bottleneck["args"]["bottleneck_dim"]
357
-
358
- self.encoder_config = AutoConfig.from_pretrained(teacher)
359
- self.encoder = AutoModel.from_config(self.encoder_config).vision_model
360
-
361
- self.encoder_hidden_dim = self.encoder.config.hidden_size
362
-
363
- self.decoder_config = Siglip2VisionConfig()
364
- self.decoder_config.update(
365
- {
366
- "patch_size": 1,
367
- "num_hidden_layers": self.decoder_depth,
368
- "num_channels": self.bottleneck_dim,
369
- "hidden_size": self.encoder_hidden_dim,
370
- }
371
- )
372
- self.decoder = Siglip2VisionModel(self.decoder_config)
373
-
374
- self.encode_task_layer = nn.Sequential(nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim), nn.Tanh())
375
- self.decode_task_layer = nn.Sequential(
376
- nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim),
377
- nn.Tanh(),
378
- nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim),
379
- )
380
-
381
- bottleneck_args = {
382
- "token_nums": self.bottleneck_token_num,
383
- "input_dim": self.encoder_hidden_dim,
384
- "output_dim": self.bottleneck_dim,
385
- }
386
- self.bottleneck = models_make(bottleneck, args=bottleneck_args)
387
-
388
- self.scale_layer = ScalingLayer(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
389
- self.image_resize = Resize((self.input_size, self.input_size))
390
-
391
- def set_vq_eval_deterministic(self, deterministic=True):
392
- self.bottleneck.regularizer.set_eval_deterministic(deterministic)
393
-
394
- @property
395
- def device(self):
396
- return next(self.parameters()).device
397
-
398
- @property
399
- def dtype(self):
400
- return next(self.parameters()).dtype
401
-
402
- def encode(self, x, **kwargs):
403
- if x.ndim == 5:
404
- x = rearrange(x, "b c t h w -> (b t) c h w")
405
- x = self.scale_layer(x)
406
- if tuple(x.shape[-2:]) != (self.input_size, self.input_size):
407
- x = self.image_resize(x)
408
- vq_feats = self.encoder(x, output_hidden_states=True).hidden_states[self.select_layer_id]
409
-
410
- pool_scale = self.pool_scale
411
- pool_scale = kwargs.get("pool_scale", pool_scale)
412
- if pool_scale != 1:
413
- vq_feats = self.avg_pool(vq_feats, pool_scale)
414
- vq_feats = self.encode_task_layer(vq_feats.to(x))
415
-
416
- bottleneck_out = self.bottleneck(vq_feats)
417
- z = bottleneck_out.pop("output")
418
-
419
- return {"encoded": z, "pool_scale": pool_scale, "vq_feats": vq_feats, **bottleneck_out}
420
-
421
- def avg_pool(self, z, pool_scale=1):
422
- if z.ndim == 3:
423
- b, n, c = z.shape
424
- p = int(n**0.5)
425
- z = rearrange(z, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p)
426
- else:
427
- b, c, p, _ = z.shape
428
- p_s = int(p // pool_scale)
429
- z = F.avg_pool2d(z, kernel_size=(pool_scale, pool_scale), stride=(pool_scale, pool_scale)).contiguous()
430
- z = rearrange(z, "b c p1 p2 -> b (p1 p2) c")
431
- return z
432
-
433
- def decode(self, z):
434
- if z.ndim == 4:
435
- z = rearrange(z, "b c p1 p2 -> b (p1 p2) c")
436
- attention_mask = torch.ones(z.shape[:2], dtype=torch.int, device=z.device)
437
- p = int(z.shape[1] ** 0.5)
438
- spatial_shape = torch.tensor([[p, p]] * z.shape[0], device=self.device)
439
- z = self.decoder(z, attention_mask, spatial_shape, output_hidden_states=True).last_hidden_state
440
- z = self.decode_task_layer(z)
441
- return z
442
-
443
- def decode_from_bottleneck(self, bottleneck_rep):
444
- z = self.bottleneck.decode(bottleneck_rep) # (b, n, c)
445
- p = int(z.shape[1] ** 0.5)
446
- z = rearrange(z, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p)
447
- return self.decode(z)
448
-
449
- def forward(self, data, **kwargs):
450
- # data: video in shape (b, c, t, h, w)
451
- encode_output = self.encode(data, **kwargs)
452
- vq_feats = encode_output["encoded"]
453
- p = int(vq_feats.shape[1] ** 0.5)
454
- vq_feats = rearrange(vq_feats, "b (h w) c -> b c h w", h=p, w=p)
455
- pred_feats = self.decode(vq_feats)
456
-
457
- if self.input_type == "quant":
458
- z = encode_output["regularized_z"] # [b, n, c]
459
- elif self.input_type == "indices":
460
- z = encode_output["bottleneck_rep"] # [b, n]
461
- elif self.input_type == "rec":
462
- z = pred_feats # [b, n, c]
463
- encode_output["encoded"] = z
464
- return encode_output
465
-
466
-
467
- # Register with AutoModel
468
- AutoModel.register(TATokConfig, TATokModel)