gberton commited on
Commit
cfbf0ec
·
verified ·
1 Parent(s): 8a5436b

Upload image_encoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. image_encoder.py +1002 -0
image_encoder.py ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Vision encoder implementation in PyTorch."""
17
+
18
+ import functools
19
+ import math
20
+ import os
21
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
22
+ import warnings
23
+ import torch
24
+ from torch import nn
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+
28
+
29
+ class Mlp(nn.Module):
30
+ """Transformer MLP, following DINOv2 implementation."""
31
+
32
+ def __init__(
33
+ self,
34
+ in_features: int,
35
+ hidden_features: Optional[int] = None,
36
+ out_features: Optional[int] = None,
37
+ act_layer: Callable[..., nn.Module] = nn.GELU,
38
+ drop: float = 0.0,
39
+ bias: bool = True,
40
+ ) -> None:
41
+ super().__init__()
42
+ out_features = out_features or in_features
43
+ hidden_features = hidden_features or in_features
44
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
45
+ self.act = act_layer()
46
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
47
+ self.drop = nn.Dropout(drop)
48
+
49
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
50
+ x = self.fc1(x)
51
+ x = self.act(x)
52
+ x = self.drop(x)
53
+ x = self.fc2(x)
54
+ x = self.drop(x)
55
+ return x
56
+
57
+
58
+ def make_2tuple(x):
59
+ if isinstance(x, tuple):
60
+ assert len(x) == 2
61
+ return x
62
+
63
+ assert isinstance(x, int)
64
+ return (x, x)
65
+
66
+
67
+ class PatchEmbed(nn.Module):
68
+ """2D image to patch embedding: (B,C,H,W) -> (B,N,D)."""
69
+
70
+ def __init__(
71
+ self,
72
+ img_size: Union[int, Tuple[int, int]] = 224,
73
+ patch_size: Union[int, Tuple[int, int]] = 16,
74
+ in_chans: int = 3,
75
+ embed_dim: int = 768,
76
+ norm_layer: Optional[Callable] = None, # pylint: disable=g-bare-generic
77
+ flatten_embedding: bool = True,
78
+ ) -> None:
79
+ super().__init__()
80
+
81
+ image_hw = make_2tuple(img_size)
82
+ patch_hw = make_2tuple(patch_size)
83
+ patch_grid_size = (
84
+ image_hw[0] // patch_hw[0],
85
+ image_hw[1] // patch_hw[1],
86
+ )
87
+
88
+ self.img_size = image_hw
89
+ self.patch_size = patch_hw
90
+ self.patches_resolution = patch_grid_size
91
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
92
+
93
+ self.in_chans = in_chans
94
+ self.embed_dim = embed_dim
95
+
96
+ self.flatten_embedding = flatten_embedding
97
+
98
+ self.proj = nn.Conv2d(
99
+ in_chans, embed_dim, kernel_size=patch_hw, stride=patch_hw
100
+ )
101
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
102
+
103
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
104
+ _, _, h, w = x.shape
105
+ patch_h, patch_w = self.patch_size
106
+
107
+ assert (
108
+ h % patch_h == 0
109
+ ), f"Input image height {h} is not a multiple of patch height {patch_h}"
110
+ assert (
111
+ w % patch_w == 0
112
+ ), f"Input image width {w} is not a multiple of patch width: {patch_w}"
113
+
114
+ x = self.proj(x) # B C H W
115
+ h, w = x.size(2), x.size(3)
116
+ x = x.flatten(2).transpose(1, 2) # B HW C
117
+ x = self.norm(x)
118
+ if not self.flatten_embedding:
119
+ x = x.reshape(-1, h, w, self.embed_dim) # B H W C
120
+ return x
121
+
122
+ def flops(self) -> float:
123
+ ho, wo = self.patches_resolution
124
+ flops = (
125
+ ho
126
+ * wo
127
+ * self.embed_dim
128
+ * self.in_chans
129
+ * (self.patch_size[0] * self.patch_size[1])
130
+ )
131
+ if self.norm is not None:
132
+ flops += ho * wo * self.embed_dim
133
+ return flops
134
+
135
+
136
+ class SwiGLUFFN(nn.Module):
137
+ """SwiGLU FFN layer, following DINOv2 implementation."""
138
+
139
+ def __init__(
140
+ self,
141
+ in_features: int,
142
+ hidden_features: Optional[int] = None,
143
+ out_features: Optional[int] = None,
144
+ act_layer: Callable[..., nn.Module] = None,
145
+ drop: float = 0.0,
146
+ bias: bool = True,
147
+ ) -> None:
148
+ super().__init__()
149
+ out_features = out_features or in_features
150
+ hidden_features = hidden_features or in_features
151
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
152
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
153
+
154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
155
+ x12 = self.w12(x)
156
+ x1, x2 = x12.chunk(2, dim=-1)
157
+ hidden = F.silu(x1) * x2
158
+ return self.w3(hidden)
159
+
160
+
161
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
162
+ try:
163
+ if XFORMERS_ENABLED:
164
+ from xformers.ops import SwiGLU, memory_efficient_attention, unbind, fmha, scaled_index_add, index_select_cat # pylint: disable=g-multiple-import, g-import-not-at-top
165
+
166
+ XFORMERS_AVAILABLE = True
167
+ warnings.warn("xFormers is available (SwiGLU)")
168
+ else:
169
+ warnings.warn("xFormers is disabled (SwiGLU)")
170
+ raise ImportError
171
+ except ImportError:
172
+ SwiGLU = SwiGLUFFN
173
+ XFORMERS_AVAILABLE = False
174
+
175
+ warnings.warn("xFormers is not available (SwiGLU)")
176
+
177
+
178
+ class SwiGLUFFNFused(SwiGLU):
179
+ """SwiGLU FFN layer, following DINOv2 implementation."""
180
+
181
+ def __init__(
182
+ self,
183
+ in_features: int,
184
+ hidden_features: Optional[int] = None,
185
+ out_features: Optional[int] = None,
186
+ act_layer: Callable[..., nn.Module] = None, # pylint: disable=unused-argument
187
+ drop: float = 0.0, # pylint: disable=unused-argument
188
+ bias: bool = True,
189
+ ) -> None:
190
+ out_features = out_features or in_features
191
+ hidden_features = hidden_features or in_features
192
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
193
+ super().__init__(
194
+ in_features=in_features,
195
+ hidden_features=hidden_features,
196
+ out_features=out_features,
197
+ bias=bias,
198
+ )
199
+
200
+
201
+ class Attention(nn.Module):
202
+ """Attention layer, following DINOv2 implementation."""
203
+
204
+ def __init__(
205
+ self,
206
+ dim: int,
207
+ num_heads: int = 8,
208
+ qkv_bias: bool = False,
209
+ proj_bias: bool = True,
210
+ attn_drop: float = 0.0,
211
+ proj_drop: float = 0.0,
212
+ ) -> None:
213
+ super().__init__()
214
+ self.num_heads = num_heads
215
+ head_dim = dim // num_heads
216
+ self.scale = head_dim**-0.5
217
+
218
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
219
+ self.attn_drop = nn.Dropout(attn_drop)
220
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
221
+ self.proj_drop = nn.Dropout(proj_drop)
222
+
223
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
224
+ b_dim, n_dim, c_dim = x.shape
225
+ qkv = (
226
+ self.qkv(x)
227
+ .reshape(b_dim, n_dim, 3, self.num_heads, c_dim // self.num_heads)
228
+ .permute(2, 0, 3, 1, 4)
229
+ )
230
+
231
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
232
+ attn = q @ k.transpose(-2, -1)
233
+
234
+ attn = attn.softmax(dim=-1)
235
+ attn = self.attn_drop(attn)
236
+
237
+ x = (attn @ v).transpose(1, 2).reshape(b_dim, n_dim, c_dim)
238
+ x = self.proj(x)
239
+ x = self.proj_drop(x)
240
+ return x
241
+
242
+
243
+ class MemEffAttention(Attention):
244
+ """Memory Efficient Attention layer, following DINOv2 implementation."""
245
+
246
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
247
+ if not XFORMERS_AVAILABLE:
248
+ if attn_bias is not None:
249
+ raise AssertionError("xFormers is required for using nested tensors")
250
+ return super().forward(x)
251
+
252
+ b_dim, n_dim, c_dim = x.shape
253
+ qkv = self.qkv(x).reshape(
254
+ b_dim, n_dim, 3, self.num_heads, c_dim // self.num_heads
255
+ )
256
+
257
+ q, k, v = unbind(qkv, 2)
258
+
259
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
260
+ x = x.reshape([b_dim, n_dim, c_dim])
261
+
262
+ x = self.proj(x)
263
+ x = self.proj_drop(x)
264
+ return x
265
+
266
+
267
+ class LayerScale(nn.Module):
268
+ """Layer scale, following DINOv2 implementation."""
269
+
270
+ def __init__(
271
+ self,
272
+ dim: int,
273
+ init_values: Union[float, torch.Tensor] = 1e-5,
274
+ inplace: bool = False,
275
+ ) -> None:
276
+ super().__init__()
277
+ self.inplace = inplace
278
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
279
+
280
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
281
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
282
+
283
+
284
+ def drop_path_impl(x, drop_prob: float = 0.0, training: bool = False):
285
+ if drop_prob == 0.0 or not training:
286
+ return x
287
+ keep_prob = 1 - drop_prob
288
+ shape = (x.shape[0],) + (1,) * (
289
+ x.ndim - 1
290
+ ) # work with diff dim tensors, not just 2D ConvNets
291
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
292
+ if keep_prob > 0.0:
293
+ random_tensor.div_(keep_prob)
294
+ output = x * random_tensor
295
+ return output
296
+
297
+
298
+ class DropPath(nn.Module):
299
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
300
+
301
+ def __init__(self, drop_prob=None):
302
+ super(DropPath, self).__init__()
303
+ self.drop_prob = drop_prob
304
+
305
+ def forward(self, x):
306
+ return drop_path_impl(x, self.drop_prob, self.training)
307
+
308
+
309
+ class Block(nn.Module):
310
+ """Transformer Block Implementation, following DINOv2 implementation."""
311
+
312
+ def __init__(
313
+ self,
314
+ dim: int,
315
+ num_heads: int,
316
+ mlp_ratio: float = 4.0,
317
+ qkv_bias: bool = False,
318
+ proj_bias: bool = True,
319
+ ffn_bias: bool = True,
320
+ drop: float = 0.0,
321
+ attn_drop: float = 0.0,
322
+ init_values=None,
323
+ drop_path: float = 0.0,
324
+ act_layer: Callable[..., nn.Module] = nn.GELU,
325
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
326
+ attn_class: Callable[..., nn.Module] = Attention,
327
+ ffn_layer: Callable[..., nn.Module] = Mlp,
328
+ ) -> None:
329
+ super().__init__()
330
+ self.norm1 = norm_layer(dim)
331
+ self.attn = attn_class(
332
+ dim,
333
+ num_heads=num_heads,
334
+ qkv_bias=qkv_bias,
335
+ proj_bias=proj_bias,
336
+ attn_drop=attn_drop,
337
+ proj_drop=drop,
338
+ )
339
+ self.ls1 = (
340
+ LayerScale(dim, init_values=init_values)
341
+ if init_values
342
+ else nn.Identity()
343
+ )
344
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
345
+
346
+ self.norm2 = norm_layer(dim)
347
+ mlp_hidden_dim = int(dim * mlp_ratio)
348
+ self.mlp = ffn_layer(
349
+ in_features=dim,
350
+ hidden_features=mlp_hidden_dim,
351
+ act_layer=act_layer,
352
+ drop=drop,
353
+ bias=ffn_bias,
354
+ )
355
+ self.ls2 = (
356
+ LayerScale(dim, init_values=init_values)
357
+ if init_values
358
+ else nn.Identity()
359
+ )
360
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
361
+
362
+ self.sample_drop_ratio = drop_path
363
+
364
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
365
+ def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
366
+ return self.ls1(self.attn(self.norm1(x)))
367
+
368
+ def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
369
+ return self.ls2(self.mlp(self.norm2(x)))
370
+
371
+ if self.training and self.sample_drop_ratio > 0.1:
372
+ # the overhead is compensated only for a drop path rate larger than 0.1
373
+ x = drop_add_residual_stochastic_depth(
374
+ x,
375
+ residual_func=attn_residual_func,
376
+ sample_drop_ratio=self.sample_drop_ratio,
377
+ )
378
+ x = drop_add_residual_stochastic_depth(
379
+ x,
380
+ residual_func=ffn_residual_func,
381
+ sample_drop_ratio=self.sample_drop_ratio,
382
+ )
383
+ elif self.training and self.sample_drop_ratio > 0.0:
384
+ x = x + self.drop_path1(attn_residual_func(x))
385
+ x = x + self.drop_path1(ffn_residual_func(x))
386
+ else:
387
+ x = x + attn_residual_func(x)
388
+ x = x + ffn_residual_func(x)
389
+ return x
390
+
391
+
392
+ def drop_add_residual_stochastic_depth(
393
+ x: torch.Tensor,
394
+ residual_func: Callable[[torch.Tensor], torch.Tensor],
395
+ sample_drop_ratio: float = 0.0,
396
+ ) -> torch.Tensor:
397
+ """This function is taken from the original implementation in DINOv2 to implement stochastic depth in the image encoder."""
398
+ # 1) extract subset using permutation
399
+ b, _, _ = x.shape
400
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
401
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
402
+ x_subset = x[brange]
403
+
404
+ # 2) apply residual_func to get residual
405
+ residual = residual_func(x_subset)
406
+
407
+ x_flat = x.flatten(1)
408
+ residual = residual.flatten(1)
409
+
410
+ residual_scale_factor = b / sample_subset_size
411
+
412
+ # 3) add the residual
413
+ x_plus_residual = torch.index_add(
414
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
415
+ )
416
+ return x_plus_residual.view_as(x)
417
+
418
+
419
+ def get_branges_scales(x, sample_drop_ratio=0.0):
420
+ b, _, _ = x.shape
421
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
422
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
423
+ residual_scale_factor = b / sample_subset_size
424
+ return brange, residual_scale_factor
425
+
426
+
427
+ def add_residual(
428
+ x, brange, residual, residual_scale_factor, scaling_vector=None
429
+ ):
430
+ """Implement residual addition in the image encoder."""
431
+ if scaling_vector is None:
432
+ x_flat = x.flatten(1)
433
+ residual = residual.flatten(1)
434
+ x_plus_residual = torch.index_add(
435
+ x_flat,
436
+ 0,
437
+ brange,
438
+ residual.to(dtype=x.dtype),
439
+ alpha=residual_scale_factor,
440
+ )
441
+ else:
442
+ x_plus_residual = scaled_index_add(
443
+ x,
444
+ brange,
445
+ residual.to(dtype=x.dtype),
446
+ scaling=scaling_vector,
447
+ alpha=residual_scale_factor,
448
+ )
449
+ return x_plus_residual
450
+
451
+
452
+ attn_bias_cache: Dict[Tuple, Any] = {} # pylint: disable=g-bare-generic
453
+
454
+
455
+ def get_attn_bias_and_cat(x_list, branges=None):
456
+ """this will perform the index select, cat the tensors, and provide the attn_bias from cache."""
457
+ batch_sizes = (
458
+ [b.shape[0] for b in branges]
459
+ if branges is not None
460
+ else [x.shape[0] for x in x_list]
461
+ )
462
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
463
+ if all_shapes not in attn_bias_cache.keys():
464
+ seqlens = []
465
+ for b, x in zip(batch_sizes, x_list):
466
+ for _ in range(b):
467
+ seqlens.append(x.shape[1])
468
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
469
+ attn_bias._batch_sizes = batch_sizes # pylint: disable=protected-access
470
+ attn_bias_cache[all_shapes] = attn_bias
471
+
472
+ if branges is not None:
473
+ cat_tensors = index_select_cat(
474
+ [x.flatten(1) for x in x_list], branges
475
+ ).view(1, -1, x_list[0].shape[-1])
476
+ else:
477
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
478
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
479
+
480
+ return attn_bias_cache[all_shapes], cat_tensors
481
+
482
+
483
+ def drop_add_residual_stochastic_depth_list(
484
+ x_list: List[torch.Tensor],
485
+ residual_func: Callable[[torch.Tensor, Any], torch.Tensor],
486
+ sample_drop_ratio: float = 0.0,
487
+ scaling_vector=None,
488
+ ) -> torch.Tensor:
489
+ """Add residual to a list of tensors."""
490
+ # 1) generate random set of indices for dropping samples in the batch.
491
+ branges_scales = [
492
+ get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
493
+ ]
494
+ branges = [s[0] for s in branges_scales]
495
+ residual_scale_factors = [s[1] for s in branges_scales]
496
+
497
+ # 2) get attention bias and index+concat the tensors.
498
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
499
+
500
+ # 3) apply residual_func to get residual, and split the result.
501
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
502
+
503
+ outputs = []
504
+ for x, brange, residual, residual_scale_factor in zip(
505
+ x_list, branges, residual_list, residual_scale_factors
506
+ ):
507
+ outputs.append(
508
+ add_residual(
509
+ x, brange, residual, residual_scale_factor, scaling_vector
510
+ ).view_as(x)
511
+ )
512
+ return outputs
513
+
514
+
515
+ class NestedTensorBlock(Block):
516
+ """Nested tensor block implementation."""
517
+
518
+ def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
519
+ """x_list contains a list of tensors to nest together and run."""
520
+ assert isinstance(self.attn, MemEffAttention)
521
+
522
+ if self.training and self.sample_drop_ratio > 0.0:
523
+
524
+ def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
525
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
526
+
527
+ def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
528
+ del attn_bias
529
+ return self.mlp(self.norm2(x))
530
+
531
+ x_list = drop_add_residual_stochastic_depth_list(
532
+ x_list,
533
+ residual_func=attn_residual_func,
534
+ sample_drop_ratio=self.sample_drop_ratio,
535
+ scaling_vector=self.ls1.gamma
536
+ if isinstance(self.ls1, LayerScale)
537
+ else None,
538
+ )
539
+ x_list = drop_add_residual_stochastic_depth_list(
540
+ x_list,
541
+ residual_func=ffn_residual_func,
542
+ sample_drop_ratio=self.sample_drop_ratio,
543
+ scaling_vector=self.ls2.gamma
544
+ if isinstance(self.ls1, LayerScale)
545
+ else None,
546
+ )
547
+ return x_list
548
+ else:
549
+
550
+ def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
551
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
552
+
553
+ def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
554
+ del attn_bias
555
+ return self.ls2(self.mlp(self.norm2(x)))
556
+
557
+ attn_bias, x = get_attn_bias_and_cat(x_list)
558
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
559
+ x = x + ffn_residual_func(x)
560
+ return attn_bias.split(x)
561
+
562
+ def forward(self, x):
563
+ if isinstance(x, torch.Tensor):
564
+ return super().forward(x)
565
+ elif isinstance(x, list):
566
+ if not XFORMERS_AVAILABLE:
567
+ raise AssertionError("xFormers is required for using nested tensors")
568
+ return self.forward_nested(x)
569
+ else:
570
+ raise AssertionError
571
+
572
+
573
+ def named_apply(
574
+ fn: Callable, # pylint: disable=g-bare-generic
575
+ module: nn.Module,
576
+ name="",
577
+ depth_first=True,
578
+ include_root=False,
579
+ ) -> nn.Module:
580
+ """Apply a function to a module and its children."""
581
+ if not depth_first and include_root:
582
+ fn(module=module, name=name)
583
+ for child_name, child_module in module.named_children():
584
+ child_name = ".".join((name, child_name)) if name else child_name
585
+ named_apply(
586
+ fn=fn,
587
+ module=child_module,
588
+ name=child_name,
589
+ depth_first=depth_first,
590
+ include_root=True,
591
+ )
592
+ if depth_first and include_root:
593
+ fn(module=module, name=name)
594
+ return module
595
+
596
+
597
+ class BlockChunk(nn.ModuleList):
598
+
599
+ def forward(self, x):
600
+ for b in self:
601
+ x = b(x)
602
+ return x
603
+
604
+
605
+ class VisionTransformer(nn.Module):
606
+ """Vision Transformer implementation."""
607
+
608
+ def __init__(
609
+ self,
610
+ img_size=224,
611
+ patch_size=16,
612
+ in_chans=3,
613
+ embed_dim=768,
614
+ depth=12,
615
+ num_heads=12,
616
+ mlp_ratio=4.0,
617
+ qkv_bias=True,
618
+ ffn_bias=True,
619
+ proj_bias=True,
620
+ drop_path_rate=0.0,
621
+ drop_path_uniform=False,
622
+ init_values=None, # for layerscale: None or 0 => no layerscale
623
+ embed_layer=PatchEmbed,
624
+ act_layer=nn.GELU,
625
+ block_fn=Block,
626
+ ffn_layer="mlp",
627
+ block_chunks=1,
628
+ num_register_tokens=0,
629
+ interpolate_antialias=False,
630
+ interpolate_offset=0.1,
631
+ ):
632
+ """Defines the Vision Transformer model.
633
+
634
+ Args:
635
+ img_size (int, tuple): input image size
636
+ patch_size (int, tuple): patch size
637
+ in_chans (int): number of input channels
638
+ embed_dim (int): embedding dimension
639
+ depth (int): depth of transformer
640
+ num_heads (int): number of attention heads
641
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
642
+ qkv_bias (bool): enable bias for qkv if True
643
+ ffn_bias (bool): enable bias for ffn if True
644
+ proj_bias (bool): enable bias for proj in attn if True
645
+ drop_path_rate (float): stochastic depth rate
646
+ drop_path_uniform (bool): apply uniform drop rate across blocks
647
+ init_values (float): layer-scale init values
648
+ embed_layer (nn.Module): patch embedding layer
649
+ act_layer (nn.Module): MLP activation layer
650
+ block_fn (nn.Module): transformer block class
651
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
652
+ block_chunks: (int) split block sequence into block_chunks units for FSDP
653
+ wrap
654
+ num_register_tokens: (int) number of extra cls tokens (so-called
655
+ "registers")
656
+ interpolate_antialias: (str) flag to apply anti-aliasing when
657
+ interpolating positional embeddings
658
+ interpolate_offset: (float) work-around offset to apply when interpolating
659
+ positional embeddings
660
+ """
661
+ super().__init__()
662
+ norm_layer = functools.partial(nn.LayerNorm, eps=1e-6)
663
+
664
+ self.num_features = self.embed_dim = (
665
+ embed_dim # num_features for consistency with other models
666
+ )
667
+ self.num_tokens = 1
668
+ self.n_blocks = depth
669
+ self.num_heads = num_heads
670
+ self.patch_size = patch_size
671
+ self.num_register_tokens = num_register_tokens
672
+ self.interpolate_antialias = interpolate_antialias
673
+ self.interpolate_offset = interpolate_offset
674
+
675
+ self.patch_embed = embed_layer(
676
+ img_size=img_size,
677
+ patch_size=patch_size,
678
+ in_chans=in_chans,
679
+ embed_dim=embed_dim,
680
+ )
681
+ num_patches = self.patch_embed.num_patches
682
+
683
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
684
+ self.pos_embed = nn.Parameter(
685
+ torch.zeros(1, num_patches + self.num_tokens, embed_dim)
686
+ )
687
+ assert num_register_tokens >= 0
688
+ self.register_tokens = (
689
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
690
+ if num_register_tokens
691
+ else None
692
+ )
693
+
694
+ if drop_path_uniform:
695
+ dpr = [drop_path_rate] * depth
696
+ else:
697
+ dpr = [
698
+ drop_path_rate * i / max(depth - 1, 1) for i in range(depth)
699
+ ] # stochastic depth decay rule
700
+
701
+ if ffn_layer == "mlp":
702
+ ffn_layer = Mlp
703
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
704
+ ffn_layer = SwiGLUFFNFused
705
+ else:
706
+ raise NotImplementedError
707
+
708
+ blocks_list = [
709
+ block_fn(
710
+ dim=embed_dim,
711
+ num_heads=num_heads,
712
+ mlp_ratio=mlp_ratio,
713
+ qkv_bias=qkv_bias,
714
+ proj_bias=proj_bias,
715
+ ffn_bias=ffn_bias,
716
+ drop_path=dpr[i],
717
+ norm_layer=norm_layer,
718
+ act_layer=act_layer,
719
+ ffn_layer=ffn_layer,
720
+ init_values=init_values,
721
+ )
722
+ for i in range(depth)
723
+ ]
724
+ if block_chunks > 0:
725
+ self.chunked_blocks = True
726
+ chunked_blocks = []
727
+ chunksize = depth // block_chunks
728
+ for i in range(0, depth, chunksize):
729
+ # this is to keep the block index consistent if we chunk the block list
730
+ chunked_blocks.append(
731
+ [nn.Identity()] * i + blocks_list[i : i + chunksize]
732
+ )
733
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
734
+ else:
735
+ self.chunked_blocks = False
736
+ self.blocks = nn.ModuleList(blocks_list)
737
+
738
+ self.norm = norm_layer(embed_dim)
739
+ self.head = nn.Identity()
740
+
741
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
742
+
743
+ self.init_weights()
744
+
745
+ def init_weights(self):
746
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
747
+ nn.init.normal_(self.cls_token, std=1e-6)
748
+ if self.register_tokens is not None:
749
+ nn.init.normal_(self.register_tokens, std=1e-6)
750
+ named_apply(init_weights_vit_timm, self)
751
+
752
+ def interpolate_pos_encoding(self, x, w, h):
753
+ previous_dtype = x.dtype
754
+ npatch = x.shape[1] - 1
755
+ num_patches = self.pos_embed.shape[1] - 1
756
+ if npatch == num_patches and w == h:
757
+ return self.pos_embed
758
+ pos_embed = self.pos_embed.float()
759
+ class_pos_embed = pos_embed[:, 0]
760
+ patch_pos_embed = pos_embed[:, 1:]
761
+ dim = x.shape[-1]
762
+ w0 = w // self.patch_size
763
+ h0 = h // self.patch_size
764
+ num_patches_dim = int(
765
+ math.sqrt(num_patches)
766
+ ) # Recover the number of patches in each dimension
767
+ assert num_patches == num_patches_dim * num_patches_dim
768
+ kwargs = {}
769
+ if self.interpolate_offset:
770
+ sx = float(w0 + self.interpolate_offset) / num_patches_dim
771
+ sy = float(h0 + self.interpolate_offset) / num_patches_dim
772
+ kwargs["scale_factor"] = (sx, sy)
773
+ else:
774
+ # Simply specify an output size instead of a scale factor
775
+ kwargs["size"] = (w0, h0)
776
+ patch_pos_embed = nn.functional.interpolate(
777
+ patch_pos_embed.reshape(
778
+ 1, num_patches_dim, num_patches_dim, dim
779
+ ).permute(0, 3, 1, 2),
780
+ mode="bilinear",
781
+ antialias=self.interpolate_antialias,
782
+ **kwargs,
783
+ )
784
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
785
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
786
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
787
+ previous_dtype
788
+ )
789
+
790
+ def prepare_tokens_with_masks(self, x, masks=None):
791
+ _, _, w, h = x.shape
792
+ x = self.patch_embed(x)
793
+ if masks is not None:
794
+ x = torch.where(
795
+ masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
796
+ )
797
+
798
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
799
+ x = x + self.interpolate_pos_encoding(x, w, h)
800
+
801
+ if self.register_tokens is not None:
802
+ x = torch.cat(
803
+ (
804
+ x[:, :1],
805
+ self.register_tokens.expand(x.shape[0], -1, -1),
806
+ x[:, 1:],
807
+ ),
808
+ dim=1,
809
+ )
810
+
811
+ return x
812
+
813
+ def forward_features_list(self, x_list, masks_list):
814
+ x = [
815
+ self.prepare_tokens_with_masks(x, masks)
816
+ for x, masks in zip(x_list, masks_list)
817
+ ]
818
+ for blk in self.blocks:
819
+ x = blk(x)
820
+
821
+ all_x = x
822
+ output = []
823
+ for x, masks in zip(all_x, masks_list):
824
+ x_norm = self.norm(x)
825
+ output.append({
826
+ "x_norm_1st_clstoken": x_norm[:, :1],
827
+ "x_norm_2nd_clstoken": x_norm[:, 1 : self.num_register_tokens + 1],
828
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
829
+ "x_prenorm": x,
830
+ "masks": masks,
831
+ })
832
+ return output
833
+
834
+ def forward_features(self, x, masks=None):
835
+ if isinstance(x, list):
836
+ return self.forward_features_list(x, masks)
837
+
838
+ x = self.prepare_tokens_with_masks(x, masks)
839
+
840
+ for blk in self.blocks:
841
+ x = blk(x)
842
+
843
+ x_norm = self.norm(x)
844
+ return {
845
+ "x_norm_1st_clstoken": x_norm[:, :1],
846
+ "x_norm_2nd_clstoken": x_norm[:, 1 : self.num_register_tokens + 1],
847
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
848
+ "x_prenorm": x,
849
+ "masks": masks,
850
+ }
851
+
852
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
853
+ x = self.prepare_tokens_with_masks(x)
854
+ # If n is an int, take the n last blocks. If it's a list, take them
855
+ output, total_block_len = [], len(self.blocks)
856
+ blocks_to_take = (
857
+ range(total_block_len - n, total_block_len) if isinstance(n, int) else n
858
+ )
859
+ for i, blk in enumerate(self.blocks):
860
+ x = blk(x)
861
+ if i in blocks_to_take:
862
+ output.append(x)
863
+ assert len(output) == len(
864
+ blocks_to_take
865
+ ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
866
+ return output
867
+
868
+ def _get_intermediate_layers_chunked(self, x, n=1):
869
+ x = self.prepare_tokens_with_masks(x)
870
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
871
+ # If n is an int, take the n last blocks. If it's a list, take them
872
+ blocks_to_take = (
873
+ range(total_block_len - n, total_block_len) if isinstance(n, int) else n
874
+ )
875
+ for block_chunk in self.blocks:
876
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
877
+ x = blk(x)
878
+ if i in blocks_to_take:
879
+ output.append(x)
880
+ i += 1
881
+ assert len(output) == len(
882
+ blocks_to_take
883
+ ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
884
+ return output
885
+
886
+ def get_intermediate_layers(
887
+ self,
888
+ x: torch.torch.Tensor,
889
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take # pylint: disable=g-bare-generic
890
+ reshape: bool = False,
891
+ return_class_token: bool = False,
892
+ norm=True,
893
+ ) -> Tuple[Union[torch.torch.Tensor, Tuple[torch.torch.Tensor]]]: # pylint: disable=g-one-element-tuple
894
+ if self.chunked_blocks:
895
+ outputs = self._get_intermediate_layers_chunked(x, n)
896
+ else:
897
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
898
+ if norm:
899
+ outputs = [self.norm(out) for out in outputs]
900
+ class_tokens = [out[:, 0] for out in outputs]
901
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
902
+ if reshape:
903
+ batch_size, _, w, h = x.shape
904
+ outputs = [
905
+ out.reshape(
906
+ batch_size, w // self.patch_size, h // self.patch_size, -1
907
+ )
908
+ .permute(0, 3, 1, 2)
909
+ .contiguous()
910
+ for out in outputs
911
+ ]
912
+ if return_class_token:
913
+ return tuple(zip(outputs, class_tokens))
914
+ return tuple(outputs)
915
+
916
+ def forward(self, *args, is_training=False, **kwargs):
917
+ ret = self.forward_features(*args, **kwargs)
918
+ if is_training:
919
+ return ret
920
+ else:
921
+ return self.head(ret["x_norm_1st_clstoken"]), self.head(
922
+ ret["x_norm_2nd_clstoken"]
923
+ ), ret["x_norm_patchtokens"]
924
+
925
+
926
+ def init_weights_vit_timm(module: nn.Module, name: str = ""): # pylint: disable=unused-argument
927
+ """ViT weight initialization, original timm impl (for reproducibility)."""
928
+ if isinstance(module, nn.Linear):
929
+ nn.init.trunc_normal_(module.weight, std=0.02)
930
+ if module.bias is not None:
931
+ nn.init.zeros_(module.bias)
932
+
933
+
934
+ def vit_small(patch_size=14, **kwargs):
935
+ model = VisionTransformer(
936
+ patch_size=patch_size,
937
+ embed_dim=384,
938
+ depth=12,
939
+ num_heads=6,
940
+ mlp_ratio=4,
941
+ block_fn=functools.partial(Block, attn_class=MemEffAttention),
942
+ num_register_tokens=1,
943
+ **kwargs,
944
+ )
945
+ return model
946
+
947
+
948
+ def vit_base(patch_size=14, **kwargs):
949
+ model = VisionTransformer(
950
+ patch_size=patch_size,
951
+ embed_dim=768,
952
+ depth=12,
953
+ num_heads=12,
954
+ mlp_ratio=4,
955
+ block_fn=functools.partial(Block, attn_class=MemEffAttention),
956
+ num_register_tokens=1,
957
+ **kwargs,
958
+ )
959
+ return model
960
+
961
+
962
+ def vit_large(patch_size=14, **kwargs):
963
+ model = VisionTransformer(
964
+ patch_size=patch_size,
965
+ embed_dim=1024,
966
+ depth=24,
967
+ num_heads=16,
968
+ mlp_ratio=4,
969
+ block_fn=functools.partial(Block, attn_class=MemEffAttention),
970
+ num_register_tokens=1,
971
+ **kwargs,
972
+ )
973
+ return model
974
+
975
+
976
+ def vit_so400m(patch_size=14, **kwargs):
977
+ """SoViT 400M model (https://arxiv.org/abs/2305.13035)."""
978
+ model = VisionTransformer(
979
+ patch_size=patch_size,
980
+ embed_dim=1152,
981
+ depth=27,
982
+ num_heads=16,
983
+ mlp_ratio=4304 / 1152,
984
+ block_fn=functools.partial(Block, attn_class=MemEffAttention),
985
+ num_register_tokens=1,
986
+ **kwargs,
987
+ )
988
+ return model
989
+
990
+
991
+ def vit_giant2(patch_size=14, **kwargs):
992
+ model = VisionTransformer(
993
+ patch_size=patch_size,
994
+ embed_dim=1536,
995
+ depth=40,
996
+ num_heads=24,
997
+ mlp_ratio=4,
998
+ block_fn=functools.partial(Block, attn_class=MemEffAttention),
999
+ num_register_tokens=1,
1000
+ **kwargs,
1001
+ )
1002
+ return model