Rihong commited on
Commit
4e5ef76
·
verified ·
1 Parent(s): 698c03f

Upload folder using huggingface_hub

Browse files
Qformer.py ADDED
@@ -0,0 +1,1272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Copyright (c) 2023, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ """
10
+
11
+ import math
12
+ from typing import Tuple
13
+
14
+ import torch
15
+ from torch import Tensor, device, nn
16
+ import torch.utils.checkpoint
17
+ from torch import nn
18
+ from torch.nn import CrossEntropyLoss
19
+
20
+ # from timm.layers import drop_path
21
+ from transformers.activations import ACT2FN
22
+ from transformers.modeling_outputs import (
23
+ BaseModelOutputWithPastAndCrossAttentions,
24
+ BaseModelOutputWithPoolingAndCrossAttentions,
25
+ CausalLMOutputWithCrossAttentions,
26
+ MaskedLMOutput,
27
+ )
28
+ from transformers.modeling_utils import (
29
+ PreTrainedModel,
30
+ # apply_chunking_to_forward,
31
+ # find_pruneable_heads_and_indices,
32
+ # prune_linear_layer,
33
+ )
34
+ from transformers.pytorch_utils import (
35
+ # PreTrainedModel,
36
+ apply_chunking_to_forward,
37
+ find_pruneable_heads_and_indices,
38
+ prune_linear_layer,
39
+ )
40
+ from transformers.utils import logging
41
+ from transformers.models.bert.configuration_bert import BertConfig
42
+
43
+ from functools import partial
44
+ from .ltm.long_term_attention_gibbs import LongTermAttention
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+
49
+ class BertEmbeddings(nn.Module):
50
+ """Construct the embeddings from word and position embeddings."""
51
+
52
+ def __init__(self, config):
53
+ super().__init__()
54
+ self.word_embeddings = nn.Embedding(
55
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
56
+ )
57
+ self.position_embeddings = nn.Embedding(
58
+ config.max_position_embeddings, config.hidden_size
59
+ )
60
+
61
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
62
+ # any TensorFlow checkpoint file
63
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
64
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
65
+
66
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
67
+ self.register_buffer(
68
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
69
+ )
70
+ self.position_embedding_type = getattr(
71
+ config, "position_embedding_type", "absolute"
72
+ )
73
+
74
+ self.config = config
75
+
76
+ def forward(
77
+ self,
78
+ input_ids=None,
79
+ position_ids=None,
80
+ query_embeds=None,
81
+ past_key_values_length=0,
82
+ ):
83
+ if input_ids is not None:
84
+ seq_length = input_ids.size()[1]
85
+ else:
86
+ seq_length = 0
87
+
88
+ if position_ids is None:
89
+ position_ids = self.position_ids[
90
+ :, past_key_values_length : seq_length + past_key_values_length
91
+ ].clone()
92
+
93
+ if input_ids is not None:
94
+ embeddings = self.word_embeddings(input_ids)
95
+ if self.position_embedding_type == "absolute":
96
+ position_embeddings = self.position_embeddings(position_ids)
97
+ embeddings = embeddings + position_embeddings
98
+
99
+ if query_embeds is not None:
100
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
101
+ else:
102
+ embeddings = query_embeds
103
+
104
+ embeddings = self.LayerNorm(embeddings)
105
+ embeddings = self.dropout(embeddings)
106
+ return embeddings
107
+
108
+
109
+ class BertSelfAttention(nn.Module):
110
+ def __init__(self, config, is_cross_attention):
111
+ super().__init__()
112
+ self.config = config
113
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
114
+ config, "embedding_size"
115
+ ):
116
+ raise ValueError(
117
+ "The hidden size (%d) is not a multiple of the number of attention "
118
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
119
+ )
120
+ self.num_attention_heads = config.num_attention_heads
121
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
122
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
123
+ self.is_cross_attention=is_cross_attention
124
+ self.alpha = config.alpha
125
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
126
+ if is_cross_attention:
127
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
128
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
129
+ long_term_attn_mechanism = partial(LongTermAttention,
130
+ attn_num_basis=config.num_basis,
131
+ head_size=self.attention_head_size,
132
+ length=config.encoder_width,
133
+ target_len=config.encoder_width,
134
+ attn_func="softmax",
135
+ infinite_memory=True,
136
+ n_layers=2,
137
+ attn_drop=0.1,
138
+ n_heads=self.num_attention_heads,
139
+ d_model=self.all_head_size,
140
+ affines=True,
141
+ mask=True,
142
+ mask_type="cnn",
143
+ kl_regularizer=False,
144
+ sigma_0=None,
145
+ mu_0=None,
146
+ sticky_memories=config.sticky,
147
+ continuous=True,
148
+ sigmas = 1,
149
+ tau = config.tau,
150
+ proj_key=self.key,
151
+ proj_value=self.value
152
+ )
153
+ self.long_term_attention=long_term_attn_mechanism()
154
+ if not is_cross_attention:
155
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
156
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
157
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
158
+ self.position_embedding_type = getattr(
159
+ config, "position_embedding_type", "absolute"
160
+ )
161
+ if (
162
+ self.position_embedding_type == "relative_key"
163
+ or self.position_embedding_type == "relative_key_query"
164
+ ):
165
+ self.max_position_embeddings = config.max_position_embeddings
166
+ self.distance_embedding = nn.Embedding(
167
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
168
+ )
169
+ self.save_attention = False
170
+
171
+ def save_attn_gradients(self, attn_gradients):
172
+ self.attn_gradients = attn_gradients
173
+
174
+ def get_attn_gradients(self):
175
+ return self.attn_gradients
176
+
177
+ def save_attention_map(self, attention_map):
178
+ self.attention_map = attention_map
179
+
180
+ def get_attention_map(self):
181
+ return self.attention_map
182
+
183
+ def transpose_for_scores(self, x):
184
+ new_x_shape = x.size()[:-1] + (
185
+ self.num_attention_heads,
186
+ self.attention_head_size,
187
+ )
188
+ x = x.view(*new_x_shape)
189
+ return x.permute(0, 2, 1, 3)
190
+
191
+ def forward(
192
+ self,
193
+ hidden_states,
194
+ position_embedding_ext,
195
+ layer,
196
+ attention_mask=None,
197
+ head_mask=None,
198
+ encoder_hidden_states=None,
199
+ encoder_attention_mask=None,
200
+ past_key_value=None,
201
+ output_attentions=False,
202
+ new_video=False,
203
+ ):
204
+
205
+ mixed_query_layer = self.query(hidden_states) #[1, 32, 768]
206
+ # If this is instantiated as a cross-attention module, the keys
207
+ # and values come from an encoder; the attention mask needs to be
208
+ # such that the encoder's padding tokens are not attended to.
209
+ is_cross_attention = self.is_cross_attention
210
+ if is_cross_attention:
211
+ bsz, p, h = encoder_hidden_states.shape
212
+ self.long_term_attention.length = p
213
+ self.long_term_attention.target_len = p
214
+ if self.alpha != 1.0:
215
+ a_long_term = self.long_term_attention(encoder_hidden_states, mixed_query_layer, new_doc=new_video, layer_n=layer).detach()
216
+ else:
217
+ a_long_term = 0
218
+ if is_cross_attention:
219
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
220
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
221
+ attention_mask = encoder_attention_mask
222
+ elif past_key_value is not None:
223
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
224
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
225
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
226
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
227
+ else:
228
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
229
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
230
+
231
+
232
+ query_layer = self.transpose_for_scores(mixed_query_layer)#[1,12,32,64]
233
+ past_key_value = (key_layer, value_layer)
234
+
235
+ # Take the dot product between "query" and "key" to get the raw attention scores.
236
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
237
+
238
+ if (
239
+ self.position_embedding_type == "relative_key"
240
+ or self.position_embedding_type == "relative_key_query"
241
+ ):
242
+ seq_length = hidden_states.size()[1]
243
+ position_ids_l = torch.arange(
244
+ seq_length, dtype=torch.long, device=hidden_states.device
245
+ ).view(-1, 1)
246
+ position_ids_r = torch.arange(
247
+ seq_length, dtype=torch.long, device=hidden_states.device
248
+ ).view(1, -1)
249
+ distance = position_ids_l - position_ids_r
250
+ positional_embedding = self.distance_embedding(
251
+ distance + self.max_position_embeddings - 1
252
+ )
253
+ positional_embedding = positional_embedding.to(
254
+ dtype=query_layer.dtype
255
+ ) # fp16 compatibility
256
+
257
+ if self.position_embedding_type == "relative_key":
258
+ relative_position_scores = torch.einsum(
259
+ "bhld,lrd->bhlr", query_layer, positional_embedding
260
+ )
261
+ attention_scores = attention_scores + relative_position_scores
262
+ elif self.position_embedding_type == "relative_key_query":
263
+ relative_position_scores_query = torch.einsum(
264
+ "bhld,lrd->bhlr", query_layer, positional_embedding
265
+ )
266
+ relative_position_scores_key = torch.einsum(
267
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
268
+ )
269
+ attention_scores = (
270
+ attention_scores
271
+ + relative_position_scores_query
272
+ + relative_position_scores_key
273
+ )
274
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
275
+ if attention_mask is not None:
276
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
277
+ attention_scores = attention_scores + attention_mask
278
+
279
+ # Normalize the attention scores to probabilities.
280
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
281
+
282
+ if is_cross_attention and self.save_attention:
283
+ self.save_attention_map(attention_probs)
284
+ attention_probs.register_hook(self.save_attn_gradients)
285
+
286
+ # This is actually dropping out entire tokens to attend to, which might
287
+ # seem a bit unusual, but is taken from the original Transformer paper.
288
+ attention_probs_dropped = self.dropout(attention_probs)
289
+ # Mask heads if we want to
290
+ if head_mask is not None:
291
+ attention_probs_dropped = attention_probs_dropped * head_mask
292
+
293
+ context_layer = torch.matmul(attention_probs_dropped, value_layer) #[1, 12, 32, 64]
294
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
295
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
296
+ context_layer = context_layer.view(*new_context_layer_shape)
297
+ if is_cross_attention:
298
+ context_layer = self.alpha*context_layer + (1-self.alpha)*a_long_term
299
+ outputs = (
300
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
301
+ )
302
+
303
+ outputs = outputs + (past_key_value,)
304
+ return outputs
305
+
306
+
307
+ class BertSelfOutput(nn.Module):
308
+ def __init__(self, config):
309
+ super().__init__()
310
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
311
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
312
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
313
+
314
+ def forward(self, hidden_states, input_tensor):
315
+ hidden_states = self.dense(hidden_states)
316
+ hidden_states = self.dropout(hidden_states)
317
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
318
+ return hidden_states
319
+
320
+
321
+ class BertAttention(nn.Module):
322
+ def __init__(self, config, is_cross_attention=False):
323
+ super().__init__()
324
+ self.self = BertSelfAttention(config, is_cross_attention)
325
+ self.output = BertSelfOutput(config)
326
+ self.pruned_heads = set()
327
+
328
+ def prune_heads(self, heads):
329
+ if len(heads) == 0:
330
+ return
331
+ heads, index = find_pruneable_heads_and_indices(
332
+ heads,
333
+ self.self.num_attention_heads,
334
+ self.self.attention_head_size,
335
+ self.pruned_heads,
336
+ )
337
+
338
+ # Prune linear layers
339
+ self.self.query = prune_linear_layer(self.self.query, index)
340
+ self.self.key = prune_linear_layer(self.self.key, index)
341
+ self.self.value = prune_linear_layer(self.self.value, index)
342
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
343
+
344
+ # Update hyper params and store pruned heads
345
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
346
+ self.self.all_head_size = (
347
+ self.self.attention_head_size * self.self.num_attention_heads
348
+ )
349
+ self.pruned_heads = self.pruned_heads.union(heads)
350
+
351
+ def forward(
352
+ self,
353
+ hidden_states,
354
+ position_embedding_ext,
355
+ layer,
356
+ attention_mask=None,
357
+ head_mask=None,
358
+ encoder_hidden_states=None,
359
+ encoder_attention_mask=None,
360
+ past_key_value=None,
361
+ output_attentions=False,
362
+ new_video=False,
363
+ ):
364
+ self_outputs = self.self(
365
+ hidden_states,
366
+ position_embedding_ext,
367
+ layer,
368
+ attention_mask,
369
+ head_mask,
370
+ encoder_hidden_states,
371
+ encoder_attention_mask,
372
+ past_key_value,
373
+ output_attentions,
374
+ new_video=new_video,
375
+ )
376
+ attention_output = self.output(self_outputs[0], hidden_states)
377
+
378
+ outputs = (attention_output,) + self_outputs[
379
+ 1:
380
+ ] # add attentions if we output them
381
+ return outputs
382
+
383
+
384
+ class BertIntermediate(nn.Module):
385
+ def __init__(self, config):
386
+ super().__init__()
387
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
388
+ if isinstance(config.hidden_act, str):
389
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
390
+ else:
391
+ self.intermediate_act_fn = config.hidden_act
392
+
393
+ def forward(self, hidden_states):
394
+ hidden_states = self.dense(hidden_states)
395
+ hidden_states = self.intermediate_act_fn(hidden_states)
396
+ return hidden_states
397
+
398
+
399
+ class BertOutput(nn.Module):
400
+ def __init__(self, config):
401
+ super().__init__()
402
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
403
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
404
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
405
+
406
+ def forward(self, hidden_states, input_tensor):
407
+ hidden_states = self.dense(hidden_states)
408
+ hidden_states = self.dropout(hidden_states)
409
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
410
+ return hidden_states
411
+
412
+
413
+ class BertLayer(nn.Module):
414
+ def __init__(self, config, layer_num):
415
+ super().__init__()
416
+ self.config = config
417
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
418
+ self.seq_len_dim = 1
419
+ self.attention = BertAttention(config)
420
+ self.layer_num = layer_num
421
+ if (
422
+ self.config.add_cross_attention
423
+ and layer_num % self.config.cross_attention_freq == 0
424
+ ):
425
+ self.crossattention = BertAttention(
426
+ config, is_cross_attention=self.config.add_cross_attention
427
+ )
428
+ self.has_cross_attention = True
429
+ else:
430
+ self.has_cross_attention = False
431
+ self.intermediate = BertIntermediate(config)
432
+ self.output = BertOutput(config)
433
+
434
+ self.intermediate_query = BertIntermediate(config)
435
+ self.output_query = BertOutput(config)
436
+
437
+ def forward(
438
+ self,
439
+ hidden_states,
440
+ position_embedding_ext,
441
+ layer,
442
+ attention_mask=None,
443
+ head_mask=None,
444
+ encoder_hidden_states=None,
445
+ encoder_attention_mask=None,
446
+ past_key_value=None,
447
+ output_attentions=False,
448
+ query_length=0,
449
+ new_video=False,
450
+ ):
451
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
452
+ self_attn_past_key_value = (
453
+ past_key_value[:2] if past_key_value is not None else None
454
+ )
455
+ self_attention_outputs = self.attention(
456
+ hidden_states,
457
+ position_embedding_ext,
458
+ layer,
459
+ attention_mask,
460
+ head_mask,
461
+ output_attentions=output_attentions,
462
+ past_key_value=self_attn_past_key_value,
463
+ new_video=new_video,
464
+ )
465
+ attention_output = self_attention_outputs[0]
466
+ outputs = self_attention_outputs[1:-1]
467
+
468
+ present_key_value = self_attention_outputs[-1]
469
+
470
+ if query_length > 0:
471
+ query_attention_output = attention_output[:, :query_length, :]
472
+
473
+ if self.has_cross_attention:
474
+ assert (
475
+ encoder_hidden_states is not None
476
+ ), "encoder_hidden_states must be given for cross-attention layers"
477
+ cross_attention_outputs = self.crossattention(
478
+ query_attention_output,
479
+ position_embedding_ext,
480
+ layer,
481
+ attention_mask,
482
+ head_mask,
483
+ encoder_hidden_states,
484
+ encoder_attention_mask,
485
+ output_attentions=output_attentions,
486
+ new_video=new_video,
487
+ )
488
+ query_attention_output = cross_attention_outputs[0]
489
+ outputs = (
490
+ outputs + cross_attention_outputs[1:-1]
491
+ ) # add cross attentions if we output attention weights
492
+
493
+ layer_output = apply_chunking_to_forward(
494
+ self.feed_forward_chunk_query,
495
+ self.chunk_size_feed_forward,
496
+ self.seq_len_dim,
497
+ query_attention_output,
498
+ )
499
+ if attention_output.shape[1] > query_length:
500
+ layer_output_text = apply_chunking_to_forward(
501
+ self.feed_forward_chunk,
502
+ self.chunk_size_feed_forward,
503
+ self.seq_len_dim,
504
+ attention_output[:, query_length:, :],
505
+ )
506
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
507
+ else:
508
+ layer_output = apply_chunking_to_forward(
509
+ self.feed_forward_chunk,
510
+ self.chunk_size_feed_forward,
511
+ self.seq_len_dim,
512
+ attention_output,
513
+ )
514
+ outputs = (layer_output,) + outputs
515
+
516
+ outputs = outputs + (present_key_value,)
517
+
518
+ return outputs
519
+
520
+ def feed_forward_chunk(self, attention_output):
521
+ intermediate_output = self.intermediate(attention_output)
522
+ layer_output = self.output(intermediate_output, attention_output)
523
+ return layer_output
524
+
525
+ def feed_forward_chunk_query(self, attention_output):
526
+ intermediate_output = self.intermediate_query(attention_output)
527
+ layer_output = self.output_query(intermediate_output, attention_output)
528
+ return layer_output
529
+
530
+
531
+ class BertEncoder(nn.Module):
532
+ def __init__(self, config):
533
+ super().__init__()
534
+ self.config = config
535
+ self.layer = nn.ModuleList(
536
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
537
+ )
538
+
539
+ def forward(
540
+ self,
541
+ hidden_states,
542
+ position_embedding_ext,
543
+ attention_mask=None,
544
+ head_mask=None,
545
+ encoder_hidden_states=None,
546
+ encoder_attention_mask=None,
547
+ past_key_values=None,
548
+ use_cache=None,
549
+ output_attentions=False,
550
+ output_hidden_states=False,
551
+ return_dict=True,
552
+ query_length=0,
553
+ new_video=False,
554
+ ):
555
+ all_hidden_states = () if output_hidden_states else None
556
+ all_self_attentions = () if output_attentions else None
557
+ all_cross_attentions = (
558
+ () if output_attentions and self.config.add_cross_attention else None
559
+ )
560
+
561
+ next_decoder_cache = () if use_cache else None
562
+
563
+ for i in range(self.config.num_hidden_layers):
564
+ layer_module = self.layer[i]
565
+ if output_hidden_states:
566
+ all_hidden_states = all_hidden_states + (hidden_states,)
567
+
568
+ layer_head_mask = head_mask[i] if head_mask is not None else None
569
+ past_key_value = past_key_values[i] if past_key_values is not None else None
570
+
571
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
572
+
573
+ if use_cache:
574
+ logger.warn(
575
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
576
+ )
577
+ use_cache = False
578
+
579
+ def create_custom_forward(module):
580
+ def custom_forward(*inputs):
581
+ return module(
582
+ *inputs, past_key_value, output_attentions, query_length
583
+ )
584
+
585
+ return custom_forward
586
+
587
+ layer_outputs = torch.utils.checkpoint.checkpoint(
588
+ create_custom_forward(layer_module),
589
+ hidden_states,
590
+ position_embedding_ext,
591
+ i,
592
+ attention_mask,
593
+ layer_head_mask,
594
+ encoder_hidden_states,
595
+ encoder_attention_mask,
596
+ new_video=new_video
597
+ )
598
+ else:
599
+ layer_outputs = layer_module(
600
+ hidden_states,
601
+ position_embedding_ext,
602
+ i,
603
+ attention_mask,
604
+ layer_head_mask,
605
+ encoder_hidden_states,
606
+ encoder_attention_mask,
607
+ past_key_value,
608
+ output_attentions,
609
+ query_length,
610
+ new_video=new_video,
611
+ )
612
+
613
+ hidden_states = layer_outputs[0]
614
+ if use_cache:
615
+ next_decoder_cache += (layer_outputs[-1],)
616
+ if output_attentions:
617
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
618
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
619
+
620
+ if output_hidden_states:
621
+ all_hidden_states = all_hidden_states + (hidden_states,)
622
+
623
+ if not return_dict:
624
+ return tuple(
625
+ v
626
+ for v in [
627
+ hidden_states,
628
+ next_decoder_cache,
629
+ all_hidden_states,
630
+ all_self_attentions,
631
+ all_cross_attentions,
632
+ ]
633
+ if v is not None
634
+ )
635
+ return BaseModelOutputWithPastAndCrossAttentions(
636
+ last_hidden_state=hidden_states,
637
+ past_key_values=next_decoder_cache,
638
+ hidden_states=all_hidden_states,
639
+ attentions=all_self_attentions,
640
+ cross_attentions=all_cross_attentions,
641
+ )
642
+
643
+
644
+ class BertPooler(nn.Module):
645
+ def __init__(self, config):
646
+ super().__init__()
647
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
648
+ self.activation = nn.Tanh()
649
+
650
+ def forward(self, hidden_states):
651
+ # We "pool" the model by simply taking the hidden state corresponding
652
+ # to the first token.
653
+ first_token_tensor = hidden_states[:, 0]
654
+ pooled_output = self.dense(first_token_tensor)
655
+ pooled_output = self.activation(pooled_output)
656
+ return pooled_output
657
+
658
+
659
+ class BertPredictionHeadTransform(nn.Module):
660
+ def __init__(self, config):
661
+ super().__init__()
662
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
663
+ if isinstance(config.hidden_act, str):
664
+ self.transform_act_fn = ACT2FN[config.hidden_act]
665
+ else:
666
+ self.transform_act_fn = config.hidden_act
667
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
668
+
669
+ def forward(self, hidden_states):
670
+ hidden_states = self.dense(hidden_states)
671
+ hidden_states = self.transform_act_fn(hidden_states)
672
+ hidden_states = self.LayerNorm(hidden_states)
673
+ return hidden_states
674
+
675
+
676
+ class BertLMPredictionHead(nn.Module):
677
+ def __init__(self, config):
678
+ super().__init__()
679
+ self.transform = BertPredictionHeadTransform(config)
680
+
681
+ # The output weights are the same as the input embeddings, but there is
682
+ # an output-only bias for each token.
683
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
684
+
685
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
686
+
687
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
688
+ self.decoder.bias = self.bias
689
+
690
+ def forward(self, hidden_states):
691
+ hidden_states = self.transform(hidden_states)
692
+ hidden_states = self.decoder(hidden_states)
693
+ return hidden_states
694
+
695
+
696
+ class BertOnlyMLMHead(nn.Module):
697
+ def __init__(self, config):
698
+ super().__init__()
699
+ self.predictions = BertLMPredictionHead(config)
700
+
701
+ def forward(self, sequence_output):
702
+ prediction_scores = self.predictions(sequence_output)
703
+ return prediction_scores
704
+
705
+
706
+ class BertPreTrainedModel(PreTrainedModel):
707
+ """
708
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
709
+ models.
710
+ """
711
+
712
+ config_class = BertConfig
713
+ base_model_prefix = "bert"
714
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
715
+
716
+ def _init_weights(self, module):
717
+ """Initialize the weights"""
718
+ if isinstance(module, (nn.Linear, nn.Embedding)):
719
+ # Slightly different from the TF version which uses truncated_normal for initialization
720
+ # cf https://github.com/pytorch/pytorch/pull/5617
721
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
722
+ elif isinstance(module, nn.LayerNorm):
723
+ module.bias.data.zero_()
724
+ module.weight.data.fill_(1.0)
725
+ if isinstance(module, nn.Linear) and module.bias is not None:
726
+ module.bias.data.zero_()
727
+
728
+
729
+ class BertModel(BertPreTrainedModel):
730
+ """
731
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
732
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
733
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
734
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
735
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
736
+ input to the forward pass.
737
+ """
738
+
739
+ def __init__(self, config, add_pooling_layer=False):
740
+ super().__init__(config)
741
+ self.config = config
742
+
743
+ self.embeddings = BertEmbeddings(config)
744
+
745
+ self.encoder = BertEncoder(config)
746
+
747
+ self.pooler = BertPooler(config) if add_pooling_layer else None
748
+
749
+ self.init_weights()
750
+
751
+ def get_input_embeddings(self):
752
+ return self.embeddings.word_embeddings
753
+
754
+ def set_input_embeddings(self, value):
755
+ self.embeddings.word_embeddings = value
756
+
757
+ def _prune_heads(self, heads_to_prune):
758
+ """
759
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
760
+ class PreTrainedModel
761
+ """
762
+ for layer, heads in heads_to_prune.items():
763
+ self.encoder.layer[layer].attention.prune_heads(heads)
764
+
765
+ def get_extended_attention_mask(
766
+ self,
767
+ attention_mask: Tensor,
768
+ input_shape: Tuple[int],
769
+ device: device,
770
+ is_decoder: bool,
771
+ has_query: bool = False,
772
+ ) -> Tensor:
773
+ """
774
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
775
+
776
+ Arguments:
777
+ attention_mask (:obj:`torch.Tensor`):
778
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
779
+ input_shape (:obj:`Tuple[int]`):
780
+ The shape of the input to the model.
781
+ device: (:obj:`torch.device`):
782
+ The device of the input to the model.
783
+
784
+ Returns:
785
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
786
+ """
787
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
788
+ # ourselves in which case we just need to make it broadcastable to all heads.
789
+ if attention_mask.dim() == 3:
790
+ extended_attention_mask = attention_mask[:, None, :, :]
791
+ elif attention_mask.dim() == 2:
792
+ # Provided a padding mask of dimensions [batch_size, seq_length]
793
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
794
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
795
+ if is_decoder:
796
+ batch_size, seq_length = input_shape
797
+
798
+ seq_ids = torch.arange(seq_length, device=device)
799
+ causal_mask = (
800
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
801
+ <= seq_ids[None, :, None]
802
+ )
803
+
804
+ # add a prefix ones mask to the causal mask
805
+ # causal and attention masks must have same type with pytorch version < 1.3
806
+ causal_mask = causal_mask.to(attention_mask.dtype)
807
+
808
+ if causal_mask.shape[1] < attention_mask.shape[1]:
809
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
810
+ if has_query: # UniLM style attention mask
811
+ causal_mask = torch.cat(
812
+ [
813
+ torch.zeros(
814
+ (batch_size, prefix_seq_len, seq_length),
815
+ device=device,
816
+ dtype=causal_mask.dtype,
817
+ ),
818
+ causal_mask,
819
+ ],
820
+ axis=1,
821
+ )
822
+ causal_mask = torch.cat(
823
+ [
824
+ torch.ones(
825
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
826
+ device=device,
827
+ dtype=causal_mask.dtype,
828
+ ),
829
+ causal_mask,
830
+ ],
831
+ axis=-1,
832
+ )
833
+ extended_attention_mask = (
834
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
835
+ )
836
+ else:
837
+ extended_attention_mask = attention_mask[:, None, None, :]
838
+ else:
839
+ raise ValueError(
840
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
841
+ input_shape, attention_mask.shape
842
+ )
843
+ )
844
+
845
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
846
+ # masked positions, this operation will create a tensor which is 0.0 for
847
+ # positions we want to attend and -10000.0 for masked positions.
848
+ # Since we are adding it to the raw scores before the softmax, this is
849
+ # effectively the same as removing these entirely.
850
+ extended_attention_mask = extended_attention_mask.to(
851
+ dtype=self.dtype
852
+ ) # fp16 compatibility
853
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
854
+ return extended_attention_mask
855
+
856
+ def forward(
857
+ self,
858
+ input_ids=None,
859
+ position_embedding_ext=None,
860
+ attention_mask=None,
861
+ position_ids=None,
862
+ head_mask=None,
863
+ query_embeds=None,
864
+ encoder_hidden_states=None,
865
+ encoder_attention_mask=None,
866
+ past_key_values=None,
867
+ use_cache=None,
868
+ output_attentions=None,
869
+ output_hidden_states=None,
870
+ return_dict=None,
871
+ is_decoder=False,
872
+ new_video=False,
873
+ ):
874
+ r"""
875
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
876
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
877
+ the model is configured as a decoder.
878
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
879
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
880
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
881
+ - 1 for tokens that are **not masked**,
882
+ - 0 for tokens that are **masked**.
883
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
884
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
885
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
886
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
887
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
888
+ use_cache (:obj:`bool`, `optional`):
889
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
890
+ decoding (see :obj:`past_key_values`).
891
+ """
892
+ output_attentions = (
893
+ output_attentions
894
+ if output_attentions is not None
895
+ else self.config.output_attentions
896
+ )
897
+ output_hidden_states = (
898
+ output_hidden_states
899
+ if output_hidden_states is not None
900
+ else self.config.output_hidden_states
901
+ )
902
+ return_dict = (
903
+ return_dict if return_dict is not None else self.config.use_return_dict
904
+ )
905
+
906
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
907
+
908
+ if input_ids is None:
909
+ assert (
910
+ query_embeds is not None
911
+ ), "You have to specify query_embeds when input_ids is None"
912
+
913
+ # past_key_values_length
914
+ past_key_values_length = (
915
+ past_key_values[0][0].shape[2] - self.config.query_length
916
+ if past_key_values is not None
917
+ else 0
918
+ )
919
+
920
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
921
+
922
+ embedding_output = self.embeddings(
923
+ input_ids=input_ids,
924
+ position_ids=position_ids,
925
+ query_embeds=query_embeds,
926
+ past_key_values_length=past_key_values_length,
927
+ )
928
+
929
+ input_shape = embedding_output.size()[:-1]
930
+ batch_size, seq_length = input_shape
931
+ device = embedding_output.device
932
+
933
+ if attention_mask is None:
934
+ attention_mask = torch.ones(
935
+ ((batch_size, seq_length + past_key_values_length)), device=device
936
+ )
937
+
938
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
939
+ # ourselves in which case we just need to make it broadcastable to all heads.
940
+ if is_decoder:
941
+ extended_attention_mask = self.get_extended_attention_mask(
942
+ attention_mask,
943
+ input_ids.shape,
944
+ device,
945
+ is_decoder,
946
+ has_query=(query_embeds is not None),
947
+ )
948
+ else:
949
+ extended_attention_mask = self.get_extended_attention_mask(
950
+ attention_mask, input_shape, device, is_decoder
951
+ )
952
+
953
+ # If a 2D or 3D attention mask is provided for the cross-attention
954
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
955
+ if encoder_hidden_states is not None:
956
+ if type(encoder_hidden_states) == list:
957
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
958
+ 0
959
+ ].size()
960
+ else:
961
+ (
962
+ encoder_batch_size,
963
+ encoder_sequence_length,
964
+ _,
965
+ ) = encoder_hidden_states.size()
966
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
967
+
968
+ if type(encoder_attention_mask) == list:
969
+ encoder_extended_attention_mask = [
970
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
971
+ ]
972
+ elif encoder_attention_mask is None:
973
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
974
+ encoder_extended_attention_mask = self.invert_attention_mask(
975
+ encoder_attention_mask
976
+ )
977
+ else:
978
+ encoder_extended_attention_mask = self.invert_attention_mask(
979
+ encoder_attention_mask
980
+ )
981
+ else:
982
+ encoder_extended_attention_mask = None
983
+
984
+ # Prepare head mask if needed
985
+ # 1.0 in head_mask indicate we keep the head
986
+ # attention_probs has shape bsz x n_heads x N x N
987
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
988
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
989
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
990
+
991
+ encoder_outputs = self.encoder(
992
+ embedding_output,
993
+ position_embedding_ext,
994
+ attention_mask=extended_attention_mask,
995
+ head_mask=head_mask,
996
+ encoder_hidden_states=encoder_hidden_states,
997
+ encoder_attention_mask=encoder_extended_attention_mask,
998
+ past_key_values=past_key_values,
999
+ use_cache=use_cache,
1000
+ output_attentions=output_attentions,
1001
+ output_hidden_states=output_hidden_states,
1002
+ return_dict=return_dict,
1003
+ query_length=query_length,
1004
+ new_video = new_video
1005
+ )
1006
+ sequence_output = encoder_outputs[0]
1007
+ pooled_output = (
1008
+ self.pooler(sequence_output) if self.pooler is not None else None
1009
+ )
1010
+
1011
+ if not return_dict:
1012
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1013
+
1014
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1015
+ last_hidden_state=sequence_output,
1016
+ pooler_output=pooled_output,
1017
+ past_key_values=encoder_outputs.past_key_values,
1018
+ hidden_states=encoder_outputs.hidden_states,
1019
+ attentions=encoder_outputs.attentions,
1020
+ cross_attentions=encoder_outputs.cross_attentions,
1021
+ )
1022
+
1023
+
1024
+ class BertLMHeadModel(BertPreTrainedModel):
1025
+
1026
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1027
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1028
+
1029
+ def __init__(self, config):
1030
+ super().__init__(config)
1031
+
1032
+ self.bert = BertModel(config, add_pooling_layer=False)
1033
+ self.cls = BertOnlyMLMHead(config)
1034
+
1035
+ self.init_weights()
1036
+
1037
+ def get_output_embeddings(self):
1038
+ return self.cls.predictions.decoder
1039
+
1040
+ def set_output_embeddings(self, new_embeddings):
1041
+ self.cls.predictions.decoder = new_embeddings
1042
+
1043
+ def forward(
1044
+ self,
1045
+ input_ids=None,
1046
+ attention_mask=None,
1047
+ position_ids=None,
1048
+ head_mask=None,
1049
+ query_embeds=None,
1050
+ encoder_hidden_states=None,
1051
+ encoder_attention_mask=None,
1052
+ labels=None,
1053
+ past_key_values=None,
1054
+ use_cache=True,
1055
+ output_attentions=None,
1056
+ output_hidden_states=None,
1057
+ return_dict=None,
1058
+ return_logits=False,
1059
+ is_decoder=True,
1060
+ reduction="mean",
1061
+ ):
1062
+ r"""
1063
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1064
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1065
+ the model is configured as a decoder.
1066
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1067
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1068
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1069
+ - 1 for tokens that are **not masked**,
1070
+ - 0 for tokens that are **masked**.
1071
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1072
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1073
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1074
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1075
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1076
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1077
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1078
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1079
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1080
+ use_cache (:obj:`bool`, `optional`):
1081
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1082
+ decoding (see :obj:`past_key_values`).
1083
+ Returns:
1084
+ Example::
1085
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1086
+ >>> import torch
1087
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1088
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1089
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1090
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1091
+ >>> outputs = model(**inputs)
1092
+ >>> prediction_logits = outputs.logits
1093
+ """
1094
+ return_dict = (
1095
+ return_dict if return_dict is not None else self.config.use_return_dict
1096
+ )
1097
+ if labels is not None:
1098
+ use_cache = False
1099
+ if past_key_values is not None:
1100
+ query_embeds = None
1101
+
1102
+ outputs = self.bert(
1103
+ input_ids,
1104
+ attention_mask=attention_mask,
1105
+ position_ids=position_ids,
1106
+ head_mask=head_mask,
1107
+ query_embeds=query_embeds,
1108
+ encoder_hidden_states=encoder_hidden_states,
1109
+ encoder_attention_mask=encoder_attention_mask,
1110
+ past_key_values=past_key_values,
1111
+ use_cache=use_cache,
1112
+ output_attentions=output_attentions,
1113
+ output_hidden_states=output_hidden_states,
1114
+ return_dict=return_dict,
1115
+ is_decoder=is_decoder,
1116
+ )
1117
+
1118
+ sequence_output = outputs[0]
1119
+ if query_embeds is not None:
1120
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1121
+
1122
+ prediction_scores = self.cls(sequence_output)
1123
+
1124
+ if return_logits:
1125
+ return prediction_scores[:, :-1, :].contiguous()
1126
+
1127
+ lm_loss = None
1128
+ if labels is not None:
1129
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1130
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1131
+ labels = labels[:, 1:].contiguous()
1132
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1133
+ lm_loss = loss_fct(
1134
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1135
+ labels.view(-1),
1136
+ )
1137
+ if reduction == "none":
1138
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1139
+
1140
+ if not return_dict:
1141
+ output = (prediction_scores,) + outputs[2:]
1142
+ return ((lm_loss,) + output) if lm_loss is not None else output
1143
+
1144
+ return CausalLMOutputWithCrossAttentions(
1145
+ loss=lm_loss,
1146
+ logits=prediction_scores,
1147
+ past_key_values=outputs.past_key_values,
1148
+ hidden_states=outputs.hidden_states,
1149
+ attentions=outputs.attentions,
1150
+ cross_attentions=outputs.cross_attentions,
1151
+ )
1152
+
1153
+ def prepare_inputs_for_generation(
1154
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1155
+ ):
1156
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1157
+ if attention_mask is None:
1158
+ attention_mask = input_ids.new_ones(input_ids.shape)
1159
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1160
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1161
+
1162
+ # cut decoder_input_ids if past is used
1163
+ if past is not None:
1164
+ input_ids = input_ids[:, -1:]
1165
+
1166
+ return {
1167
+ "input_ids": input_ids,
1168
+ "query_embeds": query_embeds,
1169
+ "attention_mask": attention_mask,
1170
+ "past_key_values": past,
1171
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1172
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1173
+ "is_decoder": True,
1174
+ }
1175
+
1176
+ def _reorder_cache(self, past, beam_idx):
1177
+ reordered_past = ()
1178
+ for layer_past in past:
1179
+ reordered_past += (
1180
+ tuple(
1181
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1182
+ ),
1183
+ )
1184
+ return reordered_past
1185
+
1186
+
1187
+ class BertForMaskedLM(BertPreTrainedModel):
1188
+
1189
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1190
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1191
+
1192
+ def __init__(self, config):
1193
+ super().__init__(config)
1194
+
1195
+ self.bert = BertModel(config, add_pooling_layer=False)
1196
+ self.cls = BertOnlyMLMHead(config)
1197
+
1198
+ self.init_weights()
1199
+
1200
+ def get_output_embeddings(self):
1201
+ return self.cls.predictions.decoder
1202
+
1203
+ def set_output_embeddings(self, new_embeddings):
1204
+ self.cls.predictions.decoder = new_embeddings
1205
+
1206
+ def forward(
1207
+ self,
1208
+ input_ids=None,
1209
+ attention_mask=None,
1210
+ position_ids=None,
1211
+ head_mask=None,
1212
+ query_embeds=None,
1213
+ encoder_hidden_states=None,
1214
+ encoder_attention_mask=None,
1215
+ labels=None,
1216
+ output_attentions=None,
1217
+ output_hidden_states=None,
1218
+ return_dict=None,
1219
+ return_logits=False,
1220
+ is_decoder=False,
1221
+ ):
1222
+ r"""
1223
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1224
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1225
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1226
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1227
+ """
1228
+
1229
+ return_dict = (
1230
+ return_dict if return_dict is not None else self.config.use_return_dict
1231
+ )
1232
+
1233
+ outputs = self.bert(
1234
+ input_ids,
1235
+ attention_mask=attention_mask,
1236
+ position_ids=position_ids,
1237
+ head_mask=head_mask,
1238
+ query_embeds=query_embeds,
1239
+ encoder_hidden_states=encoder_hidden_states,
1240
+ encoder_attention_mask=encoder_attention_mask,
1241
+ output_attentions=output_attentions,
1242
+ output_hidden_states=output_hidden_states,
1243
+ return_dict=return_dict,
1244
+ is_decoder=is_decoder,
1245
+ )
1246
+
1247
+ if query_embeds is not None:
1248
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1249
+ prediction_scores = self.cls(sequence_output)
1250
+
1251
+ if return_logits:
1252
+ return prediction_scores
1253
+
1254
+ masked_lm_loss = None
1255
+ if labels is not None:
1256
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1257
+ masked_lm_loss = loss_fct(
1258
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1259
+ )
1260
+
1261
+ if not return_dict:
1262
+ output = (prediction_scores,) + outputs[2:]
1263
+ return (
1264
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1265
+ )
1266
+
1267
+ return MaskedLMOutput(
1268
+ loss=masked_lm_loss,
1269
+ logits=prediction_scores,
1270
+ hidden_states=outputs.hidden_states,
1271
+ attentions=outputs.attentions,
1272
+ )
__init__.py ADDED
File without changes
blip2.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import contextlib
8
+ import os
9
+ import logging
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from .Qformer import BertConfig, BertLMHeadModel
15
+ from .vit import build_vit
16
+ from transformers import BertTokenizer
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ from transformers import PreTrainedModel, PretrainedConfig, AutoConfig
21
+ # class Blip2Base(nn.Module):
22
+ class Blip2Base(PreTrainedModel):
23
+ def __init__(self, config={}):
24
+ cfg=PretrainedConfig()
25
+ if isinstance(config,(PretrainedConfig,AutoConfig)):
26
+ cfg.update(config.to_dict())
27
+ else:
28
+ cfg.update(dict(config))
29
+ super().__init__(cfg)
30
+
31
+ @classmethod
32
+ def init_tokenizer(cls, truncation_side="right"):
33
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side=truncation_side, local_files_only=True)
34
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
35
+ return tokenizer
36
+
37
+ @property
38
+ def device(self):
39
+ return list(self.parameters())[0].device
40
+
41
+ def maybe_autocast(self, dtype=torch.float16):
42
+ # if on cpu, don't use autocast
43
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
44
+ enable_autocast = self.device != torch.device("cpu")
45
+
46
+ if enable_autocast:
47
+ return torch.cuda.amp.autocast(dtype=dtype)
48
+ else:
49
+ return contextlib.nullcontext()
50
+
51
+ @classmethod
52
+ def init_Qformer(
53
+ cls,
54
+ num_query_token, vision_width,
55
+ qformer_hidden_dropout_prob=0.1,
56
+ qformer_attention_probs_dropout_prob=0.1,
57
+ qformer_drop_path_rate=0.,
58
+ ):
59
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased", local_files_only=True)
60
+ encoder_config.encoder_width = vision_width
61
+ # insert cross-attention layer every other block
62
+ encoder_config.add_cross_attention = True
63
+ encoder_config.cross_attention_freq = 2
64
+ encoder_config.query_length = num_query_token
65
+ encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob
66
+ encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob
67
+ encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, qformer_drop_path_rate, encoder_config.num_hidden_layers)]
68
+ logger.info(f"Drop_path:{encoder_config.drop_path_list}")
69
+ logger.info(encoder_config)
70
+ Qformer = BertLMHeadModel(config=encoder_config)
71
+ query_tokens = nn.Parameter(
72
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
73
+ )
74
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
75
+ return Qformer, query_tokens
76
+
77
+ @classmethod
78
+ def init_vision_encoder_umt(self, config):
79
+ """build vision encoder
80
+ Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`.
81
+
82
+ """
83
+ vision_encoder = build_vit(config)
84
+
85
+ if config.vision_encoder.vit_add_ln:
86
+ vision_layernorm = nn.LayerNorm(config.vision_encoder.encoder_embed_dim, eps=1e-12)
87
+ else:
88
+ vision_layernorm = nn.Identity()
89
+
90
+ return vision_encoder, vision_layernorm
91
+
92
+
93
+ def disabled_train(self, mode=True):
94
+ """Overwrite model.train with this function to make sure train/eval mode
95
+ does not change anymore."""
96
+ return self
97
+
98
+
99
+ class LayerNorm(nn.LayerNorm):
100
+ """Subclass torch's LayerNorm to handle fp16."""
101
+
102
+ def forward(self, x: torch.Tensor):
103
+ orig_type = x.dtype
104
+ ret = super().forward(x.type(torch.float32))
105
+ return ret.type(orig_type)
config.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_second_msg": true,
3
+ "architectures": [
4
+ "VideoChat2_it_hd_mistral"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_videochat2.Config",
8
+ "AutoModel": "videochat2_it_hd_mistral.VideoChat2_it_hd_mistral"
9
+ },
10
+ "dynamic_config": {
11
+ "add_global": true,
12
+ "hd_num": 6,
13
+ "local_size": 224,
14
+ "padding": false
15
+ },
16
+ "end_token": "</Video>",
17
+ "extra_num_query_token": 64,
18
+ "freeze_qformer": false,
19
+ "freeze_vit": false,
20
+ "img_end_token": "</Image>",
21
+ "img_start_token": "<Image>",
22
+ "lora_alpha": 32,
23
+ "lora_dropout": 0.1,
24
+ "lora_r": 16,
25
+ "low_resource": false,
26
+ "max_txt_len": 512,
27
+ "mistral_model_path": "mistralai/Mistral-7B-Instruct-v0.2",
28
+ "model_cls": "VideoChat2_it_hd_mistral",
29
+ "num_query_token": 32,
30
+ "qformer_attention_probs_dropout_prob": 0.1,
31
+ "qformer_drop_path_rate": 0.2,
32
+ "qformer_hidden_dropout_prob": 0.1,
33
+ "qformer_text_input": true,
34
+ "random_shuffle": true,
35
+ "return_question_instruction": false,
36
+ "start_token": "<Video>",
37
+ "system": "",
38
+ "torch_dtype": "float32",
39
+ "transformers_version": "4.44.2",
40
+ "use_flash_attention": false,
41
+ "use_lora": false,
42
+ "videochat2_model_path": "",
43
+ "vision_encoder": {
44
+ "checkpoint_num": 18,
45
+ "ckpt_num_frame": 4,
46
+ "d_model": 1024,
47
+ "drop_path_rate": 0.0,
48
+ "encoder_depth": 24,
49
+ "encoder_embed_dim": 1024,
50
+ "encoder_num_heads": 16,
51
+ "img_size": 224,
52
+ "name": "vit_l14",
53
+ "num_frames": 4,
54
+ "patch_size": 16,
55
+ "pretrained": "",
56
+ "return_index": -2,
57
+ "tubelet_size": 1,
58
+ "use_checkpoint": true,
59
+ "vit_add_ln": true
60
+ },
61
+ "vit_blip_model_path": ""
62
+ }
configuration_videochat2.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import ast
5
+ import json
6
+ import os
7
+ import os.path as osp
8
+ import re
9
+ import shutil
10
+ import sys
11
+ import tempfile
12
+ from copy import deepcopy
13
+ from importlib import import_module
14
+
15
+ import yaml
16
+
17
+
18
+ __all__ = ["Config", "pretty_text"]
19
+
20
+
21
+ BASE_KEY = "_base_"
22
+ # BASE_CONFIG = {"OUTPUT_DIR": "./workspace", "SESSION": "base", "LOG_FILE": "log.txt"}
23
+ BASE_CONFIG = {}
24
+
25
+ cfg = None
26
+
27
+ class EasyDict(dict):
28
+ """
29
+ Get attributes
30
+
31
+ >>> d = EasyDict({'foo':3})
32
+ >>> d['foo']
33
+ 3
34
+ >>> d.foo
35
+ 3
36
+ >>> d.bar
37
+ Traceback (most recent call last):
38
+ ...
39
+ AttributeError: 'EasyDict' object has no attribute 'bar'
40
+
41
+ Works recursively
42
+
43
+ >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
44
+ >>> isinstance(d.bar, dict)
45
+ True
46
+ >>> d.bar.x
47
+ 1
48
+
49
+ Bullet-proof
50
+
51
+ >>> EasyDict({})
52
+ {}
53
+ >>> EasyDict(d={})
54
+ {}
55
+ >>> EasyDict(None)
56
+ {}
57
+ >>> d = {'a': 1}
58
+ >>> EasyDict(**d)
59
+ {'a': 1}
60
+
61
+ Set attributes
62
+
63
+ >>> d = EasyDict()
64
+ >>> d.foo = 3
65
+ >>> d.foo
66
+ 3
67
+ >>> d.bar = {'prop': 'value'}
68
+ >>> d.bar.prop
69
+ 'value'
70
+ >>> d
71
+ {'foo': 3, 'bar': {'prop': 'value'}}
72
+ >>> d.bar.prop = 'newer'
73
+ >>> d.bar.prop
74
+ 'newer'
75
+
76
+
77
+ Values extraction
78
+
79
+ >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
80
+ >>> isinstance(d.bar, list)
81
+ True
82
+ >>> from operator import attrgetter
83
+ >>> map(attrgetter('x'), d.bar)
84
+ [1, 3]
85
+ >>> map(attrgetter('y'), d.bar)
86
+ [2, 4]
87
+ >>> d = EasyDict()
88
+ >>> d.keys()
89
+ []
90
+ >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
91
+ >>> d.foo
92
+ 3
93
+ >>> d.bar.x
94
+ 1
95
+
96
+ Still like a dict though
97
+
98
+ >>> o = EasyDict({'clean':True})
99
+ >>> o.items()
100
+ [('clean', True)]
101
+
102
+ And like a class
103
+
104
+ >>> class Flower(EasyDict):
105
+ ... power = 1
106
+ ...
107
+ >>> f = Flower()
108
+ >>> f.power
109
+ 1
110
+ >>> f = Flower({'height': 12})
111
+ >>> f.height
112
+ 12
113
+ >>> f['power']
114
+ 1
115
+ >>> sorted(f.keys())
116
+ ['height', 'power']
117
+
118
+ update and pop items
119
+ >>> d = EasyDict(a=1, b='2')
120
+ >>> e = EasyDict(c=3.0, a=9.0)
121
+ >>> d.update(e)
122
+ >>> d.c
123
+ 3.0
124
+ >>> d['c']
125
+ 3.0
126
+ >>> d.get('c')
127
+ 3.0
128
+ >>> d.update(a=4, b=4)
129
+ >>> d.b
130
+ 4
131
+ >>> d.pop('a')
132
+ 4
133
+ >>> d.a
134
+ Traceback (most recent call last):
135
+ ...
136
+ AttributeError: 'EasyDict' object has no attribute 'a'
137
+ """
138
+
139
+ def __init__(self, d=None, **kwargs):
140
+ if d is None:
141
+ d = {}
142
+ if kwargs:
143
+ d.update(**kwargs)
144
+ for k, v in d.items():
145
+ setattr(self, k, v)
146
+ # Class attributes
147
+ for k in self.__class__.__dict__.keys():
148
+ if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
149
+ setattr(self, k, getattr(self, k))
150
+
151
+ def __setattr__(self, name, value):
152
+ if isinstance(value, (list, tuple)):
153
+ value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
154
+ elif isinstance(value, dict) and not isinstance(value, self.__class__):
155
+ value = self.__class__(value)
156
+ super(EasyDict, self).__setattr__(name, value)
157
+ super(EasyDict, self).__setitem__(name, value)
158
+
159
+ __setitem__ = __setattr__
160
+
161
+ def update(self, e=None, **f):
162
+ d = e or dict()
163
+ d.update(f)
164
+ for k in d:
165
+ setattr(self, k, d[k])
166
+
167
+ def pop(self, k, d=None):
168
+ if hasattr(self, k):
169
+ delattr(self, k)
170
+ return super(EasyDict, self).pop(k, d)
171
+
172
+ from transformers import PretrainedConfig
173
+ class Config(PretrainedConfig):
174
+ _auto_class = "AutoConfig"
175
+ """config"""
176
+ def __init__(self, **kwargs):
177
+ super().__init__(**kwargs)
178
+ self.cfg=EasyDict(kwargs)
179
+
180
+ @classmethod
181
+ def pretty_text(cls, cfg: dict, indent=2) -> str:
182
+ """format dict to a string
183
+
184
+ Args:
185
+ cfg (EasyDict): the params.
186
+
187
+ Returns: The string to display.
188
+
189
+ """
190
+ msg = "{\n"
191
+ for i, (k, v) in enumerate(cfg.items()):
192
+ if isinstance(v, dict):
193
+ v = cls.pretty_text(v, indent + 4)
194
+ spaces = " " * indent
195
+ msg += spaces + "{}: {}".format(k, v)
196
+ if i == len(cfg) - 1:
197
+ msg += " }"
198
+ else:
199
+ msg += "\n"
200
+ return msg
201
+
202
+ @classmethod
203
+ def dump(cls, cfg, savepath=None):
204
+ """dump cfg to `json` file.
205
+
206
+ Args:
207
+ cfg (dict): The dict to dump.
208
+ savepath (str): The filepath to save the dumped dict.
209
+
210
+ Returns: TODO
211
+
212
+ """
213
+ if savepath is None:
214
+ savepath = osp.join(cfg.WORKSPACE, "config.json")
215
+ json.dump(cfg, open(savepath, "w"), indent=2)
216
+
217
+ @classmethod
218
+ def get_config(cls, default_config: dict = None, config_file: str=''):
219
+ """get a `Config` instance.
220
+
221
+ Args:
222
+ default_config (dict): The default config. `default_config` will be overrided
223
+ by config file `--cfg`, `--cfg` will be overrided by commandline args.
224
+
225
+ Returns: an EasyDict.
226
+ """
227
+ global cfg
228
+ if cfg is not None:
229
+ return cfg
230
+
231
+ # define arg parser.
232
+ parser = argparse.ArgumentParser()
233
+ # parser.add_argument("--cfg", help="load configs from yaml file", default="", type=str)
234
+ parser.add_argument(
235
+ "--config_file", default='your config file', help="the configuration file to load. support: .yaml, .json, .py"
236
+ )
237
+ parser.add_argument(
238
+ "--opts",
239
+ default=None,
240
+ nargs="*",
241
+ help="overrided configs. List. Format: 'key1 name1 key2 name2'",
242
+ )
243
+ # args = parser.parse_args()
244
+ args = parser.parse_known_args()[0] # for jupyterrrrrrrrrrrrrrrrrrrrrrrrr
245
+ args.config_file="/mnt/petrelfs/shiyansong/WEIGHT/UMT/l16_25m.py"
246
+
247
+ if config_file:
248
+ args.config_file=config_file
249
+
250
+ cfg = EasyDict(BASE_CONFIG)
251
+ # if default_config: # new------------------------------------
252
+ # cfg = merge_a_into_b(default_config, cfg)
253
+ if osp.isfile(args.config_file):
254
+ cfg_from_file = cls.from_file(args.config_file)
255
+ cfg = merge_a_into_b(cfg_from_file, cfg)
256
+ if args.opts:
257
+ cfg = cls.merge_list(cfg, args.opts)
258
+ cfg = eval_dict_leaf(cfg)
259
+
260
+ # update some keys to make them show at the last
261
+ for k in BASE_CONFIG:
262
+ cfg[k] = cfg.pop(k)
263
+ return cfg
264
+
265
+ @classmethod
266
+ def from_file(cls, filepath: str) -> EasyDict:
267
+ """Build config from file. Supported filetypes: `.py`,`.yaml`,`.json`.
268
+
269
+ Args:
270
+ filepath (str): The config file path.
271
+
272
+ Returns: TODO
273
+
274
+ """
275
+ filepath = osp.abspath(osp.expanduser(filepath))
276
+ if not osp.isfile(filepath):
277
+ raise IOError(f"File does not exist: {filepath}")
278
+ if filepath.endswith(".py"):
279
+ sys.path.insert(0, osp.dirname(filepath))
280
+ mod = import_module(osp.splitext(osp.basename(filepath))[0])
281
+ cfg_dict = {
282
+ name: value
283
+ for name, value in mod.__dict__.items()
284
+ if not name.startswith("__")
285
+ }
286
+
287
+ # I've no idea what the fuck is this, fuck it!!!
288
+ # with tempfile.TemporaryDirectory() as temp_config_dir:
289
+ # print(temp_config_dir, filepath)
290
+
291
+ # print(f"Copying {osp.dirname(filepath)} to {osp.join(temp_config_dir, 'tmp_config')}")
292
+ # shutil.copytree(osp.dirname(filepath), osp.join(temp_config_dir, "tmp_config"))
293
+ # sys.path.insert(0, temp_config_dir)
294
+ # mod = import_module("tmp_config." + osp.splitext(osp.basename(filepath))[0])
295
+ # # mod = import_module(temp_module_name)
296
+ # sys.path.pop(0)
297
+ # cfg_dict = {
298
+ # name: value
299
+ # for name, value in mod.__dict__.items()
300
+ # if not name.startswith("__")
301
+ # }
302
+ # print("Removing")
303
+ # for k in list(sys.modules.keys()):
304
+ # if "tmp_config" in k:
305
+ # del sys.modules[k]
306
+ elif filepath.endswith((".yml", ".yaml")):
307
+ cfg_dict = yaml.load(open(filepath, "r"), Loader=yaml.Loader)
308
+ elif filepath.endswith(".json"):
309
+ cfg_dict = json.load(open(filepath, "r"))
310
+ else:
311
+ raise IOError("Only py/yml/yaml/json type are supported now!")
312
+
313
+ cfg_text = filepath + "\n"
314
+ with open(filepath, "r") as f:
315
+ cfg_text += f.read()
316
+
317
+ if BASE_KEY in cfg_dict: # load configs in `BASE_KEY`
318
+ cfg_dir = osp.dirname(filepath)
319
+ base_filename = cfg_dict.pop(BASE_KEY)
320
+ base_filename = (
321
+ base_filename if isinstance(base_filename, list) else [base_filename]
322
+ )
323
+
324
+ cfg_dict_list = list()
325
+ for f in base_filename:
326
+ _cfg_dict = Config.from_file(osp.join(cfg_dir, f))
327
+ cfg_dict_list.append(_cfg_dict)
328
+
329
+ base_cfg_dict = dict()
330
+ for c in cfg_dict_list:
331
+ if len(base_cfg_dict.keys() & c.keys()) > 0:
332
+ raise KeyError("Duplicate key is not allowed among bases")
333
+ base_cfg_dict.update(c)
334
+
335
+ cfg_dict = merge_a_into_b(cfg_dict, base_cfg_dict)
336
+
337
+ return EasyDict(cfg_dict)
338
+
339
+ @classmethod
340
+ def merge_list(cls, cfg, opts: list):
341
+ """merge commandline opts.
342
+
343
+ Args:
344
+ cfg: (dict): The config to be merged.
345
+ opts (list): The list to merge. Format: [key1, name1, key2, name2,...].
346
+ The keys can be nested. For example, ["a.b", v] will be considered
347
+ as `dict(a=dict(b=v))`.
348
+
349
+ Returns: dict.
350
+
351
+ """
352
+ assert len(opts) % 2 == 0, f"length of opts must be even. Got: {opts}"
353
+ for i in range(0, len(opts), 2):
354
+ full_k, v = opts[i], opts[i + 1]
355
+ keys = full_k.split(".")
356
+ sub_d = cfg
357
+ for i, k in enumerate(keys):
358
+ if not hasattr(sub_d, k):
359
+ raise ValueError(f"The key {k} not exist in the config. Full key:{full_k}")
360
+ if i != len(keys) - 1:
361
+ sub_d = sub_d[k]
362
+ else:
363
+ sub_d[k] = v
364
+ return cfg
365
+
366
+
367
+ def merge_a_into_b(a, b, inplace=False):
368
+ """The values in a will override values in b.
369
+
370
+ Args:
371
+ a (dict): source dict.
372
+ b (dict): target dict.
373
+
374
+ Returns: dict. recursively merge dict a into dict b.
375
+
376
+ """
377
+ if not inplace:
378
+ b = deepcopy(b)
379
+ for key in a:
380
+ if key in b:
381
+ if isinstance(a[key], dict) and isinstance(b[key], dict):
382
+ b[key] = merge_a_into_b(a[key], b[key], inplace=True)
383
+ else:
384
+ b[key] = a[key]
385
+ else:
386
+ b[key] = a[key]
387
+ return b
388
+
389
+
390
+ def eval_dict_leaf(d, orig_dict=None):
391
+ """eval values of dict leaf.
392
+
393
+ Args:
394
+ d (dict): The dict to eval.
395
+
396
+ Returns: dict.
397
+
398
+ """
399
+ if orig_dict is None:
400
+ orig_dict = d
401
+ for k, v in d.items():
402
+ if not isinstance(v, dict):
403
+ d[k] = eval_string(v, orig_dict)
404
+ else:
405
+ eval_dict_leaf(v, orig_dict)
406
+ return d
407
+
408
+
409
+ def eval_string(string, d):
410
+ """automatically evaluate string to corresponding types.
411
+
412
+ For example:
413
+ not a string -> return the original input
414
+ '0' -> 0
415
+ '0.2' -> 0.2
416
+ '[0, 1, 2]' -> [0,1,2]
417
+ 'eval(1+2)' -> 3
418
+ 'eval(range(5))' -> [0,1,2,3,4]
419
+ '${a}' -> d.a
420
+
421
+
422
+
423
+ Args:
424
+ string (str): The value to evaluate.
425
+ d (dict): The
426
+
427
+ Returns: the corresponding type
428
+
429
+ """
430
+ if not isinstance(string, str):
431
+ return string
432
+ # if len(string) > 1 and string[0] == "[" and string[-1] == "]":
433
+ # return eval(string)
434
+ if string[0:5] == "eval(":
435
+ return eval(string[5:-1])
436
+
437
+ s0 = string
438
+ s1 = re.sub(r"\${(.*)}", r"d.\1", s0)
439
+ if s1 != s0:
440
+ while s1 != s0:
441
+ s0 = s1
442
+ s1 = re.sub(r"\${(.*)}", r"d.\1", s0)
443
+ return eval(s1)
444
+
445
+ try:
446
+ v = ast.literal_eval(string)
447
+ except:
448
+ v = string
449
+ return v
450
+
451
+ if __name__=="__main__":
452
+ d=EasyDict({"1":2,"2":3})
453
+ cfg=Config({"1":2,"2":3})
ltm/basis_functions.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+
5
+ class BasisFunctions(object):
6
+ def __init__(self):
7
+ pass
8
+
9
+ def __len__(self):
10
+ """Number of basis functions."""
11
+ pass
12
+
13
+ def evaluate(self, t):
14
+ pass
15
+
16
+ def integrate_t2_times_psi(self, a, b):
17
+ """Compute integral int_a^b (t**2) * psi(t)."""
18
+ pass
19
+
20
+ def integrate_t_times_psi(self, a, b):
21
+ """Compute integral int_a^b t * psi(t)."""
22
+ pass
23
+
24
+ def integrate_psi(self, a, b):
25
+ """Compute integral int_a^b psi(t)."""
26
+ pass
27
+
28
+
29
+ class PowerBasisFunctions(BasisFunctions):
30
+ """Function phi(t) = t**degree."""
31
+ def __init__(self, degree):
32
+ self.degree = degree.unsqueeze(0)
33
+
34
+ def __len__(self):
35
+ """Number of basis functions."""
36
+ return self.degree.size(1)
37
+
38
+ def evaluate(self, t):
39
+ return t**self.degree
40
+
41
+ def integrate_t2_times_psi(self, a, b):
42
+ """Compute integral int_a^b (t**2) * psi(t)."""
43
+ return (b**(self.degree + 3) - a**(self.degree + 3)) / (self.degree + 3)
44
+
45
+ def integrate_t_times_psi(self, a, b):
46
+ """Compute integral int_a^b t * psi(t)."""
47
+ return (b**(self.degree + 2) - a**(self.degree + 2)) / (self.degree + 2)
48
+
49
+ def integrate_psi(self, a, b):
50
+ """Compute integral int_a^b psi(t)."""
51
+ return (b**(self.degree + 1) - a**(self.degree + 1)) / (self.degree + 1)
52
+
53
+ def __repr__(self):
54
+ return f"PowerBasisFunction(degree={self.degree})"
55
+
56
+
57
+ class SineBasisFunctions(BasisFunctions):
58
+ """Function phi(t) = sin(omega*t)."""
59
+ def __init__(self, omega):
60
+ self.omega = omega.unsqueeze(0)
61
+
62
+ def __repr__(self):
63
+ return f"SineBasisFunction(omega={self.omega})"
64
+
65
+ def __len__(self):
66
+ """Number of basis functions."""
67
+ return self.omega.size(1)
68
+
69
+ def evaluate(self, t):
70
+ return torch.sin(self.omega*t)
71
+
72
+ def integrate_t2_times_psi(self, a, b):
73
+ """Compute integral int_a^b (t**2) * psi(t)."""
74
+ # The antiderivative of (t**2)*sin(omega*t) is
75
+ # ((2-(t**2)*(omega**2))*cos(omega*t) + 2*omega*t*sin(omega*t)) / omega**3. # noqa
76
+ return ((2-(b**2)*(self.omega**2))*torch.cos(self.omega*b)
77
+ + 2*self.omega*b*torch.sin(self.omega*b)
78
+ - (2-(a**2)*(self.omega**2))*torch.cos(self.omega*a)
79
+ - 2*self.omega*a*torch.sin(self.omega*a)
80
+ ) / (self.omega**3)
81
+
82
+ def integrate_t_times_psi(self, a, b):
83
+ """Compute integral int_a^b t * psi(t)."""
84
+ # The antiderivative of t*sin(omega*t) is
85
+ # (sin(omega*t) - omega*t*cos(omega*t)) / omega**2.
86
+ return (torch.sin(self.omega*b) - self.omega*b*torch.cos(self.omega*b)
87
+ - torch.sin(self.omega*a) + self.omega*a*torch.cos(self.omega*a)
88
+ ) / (self.omega**2)
89
+
90
+ def integrate_psi(self, a, b):
91
+ """Compute integral int_a^b psi(t)."""
92
+ # The antiderivative of sin(omega*t) is -cos(omega*t)/omega.
93
+ return (-torch.cos(self.omega*b) + torch.cos(self.omega*a)) / self.omega
94
+
95
+
96
+ class CosineBasisFunctions(BasisFunctions):
97
+ """Function phi(t) = cos(omega*t)."""
98
+ def __init__(self, omega):
99
+ self.omega = omega.unsqueeze(0)
100
+
101
+ def __repr__(self):
102
+ return f"CosineBasisFunction(omega={self.omega})"
103
+
104
+ def __len__(self):
105
+ """Number of basis functions."""
106
+ return self.omega.size(1)
107
+
108
+ def evaluate(self, t):
109
+ return torch.cos(self.omega*t)
110
+
111
+ def integrate_t2_times_psi(self, a, b):
112
+ """Compute integral int_a^b (t**2) * psi(t)."""
113
+ # The antiderivative of (t**2)*cos(omega*t) is
114
+ # (((t**2)*(omega**2)-2)*cos(omega*t) + 2*omega*t*sin(omega*t)) / omega**3. # noqa
115
+ return (((b**2)*(self.omega**2)-2)*torch.sin(self.omega*b)
116
+ + 2*self.omega*b*torch.cos(self.omega*b)
117
+ - ((a**2)*(self.omega**2)-2)*torch.sin(self.omega*a)
118
+ - 2*self.omega*a*torch.cos(self.omega*a)
119
+ ) / (self.omega**3)
120
+
121
+ def integrate_t_times_psi(self, a, b):
122
+ """Compute integral int_a^b t * psi(t)."""
123
+ # The antiderivative of t*cos(omega*t) is
124
+ # (cos(omega*t) + omega*t*sin(omega*t)) / omega**2.
125
+ return (torch.cos(self.omega*b) + self.omega*b*torch.sin(self.omega*b)
126
+ - torch.cos(self.omega*a) - self.omega*a*torch.sin(self.omega*a)
127
+ ) / (self.omega**2)
128
+
129
+ def integrate_psi(self, a, b):
130
+ """Compute integral int_a^b psi(t)."""
131
+ # The antiderivative of cos(omega*t) is sin(omega*t)/omega.
132
+ return (torch.sin(self.omega*b) - torch.sin(self.omega*a)) / self.omega
133
+
134
+
135
+ class GaussianBasisFunctions(BasisFunctions):
136
+ """Function phi(t) = Gaussian(t; mu, sigma_sq)."""
137
+ def __init__(self, mu, sigma):
138
+ self.mu = mu.unsqueeze(0)
139
+ self.sigma = sigma.unsqueeze(0)
140
+
141
+ def __repr__(self):
142
+ return f"GaussianBasisFunction(mu={self.mu}, sigma={self.sigma})"
143
+
144
+ def __len__(self):
145
+ """Number of basis functions."""
146
+ return self.mu.size(1)
147
+
148
+ def _phi(self, t):
149
+ return 1. / math.sqrt(2 * math.pi) * torch.exp(-.5 * t**2)
150
+
151
+ def _Phi(self, t):
152
+ return .5 * (1 + torch.erf(t / math.sqrt(2)))
153
+
154
+ def _integrate_product_of_gaussians(self, mu, sigma_sq):
155
+ sigma = torch.sqrt(self.sigma ** 2 + sigma_sq)
156
+ return self._phi((mu - self.mu) / sigma) / sigma
157
+
158
+ def evaluate(self, t):
159
+ return self._phi((t - self.mu) / self.sigma) / self.sigma
160
+
161
+ def batch_evaluate(self, t):
162
+ t_ = t.repeat(self.mu.size(0),1) - self.mu.repeat(t.size(0),1).transpose(1,0)
163
+ t_ = t_ / self.sigma.repeat((t.size(0),1)).transpose(1,0)
164
+ return (self._phi(t_) / self.sigma.repeat((t.size(0),1)).transpose(1,0)).transpose(0,1)
165
+
166
+ def integrate_t2_times_psi(self, a, b):
167
+ """Compute integral int_a^b (t**2) * psi(t)."""
168
+ return (self.mu**2 + self.sigma**2) * (
169
+ self._Phi((b - self.mu) / self.sigma) - self._Phi((a - self.mu) / self.sigma)
170
+ ) - (
171
+ self.sigma * (b + self.mu) * self._phi((b - self.mu) / self.sigma)
172
+ ) + (
173
+ self.sigma * (a + self.mu) * self._phi((a - self.mu) / self.sigma)
174
+ )
175
+
176
+ def integrate_t_times_psi(self, a, b):
177
+ """Compute integral int_a^b t * psi(t)."""
178
+ return self.mu * (
179
+ self._Phi((b - self.mu) / self.sigma) - self._Phi((a - self.mu) / self.sigma)
180
+ ) - self.sigma * (
181
+ self._phi((b - self.mu) / self.sigma) - self._phi((a - self.mu) / self.sigma)
182
+ )
183
+
184
+ def integrate_psi(self, a, b):
185
+ """Compute integral int_a^b psi(t)."""
186
+ return self._Phi((b - self.mu) / self.sigma) - self._Phi((a - self.mu) / self.sigma)
187
+
188
+ def integrate_t2_times_psi_gaussian(self, mu, sigma_sq):
189
+ """Compute integral int N(t; mu, sigma_sq) * t**2 * psi(t)."""
190
+ S_tilde = self._integrate_product_of_gaussians(mu, sigma_sq)
191
+ mu_tilde = (
192
+ self.mu * sigma_sq + mu * self.sigma ** 2
193
+ ) / (
194
+ self.sigma ** 2 + sigma_sq
195
+ )
196
+ sigma_sq_tilde = ((self.sigma ** 2) * sigma_sq) / (self.sigma ** 2 + sigma_sq)
197
+ return S_tilde * (mu_tilde ** 2 + sigma_sq_tilde)
198
+
199
+ def integrate_t_times_psi_gaussian(self, mu, sigma_sq):
200
+ """Compute integral int N(t; mu, sigma_sq) * t * psi(t)."""
201
+ S_tilde = self._integrate_product_of_gaussians(mu, sigma_sq)
202
+ mu_tilde = (
203
+ self.mu * sigma_sq + mu * self.sigma ** 2
204
+ ) / (
205
+ self.sigma ** 2 + sigma_sq
206
+ )
207
+ return S_tilde * mu_tilde
208
+
209
+ def integrate_psi_gaussian(self, mu, sigma_sq):
210
+ """Compute integral int N(t; mu, sigma_sq) * psi(t)."""
211
+ return self._integrate_product_of_gaussians(mu, sigma_sq)
212
+
213
+
214
+ class RetangularBasisFunctions(BasisFunctions):
215
+ """Function phi(t) = Gaussian(t; mu, sigma_sq)."""
216
+ def __init__(self, mu, sigma):
217
+ self.mu = mu.unsqueeze(0)
218
+ self.width = sigma.unsqueeze(0)
219
+
220
+ def __repr__(self):
221
+ return f"GaussianBasisFunction(mu={self.mu}, sigma={self.sigma})"
222
+
223
+ def __len__(self):
224
+ """Number of basis functions."""
225
+ return self.mu.size(1)
226
+
227
+ def batch_evaluate(self, t):
228
+ """
229
+ Evaluate multiple time points against all rectangular basis functions.
230
+ Args:
231
+ t: Tensor of time values to evaluate, shape (num_points,).
232
+ Returns:
233
+ Tensor of evaluations, shape (num_basis, num_points).
234
+ """
235
+ t = t.repeat(self.mu.size(0),1) # Shape: (1, num_points)
236
+ mu = self.mu.repeat(t.size(0),1).transpose(1,0) # Shape: (num_basis, 1)
237
+ width = self.width.repeat(t.size(0),1).transpose(1,0) # Shape: (num_basis, 1)
238
+ return ((t >= (mu - width / 2)) & (t < (mu + width / 2))).float().transpose(0,1)
239
+
240
+ def _Phi(self, t):
241
+ """
242
+ Compute the step function for a single value of t.
243
+ Args:
244
+ t: A scalar or tensor of time values.
245
+ Returns:
246
+ Tensor of values indicating presence in each basis function's range.
247
+ """
248
+ lower_bounds = self.mu - self.width / 2
249
+ upper_bounds = self.mu + self.width / 2
250
+ return ((t >= lower_bounds) & (t < upper_bounds)).float()
251
+
252
+ def evaluate(self, t):
253
+ """
254
+ Evaluate the rectangular basis functions at a single point or array of points.
255
+ Args:
256
+ t: A scalar or 1D tensor of time values.
257
+ Returns:
258
+ Tensor of shape (num_basis,) for scalar input, or (num_basis, num_points) for tensor input.
259
+ """
260
+ if t.ndim == 0: # Scalar input
261
+ return self._Phi(t)
262
+ else: # Tensor input
263
+ # Shape: (1, num_points)
264
+ lower_bounds = (self.mu - self.width / 2) # Shape: (num_basis, 1)
265
+ upper_bounds = (self.mu + self.width / 2) # Shape: (num_basis, 1)
266
+ return ((t >= lower_bounds) & (t < upper_bounds)).float()
ltm/long_term_attention_gibbs.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ """
3
+ Attention modules
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.distributions as dist
8
+
9
+ from .basis_functions import (
10
+ PowerBasisFunctions,
11
+ SineBasisFunctions,
12
+ CosineBasisFunctions,
13
+ GaussianBasisFunctions,
14
+ RetangularBasisFunctions
15
+ )
16
+
17
+ import numpy as np
18
+
19
+
20
+
21
+ class LongTermAttention(nn.Module):
22
+ def __init__(self, head_size:int , length: int, target_len:int, attn_func: str, attn_num_basis: int,
23
+ continuous: bool, attn_drop: float, infinite_memory: bool, n_layers: int,
24
+ n_heads: int, affines: bool, mask: bool, mask_type: str, kl_regularizer: bool, proj_key, proj_value, sigma_0, mu_0, sticky_memories, sigmas, tau, **kwargs):
25
+
26
+ super(LongTermAttention, self).__init__()
27
+
28
+ self.device = 'cuda'
29
+ self.length = length #memory length
30
+ self.target_len = target_len #target length / transformer length
31
+ self.head_size = head_size
32
+ self.attn_num_basis = attn_num_basis
33
+ self.continuous = continuous # whether attention over memory vectors is continuous
34
+ self.attn_func = attn_func # normalizing function
35
+ self.n_head = n_heads
36
+ self.sigmas = sigmas
37
+ self.kl_regularizer = kl_regularizer
38
+ self.sigma_0 = sigma_0
39
+ self.mu_0 = mu_0
40
+ self.proj_key = proj_key
41
+ self.proj_value = proj_value
42
+
43
+ self.affines=affines # whether mu, sigma should be computed using affine transformations
44
+
45
+
46
+ self.sticky_memories=sticky_memories
47
+
48
+ self.mem_threshold=2048
49
+ self.infinite_memory = infinite_memory # whether the memory is infinite
50
+
51
+ self.nb_samples=512 # number of samples used for update
52
+ self.tau = tau #compressing factor
53
+ self.count = 0
54
+
55
+ self.x_past=None # previous memory vectors
56
+ self.B_past=None # previous coefficient matrix
57
+
58
+ self.ridge_penalty=0.5 # ridge penalty
59
+ self.padding = True
60
+
61
+ self.spacing='linear'
62
+
63
+ def get_basis(self, length, target_len):
64
+ def compute_G(l, psi, positions, padding=True):
65
+
66
+ F = torch.zeros(self.attn_num_basis, positions.size(0))
67
+
68
+ basis_functions = psi
69
+ F[:, :] = basis_functions.evaluate(positions.unsqueeze(1)).t()
70
+
71
+ I = torch.eye(self.attn_num_basis)
72
+ G = F.t().matmul((F.matmul(F.t()) + self.ridge_penalty * I).inverse())
73
+
74
+ if padding:
75
+ if l % 2:
76
+ G = G[((l-1)//2):(-(l-1)//2), :]
77
+ else:
78
+ G = G[(l//2):-(l//2), :]
79
+
80
+ return G.to(self.device)
81
+ padding = self.padding
82
+ attn_num_basis = self.attn_num_basis
83
+ if self.continuous:
84
+
85
+ self.psi=[None]
86
+ self.Gs=[None for _ in range(length+1)]
87
+ lengths=[]
88
+ for i in range(length):
89
+ self.psi.append([])
90
+ if (i+1)%target_len==0:
91
+ lengths.append(i+1)
92
+ if length not in lengths:
93
+ lengths.append(length)
94
+ for l in lengths:
95
+ # get positions for memory vectors
96
+ self.add_retangular_basis_functions(self.psi[l], attn_num_basis, device=self.device)
97
+
98
+ if self.spacing=='linear':
99
+ if padding:
100
+ if l % 2:
101
+ shift = 1 / float(l)
102
+ positions = torch.linspace(-.5+shift, 1.5-shift, 2*l-1).to(self.device)
103
+ else:
104
+ shift = 1 / float(2*l)
105
+ positions = torch.linspace(-.5+shift, 1.5-shift, 2*l).to(self.device)
106
+ else:
107
+ shift = 1 / float(2*l)
108
+ positions = torch.linspace(shift, 1-shift, l).to(self.device)
109
+ elif self.spacing=='log':
110
+ if padding:
111
+ if l % 2:
112
+ shift = 1 / float(l)
113
+ positions = torch.linspace(-.5+shift, 1.5-shift, 2*l-1).to(self.device)
114
+ else:
115
+ shift = 1 / float(2*l)
116
+ positions = torch.linspace(-.5+shift, 1.5-shift, 2*l).to(self.device)
117
+
118
+ pos = np.e**(np.log(1+1)*torch.arange(1,length+1)/length)-1
119
+ positions = torch.cat([positions[:int(l/2)],pos.to(self.device),positions[-int(l/2):]])
120
+
121
+ else:
122
+ positions = np.e**(np.log(1+1)*torch.arange(1,length+1)/length)-1
123
+
124
+ # compute basis functions
125
+ self.Gs[l]=compute_G(l, self.psi[l][0], positions, padding=padding) # [L,N]
126
+ self.positions = positions[int(l/2):-int(l/2)]
127
+
128
+ # compute samples for memory update
129
+ if self.infinite_memory:
130
+ tm_tau = torch.arange(1,self.nb_samples+1).float()
131
+ tm_l = torch.arange(self.nb_samples+1,length+self.nb_samples+1).float()
132
+ tm_tau = tm_tau*self.tau/self.nb_samples # positions of old vectors
133
+ tm_l = self.tau + (1-self.tau)*(tm_l-self.nb_samples)/length # positions of new vectors
134
+ positions_inf = torch.cat([tm_tau, tm_l],0).to(self.device) # positions
135
+
136
+ if padding:
137
+ if l % 2:
138
+ shift = 1 / float(length+self.nb_samples)
139
+ positions_pad = torch.linspace(-.5+shift, 1.5-shift, 2*(length+self.nb_samples)-1).to(self.device)
140
+ else:
141
+ shift = 1 / float(2*length+self.nb_samples)
142
+ positions_pad = torch.linspace(-.5+shift, 1.5-shift, 2*(length+self.nb_samples)).to(self.device)
143
+ positions_pad_ = torch.FloatTensor([i for i in positions_pad if i<0]).to(self.device)
144
+ positions_pad__ = torch.FloatTensor([i for i in positions_pad if i>1]).to(self.device)
145
+ positions_inf = torch.cat([positions_pad_,positions_inf,positions_pad__], dim=0)
146
+
147
+ self.samples=None
148
+ for t in tm_tau:
149
+ if self.samples is None:
150
+ self.samples = self.psi[l][0].evaluate(t/self.tau)
151
+ else:
152
+ self.samples = torch.cat([self.samples,self.psi[l][0].evaluate(t/self.tau)], dim=0)
153
+
154
+ # compute G for the infinite case
155
+ self.G_inf = compute_G(self.nb_samples+length, self.psi[l][0], positions_inf, padding=padding) #[L+nb_samples,N]
156
+
157
+ if self.sticky_memories:
158
+ self.bins = torch.linspace(0,1,129).to(device=self.device) #self.positions
159
+ self.nb_bins_cat=1
160
+ self.bins_cat = dist.Categorical(torch.ones(self.nb_bins_cat))
161
+
162
+ def add_gaussian_basis_functions(self, psi, nb_basis, sigmas, device):
163
+ mu, sigma = torch.meshgrid(torch.linspace(0, 1, nb_basis // len(sigmas)), torch.Tensor(sigmas))
164
+ mu = mu.flatten().to(device)
165
+ sigma = sigma.flatten().to(device)
166
+ self.basis_mu=mu
167
+ self.basis_sigma=sigma
168
+ assert mu.size(0) == nb_basis
169
+ psi.append(GaussianBasisFunctions(mu=mu, sigma=sigma))
170
+
171
+ def add_retangular_basis_functions(self, psi, nb_basis, device):
172
+ width = torch.ones(nb_basis, device=device) / nb_basis
173
+
174
+ # Compute the centers (midpoints) of each bin
175
+ edges = torch.linspace(0, 1, nb_basis + 1, device=device)
176
+ mu = (edges[:-1] + edges[1:]) / 2
177
+ psi.append(RetangularBasisFunctions(mu=mu, sigma=width))
178
+
179
+ def value_function(self, x, inf=False):
180
+ if inf:
181
+ G = self.G_inf # [nb_sample+L,N]
182
+ else:
183
+ G = self.Gs[x.size(-1)] # [L,N]
184
+ B = torch.matmul(x, G) # [B,e,N]
185
+ B = B.permute(0,2,1) # [B,N,e]
186
+
187
+ return B
188
+
189
+ def update_inf(self, x):
190
+ if self.B_past is not None:
191
+ if self.sticky_memories:
192
+ bins = self.bins.clone()
193
+ bins[0]=-.000001
194
+ bins[-1]=1.000001
195
+ prob_density = self.compute_probability(self.score, t=bins)
196
+ cum_prob = torch.cumulative_trapezoid(prob_density, bins, dim=-1).to(self.device)
197
+ p = (cum_prob[..., 1:] - cum_prob[..., :-1]).sum(dim=(1, 2))
198
+ p = p / p.sum(-1, keepdim=True) # Normalize over the last dimension (bins)
199
+ p = dist.Categorical(p)
200
+ b = p.sample((self.nb_samples,))
201
+ t = self.bins_cat.sample((self.nb_samples, 1)).to(device=self.device)
202
+ ts = (t*(self.bins[b+1]-self.bins[b])/self.nb_bins_cat +self.bins[b]).transpose(1,0)
203
+ samples = self.psi[self.length][0].batch_evaluate(ts[0]).contiguous()
204
+
205
+ xm_tau = self.B_past.transpose(-1,-2).matmul(samples.transpose(-1,-2)) # [B,e,nb_samples]
206
+ else:
207
+ xm_tau = self.B_past.transpose(-1,-2).matmul(self.samples.transpose(-1,-2)) # [B,e,nb_samples]
208
+
209
+
210
+ x = torch.cat([xm_tau,x], dim=2) # [B,e,nb_samples+L]
211
+ B = self.value_function(x, inf=True) # [B,N,e]
212
+ else:
213
+ B = self.value_function(x)
214
+
215
+ self.B_past=B.detach()
216
+ self.x_past=x
217
+ return B
218
+
219
+ def score(self, t):
220
+ psis = self.psis[0].batch_evaluate(t)
221
+ query = self.queries/ (self.d_head ** 0.5) # divide by sqrt(d_head) [B,h,q,d]
222
+ keys = self.keys.transpose(-1, -2)
223
+ keys = torch.matmul(keys, psis.T) #[B,h,d,1]
224
+ scores = torch.matmul(query, keys) #[B,h,q,1]
225
+ return scores
226
+
227
+ def compute_probability(self, score_fn, num_points=1000, t=None):
228
+ """
229
+ Compute probability distribution p(t).
230
+
231
+ Args:
232
+ score_fn (callable): Function that computes z(t)
233
+ num_points (int): Number of points for numerical integration
234
+
235
+ Returns:
236
+ tuple: (probabilities, normalization constant)
237
+ """
238
+ if t is None:
239
+ # Create integration points
240
+ t = torch.linspace(0, 1, num_points).to(self.device)
241
+
242
+ scores = score_fn(t)
243
+ prob = torch.exp(scores) / torch.trapz(torch.exp(scores), t, dim=-1).unsqueeze(-1)
244
+ return prob
245
+
246
+ def expected_value(self, score_fn, num_points=1000):
247
+ """
248
+ Compute expected value E_p[V(t)] using nested integration.
249
+
250
+ Args:
251
+ score_fn (callable): Function that computes z(t)
252
+ value_fn (callable): Function that computes v(t)
253
+ num_points (int): Number of points for numerical integration
254
+
255
+ Returns:
256
+ torch.Tensor: Expected value
257
+ """
258
+ # Create integration points
259
+ t = torch.linspace(0, 1, num_points).to(self.device)
260
+
261
+ # Compute basis functions
262
+ self.psis = []
263
+ self.add_retangular_basis_functions(self.psis, self.attn_num_basis, self.device)
264
+ psi = self.psis[0].batch_evaluate(t)
265
+ # Compute probability distribution
266
+ prob = self.compute_probability(score_fn, num_points)
267
+ # Compute values at integration points
268
+ values = self.values
269
+ # Compute p(t) * psi(t)
270
+ # Reshape psi for broadcasting to match the shape of prob
271
+ psi_broadcasted = psi.unsqueeze(1).unsqueeze(2).unsqueeze(3)
272
+
273
+ # Expand psi to match the dimensions of prob (num_points, batch_size, n_head, qlen, 256)
274
+ psi_broadcasted = psi_broadcasted.expand(num_points, self.batch_size, self.n_head, self.qlen, self.attn_num_basis)
275
+ integrand = torch.matmul(prob.permute(3,0,1,2).unsqueeze(-1).unsqueeze(-1), psi_broadcasted.unsqueeze(-2)).permute(1, 2, 3, 4, 5, 0).squeeze(-3)
276
+
277
+ integral = torch.trapz(integrand, t, dim=-1)
278
+ # Matrix multiply with values
279
+ expected_value = torch.matmul(integral, values) # [B, h, q, d]
280
+
281
+ return expected_value
282
+
283
+ def forward(self, k, q, new_doc, layer_n):
284
+ self.device = k.device
285
+ if self.continuous:
286
+ klen = int(k.size(1)/(14*14))
287
+ self.length = klen
288
+ batch_size = k.size(0) #batch size
289
+ qlen = q.size(1) #query length
290
+ self.qlen = qlen
291
+ self.batch_size = batch_size
292
+ self.d_head = self.head_size #head size
293
+ self.get_basis(klen, klen)
294
+ # clean memory if going through different document
295
+ if new_doc:
296
+ self.B_past=None
297
+ self.x_past=None
298
+
299
+ k = k.reshape(batch_size, klen, 14, 14, 1024).mean(dim=(2, 3))
300
+ k = k.transpose(1,2)
301
+ # perform memory update
302
+ if self.infinite_memory:
303
+ B = self.update_inf(k)
304
+ else: # compute input continuous approximation
305
+ B = self.value_function(k) # [B,N,e]
306
+ keys = self.proj_key(B)
307
+ values = self.proj_value(B)
308
+ query = q
309
+ self.queries = query.view(batch_size,qlen,self.n_head,self.d_head).transpose(1,2) # [B,h,q,d]
310
+ self.keys = keys.view(batch_size,self.attn_num_basis,self.n_head,self.d_head).transpose(1,2) # [B,h,N,d]
311
+ self.values = values.view(batch_size,self.attn_num_basis,self.n_head,self.d_head).transpose(1,2) # [B, h, q, N]
312
+ context = self.expected_value(self.score) # Shape [1, 32, 768]
313
+
314
+ return context.contiguous().transpose(1,2).reshape(1, qlen, -1)
315
+
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecf6a804c5af89465362453e591d8c3358cd97ad48247baabfc5b070edad2e07
3
+ size 4971600800
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf5745ae7b321d884e62f74589758abee57e79c6ae138e1b1f6877b5cad20565
3
+ size 4915917440
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:475cbd791fe87314409771c7f9651e5f7237c43e8eb5d9662714ff1d3d4fbc04
3
+ size 4999820720
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7648e67deaaa08ce6f73df0f96963c62dba9702927390a73c69bdc328d6f5d27
3
+ size 1499540784
model.safetensors.index.json ADDED
@@ -0,0 +1,934 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 16386763776
4
+ },
5
+ "weight_map": {
6
+ "extra_query_tokens": "model-00001-of-00004.safetensors",
7
+ "mistral_model.lm_head.weight": "model-00004-of-00004.safetensors",
8
+ "mistral_model.model.embed_tokens.weight": "model-00001-of-00004.safetensors",
9
+ "mistral_model.model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
10
+ "mistral_model.model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
11
+ "mistral_model.model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
12
+ "mistral_model.model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
13
+ "mistral_model.model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
14
+ "mistral_model.model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
15
+ "mistral_model.model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
16
+ "mistral_model.model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
17
+ "mistral_model.model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
18
+ "mistral_model.model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
19
+ "mistral_model.model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
20
+ "mistral_model.model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
21
+ "mistral_model.model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
22
+ "mistral_model.model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
23
+ "mistral_model.model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
24
+ "mistral_model.model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
25
+ "mistral_model.model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
26
+ "mistral_model.model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
27
+ "mistral_model.model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
28
+ "mistral_model.model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
29
+ "mistral_model.model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
30
+ "mistral_model.model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
31
+ "mistral_model.model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
32
+ "mistral_model.model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
33
+ "mistral_model.model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
34
+ "mistral_model.model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
35
+ "mistral_model.model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
36
+ "mistral_model.model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
37
+ "mistral_model.model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
38
+ "mistral_model.model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
39
+ "mistral_model.model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
40
+ "mistral_model.model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
41
+ "mistral_model.model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
42
+ "mistral_model.model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
43
+ "mistral_model.model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
44
+ "mistral_model.model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
45
+ "mistral_model.model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
46
+ "mistral_model.model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
47
+ "mistral_model.model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
48
+ "mistral_model.model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
49
+ "mistral_model.model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
50
+ "mistral_model.model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
51
+ "mistral_model.model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
52
+ "mistral_model.model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
53
+ "mistral_model.model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
54
+ "mistral_model.model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
55
+ "mistral_model.model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
56
+ "mistral_model.model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
57
+ "mistral_model.model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
58
+ "mistral_model.model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
59
+ "mistral_model.model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
60
+ "mistral_model.model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
61
+ "mistral_model.model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
62
+ "mistral_model.model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
63
+ "mistral_model.model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
64
+ "mistral_model.model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
65
+ "mistral_model.model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
66
+ "mistral_model.model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
67
+ "mistral_model.model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
68
+ "mistral_model.model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
69
+ "mistral_model.model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
70
+ "mistral_model.model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
71
+ "mistral_model.model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
72
+ "mistral_model.model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
73
+ "mistral_model.model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
74
+ "mistral_model.model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
75
+ "mistral_model.model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
76
+ "mistral_model.model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
77
+ "mistral_model.model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
78
+ "mistral_model.model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
79
+ "mistral_model.model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
80
+ "mistral_model.model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
81
+ "mistral_model.model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
82
+ "mistral_model.model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
83
+ "mistral_model.model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
84
+ "mistral_model.model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
85
+ "mistral_model.model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
86
+ "mistral_model.model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
87
+ "mistral_model.model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
88
+ "mistral_model.model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
89
+ "mistral_model.model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
90
+ "mistral_model.model.layers.17.input_layernorm.weight": "model-00003-of-00004.safetensors",
91
+ "mistral_model.model.layers.17.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
92
+ "mistral_model.model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
93
+ "mistral_model.model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
94
+ "mistral_model.model.layers.17.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
95
+ "mistral_model.model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
96
+ "mistral_model.model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
97
+ "mistral_model.model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
98
+ "mistral_model.model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
99
+ "mistral_model.model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
100
+ "mistral_model.model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
101
+ "mistral_model.model.layers.18.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
102
+ "mistral_model.model.layers.18.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
103
+ "mistral_model.model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
104
+ "mistral_model.model.layers.18.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
105
+ "mistral_model.model.layers.18.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
106
+ "mistral_model.model.layers.18.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
107
+ "mistral_model.model.layers.18.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
108
+ "mistral_model.model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
109
+ "mistral_model.model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
110
+ "mistral_model.model.layers.19.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
111
+ "mistral_model.model.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
112
+ "mistral_model.model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
113
+ "mistral_model.model.layers.19.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
114
+ "mistral_model.model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
115
+ "mistral_model.model.layers.19.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
116
+ "mistral_model.model.layers.19.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
117
+ "mistral_model.model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
118
+ "mistral_model.model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
119
+ "mistral_model.model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
120
+ "mistral_model.model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
121
+ "mistral_model.model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
122
+ "mistral_model.model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
123
+ "mistral_model.model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
124
+ "mistral_model.model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
125
+ "mistral_model.model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
126
+ "mistral_model.model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
127
+ "mistral_model.model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
128
+ "mistral_model.model.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
129
+ "mistral_model.model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
130
+ "mistral_model.model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
131
+ "mistral_model.model.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
132
+ "mistral_model.model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
133
+ "mistral_model.model.layers.20.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
134
+ "mistral_model.model.layers.20.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
135
+ "mistral_model.model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
136
+ "mistral_model.model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
137
+ "mistral_model.model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
138
+ "mistral_model.model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
139
+ "mistral_model.model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
140
+ "mistral_model.model.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
141
+ "mistral_model.model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
142
+ "mistral_model.model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
143
+ "mistral_model.model.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
144
+ "mistral_model.model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
145
+ "mistral_model.model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
146
+ "mistral_model.model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
147
+ "mistral_model.model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
148
+ "mistral_model.model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
149
+ "mistral_model.model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
150
+ "mistral_model.model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
151
+ "mistral_model.model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
152
+ "mistral_model.model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
153
+ "mistral_model.model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
154
+ "mistral_model.model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
155
+ "mistral_model.model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
156
+ "mistral_model.model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
157
+ "mistral_model.model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
158
+ "mistral_model.model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
159
+ "mistral_model.model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
160
+ "mistral_model.model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
161
+ "mistral_model.model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
162
+ "mistral_model.model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
163
+ "mistral_model.model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
164
+ "mistral_model.model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
165
+ "mistral_model.model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
166
+ "mistral_model.model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
167
+ "mistral_model.model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
168
+ "mistral_model.model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
169
+ "mistral_model.model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
170
+ "mistral_model.model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
171
+ "mistral_model.model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
172
+ "mistral_model.model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
173
+ "mistral_model.model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
174
+ "mistral_model.model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
175
+ "mistral_model.model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
176
+ "mistral_model.model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
177
+ "mistral_model.model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
178
+ "mistral_model.model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
179
+ "mistral_model.model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
180
+ "mistral_model.model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
181
+ "mistral_model.model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
182
+ "mistral_model.model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
183
+ "mistral_model.model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
184
+ "mistral_model.model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
185
+ "mistral_model.model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
186
+ "mistral_model.model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
187
+ "mistral_model.model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
188
+ "mistral_model.model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
189
+ "mistral_model.model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
190
+ "mistral_model.model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
191
+ "mistral_model.model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
192
+ "mistral_model.model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
193
+ "mistral_model.model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
194
+ "mistral_model.model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
195
+ "mistral_model.model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
196
+ "mistral_model.model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
197
+ "mistral_model.model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
198
+ "mistral_model.model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
199
+ "mistral_model.model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
200
+ "mistral_model.model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
201
+ "mistral_model.model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
202
+ "mistral_model.model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
203
+ "mistral_model.model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
204
+ "mistral_model.model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
205
+ "mistral_model.model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
206
+ "mistral_model.model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
207
+ "mistral_model.model.layers.29.input_layernorm.weight": "model-00004-of-00004.safetensors",
208
+ "mistral_model.model.layers.29.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
209
+ "mistral_model.model.layers.29.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
210
+ "mistral_model.model.layers.29.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
211
+ "mistral_model.model.layers.29.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
212
+ "mistral_model.model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
213
+ "mistral_model.model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
214
+ "mistral_model.model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
215
+ "mistral_model.model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
216
+ "mistral_model.model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
217
+ "mistral_model.model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
218
+ "mistral_model.model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
219
+ "mistral_model.model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
220
+ "mistral_model.model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
221
+ "mistral_model.model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
222
+ "mistral_model.model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
223
+ "mistral_model.model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
224
+ "mistral_model.model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
225
+ "mistral_model.model.layers.30.input_layernorm.weight": "model-00004-of-00004.safetensors",
226
+ "mistral_model.model.layers.30.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
227
+ "mistral_model.model.layers.30.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
228
+ "mistral_model.model.layers.30.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
229
+ "mistral_model.model.layers.30.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
230
+ "mistral_model.model.layers.30.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
231
+ "mistral_model.model.layers.30.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
232
+ "mistral_model.model.layers.30.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
233
+ "mistral_model.model.layers.30.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
234
+ "mistral_model.model.layers.31.input_layernorm.weight": "model-00004-of-00004.safetensors",
235
+ "mistral_model.model.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
236
+ "mistral_model.model.layers.31.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
237
+ "mistral_model.model.layers.31.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
238
+ "mistral_model.model.layers.31.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
239
+ "mistral_model.model.layers.31.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
240
+ "mistral_model.model.layers.31.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
241
+ "mistral_model.model.layers.31.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
242
+ "mistral_model.model.layers.31.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
243
+ "mistral_model.model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
244
+ "mistral_model.model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
245
+ "mistral_model.model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
246
+ "mistral_model.model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
247
+ "mistral_model.model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
248
+ "mistral_model.model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
249
+ "mistral_model.model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
250
+ "mistral_model.model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
251
+ "mistral_model.model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
252
+ "mistral_model.model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
253
+ "mistral_model.model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
254
+ "mistral_model.model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
255
+ "mistral_model.model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
256
+ "mistral_model.model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
257
+ "mistral_model.model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
258
+ "mistral_model.model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
259
+ "mistral_model.model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
260
+ "mistral_model.model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
261
+ "mistral_model.model.layers.6.input_layernorm.weight": "model-00002-of-00004.safetensors",
262
+ "mistral_model.model.layers.6.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
263
+ "mistral_model.model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
264
+ "mistral_model.model.layers.6.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
265
+ "mistral_model.model.layers.6.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
266
+ "mistral_model.model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
267
+ "mistral_model.model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
268
+ "mistral_model.model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
269
+ "mistral_model.model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
270
+ "mistral_model.model.layers.7.input_layernorm.weight": "model-00002-of-00004.safetensors",
271
+ "mistral_model.model.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
272
+ "mistral_model.model.layers.7.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
273
+ "mistral_model.model.layers.7.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
274
+ "mistral_model.model.layers.7.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
275
+ "mistral_model.model.layers.7.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
276
+ "mistral_model.model.layers.7.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
277
+ "mistral_model.model.layers.7.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
278
+ "mistral_model.model.layers.7.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
279
+ "mistral_model.model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
280
+ "mistral_model.model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
281
+ "mistral_model.model.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
282
+ "mistral_model.model.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
283
+ "mistral_model.model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
284
+ "mistral_model.model.layers.8.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
285
+ "mistral_model.model.layers.8.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
286
+ "mistral_model.model.layers.8.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
287
+ "mistral_model.model.layers.8.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
288
+ "mistral_model.model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
289
+ "mistral_model.model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
290
+ "mistral_model.model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
291
+ "mistral_model.model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
292
+ "mistral_model.model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
293
+ "mistral_model.model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
294
+ "mistral_model.model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
295
+ "mistral_model.model.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
296
+ "mistral_model.model.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
297
+ "mistral_model.model.norm.weight": "model-00004-of-00004.safetensors",
298
+ "mistral_proj.bias": "model-00004-of-00004.safetensors",
299
+ "mistral_proj.weight": "model-00004-of-00004.safetensors",
300
+ "qformer.bert.embeddings.LayerNorm.bias": "model-00001-of-00004.safetensors",
301
+ "qformer.bert.embeddings.LayerNorm.weight": "model-00001-of-00004.safetensors",
302
+ "qformer.bert.embeddings.position_embeddings.weight": "model-00001-of-00004.safetensors",
303
+ "qformer.bert.embeddings.position_ids": "model-00001-of-00004.safetensors",
304
+ "qformer.bert.embeddings.word_embeddings.weight": "model-00001-of-00004.safetensors",
305
+ "qformer.bert.encoder.layer.0.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
306
+ "qformer.bert.encoder.layer.0.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
307
+ "qformer.bert.encoder.layer.0.attention.output.dense.bias": "model-00001-of-00004.safetensors",
308
+ "qformer.bert.encoder.layer.0.attention.output.dense.weight": "model-00001-of-00004.safetensors",
309
+ "qformer.bert.encoder.layer.0.attention.self.key.bias": "model-00001-of-00004.safetensors",
310
+ "qformer.bert.encoder.layer.0.attention.self.key.weight": "model-00001-of-00004.safetensors",
311
+ "qformer.bert.encoder.layer.0.attention.self.query.bias": "model-00001-of-00004.safetensors",
312
+ "qformer.bert.encoder.layer.0.attention.self.query.weight": "model-00001-of-00004.safetensors",
313
+ "qformer.bert.encoder.layer.0.attention.self.value.bias": "model-00001-of-00004.safetensors",
314
+ "qformer.bert.encoder.layer.0.attention.self.value.weight": "model-00001-of-00004.safetensors",
315
+ "qformer.bert.encoder.layer.0.crossattention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
316
+ "qformer.bert.encoder.layer.0.crossattention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
317
+ "qformer.bert.encoder.layer.0.crossattention.output.dense.bias": "model-00001-of-00004.safetensors",
318
+ "qformer.bert.encoder.layer.0.crossattention.output.dense.weight": "model-00001-of-00004.safetensors",
319
+ "qformer.bert.encoder.layer.0.crossattention.self.key.bias": "model-00001-of-00004.safetensors",
320
+ "qformer.bert.encoder.layer.0.crossattention.self.key.weight": "model-00001-of-00004.safetensors",
321
+ "qformer.bert.encoder.layer.0.crossattention.self.query.bias": "model-00001-of-00004.safetensors",
322
+ "qformer.bert.encoder.layer.0.crossattention.self.query.weight": "model-00001-of-00004.safetensors",
323
+ "qformer.bert.encoder.layer.0.crossattention.self.value.bias": "model-00001-of-00004.safetensors",
324
+ "qformer.bert.encoder.layer.0.crossattention.self.value.weight": "model-00001-of-00004.safetensors",
325
+ "qformer.bert.encoder.layer.0.intermediate.dense.bias": "model-00001-of-00004.safetensors",
326
+ "qformer.bert.encoder.layer.0.intermediate.dense.weight": "model-00001-of-00004.safetensors",
327
+ "qformer.bert.encoder.layer.0.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
328
+ "qformer.bert.encoder.layer.0.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
329
+ "qformer.bert.encoder.layer.0.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
330
+ "qformer.bert.encoder.layer.0.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
331
+ "qformer.bert.encoder.layer.0.output.dense.bias": "model-00001-of-00004.safetensors",
332
+ "qformer.bert.encoder.layer.0.output.dense.weight": "model-00001-of-00004.safetensors",
333
+ "qformer.bert.encoder.layer.0.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
334
+ "qformer.bert.encoder.layer.0.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
335
+ "qformer.bert.encoder.layer.0.output_query.dense.bias": "model-00001-of-00004.safetensors",
336
+ "qformer.bert.encoder.layer.0.output_query.dense.weight": "model-00001-of-00004.safetensors",
337
+ "qformer.bert.encoder.layer.1.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
338
+ "qformer.bert.encoder.layer.1.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
339
+ "qformer.bert.encoder.layer.1.attention.output.dense.bias": "model-00001-of-00004.safetensors",
340
+ "qformer.bert.encoder.layer.1.attention.output.dense.weight": "model-00001-of-00004.safetensors",
341
+ "qformer.bert.encoder.layer.1.attention.self.key.bias": "model-00001-of-00004.safetensors",
342
+ "qformer.bert.encoder.layer.1.attention.self.key.weight": "model-00001-of-00004.safetensors",
343
+ "qformer.bert.encoder.layer.1.attention.self.query.bias": "model-00001-of-00004.safetensors",
344
+ "qformer.bert.encoder.layer.1.attention.self.query.weight": "model-00001-of-00004.safetensors",
345
+ "qformer.bert.encoder.layer.1.attention.self.value.bias": "model-00001-of-00004.safetensors",
346
+ "qformer.bert.encoder.layer.1.attention.self.value.weight": "model-00001-of-00004.safetensors",
347
+ "qformer.bert.encoder.layer.1.intermediate.dense.bias": "model-00001-of-00004.safetensors",
348
+ "qformer.bert.encoder.layer.1.intermediate.dense.weight": "model-00001-of-00004.safetensors",
349
+ "qformer.bert.encoder.layer.1.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
350
+ "qformer.bert.encoder.layer.1.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
351
+ "qformer.bert.encoder.layer.1.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
352
+ "qformer.bert.encoder.layer.1.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
353
+ "qformer.bert.encoder.layer.1.output.dense.bias": "model-00001-of-00004.safetensors",
354
+ "qformer.bert.encoder.layer.1.output.dense.weight": "model-00001-of-00004.safetensors",
355
+ "qformer.bert.encoder.layer.1.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
356
+ "qformer.bert.encoder.layer.1.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
357
+ "qformer.bert.encoder.layer.1.output_query.dense.bias": "model-00001-of-00004.safetensors",
358
+ "qformer.bert.encoder.layer.1.output_query.dense.weight": "model-00001-of-00004.safetensors",
359
+ "qformer.bert.encoder.layer.10.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
360
+ "qformer.bert.encoder.layer.10.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
361
+ "qformer.bert.encoder.layer.10.attention.output.dense.bias": "model-00001-of-00004.safetensors",
362
+ "qformer.bert.encoder.layer.10.attention.output.dense.weight": "model-00001-of-00004.safetensors",
363
+ "qformer.bert.encoder.layer.10.attention.self.key.bias": "model-00001-of-00004.safetensors",
364
+ "qformer.bert.encoder.layer.10.attention.self.key.weight": "model-00001-of-00004.safetensors",
365
+ "qformer.bert.encoder.layer.10.attention.self.query.bias": "model-00001-of-00004.safetensors",
366
+ "qformer.bert.encoder.layer.10.attention.self.query.weight": "model-00001-of-00004.safetensors",
367
+ "qformer.bert.encoder.layer.10.attention.self.value.bias": "model-00001-of-00004.safetensors",
368
+ "qformer.bert.encoder.layer.10.attention.self.value.weight": "model-00001-of-00004.safetensors",
369
+ "qformer.bert.encoder.layer.10.crossattention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
370
+ "qformer.bert.encoder.layer.10.crossattention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
371
+ "qformer.bert.encoder.layer.10.crossattention.output.dense.bias": "model-00001-of-00004.safetensors",
372
+ "qformer.bert.encoder.layer.10.crossattention.output.dense.weight": "model-00001-of-00004.safetensors",
373
+ "qformer.bert.encoder.layer.10.crossattention.self.key.bias": "model-00001-of-00004.safetensors",
374
+ "qformer.bert.encoder.layer.10.crossattention.self.key.weight": "model-00001-of-00004.safetensors",
375
+ "qformer.bert.encoder.layer.10.crossattention.self.query.bias": "model-00001-of-00004.safetensors",
376
+ "qformer.bert.encoder.layer.10.crossattention.self.query.weight": "model-00001-of-00004.safetensors",
377
+ "qformer.bert.encoder.layer.10.crossattention.self.value.bias": "model-00001-of-00004.safetensors",
378
+ "qformer.bert.encoder.layer.10.crossattention.self.value.weight": "model-00001-of-00004.safetensors",
379
+ "qformer.bert.encoder.layer.10.intermediate.dense.bias": "model-00001-of-00004.safetensors",
380
+ "qformer.bert.encoder.layer.10.intermediate.dense.weight": "model-00001-of-00004.safetensors",
381
+ "qformer.bert.encoder.layer.10.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
382
+ "qformer.bert.encoder.layer.10.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
383
+ "qformer.bert.encoder.layer.10.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
384
+ "qformer.bert.encoder.layer.10.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
385
+ "qformer.bert.encoder.layer.10.output.dense.bias": "model-00001-of-00004.safetensors",
386
+ "qformer.bert.encoder.layer.10.output.dense.weight": "model-00001-of-00004.safetensors",
387
+ "qformer.bert.encoder.layer.10.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
388
+ "qformer.bert.encoder.layer.10.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
389
+ "qformer.bert.encoder.layer.10.output_query.dense.bias": "model-00001-of-00004.safetensors",
390
+ "qformer.bert.encoder.layer.10.output_query.dense.weight": "model-00001-of-00004.safetensors",
391
+ "qformer.bert.encoder.layer.11.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
392
+ "qformer.bert.encoder.layer.11.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
393
+ "qformer.bert.encoder.layer.11.attention.output.dense.bias": "model-00001-of-00004.safetensors",
394
+ "qformer.bert.encoder.layer.11.attention.output.dense.weight": "model-00001-of-00004.safetensors",
395
+ "qformer.bert.encoder.layer.11.attention.self.key.bias": "model-00001-of-00004.safetensors",
396
+ "qformer.bert.encoder.layer.11.attention.self.key.weight": "model-00001-of-00004.safetensors",
397
+ "qformer.bert.encoder.layer.11.attention.self.query.bias": "model-00001-of-00004.safetensors",
398
+ "qformer.bert.encoder.layer.11.attention.self.query.weight": "model-00001-of-00004.safetensors",
399
+ "qformer.bert.encoder.layer.11.attention.self.value.bias": "model-00001-of-00004.safetensors",
400
+ "qformer.bert.encoder.layer.11.attention.self.value.weight": "model-00001-of-00004.safetensors",
401
+ "qformer.bert.encoder.layer.11.intermediate.dense.bias": "model-00001-of-00004.safetensors",
402
+ "qformer.bert.encoder.layer.11.intermediate.dense.weight": "model-00001-of-00004.safetensors",
403
+ "qformer.bert.encoder.layer.11.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
404
+ "qformer.bert.encoder.layer.11.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
405
+ "qformer.bert.encoder.layer.11.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
406
+ "qformer.bert.encoder.layer.11.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
407
+ "qformer.bert.encoder.layer.11.output.dense.bias": "model-00001-of-00004.safetensors",
408
+ "qformer.bert.encoder.layer.11.output.dense.weight": "model-00001-of-00004.safetensors",
409
+ "qformer.bert.encoder.layer.11.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
410
+ "qformer.bert.encoder.layer.11.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
411
+ "qformer.bert.encoder.layer.11.output_query.dense.bias": "model-00001-of-00004.safetensors",
412
+ "qformer.bert.encoder.layer.11.output_query.dense.weight": "model-00001-of-00004.safetensors",
413
+ "qformer.bert.encoder.layer.2.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
414
+ "qformer.bert.encoder.layer.2.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
415
+ "qformer.bert.encoder.layer.2.attention.output.dense.bias": "model-00001-of-00004.safetensors",
416
+ "qformer.bert.encoder.layer.2.attention.output.dense.weight": "model-00001-of-00004.safetensors",
417
+ "qformer.bert.encoder.layer.2.attention.self.key.bias": "model-00001-of-00004.safetensors",
418
+ "qformer.bert.encoder.layer.2.attention.self.key.weight": "model-00001-of-00004.safetensors",
419
+ "qformer.bert.encoder.layer.2.attention.self.query.bias": "model-00001-of-00004.safetensors",
420
+ "qformer.bert.encoder.layer.2.attention.self.query.weight": "model-00001-of-00004.safetensors",
421
+ "qformer.bert.encoder.layer.2.attention.self.value.bias": "model-00001-of-00004.safetensors",
422
+ "qformer.bert.encoder.layer.2.attention.self.value.weight": "model-00001-of-00004.safetensors",
423
+ "qformer.bert.encoder.layer.2.crossattention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
424
+ "qformer.bert.encoder.layer.2.crossattention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
425
+ "qformer.bert.encoder.layer.2.crossattention.output.dense.bias": "model-00001-of-00004.safetensors",
426
+ "qformer.bert.encoder.layer.2.crossattention.output.dense.weight": "model-00001-of-00004.safetensors",
427
+ "qformer.bert.encoder.layer.2.crossattention.self.key.bias": "model-00001-of-00004.safetensors",
428
+ "qformer.bert.encoder.layer.2.crossattention.self.key.weight": "model-00001-of-00004.safetensors",
429
+ "qformer.bert.encoder.layer.2.crossattention.self.query.bias": "model-00001-of-00004.safetensors",
430
+ "qformer.bert.encoder.layer.2.crossattention.self.query.weight": "model-00001-of-00004.safetensors",
431
+ "qformer.bert.encoder.layer.2.crossattention.self.value.bias": "model-00001-of-00004.safetensors",
432
+ "qformer.bert.encoder.layer.2.crossattention.self.value.weight": "model-00001-of-00004.safetensors",
433
+ "qformer.bert.encoder.layer.2.intermediate.dense.bias": "model-00001-of-00004.safetensors",
434
+ "qformer.bert.encoder.layer.2.intermediate.dense.weight": "model-00001-of-00004.safetensors",
435
+ "qformer.bert.encoder.layer.2.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
436
+ "qformer.bert.encoder.layer.2.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
437
+ "qformer.bert.encoder.layer.2.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
438
+ "qformer.bert.encoder.layer.2.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
439
+ "qformer.bert.encoder.layer.2.output.dense.bias": "model-00001-of-00004.safetensors",
440
+ "qformer.bert.encoder.layer.2.output.dense.weight": "model-00001-of-00004.safetensors",
441
+ "qformer.bert.encoder.layer.2.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
442
+ "qformer.bert.encoder.layer.2.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
443
+ "qformer.bert.encoder.layer.2.output_query.dense.bias": "model-00001-of-00004.safetensors",
444
+ "qformer.bert.encoder.layer.2.output_query.dense.weight": "model-00001-of-00004.safetensors",
445
+ "qformer.bert.encoder.layer.3.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
446
+ "qformer.bert.encoder.layer.3.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
447
+ "qformer.bert.encoder.layer.3.attention.output.dense.bias": "model-00001-of-00004.safetensors",
448
+ "qformer.bert.encoder.layer.3.attention.output.dense.weight": "model-00001-of-00004.safetensors",
449
+ "qformer.bert.encoder.layer.3.attention.self.key.bias": "model-00001-of-00004.safetensors",
450
+ "qformer.bert.encoder.layer.3.attention.self.key.weight": "model-00001-of-00004.safetensors",
451
+ "qformer.bert.encoder.layer.3.attention.self.query.bias": "model-00001-of-00004.safetensors",
452
+ "qformer.bert.encoder.layer.3.attention.self.query.weight": "model-00001-of-00004.safetensors",
453
+ "qformer.bert.encoder.layer.3.attention.self.value.bias": "model-00001-of-00004.safetensors",
454
+ "qformer.bert.encoder.layer.3.attention.self.value.weight": "model-00001-of-00004.safetensors",
455
+ "qformer.bert.encoder.layer.3.intermediate.dense.bias": "model-00001-of-00004.safetensors",
456
+ "qformer.bert.encoder.layer.3.intermediate.dense.weight": "model-00001-of-00004.safetensors",
457
+ "qformer.bert.encoder.layer.3.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
458
+ "qformer.bert.encoder.layer.3.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
459
+ "qformer.bert.encoder.layer.3.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
460
+ "qformer.bert.encoder.layer.3.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
461
+ "qformer.bert.encoder.layer.3.output.dense.bias": "model-00001-of-00004.safetensors",
462
+ "qformer.bert.encoder.layer.3.output.dense.weight": "model-00001-of-00004.safetensors",
463
+ "qformer.bert.encoder.layer.3.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
464
+ "qformer.bert.encoder.layer.3.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
465
+ "qformer.bert.encoder.layer.3.output_query.dense.bias": "model-00001-of-00004.safetensors",
466
+ "qformer.bert.encoder.layer.3.output_query.dense.weight": "model-00001-of-00004.safetensors",
467
+ "qformer.bert.encoder.layer.4.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
468
+ "qformer.bert.encoder.layer.4.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
469
+ "qformer.bert.encoder.layer.4.attention.output.dense.bias": "model-00001-of-00004.safetensors",
470
+ "qformer.bert.encoder.layer.4.attention.output.dense.weight": "model-00001-of-00004.safetensors",
471
+ "qformer.bert.encoder.layer.4.attention.self.key.bias": "model-00001-of-00004.safetensors",
472
+ "qformer.bert.encoder.layer.4.attention.self.key.weight": "model-00001-of-00004.safetensors",
473
+ "qformer.bert.encoder.layer.4.attention.self.query.bias": "model-00001-of-00004.safetensors",
474
+ "qformer.bert.encoder.layer.4.attention.self.query.weight": "model-00001-of-00004.safetensors",
475
+ "qformer.bert.encoder.layer.4.attention.self.value.bias": "model-00001-of-00004.safetensors",
476
+ "qformer.bert.encoder.layer.4.attention.self.value.weight": "model-00001-of-00004.safetensors",
477
+ "qformer.bert.encoder.layer.4.crossattention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
478
+ "qformer.bert.encoder.layer.4.crossattention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
479
+ "qformer.bert.encoder.layer.4.crossattention.output.dense.bias": "model-00001-of-00004.safetensors",
480
+ "qformer.bert.encoder.layer.4.crossattention.output.dense.weight": "model-00001-of-00004.safetensors",
481
+ "qformer.bert.encoder.layer.4.crossattention.self.key.bias": "model-00001-of-00004.safetensors",
482
+ "qformer.bert.encoder.layer.4.crossattention.self.key.weight": "model-00001-of-00004.safetensors",
483
+ "qformer.bert.encoder.layer.4.crossattention.self.query.bias": "model-00001-of-00004.safetensors",
484
+ "qformer.bert.encoder.layer.4.crossattention.self.query.weight": "model-00001-of-00004.safetensors",
485
+ "qformer.bert.encoder.layer.4.crossattention.self.value.bias": "model-00001-of-00004.safetensors",
486
+ "qformer.bert.encoder.layer.4.crossattention.self.value.weight": "model-00001-of-00004.safetensors",
487
+ "qformer.bert.encoder.layer.4.intermediate.dense.bias": "model-00001-of-00004.safetensors",
488
+ "qformer.bert.encoder.layer.4.intermediate.dense.weight": "model-00001-of-00004.safetensors",
489
+ "qformer.bert.encoder.layer.4.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
490
+ "qformer.bert.encoder.layer.4.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
491
+ "qformer.bert.encoder.layer.4.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
492
+ "qformer.bert.encoder.layer.4.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
493
+ "qformer.bert.encoder.layer.4.output.dense.bias": "model-00001-of-00004.safetensors",
494
+ "qformer.bert.encoder.layer.4.output.dense.weight": "model-00001-of-00004.safetensors",
495
+ "qformer.bert.encoder.layer.4.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
496
+ "qformer.bert.encoder.layer.4.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
497
+ "qformer.bert.encoder.layer.4.output_query.dense.bias": "model-00001-of-00004.safetensors",
498
+ "qformer.bert.encoder.layer.4.output_query.dense.weight": "model-00001-of-00004.safetensors",
499
+ "qformer.bert.encoder.layer.5.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
500
+ "qformer.bert.encoder.layer.5.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
501
+ "qformer.bert.encoder.layer.5.attention.output.dense.bias": "model-00001-of-00004.safetensors",
502
+ "qformer.bert.encoder.layer.5.attention.output.dense.weight": "model-00001-of-00004.safetensors",
503
+ "qformer.bert.encoder.layer.5.attention.self.key.bias": "model-00001-of-00004.safetensors",
504
+ "qformer.bert.encoder.layer.5.attention.self.key.weight": "model-00001-of-00004.safetensors",
505
+ "qformer.bert.encoder.layer.5.attention.self.query.bias": "model-00001-of-00004.safetensors",
506
+ "qformer.bert.encoder.layer.5.attention.self.query.weight": "model-00001-of-00004.safetensors",
507
+ "qformer.bert.encoder.layer.5.attention.self.value.bias": "model-00001-of-00004.safetensors",
508
+ "qformer.bert.encoder.layer.5.attention.self.value.weight": "model-00001-of-00004.safetensors",
509
+ "qformer.bert.encoder.layer.5.intermediate.dense.bias": "model-00001-of-00004.safetensors",
510
+ "qformer.bert.encoder.layer.5.intermediate.dense.weight": "model-00001-of-00004.safetensors",
511
+ "qformer.bert.encoder.layer.5.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
512
+ "qformer.bert.encoder.layer.5.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
513
+ "qformer.bert.encoder.layer.5.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
514
+ "qformer.bert.encoder.layer.5.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
515
+ "qformer.bert.encoder.layer.5.output.dense.bias": "model-00001-of-00004.safetensors",
516
+ "qformer.bert.encoder.layer.5.output.dense.weight": "model-00001-of-00004.safetensors",
517
+ "qformer.bert.encoder.layer.5.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
518
+ "qformer.bert.encoder.layer.5.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
519
+ "qformer.bert.encoder.layer.5.output_query.dense.bias": "model-00001-of-00004.safetensors",
520
+ "qformer.bert.encoder.layer.5.output_query.dense.weight": "model-00001-of-00004.safetensors",
521
+ "qformer.bert.encoder.layer.6.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
522
+ "qformer.bert.encoder.layer.6.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
523
+ "qformer.bert.encoder.layer.6.attention.output.dense.bias": "model-00001-of-00004.safetensors",
524
+ "qformer.bert.encoder.layer.6.attention.output.dense.weight": "model-00001-of-00004.safetensors",
525
+ "qformer.bert.encoder.layer.6.attention.self.key.bias": "model-00001-of-00004.safetensors",
526
+ "qformer.bert.encoder.layer.6.attention.self.key.weight": "model-00001-of-00004.safetensors",
527
+ "qformer.bert.encoder.layer.6.attention.self.query.bias": "model-00001-of-00004.safetensors",
528
+ "qformer.bert.encoder.layer.6.attention.self.query.weight": "model-00001-of-00004.safetensors",
529
+ "qformer.bert.encoder.layer.6.attention.self.value.bias": "model-00001-of-00004.safetensors",
530
+ "qformer.bert.encoder.layer.6.attention.self.value.weight": "model-00001-of-00004.safetensors",
531
+ "qformer.bert.encoder.layer.6.crossattention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
532
+ "qformer.bert.encoder.layer.6.crossattention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
533
+ "qformer.bert.encoder.layer.6.crossattention.output.dense.bias": "model-00001-of-00004.safetensors",
534
+ "qformer.bert.encoder.layer.6.crossattention.output.dense.weight": "model-00001-of-00004.safetensors",
535
+ "qformer.bert.encoder.layer.6.crossattention.self.key.bias": "model-00001-of-00004.safetensors",
536
+ "qformer.bert.encoder.layer.6.crossattention.self.key.weight": "model-00001-of-00004.safetensors",
537
+ "qformer.bert.encoder.layer.6.crossattention.self.query.bias": "model-00001-of-00004.safetensors",
538
+ "qformer.bert.encoder.layer.6.crossattention.self.query.weight": "model-00001-of-00004.safetensors",
539
+ "qformer.bert.encoder.layer.6.crossattention.self.value.bias": "model-00001-of-00004.safetensors",
540
+ "qformer.bert.encoder.layer.6.crossattention.self.value.weight": "model-00001-of-00004.safetensors",
541
+ "qformer.bert.encoder.layer.6.intermediate.dense.bias": "model-00001-of-00004.safetensors",
542
+ "qformer.bert.encoder.layer.6.intermediate.dense.weight": "model-00001-of-00004.safetensors",
543
+ "qformer.bert.encoder.layer.6.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
544
+ "qformer.bert.encoder.layer.6.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
545
+ "qformer.bert.encoder.layer.6.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
546
+ "qformer.bert.encoder.layer.6.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
547
+ "qformer.bert.encoder.layer.6.output.dense.bias": "model-00001-of-00004.safetensors",
548
+ "qformer.bert.encoder.layer.6.output.dense.weight": "model-00001-of-00004.safetensors",
549
+ "qformer.bert.encoder.layer.6.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
550
+ "qformer.bert.encoder.layer.6.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
551
+ "qformer.bert.encoder.layer.6.output_query.dense.bias": "model-00001-of-00004.safetensors",
552
+ "qformer.bert.encoder.layer.6.output_query.dense.weight": "model-00001-of-00004.safetensors",
553
+ "qformer.bert.encoder.layer.7.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
554
+ "qformer.bert.encoder.layer.7.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
555
+ "qformer.bert.encoder.layer.7.attention.output.dense.bias": "model-00001-of-00004.safetensors",
556
+ "qformer.bert.encoder.layer.7.attention.output.dense.weight": "model-00001-of-00004.safetensors",
557
+ "qformer.bert.encoder.layer.7.attention.self.key.bias": "model-00001-of-00004.safetensors",
558
+ "qformer.bert.encoder.layer.7.attention.self.key.weight": "model-00001-of-00004.safetensors",
559
+ "qformer.bert.encoder.layer.7.attention.self.query.bias": "model-00001-of-00004.safetensors",
560
+ "qformer.bert.encoder.layer.7.attention.self.query.weight": "model-00001-of-00004.safetensors",
561
+ "qformer.bert.encoder.layer.7.attention.self.value.bias": "model-00001-of-00004.safetensors",
562
+ "qformer.bert.encoder.layer.7.attention.self.value.weight": "model-00001-of-00004.safetensors",
563
+ "qformer.bert.encoder.layer.7.intermediate.dense.bias": "model-00001-of-00004.safetensors",
564
+ "qformer.bert.encoder.layer.7.intermediate.dense.weight": "model-00001-of-00004.safetensors",
565
+ "qformer.bert.encoder.layer.7.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
566
+ "qformer.bert.encoder.layer.7.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
567
+ "qformer.bert.encoder.layer.7.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
568
+ "qformer.bert.encoder.layer.7.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
569
+ "qformer.bert.encoder.layer.7.output.dense.bias": "model-00001-of-00004.safetensors",
570
+ "qformer.bert.encoder.layer.7.output.dense.weight": "model-00001-of-00004.safetensors",
571
+ "qformer.bert.encoder.layer.7.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
572
+ "qformer.bert.encoder.layer.7.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
573
+ "qformer.bert.encoder.layer.7.output_query.dense.bias": "model-00001-of-00004.safetensors",
574
+ "qformer.bert.encoder.layer.7.output_query.dense.weight": "model-00001-of-00004.safetensors",
575
+ "qformer.bert.encoder.layer.8.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
576
+ "qformer.bert.encoder.layer.8.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
577
+ "qformer.bert.encoder.layer.8.attention.output.dense.bias": "model-00001-of-00004.safetensors",
578
+ "qformer.bert.encoder.layer.8.attention.output.dense.weight": "model-00001-of-00004.safetensors",
579
+ "qformer.bert.encoder.layer.8.attention.self.key.bias": "model-00001-of-00004.safetensors",
580
+ "qformer.bert.encoder.layer.8.attention.self.key.weight": "model-00001-of-00004.safetensors",
581
+ "qformer.bert.encoder.layer.8.attention.self.query.bias": "model-00001-of-00004.safetensors",
582
+ "qformer.bert.encoder.layer.8.attention.self.query.weight": "model-00001-of-00004.safetensors",
583
+ "qformer.bert.encoder.layer.8.attention.self.value.bias": "model-00001-of-00004.safetensors",
584
+ "qformer.bert.encoder.layer.8.attention.self.value.weight": "model-00001-of-00004.safetensors",
585
+ "qformer.bert.encoder.layer.8.crossattention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
586
+ "qformer.bert.encoder.layer.8.crossattention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
587
+ "qformer.bert.encoder.layer.8.crossattention.output.dense.bias": "model-00001-of-00004.safetensors",
588
+ "qformer.bert.encoder.layer.8.crossattention.output.dense.weight": "model-00001-of-00004.safetensors",
589
+ "qformer.bert.encoder.layer.8.crossattention.self.key.bias": "model-00001-of-00004.safetensors",
590
+ "qformer.bert.encoder.layer.8.crossattention.self.key.weight": "model-00001-of-00004.safetensors",
591
+ "qformer.bert.encoder.layer.8.crossattention.self.query.bias": "model-00001-of-00004.safetensors",
592
+ "qformer.bert.encoder.layer.8.crossattention.self.query.weight": "model-00001-of-00004.safetensors",
593
+ "qformer.bert.encoder.layer.8.crossattention.self.value.bias": "model-00001-of-00004.safetensors",
594
+ "qformer.bert.encoder.layer.8.crossattention.self.value.weight": "model-00001-of-00004.safetensors",
595
+ "qformer.bert.encoder.layer.8.intermediate.dense.bias": "model-00001-of-00004.safetensors",
596
+ "qformer.bert.encoder.layer.8.intermediate.dense.weight": "model-00001-of-00004.safetensors",
597
+ "qformer.bert.encoder.layer.8.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
598
+ "qformer.bert.encoder.layer.8.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
599
+ "qformer.bert.encoder.layer.8.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
600
+ "qformer.bert.encoder.layer.8.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
601
+ "qformer.bert.encoder.layer.8.output.dense.bias": "model-00001-of-00004.safetensors",
602
+ "qformer.bert.encoder.layer.8.output.dense.weight": "model-00001-of-00004.safetensors",
603
+ "qformer.bert.encoder.layer.8.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
604
+ "qformer.bert.encoder.layer.8.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
605
+ "qformer.bert.encoder.layer.8.output_query.dense.bias": "model-00001-of-00004.safetensors",
606
+ "qformer.bert.encoder.layer.8.output_query.dense.weight": "model-00001-of-00004.safetensors",
607
+ "qformer.bert.encoder.layer.9.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
608
+ "qformer.bert.encoder.layer.9.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
609
+ "qformer.bert.encoder.layer.9.attention.output.dense.bias": "model-00001-of-00004.safetensors",
610
+ "qformer.bert.encoder.layer.9.attention.output.dense.weight": "model-00001-of-00004.safetensors",
611
+ "qformer.bert.encoder.layer.9.attention.self.key.bias": "model-00001-of-00004.safetensors",
612
+ "qformer.bert.encoder.layer.9.attention.self.key.weight": "model-00001-of-00004.safetensors",
613
+ "qformer.bert.encoder.layer.9.attention.self.query.bias": "model-00001-of-00004.safetensors",
614
+ "qformer.bert.encoder.layer.9.attention.self.query.weight": "model-00001-of-00004.safetensors",
615
+ "qformer.bert.encoder.layer.9.attention.self.value.bias": "model-00001-of-00004.safetensors",
616
+ "qformer.bert.encoder.layer.9.attention.self.value.weight": "model-00001-of-00004.safetensors",
617
+ "qformer.bert.encoder.layer.9.intermediate.dense.bias": "model-00001-of-00004.safetensors",
618
+ "qformer.bert.encoder.layer.9.intermediate.dense.weight": "model-00001-of-00004.safetensors",
619
+ "qformer.bert.encoder.layer.9.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
620
+ "qformer.bert.encoder.layer.9.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
621
+ "qformer.bert.encoder.layer.9.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
622
+ "qformer.bert.encoder.layer.9.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
623
+ "qformer.bert.encoder.layer.9.output.dense.bias": "model-00001-of-00004.safetensors",
624
+ "qformer.bert.encoder.layer.9.output.dense.weight": "model-00001-of-00004.safetensors",
625
+ "qformer.bert.encoder.layer.9.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
626
+ "qformer.bert.encoder.layer.9.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
627
+ "qformer.bert.encoder.layer.9.output_query.dense.bias": "model-00001-of-00004.safetensors",
628
+ "qformer.bert.encoder.layer.9.output_query.dense.weight": "model-00001-of-00004.safetensors",
629
+ "query_tokens": "model-00001-of-00004.safetensors",
630
+ "vision_encoder.encoder.blocks.0.attn.proj.bias": "model-00001-of-00004.safetensors",
631
+ "vision_encoder.encoder.blocks.0.attn.proj.weight": "model-00001-of-00004.safetensors",
632
+ "vision_encoder.encoder.blocks.0.attn.q_bias": "model-00001-of-00004.safetensors",
633
+ "vision_encoder.encoder.blocks.0.attn.qkv.weight": "model-00001-of-00004.safetensors",
634
+ "vision_encoder.encoder.blocks.0.attn.v_bias": "model-00001-of-00004.safetensors",
635
+ "vision_encoder.encoder.blocks.0.mlp.fc1.bias": "model-00001-of-00004.safetensors",
636
+ "vision_encoder.encoder.blocks.0.mlp.fc1.weight": "model-00001-of-00004.safetensors",
637
+ "vision_encoder.encoder.blocks.0.mlp.fc2.bias": "model-00001-of-00004.safetensors",
638
+ "vision_encoder.encoder.blocks.0.mlp.fc2.weight": "model-00001-of-00004.safetensors",
639
+ "vision_encoder.encoder.blocks.0.norm1.bias": "model-00001-of-00004.safetensors",
640
+ "vision_encoder.encoder.blocks.0.norm1.weight": "model-00001-of-00004.safetensors",
641
+ "vision_encoder.encoder.blocks.0.norm2.bias": "model-00001-of-00004.safetensors",
642
+ "vision_encoder.encoder.blocks.0.norm2.weight": "model-00001-of-00004.safetensors",
643
+ "vision_encoder.encoder.blocks.1.attn.proj.bias": "model-00001-of-00004.safetensors",
644
+ "vision_encoder.encoder.blocks.1.attn.proj.weight": "model-00001-of-00004.safetensors",
645
+ "vision_encoder.encoder.blocks.1.attn.q_bias": "model-00001-of-00004.safetensors",
646
+ "vision_encoder.encoder.blocks.1.attn.qkv.weight": "model-00001-of-00004.safetensors",
647
+ "vision_encoder.encoder.blocks.1.attn.v_bias": "model-00001-of-00004.safetensors",
648
+ "vision_encoder.encoder.blocks.1.mlp.fc1.bias": "model-00001-of-00004.safetensors",
649
+ "vision_encoder.encoder.blocks.1.mlp.fc1.weight": "model-00001-of-00004.safetensors",
650
+ "vision_encoder.encoder.blocks.1.mlp.fc2.bias": "model-00001-of-00004.safetensors",
651
+ "vision_encoder.encoder.blocks.1.mlp.fc2.weight": "model-00001-of-00004.safetensors",
652
+ "vision_encoder.encoder.blocks.1.norm1.bias": "model-00001-of-00004.safetensors",
653
+ "vision_encoder.encoder.blocks.1.norm1.weight": "model-00001-of-00004.safetensors",
654
+ "vision_encoder.encoder.blocks.1.norm2.bias": "model-00001-of-00004.safetensors",
655
+ "vision_encoder.encoder.blocks.1.norm2.weight": "model-00001-of-00004.safetensors",
656
+ "vision_encoder.encoder.blocks.10.attn.proj.bias": "model-00001-of-00004.safetensors",
657
+ "vision_encoder.encoder.blocks.10.attn.proj.weight": "model-00001-of-00004.safetensors",
658
+ "vision_encoder.encoder.blocks.10.attn.q_bias": "model-00001-of-00004.safetensors",
659
+ "vision_encoder.encoder.blocks.10.attn.qkv.weight": "model-00001-of-00004.safetensors",
660
+ "vision_encoder.encoder.blocks.10.attn.v_bias": "model-00001-of-00004.safetensors",
661
+ "vision_encoder.encoder.blocks.10.mlp.fc1.bias": "model-00001-of-00004.safetensors",
662
+ "vision_encoder.encoder.blocks.10.mlp.fc1.weight": "model-00001-of-00004.safetensors",
663
+ "vision_encoder.encoder.blocks.10.mlp.fc2.bias": "model-00001-of-00004.safetensors",
664
+ "vision_encoder.encoder.blocks.10.mlp.fc2.weight": "model-00001-of-00004.safetensors",
665
+ "vision_encoder.encoder.blocks.10.norm1.bias": "model-00001-of-00004.safetensors",
666
+ "vision_encoder.encoder.blocks.10.norm1.weight": "model-00001-of-00004.safetensors",
667
+ "vision_encoder.encoder.blocks.10.norm2.bias": "model-00001-of-00004.safetensors",
668
+ "vision_encoder.encoder.blocks.10.norm2.weight": "model-00001-of-00004.safetensors",
669
+ "vision_encoder.encoder.blocks.11.attn.proj.bias": "model-00001-of-00004.safetensors",
670
+ "vision_encoder.encoder.blocks.11.attn.proj.weight": "model-00001-of-00004.safetensors",
671
+ "vision_encoder.encoder.blocks.11.attn.q_bias": "model-00001-of-00004.safetensors",
672
+ "vision_encoder.encoder.blocks.11.attn.qkv.weight": "model-00001-of-00004.safetensors",
673
+ "vision_encoder.encoder.blocks.11.attn.v_bias": "model-00001-of-00004.safetensors",
674
+ "vision_encoder.encoder.blocks.11.mlp.fc1.bias": "model-00001-of-00004.safetensors",
675
+ "vision_encoder.encoder.blocks.11.mlp.fc1.weight": "model-00001-of-00004.safetensors",
676
+ "vision_encoder.encoder.blocks.11.mlp.fc2.bias": "model-00001-of-00004.safetensors",
677
+ "vision_encoder.encoder.blocks.11.mlp.fc2.weight": "model-00001-of-00004.safetensors",
678
+ "vision_encoder.encoder.blocks.11.norm1.bias": "model-00001-of-00004.safetensors",
679
+ "vision_encoder.encoder.blocks.11.norm1.weight": "model-00001-of-00004.safetensors",
680
+ "vision_encoder.encoder.blocks.11.norm2.bias": "model-00001-of-00004.safetensors",
681
+ "vision_encoder.encoder.blocks.11.norm2.weight": "model-00001-of-00004.safetensors",
682
+ "vision_encoder.encoder.blocks.12.attn.proj.bias": "model-00001-of-00004.safetensors",
683
+ "vision_encoder.encoder.blocks.12.attn.proj.weight": "model-00001-of-00004.safetensors",
684
+ "vision_encoder.encoder.blocks.12.attn.q_bias": "model-00001-of-00004.safetensors",
685
+ "vision_encoder.encoder.blocks.12.attn.qkv.weight": "model-00001-of-00004.safetensors",
686
+ "vision_encoder.encoder.blocks.12.attn.v_bias": "model-00001-of-00004.safetensors",
687
+ "vision_encoder.encoder.blocks.12.mlp.fc1.bias": "model-00001-of-00004.safetensors",
688
+ "vision_encoder.encoder.blocks.12.mlp.fc1.weight": "model-00001-of-00004.safetensors",
689
+ "vision_encoder.encoder.blocks.12.mlp.fc2.bias": "model-00001-of-00004.safetensors",
690
+ "vision_encoder.encoder.blocks.12.mlp.fc2.weight": "model-00001-of-00004.safetensors",
691
+ "vision_encoder.encoder.blocks.12.norm1.bias": "model-00001-of-00004.safetensors",
692
+ "vision_encoder.encoder.blocks.12.norm1.weight": "model-00001-of-00004.safetensors",
693
+ "vision_encoder.encoder.blocks.12.norm2.bias": "model-00001-of-00004.safetensors",
694
+ "vision_encoder.encoder.blocks.12.norm2.weight": "model-00001-of-00004.safetensors",
695
+ "vision_encoder.encoder.blocks.13.attn.proj.bias": "model-00001-of-00004.safetensors",
696
+ "vision_encoder.encoder.blocks.13.attn.proj.weight": "model-00001-of-00004.safetensors",
697
+ "vision_encoder.encoder.blocks.13.attn.q_bias": "model-00001-of-00004.safetensors",
698
+ "vision_encoder.encoder.blocks.13.attn.qkv.weight": "model-00001-of-00004.safetensors",
699
+ "vision_encoder.encoder.blocks.13.attn.v_bias": "model-00001-of-00004.safetensors",
700
+ "vision_encoder.encoder.blocks.13.mlp.fc1.bias": "model-00001-of-00004.safetensors",
701
+ "vision_encoder.encoder.blocks.13.mlp.fc1.weight": "model-00001-of-00004.safetensors",
702
+ "vision_encoder.encoder.blocks.13.mlp.fc2.bias": "model-00001-of-00004.safetensors",
703
+ "vision_encoder.encoder.blocks.13.mlp.fc2.weight": "model-00001-of-00004.safetensors",
704
+ "vision_encoder.encoder.blocks.13.norm1.bias": "model-00001-of-00004.safetensors",
705
+ "vision_encoder.encoder.blocks.13.norm1.weight": "model-00001-of-00004.safetensors",
706
+ "vision_encoder.encoder.blocks.13.norm2.bias": "model-00001-of-00004.safetensors",
707
+ "vision_encoder.encoder.blocks.13.norm2.weight": "model-00001-of-00004.safetensors",
708
+ "vision_encoder.encoder.blocks.14.attn.proj.bias": "model-00001-of-00004.safetensors",
709
+ "vision_encoder.encoder.blocks.14.attn.proj.weight": "model-00001-of-00004.safetensors",
710
+ "vision_encoder.encoder.blocks.14.attn.q_bias": "model-00001-of-00004.safetensors",
711
+ "vision_encoder.encoder.blocks.14.attn.qkv.weight": "model-00001-of-00004.safetensors",
712
+ "vision_encoder.encoder.blocks.14.attn.v_bias": "model-00001-of-00004.safetensors",
713
+ "vision_encoder.encoder.blocks.14.mlp.fc1.bias": "model-00001-of-00004.safetensors",
714
+ "vision_encoder.encoder.blocks.14.mlp.fc1.weight": "model-00001-of-00004.safetensors",
715
+ "vision_encoder.encoder.blocks.14.mlp.fc2.bias": "model-00001-of-00004.safetensors",
716
+ "vision_encoder.encoder.blocks.14.mlp.fc2.weight": "model-00001-of-00004.safetensors",
717
+ "vision_encoder.encoder.blocks.14.norm1.bias": "model-00001-of-00004.safetensors",
718
+ "vision_encoder.encoder.blocks.14.norm1.weight": "model-00001-of-00004.safetensors",
719
+ "vision_encoder.encoder.blocks.14.norm2.bias": "model-00001-of-00004.safetensors",
720
+ "vision_encoder.encoder.blocks.14.norm2.weight": "model-00001-of-00004.safetensors",
721
+ "vision_encoder.encoder.blocks.15.attn.proj.bias": "model-00001-of-00004.safetensors",
722
+ "vision_encoder.encoder.blocks.15.attn.proj.weight": "model-00001-of-00004.safetensors",
723
+ "vision_encoder.encoder.blocks.15.attn.q_bias": "model-00001-of-00004.safetensors",
724
+ "vision_encoder.encoder.blocks.15.attn.qkv.weight": "model-00001-of-00004.safetensors",
725
+ "vision_encoder.encoder.blocks.15.attn.v_bias": "model-00001-of-00004.safetensors",
726
+ "vision_encoder.encoder.blocks.15.mlp.fc1.bias": "model-00001-of-00004.safetensors",
727
+ "vision_encoder.encoder.blocks.15.mlp.fc1.weight": "model-00001-of-00004.safetensors",
728
+ "vision_encoder.encoder.blocks.15.mlp.fc2.bias": "model-00001-of-00004.safetensors",
729
+ "vision_encoder.encoder.blocks.15.mlp.fc2.weight": "model-00001-of-00004.safetensors",
730
+ "vision_encoder.encoder.blocks.15.norm1.bias": "model-00001-of-00004.safetensors",
731
+ "vision_encoder.encoder.blocks.15.norm1.weight": "model-00001-of-00004.safetensors",
732
+ "vision_encoder.encoder.blocks.15.norm2.bias": "model-00001-of-00004.safetensors",
733
+ "vision_encoder.encoder.blocks.15.norm2.weight": "model-00001-of-00004.safetensors",
734
+ "vision_encoder.encoder.blocks.16.attn.proj.bias": "model-00001-of-00004.safetensors",
735
+ "vision_encoder.encoder.blocks.16.attn.proj.weight": "model-00001-of-00004.safetensors",
736
+ "vision_encoder.encoder.blocks.16.attn.q_bias": "model-00001-of-00004.safetensors",
737
+ "vision_encoder.encoder.blocks.16.attn.qkv.weight": "model-00001-of-00004.safetensors",
738
+ "vision_encoder.encoder.blocks.16.attn.v_bias": "model-00001-of-00004.safetensors",
739
+ "vision_encoder.encoder.blocks.16.mlp.fc1.bias": "model-00001-of-00004.safetensors",
740
+ "vision_encoder.encoder.blocks.16.mlp.fc1.weight": "model-00001-of-00004.safetensors",
741
+ "vision_encoder.encoder.blocks.16.mlp.fc2.bias": "model-00001-of-00004.safetensors",
742
+ "vision_encoder.encoder.blocks.16.mlp.fc2.weight": "model-00001-of-00004.safetensors",
743
+ "vision_encoder.encoder.blocks.16.norm1.bias": "model-00001-of-00004.safetensors",
744
+ "vision_encoder.encoder.blocks.16.norm1.weight": "model-00001-of-00004.safetensors",
745
+ "vision_encoder.encoder.blocks.16.norm2.bias": "model-00001-of-00004.safetensors",
746
+ "vision_encoder.encoder.blocks.16.norm2.weight": "model-00001-of-00004.safetensors",
747
+ "vision_encoder.encoder.blocks.17.attn.proj.bias": "model-00001-of-00004.safetensors",
748
+ "vision_encoder.encoder.blocks.17.attn.proj.weight": "model-00001-of-00004.safetensors",
749
+ "vision_encoder.encoder.blocks.17.attn.q_bias": "model-00001-of-00004.safetensors",
750
+ "vision_encoder.encoder.blocks.17.attn.qkv.weight": "model-00001-of-00004.safetensors",
751
+ "vision_encoder.encoder.blocks.17.attn.v_bias": "model-00001-of-00004.safetensors",
752
+ "vision_encoder.encoder.blocks.17.mlp.fc1.bias": "model-00001-of-00004.safetensors",
753
+ "vision_encoder.encoder.blocks.17.mlp.fc1.weight": "model-00001-of-00004.safetensors",
754
+ "vision_encoder.encoder.blocks.17.mlp.fc2.bias": "model-00001-of-00004.safetensors",
755
+ "vision_encoder.encoder.blocks.17.mlp.fc2.weight": "model-00001-of-00004.safetensors",
756
+ "vision_encoder.encoder.blocks.17.norm1.bias": "model-00001-of-00004.safetensors",
757
+ "vision_encoder.encoder.blocks.17.norm1.weight": "model-00001-of-00004.safetensors",
758
+ "vision_encoder.encoder.blocks.17.norm2.bias": "model-00001-of-00004.safetensors",
759
+ "vision_encoder.encoder.blocks.17.norm2.weight": "model-00001-of-00004.safetensors",
760
+ "vision_encoder.encoder.blocks.18.attn.proj.bias": "model-00001-of-00004.safetensors",
761
+ "vision_encoder.encoder.blocks.18.attn.proj.weight": "model-00001-of-00004.safetensors",
762
+ "vision_encoder.encoder.blocks.18.attn.q_bias": "model-00001-of-00004.safetensors",
763
+ "vision_encoder.encoder.blocks.18.attn.qkv.weight": "model-00001-of-00004.safetensors",
764
+ "vision_encoder.encoder.blocks.18.attn.v_bias": "model-00001-of-00004.safetensors",
765
+ "vision_encoder.encoder.blocks.18.mlp.fc1.bias": "model-00001-of-00004.safetensors",
766
+ "vision_encoder.encoder.blocks.18.mlp.fc1.weight": "model-00001-of-00004.safetensors",
767
+ "vision_encoder.encoder.blocks.18.mlp.fc2.bias": "model-00001-of-00004.safetensors",
768
+ "vision_encoder.encoder.blocks.18.mlp.fc2.weight": "model-00001-of-00004.safetensors",
769
+ "vision_encoder.encoder.blocks.18.norm1.bias": "model-00001-of-00004.safetensors",
770
+ "vision_encoder.encoder.blocks.18.norm1.weight": "model-00001-of-00004.safetensors",
771
+ "vision_encoder.encoder.blocks.18.norm2.bias": "model-00001-of-00004.safetensors",
772
+ "vision_encoder.encoder.blocks.18.norm2.weight": "model-00001-of-00004.safetensors",
773
+ "vision_encoder.encoder.blocks.19.attn.proj.bias": "model-00001-of-00004.safetensors",
774
+ "vision_encoder.encoder.blocks.19.attn.proj.weight": "model-00001-of-00004.safetensors",
775
+ "vision_encoder.encoder.blocks.19.attn.q_bias": "model-00001-of-00004.safetensors",
776
+ "vision_encoder.encoder.blocks.19.attn.qkv.weight": "model-00001-of-00004.safetensors",
777
+ "vision_encoder.encoder.blocks.19.attn.v_bias": "model-00001-of-00004.safetensors",
778
+ "vision_encoder.encoder.blocks.19.mlp.fc1.bias": "model-00001-of-00004.safetensors",
779
+ "vision_encoder.encoder.blocks.19.mlp.fc1.weight": "model-00001-of-00004.safetensors",
780
+ "vision_encoder.encoder.blocks.19.mlp.fc2.bias": "model-00001-of-00004.safetensors",
781
+ "vision_encoder.encoder.blocks.19.mlp.fc2.weight": "model-00001-of-00004.safetensors",
782
+ "vision_encoder.encoder.blocks.19.norm1.bias": "model-00001-of-00004.safetensors",
783
+ "vision_encoder.encoder.blocks.19.norm1.weight": "model-00001-of-00004.safetensors",
784
+ "vision_encoder.encoder.blocks.19.norm2.bias": "model-00001-of-00004.safetensors",
785
+ "vision_encoder.encoder.blocks.19.norm2.weight": "model-00001-of-00004.safetensors",
786
+ "vision_encoder.encoder.blocks.2.attn.proj.bias": "model-00001-of-00004.safetensors",
787
+ "vision_encoder.encoder.blocks.2.attn.proj.weight": "model-00001-of-00004.safetensors",
788
+ "vision_encoder.encoder.blocks.2.attn.q_bias": "model-00001-of-00004.safetensors",
789
+ "vision_encoder.encoder.blocks.2.attn.qkv.weight": "model-00001-of-00004.safetensors",
790
+ "vision_encoder.encoder.blocks.2.attn.v_bias": "model-00001-of-00004.safetensors",
791
+ "vision_encoder.encoder.blocks.2.mlp.fc1.bias": "model-00001-of-00004.safetensors",
792
+ "vision_encoder.encoder.blocks.2.mlp.fc1.weight": "model-00001-of-00004.safetensors",
793
+ "vision_encoder.encoder.blocks.2.mlp.fc2.bias": "model-00001-of-00004.safetensors",
794
+ "vision_encoder.encoder.blocks.2.mlp.fc2.weight": "model-00001-of-00004.safetensors",
795
+ "vision_encoder.encoder.blocks.2.norm1.bias": "model-00001-of-00004.safetensors",
796
+ "vision_encoder.encoder.blocks.2.norm1.weight": "model-00001-of-00004.safetensors",
797
+ "vision_encoder.encoder.blocks.2.norm2.bias": "model-00001-of-00004.safetensors",
798
+ "vision_encoder.encoder.blocks.2.norm2.weight": "model-00001-of-00004.safetensors",
799
+ "vision_encoder.encoder.blocks.20.attn.proj.bias": "model-00001-of-00004.safetensors",
800
+ "vision_encoder.encoder.blocks.20.attn.proj.weight": "model-00001-of-00004.safetensors",
801
+ "vision_encoder.encoder.blocks.20.attn.q_bias": "model-00001-of-00004.safetensors",
802
+ "vision_encoder.encoder.blocks.20.attn.qkv.weight": "model-00001-of-00004.safetensors",
803
+ "vision_encoder.encoder.blocks.20.attn.v_bias": "model-00001-of-00004.safetensors",
804
+ "vision_encoder.encoder.blocks.20.mlp.fc1.bias": "model-00001-of-00004.safetensors",
805
+ "vision_encoder.encoder.blocks.20.mlp.fc1.weight": "model-00001-of-00004.safetensors",
806
+ "vision_encoder.encoder.blocks.20.mlp.fc2.bias": "model-00001-of-00004.safetensors",
807
+ "vision_encoder.encoder.blocks.20.mlp.fc2.weight": "model-00001-of-00004.safetensors",
808
+ "vision_encoder.encoder.blocks.20.norm1.bias": "model-00001-of-00004.safetensors",
809
+ "vision_encoder.encoder.blocks.20.norm1.weight": "model-00001-of-00004.safetensors",
810
+ "vision_encoder.encoder.blocks.20.norm2.bias": "model-00001-of-00004.safetensors",
811
+ "vision_encoder.encoder.blocks.20.norm2.weight": "model-00001-of-00004.safetensors",
812
+ "vision_encoder.encoder.blocks.21.attn.proj.bias": "model-00001-of-00004.safetensors",
813
+ "vision_encoder.encoder.blocks.21.attn.proj.weight": "model-00001-of-00004.safetensors",
814
+ "vision_encoder.encoder.blocks.21.attn.q_bias": "model-00001-of-00004.safetensors",
815
+ "vision_encoder.encoder.blocks.21.attn.qkv.weight": "model-00001-of-00004.safetensors",
816
+ "vision_encoder.encoder.blocks.21.attn.v_bias": "model-00001-of-00004.safetensors",
817
+ "vision_encoder.encoder.blocks.21.mlp.fc1.bias": "model-00001-of-00004.safetensors",
818
+ "vision_encoder.encoder.blocks.21.mlp.fc1.weight": "model-00001-of-00004.safetensors",
819
+ "vision_encoder.encoder.blocks.21.mlp.fc2.bias": "model-00001-of-00004.safetensors",
820
+ "vision_encoder.encoder.blocks.21.mlp.fc2.weight": "model-00001-of-00004.safetensors",
821
+ "vision_encoder.encoder.blocks.21.norm1.bias": "model-00001-of-00004.safetensors",
822
+ "vision_encoder.encoder.blocks.21.norm1.weight": "model-00001-of-00004.safetensors",
823
+ "vision_encoder.encoder.blocks.21.norm2.bias": "model-00001-of-00004.safetensors",
824
+ "vision_encoder.encoder.blocks.21.norm2.weight": "model-00001-of-00004.safetensors",
825
+ "vision_encoder.encoder.blocks.22.attn.proj.bias": "model-00001-of-00004.safetensors",
826
+ "vision_encoder.encoder.blocks.22.attn.proj.weight": "model-00001-of-00004.safetensors",
827
+ "vision_encoder.encoder.blocks.22.attn.q_bias": "model-00001-of-00004.safetensors",
828
+ "vision_encoder.encoder.blocks.22.attn.qkv.weight": "model-00001-of-00004.safetensors",
829
+ "vision_encoder.encoder.blocks.22.attn.v_bias": "model-00001-of-00004.safetensors",
830
+ "vision_encoder.encoder.blocks.22.mlp.fc1.bias": "model-00001-of-00004.safetensors",
831
+ "vision_encoder.encoder.blocks.22.mlp.fc1.weight": "model-00001-of-00004.safetensors",
832
+ "vision_encoder.encoder.blocks.22.mlp.fc2.bias": "model-00001-of-00004.safetensors",
833
+ "vision_encoder.encoder.blocks.22.mlp.fc2.weight": "model-00001-of-00004.safetensors",
834
+ "vision_encoder.encoder.blocks.22.norm1.bias": "model-00001-of-00004.safetensors",
835
+ "vision_encoder.encoder.blocks.22.norm1.weight": "model-00001-of-00004.safetensors",
836
+ "vision_encoder.encoder.blocks.22.norm2.bias": "model-00001-of-00004.safetensors",
837
+ "vision_encoder.encoder.blocks.22.norm2.weight": "model-00001-of-00004.safetensors",
838
+ "vision_encoder.encoder.blocks.3.attn.proj.bias": "model-00001-of-00004.safetensors",
839
+ "vision_encoder.encoder.blocks.3.attn.proj.weight": "model-00001-of-00004.safetensors",
840
+ "vision_encoder.encoder.blocks.3.attn.q_bias": "model-00001-of-00004.safetensors",
841
+ "vision_encoder.encoder.blocks.3.attn.qkv.weight": "model-00001-of-00004.safetensors",
842
+ "vision_encoder.encoder.blocks.3.attn.v_bias": "model-00001-of-00004.safetensors",
843
+ "vision_encoder.encoder.blocks.3.mlp.fc1.bias": "model-00001-of-00004.safetensors",
844
+ "vision_encoder.encoder.blocks.3.mlp.fc1.weight": "model-00001-of-00004.safetensors",
845
+ "vision_encoder.encoder.blocks.3.mlp.fc2.bias": "model-00001-of-00004.safetensors",
846
+ "vision_encoder.encoder.blocks.3.mlp.fc2.weight": "model-00001-of-00004.safetensors",
847
+ "vision_encoder.encoder.blocks.3.norm1.bias": "model-00001-of-00004.safetensors",
848
+ "vision_encoder.encoder.blocks.3.norm1.weight": "model-00001-of-00004.safetensors",
849
+ "vision_encoder.encoder.blocks.3.norm2.bias": "model-00001-of-00004.safetensors",
850
+ "vision_encoder.encoder.blocks.3.norm2.weight": "model-00001-of-00004.safetensors",
851
+ "vision_encoder.encoder.blocks.4.attn.proj.bias": "model-00001-of-00004.safetensors",
852
+ "vision_encoder.encoder.blocks.4.attn.proj.weight": "model-00001-of-00004.safetensors",
853
+ "vision_encoder.encoder.blocks.4.attn.q_bias": "model-00001-of-00004.safetensors",
854
+ "vision_encoder.encoder.blocks.4.attn.qkv.weight": "model-00001-of-00004.safetensors",
855
+ "vision_encoder.encoder.blocks.4.attn.v_bias": "model-00001-of-00004.safetensors",
856
+ "vision_encoder.encoder.blocks.4.mlp.fc1.bias": "model-00001-of-00004.safetensors",
857
+ "vision_encoder.encoder.blocks.4.mlp.fc1.weight": "model-00001-of-00004.safetensors",
858
+ "vision_encoder.encoder.blocks.4.mlp.fc2.bias": "model-00001-of-00004.safetensors",
859
+ "vision_encoder.encoder.blocks.4.mlp.fc2.weight": "model-00001-of-00004.safetensors",
860
+ "vision_encoder.encoder.blocks.4.norm1.bias": "model-00001-of-00004.safetensors",
861
+ "vision_encoder.encoder.blocks.4.norm1.weight": "model-00001-of-00004.safetensors",
862
+ "vision_encoder.encoder.blocks.4.norm2.bias": "model-00001-of-00004.safetensors",
863
+ "vision_encoder.encoder.blocks.4.norm2.weight": "model-00001-of-00004.safetensors",
864
+ "vision_encoder.encoder.blocks.5.attn.proj.bias": "model-00001-of-00004.safetensors",
865
+ "vision_encoder.encoder.blocks.5.attn.proj.weight": "model-00001-of-00004.safetensors",
866
+ "vision_encoder.encoder.blocks.5.attn.q_bias": "model-00001-of-00004.safetensors",
867
+ "vision_encoder.encoder.blocks.5.attn.qkv.weight": "model-00001-of-00004.safetensors",
868
+ "vision_encoder.encoder.blocks.5.attn.v_bias": "model-00001-of-00004.safetensors",
869
+ "vision_encoder.encoder.blocks.5.mlp.fc1.bias": "model-00001-of-00004.safetensors",
870
+ "vision_encoder.encoder.blocks.5.mlp.fc1.weight": "model-00001-of-00004.safetensors",
871
+ "vision_encoder.encoder.blocks.5.mlp.fc2.bias": "model-00001-of-00004.safetensors",
872
+ "vision_encoder.encoder.blocks.5.mlp.fc2.weight": "model-00001-of-00004.safetensors",
873
+ "vision_encoder.encoder.blocks.5.norm1.bias": "model-00001-of-00004.safetensors",
874
+ "vision_encoder.encoder.blocks.5.norm1.weight": "model-00001-of-00004.safetensors",
875
+ "vision_encoder.encoder.blocks.5.norm2.bias": "model-00001-of-00004.safetensors",
876
+ "vision_encoder.encoder.blocks.5.norm2.weight": "model-00001-of-00004.safetensors",
877
+ "vision_encoder.encoder.blocks.6.attn.proj.bias": "model-00001-of-00004.safetensors",
878
+ "vision_encoder.encoder.blocks.6.attn.proj.weight": "model-00001-of-00004.safetensors",
879
+ "vision_encoder.encoder.blocks.6.attn.q_bias": "model-00001-of-00004.safetensors",
880
+ "vision_encoder.encoder.blocks.6.attn.qkv.weight": "model-00001-of-00004.safetensors",
881
+ "vision_encoder.encoder.blocks.6.attn.v_bias": "model-00001-of-00004.safetensors",
882
+ "vision_encoder.encoder.blocks.6.mlp.fc1.bias": "model-00001-of-00004.safetensors",
883
+ "vision_encoder.encoder.blocks.6.mlp.fc1.weight": "model-00001-of-00004.safetensors",
884
+ "vision_encoder.encoder.blocks.6.mlp.fc2.bias": "model-00001-of-00004.safetensors",
885
+ "vision_encoder.encoder.blocks.6.mlp.fc2.weight": "model-00001-of-00004.safetensors",
886
+ "vision_encoder.encoder.blocks.6.norm1.bias": "model-00001-of-00004.safetensors",
887
+ "vision_encoder.encoder.blocks.6.norm1.weight": "model-00001-of-00004.safetensors",
888
+ "vision_encoder.encoder.blocks.6.norm2.bias": "model-00001-of-00004.safetensors",
889
+ "vision_encoder.encoder.blocks.6.norm2.weight": "model-00001-of-00004.safetensors",
890
+ "vision_encoder.encoder.blocks.7.attn.proj.bias": "model-00001-of-00004.safetensors",
891
+ "vision_encoder.encoder.blocks.7.attn.proj.weight": "model-00001-of-00004.safetensors",
892
+ "vision_encoder.encoder.blocks.7.attn.q_bias": "model-00001-of-00004.safetensors",
893
+ "vision_encoder.encoder.blocks.7.attn.qkv.weight": "model-00001-of-00004.safetensors",
894
+ "vision_encoder.encoder.blocks.7.attn.v_bias": "model-00001-of-00004.safetensors",
895
+ "vision_encoder.encoder.blocks.7.mlp.fc1.bias": "model-00001-of-00004.safetensors",
896
+ "vision_encoder.encoder.blocks.7.mlp.fc1.weight": "model-00001-of-00004.safetensors",
897
+ "vision_encoder.encoder.blocks.7.mlp.fc2.bias": "model-00001-of-00004.safetensors",
898
+ "vision_encoder.encoder.blocks.7.mlp.fc2.weight": "model-00001-of-00004.safetensors",
899
+ "vision_encoder.encoder.blocks.7.norm1.bias": "model-00001-of-00004.safetensors",
900
+ "vision_encoder.encoder.blocks.7.norm1.weight": "model-00001-of-00004.safetensors",
901
+ "vision_encoder.encoder.blocks.7.norm2.bias": "model-00001-of-00004.safetensors",
902
+ "vision_encoder.encoder.blocks.7.norm2.weight": "model-00001-of-00004.safetensors",
903
+ "vision_encoder.encoder.blocks.8.attn.proj.bias": "model-00001-of-00004.safetensors",
904
+ "vision_encoder.encoder.blocks.8.attn.proj.weight": "model-00001-of-00004.safetensors",
905
+ "vision_encoder.encoder.blocks.8.attn.q_bias": "model-00001-of-00004.safetensors",
906
+ "vision_encoder.encoder.blocks.8.attn.qkv.weight": "model-00001-of-00004.safetensors",
907
+ "vision_encoder.encoder.blocks.8.attn.v_bias": "model-00001-of-00004.safetensors",
908
+ "vision_encoder.encoder.blocks.8.mlp.fc1.bias": "model-00001-of-00004.safetensors",
909
+ "vision_encoder.encoder.blocks.8.mlp.fc1.weight": "model-00001-of-00004.safetensors",
910
+ "vision_encoder.encoder.blocks.8.mlp.fc2.bias": "model-00001-of-00004.safetensors",
911
+ "vision_encoder.encoder.blocks.8.mlp.fc2.weight": "model-00001-of-00004.safetensors",
912
+ "vision_encoder.encoder.blocks.8.norm1.bias": "model-00001-of-00004.safetensors",
913
+ "vision_encoder.encoder.blocks.8.norm1.weight": "model-00001-of-00004.safetensors",
914
+ "vision_encoder.encoder.blocks.8.norm2.bias": "model-00001-of-00004.safetensors",
915
+ "vision_encoder.encoder.blocks.8.norm2.weight": "model-00001-of-00004.safetensors",
916
+ "vision_encoder.encoder.blocks.9.attn.proj.bias": "model-00001-of-00004.safetensors",
917
+ "vision_encoder.encoder.blocks.9.attn.proj.weight": "model-00001-of-00004.safetensors",
918
+ "vision_encoder.encoder.blocks.9.attn.q_bias": "model-00001-of-00004.safetensors",
919
+ "vision_encoder.encoder.blocks.9.attn.qkv.weight": "model-00001-of-00004.safetensors",
920
+ "vision_encoder.encoder.blocks.9.attn.v_bias": "model-00001-of-00004.safetensors",
921
+ "vision_encoder.encoder.blocks.9.mlp.fc1.bias": "model-00001-of-00004.safetensors",
922
+ "vision_encoder.encoder.blocks.9.mlp.fc1.weight": "model-00001-of-00004.safetensors",
923
+ "vision_encoder.encoder.blocks.9.mlp.fc2.bias": "model-00001-of-00004.safetensors",
924
+ "vision_encoder.encoder.blocks.9.mlp.fc2.weight": "model-00001-of-00004.safetensors",
925
+ "vision_encoder.encoder.blocks.9.norm1.bias": "model-00001-of-00004.safetensors",
926
+ "vision_encoder.encoder.blocks.9.norm1.weight": "model-00001-of-00004.safetensors",
927
+ "vision_encoder.encoder.blocks.9.norm2.bias": "model-00001-of-00004.safetensors",
928
+ "vision_encoder.encoder.blocks.9.norm2.weight": "model-00001-of-00004.safetensors",
929
+ "vision_encoder.encoder.patch_embed.proj.bias": "model-00001-of-00004.safetensors",
930
+ "vision_encoder.encoder.patch_embed.proj.weight": "model-00001-of-00004.safetensors",
931
+ "vision_layernorm.bias": "model-00001-of-00004.safetensors",
932
+ "vision_layernorm.weight": "model-00001-of-00004.safetensors"
933
+ }
934
+ }
videochat2_it_hd_mistral.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import logging
3
+
4
+ import torch
5
+ from torch.cuda.amp import autocast as autocast
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from peft import get_peft_model, LoraConfig, TaskType
9
+
10
+ from .blip2 import Blip2Base, disabled_train
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, PretrainedConfig
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ from easydict import EasyDict
16
+ from .configuration_videochat2 import Config
17
+
18
+ class VideoChat2_it_hd_mistral(Blip2Base):
19
+ _auto_class='AutoModel'
20
+ config_class=Config
21
+ """
22
+ VideoChat2 model.
23
+ """
24
+ def __init__(self, config):
25
+ super().__init__()
26
+ # pretrained_path
27
+ self.config=config
28
+ if isinstance(config,(PretrainedConfig,AutoConfig)):
29
+ if hasattr(config,'cfg'): # my own cfg
30
+ config=EasyDict(config.cfg)
31
+ else:
32
+ config=EasyDict(config.to_dict())
33
+ pc=PretrainedConfig()
34
+ pc.update(config)
35
+ vit_blip_model_path = config.get("vit_blip_model_path", None)
36
+ mistral_model_path = config.get("mistral_model_path")
37
+ videochat2_model_path = config.get("videochat2_model_path", "")
38
+ freeze_vit = config.get("freeze_vit", True)
39
+ freeze_qformer = config.get("freeze_qformer", True)
40
+ freeze_llm = config.get("freeze_llm", True)
41
+ # vit
42
+ low_resource = config.get("low_resource", False) # use 8 bit and put vit in cpu
43
+ # qformer
44
+ num_query_token = config.get("num_query_token")
45
+ qformer_hidden_dropout_prob = config.get("qformer_hidden_dropout_prob", 0.1)
46
+ qformer_attention_probs_dropout_prob = config.get("qformer_attention_probs_dropout_prob", 0.1)
47
+ qformer_drop_path_rate = config.get("qformer_drop_path_rate", 0.1)
48
+ extra_num_query_token = config.get("extra_num_query_token", 32)
49
+ self.qformer_text_input = config.get("qformer_text_input", False)
50
+
51
+ # Infinite-Video related hyperparameters
52
+ num_basis = config.get("num_basis", 256)
53
+ sticky = config.get("sticky", True)
54
+ tau = config.get("tau", 0.75)
55
+ alpha = config.get("alpha", 0.75)
56
+
57
+ # prompt
58
+ max_txt_len = config.get("max_txt_len", 32)
59
+ self.human_start = "[INST]"
60
+ self.human_end = "[/INST]"
61
+ self.assist_end = "</s>"
62
+ self.start_token = config.get("start_token", "<Video>")
63
+ self.end_token = config.get("end_token", "</Video>")
64
+ self.img_start_token = config.get("img_start_token", "<Image>")
65
+ self.img_end_token = config.get("img_end_token", "</Image>")
66
+ logger.info(f"Add instruction in qformer: {self.qformer_text_input}")
67
+ # debug
68
+ self.debug = config.get("debug", False)
69
+ self.llm_bf16 = config.get("llm_bf16", False)
70
+ use_flash_attention = config.get("use_flash_attention", False)
71
+ self.use_lora = config.get("use_lora", False)
72
+ lora_r = config.get("lora_r", 8)
73
+ lora_alpha = config.get("lora_alpha", 32)
74
+ lora_dropout = config.get("lora_dropout", 0.05)
75
+ # dynamic resolution
76
+ self.local_size = config.dynamic_config.get("local_size", 224)
77
+ self.add_global = config.dynamic_config.get("add_global", True)
78
+
79
+ self.tokenizer = self.init_tokenizer(truncation_side="left")
80
+ self.tokenizer.padding_side = "left"
81
+ self.low_resource = low_resource
82
+ self.vision_encoder, self.vision_layernorm = self.init_vision_encoder_umt(config)
83
+ self.qformer, self.query_tokens = self.init_Qformer(
84
+ num_query_token, config.vision_encoder.encoder_embed_dim,
85
+ qformer_hidden_dropout_prob=qformer_hidden_dropout_prob,
86
+ qformer_attention_probs_dropout_prob=qformer_attention_probs_dropout_prob,
87
+ qformer_drop_path_rate=qformer_drop_path_rate,
88
+ num_basis=num_basis, alpha=alpha, tau=tau, sticky=sticky,
89
+ )
90
+
91
+ if not self.qformer_text_input:
92
+ self.qformer.bert.embeddings.word_embeddings = None
93
+ self.qformer.bert.embeddings.position_embeddings = None
94
+ for layer in self.qformer.bert.encoder.layer:
95
+ layer.output = None
96
+ layer.intermediate = None
97
+ else:
98
+ self.qformer.resize_token_embeddings(len(self.tokenizer))
99
+ self.qformer.cls = None
100
+
101
+ if vit_blip_model_path:
102
+ logger.info(f"Load ViT and QFormer from {vit_blip_model_path}")
103
+ state_dict = torch.load(vit_blip_model_path, map_location="cpu")
104
+ msg = self.load_state_dict(state_dict, strict=False)
105
+ logger.info(msg)
106
+ logger.info('Loading ViT and Q-Former Done')
107
+
108
+ self.extra_num_query_token = extra_num_query_token
109
+ if extra_num_query_token > 0:
110
+ logger.info(f"Add extra {extra_num_query_token} tokens in QFormer")
111
+ self.extra_query_tokens = nn.Parameter(
112
+ torch.zeros(1, extra_num_query_token, self.query_tokens.shape[-1])
113
+ )
114
+
115
+ if freeze_vit:
116
+ logger.info("freeze vision encoder")
117
+ for _, param in self.vision_encoder.named_parameters():
118
+ param.requires_grad = False
119
+ self.vision_encoder = self.vision_encoder.eval()
120
+ self.vision_encoder.train = disabled_train
121
+ for _, param in self.vision_layernorm.named_parameters():
122
+ param.requires_grad = False
123
+ self.vision_layernorm = self.vision_layernorm.eval()
124
+ self.vision_layernorm.train = disabled_train
125
+
126
+ if freeze_qformer:
127
+ logger.info("freeze Qformer")
128
+ for _, param in self.qformer.named_parameters():
129
+ param.requires_grad = False
130
+ self.qformer = self.qformer.eval()
131
+ self.qformer.train = disabled_train
132
+ self.query_tokens.requires_grad = False
133
+
134
+ logger.info('Loading Mistral')
135
+ self.mistral_tokenizer = AutoTokenizer.from_pretrained(mistral_model_path)
136
+ self.mistral_tokenizer.padding_side = "left"
137
+ if not self.mistral_tokenizer.pad_token:
138
+ logger.info("Set pad_token")
139
+ self.mistral_tokenizer.pad_token = self.mistral_tokenizer.eos_token
140
+
141
+ if self.debug:
142
+ logger.info("Debug mode, build small Mistral")
143
+ mistral_config = AutoConfig.from_pretrained(mistral_model_path)
144
+ mistral_config.hidden_size = 512
145
+ mistral_config.intermediate_size = 2048
146
+ mistral_config.num_attention_heads = 8
147
+ mistral_config.num_hidden_layers = 12
148
+ mistral_config.torch_dtype = torch.float16
149
+ self.mistral_model = AutoModelForCausalLM.from_config(mistral_config)
150
+ else:
151
+ if use_flash_attention:
152
+ self.mistral_model = AutoModelForCausalLM.from_pretrained(
153
+ mistral_model_path,
154
+ torch_dtype=torch.bfloat16 if self.llm_bf16 else torch.float16,
155
+ # use_flash_attention_2=True,
156
+ attn_implementation="flash_attention_2",
157
+ )
158
+ else:
159
+ self.mistral_model = AutoModelForCausalLM.from_pretrained(
160
+ mistral_model_path,
161
+ torch_dtype=torch.bfloat16 if self.llm_bf16 else torch.float16,
162
+ )
163
+
164
+ if freeze_llm:
165
+ logger.info("freeze Mistral")
166
+ for _, param in self.mistral_model.named_parameters():
167
+ param.requires_grad = False
168
+ logger.info('Loading Mistral Done')
169
+
170
+ if self.use_lora:
171
+ logger.info("Use lora")
172
+ peft_config = LoraConfig(
173
+ task_type=TaskType.CAUSAL_LM, inference_mode=False,
174
+ r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
175
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
176
+ "gate_proj", "up_proj", "down_proj", "lm_head"]
177
+ )
178
+ self.mistral_model = get_peft_model(self.mistral_model, peft_config)
179
+ if not freeze_llm:
180
+ logger.info("Unfreeze Mistral")
181
+ for _, param in self.mistral_model.base_model.named_parameters():
182
+ param.requires_grad = True
183
+ self.mistral_model.print_trainable_parameters()
184
+
185
+ self.mistral_proj = nn.Linear(
186
+ self.qformer.config.hidden_size, self.mistral_model.config.hidden_size
187
+ )
188
+ self.max_txt_len = max_txt_len
189
+
190
+ # load weights of VideoChat2
191
+ if videochat2_model_path:
192
+ logger.info(f"Load VideoChat2 from: {videochat2_model_path}")
193
+ ckpt = torch.load(videochat2_model_path, map_location="cpu")
194
+ if 'model' in ckpt.keys():
195
+ msg = self.load_state_dict(ckpt['model'], strict=False)
196
+ else:
197
+ msg = self.load_state_dict(ckpt, strict=False)
198
+ logger.info(msg)
199
+ self.config=pc
200
+
201
+ def vit_to_cpu(self):
202
+ self.vision_layernorm.to("cpu")
203
+ self.vision_layernorm.float()
204
+ self.vision_encoder.to("cpu")
205
+ self.vision_encoder.float()
206
+
207
+ def encode_img(self, image, instruction, new_video=False):
208
+ device = image[0].device
209
+ if self.low_resource:
210
+ self.vit_to_cpu()
211
+ image = [img.to("cpu") for img in image]
212
+
213
+ with self.maybe_autocast():
214
+ # split the image or video according to the shape
215
+ shapes = []
216
+ input_imgs = []
217
+ input_instructions = []
218
+ for idx, img in enumerate(image):
219
+ # logger.info(f"Input shape: {img.shape}")
220
+ T, C, H, W = img.shape
221
+ shapes.append([H//self.local_size, W//self.local_size])
222
+ sub_img = img.reshape(
223
+ 1, T, 3, H//self.local_size, self.local_size, W//self.local_size, self.local_size
224
+ ).permute(0, 3, 5, 1, 2, 4, 6).reshape(-1, T, 3, self.local_size, self.local_size).contiguous()
225
+ input_imgs.append(sub_img)
226
+ input_instructions.extend([instruction[idx]] * len(sub_img))
227
+ if self.add_global:
228
+ glb_img = F.interpolate(
229
+ img.float(), size=(self.local_size, self.local_size), mode='bicubic', align_corners=False
230
+ ).to(sub_img.dtype)
231
+ input_imgs.append(glb_img.unsqueeze(0))
232
+ input_instructions.append(instruction[idx])
233
+ input_imgs = torch.cat(input_imgs, dim=0)
234
+
235
+ T = input_imgs.shape[1]
236
+ use_image = True if T == 1 else False
237
+ input_imgs = input_imgs.permute(0, 2, 1, 3, 4) # [B,T,C,H,W] -> [B,C,T,H,W]
238
+
239
+ image_embeds = self.vision_encoder(input_imgs, use_image)
240
+ B, T, L, C = image_embeds.shape
241
+ image_embeds = image_embeds.reshape(B, -1, C)
242
+ image_embeds = self.vision_layernorm(image_embeds).to(device) # [B, T*L, C]
243
+
244
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
245
+
246
+ if self.extra_num_query_token > 0:
247
+ query_tokens = torch.cat([self.query_tokens, self.extra_query_tokens], dim=1)
248
+ else:
249
+ query_tokens = self.query_tokens
250
+ query_tokens = query_tokens.expand(image_embeds.shape[0], -1, -1)
251
+ if self.qformer_text_input:
252
+ text_Qformer = self.tokenizer(
253
+ input_instructions,
254
+ padding='longest',
255
+ truncation=True,
256
+ max_length=self.max_txt_len,
257
+ return_tensors="pt",
258
+ ).to(image_embeds.device)
259
+ query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image_embeds.device)
260
+ Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
261
+
262
+ query_output = self.qformer.bert(
263
+ text_Qformer.input_ids,
264
+ attention_mask=Qformer_atts,
265
+ query_embeds=query_tokens,
266
+ encoder_hidden_states=image_embeds,
267
+ encoder_attention_mask=image_atts,
268
+ return_dict=True,
269
+ new_video=new_video,
270
+ )
271
+ else:
272
+ query_output = self.qformer.bert(
273
+ query_embeds=query_tokens,
274
+ encoder_hidden_states=image_embeds,
275
+ encoder_attention_mask=image_atts,
276
+ return_dict=True,
277
+ new_video=new_video
278
+ )
279
+
280
+ qformer_features = self.mistral_proj(query_output.last_hidden_state[:, :query_tokens.size(1), :])
281
+ q_C = qformer_features.shape[-1]
282
+
283
+ # merge the features from different split
284
+ # stolen from https://huggingface.co/internlm/internlm-xcomposer2-4khd-7b/blob/main/build_mlp.py#L97-L115
285
+ output_imgs = []
286
+ output_len = []
287
+ for [h, w] in shapes:
288
+ B_ = h * w
289
+ if self.add_global:
290
+ output_imgs.append(qformer_features[:B_+1].view(1, -1, q_C))
291
+ qformer_features = qformer_features[B_+1:]
292
+ else:
293
+ output_imgs.append(qformer_features[:B_].view(1, -1, q_C))
294
+ qformer_features = qformer_features[B_:]
295
+ # logger.info(f"Features shape: {output_imgs[-1].shape}")
296
+ output_len.append(output_imgs[-1].shape[1])
297
+
298
+ return output_imgs, output_len, use_image
299
+
300
+ def _get_text_len(self, text):
301
+ return self.mistral_tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.shape[1]
302
+
303
+ def forward(self, image, text_input, instruction):
304
+ if len(image[0].shape) == 1:
305
+ use_text = True
306
+ device = image[0].device
307
+ batch_size = len(image)
308
+ img_lens = [0] * batch_size
309
+ else:
310
+ use_text = False
311
+ img_embeds, img_lens, use_image = self.encode_img(image, instruction)
312
+ device = img_embeds[0].device
313
+ batch_size = len(img_embeds)
314
+
315
+ # mark the largest length
316
+ # when padding, the attention mask will be 0
317
+ max_len = 0
318
+ input_embed_list = []
319
+ p_before_len_list = []
320
+ target_list = []
321
+ # handle each prompt individually
322
+ for idx, prompt in enumerate(text_input):
323
+ if use_text:
324
+ p_after = prompt
325
+ p_after_tokens = self.mistral_tokenizer(p_after, return_tensors="pt", add_special_tokens=False).to(device)
326
+ if self.use_lora:
327
+ p_after_embeds = self.mistral_model.base_model.model.model.embed_tokens(p_after_tokens.input_ids)
328
+ else:
329
+ p_after_embeds = self.mistral_model.model.embed_tokens(p_after_tokens.input_ids)
330
+ input_embeds = p_after_embeds
331
+ else:
332
+ tmp_img_embeds = img_embeds[idx]
333
+ # split the prompt via END_TOKEN
334
+ end_token = self.img_end_token if use_image else self.end_token
335
+ p_before, p_after = prompt.split(end_token)
336
+ p_after = end_token + p_after
337
+ p_before_tokens = self.mistral_tokenizer(p_before, return_tensors="pt", add_special_tokens=False).to(tmp_img_embeds.device)
338
+ p_after_tokens = self.mistral_tokenizer(p_after, return_tensors="pt", add_special_tokens=False).to(tmp_img_embeds.device)
339
+ if self.use_lora:
340
+ p_before_embeds = self.mistral_model.base_model.model.model.embed_tokens(p_before_tokens.input_ids)
341
+ p_after_embeds = self.mistral_model.base_model.model.model.embed_tokens(p_after_tokens.input_ids)
342
+ else:
343
+ p_before_embeds = self.mistral_model.model.embed_tokens(p_before_tokens.input_ids)
344
+ p_after_embeds = self.mistral_model.model.embed_tokens(p_after_tokens.input_ids)
345
+ input_embeds = torch.cat([p_before_embeds, tmp_img_embeds, p_after_embeds], dim=1)
346
+
347
+ # extract the answers and mask the target
348
+ # the answers are only in the p_after
349
+ sep1 = self.human_start + " "
350
+ sep2 = " " + self.human_end + " "
351
+ raw_text = p_after.split(sep2)
352
+ for idx in range(0, len(raw_text) - 1):
353
+ raw_text[idx] = raw_text[idx] + sep2
354
+ # the first raw_text contains system and question
355
+ # the last raw_text only contains answer
356
+ # rstrip() for the extra " "
357
+ answer_targets = p_after_tokens.input_ids.clone()
358
+ # [target] "xxxxx. </s>"
359
+ cur_len = self._get_text_len(raw_text[0].rstrip())
360
+ answer_targets[:, :cur_len] = -100
361
+ for text in raw_text[1:-1]:
362
+ total_len = self._get_text_len(text.rstrip())
363
+ ans_len = self._get_text_len((text.split(sep1)[0]).rstrip())
364
+ answer_targets[:, (cur_len+ans_len):(cur_len+total_len)] = -100
365
+ cur_len += total_len
366
+ cur_len += self._get_text_len(raw_text[-1].rstrip())
367
+
368
+ if self.debug: # Inspect and check the correctness of masking
369
+ z = answer_targets[0].clone()
370
+ z = torch.where(z == -100, self.mistral_tokenizer.unk_token_id, z)
371
+ logger.info(self.mistral_tokenizer.decode(z))
372
+
373
+ assert cur_len == answer_targets.shape[1], f"The final length ({cur_len}) is not equal to the original prompt ({answer_targets.shape[1]}): {prompt}"
374
+
375
+ max_len = max(max_len, input_embeds.shape[1])
376
+ input_embed_list.append(input_embeds)
377
+ if use_text:
378
+ p_before_len_list.append(0)
379
+ else:
380
+ p_before_len_list.append(p_before_tokens.input_ids.shape[1])
381
+ target_list.append(answer_targets)
382
+
383
+ # plus one for bos
384
+ # max_txt_len plus num_query_token is the max len
385
+ txt_len = min(max_len + 1, self.max_txt_len + max(img_lens))
386
+ inputs_embeds = torch.ones([batch_size, txt_len], dtype=torch.long).to(device) * self.mistral_tokenizer.pad_token_id
387
+ if self.use_lora:
388
+ inputs_embeds = self.mistral_model.base_model.model.model.embed_tokens(inputs_embeds)
389
+ else:
390
+ inputs_embeds = self.mistral_model.model.embed_tokens(inputs_embeds)
391
+ attention_mask = torch.zeros([batch_size, txt_len], dtype=torch.long).to(device)
392
+ targets = torch.ones([batch_size, txt_len], dtype=torch.long).to(device).fill_(-100)
393
+ # set bos_token
394
+ inputs_embeds[:, :1] = self.mistral_tokenizer.bos_token_id
395
+
396
+ for idx in range(batch_size):
397
+ input_len = min(input_embed_list[idx].shape[1], txt_len - 1)
398
+ # if less than txt_len, the input will be padding
399
+ # if more than txt_len, the input will be truncated
400
+ inputs_embeds[idx, 1:(input_len+1)] = input_embed_list[idx][:, :input_len]
401
+ # the attention_mask is 0 when padding
402
+ attention_mask[idx, :(input_len+1)] = 1
403
+ # the target is -100 when padding
404
+ p_before_len = p_before_len_list[idx]
405
+ targets[idx, (p_before_len+img_lens[idx]+1):(input_len+1)] = target_list[idx][0, :(input_len-p_before_len-img_lens[idx])]
406
+
407
+ with self.maybe_autocast():
408
+ outputs = self.mistral_model(
409
+ inputs_embeds=inputs_embeds,
410
+ attention_mask=attention_mask,
411
+ return_dict=True,
412
+ labels=targets,
413
+ use_cache=False, # current flash_attn2 dows not support padding=right for mistral
414
+ )
415
+
416
+ return dict(
417
+ loss=outputs.loss,
418
+ )
vit.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.utils.checkpoint as checkpoint
7
+ from functools import partial
8
+
9
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def _cfg(url='', **kwargs):
15
+ return {
16
+ 'url': url,
17
+ 'num_classes': 400, 'input_size': (3, 224, 224), 'pool_size': None,
18
+ 'crop_pct': .9, 'interpolation': 'bicubic',
19
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
20
+ **kwargs
21
+ }
22
+
23
+
24
+ class DropPath(nn.Module):
25
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
26
+ """
27
+ def __init__(self, drop_prob=None):
28
+ super(DropPath, self).__init__()
29
+ self.drop_prob = drop_prob
30
+
31
+ def forward(self, x):
32
+ return drop_path(x, self.drop_prob, self.training)
33
+
34
+ def extra_repr(self) -> str:
35
+ return 'p={}'.format(self.drop_prob)
36
+
37
+
38
+ class Mlp(nn.Module):
39
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
40
+ super().__init__()
41
+ out_features = out_features or in_features
42
+ hidden_features = hidden_features or in_features
43
+ self.fc1 = nn.Linear(in_features, hidden_features)
44
+ self.act = act_layer()
45
+ self.fc2 = nn.Linear(hidden_features, out_features)
46
+ self.drop = nn.Dropout(drop)
47
+
48
+ def forward(self, x):
49
+ x = self.fc1(x)
50
+ x = self.act(x)
51
+ x = self.drop(x)
52
+ x = self.fc2(x)
53
+ x = self.drop(x)
54
+ return x
55
+
56
+
57
+ class Attention(nn.Module):
58
+ def __init__(
59
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
60
+ proj_drop=0., attn_head_dim=None):
61
+ super().__init__()
62
+ self.num_heads = num_heads
63
+ head_dim = dim // num_heads
64
+ if attn_head_dim is not None:
65
+ head_dim = attn_head_dim
66
+ all_head_dim = head_dim * self.num_heads
67
+ self.scale = qk_scale or head_dim ** -0.5
68
+
69
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
70
+ if qkv_bias:
71
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
72
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
73
+ else:
74
+ self.q_bias = None
75
+ self.v_bias = None
76
+
77
+ self.attn_drop = nn.Dropout(attn_drop)
78
+ self.proj = nn.Linear(all_head_dim, dim)
79
+ self.proj_drop = nn.Dropout(proj_drop)
80
+
81
+ def forward(self, x):
82
+ B, N, C = x.shape
83
+ qkv_bias = None
84
+ if self.q_bias is not None:
85
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
86
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
87
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
88
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
89
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
90
+
91
+ q = q * self.scale
92
+ attn = (q @ k.transpose(-2, -1))
93
+
94
+ attn = attn.softmax(dim=-1)
95
+ attn = self.attn_drop(attn)
96
+
97
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
98
+ x = self.proj(x)
99
+ x = self.proj_drop(x)
100
+ return x
101
+
102
+
103
+ class Block(nn.Module):
104
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
105
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
106
+ attn_head_dim=None):
107
+ super().__init__()
108
+ self.norm1 = norm_layer(dim)
109
+ self.attn = Attention(
110
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
111
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
112
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
113
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
114
+ self.norm2 = norm_layer(dim)
115
+ mlp_hidden_dim = int(dim * mlp_ratio)
116
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
117
+
118
+ if init_values > 0:
119
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
120
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
121
+ else:
122
+ self.gamma_1, self.gamma_2 = None, None
123
+
124
+ def forward(self, x):
125
+ if self.gamma_1 is None:
126
+ x = x + self.drop_path(self.attn(self.norm1(x)))
127
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
128
+ else:
129
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
130
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
131
+ return x
132
+
133
+
134
+ class PatchEmbed(nn.Module):
135
+ """ Image to Patch Embedding
136
+ """
137
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
138
+ super().__init__()
139
+ img_size = to_2tuple(img_size)
140
+ patch_size = to_2tuple(patch_size)
141
+ self.tubelet_size = int(tubelet_size)
142
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
143
+ self.img_size = img_size
144
+ self.patch_size = patch_size
145
+ self.num_patches = num_patches
146
+ self.proj = nn.Conv3d(
147
+ in_channels=in_chans, out_channels=embed_dim,
148
+ kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
149
+ stride=(self.tubelet_size, patch_size[0], patch_size[1])
150
+ )
151
+ logger.info(f'Num of patches: {num_patches}')
152
+
153
+ def forward(self, x, **kwargs):
154
+ B, C, T, H, W = x.shape
155
+ # FIXME look at relaxing size constraints
156
+ # assert H == self.img_size[0] and W == self.img_size[1], \
157
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
158
+ x = self.proj(x).flatten(2).transpose(1, 2)
159
+ return x
160
+
161
+ # sin-cos position encoding
162
+ # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
163
+ def get_sinusoid_encoding_table(n_position, d_hid, ckpt_num_frame=-1, cur_frame=12):
164
+ ''' Sinusoid position encoding table '''
165
+ # TODO: make it with torch instead of numpy
166
+ def get_position_angle_vec(position):
167
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
168
+
169
+ if ckpt_num_frame != -1 and ckpt_num_frame != cur_frame:
170
+ logger.info(f"Interpolate position embedding")
171
+ logger.info(f"Testing frame: {cur_frame}")
172
+ logger.info(f"Checkpoint frame: {ckpt_num_frame}")
173
+
174
+ T = ckpt_num_frame # checkpoint frame
175
+ new_T = cur_frame # testing frame
176
+ n_position = n_position // new_T * T # generate checkpoint position embedding
177
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
178
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
179
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
180
+ sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
181
+ # interpolate
182
+ P = int((n_position // T) ** 0.5)
183
+ C = d_hid
184
+ sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
185
+ sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T) # BHW, C, T
186
+ sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear')
187
+ sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C
188
+ sinusoid_table = sinusoid_table.flatten(1, 3)
189
+ return sinusoid_table
190
+ else:
191
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
192
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
193
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
194
+ return torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
195
+
196
+
197
+ def get_sinusoid_encoding_table2(n_position=784, d_hid=1024, cur_frame=8, ckpt_num_frame=4, pre_n_position=784):
198
+ ''' Sinusoid position encoding table '''
199
+ # TODO: make it with torch instead of numpy
200
+ def get_position_angle_vec(position):
201
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
202
+
203
+ # generate checkpoint position embedding
204
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(pre_n_position)])
205
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
206
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
207
+ sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
208
+
209
+ print(f"n_position: {n_position}")
210
+ print(f"pre_n_position: {pre_n_position}")
211
+
212
+ if n_position != pre_n_position:
213
+ T = ckpt_num_frame # checkpoint frame
214
+ P = 14 # checkpoint size
215
+ C = d_hid
216
+ new_P = int((n_position // cur_frame) ** 0.5) # testing size
217
+ print(f'Pretraining uses 14x14, but current version is {new_P}x{new_P}')
218
+ print(f'Interpolate the position embedding')
219
+ sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
220
+ sinusoid_table = sinusoid_table.reshape(-1, P, P, C).permute(0, 3, 1, 2)
221
+ sinusoid_table = torch.nn.functional.interpolate(
222
+ sinusoid_table, size=(new_P, new_P), mode='bicubic', align_corners=False)
223
+ # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
224
+ sinusoid_table = sinusoid_table.permute(0, 2, 3, 1).reshape(-1, T, new_P, new_P, C)
225
+ sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C
226
+
227
+ if cur_frame != ckpt_num_frame:
228
+ print(f'Pretraining uses 4 frames, but current frame is {cur_frame}')
229
+ print(f'Interpolate the position embedding')
230
+ T = ckpt_num_frame # checkpoint frame
231
+ new_T = cur_frame # testing frame
232
+ # interpolate
233
+ P = int((n_position // cur_frame) ** 0.5) # testing size
234
+ C = d_hid
235
+ sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
236
+ sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T) # BHW, C, T
237
+ sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear')
238
+ sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C
239
+ sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C
240
+
241
+ return sinusoid_table
242
+
243
+
244
+ class PretrainVisionTransformerEncoder(nn.Module):
245
+ """ Vision Transformer with support for patch or hybrid CNN input stage
246
+ """
247
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12,
248
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
249
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, num_frames=8, tubelet_size=1,
250
+ use_learnable_pos_emb=False,
251
+ use_checkpoint=False, checkpoint_num=0,
252
+ ckpt_num_frame=-1, with_ln=True, return_index=-1
253
+ ):
254
+ super().__init__()
255
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
256
+ self.patch_embed = PatchEmbed(
257
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
258
+ num_frames=num_frames, tubelet_size=tubelet_size
259
+ )
260
+ num_patches = self.patch_embed.num_patches
261
+ self.depth = depth + return_index + 1
262
+ self.use_checkpoint = use_checkpoint
263
+ self.checkpoint_num = checkpoint_num
264
+ logger.info(f"Use checkpoint: {use_checkpoint}")
265
+ logger.info(f"Checkpoint number: {checkpoint_num}")
266
+ logger.info(f"Real runing depth: {self.depth}")
267
+
268
+ # TODO: Add the cls token
269
+ if use_learnable_pos_emb:
270
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
271
+ self.img_pos_embed = nn.Parameter(torch.zeros(1, num_patches//(num_frames//tubelet_size) + 1, embed_dim))
272
+ else:
273
+ # sine-cosine positional embeddings
274
+ if img_size != 224:
275
+ self.pos_embed = get_sinusoid_encoding_table2(num_patches, embed_dim, ckpt_num_frame=ckpt_num_frame, cur_frame=num_frames//tubelet_size)
276
+ self.img_pos_embed = get_sinusoid_encoding_table2(num_patches//(num_frames//tubelet_size), embed_dim, cur_frame=1, ckpt_num_frame=1, pre_n_position=14*14)
277
+ else:
278
+ self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim, ckpt_num_frame=ckpt_num_frame, cur_frame=num_frames//tubelet_size)
279
+ self.img_pos_embed = get_sinusoid_encoding_table(num_patches//(num_frames//tubelet_size), embed_dim)
280
+
281
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
282
+ self.blocks = nn.ModuleList([
283
+ Block(
284
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
285
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
286
+ init_values=init_values)
287
+ for i in range(self.depth)])
288
+
289
+ if with_ln:
290
+ self.norm = norm_layer(embed_dim)
291
+ else:
292
+ self.norm = nn.Identity()
293
+
294
+ if use_learnable_pos_emb:
295
+ trunc_normal_(self.pos_embed, std=.02)
296
+
297
+ @torch.jit.ignore
298
+ def no_weight_decay(self):
299
+ return {'pos_embed', 'cls_token'}
300
+
301
+ def forward_features(self, x, use_image=False):
302
+ x = self.patch_embed(x)
303
+
304
+ if use_image:
305
+ x = x + self.img_pos_embed.type_as(x).to(x.device).clone().detach()
306
+ else:
307
+ x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
308
+
309
+ B, _, C = x.shape
310
+ x_vis = x
311
+
312
+ for idx, blk in enumerate(self.blocks):
313
+ if self.use_checkpoint and idx < self.checkpoint_num:
314
+ x_vis = checkpoint.checkpoint(blk, x_vis)
315
+ else:
316
+ x_vis = blk(x_vis)
317
+
318
+ # with ln ot not
319
+ x_vis = self.norm(x_vis)
320
+ return x_vis
321
+
322
+ def forward(self, x, use_image=False):
323
+ x_vis = self.forward_features(x, use_image)
324
+ return x_vis
325
+
326
+
327
+ class PretrainVisionTransformer(nn.Module):
328
+ """ Vision Transformer with support for patch or hybrid CNN input stage
329
+ """
330
+ def __init__(self,
331
+ img_size=224,
332
+ patch_size=16,
333
+ encoder_in_chans=3,
334
+ encoder_embed_dim=768,
335
+ encoder_depth=12,
336
+ encoder_num_heads=12,
337
+ mlp_ratio=4.,
338
+ qkv_bias=True,
339
+ qk_scale=None,
340
+ drop_rate=0.,
341
+ attn_drop_rate=0.,
342
+ drop_path_rate=0.,
343
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
344
+ init_values=0.,
345
+ use_learnable_pos_emb=False,
346
+ num_frames=8,
347
+ tubelet_size=1,
348
+ use_checkpoint=False,
349
+ checkpoint_num=0,
350
+ ckpt_num_frame=4, # the pretrained model uses 4 frames
351
+ return_index=-1,
352
+ with_ln=False
353
+ ):
354
+ super().__init__()
355
+
356
+ self.encoder = PretrainVisionTransformerEncoder(
357
+ img_size=img_size,
358
+ patch_size=patch_size,
359
+ in_chans=encoder_in_chans,
360
+ embed_dim=encoder_embed_dim,
361
+ depth=encoder_depth,
362
+ num_heads=encoder_num_heads,
363
+ mlp_ratio=mlp_ratio,
364
+ qkv_bias=qkv_bias,
365
+ qk_scale=qk_scale,
366
+ drop_rate=drop_rate,
367
+ attn_drop_rate=attn_drop_rate,
368
+ drop_path_rate=drop_path_rate,
369
+ norm_layer=norm_layer,
370
+ init_values=init_values,
371
+ num_frames=num_frames,
372
+ tubelet_size=tubelet_size,
373
+ use_learnable_pos_emb=use_learnable_pos_emb,
374
+ use_checkpoint=use_checkpoint,
375
+ checkpoint_num=checkpoint_num,
376
+ ckpt_num_frame=ckpt_num_frame,
377
+ with_ln=with_ln,
378
+ return_index=return_index
379
+ )
380
+ logger.info(f'With LN: {with_ln}')
381
+ logger.info(f'Total {encoder_depth} layer')
382
+ logger.info(f'Return {encoder_depth+return_index+1}-th layer')
383
+
384
+ self.apply(self._init_weights)
385
+
386
+ def _init_weights(self, m):
387
+ if isinstance(m, nn.Linear):
388
+ nn.init.xavier_uniform_(m.weight)
389
+ if isinstance(m, nn.Linear) and m.bias is not None:
390
+ nn.init.constant_(m.bias, 0)
391
+ elif isinstance(m, nn.LayerNorm):
392
+ nn.init.constant_(m.bias, 0)
393
+ nn.init.constant_(m.weight, 1.0)
394
+
395
+ @torch.jit.ignore
396
+ def no_weight_decay(self):
397
+ return {'pos_embed', 'cls_token', 'clip_pos_embed'}
398
+
399
+ def forward(self, x, use_image=False):
400
+ T = x.shape[2]
401
+ x_vis = self.encoder(x, use_image) # [B, N_vis, C_e]
402
+ B, TL, C = x_vis.shape
403
+ x_vis = x_vis.view(B, T, TL // T, C)
404
+
405
+ return x_vis
406
+
407
+
408
+ def build_vit(config):
409
+ model = PretrainVisionTransformer(
410
+ img_size=config.vision_encoder.img_size,
411
+ patch_size=config.vision_encoder.patch_size,
412
+ encoder_embed_dim=config.vision_encoder.encoder_embed_dim,
413
+ encoder_depth=config.vision_encoder.encoder_depth,
414
+ encoder_num_heads=config.vision_encoder.encoder_num_heads,
415
+ drop_path_rate=config.vision_encoder.drop_path_rate,
416
+ num_frames=config.vision_encoder.num_frames,
417
+ tubelet_size=config.vision_encoder.tubelet_size,
418
+ use_checkpoint=config.vision_encoder.use_checkpoint,
419
+ checkpoint_num=config.vision_encoder.checkpoint_num,
420
+ return_index=config.vision_encoder.get('return_index', -1),
421
+ with_ln=config.vision_encoder.get('with_ln', False),
422
+ )
423
+ model.default_cfg = _cfg()
424
+ if config.vision_encoder.pretrained:
425
+ logger.info(f"Loading pretrained weights from {config.vision_encoder.pretrained}")
426
+ state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu')
427
+ model.load_state_dict(state_dict, strict=False)
428
+ else:
429
+ logger.info("No pretrained weights!!!")
430
+ return model
431
+
432
+
433
+ if __name__ == '__main__':
434
+ import time
435
+ from fvcore.nn import FlopCountAnalysis
436
+ from fvcore.nn import flop_count_table
437
+ import numpy as np
438
+
439
+ seed = 4217
440
+ np.random.seed(seed)
441
+ torch.manual_seed(seed)
442
+ torch.cuda.manual_seed(seed)
443
+ torch.cuda.manual_seed_all(seed)
444
+ num_frames = 4
445
+
446
+ config = {
447
+ 'vision_encoder':
448
+ {
449
+ 'img_size': 224,
450
+ 'patch_size': 16,
451
+ 'encoder_embed_dim': 768,
452
+ 'encoder_depth': 12,
453
+ 'encoder_num_heads': 12,
454
+ 'drop_path_rate': 0.1,
455
+ 'num_frames': num_frames,
456
+ 'tubelet_size': 1,
457
+ 'use_checkpoint': False,
458
+ 'checkpoint_num': 0,
459
+ 'pretrained': 'your_model_path/l16_25m.pth',
460
+ 'ckpt_num_frame': 8,
461
+ 'return_index': -1,
462
+ 'with_ln': False,
463
+ }
464
+ }
465
+ from easydict import EasyDict
466
+ model = build_vit(EasyDict(config))
467
+
468
+ # flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, 224, 224))
469
+ # s = time.time()
470
+ # print(flop_count_table(flops, max_depth=1))
471
+ # print(time.time()-s)
472
+ print(model(torch.rand(1, 3, num_frames, 224, 224), False).shape)