mingdali commited on
Commit
2e488d0
·
1 Parent(s): 271378d

Delete visual.py

Browse files
Files changed (1) hide show
  1. visual.py +0 -429
visual.py DELETED
@@ -1,429 +0,0 @@
1
- # Copyright (c) Alibaba Cloud.
2
- #
3
- # This source code is licensed under the license found in the
4
- # LICENSE file in the root directory of this source tree.
5
-
6
- from collections import OrderedDict
7
- import math
8
- import requests
9
- from io import BytesIO
10
- from functools import partial
11
- from PIL import Image
12
- from typing import Callable, Optional, Sequence, Tuple, List
13
- import numpy as np
14
- import torch
15
- from torch import nn
16
- from torch.nn import functional as F
17
- from torch.nn.init import trunc_normal_
18
- from torchvision import transforms
19
- from torchvision.transforms import InterpolationMode
20
-
21
-
22
- def get_abs_pos(abs_pos, tgt_size):
23
- # abs_pos: L, C
24
- # tgt_size: M
25
- # return: M, C
26
- src_size = int(math.sqrt(abs_pos.size(0)))
27
- tgt_size = int(math.sqrt(tgt_size))
28
- dtype = abs_pos.dtype
29
-
30
- if src_size != tgt_size:
31
- return F.interpolate(
32
- abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
33
- size=(tgt_size, tgt_size),
34
- mode="bicubic",
35
- align_corners=False,
36
- ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
37
- else:
38
- return abs_pos
39
-
40
- # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
41
- def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
42
- """
43
- grid_size: int of the grid height and width
44
- return:
45
- pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
46
- """
47
- grid_h = np.arange(grid_size, dtype=np.float32)
48
- grid_w = np.arange(grid_size, dtype=np.float32)
49
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
50
- grid = np.stack(grid, axis=0)
51
-
52
- grid = grid.reshape([2, 1, grid_size, grid_size])
53
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
54
- if cls_token:
55
- pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
56
- return pos_embed
57
-
58
-
59
- def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
60
- assert embed_dim % 2 == 0
61
-
62
- # use half of dimensions to encode grid_h
63
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
64
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
65
-
66
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
67
- return emb
68
-
69
-
70
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
71
- """
72
- embed_dim: output dimension for each position
73
- pos: a list of positions to be encoded: size (M,)
74
- out: (M, D)
75
- """
76
- assert embed_dim % 2 == 0
77
- omega = np.arange(embed_dim // 2, dtype=np.float32)
78
- omega /= embed_dim / 2.
79
- omega = 1. / 10000**omega # (D/2,)
80
-
81
- pos = pos.reshape(-1) # (M,)
82
- out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
83
-
84
- emb_sin = np.sin(out) # (M, D/2)
85
- emb_cos = np.cos(out) # (M, D/2)
86
-
87
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
88
- return emb
89
-
90
-
91
- class Resampler(nn.Module):
92
- """
93
- A 2D perceiver-resampler network with one cross attention layers by
94
- (grid_size**2) learnable queries and 2d sincos pos_emb
95
- Outputs:
96
- A tensor with the shape of (grid_size**2, embed_dim)
97
- """
98
- def __init__(
99
- self,
100
- grid_size,
101
- embed_dim,
102
- num_heads,
103
- kv_dim=None,
104
- norm_layer=nn.LayerNorm
105
- ):
106
- super().__init__()
107
- self.num_queries = grid_size ** 2
108
- self.embed_dim = embed_dim
109
- self.num_heads = num_heads
110
-
111
- self.pos_embed = nn.Parameter(
112
- torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
113
- ).requires_grad_(False)
114
-
115
- self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
116
- trunc_normal_(self.query, std=.02)
117
-
118
- if kv_dim is not None and kv_dim != embed_dim:
119
- self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
120
- else:
121
- self.kv_proj = nn.Identity()
122
-
123
- self.attn = nn.MultiheadAttention(embed_dim, num_heads)
124
- self.ln_q = norm_layer(embed_dim)
125
- self.ln_kv = norm_layer(embed_dim)
126
-
127
- self.apply(self._init_weights)
128
-
129
- def _init_weights(self, m):
130
- if isinstance(m, nn.Linear):
131
- trunc_normal_(m.weight, std=.02)
132
- if isinstance(m, nn.Linear) and m.bias is not None:
133
- nn.init.constant_(m.bias, 0)
134
- elif isinstance(m, nn.LayerNorm):
135
- nn.init.constant_(m.bias, 0)
136
- nn.init.constant_(m.weight, 1.0)
137
-
138
- def forward(self, x, attn_mask=None):
139
-
140
- pos_embed = get_abs_pos(self.pos_embed, x.size(1))
141
-
142
- x = self.kv_proj(x)
143
- x = self.ln_kv(x).permute(1, 0, 2)
144
-
145
- N = x.shape[1]
146
- q = self.ln_q(self.query)
147
- out = self.attn(
148
- self._repeat(q, N) + self.pos_embed.unsqueeze(1),
149
- x + pos_embed.unsqueeze(1),
150
- x,
151
- attn_mask=attn_mask)[0]
152
- return out.permute(1, 0, 2)
153
-
154
- def _repeat(self, query, N: int):
155
- return query.unsqueeze(1).repeat(1, N, 1)
156
-
157
-
158
- class VisualAttention(nn.Module):
159
- """self-attention layer class.
160
-
161
- Self-attention layer takes input with size [s, b, h]
162
- and returns output of the same size.
163
- """
164
-
165
- def __init__(self, embed_dim, num_heads,
166
- bias=True, kdim=None, vdim=None):
167
- super(VisualAttention, self).__init__()
168
- self.embed_dim = embed_dim
169
- self.kdim = kdim if kdim is not None else embed_dim
170
- self.vdim = vdim if vdim is not None else embed_dim
171
- self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
172
-
173
- self.num_heads = num_heads
174
-
175
- # Per attention head and per partition values.
176
- assert embed_dim % num_heads == 0
177
- self.hidden_size_per_attention_head = embed_dim // num_heads
178
- self.num_attention_heads_per_partition = num_heads
179
- self.hidden_size_per_partition = embed_dim
180
-
181
- # Strided linear layer.
182
- assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently'
183
- self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
184
- self.out_proj = nn.Linear(embed_dim, embed_dim)
185
- self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
186
-
187
- def forward(self, query, key, value, attn_mask = None):
188
- # query/key/value: [sq, b, h]
189
- sq, b, _ = query.size()
190
-
191
- assert torch.allclose(query, key), 'Only Support Self-Attention Currently'
192
- sk = sq
193
- mixed_x_layer = self.in_proj(query)
194
-
195
- # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
196
- new_tensor_shape = mixed_x_layer.size()[:-1] + \
197
- (self.num_attention_heads_per_partition,
198
- 3 * self.hidden_size_per_attention_head)
199
- mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
200
-
201
- # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
202
- query_layer, key_layer, value_layer = mixed_x_layer.split(
203
- self.hidden_size_per_attention_head, dim=-1)
204
-
205
- # [sq, b, np, hn] -> [sq, b * np, hn]
206
- query_layer = query_layer.view(sq,
207
- b * self.num_attention_heads_per_partition,
208
- self.hidden_size_per_attention_head).transpose(0, 1)
209
- # [sk, b, np, hn] -> [sk, b * np, hn]
210
- key_layer = key_layer.view(sk,
211
- b * self.num_attention_heads_per_partition,
212
- self.hidden_size_per_attention_head).transpose(0, 1)
213
-
214
- q_scaled = query_layer / self.norm_factor
215
- if attn_mask is not None:
216
- attention_probs = torch.baddbmm(attn_mask, q_scaled, key_layer.transpose(-2, -1))
217
- else:
218
- attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
219
- attention_probs = attention_probs.softmax(dim=-1)
220
-
221
- value_layer = value_layer.view(sk,
222
- b * self.num_attention_heads_per_partition,
223
- self.hidden_size_per_attention_head).transpose(0, 1)
224
-
225
- # matmul: [b * np, sq, hn]
226
- context_layer = torch.bmm(attention_probs, value_layer)
227
-
228
- # change view [b, np, sq, hn]
229
- context_layer = context_layer.view(b,
230
- self.num_attention_heads_per_partition,
231
- sq, self.hidden_size_per_attention_head)
232
-
233
- # [b, np, sq, hn] --> [sq, b, np, hn]
234
- context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
235
-
236
- # [sq, b, np, hn] --> [sq, b, hp]
237
- new_context_layer_shape = context_layer.size()[:-2] + \
238
- (self.hidden_size_per_partition,)
239
- context_layer = context_layer.view(*new_context_layer_shape)
240
-
241
- output = self.out_proj(context_layer)
242
-
243
- return output
244
-
245
-
246
- class VisualAttentionBlock(nn.Module):
247
- def __init__(
248
- self,
249
- d_model: int,
250
- n_head: int,
251
- mlp_ratio: float = 4.0,
252
- act_layer: Callable = nn.GELU,
253
- norm_layer: Callable = nn.LayerNorm,
254
- is_cross_attention: bool = False,
255
- ):
256
- super().__init__()
257
-
258
- self.ln_1 = norm_layer(d_model)
259
- if is_cross_attention:
260
- self.ln_1_kv = norm_layer(d_model)
261
-
262
- self.ln_2 = norm_layer(d_model)
263
- mlp_width = int(d_model * mlp_ratio)
264
- self.attn = VisualAttention(d_model, n_head)
265
- self.mlp = nn.Sequential(OrderedDict([
266
- ("c_fc", nn.Linear(d_model, mlp_width)),
267
- ("gelu", act_layer()),
268
- ("c_proj", nn.Linear(mlp_width, d_model))
269
- ]))
270
-
271
- def attention(
272
- self,
273
- q_x: torch.Tensor,
274
- k_x: Optional[torch.Tensor] = None,
275
- v_x: Optional[torch.Tensor] = None,
276
- attn_mask: Optional[torch.Tensor] = None,
277
- ):
278
- k_x = k_x if k_x is not None else q_x
279
- v_x = v_x if v_x is not None else q_x
280
-
281
- attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
282
- return self.attn(q_x, k_x, v_x, attn_mask=attn_mask)
283
-
284
- def forward(
285
- self,
286
- q_x: torch.Tensor,
287
- k_x: Optional[torch.Tensor] = None,
288
- v_x: Optional[torch.Tensor] = None,
289
- attn_mask: Optional[torch.Tensor] = None,
290
- ):
291
- k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
292
- v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
293
-
294
- x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
295
- x = x + self.mlp(self.ln_2(x))
296
- return x
297
-
298
-
299
- class TransformerBlock(nn.Module):
300
- def __init__(
301
- self,
302
- width: int,
303
- layers: int,
304
- heads: int,
305
- mlp_ratio: float = 4.0,
306
- act_layer: Callable = nn.GELU,
307
- norm_layer: Callable = nn.LayerNorm,
308
- ):
309
- super().__init__()
310
- self.width = width
311
- self.layers = layers
312
-
313
- self.resblocks = nn.ModuleList([
314
- VisualAttentionBlock(
315
- width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer)
316
- for _ in range(layers)
317
- ])
318
-
319
- def get_cast_dtype(self) -> torch.dtype:
320
- return self.resblocks[0].mlp.c_fc.weight.dtype
321
-
322
- def get_cast_device(self) -> torch.device:
323
- return self.resblocks[0].mlp.c_fc.weight.device
324
-
325
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
326
- for r in self.resblocks:
327
- x = r(x, attn_mask=attn_mask)
328
- return x
329
-
330
-
331
- class VisionTransformer(nn.Module):
332
-
333
- def __init__(
334
- self,
335
- image_size: int,
336
- patch_size: int,
337
- width: int,
338
- layers: int,
339
- heads: int,
340
- mlp_ratio: float,
341
- n_queries: int = 256,
342
- output_dim: int = 512,
343
- **kwargs
344
- ):
345
- super().__init__()
346
- image_height, image_width = self.image_size = (image_size, image_size)
347
- patch_height, patch_width = self.patch_size = (patch_size, patch_size)
348
- self.grid_size = (image_height // patch_height, image_width // patch_width)
349
- self.output_dim = output_dim
350
-
351
- mean = (0.48145466, 0.4578275, 0.40821073)
352
- std = (0.26862954, 0.26130258, 0.27577711)
353
- self.image_transform = transforms.Compose([
354
- transforms.Resize(
355
- (image_size, image_size),
356
- interpolation=InterpolationMode.BICUBIC
357
- ),
358
- transforms.ToTensor(),
359
- transforms.Normalize(mean=mean, std=std),
360
- ])
361
-
362
- self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
363
-
364
- # class embeddings and positional embeddings
365
- scale = width ** -0.5
366
- self.positional_embedding = nn.Parameter(scale * torch.randn(256, width))
367
-
368
- norm_layer = partial(nn.LayerNorm, eps=1e-6)
369
- act_layer = nn.GELU
370
-
371
- self.ln_pre = norm_layer(width)
372
- self.transformer = TransformerBlock(
373
- width,
374
- layers,
375
- heads,
376
- mlp_ratio,
377
- act_layer=act_layer,
378
- norm_layer=norm_layer,
379
- )
380
-
381
- self.attn_pool = Resampler(
382
- grid_size=int(math.sqrt(n_queries)),
383
- embed_dim=output_dim,
384
- num_heads=output_dim // 128,
385
- kv_dim=width,
386
- norm_layer=norm_layer,
387
- )
388
- self.ln_post = norm_layer(output_dim)
389
- self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
390
-
391
- def forward(self, x: torch.Tensor):
392
- x = x.to(
393
- dtype=self.transformer.get_cast_dtype(),
394
- device=self.transformer.get_cast_device(),
395
- )
396
- # to patches
397
- x = self.conv1(x) # shape = [*, width, grid, grid]
398
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
399
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
400
-
401
- x = x + get_abs_pos(self.positional_embedding, x.size(1))
402
-
403
- x = self.ln_pre(x)
404
-
405
- x = x.permute(1, 0, 2) # NLD -> LND
406
- x = self.transformer(x)
407
- x = x.permute(1, 0, 2) # LND -> NLD
408
-
409
- x = self.attn_pool(x)
410
- x = self.ln_post(x)
411
- x = x @ self.proj
412
-
413
- return x
414
-
415
- def encode(self, image_paths: List[str]):
416
- images = []
417
- for image_path in image_paths:
418
- try:
419
- if image_path.startswith("http://") or image_path.startswith("https://"):
420
- image = Image.open(requests.get(image_path, stream=True).raw)
421
- else:
422
- image = self.image_transform(Image.open(image_path).convert("RGB"))
423
- except:
424
- image = torch.zeros((3, 448, 448))
425
- # pdb.set_trace()
426
- images.append(image)
427
- images = torch.stack(images, dim=0)
428
-
429
- return self(images)