huu-ontocord commited on
Commit
c56cca2
·
verified ·
1 Parent(s): 268f5a3

Create seed2_tokenizer.py

Browse files
Files changed (1) hide show
  1. seed2_tokenizer.py +2190 -0
seed2_tokenizer.py ADDED
@@ -0,0 +1,2190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+
7
+ Based on timm code base
8
+ https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ """
10
+
11
+ """
12
+ Copyright (c) 2023, salesforce.com, inc.
13
+ All rights reserved.
14
+ SPDX-License-Identifier: BSD-3-Clause
15
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
16
+ """
17
+ """
18
+ Copyright (c) 2023, salesforce.com, inc.
19
+ All rights reserved.
20
+ SPDX-License-Identifier: BSD-3-Clause
21
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
22
+ """
23
+
24
+ import torch.nn as nn
25
+ import torch
26
+ # import math
27
+ # from torchvision import transforms
28
+ import os
29
+ # from timm.models import create_model
30
+ from typing import Any, Dict, List, Optional, Union
31
+ from transformers import LlamaTokenizer
32
+ from diffusers import DiffusionPipeline
33
+ # from torchvision.transforms.functional import pil_to_tensor
34
+
35
+ # import torch
36
+ from PIL import Image
37
+ from torchvision import transforms
38
+
39
+ WEIGHTS_NAME = 'seed_quantizer.pt'
40
+ DIFFUSION_NAME = 'stabilityai/stable-diffusion-2-1-unclip'
41
+
42
+ # from qformer.qformer_quantizer import Blip2QformerQuantizer
43
+ # from diffusers import StableUnCLIPImg2ImgPipeline
44
+
45
+ from pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline
46
+
47
+ import logging
48
+
49
+ import torch
50
+ import torch.distributed as dist
51
+ import torch.nn as nn
52
+ from torch.cuda.amp import autocast as autocast
53
+ from torch.nn import functional as F
54
+ import numpy as np
55
+ from functools import partial
56
+ from einops import rearrange
57
+
58
+ import contextlib
59
+ import logging
60
+ import os
61
+ import time
62
+ import datetime
63
+
64
+ import torch
65
+ import torch.nn as nn
66
+ import torch.distributed as dist
67
+ import torch.nn.functional as F
68
+
69
+
70
+ from eva_vit import create_eva_vit_g, VisionTransformerEvaClip
71
+ from transformers import BertTokenizer
72
+
73
+ import math
74
+ import torch
75
+ import torch.nn as nn
76
+ import torch.nn.functional as F
77
+ from functools import partial
78
+
79
+ from timm.models.vision_transformer import _cfg, PatchEmbed
80
+ from timm.models.registry import register_model
81
+ from timm.models.layers import trunc_normal_, DropPath
82
+ from timm.models.helpers import named_apply, adapt_input_conv
83
+
84
+ """
85
+ * Copyright (c) 2023, salesforce.com, inc.
86
+ * All rights reserved.
87
+ * SPDX-License-Identifier: BSD-3-Clause
88
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
89
+ * By Junnan Li
90
+ * Based on huggingface code base
91
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
92
+ """
93
+
94
+ import math
95
+ import os
96
+ import warnings
97
+ from dataclasses import dataclass
98
+ from typing import Optional, Tuple, Dict, Any
99
+
100
+ import torch
101
+ from torch import Tensor, device, dtype, nn
102
+ import torch.utils.checkpoint
103
+ from torch.nn import CrossEntropyLoss
104
+ import torch.nn.functional as F
105
+ import numpy as np
106
+
107
+
108
+
109
+ from transformers.activations import ACT2FN
110
+ from transformers.file_utils import (
111
+ ModelOutput, )
112
+ from transformers.modeling_outputs import (
113
+ BaseModelOutputWithPastAndCrossAttentions,
114
+ BaseModelOutputWithPoolingAndCrossAttentions,
115
+ CausalLMOutputWithCrossAttentions,
116
+ MaskedLMOutput,
117
+ MultipleChoiceModelOutput,
118
+ NextSentencePredictorOutput,
119
+ QuestionAnsweringModelOutput,
120
+ SequenceClassifierOutput,
121
+ TokenClassifierOutput,
122
+ )
123
+ from transformers.modeling_utils import (
124
+ PreTrainedModel,
125
+ apply_chunking_to_forward,
126
+ find_pruneable_heads_and_indices,
127
+ prune_linear_layer,
128
+ )
129
+ from transformers.models.bert.configuration_bert import BertConfig
130
+
131
+ #torch.set_printoptions(profile="full")
132
+
133
+
134
+ class BertEmbeddings(nn.Module):
135
+ """Construct the embeddings from word and position embeddings."""
136
+ def __init__(self, config):
137
+ super().__init__()
138
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
139
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
140
+
141
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
142
+ # any TensorFlow checkpoint file
143
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
144
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
145
+
146
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
147
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
148
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
149
+
150
+ self.config = config
151
+
152
+ def forward(
153
+ self,
154
+ input_ids=None,
155
+ position_ids=None,
156
+ query_embeds=None,
157
+ past_key_values_length=0,
158
+ ):
159
+ if input_ids is not None:
160
+ seq_length = input_ids.size()[1]
161
+ else:
162
+ seq_length = 0
163
+
164
+ if position_ids is None:
165
+ position_ids = self.position_ids[:, past_key_values_length:seq_length + past_key_values_length].clone()
166
+
167
+ if input_ids is not None:
168
+ embeddings = self.word_embeddings(input_ids)
169
+ if self.position_embedding_type == "absolute":
170
+ position_embeddings = self.position_embeddings(position_ids)
171
+ embeddings = embeddings + position_embeddings
172
+
173
+ if query_embeds is not None:
174
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
175
+ #print(query_embeds.shape, embeddings.shape)
176
+ else:
177
+ embeddings = query_embeds
178
+
179
+ embeddings = self.LayerNorm(embeddings)
180
+ embeddings = self.dropout(embeddings)
181
+ return embeddings
182
+
183
+
184
+ class BertSelfAttention(nn.Module):
185
+ def __init__(self, config, is_cross_attention):
186
+ super().__init__()
187
+ self.config = config
188
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
189
+ raise ValueError("The hidden size (%d) is not a multiple of the number of attention "
190
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
191
+
192
+ self.num_attention_heads = config.num_attention_heads
193
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
194
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
195
+
196
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
197
+ if is_cross_attention:
198
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
199
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
200
+ else:
201
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
202
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
203
+
204
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
205
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
206
+ if (self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query"):
207
+ self.max_position_embeddings = config.max_position_embeddings
208
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
209
+ self.save_attention = False
210
+
211
+ def save_attn_gradients(self, attn_gradients):
212
+ self.attn_gradients = attn_gradients
213
+
214
+ def get_attn_gradients(self):
215
+ return self.attn_gradients
216
+
217
+ def save_attention_map(self, attention_map):
218
+ self.attention_map = attention_map
219
+
220
+ def get_attention_map(self):
221
+ return self.attention_map
222
+
223
+ def transpose_for_scores(self, x):
224
+ new_x_shape = x.size()[:-1] + (
225
+ self.num_attention_heads,
226
+ self.attention_head_size,
227
+ )
228
+ x = x.view(*new_x_shape)
229
+ return x.permute(0, 2, 1, 3)
230
+
231
+ def forward(
232
+ self,
233
+ hidden_states,
234
+ attention_mask=None,
235
+ head_mask=None,
236
+ encoder_hidden_states=None,
237
+ encoder_attention_mask=None,
238
+ past_key_value=None,
239
+ output_attentions=False,
240
+ ):
241
+
242
+ # If this is instantiated as a cross-attention module, the keys
243
+ # and values come from an encoder; the attention mask needs to be
244
+ # such that the encoder's padding tokens are not attended to.
245
+ is_cross_attention = encoder_hidden_states is not None
246
+
247
+ if is_cross_attention:
248
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
249
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
250
+ #print(key_layer.shape, value_layer.shape)
251
+ attention_mask = encoder_attention_mask
252
+ elif past_key_value is not None:
253
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
254
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
255
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
256
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
257
+ #print(past_key_value[0].shape, key_layer.shape)
258
+ else:
259
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
260
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
261
+
262
+ mixed_query_layer = self.query(hidden_states)
263
+
264
+ query_layer = self.transpose_for_scores(mixed_query_layer)
265
+ # if past_key_value is not None:
266
+ # print(query_layer.shape)
267
+
268
+ past_key_value = (key_layer, value_layer)
269
+ #print(key_layer.shape, value_layer.shape)
270
+
271
+ # Take the dot product between "query" and "key" to get the raw attention scores.
272
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
273
+ #if is_cross_attention:
274
+ # if attention_scores.shape[2] == 32:
275
+ # attention_scores_save = attention_scores[0].detach().cpu().numpy()
276
+ # print(attention_scores_save.shape)
277
+ # np.save('attention_scores_causal_text_child.npy', attention_scores_save)
278
+
279
+ if (self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query"):
280
+ seq_length = hidden_states.size()[1]
281
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
282
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
283
+ distance = position_ids_l - position_ids_r
284
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
285
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
286
+
287
+ if self.position_embedding_type == "relative_key":
288
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
289
+ attention_scores = attention_scores + relative_position_scores
290
+ elif self.position_embedding_type == "relative_key_query":
291
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
292
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
293
+ attention_scores = (attention_scores + relative_position_scores_query + relative_position_scores_key)
294
+
295
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
296
+ if attention_mask is not None:
297
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
298
+ attention_scores = attention_scores + attention_mask
299
+
300
+ # Normalize the attention scores to probabilities.
301
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
302
+
303
+ if is_cross_attention and self.save_attention:
304
+ self.save_attention_map(attention_probs)
305
+ attention_probs.register_hook(self.save_attn_gradients)
306
+
307
+ # This is actually dropping out entire tokens to attend to, which might
308
+ # seem a bit unusual, but is taken from the original Transformer paper.
309
+ attention_probs_dropped = self.dropout(attention_probs)
310
+
311
+ # Mask heads if we want to
312
+ if head_mask is not None:
313
+ attention_probs_dropped = attention_probs_dropped * head_mask
314
+
315
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
316
+
317
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
318
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, )
319
+ context_layer = context_layer.view(*new_context_layer_shape)
320
+
321
+ outputs = ((context_layer, attention_probs) if output_attentions else (context_layer, ))
322
+
323
+ outputs = outputs + (past_key_value, )
324
+ return outputs
325
+
326
+
327
+ class BertSelfOutput(nn.Module):
328
+ def __init__(self, config):
329
+ super().__init__()
330
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
331
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
332
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
333
+
334
+ def forward(self, hidden_states, input_tensor):
335
+ hidden_states = self.dense(hidden_states)
336
+ hidden_states = self.dropout(hidden_states)
337
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
338
+ return hidden_states
339
+
340
+
341
+ class BertAttention(nn.Module):
342
+ def __init__(self, config, is_cross_attention=False):
343
+ super().__init__()
344
+ self.self = BertSelfAttention(config, is_cross_attention)
345
+ self.output = BertSelfOutput(config)
346
+ self.pruned_heads = set()
347
+
348
+ def prune_heads(self, heads):
349
+ if len(heads) == 0:
350
+ return
351
+ heads, index = find_pruneable_heads_and_indices(
352
+ heads,
353
+ self.self.num_attention_heads,
354
+ self.self.attention_head_size,
355
+ self.pruned_heads,
356
+ )
357
+
358
+ # Prune linear layers
359
+ self.self.query = prune_linear_layer(self.self.query, index)
360
+ self.self.key = prune_linear_layer(self.self.key, index)
361
+ self.self.value = prune_linear_layer(self.self.value, index)
362
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
363
+
364
+ # Update hyper params and store pruned heads
365
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
366
+ self.self.all_head_size = (self.self.attention_head_size * self.self.num_attention_heads)
367
+ self.pruned_heads = self.pruned_heads.union(heads)
368
+
369
+ def forward(
370
+ self,
371
+ hidden_states,
372
+ attention_mask=None,
373
+ head_mask=None,
374
+ encoder_hidden_states=None,
375
+ encoder_attention_mask=None,
376
+ past_key_value=None,
377
+ output_attentions=False,
378
+ ):
379
+ self_outputs = self.self(
380
+ hidden_states,
381
+ attention_mask,
382
+ head_mask,
383
+ encoder_hidden_states,
384
+ encoder_attention_mask,
385
+ past_key_value,
386
+ output_attentions,
387
+ )
388
+ attention_output = self.output(self_outputs[0], hidden_states)
389
+
390
+ outputs = (attention_output, ) + self_outputs[1:] # add attentions if we output them
391
+ return outputs
392
+
393
+
394
+ class BertIntermediate(nn.Module):
395
+ def __init__(self, config):
396
+ super().__init__()
397
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
398
+ if isinstance(config.hidden_act, str):
399
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
400
+ else:
401
+ self.intermediate_act_fn = config.hidden_act
402
+
403
+ def forward(self, hidden_states):
404
+ hidden_states = self.dense(hidden_states)
405
+ hidden_states = self.intermediate_act_fn(hidden_states)
406
+ return hidden_states
407
+
408
+
409
+ class BertOutput(nn.Module):
410
+ def __init__(self, config):
411
+ super().__init__()
412
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
413
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
414
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
415
+
416
+ def forward(self, hidden_states, input_tensor):
417
+ hidden_states = self.dense(hidden_states)
418
+ hidden_states = self.dropout(hidden_states)
419
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
420
+ return hidden_states
421
+
422
+
423
+ class BertLayer(nn.Module):
424
+ def __init__(self, config, layer_num):
425
+ super().__init__()
426
+ self.config = config
427
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
428
+ self.seq_len_dim = 1
429
+ self.attention = BertAttention(config)
430
+ self.layer_num = layer_num
431
+ if (self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0):
432
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
433
+ self.has_cross_attention = True
434
+ else:
435
+ self.has_cross_attention = False
436
+ self.intermediate = BertIntermediate(config)
437
+ self.output = BertOutput(config)
438
+
439
+ self.intermediate_query = BertIntermediate(config)
440
+ self.output_query = BertOutput(config)
441
+
442
+ def forward(
443
+ self,
444
+ hidden_states,
445
+ attention_mask=None,
446
+ head_mask=None,
447
+ encoder_hidden_states=None,
448
+ encoder_attention_mask=None,
449
+ past_key_value=None,
450
+ output_attentions=False,
451
+ query_length=0,
452
+ ):
453
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
454
+ self_attn_past_key_value = (past_key_value[:2] if past_key_value is not None else None)
455
+ # if past_key_value is not None:
456
+ # print(hidden_states.shape, attention_mask.shape)
457
+ #print(hidden_states.shape, attention_mask.shape)
458
+ # casual attention for query embeds with self attention
459
+ self_attention_outputs = self.attention(
460
+ hidden_states,
461
+ attention_mask,
462
+ head_mask,
463
+ output_attentions=output_attentions,
464
+ past_key_value=self_attn_past_key_value,
465
+ )
466
+ #print('attention_mask', attention_mask.shape)
467
+ # if attention_mask.shape[-1] == 77:
468
+ # print('attention_mask', attention_mask[0])
469
+ attention_output = self_attention_outputs[0]
470
+ outputs = self_attention_outputs[1:-1]
471
+
472
+ present_key_value = self_attention_outputs[-1]
473
+ #print(present_key_value[0].shape)
474
+
475
+ if query_length > 0:
476
+ query_attention_output = attention_output[:, :query_length, :]
477
+
478
+ if self.has_cross_attention:
479
+ assert (encoder_hidden_states is not None), "encoder_hidden_states must be given for cross-attention layers"
480
+ #print(attention_mask.shape)
481
+ cross_attention_outputs = self.crossattention(
482
+ query_attention_output,
483
+ attention_mask,
484
+ head_mask,
485
+ encoder_hidden_states,
486
+ encoder_attention_mask,
487
+ output_attentions=output_attentions,
488
+ )
489
+ query_attention_output = cross_attention_outputs[0]
490
+ outputs = (outputs + cross_attention_outputs[1:-1]) # add cross attentions if we output attention weights
491
+
492
+ layer_output = apply_chunking_to_forward(
493
+ self.feed_forward_chunk_query,
494
+ self.chunk_size_feed_forward,
495
+ self.seq_len_dim,
496
+ query_attention_output,
497
+ )
498
+ if attention_output.shape[1] > query_length:
499
+ layer_output_text = apply_chunking_to_forward(
500
+ self.feed_forward_chunk,
501
+ self.chunk_size_feed_forward,
502
+ self.seq_len_dim,
503
+ attention_output[:, query_length:, :],
504
+ )
505
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
506
+ else:
507
+ layer_output = apply_chunking_to_forward(
508
+ self.feed_forward_chunk,
509
+ self.chunk_size_feed_forward,
510
+ self.seq_len_dim,
511
+ attention_output,
512
+ )
513
+ outputs = (layer_output, ) + outputs
514
+
515
+ outputs = outputs + (present_key_value, )
516
+
517
+ return outputs
518
+
519
+ def feed_forward_chunk(self, attention_output):
520
+ intermediate_output = self.intermediate(attention_output)
521
+ layer_output = self.output(intermediate_output, attention_output)
522
+ return layer_output
523
+
524
+ def feed_forward_chunk_query(self, attention_output):
525
+ intermediate_output = self.intermediate_query(attention_output)
526
+ layer_output = self.output_query(intermediate_output, attention_output)
527
+ return layer_output
528
+
529
+
530
+ class BertEncoder(nn.Module):
531
+ def __init__(self, config):
532
+ super().__init__()
533
+ self.config = config
534
+ self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)])
535
+
536
+ def forward(
537
+ self,
538
+ hidden_states,
539
+ attention_mask=None,
540
+ head_mask=None,
541
+ encoder_hidden_states=None,
542
+ encoder_attention_mask=None,
543
+ past_key_values=None,
544
+ use_cache=None,
545
+ output_attentions=False,
546
+ output_hidden_states=False,
547
+ return_dict=True,
548
+ query_length=0,
549
+ ):
550
+ all_hidden_states = () if output_hidden_states else None
551
+ all_self_attentions = () if output_attentions else None
552
+ all_cross_attentions = (() if output_attentions and self.config.add_cross_attention else None)
553
+
554
+ next_decoder_cache = () if use_cache else None
555
+
556
+ for i in range(self.config.num_hidden_layers):
557
+ layer_module = self.layer[i]
558
+ if output_hidden_states:
559
+ all_hidden_states = all_hidden_states + (hidden_states, )
560
+
561
+ layer_head_mask = head_mask[i] if head_mask is not None else None
562
+ past_key_value = past_key_values[i] if past_key_values is not None else None
563
+ # if past_key_value is not None:
564
+ # print(past_key_value[0].shape, past_key_value[1].shape)
565
+
566
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
567
+
568
+ if use_cache:
569
+ logger.warn("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
570
+ use_cache = False
571
+
572
+ def create_custom_forward(module):
573
+ def custom_forward(*inputs):
574
+ return module(*inputs, past_key_value, output_attentions, query_length)
575
+
576
+ return custom_forward
577
+
578
+ layer_outputs = torch.utils.checkpoint.checkpoint(
579
+ create_custom_forward(layer_module),
580
+ hidden_states,
581
+ attention_mask,
582
+ layer_head_mask,
583
+ encoder_hidden_states,
584
+ encoder_attention_mask,
585
+ )
586
+ else:
587
+ layer_outputs = layer_module(
588
+ hidden_states,
589
+ attention_mask,
590
+ layer_head_mask,
591
+ encoder_hidden_states,
592
+ encoder_attention_mask,
593
+ past_key_value,
594
+ output_attentions,
595
+ query_length,
596
+ )
597
+ # if past_key_value is not None:
598
+ # print(hidden_states.shape, attention_mask.shape)
599
+ # print(len(past_key_value))
600
+
601
+ hidden_states = layer_outputs[0]
602
+ if use_cache:
603
+ next_decoder_cache += (layer_outputs[-1], )
604
+ #print(layer_outputs[-1][0].shape)
605
+ if output_attentions:
606
+ all_self_attentions = all_self_attentions + (layer_outputs[1], )
607
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2], )
608
+
609
+ if output_hidden_states:
610
+ all_hidden_states = all_hidden_states + (hidden_states, )
611
+
612
+ if not return_dict:
613
+ return tuple(v for v in [
614
+ hidden_states,
615
+ next_decoder_cache,
616
+ all_hidden_states,
617
+ all_self_attentions,
618
+ all_cross_attentions,
619
+ ] if v is not None)
620
+ return BaseModelOutputWithPastAndCrossAttentions(
621
+ last_hidden_state=hidden_states,
622
+ past_key_values=next_decoder_cache,
623
+ hidden_states=all_hidden_states,
624
+ attentions=all_self_attentions,
625
+ cross_attentions=all_cross_attentions,
626
+ )
627
+
628
+
629
+ class BertPooler(nn.Module):
630
+ def __init__(self, config):
631
+ super().__init__()
632
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
633
+ self.activation = nn.Tanh()
634
+
635
+ def forward(self, hidden_states):
636
+ # We "pool" the model by simply taking the hidden state corresponding
637
+ # to the first token.
638
+ first_token_tensor = hidden_states[:, 0]
639
+ pooled_output = self.dense(first_token_tensor)
640
+ pooled_output = self.activation(pooled_output)
641
+ return pooled_output
642
+
643
+
644
+ class BertPredictionHeadTransform(nn.Module):
645
+ def __init__(self, config):
646
+ super().__init__()
647
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
648
+ if isinstance(config.hidden_act, str):
649
+ self.transform_act_fn = ACT2FN[config.hidden_act]
650
+ else:
651
+ self.transform_act_fn = config.hidden_act
652
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
653
+
654
+ def forward(self, hidden_states):
655
+ hidden_states = self.dense(hidden_states)
656
+ hidden_states = self.transform_act_fn(hidden_states)
657
+ hidden_states = self.LayerNorm(hidden_states)
658
+ return hidden_states
659
+
660
+
661
+ class BertLMPredictionHead(nn.Module):
662
+ def __init__(self, config):
663
+ super().__init__()
664
+ self.transform = BertPredictionHeadTransform(config)
665
+
666
+ # The output weights are the same as the input embeddings, but there is
667
+ # an output-only bias for each token.
668
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
669
+
670
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
671
+
672
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
673
+ self.decoder.bias = self.bias
674
+
675
+ def forward(self, hidden_states):
676
+ hidden_states = self.transform(hidden_states)
677
+ hidden_states = self.decoder(hidden_states)
678
+ return hidden_states
679
+
680
+
681
+ class BertOnlyMLMHead(nn.Module):
682
+ def __init__(self, config):
683
+ super().__init__()
684
+ self.predictions = BertLMPredictionHead(config)
685
+
686
+ def forward(self, sequence_output):
687
+ prediction_scores = self.predictions(sequence_output)
688
+ return prediction_scores
689
+
690
+
691
+ class BertPreTrainedModel(PreTrainedModel):
692
+ """
693
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
694
+ models.
695
+ """
696
+
697
+ config_class = BertConfig
698
+ base_model_prefix = "bert"
699
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
700
+
701
+ def _init_weights(self, module):
702
+ """Initialize the weights"""
703
+ if isinstance(module, (nn.Linear, nn.Embedding)):
704
+ # Slightly different from the TF version which uses truncated_normal for initialization
705
+ # cf https://github.com/pytorch/pytorch/pull/5617
706
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
707
+ elif isinstance(module, nn.LayerNorm):
708
+ module.bias.data.zero_()
709
+ module.weight.data.fill_(1.0)
710
+ if isinstance(module, nn.Linear) and module.bias is not None:
711
+ module.bias.data.zero_()
712
+
713
+
714
+ class BertModel(BertPreTrainedModel):
715
+ """
716
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
717
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
718
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
719
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
720
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
721
+ input to the forward pass.
722
+ """
723
+ def __init__(self, config, add_pooling_layer=False):
724
+ super().__init__(config)
725
+ self.config = config
726
+
727
+ self.embeddings = BertEmbeddings(config)
728
+
729
+ self.encoder = BertEncoder(config)
730
+
731
+ self.pooler = BertPooler(config) if add_pooling_layer else None
732
+
733
+ self.init_weights()
734
+
735
+ def get_input_embeddings(self):
736
+ return self.embeddings.word_embeddings
737
+
738
+ def set_input_embeddings(self, value):
739
+ self.embeddings.word_embeddings = value
740
+
741
+ def _prune_heads(self, heads_to_prune):
742
+ """
743
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
744
+ class PreTrainedModel
745
+ """
746
+ for layer, heads in heads_to_prune.items():
747
+ self.encoder.layer[layer].attention.prune_heads(heads)
748
+
749
+ def get_extended_attention_mask(
750
+ self,
751
+ attention_mask: Tensor,
752
+ input_shape: Tuple[int],
753
+ device: device,
754
+ is_decoder: bool,
755
+ is_casual: bool,
756
+ has_query: bool = False,
757
+ ) -> Tensor:
758
+ """
759
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
760
+
761
+ Arguments:
762
+ attention_mask (:obj:`torch.Tensor`):
763
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
764
+ input_shape (:obj:`Tuple[int]`):
765
+ The shape of the input to the model.
766
+ device: (:obj:`torch.device`):
767
+ The device of the input to the model.
768
+
769
+ Returns:
770
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
771
+ """
772
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
773
+ # ourselves in which case we just need to make it broadcastable to all heads.
774
+ #print(attention_mask.dim())
775
+ if attention_mask.dim() == 3:
776
+ extended_attention_mask = attention_mask[:, None, :, :]
777
+ elif attention_mask.dim() == 2:
778
+ # Provided a padding mask of dimensions [batch_size, seq_length]
779
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
780
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
781
+ if is_decoder or is_casual:
782
+ batch_size, seq_length = input_shape
783
+ #print(input_shape)
784
+ if not is_decoder and seq_length > 32:
785
+ query_length = 32
786
+ text_length = seq_length - query_length
787
+ query_ids = torch.arange(query_length, device=device)
788
+ query_causal_mask = (query_ids[None, None, :].repeat(batch_size, query_length, 1) <= query_ids[None, :,
789
+ None])
790
+ causal_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
791
+ causal_mask[:, :query_length, :query_length] = query_causal_mask
792
+ # print(query_causal_mask.shape, causal_mask.shape)
793
+ #print(causal_mask[0])
794
+
795
+ else:
796
+ seq_ids = torch.arange(seq_length, device=device)
797
+ causal_mask = (seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None])
798
+
799
+ # add a prefix ones mask to the causal mask
800
+ # causal and attention masks must have same type with pytorch version < 1.3
801
+ causal_mask = causal_mask.to(attention_mask.dtype)
802
+ # if is_decoder:
803
+ # print(causal_mask.shape, attention_mask.shape)
804
+ #print(causal_mask.shape, attention_mask.shape)
805
+
806
+ if causal_mask.shape[1] < attention_mask.shape[1]:
807
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
808
+ if has_query: # UniLM style attention mask
809
+ causal_mask = torch.cat(
810
+ [
811
+ torch.zeros(
812
+ (batch_size, prefix_seq_len, seq_length),
813
+ device=device,
814
+ dtype=causal_mask.dtype,
815
+ ),
816
+ causal_mask,
817
+ ],
818
+ axis=1,
819
+ )
820
+ causal_mask = torch.cat(
821
+ [
822
+ torch.ones(
823
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
824
+ device=device,
825
+ dtype=causal_mask.dtype,
826
+ ),
827
+ causal_mask,
828
+ ],
829
+ axis=-1,
830
+ )
831
+ #print(has_query, causal_mask.shape)
832
+ #print(causal_mask[0])
833
+ extended_attention_mask = (causal_mask[:, None, :, :] * attention_mask[:, None, None, :])
834
+ #print(extended_attention_mask[0])
835
+ #print('extended_attention_mask', extended_attention_mask.shape)
836
+ else:
837
+ extended_attention_mask = attention_mask[:, None, None, :]
838
+ #print(attention_mask.shape, extended_attention_mask.shape)
839
+ else:
840
+ raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
841
+ input_shape, attention_mask.shape))
842
+
843
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
844
+ # masked positions, this operation will create a tensor which is 0.0 for
845
+ # positions we want to attend and -10000.0 for masked positions.
846
+ # Since we are adding it to the raw scores before the softmax, this is
847
+ # effectively the same as removing these entirely.
848
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
849
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
850
+ return extended_attention_mask
851
+
852
+ def forward(
853
+ self,
854
+ input_ids=None,
855
+ attention_mask=None,
856
+ position_ids=None,
857
+ head_mask=None,
858
+ query_embeds=None,
859
+ encoder_hidden_states=None,
860
+ encoder_attention_mask=None,
861
+ past_key_values=None,
862
+ use_cache=None,
863
+ output_attentions=None,
864
+ output_hidden_states=None,
865
+ return_dict=None,
866
+ is_decoder=False,
867
+ ):
868
+ r"""
869
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
870
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
871
+ the model is configured as a decoder.
872
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
873
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
874
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
875
+ - 1 for tokens that are **not masked**,
876
+ - 0 for tokens that are **masked**.
877
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
878
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
879
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
880
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
881
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
882
+ use_cache (:obj:`bool`, `optional`):
883
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
884
+ decoding (see :obj:`past_key_values`).
885
+ """
886
+ output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
887
+ output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
888
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
889
+
890
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
891
+
892
+ if input_ids is None:
893
+ assert (query_embeds is not None), "You have to specify query_embeds when input_ids is None"
894
+
895
+ #if query_embeds is not None:
896
+ if query_embeds is not None and query_embeds.shape[1] == 32:
897
+ is_casual = True
898
+ else:
899
+ is_casual = False
900
+ past_key_values_length = (past_key_values[0][0].shape[2] -
901
+ self.config.query_length if past_key_values is not None else 0)
902
+
903
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
904
+
905
+ embedding_output = self.embeddings(
906
+ input_ids=input_ids,
907
+ position_ids=position_ids,
908
+ query_embeds=query_embeds,
909
+ past_key_values_length=past_key_values_length,
910
+ )
911
+
912
+ input_shape = embedding_output.size()[:-1]
913
+ batch_size, seq_length = input_shape
914
+ device = embedding_output.device
915
+
916
+ #print('attention_mask', attention_mask)
917
+ if attention_mask is None:
918
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
919
+ #print(seq_length, past_key_values_length)
920
+
921
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
922
+ # ourselves in which case we just need to make it broadcastable to all heads.
923
+ if is_decoder:
924
+ #print(attention_mask.shape, input_ids.shape)
925
+ extended_attention_mask = self.get_extended_attention_mask(
926
+ attention_mask,
927
+ input_ids.shape,
928
+ device,
929
+ is_decoder,
930
+ is_casual,
931
+ has_query=(query_embeds is not None),
932
+ )
933
+ else:
934
+ extended_attention_mask = self.get_extended_attention_mask(
935
+ attention_mask,
936
+ input_shape,
937
+ device,
938
+ is_decoder,
939
+ is_casual,
940
+ )
941
+ #print(is_decoder, extended_attention_mask.shape)
942
+ # if is_decoder:
943
+ # print(extended_attention_mask[0,0,:,32:])
944
+ # if attention_mask is not None:
945
+ # print(input_ids, embedding_output.shape, extended_attention_mask.shape)
946
+
947
+ # If a 2D or 3D attention mask is provided for the cross-attention
948
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
949
+ if encoder_hidden_states is not None:
950
+ if type(encoder_hidden_states) == list:
951
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
952
+ else:
953
+ (
954
+ encoder_batch_size,
955
+ encoder_sequence_length,
956
+ _,
957
+ ) = encoder_hidden_states.size()
958
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
959
+
960
+ if type(encoder_attention_mask) == list:
961
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
962
+ elif encoder_attention_mask is None:
963
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
964
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
965
+ else:
966
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
967
+ #print(is_casual, extended_attention_mask.shape, encoder_attention_mask.shape, encoder_extended_attention_mask.shape)
968
+ else:
969
+ encoder_extended_attention_mask = None
970
+
971
+ # if input_ids is not None and query_embeds is not None:
972
+ # print(extended_attention_mask.shape, encoder_extended_attention_mask.shape)
973
+ # Prepare head mask if needed
974
+ # 1.0 in head_mask indicate we keep the head
975
+ # attention_probs has shape bsz x n_heads x N x N
976
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
977
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
978
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
979
+ #print(head_mask)
980
+
981
+ encoder_outputs = self.encoder(
982
+ embedding_output,
983
+ attention_mask=extended_attention_mask,
984
+ head_mask=head_mask,
985
+ encoder_hidden_states=encoder_hidden_states,
986
+ encoder_attention_mask=encoder_extended_attention_mask,
987
+ past_key_values=past_key_values,
988
+ use_cache=use_cache,
989
+ output_attentions=output_attentions,
990
+ output_hidden_states=output_hidden_states,
991
+ return_dict=return_dict,
992
+ query_length=query_length,
993
+ )
994
+ # if is_decoder:
995
+ # print(embedding_output.shape, attention_mask.shape, len(past_key_values))
996
+ #print(embedding_output.shape, extended_attention_mask.shape, encoder_hidden_states.shape, encoder_extended_attention_mask.shape)
997
+ #print(extended_attention_mask[0], encoder_extended_attention_mask[0])
998
+
999
+ #print(query_embeds.shape, encoder_hidden_states.shape)
1000
+
1001
+ sequence_output = encoder_outputs[0]
1002
+ pooled_output = (self.pooler(sequence_output) if self.pooler is not None else None)
1003
+
1004
+ if not return_dict:
1005
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1006
+
1007
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1008
+ last_hidden_state=sequence_output,
1009
+ pooler_output=pooled_output,
1010
+ past_key_values=encoder_outputs.past_key_values,
1011
+ hidden_states=encoder_outputs.hidden_states,
1012
+ attentions=encoder_outputs.attentions,
1013
+ cross_attentions=encoder_outputs.cross_attentions,
1014
+ )
1015
+
1016
+
1017
+ class BertLMHeadModel(BertPreTrainedModel):
1018
+
1019
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1020
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1021
+
1022
+ def __init__(self, config):
1023
+ super().__init__(config)
1024
+
1025
+ self.bert = BertModel(config, add_pooling_layer=False)
1026
+ self.cls = BertOnlyMLMHead(config)
1027
+
1028
+ self.init_weights()
1029
+
1030
+ def get_output_embeddings(self):
1031
+ return self.cls.predictions.decoder
1032
+
1033
+ def set_output_embeddings(self, new_embeddings):
1034
+ self.cls.predictions.decoder = new_embeddings
1035
+
1036
+ def forward(
1037
+ self,
1038
+ input_ids=None,
1039
+ attention_mask=None,
1040
+ position_ids=None,
1041
+ head_mask=None,
1042
+ query_embeds=None,
1043
+ encoder_hidden_states=None,
1044
+ encoder_attention_mask=None,
1045
+ labels=None,
1046
+ past_key_values=None,
1047
+ use_cache=True,
1048
+ output_attentions=None,
1049
+ output_hidden_states=None,
1050
+ return_dict=None,
1051
+ return_logits=False,
1052
+ is_decoder=True,
1053
+ reduction="mean",
1054
+ ):
1055
+ r"""
1056
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1057
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1058
+ the model is configured as a decoder.
1059
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1060
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1061
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1062
+ - 1 for tokens that are **not masked**,
1063
+ - 0 for tokens that are **masked**.
1064
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1065
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1066
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1067
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1068
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1069
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1070
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1071
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1072
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1073
+ use_cache (:obj:`bool`, `optional`):
1074
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1075
+ decoding (see :obj:`past_key_values`).
1076
+ Returns:
1077
+ Example::
1078
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1079
+ >>> import torch
1080
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1081
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1082
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1083
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1084
+ >>> outputs = model(**inputs)
1085
+ >>> prediction_logits = outputs.logits
1086
+ """
1087
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
1088
+ if labels is not None:
1089
+ use_cache = False
1090
+ if past_key_values is not None:
1091
+ query_embeds = None
1092
+ #print(len(past_key_values))
1093
+ #print('attention_mask', attention_mask)
1094
+ outputs = self.bert(
1095
+ input_ids,
1096
+ attention_mask=attention_mask,
1097
+ position_ids=position_ids,
1098
+ head_mask=head_mask,
1099
+ query_embeds=query_embeds,
1100
+ encoder_hidden_states=encoder_hidden_states,
1101
+ encoder_attention_mask=encoder_attention_mask,
1102
+ past_key_values=past_key_values,
1103
+ use_cache=use_cache,
1104
+ output_attentions=output_attentions,
1105
+ output_hidden_states=output_hidden_states,
1106
+ return_dict=return_dict,
1107
+ is_decoder=is_decoder,
1108
+ )
1109
+
1110
+ sequence_output = outputs[0]
1111
+ if query_embeds is not None:
1112
+ sequence_output = outputs[0][:, query_embeds.shape[1]:, :]
1113
+
1114
+ prediction_scores = self.cls(sequence_output)
1115
+
1116
+ if return_logits:
1117
+ return prediction_scores[:, :-1, :].contiguous()
1118
+
1119
+ lm_loss = None
1120
+ if labels is not None:
1121
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1122
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1123
+ labels = labels[:, 1:].contiguous()
1124
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1125
+ lm_loss = loss_fct(
1126
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1127
+ labels.view(-1),
1128
+ )
1129
+ if reduction == "none":
1130
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1131
+
1132
+ if not return_dict:
1133
+ output = (prediction_scores, ) + outputs[2:]
1134
+ return ((lm_loss, ) + output) if lm_loss is not None else output
1135
+
1136
+ return CausalLMOutputWithCrossAttentions(
1137
+ loss=lm_loss,
1138
+ logits=prediction_scores,
1139
+ past_key_values=outputs.past_key_values,
1140
+ hidden_states=outputs.hidden_states,
1141
+ attentions=outputs.attentions,
1142
+ cross_attentions=outputs.cross_attentions,
1143
+ )
1144
+
1145
+ def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs):
1146
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1147
+ if attention_mask is None:
1148
+ attention_mask = input_ids.new_ones(input_ids.shape)
1149
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1150
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1151
+
1152
+ # cut decoder_input_ids if past is used
1153
+ if past is not None:
1154
+ input_ids = input_ids[:, -1:]
1155
+
1156
+ return {
1157
+ "input_ids": input_ids,
1158
+ "query_embeds": query_embeds,
1159
+ "attention_mask": attention_mask,
1160
+ "past_key_values": past,
1161
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1162
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1163
+ "is_decoder": True,
1164
+ }
1165
+
1166
+ def _reorder_cache(self, past, beam_idx):
1167
+ reordered_past = ()
1168
+ for layer_past in past:
1169
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past), )
1170
+ return reordered_past
1171
+
1172
+
1173
+ class BertForMaskedLM(BertPreTrainedModel):
1174
+
1175
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1176
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1177
+
1178
+ def __init__(self, config):
1179
+ super().__init__(config)
1180
+
1181
+ self.bert = BertModel(config, add_pooling_layer=False)
1182
+ self.cls = BertOnlyMLMHead(config)
1183
+
1184
+ self.init_weights()
1185
+
1186
+ def get_output_embeddings(self):
1187
+ return self.cls.predictions.decoder
1188
+
1189
+ def set_output_embeddings(self, new_embeddings):
1190
+ self.cls.predictions.decoder = new_embeddings
1191
+
1192
+ def forward(
1193
+ self,
1194
+ input_ids=None,
1195
+ attention_mask=None,
1196
+ position_ids=None,
1197
+ head_mask=None,
1198
+ query_embeds=None,
1199
+ encoder_hidden_states=None,
1200
+ encoder_attention_mask=None,
1201
+ labels=None,
1202
+ output_attentions=None,
1203
+ output_hidden_states=None,
1204
+ return_dict=None,
1205
+ return_logits=False,
1206
+ is_decoder=False,
1207
+ ):
1208
+ r"""
1209
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1210
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1211
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1212
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1213
+ """
1214
+
1215
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
1216
+
1217
+ outputs = self.bert(
1218
+ input_ids,
1219
+ attention_mask=attention_mask,
1220
+ position_ids=position_ids,
1221
+ head_mask=head_mask,
1222
+ query_embeds=query_embeds,
1223
+ encoder_hidden_states=encoder_hidden_states,
1224
+ encoder_attention_mask=encoder_attention_mask,
1225
+ output_attentions=output_attentions,
1226
+ output_hidden_states=output_hidden_states,
1227
+ return_dict=return_dict,
1228
+ is_decoder=is_decoder,
1229
+ )
1230
+
1231
+ if query_embeds is not None:
1232
+ sequence_output = outputs[0][:, query_embeds.shape[1]:, :]
1233
+ prediction_scores = self.cls(sequence_output)
1234
+
1235
+ if return_logits:
1236
+ return prediction_scores
1237
+
1238
+ masked_lm_loss = None
1239
+ if labels is not None:
1240
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1241
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1242
+
1243
+ if not return_dict:
1244
+ output = (prediction_scores, ) + outputs[2:]
1245
+ return (((masked_lm_loss, ) + output) if masked_lm_loss is not None else output)
1246
+
1247
+ return MaskedLMOutput(
1248
+ loss=masked_lm_loss,
1249
+ logits=prediction_scores,
1250
+ hidden_states=outputs.hidden_states,
1251
+ attentions=outputs.attentions,
1252
+ )
1253
+
1254
+ class Mlp(nn.Module):
1255
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
1256
+ def __init__(
1257
+ self,
1258
+ in_features,
1259
+ hidden_features=None,
1260
+ out_features=None,
1261
+ act_layer=nn.GELU,
1262
+ drop=0.0,
1263
+ ):
1264
+ super().__init__()
1265
+ out_features = out_features or in_features
1266
+ hidden_features = hidden_features or in_features
1267
+ self.fc1 = nn.Linear(in_features, hidden_features)
1268
+ self.act = act_layer()
1269
+ self.fc2 = nn.Linear(hidden_features, out_features)
1270
+ self.drop = nn.Dropout(drop)
1271
+
1272
+ def forward(self, x):
1273
+ x = self.fc1(x)
1274
+ x = self.act(x)
1275
+ x = self.drop(x)
1276
+ x = self.fc2(x)
1277
+ x = self.drop(x)
1278
+ return x
1279
+
1280
+
1281
+ class Attention(nn.Module):
1282
+ def __init__(
1283
+ self,
1284
+ dim,
1285
+ num_heads=8,
1286
+ qkv_bias=False,
1287
+ qk_scale=None,
1288
+ attn_drop=0.0,
1289
+ proj_drop=0.0,
1290
+ ):
1291
+ super().__init__()
1292
+ self.num_heads = num_heads
1293
+ head_dim = dim // num_heads
1294
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
1295
+ self.scale = qk_scale or head_dim**-0.5
1296
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
1297
+ self.attn_drop = nn.Dropout(attn_drop)
1298
+ self.proj = nn.Linear(dim, dim)
1299
+ self.proj_drop = nn.Dropout(proj_drop)
1300
+ self.attn_gradients = None
1301
+ self.attention_map = None
1302
+
1303
+ def save_attn_gradients(self, attn_gradients):
1304
+ self.attn_gradients = attn_gradients
1305
+
1306
+ def get_attn_gradients(self):
1307
+ return self.attn_gradients
1308
+
1309
+ def save_attention_map(self, attention_map):
1310
+ self.attention_map = attention_map
1311
+
1312
+ def get_attention_map(self):
1313
+ return self.attention_map
1314
+
1315
+ def forward(self, x, register_hook=False):
1316
+ B, N, C = x.shape
1317
+ qkv = (self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4))
1318
+ q, k, v = (
1319
+ qkv[0],
1320
+ qkv[1],
1321
+ qkv[2],
1322
+ ) # make torchscript happy (cannot use tensor as tuple)
1323
+
1324
+ attn = (q @ k.transpose(-2, -1)) * self.scale
1325
+ attn = attn.softmax(dim=-1)
1326
+ attn = self.attn_drop(attn)
1327
+
1328
+ if register_hook:
1329
+ self.save_attention_map(attn)
1330
+ attn.register_hook(self.save_attn_gradients)
1331
+
1332
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
1333
+ x = self.proj(x)
1334
+ x = self.proj_drop(x)
1335
+ return x
1336
+
1337
+
1338
+ class Block(nn.Module):
1339
+ def __init__(
1340
+ self,
1341
+ dim,
1342
+ num_heads,
1343
+ mlp_ratio=4.0,
1344
+ qkv_bias=False,
1345
+ qk_scale=None,
1346
+ drop=0.0,
1347
+ attn_drop=0.0,
1348
+ drop_path=0.0,
1349
+ act_layer=nn.GELU,
1350
+ norm_layer=nn.LayerNorm,
1351
+ use_grad_checkpointing=False,
1352
+ ):
1353
+ super().__init__()
1354
+ self.norm1 = norm_layer(dim)
1355
+ self.attn = Attention(
1356
+ dim,
1357
+ num_heads=num_heads,
1358
+ qkv_bias=qkv_bias,
1359
+ qk_scale=qk_scale,
1360
+ attn_drop=attn_drop,
1361
+ proj_drop=drop,
1362
+ )
1363
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
1364
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
1365
+ self.norm2 = norm_layer(dim)
1366
+ mlp_hidden_dim = int(dim * mlp_ratio)
1367
+ self.mlp = Mlp(
1368
+ in_features=dim,
1369
+ hidden_features=mlp_hidden_dim,
1370
+ act_layer=act_layer,
1371
+ drop=drop,
1372
+ )
1373
+
1374
+ # if use_grad_checkpointing:
1375
+ # self.attn = checkpoint_wrapper(self.attn)
1376
+ # self.mlp = checkpoint_wrapper(self.mlp)
1377
+
1378
+ def forward(self, x, register_hook=False):
1379
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
1380
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
1381
+ return x
1382
+
1383
+
1384
+ class VisionTransformer(nn.Module):
1385
+ """Vision Transformer
1386
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
1387
+ https://arxiv.org/abs/2010.11929
1388
+ """
1389
+ def __init__(
1390
+ self,
1391
+ img_size=224,
1392
+ patch_size=16,
1393
+ in_chans=3,
1394
+ num_classes=1000,
1395
+ embed_dim=768,
1396
+ depth=12,
1397
+ num_heads=12,
1398
+ mlp_ratio=4.0,
1399
+ qkv_bias=True,
1400
+ qk_scale=None,
1401
+ representation_size=None,
1402
+ drop_rate=0.0,
1403
+ attn_drop_rate=0.0,
1404
+ drop_path_rate=0.0,
1405
+ norm_layer=None,
1406
+ use_grad_checkpointing=False,
1407
+ ckpt_layer=0,
1408
+ ):
1409
+ """
1410
+ Args:
1411
+ img_size (int, tuple): input image size
1412
+ patch_size (int, tuple): patch size
1413
+ in_chans (int): number of input channels
1414
+ num_classes (int): number of classes for classification head
1415
+ embed_dim (int): embedding dimension
1416
+ depth (int): depth of transformer
1417
+ num_heads (int): number of attention heads
1418
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
1419
+ qkv_bias (bool): enable bias for qkv if True
1420
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
1421
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
1422
+ drop_rate (float): dropout rate
1423
+ attn_drop_rate (float): attention dropout rate
1424
+ drop_path_rate (float): stochastic depth rate
1425
+ norm_layer: (nn.Module): normalization layer
1426
+ """
1427
+ super().__init__()
1428
+ self.num_features = (self.embed_dim) = embed_dim # num_features for consistency with other models
1429
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
1430
+
1431
+ self.patch_embed = PatchEmbed(
1432
+ img_size=img_size,
1433
+ patch_size=patch_size,
1434
+ in_chans=in_chans,
1435
+ embed_dim=embed_dim,
1436
+ )
1437
+
1438
+ num_patches = self.patch_embed.num_patches
1439
+
1440
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
1441
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
1442
+ self.pos_drop = nn.Dropout(p=drop_rate)
1443
+
1444
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
1445
+ self.blocks = nn.ModuleList([
1446
+ Block(
1447
+ dim=embed_dim,
1448
+ num_heads=num_heads,
1449
+ mlp_ratio=mlp_ratio,
1450
+ qkv_bias=qkv_bias,
1451
+ qk_scale=qk_scale,
1452
+ drop=drop_rate,
1453
+ attn_drop=attn_drop_rate,
1454
+ drop_path=dpr[i],
1455
+ norm_layer=norm_layer,
1456
+ use_grad_checkpointing=(use_grad_checkpointing and i >= depth - ckpt_layer),
1457
+ ) for i in range(depth)
1458
+ ])
1459
+ self.norm = norm_layer(embed_dim)
1460
+
1461
+ trunc_normal_(self.pos_embed, std=0.02)
1462
+ trunc_normal_(self.cls_token, std=0.02)
1463
+ self.apply(self._init_weights)
1464
+
1465
+ def _init_weights(self, m):
1466
+ if isinstance(m, nn.Linear):
1467
+ trunc_normal_(m.weight, std=0.02)
1468
+ if isinstance(m, nn.Linear) and m.bias is not None:
1469
+ nn.init.constant_(m.bias, 0)
1470
+ elif isinstance(m, nn.LayerNorm):
1471
+ nn.init.constant_(m.bias, 0)
1472
+ nn.init.constant_(m.weight, 1.0)
1473
+
1474
+ @torch.jit.ignore
1475
+ def no_weight_decay(self):
1476
+ return {"pos_embed", "cls_token"}
1477
+
1478
+ def forward(self, x, register_blk=-1):
1479
+ B = x.shape[0]
1480
+ x = self.patch_embed(x)
1481
+
1482
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
1483
+ x = torch.cat((cls_tokens, x), dim=1)
1484
+
1485
+ x = x + self.pos_embed[:, :x.size(1), :]
1486
+ x = self.pos_drop(x)
1487
+
1488
+ for i, blk in enumerate(self.blocks):
1489
+ x = blk(x, register_blk == i)
1490
+ x = self.norm(x)
1491
+
1492
+ return x
1493
+
1494
+ @torch.jit.ignore()
1495
+ def load_pretrained(self, checkpoint_path, prefix=""):
1496
+ _load_weights(self, checkpoint_path, prefix)
1497
+
1498
+
1499
+ @torch.no_grad()
1500
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ""):
1501
+ """Load weights from .npz checkpoints for official Google Brain Flax implementation"""
1502
+ import numpy as np
1503
+
1504
+ def _n2p(w, t=True):
1505
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
1506
+ w = w.flatten()
1507
+ if t:
1508
+ if w.ndim == 4:
1509
+ w = w.transpose([3, 2, 0, 1])
1510
+ elif w.ndim == 3:
1511
+ w = w.transpose([2, 0, 1])
1512
+ elif w.ndim == 2:
1513
+ w = w.transpose([1, 0])
1514
+ return torch.from_numpy(w)
1515
+
1516
+ w = np.load(checkpoint_path)
1517
+ if not prefix and "opt/target/embedding/kernel" in w:
1518
+ prefix = "opt/target/"
1519
+
1520
+ if hasattr(model.patch_embed, "backbone"):
1521
+ # hybrid
1522
+ backbone = model.patch_embed.backbone
1523
+ stem_only = not hasattr(backbone, "stem")
1524
+ stem = backbone if stem_only else backbone.stem
1525
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f"{prefix}conv_root/kernel"])))
1526
+ stem.norm.weight.copy_(_n2p(w[f"{prefix}gn_root/scale"]))
1527
+ stem.norm.bias.copy_(_n2p(w[f"{prefix}gn_root/bias"]))
1528
+ if not stem_only:
1529
+ for i, stage in enumerate(backbone.stages):
1530
+ for j, block in enumerate(stage.blocks):
1531
+ bp = f"{prefix}block{i + 1}/unit{j + 1}/"
1532
+ for r in range(3):
1533
+ getattr(block, f"conv{r + 1}").weight.copy_(_n2p(w[f"{bp}conv{r + 1}/kernel"]))
1534
+ getattr(block, f"norm{r + 1}").weight.copy_(_n2p(w[f"{bp}gn{r + 1}/scale"]))
1535
+ getattr(block, f"norm{r + 1}").bias.copy_(_n2p(w[f"{bp}gn{r + 1}/bias"]))
1536
+ if block.downsample is not None:
1537
+ block.downsample.conv.weight.copy_(_n2p(w[f"{bp}conv_proj/kernel"]))
1538
+ block.downsample.norm.weight.copy_(_n2p(w[f"{bp}gn_proj/scale"]))
1539
+ block.downsample.norm.bias.copy_(_n2p(w[f"{bp}gn_proj/bias"]))
1540
+ embed_conv_w = _n2p(w[f"{prefix}embedding/kernel"])
1541
+ else:
1542
+ embed_conv_w = adapt_input_conv(model.patch_embed.proj.weight.shape[1], _n2p(w[f"{prefix}embedding/kernel"]))
1543
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
1544
+ model.patch_embed.proj.bias.copy_(_n2p(w[f"{prefix}embedding/bias"]))
1545
+ model.cls_token.copy_(_n2p(w[f"{prefix}cls"], t=False))
1546
+ pos_embed_w = _n2p(w[f"{prefix}Transformer/posembed_input/pos_embedding"], t=False)
1547
+ if pos_embed_w.shape != model.pos_embed.shape:
1548
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
1549
+ pos_embed_w,
1550
+ model.pos_embed,
1551
+ getattr(model, "num_tokens", 1),
1552
+ model.patch_embed.grid_size,
1553
+ )
1554
+ model.pos_embed.copy_(pos_embed_w)
1555
+ model.norm.weight.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/scale"]))
1556
+ model.norm.bias.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/bias"]))
1557
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
1558
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
1559
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
1560
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
1561
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
1562
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
1563
+ for i, block in enumerate(model.blocks.children()):
1564
+ block_prefix = f"{prefix}Transformer/encoderblock_{i}/"
1565
+ mha_prefix = block_prefix + "MultiHeadDotProductAttention_1/"
1566
+ block.norm1.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/scale"]))
1567
+ block.norm1.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/bias"]))
1568
+ block.attn.qkv.weight.copy_(
1569
+ torch.cat([_n2p(w[f"{mha_prefix}{n}/kernel"], t=False).flatten(1).T for n in ("query", "key", "value")]))
1570
+ block.attn.qkv.bias.copy_(
1571
+ torch.cat([_n2p(w[f"{mha_prefix}{n}/bias"], t=False).reshape(-1) for n in ("query", "key", "value")]))
1572
+ block.attn.proj.weight.copy_(_n2p(w[f"{mha_prefix}out/kernel"]).flatten(1))
1573
+ block.attn.proj.bias.copy_(_n2p(w[f"{mha_prefix}out/bias"]))
1574
+ for r in range(2):
1575
+ getattr(block.mlp, f"fc{r + 1}").weight.copy_(_n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/kernel"]))
1576
+ getattr(block.mlp, f"fc{r + 1}").bias.copy_(_n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/bias"]))
1577
+ block.norm2.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/scale"]))
1578
+ block.norm2.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/bias"]))
1579
+
1580
+
1581
+ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
1582
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
1583
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
1584
+ print("Resized position embedding: %s to %s", posemb.shape, posemb_new.shape)
1585
+ ntok_new = posemb_new.shape[1]
1586
+ if num_tokens:
1587
+ posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
1588
+ ntok_new -= num_tokens
1589
+ else:
1590
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
1591
+ gs_old = int(math.sqrt(len(posemb_grid)))
1592
+ if not len(gs_new): # backwards compatibility
1593
+ gs_new = [int(math.sqrt(ntok_new))] * 2
1594
+ assert len(gs_new) >= 2
1595
+ print("Position embedding grid-size from %s to %s", [gs_old, gs_old], gs_new)
1596
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
1597
+ posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode="bicubic", align_corners=False)
1598
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
1599
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
1600
+ return
1601
+
1602
+
1603
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
1604
+ # interpolate position embedding
1605
+ embedding_size = pos_embed_checkpoint.shape[-1]
1606
+ num_patches = visual_encoder.patch_embed.num_patches
1607
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
1608
+ # height (== width) for the checkpoint position embedding
1609
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
1610
+ # height (== width) for the new position embedding
1611
+ new_size = int(num_patches**0.5)
1612
+
1613
+ if orig_size != new_size:
1614
+ # class_token and dist_token are kept unchanged
1615
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
1616
+ # only the position tokens are interpolated
1617
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
1618
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
1619
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False)
1620
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
1621
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
1622
+ print("reshape position embedding from %d to %d" % (orig_size**2, new_size**2))
1623
+
1624
+ return new_pos_embed
1625
+ else:
1626
+ return pos_embed_checkpoint
1627
+
1628
+ # class Blip2Base(BaseModel):
1629
+ class Blip2Base(PreTrainedModel):
1630
+ config_class = BertConfig
1631
+
1632
+ def __init__(self, config):
1633
+ super().__init__(config)
1634
+
1635
+ @property
1636
+ def device(self):
1637
+ return list(self.parameters())[0].device
1638
+
1639
+ @classmethod
1640
+ def init_tokenizer(cls, truncation_side="right"):
1641
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side=truncation_side)
1642
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
1643
+ return tokenizer
1644
+
1645
+ def maybe_autocast(self, dtype=torch.float16):
1646
+ # if on cpu, don't use autocast
1647
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
1648
+ enable_autocast = self.device != torch.device("cpu")
1649
+
1650
+ if enable_autocast:
1651
+ return torch.cuda.amp.autocast(dtype=dtype)
1652
+ else:
1653
+ return contextlib.nullcontext()
1654
+
1655
+ @classmethod
1656
+ def init_Qformer(cls, encoder_config, num_query_token, vision_width, cross_attention_freq=2, cache_dir=""):
1657
+ print ("loading")
1658
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
1659
+ encoder_config.encoder_width = vision_width
1660
+ # insert cross-attention layer every other block
1661
+ encoder_config.add_cross_attention = True
1662
+ encoder_config.cross_attention_freq = cross_attention_freq
1663
+ encoder_config.query_length = num_query_token
1664
+ Qformer = BertLMHeadModel(encoder_config) # .from_pretrained("bert-base-uncased", config=encoder_config, cache_dir=cache_dir)
1665
+ query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size))
1666
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
1667
+ return Qformer, query_tokens
1668
+
1669
+ def load_from_pretrained(self, url_or_filename):
1670
+ if is_url(url_or_filename):
1671
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
1672
+ checkpoint = torch.load(cached_file, map_location="cpu")
1673
+ elif os.path.isfile(url_or_filename):
1674
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
1675
+ else:
1676
+ raise RuntimeError("checkpoint url or path is invalid")
1677
+
1678
+ state_dict = checkpoint["model"]
1679
+
1680
+ msg = self.load_state_dict(state_dict, strict=False)
1681
+
1682
+ # logging.info("Missing keys {}".format(msg.missing_keys))
1683
+ logging.info("load checkpoint from %s" % url_or_filename)
1684
+
1685
+ return msg
1686
+
1687
+ def _lemmatize(self, answers):
1688
+ def apply(answer):
1689
+ doc = self.lemmatizer(answer)
1690
+
1691
+ words = []
1692
+ for token in doc:
1693
+ if token.pos_ in ["NOUN", "VERB"]:
1694
+ words.append(token.lemma_)
1695
+ else:
1696
+ words.append(token.text)
1697
+ answer = " ".join(words)
1698
+
1699
+ return answer
1700
+
1701
+ return [apply(answer) for answer in answers]
1702
+
1703
+ @property
1704
+ def lemmatizer(self):
1705
+ if self._lemmatizer is None:
1706
+ try:
1707
+ import spacy
1708
+
1709
+ self._lemmatizer = spacy.load("en_core_web_sm")
1710
+ except ImportError:
1711
+ logging.error("""
1712
+ Please install spacy and en_core_web_sm model to apply lemmatization.
1713
+ python -m spacy download en_core_web_sm
1714
+ OR
1715
+ import spacy.cli
1716
+ spacy.cli.download("en_core_web_sm")
1717
+ """)
1718
+ exit(1)
1719
+
1720
+ return self._lemmatizer
1721
+
1722
+
1723
+ def disabled_train(self, mode=True):
1724
+ """Overwrite model.train with this function to make sure train/eval mode
1725
+ does not change anymore."""
1726
+ return self
1727
+
1728
+
1729
+ class LayerNorm(nn.LayerNorm):
1730
+ """Subclass torch's LayerNorm to handle fp16."""
1731
+ def forward(self, x: torch.Tensor):
1732
+ orig_type = x.dtype
1733
+ ret = super().forward(x.type(torch.float32))
1734
+ return ret.type(orig_type)
1735
+
1736
+
1737
+
1738
+
1739
+ class VectorQuantizer2(nn.Module):
1740
+ """
1741
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
1742
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
1743
+ """
1744
+
1745
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
1746
+ # backwards compatibility we use the buggy version by default, but you can
1747
+ # specify legacy=False to fix it.
1748
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
1749
+ super().__init__()
1750
+ self.n_e = n_e
1751
+ self.e_dim = e_dim
1752
+ self.beta = beta
1753
+ self.legacy = legacy
1754
+
1755
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
1756
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
1757
+
1758
+ self.remap = remap
1759
+ if self.remap is not None:
1760
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
1761
+ self.re_embed = self.used.shape[0]
1762
+ self.unknown_index = unknown_index # "random" or "extra" or integer
1763
+ if self.unknown_index == "extra":
1764
+ self.unknown_index = self.re_embed
1765
+ self.re_embed = self.re_embed + 1
1766
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
1767
+ f"Using {self.unknown_index} for unknown indices.")
1768
+ else:
1769
+ self.re_embed = n_e
1770
+
1771
+ self.sane_index_shape = sane_index_shape
1772
+
1773
+ def remap_to_used(self, inds):
1774
+ ishape = inds.shape
1775
+ assert len(ishape) > 1
1776
+ inds = inds.reshape(ishape[0], -1)
1777
+ used = self.used.to(inds)
1778
+ match = (inds[:, :, None] == used[None, None, ...]).long()
1779
+ new = match.argmax(-1)
1780
+ unknown = match.sum(2) < 1
1781
+ if self.unknown_index == "random":
1782
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
1783
+ else:
1784
+ new[unknown] = self.unknown_index
1785
+ return new.reshape(ishape)
1786
+
1787
+ def unmap_to_all(self, inds):
1788
+ ishape = inds.shape
1789
+ assert len(ishape) > 1
1790
+ inds = inds.reshape(ishape[0], -1)
1791
+ used = self.used.to(inds)
1792
+ if self.re_embed > self.used.shape[0]: # extra token
1793
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
1794
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
1795
+ return back.reshape(ishape)
1796
+
1797
+ # def l2norm(self, t):
1798
+ # return F.normalize(t, p = 2, dim = -1)
1799
+
1800
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
1801
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
1802
+ assert rescale_logits is False, "Only for interface compatible with Gumbel"
1803
+ assert return_logits is False, "Only for interface compatible with Gumbel"
1804
+ # reshape z -> (batch, height, width, channel) and flatten
1805
+ #z = rearrange(z, 'b c h w -> b h w c').contiguous()
1806
+ bz = z.shape[0]
1807
+ z_flattened = z.view(-1, self.e_dim)
1808
+ #print('z_flattened', z_flattened.shape)
1809
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
1810
+
1811
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
1812
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
1813
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
1814
+
1815
+ min_encoding_indices = torch.argmin(d, dim=1)
1816
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
1817
+ perplexity = None
1818
+ min_encodings = None
1819
+
1820
+ # compute loss for embedding
1821
+ if not self.legacy:
1822
+ loss = self.beta * torch.mean((z_q.detach() - z)**2) + torch.mean((z_q - z.detach())**2)
1823
+ else:
1824
+ loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)
1825
+
1826
+ # preserve gradients
1827
+ z_q = z + (z_q - z).detach()
1828
+
1829
+ # reshape back to match original input shape
1830
+ #z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
1831
+ z_q = z_q.reshape(bz, -1, z_q.shape[-1])
1832
+ if self.remap is not None:
1833
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
1834
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
1835
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
1836
+
1837
+ if self.sane_index_shape:
1838
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
1839
+
1840
+ return z_q, loss, min_encoding_indices
1841
+
1842
+ def get_codebook_entry(self, indices, shape=None):
1843
+ # shape specifying (batch, height, width, channel)
1844
+ if self.remap is not None:
1845
+ indices = indices.reshape(shape[0], -1) # add batch axis
1846
+ indices = self.unmap_to_all(indices)
1847
+ indices = indices.reshape(-1) # flatten again
1848
+
1849
+ # get quantized latent vectors
1850
+ z_q = self.embedding(indices)
1851
+
1852
+ if shape is not None:
1853
+ z_q = z_q.view(shape)
1854
+ # reshape back to match original input shape
1855
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
1856
+
1857
+ return z_q
1858
+
1859
+
1860
+ class Blip2QformerQuantizer(Blip2Base):
1861
+ """
1862
+ BLIP2 first-stage model with Q-former and ViT.
1863
+ Supported model types:
1864
+ - pretrained: pretrained model with vit-g
1865
+ - pretrain_vitL: pretrained model with vit-large
1866
+ - coco: fintuned model on coco
1867
+ Usage:
1868
+ >>> from lavis.models import load_model
1869
+ >>> model = load_model("blip2", "pretrain")
1870
+ """
1871
+
1872
+ PRETRAINED_MODEL_CONFIG_DICT = {
1873
+ "pretrain": "configs/models/blip2/blip2_pretrain.yaml",
1874
+ "pretrain_vitL": "configs/models/blip2/blip2_pretrain_vitL.yaml",
1875
+ "coco": "configs/models/blip2/blip2_coco.yaml",
1876
+ }
1877
+
1878
+ def __init__(self,
1879
+ config,
1880
+ img_size=224,
1881
+ drop_path_rate=0,
1882
+ use_grad_checkpoint=False,
1883
+ freeze_vit=True,
1884
+ num_query_token=32,
1885
+ cross_attention_freq=2,
1886
+ embed_dim=256,
1887
+ max_txt_len=32,
1888
+ codebook_embed_dim=32,
1889
+ n_embed=8192,
1890
+ recon_s=True,
1891
+ blocks_for_image=True,
1892
+ decode_depth=4,
1893
+ use_recon_s_for_image=False,
1894
+ image_features_dim=1024,
1895
+ visual_encoder_num_features=1408,
1896
+ cache_dir="./"):
1897
+ super().__init__(config)
1898
+
1899
+ self.tokenizer = self.init_tokenizer()
1900
+
1901
+ self.codebook_embed_dim = codebook_embed_dim
1902
+ self.n_embed = n_embed
1903
+ self.recon_s = recon_s
1904
+ self.blocks_for_image = blocks_for_image
1905
+ self.use_recon_s_for_image = use_recon_s_for_image
1906
+ self.depth = decode_depth
1907
+ self.image_features_dim = image_features_dim
1908
+
1909
+ self.Qformer, self.query_tokens = self.init_Qformer(config, num_query_token, visual_encoder_num_features, cache_dir=cache_dir)
1910
+
1911
+ self.Qformer.cls = None
1912
+ self.Qformer.bert.embeddings.word_embeddings = None
1913
+ self.Qformer.bert.embeddings.position_embeddings = None
1914
+ for layer in self.Qformer.bert.encoder.layer:
1915
+ layer.output = None
1916
+ layer.intermediate = None
1917
+
1918
+ for name, param in self.Qformer.named_parameters():
1919
+ param.requires_grad = False
1920
+ self.query_tokens.requires_grad = False
1921
+
1922
+ self.quantize = VectorQuantizer2(n_embed, codebook_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
1923
+
1924
+ self.encode_task_layer = nn.Sequential(
1925
+ nn.Linear(self.Qformer.config.hidden_size, self.Qformer.config.hidden_size),
1926
+ nn.Tanh(),
1927
+ nn.Linear(self.Qformer.config.hidden_size, codebook_embed_dim) # for quantize
1928
+ )
1929
+
1930
+ self.decode_task_layer = nn.Sequential(
1931
+ nn.Linear(codebook_embed_dim, codebook_embed_dim),
1932
+ nn.Tanh(),
1933
+ nn.Linear(codebook_embed_dim, self.Qformer.config.hidden_size) # for quantize
1934
+ )
1935
+
1936
+ self.quantize = self.quantize.eval()
1937
+ self.quantize.training = False
1938
+ for name, param in self.named_parameters():
1939
+ if 'quantize' in name or 'encode_task_layer' in name or 'decode_task_layer' in name:
1940
+ #print('freeze params', name)
1941
+ param.requires_grad = False
1942
+
1943
+ if self.recon_s:
1944
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_query_token, self.Qformer.config.hidden_size))
1945
+ self.blocks = nn.ModuleList([
1946
+ Block(dim=self.Qformer.config.hidden_size,
1947
+ num_heads=12,
1948
+ mlp_ratio=4.0,
1949
+ qkv_bias=True,
1950
+ qk_scale=None,
1951
+ drop=0.0,
1952
+ attn_drop=0.0,
1953
+ drop_path=0.0,
1954
+ norm_layer=partial(nn.LayerNorm, eps=1e-6)) for i in range(self.depth)
1955
+ ])
1956
+
1957
+ if self.blocks_for_image:
1958
+ self.pos_embed_image = nn.Parameter(torch.zeros(1, num_query_token, self.Qformer.config.hidden_size))
1959
+ self.blocks_image = nn.ModuleList([
1960
+ Block(dim=self.Qformer.config.hidden_size,
1961
+ num_heads=12,
1962
+ mlp_ratio=4.0,
1963
+ qkv_bias=True,
1964
+ qk_scale=None,
1965
+ drop=0.0,
1966
+ attn_drop=0.0,
1967
+ drop_path=0.0,
1968
+ norm_layer=partial(nn.LayerNorm, eps=1e-6)) for i in range(self.depth)
1969
+ ])
1970
+
1971
+ self.image_down = nn.Sequential(
1972
+ nn.Linear(self.Qformer.config.hidden_size, 256, bias=False),
1973
+ nn.ReLU(),
1974
+ nn.Linear(256, 128, bias=False),
1975
+ nn.ReLU(),
1976
+ nn.Linear(128, 32, bias=False),
1977
+ )
1978
+ self.distill_image_proj = nn.Linear(num_query_token * 32, image_features_dim)
1979
+
1980
+ @classmethod
1981
+ def load_from_pretrained(cls, config, pretrained_model_path, **kwargs):
1982
+ img_size = kwargs.get("image_size", 224)
1983
+ num_query_token = kwargs.get("num_query_token", 32)
1984
+ cross_attention_freq = kwargs.get("cross_attention_freq", 2)
1985
+
1986
+ drop_path_rate = kwargs.get("drop_path_rate", 0)
1987
+ use_grad_checkpoint = kwargs.get("use_grad_checkpoint", False)
1988
+ freeze_vit = kwargs.get("freeze_vit", True)
1989
+ cache_dir = kwargs.get("cache_dir", "./")
1990
+
1991
+ max_txt_len = kwargs.get("max_txt_len", 32)
1992
+
1993
+ model = cls(config,
1994
+ img_size=img_size,
1995
+ drop_path_rate=drop_path_rate,
1996
+ use_grad_checkpoint=use_grad_checkpoint,
1997
+ freeze_vit=freeze_vit,
1998
+ num_query_token=num_query_token,
1999
+ cross_attention_freq=cross_attention_freq,
2000
+ max_txt_len=max_txt_len,
2001
+ cache_dir=cache_dir,
2002
+ )
2003
+
2004
+ ckpt = torch.load(cache_dir+pretrained_model_path, map_location="cpu")
2005
+ missing, unexcepted = model.load_state_dict(ckpt, strict=False)
2006
+ #print('**** missing keys: ', missing)
2007
+ #print('***unexpected keys:', unexcepted)
2008
+ return model
2009
+
2010
+
2011
+
2012
+ def get_codebook_indices(self, visual_encoder, image):
2013
+ with torch.no_grad():
2014
+ with self.maybe_autocast():
2015
+ image_embeds = visual_encoder.ln_vision(visual_encoder(image))
2016
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
2017
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
2018
+ query_output = self.Qformer.bert(
2019
+ query_embeds=query_tokens,
2020
+ encoder_hidden_states=image_embeds,
2021
+ encoder_attention_mask=image_atts,
2022
+ return_dict=True,
2023
+ )
2024
+
2025
+ query_output_down = self.encode_task_layer(query_output.last_hidden_state)
2026
+ quant, loss_embed, embed_ind = self.quantize(query_output_down)
2027
+ embed_ind = embed_ind.reshape(quant.shape[0], -1)
2028
+
2029
+ query_output_up = self.decode_task_layer(quant)
2030
+
2031
+ return embed_ind, query_output_up
2032
+
2033
+ def get_codebook_entry(self, indices):
2034
+ with torch.no_grad():
2035
+ quant_embedding = self.quantize.get_codebook_entry(indices)
2036
+ # print('quant_embedding_shape: ', quant_embedding.shape)
2037
+ # print(self.decode_task_layer)
2038
+ # exit()
2039
+ query_output_up = self.decode_task_layer(quant_embedding)
2040
+
2041
+ pos_embed_image = self.pos_embed_image.repeat(query_output_up.shape[0], 1, 1)
2042
+ query_output_up_pos_image = query_output_up + pos_embed_image
2043
+ for blk in self.blocks_image:
2044
+ query_output_up_pos_image = blk(query_output_up_pos_image)
2045
+ query_output_up = query_output_up_pos_image
2046
+
2047
+ reverse_output = self.image_down(query_output_up)
2048
+ reverse_output = reverse_output.reshape(reverse_output.shape[0], -1)
2049
+ reverse_output_proj = self.distill_image_proj(reverse_output)
2050
+
2051
+ return reverse_output_proj
2052
+
2053
+ @classmethod
2054
+ def get_vision_encoder(cls,model_name="eva_vit_g",
2055
+ img_size=224,
2056
+ drop_path_rate=0,
2057
+ use_grad_checkpoint=False,
2058
+ precision="fp32",
2059
+ cache_dir="./"):
2060
+ visual_encoder = create_eva_vit_g(img_size, drop_path_rate, use_grad_checkpoint, precision, cache_dir=cache_dir)
2061
+ visual_encoder.ln_vision = LayerNorm(visual_encoder.num_features)
2062
+ for name, param in visual_encoder.named_parameters():
2063
+ param.requires_grad = False
2064
+ visual_encoder = visual_encoder.eval()
2065
+ visual_encoder.train = disabled_train
2066
+ logging.info("freeze vision encoder")
2067
+ visual_encoder.ln_vision.weight.requires_grad = False
2068
+ visual_encoder.ln_vision.bias.requires_grad = False
2069
+ return visual_encoder
2070
+
2071
+ class Seed2Tokenizer(PreTrainedModel):
2072
+ config_class = BertConfig
2073
+ base_model_prefix = "model"
2074
+ def __init__(self,
2075
+ config,
2076
+ image_size=224,
2077
+ drop_path_rate=0.4):
2078
+ super().__init__(config)
2079
+
2080
+ model = Blip2QformerQuantizer(config) # .from_pretrained(pretrained_model_path=model_path,
2081
+ # cache_dir=cache_dir,
2082
+ # **kwargs).eval()
2083
+ #model = model.to(device)
2084
+
2085
+ processor = transforms.Compose([
2086
+ transforms.Resize((image_size, image_size), interpolation=3),
2087
+ # transforms.Resize(image_size, interpolation=3),
2088
+ # transforms.CenterCrop(image_size),
2089
+ transforms.ToTensor(),
2090
+ transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
2091
+ ])
2092
+
2093
+ shape_latents = torch.Size([1, 4, 96, 96])
2094
+ self.latents = torch.randn(shape_latents, generator=None, layout=torch.strided)
2095
+
2096
+ shape_noise = torch.Size([1, 1024])
2097
+ self.noise = torch.randn(shape_noise, generator=None, layout=torch.strided)
2098
+
2099
+ self.model = model
2100
+ self.processor = processor
2101
+ self.visual_encoder = VisionTransformerEvaClip(
2102
+ img_size=image_size,
2103
+ patch_size=14,
2104
+ use_mean_pooling=False,
2105
+ embed_dim=1408,
2106
+ depth=39,
2107
+ num_heads=1408 // 88,
2108
+ mlp_ratio=4.3637,
2109
+ qkv_bias=True,
2110
+ drop_path_rate=drop_path_rate,
2111
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
2112
+ use_checkpoint=False,
2113
+ )
2114
+
2115
+
2116
+ def __len__(self):
2117
+ return self.model.n_embed
2118
+
2119
+ def encode(self, visual_encoder, image_torch):
2120
+ '''Convert a batch of img to code
2121
+ Args:
2122
+ model: The tokenizer model.
2123
+ img: [b, c, h, w]
2124
+ '''
2125
+ if len(image_torch.shape) == 3:
2126
+ image_torch = image_torch.unsqueeze(0)
2127
+
2128
+ # img = image_torch.to(self.device)
2129
+ img = image_torch
2130
+ #if self.fp16:
2131
+ # img = img.half()
2132
+ with torch.no_grad():
2133
+ id, _ = self.model.get_codebook_indices(visual_encoder, img)
2134
+ return id.view(img.shape[0], -1)
2135
+
2136
+ def decode(self, diffusion_model, indices, negative_indices=None, guidance_scale=10, num_inference_steps=20):
2137
+ image_embeds = self.model.get_codebook_entry(indices)
2138
+ # image = self.diffusion_model(image_embeds=image_embed,
2139
+ # noise_level=0,
2140
+ # num_inference_steps=20,
2141
+ # latents=self.latents,
2142
+ # noise=self.noise).images
2143
+ if negative_indices is not None:
2144
+ assert indices.shape == negative_indices.shape, 'Negative indices must have the same shape with indices'
2145
+ negative_image_embeds = self.model.get_codebook_entry(negative_indices)
2146
+ else:
2147
+ negative_image_embeds = None
2148
+
2149
+ image = diffusion_model(
2150
+ image_embeds=image_embeds,
2151
+ negative_image_embeds=negative_image_embeds,
2152
+ guidance_scale=guidance_scale,
2153
+ noise_level=0,
2154
+ num_inference_steps=num_inference_steps,
2155
+ latents=self.latents,
2156
+ ).images
2157
+ return image
2158
+
2159
+ @property
2160
+ def num_image_tokens(self):
2161
+ return 8192 # self.image_tokenizer.num_tokens # allow not load
2162
+
2163
+ def encode_image(
2164
+ self,
2165
+ visual_encoder,
2166
+ image_path=None,
2167
+ image_pil=None,
2168
+ image_torch=None,
2169
+ image_size: int = 224,
2170
+ ):
2171
+ assert (image_path is None) + (image_pil is None) + (image_torch is None) == 2
2172
+
2173
+ # need_norm_to_1 = False
2174
+ if image_path is not None:
2175
+ image_pil = Image.open(image_path).convert('RGB')
2176
+
2177
+ if image_pil is not None:
2178
+ image_torch = self.processor(image_pil)
2179
+
2180
+ image_torch = image_torch.to(self.device)
2181
+ return self.encode(visual_encoder, image_torch)
2182
+
2183
+ if __name__ == "__main__":
2184
+ tokenizer = Seed2Tokenizer.from_pretrained("ontocord/seed2")
2185
+ print (tokenizer)
2186
+ tokens = tokenizer.encode_image(tokenizer.visual_encoder, "../dog3.jpg")
2187
+ print (tokens)
2188
+ image_embeds = tokenizer.model.get_codebook_entry(tokens)
2189
+ print (image_embeds)
2190
+