huu-ontocord commited on
Commit
b6b800b
·
verified ·
1 Parent(s): c2b09ed

Update seed2_tokenizer.py

Browse files
Files changed (1) hide show
  1. seed2_tokenizer.py +443 -126
seed2_tokenizer.py CHANGED
@@ -36,7 +36,8 @@ from diffusers import DiffusionPipeline
36
  from PIL import Image
37
  from torchvision import transforms
38
 
39
- WEIGHTS_NAME = 'seed_quantizer.pt'
 
40
  DIFFUSION_NAME = 'stabilityai/stable-diffusion-2-1-unclip'
41
 
42
  # from qformer.qformer_quantizer import Blip2QformerQuantizer
@@ -66,8 +67,8 @@ import torch.nn as nn
66
  import torch.distributed as dist
67
  import torch.nn.functional as F
68
 
 
69
 
70
- from .eva_vit import create_eva_vit_g, VisionTransformerEvaClip
71
  from transformers import BertTokenizer
72
 
73
  import math
@@ -130,6 +131,420 @@ from transformers.models.bert.configuration_bert import BertConfig
130
 
131
  #torch.set_printoptions(profile="full")
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  class BertEmbeddings(nn.Module):
135
  """Construct the embeddings from word and position embeddings."""
@@ -1491,10 +1906,6 @@ class VisionTransformer(nn.Module):
1491
 
1492
  return x
1493
 
1494
- @torch.jit.ignore()
1495
- def load_pretrained(self, checkpoint_path, prefix=""):
1496
- _load_weights(self, checkpoint_path, prefix)
1497
-
1498
 
1499
  @torch.no_grad()
1500
  def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ""):
@@ -1642,19 +2053,9 @@ class Blip2Base(PreTrainedModel):
1642
  tokenizer.add_special_tokens({"bos_token": "[DEC]"})
1643
  return tokenizer
1644
 
1645
- def maybe_autocast(self, dtype=torch.float16):
1646
- # if on cpu, don't use autocast
1647
- # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
1648
- enable_autocast = self.device != torch.device("cpu")
1649
-
1650
- if enable_autocast:
1651
- return torch.cuda.amp.autocast(dtype=dtype)
1652
- else:
1653
- return contextlib.nullcontext()
1654
-
1655
  @classmethod
1656
  def init_Qformer(cls, encoder_config, num_query_token, vision_width, cross_attention_freq=2, cache_dir=""):
1657
- print ("loading")
1658
  encoder_config = BertConfig.from_pretrained("bert-base-uncased")
1659
  encoder_config.encoder_width = vision_width
1660
  # insert cross-attention layer every other block
@@ -1666,74 +2067,6 @@ class Blip2Base(PreTrainedModel):
1666
  query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
1667
  return Qformer, query_tokens
1668
 
1669
- def load_from_pretrained(self, url_or_filename):
1670
- if is_url(url_or_filename):
1671
- cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
1672
- checkpoint = torch.load(cached_file, map_location="cpu")
1673
- elif os.path.isfile(url_or_filename):
1674
- checkpoint = torch.load(url_or_filename, map_location="cpu")
1675
- else:
1676
- raise RuntimeError("checkpoint url or path is invalid")
1677
-
1678
- state_dict = checkpoint["model"]
1679
-
1680
- msg = self.load_state_dict(state_dict, strict=False)
1681
-
1682
- # logging.info("Missing keys {}".format(msg.missing_keys))
1683
- logging.info("load checkpoint from %s" % url_or_filename)
1684
-
1685
- return msg
1686
-
1687
- def _lemmatize(self, answers):
1688
- def apply(answer):
1689
- doc = self.lemmatizer(answer)
1690
-
1691
- words = []
1692
- for token in doc:
1693
- if token.pos_ in ["NOUN", "VERB"]:
1694
- words.append(token.lemma_)
1695
- else:
1696
- words.append(token.text)
1697
- answer = " ".join(words)
1698
-
1699
- return answer
1700
-
1701
- return [apply(answer) for answer in answers]
1702
-
1703
- @property
1704
- def lemmatizer(self):
1705
- if self._lemmatizer is None:
1706
- try:
1707
- import spacy
1708
-
1709
- self._lemmatizer = spacy.load("en_core_web_sm")
1710
- except ImportError:
1711
- logging.error("""
1712
- Please install spacy and en_core_web_sm model to apply lemmatization.
1713
- python -m spacy download en_core_web_sm
1714
- OR
1715
- import spacy.cli
1716
- spacy.cli.download("en_core_web_sm")
1717
- """)
1718
- exit(1)
1719
-
1720
- return self._lemmatizer
1721
-
1722
-
1723
- def disabled_train(self, mode=True):
1724
- """Overwrite model.train with this function to make sure train/eval mode
1725
- does not change anymore."""
1726
- return self
1727
-
1728
-
1729
- class LayerNorm(nn.LayerNorm):
1730
- """Subclass torch's LayerNorm to handle fp16."""
1731
- def forward(self, x: torch.Tensor):
1732
- orig_type = x.dtype
1733
- ret = super().forward(x.type(torch.float32))
1734
- return ret.type(orig_type)
1735
-
1736
-
1737
 
1738
 
1739
  class VectorQuantizer2(nn.Module):
@@ -1977,43 +2310,28 @@ class Blip2QformerQuantizer(Blip2Base):
1977
  )
1978
  self.distill_image_proj = nn.Linear(num_query_token * 32, image_features_dim)
1979
 
1980
- @classmethod
1981
- def load_from_pretrained(cls, config, pretrained_model_path, **kwargs):
1982
- img_size = kwargs.get("image_size", 224)
1983
- num_query_token = kwargs.get("num_query_token", 32)
1984
- cross_attention_freq = kwargs.get("cross_attention_freq", 2)
1985
-
1986
- drop_path_rate = kwargs.get("drop_path_rate", 0)
1987
- use_grad_checkpoint = kwargs.get("use_grad_checkpoint", False)
1988
- freeze_vit = kwargs.get("freeze_vit", True)
1989
- cache_dir = kwargs.get("cache_dir", "./")
1990
-
1991
- max_txt_len = kwargs.get("max_txt_len", 32)
1992
-
1993
- model = cls(config,
1994
- img_size=img_size,
1995
- drop_path_rate=drop_path_rate,
1996
- use_grad_checkpoint=use_grad_checkpoint,
1997
- freeze_vit=freeze_vit,
1998
- num_query_token=num_query_token,
1999
- cross_attention_freq=cross_attention_freq,
2000
- max_txt_len=max_txt_len,
2001
- cache_dir=cache_dir,
2002
- )
2003
-
2004
- ckpt = torch.load(cache_dir+pretrained_model_path, map_location="cpu")
2005
- missing, unexcepted = model.load_state_dict(ckpt, strict=False)
2006
- #print('**** missing keys: ', missing)
2007
- #print('***unexpected keys:', unexcepted)
2008
- return model
2009
 
 
 
 
2010
 
 
2011
 
2012
  def get_codebook_indices(self, visual_encoder, image):
2013
  with torch.no_grad():
2014
- with self.maybe_autocast():
2015
- image_embeds = visual_encoder.ln_vision(visual_encoder(image))
2016
- image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
2017
  query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
2018
  query_output = self.Qformer.bert(
2019
  query_embeds=query_tokens,
@@ -2029,7 +2347,7 @@ class Blip2QformerQuantizer(Blip2Base):
2029
  query_output_up = self.decode_task_layer(quant)
2030
 
2031
  return embed_ind, query_output_up
2032
-
2033
  def get_codebook_entry(self, indices):
2034
  with torch.no_grad():
2035
  quant_embedding = self.quantize.get_codebook_entry(indices)
@@ -2058,12 +2376,10 @@ class Blip2QformerQuantizer(Blip2Base):
2058
  precision="fp32",
2059
  cache_dir="./"):
2060
  visual_encoder = create_eva_vit_g(img_size, drop_path_rate, use_grad_checkpoint, precision, cache_dir=cache_dir)
2061
- visual_encoder.ln_vision = LayerNorm(visual_encoder.num_features)
2062
  for name, param in visual_encoder.named_parameters():
2063
  param.requires_grad = False
2064
  visual_encoder = visual_encoder.eval()
2065
- visual_encoder.train = disabled_train
2066
- logging.info("freeze vision encoder")
2067
  visual_encoder.ln_vision.weight.requires_grad = False
2068
  visual_encoder.ln_vision.bias.requires_grad = False
2069
  return visual_encoder
@@ -2094,12 +2410,10 @@ class Seed2Tokenizer(PreTrainedModel):
2094
  self.register_buffer("latents",torch.randn(shape_latents, generator=None, layout=torch.strided))
2095
 
2096
 
2097
- shape_noise = torch.Size([1, 1024])
2098
- self.register_buffer("noise",torch.randn(shape_noise, generator=None, layout=torch.strided))
2099
-
2100
  self.model = model
2101
  self.processor = processor
2102
- self.visual_encoder = VisionTransformerEvaClip(
2103
  img_size=image_size,
2104
  patch_size=14,
2105
  use_mean_pooling=False,
@@ -2128,10 +2442,13 @@ class Seed2Tokenizer(PreTrainedModel):
2128
  if len(image_torch.shape) == 3:
2129
  image_torch = image_torch.unsqueeze(0)
2130
 
 
 
2131
  # img = image_torch.to(self.device)
2132
  img = image_torch
2133
  #if self.fp16:
2134
  # img = img.half()
 
2135
  with torch.no_grad():
2136
  id, _ = self.model.get_codebook_indices(visual_encoder, img)
2137
  return id.view(img.shape[0], -1)
@@ -2146,7 +2463,6 @@ class Seed2Tokenizer(PreTrainedModel):
2146
 
2147
  image = diffusion_model(
2148
  image_embeds=image_embeds,
2149
- negative_image_embeds=negative_image_embeds,
2150
  guidance_scale=guidance_scale,
2151
  noise_level=noise_level,
2152
  num_inference_steps=num_inference_steps,
@@ -2178,5 +2494,6 @@ class Seed2Tokenizer(PreTrainedModel):
2178
  image_torch = self.processor(image_pil)
2179
 
2180
  image_torch = image_torch.to(self.device)
 
2181
  return self.encode(image_torch, visual_encoder)
2182
 
 
36
  from PIL import Image
37
  from torchvision import transforms
38
 
39
+ import torch.utils.checkpoint as checkpoint
40
+
41
  DIFFUSION_NAME = 'stabilityai/stable-diffusion-2-1-unclip'
42
 
43
  # from qformer.qformer_quantizer import Blip2QformerQuantizer
 
67
  import torch.distributed as dist
68
  import torch.nn.functional as F
69
 
70
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
71
 
 
72
  from transformers import BertTokenizer
73
 
74
  import math
 
131
 
132
  #torch.set_printoptions(profile="full")
133
 
134
+ class DropPathEvaVit(nn.Module):
135
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
136
+ """
137
+ def __init__(self, drop_prob=None):
138
+ super(DropPathEvaVit, self).__init__()
139
+ self.drop_prob = drop_prob
140
+
141
+ def forward(self, x):
142
+ return drop_path(x, self.drop_prob, self.training)
143
+
144
+ def extra_repr(self) -> str:
145
+ return 'p={}'.format(self.drop_prob)
146
+
147
+
148
+ class MlpEvaVit(nn.Module):
149
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
150
+ super().__init__()
151
+ out_features = out_features or in_features
152
+ hidden_features = hidden_features or in_features
153
+ self.fc1 = nn.Linear(in_features, hidden_features)
154
+ self.act = act_layer()
155
+ self.fc2 = nn.Linear(hidden_features, out_features)
156
+ self.drop = nn.Dropout(drop)
157
+
158
+ def forward(self, x):
159
+ x = self.fc1(x)
160
+ x = self.act(x)
161
+ # x = self.drop(x)
162
+ # commit this for the orignal BERT implement
163
+ x = self.fc2(x)
164
+ x = self.drop(x)
165
+ return x
166
+
167
+
168
+ class AttentionEvaVit(nn.Module):
169
+ def __init__(self,
170
+ dim,
171
+ num_heads=8,
172
+ qkv_bias=False,
173
+ qk_scale=None,
174
+ attn_drop=0.,
175
+ proj_drop=0.,
176
+ window_size=None,
177
+ attn_head_dim=None):
178
+ super().__init__()
179
+ self.num_heads = num_heads
180
+ head_dim = dim // num_heads
181
+ if attn_head_dim is not None:
182
+ head_dim = attn_head_dim
183
+ all_head_dim = head_dim * self.num_heads
184
+ self.scale = qk_scale or head_dim**-0.5
185
+
186
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
187
+ if qkv_bias:
188
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
189
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
190
+ else:
191
+ self.q_bias = None
192
+ self.v_bias = None
193
+
194
+ if window_size:
195
+ self.window_size = window_size
196
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
197
+ self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance,
198
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
199
+ # cls to token & token 2 cls & cls to cls
200
+
201
+ # get pair-wise relative position index for each token inside the window
202
+ coords_h = torch.arange(window_size[0])
203
+ coords_w = torch.arange(window_size[1])
204
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
205
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
206
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
207
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
208
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
209
+ relative_coords[:, :, 1] += window_size[1] - 1
210
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
211
+ relative_position_index = \
212
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
213
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
214
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
215
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
216
+ relative_position_index[0, 0] = self.num_relative_distance - 1
217
+
218
+ self.register_buffer("relative_position_index", relative_position_index)
219
+ else:
220
+ self.window_size = None
221
+ self.relative_position_bias_table = None
222
+ self.relative_position_index = None
223
+
224
+ self.attn_drop = nn.Dropout(attn_drop)
225
+ self.proj = nn.Linear(all_head_dim, dim)
226
+ self.proj_drop = nn.Dropout(proj_drop)
227
+
228
+ def forward(self, x, rel_pos_bias=None):
229
+ B, N, C = x.shape
230
+ qkv_bias = None
231
+ if self.q_bias is not None:
232
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
233
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
234
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
235
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
236
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
237
+
238
+ q = q * self.scale
239
+ attn = (q @ k.transpose(-2, -1))
240
+
241
+ if self.relative_position_bias_table is not None:
242
+ relative_position_bias = \
243
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
244
+ self.window_size[0] * self.window_size[1] + 1,
245
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
246
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
247
+ attn = attn + relative_position_bias.unsqueeze(0)
248
+
249
+ if rel_pos_bias is not None:
250
+ attn = attn + rel_pos_bias
251
+
252
+ attn = attn.softmax(dim=-1)
253
+ attn = self.attn_drop(attn)
254
+
255
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
256
+ x = self.proj(x)
257
+ x = self.proj_drop(x)
258
+ return x
259
+
260
+
261
+ class BlockEvaVit(nn.Module):
262
+ def __init__(self,
263
+ dim,
264
+ num_heads,
265
+ mlp_ratio=4.,
266
+ qkv_bias=False,
267
+ qk_scale=None,
268
+ drop=0.,
269
+ attn_drop=0.,
270
+ drop_path=0.,
271
+ init_values=None,
272
+ act_layer=nn.GELU,
273
+ norm_layer=nn.LayerNorm,
274
+ window_size=None,
275
+ attn_head_dim=None):
276
+ super().__init__()
277
+ self.norm1 = norm_layer(dim)
278
+ self.attn = AttentionEvaVit(dim,
279
+ num_heads=num_heads,
280
+ qkv_bias=qkv_bias,
281
+ qk_scale=qk_scale,
282
+ attn_drop=attn_drop,
283
+ proj_drop=drop,
284
+ window_size=window_size,
285
+ attn_head_dim=attn_head_dim)
286
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
287
+ self.drop_path = DropPathEvaVit(drop_path) if drop_path > 0. else nn.Identity()
288
+ self.norm2 = norm_layer(dim)
289
+ mlp_hidden_dim = int(dim * mlp_ratio)
290
+ self.mlp = MlpEvaVit(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
291
+
292
+ if init_values is not None and init_values > 0:
293
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
294
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
295
+ else:
296
+ self.gamma_1, self.gamma_2 = None, None
297
+
298
+ def forward(self, x, rel_pos_bias=None):
299
+ if self.gamma_1 is None:
300
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
301
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
302
+ else:
303
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
304
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
305
+ return x
306
+
307
+
308
+ class PatchEmbedEvaVit(nn.Module):
309
+ """ Image to Patch Embedding
310
+ """
311
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
312
+ super().__init__()
313
+ img_size = to_2tuple(img_size)
314
+ patch_size = to_2tuple(patch_size)
315
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
316
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
317
+ self.img_size = img_size
318
+ self.patch_size = patch_size
319
+ self.num_patches = num_patches
320
+
321
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
322
+
323
+ def forward(self, x, **kwargs):
324
+ B, C, H, W = x.shape
325
+ # FIXME look at relaxing size constraints
326
+ assert H == self.img_size[0] and W == self.img_size[1], \
327
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
328
+ x = self.proj(x).flatten(2).transpose(1, 2)
329
+ return x
330
+
331
+
332
+ class RelativePositionBiasEvaVit(nn.Module):
333
+ def __init__(self, window_size, num_heads):
334
+ super().__init__()
335
+ self.window_size = window_size
336
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
337
+ self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance,
338
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
339
+ # cls to token & token 2 cls & cls to cls
340
+
341
+ # get pair-wise relative position index for each token inside the window
342
+ coords_h = torch.arange(window_size[0])
343
+ coords_w = torch.arange(window_size[1])
344
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
345
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
346
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
347
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
348
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
349
+ relative_coords[:, :, 1] += window_size[1] - 1
350
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
351
+ relative_position_index = \
352
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
353
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
354
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
355
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
356
+ relative_position_index[0, 0] = self.num_relative_distance - 1
357
+
358
+ self.register_buffer("relative_position_index", relative_position_index)
359
+
360
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
361
+
362
+ def forward(self):
363
+ relative_position_bias = \
364
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
365
+ self.window_size[0] * self.window_size[1] + 1,
366
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
367
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
368
+
369
+
370
+ class VisionTransformerEvaVit(nn.Module):
371
+ """ Vision Transformer with support for patch or hybrid CNN input stage
372
+ """
373
+ def __init__(self,
374
+ img_size=224,
375
+ patch_size=16,
376
+ in_chans=3,
377
+ num_classes=1000,
378
+ embed_dim=768,
379
+ depth=12,
380
+ num_heads=12,
381
+ mlp_ratio=4.,
382
+ qkv_bias=False,
383
+ qk_scale=None,
384
+ drop_rate=0.,
385
+ attn_drop_rate=0.,
386
+ drop_path_rate=0.,
387
+ norm_layer=nn.LayerNorm,
388
+ init_values=None,
389
+ use_abs_pos_emb=True,
390
+ use_rel_pos_bias=False,
391
+ use_shared_rel_pos_bias=False,
392
+ use_mean_pooling=True,
393
+ init_scale=0.001,
394
+ use_checkpoint=False):
395
+ super().__init__()
396
+ self.image_size = img_size
397
+ self.num_classes = num_classes
398
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
399
+
400
+ self.patch_embed = PatchEmbedEvaVit(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
401
+ num_patches = self.patch_embed.num_patches
402
+
403
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
404
+ if use_abs_pos_emb:
405
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
406
+ else:
407
+ self.pos_embed = None
408
+ self.pos_drop = nn.Dropout(p=drop_rate)
409
+
410
+ if use_shared_rel_pos_bias:
411
+ self.rel_pos_bias = RelativePositionBiasEvaVit(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
412
+ else:
413
+ self.rel_pos_bias = None
414
+ self.use_checkpoint = use_checkpoint
415
+
416
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
417
+ self.use_rel_pos_bias = use_rel_pos_bias
418
+ self.blocks = nn.ModuleList([
419
+ BlockEvaVit(dim=embed_dim,
420
+ num_heads=num_heads,
421
+ mlp_ratio=mlp_ratio,
422
+ qkv_bias=qkv_bias,
423
+ qk_scale=qk_scale,
424
+ drop=drop_rate,
425
+ attn_drop=attn_drop_rate,
426
+ drop_path=dpr[i],
427
+ norm_layer=norm_layer,
428
+ init_values=init_values,
429
+ window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) for i in range(depth)
430
+ ])
431
+ # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
432
+ # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
433
+ # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
434
+
435
+ if self.pos_embed is not None:
436
+ trunc_normal_(self.pos_embed, std=.02)
437
+ trunc_normal_(self.cls_token, std=.02)
438
+ # trunc_normal_(self.mask_token, std=.02)
439
+ # if isinstance(self.head, nn.Linear):
440
+ # trunc_normal_(self.head.weight, std=.02)
441
+ self.apply(self._init_weights)
442
+ self.fix_init_weight()
443
+ self.ln_vision = nn.LayerNorm(self.num_features)
444
+
445
+ def fix_init_weight(self):
446
+ def rescale(param, layer_id):
447
+ param.div_(math.sqrt(2.0 * layer_id))
448
+
449
+ for layer_id, layer in enumerate(self.blocks):
450
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
451
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
452
+
453
+ def _init_weights(self, m):
454
+ if isinstance(m, nn.Linear):
455
+ trunc_normal_(m.weight, std=.02)
456
+ if isinstance(m, nn.Linear) and m.bias is not None:
457
+ nn.init.constant_(m.bias, 0)
458
+ elif isinstance(m, nn.LayerNorm):
459
+ nn.init.constant_(m.bias, 0)
460
+ nn.init.constant_(m.weight, 1.0)
461
+
462
+ _initialize_weights = _init_weights
463
+
464
+ def get_classifier(self):
465
+ return self.head
466
+
467
+ def reset_classifier(self, num_classes, global_pool=''):
468
+ self.num_classes = num_classes
469
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
470
+
471
+ def forward_features(self, x):
472
+ x = self.patch_embed(x)
473
+ batch_size, seq_len, _ = x.size()
474
+
475
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
476
+ x = torch.cat((cls_tokens, x), dim=1)
477
+ if self.pos_embed is not None:
478
+ x = x + self.pos_embed
479
+ x = self.pos_drop(x)
480
+
481
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
482
+ for blk in self.blocks:
483
+ if self.use_checkpoint:
484
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias)
485
+ else:
486
+ x = blk(x, rel_pos_bias)
487
+ return x
488
+
489
+ def forward(self, x):
490
+ x = self.forward_features(x)
491
+ # x = self.head(x)
492
+ return x
493
+
494
+ def get_intermediate_layers(self, x):
495
+ x = self.patch_embed(x)
496
+ batch_size, seq_len, _ = x.size()
497
+
498
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
499
+ x = torch.cat((cls_tokens, x), dim=1)
500
+ if self.pos_embed is not None:
501
+ x = x + self.pos_embed
502
+ x = self.pos_drop(x)
503
+
504
+ features = []
505
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
506
+ for blk in self.blocks:
507
+ x = blk(x, rel_pos_bias)
508
+ features.append(x)
509
+
510
+ return features
511
+
512
+ def get_num_layer(self, var_name=""):
513
+ if var_name in ("cls_token", "mask_token", "pos_embed"):
514
+ return 0
515
+ elif var_name.startswith("patch_embed"):
516
+ return 0
517
+ elif var_name.startswith("rel_pos_bias"):
518
+ return len(self.blocks) - 1
519
+ elif var_name.startswith("blocks"):
520
+ layer_id = int(var_name.split('.')[1])
521
+ return layer_id + 1
522
+ else:
523
+ return len(self.blocks)
524
+
525
+
526
+ def create_eva_vit_g(img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16", cache_dir="./",):
527
+ model = VisionTransformerEvaVit(
528
+ img_size=img_size,
529
+ patch_size=14,
530
+ use_mean_pooling=False,
531
+ embed_dim=1408,
532
+ depth=39,
533
+ num_heads=1408 // 88,
534
+ mlp_ratio=4.3637,
535
+ qkv_bias=True,
536
+ drop_path_rate=drop_path_rate,
537
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
538
+ use_checkpoint=use_checkpoint,
539
+ )
540
+ cache_path = cache_dir
541
+ state_dict = torch.load(cache_path+"/eva_vit_g.pth", map_location="cpu")
542
+ interpolate_pos_embed(model, state_dict)
543
+
544
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
545
+ print(incompatible_keys)
546
+
547
+ return model
548
 
549
  class BertEmbeddings(nn.Module):
550
  """Construct the embeddings from word and position embeddings."""
 
1906
 
1907
  return x
1908
 
 
 
 
 
1909
 
1910
  @torch.no_grad()
1911
  def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ""):
 
2053
  tokenizer.add_special_tokens({"bos_token": "[DEC]"})
2054
  return tokenizer
2055
 
 
 
 
 
 
 
 
 
 
 
2056
  @classmethod
2057
  def init_Qformer(cls, encoder_config, num_query_token, vision_width, cross_attention_freq=2, cache_dir=""):
2058
+ #print ("loading")
2059
  encoder_config = BertConfig.from_pretrained("bert-base-uncased")
2060
  encoder_config.encoder_width = vision_width
2061
  # insert cross-attention layer every other block
 
2067
  query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
2068
  return Qformer, query_tokens
2069
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2070
 
2071
 
2072
  class VectorQuantizer2(nn.Module):
 
2310
  )
2311
  self.distill_image_proj = nn.Linear(num_query_token * 32, image_features_dim)
2312
 
2313
+ def get_codebook_indices(self, visual_encoder, image):
2314
+ with torch.no_grad():
2315
+ image_embeds = visual_encoder.ln_vision(visual_encoder(image))
2316
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
2317
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
2318
+ query_output = self.Qformer.bert(
2319
+ query_embeds=query_tokens,
2320
+ encoder_hidden_states=image_embeds,
2321
+ encoder_attention_mask=image_atts,
2322
+ return_dict=True,
2323
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2324
 
2325
+ query_output_down = self.encode_task_layer(query_output.last_hidden_state)
2326
+ quant, loss_embed, embed_ind = self.quantize(query_output_down)
2327
+ embed_ind = embed_ind.reshape(quant.shape[0], -1)
2328
 
2329
+ return embed_ind
2330
 
2331
  def get_codebook_indices(self, visual_encoder, image):
2332
  with torch.no_grad():
2333
+ image_embeds = visual_encoder.ln_vision(visual_encoder(image))
2334
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
 
2335
  query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
2336
  query_output = self.Qformer.bert(
2337
  query_embeds=query_tokens,
 
2347
  query_output_up = self.decode_task_layer(quant)
2348
 
2349
  return embed_ind, query_output_up
2350
+
2351
  def get_codebook_entry(self, indices):
2352
  with torch.no_grad():
2353
  quant_embedding = self.quantize.get_codebook_entry(indices)
 
2376
  precision="fp32",
2377
  cache_dir="./"):
2378
  visual_encoder = create_eva_vit_g(img_size, drop_path_rate, use_grad_checkpoint, precision, cache_dir=cache_dir)
2379
+ visual_encoder.ln_vision = nn.LayerNorm(visual_encoder.num_features)
2380
  for name, param in visual_encoder.named_parameters():
2381
  param.requires_grad = False
2382
  visual_encoder = visual_encoder.eval()
 
 
2383
  visual_encoder.ln_vision.weight.requires_grad = False
2384
  visual_encoder.ln_vision.bias.requires_grad = False
2385
  return visual_encoder
 
2410
  self.register_buffer("latents",torch.randn(shape_latents, generator=None, layout=torch.strided))
2411
 
2412
 
2413
+
 
 
2414
  self.model = model
2415
  self.processor = processor
2416
+ self.visual_encoder = VisionTransformerEvaVit(
2417
  img_size=image_size,
2418
  patch_size=14,
2419
  use_mean_pooling=False,
 
2442
  if len(image_torch.shape) == 3:
2443
  image_torch = image_torch.unsqueeze(0)
2444
 
2445
+ image_torch = image_torch.to(dtype=self.latents.dtype)
2446
+ image_torch = image_torch.to(self.device)
2447
  # img = image_torch.to(self.device)
2448
  img = image_torch
2449
  #if self.fp16:
2450
  # img = img.half()
2451
+ print (img.dtype)
2452
  with torch.no_grad():
2453
  id, _ = self.model.get_codebook_indices(visual_encoder, img)
2454
  return id.view(img.shape[0], -1)
 
2463
 
2464
  image = diffusion_model(
2465
  image_embeds=image_embeds,
 
2466
  guidance_scale=guidance_scale,
2467
  noise_level=noise_level,
2468
  num_inference_steps=num_inference_steps,
 
2494
  image_torch = self.processor(image_pil)
2495
 
2496
  image_torch = image_torch.to(self.device)
2497
+ image_torch = image_torch.to(dtype=self.latents.dtype)
2498
  return self.encode(image_torch, visual_encoder)
2499