Continual-Mega commited on
Commit
7caf841
·
verified ·
1 Parent(s): f41cee4

Upload CLIP/transformer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. CLIP/transformer.py +760 -0
CLIP/transformer.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import math
3
+ from typing import Callable, Optional, Sequence, Tuple
4
+ from itertools import repeat
5
+ import collections.abc
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torch.utils.checkpoint import checkpoint
10
+
11
+ # From PyTorch internals
12
+ def _ntuple(n):
13
+ def parse(x):
14
+ if isinstance(x, collections.abc.Iterable):
15
+ return x
16
+ return tuple(repeat(x, n))
17
+ return parse
18
+ to_2tuple = _ntuple(2)
19
+
20
+
21
+ class LayerNormFp32(nn.LayerNorm):
22
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
23
+
24
+ def forward(self, x: torch.Tensor):
25
+ orig_type = x.dtype
26
+ x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
27
+ return x.to(orig_type)
28
+
29
+
30
+ class LayerNorm(nn.LayerNorm):
31
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
32
+
33
+ def forward(self, x: torch.Tensor):
34
+ orig_type = x.dtype
35
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
36
+ return x.to(orig_type)
37
+
38
+
39
+ class QuickGELU(nn.Module):
40
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
41
+ def forward(self, x: torch.Tensor):
42
+ return x * torch.sigmoid(1.702 * x)
43
+
44
+
45
+ class LayerScale(nn.Module):
46
+ def __init__(self, dim, init_values=1e-5, inplace=False):
47
+ super().__init__()
48
+ self.inplace = inplace
49
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
50
+
51
+ def forward(self, x):
52
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
53
+
54
+
55
+ class PatchDropout(nn.Module):
56
+ """
57
+ https://arxiv.org/abs/2212.00794
58
+ """
59
+
60
+ def __init__(self, prob, exclude_first_token=True):
61
+ super().__init__()
62
+ assert 0 <= prob < 1.
63
+ self.prob = prob
64
+ self.exclude_first_token = exclude_first_token # exclude CLS token
65
+
66
+ def forward(self, x):
67
+ if not self.training or self.prob == 0.:
68
+ return x
69
+
70
+ if self.exclude_first_token:
71
+ cls_tokens, x = x[:, :1], x[:, 1:]
72
+ else:
73
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
74
+
75
+ batch = x.size()[0]
76
+ num_tokens = x.size()[1]
77
+
78
+ batch_indices = torch.arange(batch)
79
+ batch_indices = batch_indices[..., None]
80
+
81
+ keep_prob = 1 - self.prob
82
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
83
+
84
+ rand = torch.randn(batch, num_tokens)
85
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
86
+
87
+ x = x[batch_indices, patch_indices_keep]
88
+
89
+ if self.exclude_first_token:
90
+ x = torch.cat((cls_tokens, x), dim=1)
91
+
92
+ return x
93
+
94
+
95
+ class Attention(nn.Module):
96
+ def __init__(
97
+ self,
98
+ dim,
99
+ num_heads=8,
100
+ qkv_bias=True,
101
+ scaled_cosine=False,
102
+ scale_heads=False,
103
+ logit_scale_max=math.log(1. / 0.01),
104
+ attn_drop=0.,
105
+ proj_drop=0.
106
+ ):
107
+ super().__init__()
108
+ self.scaled_cosine = scaled_cosine
109
+ self.scale_heads = scale_heads
110
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
111
+ self.num_heads = num_heads
112
+ self.head_dim = dim // num_heads
113
+ self.scale = self.head_dim ** -0.5
114
+ self.logit_scale_max = logit_scale_max
115
+
116
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
117
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
118
+ if qkv_bias:
119
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
120
+ else:
121
+ self.in_proj_bias = None
122
+
123
+ if self.scaled_cosine:
124
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
125
+ else:
126
+ self.logit_scale = None
127
+ self.attn_drop = nn.Dropout(attn_drop)
128
+ if self.scale_heads:
129
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
130
+ else:
131
+ self.head_scale = None
132
+ self.out_proj = nn.Linear(dim, dim)
133
+ self.out_drop = nn.Dropout(proj_drop)
134
+
135
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
136
+ L, N, C = x.shape
137
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
138
+ q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
139
+ k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
140
+ v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
141
+
142
+ if self.logit_scale is not None:
143
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
144
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
145
+ attn = attn.view(N, self.num_heads, L, L) * logit_scale
146
+ attn = attn.view(-1, L, L)
147
+ else:
148
+ q = q * self.scale
149
+ attn = torch.bmm(q, k.transpose(-1, -2))
150
+
151
+ if attn_mask is not None:
152
+ if attn_mask.dtype == torch.bool:
153
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
154
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
155
+ attn_mask = new_attn_mask
156
+ attn += attn_mask
157
+
158
+ attn = attn.softmax(dim=-1)
159
+ attn = self.attn_drop(attn)
160
+
161
+ x = torch.bmm(attn, v)
162
+ if self.head_scale is not None:
163
+ x = x.view(N, self.num_heads, L, C) * self.head_scale
164
+ x = x.view(-1, L, C)
165
+ x = x.transpose(0, 1).reshape(L, N, C)
166
+ x = self.out_proj(x)
167
+ x = self.out_drop(x)
168
+ return x
169
+
170
+
171
+ class AttentionalPooler(nn.Module):
172
+ def __init__(
173
+ self,
174
+ d_model: int,
175
+ context_dim: int,
176
+ n_head: int = 8,
177
+ n_queries: int = 256,
178
+ norm_layer: Callable = LayerNorm
179
+ ):
180
+ super().__init__()
181
+ self.query = nn.Parameter(torch.randn(n_queries, d_model))
182
+ self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)
183
+ self.ln_q = norm_layer(d_model)
184
+ self.ln_k = norm_layer(context_dim)
185
+
186
+ def forward(self, x: torch.Tensor):
187
+ x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND
188
+ N = x.shape[1]
189
+ q = self.ln_q(self.query)
190
+ out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0]
191
+ return out.permute(1, 0, 2) # LND -> NLD
192
+
193
+ def _repeat(self, query, N: int):
194
+ return query.unsqueeze(1).repeat(1, N, 1)
195
+
196
+
197
+ class ResidualAttentionBlock(nn.Module):
198
+ def __init__(
199
+ self,
200
+ d_model: int,
201
+ n_head: int,
202
+ mlp_ratio: float = 4.0,
203
+ ls_init_value: float = None,
204
+ act_layer: Callable = nn.GELU,
205
+ norm_layer: Callable = LayerNorm,
206
+ is_cross_attention: bool = False,
207
+ idx: int = 12,
208
+ ):
209
+ super().__init__()
210
+
211
+ self.idx = idx
212
+
213
+ self.ln_1 = norm_layer(d_model)
214
+ self.attn = nn.MultiheadAttention(d_model, n_head)
215
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
216
+ if is_cross_attention:
217
+ self.ln_1_kv = norm_layer(d_model)
218
+
219
+ self.ln_2 = norm_layer(d_model)
220
+ mlp_width = int(d_model * mlp_ratio)
221
+ self.mlp = nn.Sequential(OrderedDict([
222
+ ("c_fc", nn.Linear(d_model, mlp_width)),
223
+ ("gelu", act_layer()),
224
+ ("c_proj", nn.Linear(mlp_width, d_model))
225
+ ]))
226
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
227
+
228
+ def attention(
229
+ self,
230
+ q_x: torch.Tensor,
231
+ k_x: Optional[torch.Tensor] = None,
232
+ v_x: Optional[torch.Tensor] = None,
233
+ attn_mask: Optional[torch.Tensor] = None,
234
+ ):
235
+ k_x = k_x if k_x is not None else q_x
236
+ v_x = v_x if v_x is not None else q_x
237
+
238
+ attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
239
+ return self.attn(
240
+ q_x, k_x, v_x, need_weights=True, attn_mask=attn_mask
241
+ )
242
+
243
+ def forward(
244
+ self,
245
+ q_x: torch.Tensor,
246
+ k_x: Optional[torch.Tensor] = None,
247
+ v_x: Optional[torch.Tensor] = None,
248
+ attn_mask: Optional[torch.Tensor] = None,
249
+ ):
250
+ k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
251
+ v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
252
+
253
+ tmp, attn = self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
254
+ x = q_x + self.ls_1(tmp)
255
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
256
+ return x, attn
257
+
258
+
259
+ class CustomResidualAttentionBlock(nn.Module):
260
+ def __init__(
261
+ self,
262
+ d_model: int,
263
+ n_head: int,
264
+ mlp_ratio: float = 4.0,
265
+ ls_init_value: float = None,
266
+ act_layer: Callable = nn.GELU,
267
+ norm_layer: Callable = LayerNorm,
268
+ scale_cosine_attn: bool = False,
269
+ scale_heads: bool = False,
270
+ scale_attn: bool = False,
271
+ scale_fc: bool = False,
272
+ ):
273
+ super().__init__()
274
+
275
+ self.ln_1 = norm_layer(d_model)
276
+ self.attn = Attention(
277
+ d_model, n_head,
278
+ scaled_cosine=scale_cosine_attn,
279
+ scale_heads=scale_heads,
280
+ )
281
+ self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
282
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
283
+
284
+ self.ln_2 = norm_layer(d_model)
285
+ mlp_width = int(d_model * mlp_ratio)
286
+ self.mlp = nn.Sequential(OrderedDict([
287
+ ("c_fc", nn.Linear(d_model, mlp_width)),
288
+ ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
289
+ ("gelu", act_layer()),
290
+ ("c_proj", nn.Linear(mlp_width, d_model))
291
+ ]))
292
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
293
+
294
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
295
+ x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
296
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
297
+ return x
298
+
299
+
300
+ class Transformer(nn.Module):
301
+ def __init__(
302
+ self,
303
+ width: int,
304
+ layers: int,
305
+ heads: int,
306
+ mlp_ratio: float = 4.0,
307
+ ls_init_value: float = None,
308
+ act_layer: Callable = nn.GELU,
309
+ norm_layer: Callable = LayerNorm,
310
+ ):
311
+ super().__init__()
312
+ self.width = width
313
+ self.layers = layers
314
+ self.grad_checkpointing = False
315
+
316
+ self.resblocks = nn.ModuleList([
317
+ ResidualAttentionBlock(
318
+ width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer,
319
+ idx=idx)
320
+ for idx in range(layers)
321
+ ])
322
+
323
+ def get_cast_dtype(self) -> torch.dtype:
324
+ return self.resblocks[0].mlp.c_fc.weight.dtype
325
+
326
+ def forward(self, x: torch.Tensor, out_layers: list = [3, 6, 9],
327
+ attn_mask: Optional[torch.Tensor] = None):
328
+ idx = 0
329
+ out_attn = []
330
+ # out_tokens = x
331
+ out_tokens = []
332
+ for r in self.resblocks:
333
+ idx += 1
334
+ if self.grad_checkpointing and not torch.jit.is_scripting():
335
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
336
+ x = checkpoint(r, x, None, None, attn_mask)
337
+ else:
338
+ if idx == 12:
339
+ x, attn = r(x, attn_mask=attn_mask)
340
+ out_attn.append(attn)
341
+ else:
342
+ x, attn_tmp = r(x, attn_mask=attn_mask)
343
+ if idx in out_layers:
344
+ out_tokens.append(x)
345
+ # out_tokens = x
346
+ return x, out_attn, out_tokens
347
+
348
+
349
+
350
+ class VisionTransformer(nn.Module):
351
+ output_tokens: torch.jit.Final[bool]
352
+
353
+ def __init__(
354
+ self,
355
+ image_size: int,
356
+ patch_size: int,
357
+ width: int,
358
+ layers: int,
359
+ heads: int,
360
+ mlp_ratio: float,
361
+ ls_init_value: float = None,
362
+ global_average_pool: bool = False,
363
+ attentional_pool: bool = False,
364
+ n_queries: int = 256,
365
+ attn_pooler_heads: int = 8,
366
+ output_dim: int = 512,
367
+ patch_dropout: float = 0.4,
368
+ input_patchnorm: bool = False,
369
+ act_layer: Callable = nn.GELU,
370
+ norm_layer: Callable = LayerNorm,
371
+ output_tokens: bool = False
372
+ ):
373
+ super().__init__()
374
+ self.output_tokens = output_tokens
375
+ image_height, image_width = self.image_size = to_2tuple(image_size)
376
+ patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
377
+ self.grid_size = (image_height // patch_height, image_width // patch_width)
378
+ self.output_dim = output_dim
379
+
380
+ # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1
381
+ self.input_patchnorm = input_patchnorm
382
+
383
+ if input_patchnorm:
384
+ patch_input_dim = patch_height * patch_width * 3
385
+ self.patchnorm_pre_ln = LayerNorm(patch_input_dim)
386
+ self.conv1 = nn.Linear(patch_input_dim, width)
387
+ else:
388
+ self.patchnorm_pre_ln = nn.Identity()
389
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
390
+
391
+ # class embeddings and positional embeddings
392
+ scale = width ** -0.5
393
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
394
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
395
+
396
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
397
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
398
+
399
+ self.ln_pre = norm_layer(width)
400
+ self.transformer = Transformer(
401
+ width,
402
+ layers,
403
+ heads,
404
+ mlp_ratio,
405
+ ls_init_value=ls_init_value,
406
+ act_layer=act_layer,
407
+ norm_layer=norm_layer,
408
+ )
409
+
410
+ self.global_average_pool = global_average_pool
411
+ if attentional_pool:
412
+ self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries)
413
+ self.ln_post = norm_layer(output_dim)
414
+ self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))
415
+ else:
416
+ self.attn_pool = None
417
+ self.ln_post = norm_layer(width)
418
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
419
+
420
+ self.init_parameters()
421
+
422
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
423
+ for param in self.parameters():
424
+ param.requires_grad = False
425
+
426
+ if unlocked_groups != 0:
427
+ groups = [
428
+ [
429
+ self.conv1,
430
+ self.class_embedding,
431
+ self.positional_embedding,
432
+ self.ln_pre,
433
+ ],
434
+ *self.transformer.resblocks[:-1],
435
+ [
436
+ self.transformer.resblocks[-1],
437
+ self.ln_post,
438
+ ],
439
+ self.proj,
440
+ ]
441
+
442
+ def _unlock(x):
443
+ if isinstance(x, Sequence):
444
+ for g in x:
445
+ _unlock(g)
446
+ else:
447
+ if isinstance(x, torch.nn.Parameter):
448
+ x.requires_grad = True
449
+ else:
450
+ for p in x.parameters():
451
+ p.requires_grad = True
452
+
453
+ _unlock(groups[-unlocked_groups:])
454
+
455
+ def init_parameters(self):
456
+ # FIXME OpenAI CLIP did not define an init for the VisualTransformer
457
+ # TODO experiment if default PyTorch init, below, or alternate init is best.
458
+
459
+ # nn.init.normal_(self.class_embedding, std=self.scale)
460
+ # nn.init.normal_(self.positional_embedding, std=self.scale)
461
+ #
462
+ # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
463
+ # attn_std = self.transformer.width ** -0.5
464
+ # fc_std = (2 * self.transformer.width) ** -0.5
465
+ # for block in self.transformer.resblocks:
466
+ # nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
467
+ # nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
468
+ # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
469
+ # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
470
+ #
471
+ # if self.text_projection is not None:
472
+ # nn.init.normal_(self.text_projection, std=self.scale)
473
+ pass
474
+
475
+ @torch.jit.ignore
476
+ def set_grad_checkpointing(self, enable=True):
477
+ self.transformer.grad_checkpointing = enable
478
+
479
+ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
480
+ if self.global_average_pool:
481
+ return x.mean(dim=1), x
482
+ else:
483
+ return x[:, 0], x[:, 1:]
484
+
485
+ def forward(self, x: torch.Tensor, out_layers: list):
486
+
487
+ # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
488
+ if self.input_patchnorm:
489
+ # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
490
+ x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1])
491
+ x = x.permute(0, 2, 4, 1, 3, 5)
492
+ x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1)
493
+ x = self.patchnorm_pre_ln(x)
494
+ x = self.conv1(x)
495
+ else:
496
+ x = self.conv1(x) # shape = [*, width, grid, grid]
497
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
498
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
499
+
500
+ # class embeddings and positional embeddings
501
+ x = torch.cat(
502
+ [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
503
+ x], dim=1) # shape = [*, grid ** 2 + 1, width]
504
+ x = x + self.positional_embedding.to(x.dtype)
505
+
506
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
507
+ x = self.patch_dropout(x)
508
+ x = self.ln_pre(x)
509
+
510
+ x = x.permute(1, 0, 2) # NLD -> LND
511
+
512
+ x, attn, patch_tokens = self.transformer(x, out_layers)
513
+
514
+ # attn = attn[0, 0, 1:].view(14, 14) # 49
515
+ B, C, L = attn[0].shape
516
+ H = int(math.sqrt(L-1))
517
+ out_attn = torch.zeros([H, H]).to('cuda')
518
+ for i in range(len(attn)):
519
+ out_attn += attn[i][0, 0, 1:].view(H, H)
520
+ x = x.permute(1, 0, 2) # LND -> NLD
521
+ patch_tokens = [patch_tokens[t].permute(1, 0, 2) for t in range(len(patch_tokens))] # LND -> NLD
522
+
523
+ if self.attn_pool is not None:
524
+ x = self.attn_pool(x)
525
+ x = self.ln_post(x)
526
+ pooled, tokens = self._global_pool(x)
527
+ else:
528
+ pooled, tokens = self._global_pool(x)
529
+ pooled = self.ln_post(pooled)
530
+
531
+ if self.proj is not None:
532
+ pooled = pooled @ self.proj
533
+
534
+ if self.output_tokens:
535
+ return pooled, patch_tokens
536
+
537
+ return pooled, patch_tokens
538
+
539
+
540
+ class TextTransformer(nn.Module):
541
+ output_tokens: torch.jit.Final[bool]
542
+
543
+ def __init__(
544
+ self,
545
+ context_length: int = 77,
546
+ vocab_size: int = 49408,
547
+ width: int = 512,
548
+ heads: int = 8,
549
+ layers: int = 12,
550
+ ls_init_value: float = None,
551
+ output_dim: int = 512,
552
+ act_layer: Callable = nn.GELU,
553
+ norm_layer: Callable = LayerNorm,
554
+ embed_cls: bool = False,
555
+ pad_id: int = 0,
556
+ output_tokens: bool = False,
557
+ ):
558
+ super().__init__()
559
+ self.output_tokens = output_tokens
560
+ self.num_pos = self.context_length = context_length
561
+ self.vocab_size = vocab_size
562
+ self.width = width
563
+ self.output_dim = output_dim
564
+ self.heads = heads
565
+ self.pad_id = pad_id
566
+
567
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
568
+
569
+ if embed_cls:
570
+ self.cls_emb = nn.Parameter(torch.empty(width))
571
+ self.num_pos += 1
572
+ else:
573
+ self.cls_emb = None
574
+
575
+ self.token_embedding = nn.Embedding(vocab_size, width)
576
+ self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
577
+ self.transformer = Transformer(
578
+ width=width,
579
+ layers=layers,
580
+ heads=heads,
581
+ ls_init_value=ls_init_value,
582
+ act_layer=act_layer,
583
+ norm_layer=norm_layer,
584
+ )
585
+ self.ln_final = norm_layer(width)
586
+
587
+ self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
588
+
589
+ self.init_parameters()
590
+
591
+ def init_parameters(self):
592
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
593
+ nn.init.normal_(self.positional_embedding, std=0.01)
594
+ if self.cls_emb is not None:
595
+ nn.init.normal_(self.cls_emb, std=0.01)
596
+
597
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
598
+ attn_std = self.transformer.width ** -0.5
599
+ fc_std = (2 * self.transformer.width) ** -0.5
600
+ for block in self.transformer.resblocks:
601
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
602
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
603
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
604
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
605
+
606
+ if self.text_projection is not None:
607
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
608
+
609
+ @torch.jit.ignore
610
+ def set_grad_checkpointing(self, enable=True):
611
+ self.transformer.grad_checkpointing = enable
612
+
613
+ def build_attention_mask(self):
614
+ # lazily create causal attention mask, with full attention between the tokens
615
+ # pytorch uses additive attention mask; fill with -inf
616
+ mask = torch.empty(self.num_pos, self.num_pos)
617
+ mask.fill_(float("-inf"))
618
+ mask.triu_(1) # zero out the lower diagonal
619
+ return mask
620
+
621
+ def build_cls_mask(self, text, cast_dtype: torch.dtype):
622
+ cls_mask = (text != self.pad_id).unsqueeze(1)
623
+ cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
624
+ additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
625
+ additive_mask.fill_(0)
626
+ additive_mask.masked_fill_(~cls_mask, float("-inf"))
627
+ additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
628
+ return additive_mask
629
+
630
+ def _repeat(self, t, N: int):
631
+ return t.reshape(1, 1, -1).repeat(N, 1, 1)
632
+
633
+ def forward(self, text):
634
+ cast_dtype = self.transformer.get_cast_dtype()
635
+ seq_len = text.shape[1]
636
+
637
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
638
+ attn_mask = self.attn_mask
639
+ if self.cls_emb is not None:
640
+ seq_len += 1
641
+ x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
642
+ cls_mask = self.build_cls_mask(text, cast_dtype)
643
+ attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
644
+
645
+ x = x + self.positional_embedding[:seq_len].to(cast_dtype)
646
+ x = x.permute(1, 0, 2) # NLD -> LND
647
+ x, attn, patch_tokens = self.transformer(x, attn_mask=attn_mask)
648
+ x = x.permute(1, 0, 2) # LND -> NLD
649
+
650
+ # x.shape = [batch_size, n_ctx, transformer.width]
651
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
652
+ if self.cls_emb is not None:
653
+ pooled, tokens = x[:, -1], x[:, :-1]
654
+ pooled = self.ln_final(pooled)
655
+ else:
656
+ x = self.ln_final(x)
657
+ pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
658
+
659
+ if self.text_projection is not None:
660
+ pooled = pooled @ self.text_projection
661
+
662
+ if self.output_tokens:
663
+ return pooled, tokens
664
+
665
+ return pooled
666
+
667
+
668
+ class MultimodalTransformer(Transformer):
669
+ def __init__(
670
+ self,
671
+ width: int,
672
+ layers: int,
673
+ heads: int,
674
+ context_length: int = 77,
675
+ mlp_ratio: float = 4.0,
676
+ ls_init_value: float = None,
677
+ act_layer: Callable = nn.GELU,
678
+ norm_layer: Callable = LayerNorm,
679
+ output_dim: int = 512,
680
+ ):
681
+
682
+ super().__init__(
683
+ width=width,
684
+ layers=layers,
685
+ heads=heads,
686
+ mlp_ratio=mlp_ratio,
687
+ ls_init_value=ls_init_value,
688
+ act_layer=act_layer,
689
+ norm_layer=norm_layer,
690
+ )
691
+ self.context_length = context_length
692
+ self.cross_attn = nn.ModuleList([
693
+ ResidualAttentionBlock(
694
+ width,
695
+ heads,
696
+ mlp_ratio,
697
+ ls_init_value=ls_init_value,
698
+ act_layer=act_layer,
699
+ norm_layer=norm_layer,
700
+ is_cross_attention=True,
701
+ )
702
+ for _ in range(layers)
703
+ ])
704
+
705
+ self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
706
+
707
+ self.ln_final = norm_layer(width)
708
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
709
+
710
+ def init_parameters(self):
711
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
712
+ attn_std = self.transformer.width ** -0.5
713
+ fc_std = (2 * self.transformer.width) ** -0.5
714
+ for block in self.transformer.resblocks:
715
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
716
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
717
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
718
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
719
+ for block in self.transformer.cross_attn:
720
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
721
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
722
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
723
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
724
+
725
+ if self.text_projection is not None:
726
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
727
+
728
+ def build_attention_mask(self):
729
+ # lazily create causal attention mask, with full attention between the tokens
730
+ # pytorch uses additive attention mask; fill with -inf
731
+ mask = torch.empty(self.context_length, self.context_length)
732
+ mask.fill_(float("-inf"))
733
+ mask.triu_(1) # zero out the lower diagonal
734
+ return mask
735
+
736
+ def forward(self, image_embs, text_embs):
737
+ text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq
738
+ image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
739
+ seq_len = text_embs.shape[0]
740
+
741
+ for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
742
+ if self.grad_checkpointing and not torch.jit.is_scripting():
743
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
744
+ text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])
745
+ text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)
746
+ else:
747
+ text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
748
+ text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
749
+
750
+ x = text_embs.permute(1, 0, 2) # LND -> NLD
751
+ x = self.ln_final(x)
752
+
753
+ if self.text_projection is not None:
754
+ x = x @ self.text_projection
755
+
756
+ return x
757
+
758
+ @torch.jit.ignore
759
+ def set_grad_checkpointing(self, enable=True):
760
+ self.grad_checkpointing = enable