antonbaumann commited on
Commit
d91e16d
·
verified ·
1 Parent(s): 8d909d1

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "laion/CLIP-ViT-B-32-laion2B-s34B-b79K",
3
+ "architectures": [
4
+ "BayesVLMModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoModel": "modeling_bayesvlm_clip.BayesVLMModel",
8
+ "AutoProcessor": "transformers.CLIPProcessor"
9
+ },
10
+ "initializer_factor": 1.0,
11
+ "logit_scale_init_value": 2.6592,
12
+ "model_type": "clip",
13
+ "projection_dim": 512,
14
+ "text_config": {
15
+ "dropout": 0.0,
16
+ "hidden_act": "gelu",
17
+ "model_type": "clip_text_model"
18
+ },
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.40.2",
21
+ "vision_config": {
22
+ "dropout": 0.0,
23
+ "hidden_act": "gelu",
24
+ "model_type": "clip_vision_model"
25
+ }
26
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
modeling_bayesvlm_clip.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ from transformers import CLIPModel, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
8
+ from transformers.modeling_outputs import ModelOutput
9
+
10
+
11
+ def _as_optional_tensor(tensor: torch.Tensor | None) -> torch.Tensor | None:
12
+ return tensor if tensor is not None else None
13
+
14
+
15
+ def _diag_cov(
16
+ activations: torch.Tensor,
17
+ a_inv: torch.Tensor,
18
+ b_diag: torch.Tensor,
19
+ add_bias: bool,
20
+ ) -> torch.Tensor | None:
21
+ if a_inv.numel() == 0 or b_diag.numel() == 0:
22
+ return None
23
+
24
+ if add_bias:
25
+ ones = torch.ones_like(activations[:, :1])
26
+ activations = torch.cat([activations, ones], dim=-1)
27
+
28
+ quad = torch.einsum("ij,jk,ik->i", activations, a_inv, activations)[:, None]
29
+ return quad * b_diag
30
+
31
+
32
+ def _std_from_var(var: torch.Tensor | None) -> torch.Tensor | None:
33
+ if var is None:
34
+ return None
35
+ return torch.sqrt(var)
36
+
37
+ def _get_output(outputs, name: str, index: int):
38
+ if hasattr(outputs, name):
39
+ return getattr(outputs, name)
40
+ if isinstance(outputs, (tuple, list)) and len(outputs) > index:
41
+ return outputs[index]
42
+ return None
43
+
44
+ def _normalize_mean_and_var(
45
+ mean: torch.Tensor,
46
+ var: torch.Tensor,
47
+ eps: float = 1e-6,
48
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
49
+ r2 = (mean**2).sum(dim=-1, keepdim=True).clamp_min(eps)
50
+ r = torch.sqrt(r2)
51
+ normalized = mean / r
52
+
53
+ # Delta-method approximation with diagonal covariance.
54
+ y2 = normalized**2
55
+ sum_y2v = (y2 * var).sum(dim=-1, keepdim=True)
56
+ norm_var = (var - 2 * y2 * var + y2 * sum_y2v) / r2
57
+ norm_var = norm_var.clamp_min(0)
58
+ return normalized, norm_var
59
+
60
+
61
+ @dataclass
62
+ class BayesVLMEmbeddingOutput(ModelOutput):
63
+ mean: torch.FloatTensor | None = None
64
+ var: torch.FloatTensor | None = None
65
+ std: torch.FloatTensor | None = None
66
+
67
+
68
+ @dataclass
69
+ class BayesVLMTextModelOutput(ModelOutput):
70
+ text_embeds: torch.FloatTensor | None = None
71
+ text_embeds_var: torch.FloatTensor | None = None
72
+ text_embeds_std: torch.FloatTensor | None = None
73
+ last_hidden_state: torch.FloatTensor | None = None
74
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
75
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
76
+
77
+
78
+ @dataclass
79
+ class BayesVLMVisionModelOutput(ModelOutput):
80
+ image_embeds: torch.FloatTensor | None = None
81
+ image_embeds_var: torch.FloatTensor | None = None
82
+ image_embeds_std: torch.FloatTensor | None = None
83
+ last_hidden_state: torch.FloatTensor | None = None
84
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
85
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
86
+
87
+
88
+ @dataclass
89
+ class BayesVLMOutput(ModelOutput):
90
+ loss: torch.FloatTensor | None = None
91
+ logits_per_image: torch.FloatTensor | None = None
92
+ logits_per_text: torch.FloatTensor | None = None
93
+ logits_per_image_var: torch.FloatTensor | None = None
94
+ logits_per_text_var: torch.FloatTensor | None = None
95
+ logits_per_image_std: torch.FloatTensor | None = None
96
+ logits_per_text_std: torch.FloatTensor | None = None
97
+ text_embeds: torch.FloatTensor | None = None
98
+ image_embeds: torch.FloatTensor | None = None
99
+ text_embeds_var: torch.FloatTensor | None = None
100
+ image_embeds_var: torch.FloatTensor | None = None
101
+ text_embeds_std: torch.FloatTensor | None = None
102
+ image_embeds_std: torch.FloatTensor | None = None
103
+ text_model_output: Optional[ModelOutput] = None
104
+ vision_model_output: Optional[ModelOutput] = None
105
+
106
+
107
+ class BayesVLMTextModel(CLIPTextModelWithProjection):
108
+ def __init__(self, config):
109
+ super().__init__(config)
110
+ hidden = int(config.hidden_size)
111
+ proj = int(config.projection_dim)
112
+ self.register_buffer("a_inv", torch.zeros(hidden, hidden))
113
+ self.register_buffer("b_diag", torch.zeros(proj))
114
+
115
+ def set_covariance(self, a_inv: torch.Tensor, b_inv: torch.Tensor) -> None:
116
+ self.a_inv = a_inv
117
+ self.b_diag = torch.diagonal(b_inv)
118
+
119
+ def forward(
120
+ self,
121
+ input_ids: Optional[torch.LongTensor] = None,
122
+ attention_mask: Optional[torch.Tensor] = None,
123
+ position_ids: Optional[torch.LongTensor] = None,
124
+ output_attentions: Optional[bool] = None,
125
+ output_hidden_states: Optional[bool] = None,
126
+ return_dict: Optional[bool] = None,
127
+ ):
128
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
129
+
130
+ if not return_dict:
131
+ return super().forward(
132
+ input_ids=input_ids,
133
+ attention_mask=attention_mask,
134
+ position_ids=position_ids,
135
+ output_attentions=output_attentions,
136
+ output_hidden_states=output_hidden_states,
137
+ return_dict=return_dict,
138
+ )
139
+
140
+ text_outputs = self.text_model(
141
+ input_ids=input_ids,
142
+ attention_mask=attention_mask,
143
+ position_ids=position_ids,
144
+ output_attentions=output_attentions,
145
+ output_hidden_states=output_hidden_states,
146
+ )
147
+ pooled_output = _get_output(text_outputs, "pooler_output", 1)
148
+ last_hidden_state = _get_output(text_outputs, "last_hidden_state", 0)
149
+ hidden_states = _get_output(text_outputs, "hidden_states", 2)
150
+ attentions = _get_output(text_outputs, "attentions", 3)
151
+ text_embeds = self.text_projection(pooled_output)
152
+
153
+ text_var = _diag_cov(
154
+ pooled_output,
155
+ self.a_inv,
156
+ self.b_diag,
157
+ add_bias=self.text_projection.bias is not None,
158
+ )
159
+ if text_var is None:
160
+ text_var = torch.zeros_like(text_embeds)
161
+ text_std = _std_from_var(text_var)
162
+
163
+ return BayesVLMTextModelOutput(
164
+ text_embeds=text_embeds,
165
+ text_embeds_var=text_var,
166
+ text_embeds_std=text_std,
167
+ last_hidden_state=last_hidden_state,
168
+ hidden_states=hidden_states,
169
+ attentions=attentions,
170
+ )
171
+
172
+
173
+ class BayesVLMVisionModel(CLIPVisionModelWithProjection):
174
+ def __init__(self, config):
175
+ super().__init__(config)
176
+ hidden = int(config.hidden_size)
177
+ proj = int(config.projection_dim)
178
+ self.register_buffer("a_inv", torch.zeros(hidden, hidden))
179
+ self.register_buffer("b_diag", torch.zeros(proj))
180
+
181
+ def set_covariance(self, a_inv: torch.Tensor, b_inv: torch.Tensor) -> None:
182
+ self.a_inv = a_inv
183
+ self.b_diag = torch.diagonal(b_inv)
184
+
185
+ def forward(
186
+ self,
187
+ pixel_values: Optional[torch.FloatTensor] = None,
188
+ output_attentions: Optional[bool] = None,
189
+ output_hidden_states: Optional[bool] = None,
190
+ return_dict: Optional[bool] = None,
191
+ ):
192
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
193
+
194
+ if not return_dict:
195
+ return super().forward(
196
+ pixel_values=pixel_values,
197
+ output_attentions=output_attentions,
198
+ output_hidden_states=output_hidden_states,
199
+ return_dict=return_dict,
200
+ )
201
+
202
+ vision_outputs = self.vision_model(
203
+ pixel_values=pixel_values,
204
+ output_attentions=output_attentions,
205
+ output_hidden_states=output_hidden_states,
206
+ )
207
+ pooled_output = _get_output(vision_outputs, "pooler_output", 1)
208
+ last_hidden_state = _get_output(vision_outputs, "last_hidden_state", 0)
209
+ hidden_states = _get_output(vision_outputs, "hidden_states", 2)
210
+ attentions = _get_output(vision_outputs, "attentions", 3)
211
+ image_embeds = self.visual_projection(pooled_output)
212
+
213
+ image_var = _diag_cov(
214
+ pooled_output,
215
+ self.a_inv,
216
+ self.b_diag,
217
+ add_bias=self.visual_projection.bias is not None,
218
+ )
219
+ if image_var is None:
220
+ image_var = torch.zeros_like(image_embeds)
221
+ image_std = _std_from_var(image_var)
222
+
223
+ return BayesVLMVisionModelOutput(
224
+ image_embeds=image_embeds,
225
+ image_embeds_var=image_var,
226
+ image_embeds_std=image_std,
227
+ last_hidden_state=last_hidden_state,
228
+ hidden_states=hidden_states,
229
+ attentions=attentions,
230
+ )
231
+
232
+
233
+ class BayesVLMModel(CLIPModel):
234
+ def __init__(self, config):
235
+ super().__init__(config)
236
+ text_hidden = int(config.text_config.hidden_size)
237
+ vision_hidden = int(config.vision_config.hidden_size)
238
+ proj = int(config.projection_dim)
239
+ self.register_buffer("text_a_inv", torch.zeros(text_hidden, text_hidden))
240
+ self.register_buffer("text_b_diag", torch.zeros(proj))
241
+ self.register_buffer("image_a_inv", torch.zeros(vision_hidden, vision_hidden))
242
+ self.register_buffer("image_b_diag", torch.zeros(proj))
243
+
244
+ def set_covariances(
245
+ self,
246
+ image_a_inv: torch.Tensor,
247
+ image_b_inv: torch.Tensor,
248
+ text_a_inv: torch.Tensor,
249
+ text_b_inv: torch.Tensor,
250
+ ) -> None:
251
+ self.image_a_inv = image_a_inv
252
+ self.image_b_diag = torch.diagonal(image_b_inv)
253
+ self.text_a_inv = text_a_inv
254
+ self.text_b_diag = torch.diagonal(text_b_inv)
255
+
256
+ def _expected_logits_and_var(
257
+ self,
258
+ image_embeds: torch.Tensor,
259
+ text_embeds: torch.Tensor,
260
+ image_acts: torch.Tensor,
261
+ text_acts: torch.Tensor,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor | None]:
263
+ scale = self.logit_scale.exp()
264
+
265
+ if self.image_a_inv.numel() == 0 or self.text_a_inv.numel() == 0:
266
+ image_norm = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
267
+ text_norm = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
268
+ logits = image_norm @ text_norm.t()
269
+ logits = logits * scale
270
+ return logits, None
271
+
272
+ image_diag_cov = _diag_cov(
273
+ image_acts,
274
+ self.image_a_inv,
275
+ self.image_b_diag,
276
+ add_bias=self.visual_projection.bias is not None,
277
+ )
278
+ text_diag_cov = _diag_cov(
279
+ text_acts,
280
+ self.text_a_inv,
281
+ self.text_b_diag,
282
+ add_bias=self.text_projection.bias is not None,
283
+ )
284
+
285
+ norm_image = image_embeds**2 + image_diag_cov
286
+ norm_text = text_embeds**2 + text_diag_cov
287
+ expect_norm_image = norm_image.sum(dim=-1, keepdim=True)
288
+ expect_norm_text = norm_text.sum(dim=-1, keepdim=True)
289
+
290
+ expected_similarity = torch.matmul(
291
+ image_embeds / torch.sqrt(expect_norm_image),
292
+ (text_embeds / torch.sqrt(expect_norm_text)).t(),
293
+ )
294
+
295
+ term1 = torch.matmul(norm_image, text_diag_cov.t())
296
+ term2 = torch.matmul(image_diag_cov, (text_embeds**2).t())
297
+ variance_similarity = (term1 + term2) / (expect_norm_image * expect_norm_text.t())
298
+
299
+ logits_mean = expected_similarity * scale
300
+ logits_var = variance_similarity * (scale**2)
301
+ return logits_mean, logits_var
302
+
303
+ def get_text_features(
304
+ self,
305
+ input_ids: Optional[torch.LongTensor] = None,
306
+ attention_mask: Optional[torch.Tensor] = None,
307
+ position_ids: Optional[torch.LongTensor] = None,
308
+ output_attentions: Optional[bool] = None,
309
+ output_hidden_states: Optional[bool] = None,
310
+ return_dict: Optional[bool] = None,
311
+ return_std: bool = False,
312
+ ):
313
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
314
+ output_hidden_states = (
315
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
316
+ )
317
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
318
+
319
+ text_outputs = self.text_model(
320
+ input_ids=input_ids,
321
+ attention_mask=attention_mask,
322
+ position_ids=position_ids,
323
+ output_attentions=output_attentions,
324
+ output_hidden_states=output_hidden_states,
325
+ )
326
+ pooled_output = _get_output(text_outputs, "pooler_output", 1)
327
+ text_embeds = self.text_projection(pooled_output)
328
+
329
+ text_var = _diag_cov(
330
+ pooled_output,
331
+ self.text_a_inv,
332
+ self.text_b_diag,
333
+ add_bias=self.text_projection.bias is not None,
334
+ )
335
+ if text_var is None:
336
+ text_var = torch.zeros_like(text_embeds)
337
+ text_std = _std_from_var(text_var)
338
+
339
+ if not return_dict and not return_std:
340
+ return text_embeds
341
+
342
+ return BayesVLMEmbeddingOutput(mean=text_embeds, var=text_var, std=text_std)
343
+
344
+ def get_image_features(
345
+ self,
346
+ pixel_values: Optional[torch.FloatTensor] = None,
347
+ output_attentions: Optional[bool] = None,
348
+ output_hidden_states: Optional[bool] = None,
349
+ return_dict: Optional[bool] = None,
350
+ return_std: bool = False,
351
+ ):
352
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
353
+ output_hidden_states = (
354
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
355
+ )
356
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
357
+
358
+ vision_outputs = self.vision_model(
359
+ pixel_values=pixel_values,
360
+ output_attentions=output_attentions,
361
+ output_hidden_states=output_hidden_states,
362
+ )
363
+ pooled_output = _get_output(vision_outputs, "pooler_output", 1)
364
+ image_embeds = self.visual_projection(pooled_output)
365
+
366
+ image_var = _diag_cov(
367
+ pooled_output,
368
+ self.image_a_inv,
369
+ self.image_b_diag,
370
+ add_bias=self.visual_projection.bias is not None,
371
+ )
372
+ if image_var is None:
373
+ image_var = torch.zeros_like(image_embeds)
374
+ image_std = _std_from_var(image_var)
375
+
376
+ if not return_dict and not return_std:
377
+ return image_embeds
378
+
379
+ return BayesVLMEmbeddingOutput(mean=image_embeds, var=image_var, std=image_std)
380
+
381
+ def forward(
382
+ self,
383
+ input_ids: Optional[torch.LongTensor] = None,
384
+ pixel_values: Optional[torch.FloatTensor] = None,
385
+ attention_mask: Optional[torch.Tensor] = None,
386
+ position_ids: Optional[torch.LongTensor] = None,
387
+ return_loss: Optional[bool] = None,
388
+ output_attentions: Optional[bool] = None,
389
+ output_hidden_states: Optional[bool] = None,
390
+ return_dict: Optional[bool] = None,
391
+ ):
392
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
393
+ output_hidden_states = (
394
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
395
+ )
396
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
397
+
398
+ if not return_dict:
399
+ return super().forward(
400
+ input_ids=input_ids,
401
+ pixel_values=pixel_values,
402
+ attention_mask=attention_mask,
403
+ position_ids=position_ids,
404
+ return_loss=return_loss,
405
+ output_attentions=output_attentions,
406
+ output_hidden_states=output_hidden_states,
407
+ return_dict=return_dict,
408
+ )
409
+
410
+ text_outputs = self.text_model(
411
+ input_ids=input_ids,
412
+ attention_mask=attention_mask,
413
+ position_ids=position_ids,
414
+ output_attentions=output_attentions,
415
+ output_hidden_states=output_hidden_states,
416
+ )
417
+ vision_outputs = self.vision_model(
418
+ pixel_values=pixel_values,
419
+ output_attentions=output_attentions,
420
+ output_hidden_states=output_hidden_states,
421
+ )
422
+
423
+ text_pooled = _get_output(text_outputs, "pooler_output", 1)
424
+ image_pooled = _get_output(vision_outputs, "pooler_output", 1)
425
+
426
+ text_embeds = self.text_projection(text_pooled)
427
+ image_embeds = self.visual_projection(image_pooled)
428
+
429
+ text_var = _diag_cov(
430
+ text_pooled,
431
+ self.text_a_inv,
432
+ self.text_b_diag,
433
+ add_bias=self.text_projection.bias is not None,
434
+ )
435
+ image_var = _diag_cov(
436
+ image_pooled,
437
+ self.image_a_inv,
438
+ self.image_b_diag,
439
+ add_bias=self.visual_projection.bias is not None,
440
+ )
441
+ if text_var is None:
442
+ text_var = torch.zeros_like(text_embeds)
443
+ if image_var is None:
444
+ image_var = torch.zeros_like(image_embeds)
445
+
446
+ text_std = _std_from_var(text_var)
447
+ image_std = _std_from_var(image_var)
448
+
449
+ logits_mean, logits_var = self._expected_logits_and_var(
450
+ image_embeds,
451
+ text_embeds,
452
+ image_pooled,
453
+ text_pooled,
454
+ )
455
+
456
+ text_embeds, text_var = _normalize_mean_and_var(text_embeds, text_var)
457
+ image_embeds, image_var = _normalize_mean_and_var(image_embeds, image_var)
458
+ text_std = _std_from_var(text_var)
459
+ image_std = _std_from_var(image_var)
460
+
461
+ logits_per_image = logits_mean
462
+ logits_per_text = logits_mean.t() if logits_mean is not None else None
463
+
464
+ if logits_var is None and logits_mean is not None:
465
+ logits_var = torch.zeros_like(logits_mean)
466
+ logits_per_image_var = _as_optional_tensor(logits_var)
467
+ logits_per_text_var = logits_var.t() if logits_var is not None else None
468
+
469
+ logits_per_image_std = _std_from_var(logits_per_image_var)
470
+ logits_per_text_std = _std_from_var(logits_per_text_var)
471
+
472
+ loss = None
473
+ if return_loss and logits_per_image is not None and logits_per_text is not None:
474
+ labels = torch.arange(logits_per_image.shape[0], device=logits_per_image.device)
475
+ loss_i = torch.nn.functional.cross_entropy(logits_per_image, labels)
476
+ loss_t = torch.nn.functional.cross_entropy(logits_per_text, labels)
477
+ loss = (loss_i + loss_t) / 2
478
+
479
+ return BayesVLMOutput(
480
+ loss=loss,
481
+ logits_per_image=logits_per_image,
482
+ logits_per_text=logits_per_text,
483
+ logits_per_image_var=logits_per_image_var,
484
+ logits_per_text_var=logits_per_text_var,
485
+ logits_per_image_std=logits_per_image_std,
486
+ logits_per_text_std=logits_per_text_std,
487
+ text_embeds=text_embeds,
488
+ image_embeds=image_embeds,
489
+ text_embeds_var=text_var,
490
+ image_embeds_var=image_var,
491
+ text_embeds_std=text_std,
492
+ image_embeds_std=image_std,
493
+ text_model_output=text_outputs,
494
+ vision_model_output=vision_outputs,
495
+ )
preprocessor_config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_valid_processor_keys": [
3
+ "images",
4
+ "do_resize",
5
+ "size",
6
+ "resample",
7
+ "do_center_crop",
8
+ "crop_size",
9
+ "do_rescale",
10
+ "rescale_factor",
11
+ "do_normalize",
12
+ "image_mean",
13
+ "image_std",
14
+ "do_convert_rgb",
15
+ "return_tensors",
16
+ "data_format",
17
+ "input_data_format"
18
+ ],
19
+ "crop_size": {
20
+ "height": 224,
21
+ "width": 224
22
+ },
23
+ "do_center_crop": true,
24
+ "do_convert_rgb": true,
25
+ "do_normalize": true,
26
+ "do_rescale": true,
27
+ "do_resize": true,
28
+ "image_mean": [
29
+ 0.48145466,
30
+ 0.4578275,
31
+ 0.40821073
32
+ ],
33
+ "image_processor_type": "CLIPImageProcessor",
34
+ "image_std": [
35
+ 0.26862954,
36
+ 0.26130258,
37
+ 0.27577711
38
+ ],
39
+ "processor_class": "CLIPProcessor",
40
+ "resample": 3,
41
+ "rescale_factor": 0.00392156862745098,
42
+ "size": {
43
+ "shortest_edge": 224
44
+ }
45
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77f85aa5629f2cb0abe974fd5da165f1fcce6823b4088c0ce5df845c2e89e2a1
3
+ size 610746962
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
text/config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BayesVLMTextModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoModel": "modeling_bayesvlm_clip.BayesVLMTextModel",
8
+ "AutoProcessor": "transformers.CLIPProcessor"
9
+ },
10
+ "bos_token_id": 49406,
11
+ "dropout": 0.0,
12
+ "eos_token_id": 49407,
13
+ "hidden_act": "gelu",
14
+ "hidden_size": 512,
15
+ "initializer_factor": 1.0,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 2048,
18
+ "layer_norm_eps": 1e-05,
19
+ "max_position_embeddings": 77,
20
+ "model_type": "clip_text_model",
21
+ "num_attention_heads": 8,
22
+ "num_hidden_layers": 12,
23
+ "pad_token_id": 1,
24
+ "projection_dim": 512,
25
+ "torch_dtype": "float32",
26
+ "transformers_version": "4.40.2",
27
+ "vocab_size": 49408
28
+ }
text/modeling_bayesvlm_clip.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ from transformers import CLIPModel, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
8
+ from transformers.modeling_outputs import ModelOutput
9
+
10
+
11
+ def _as_optional_tensor(tensor: torch.Tensor | None) -> torch.Tensor | None:
12
+ return tensor if tensor is not None else None
13
+
14
+
15
+ def _diag_cov(
16
+ activations: torch.Tensor,
17
+ a_inv: torch.Tensor,
18
+ b_diag: torch.Tensor,
19
+ add_bias: bool,
20
+ ) -> torch.Tensor | None:
21
+ if a_inv.numel() == 0 or b_diag.numel() == 0:
22
+ return None
23
+
24
+ if add_bias:
25
+ ones = torch.ones_like(activations[:, :1])
26
+ activations = torch.cat([activations, ones], dim=-1)
27
+
28
+ quad = torch.einsum("ij,jk,ik->i", activations, a_inv, activations)[:, None]
29
+ return quad * b_diag
30
+
31
+
32
+ def _std_from_var(var: torch.Tensor | None) -> torch.Tensor | None:
33
+ if var is None:
34
+ return None
35
+ return torch.sqrt(var)
36
+
37
+ def _get_output(outputs, name: str, index: int):
38
+ if hasattr(outputs, name):
39
+ return getattr(outputs, name)
40
+ if isinstance(outputs, (tuple, list)) and len(outputs) > index:
41
+ return outputs[index]
42
+ return None
43
+
44
+ def _normalize_mean_and_var(
45
+ mean: torch.Tensor,
46
+ var: torch.Tensor,
47
+ eps: float = 1e-6,
48
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
49
+ r2 = (mean**2).sum(dim=-1, keepdim=True).clamp_min(eps)
50
+ r = torch.sqrt(r2)
51
+ normalized = mean / r
52
+
53
+ # Delta-method approximation with diagonal covariance.
54
+ y2 = normalized**2
55
+ sum_y2v = (y2 * var).sum(dim=-1, keepdim=True)
56
+ norm_var = (var - 2 * y2 * var + y2 * sum_y2v) / r2
57
+ norm_var = norm_var.clamp_min(0)
58
+ return normalized, norm_var
59
+
60
+
61
+ @dataclass
62
+ class BayesVLMEmbeddingOutput(ModelOutput):
63
+ mean: torch.FloatTensor | None = None
64
+ var: torch.FloatTensor | None = None
65
+ std: torch.FloatTensor | None = None
66
+
67
+
68
+ @dataclass
69
+ class BayesVLMTextModelOutput(ModelOutput):
70
+ text_embeds: torch.FloatTensor | None = None
71
+ text_embeds_var: torch.FloatTensor | None = None
72
+ text_embeds_std: torch.FloatTensor | None = None
73
+ last_hidden_state: torch.FloatTensor | None = None
74
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
75
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
76
+
77
+
78
+ @dataclass
79
+ class BayesVLMVisionModelOutput(ModelOutput):
80
+ image_embeds: torch.FloatTensor | None = None
81
+ image_embeds_var: torch.FloatTensor | None = None
82
+ image_embeds_std: torch.FloatTensor | None = None
83
+ last_hidden_state: torch.FloatTensor | None = None
84
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
85
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
86
+
87
+
88
+ @dataclass
89
+ class BayesVLMOutput(ModelOutput):
90
+ loss: torch.FloatTensor | None = None
91
+ logits_per_image: torch.FloatTensor | None = None
92
+ logits_per_text: torch.FloatTensor | None = None
93
+ logits_per_image_var: torch.FloatTensor | None = None
94
+ logits_per_text_var: torch.FloatTensor | None = None
95
+ logits_per_image_std: torch.FloatTensor | None = None
96
+ logits_per_text_std: torch.FloatTensor | None = None
97
+ text_embeds: torch.FloatTensor | None = None
98
+ image_embeds: torch.FloatTensor | None = None
99
+ text_embeds_var: torch.FloatTensor | None = None
100
+ image_embeds_var: torch.FloatTensor | None = None
101
+ text_embeds_std: torch.FloatTensor | None = None
102
+ image_embeds_std: torch.FloatTensor | None = None
103
+ text_model_output: Optional[ModelOutput] = None
104
+ vision_model_output: Optional[ModelOutput] = None
105
+
106
+
107
+ class BayesVLMTextModel(CLIPTextModelWithProjection):
108
+ def __init__(self, config):
109
+ super().__init__(config)
110
+ hidden = int(config.hidden_size)
111
+ proj = int(config.projection_dim)
112
+ self.register_buffer("a_inv", torch.zeros(hidden, hidden))
113
+ self.register_buffer("b_diag", torch.zeros(proj))
114
+
115
+ def set_covariance(self, a_inv: torch.Tensor, b_inv: torch.Tensor) -> None:
116
+ self.a_inv = a_inv
117
+ self.b_diag = torch.diagonal(b_inv)
118
+
119
+ def forward(
120
+ self,
121
+ input_ids: Optional[torch.LongTensor] = None,
122
+ attention_mask: Optional[torch.Tensor] = None,
123
+ position_ids: Optional[torch.LongTensor] = None,
124
+ output_attentions: Optional[bool] = None,
125
+ output_hidden_states: Optional[bool] = None,
126
+ return_dict: Optional[bool] = None,
127
+ ):
128
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
129
+
130
+ if not return_dict:
131
+ return super().forward(
132
+ input_ids=input_ids,
133
+ attention_mask=attention_mask,
134
+ position_ids=position_ids,
135
+ output_attentions=output_attentions,
136
+ output_hidden_states=output_hidden_states,
137
+ return_dict=return_dict,
138
+ )
139
+
140
+ text_outputs = self.text_model(
141
+ input_ids=input_ids,
142
+ attention_mask=attention_mask,
143
+ position_ids=position_ids,
144
+ output_attentions=output_attentions,
145
+ output_hidden_states=output_hidden_states,
146
+ )
147
+ pooled_output = _get_output(text_outputs, "pooler_output", 1)
148
+ last_hidden_state = _get_output(text_outputs, "last_hidden_state", 0)
149
+ hidden_states = _get_output(text_outputs, "hidden_states", 2)
150
+ attentions = _get_output(text_outputs, "attentions", 3)
151
+ text_embeds = self.text_projection(pooled_output)
152
+
153
+ text_var = _diag_cov(
154
+ pooled_output,
155
+ self.a_inv,
156
+ self.b_diag,
157
+ add_bias=self.text_projection.bias is not None,
158
+ )
159
+ if text_var is None:
160
+ text_var = torch.zeros_like(text_embeds)
161
+ text_std = _std_from_var(text_var)
162
+
163
+ return BayesVLMTextModelOutput(
164
+ text_embeds=text_embeds,
165
+ text_embeds_var=text_var,
166
+ text_embeds_std=text_std,
167
+ last_hidden_state=last_hidden_state,
168
+ hidden_states=hidden_states,
169
+ attentions=attentions,
170
+ )
171
+
172
+
173
+ class BayesVLMVisionModel(CLIPVisionModelWithProjection):
174
+ def __init__(self, config):
175
+ super().__init__(config)
176
+ hidden = int(config.hidden_size)
177
+ proj = int(config.projection_dim)
178
+ self.register_buffer("a_inv", torch.zeros(hidden, hidden))
179
+ self.register_buffer("b_diag", torch.zeros(proj))
180
+
181
+ def set_covariance(self, a_inv: torch.Tensor, b_inv: torch.Tensor) -> None:
182
+ self.a_inv = a_inv
183
+ self.b_diag = torch.diagonal(b_inv)
184
+
185
+ def forward(
186
+ self,
187
+ pixel_values: Optional[torch.FloatTensor] = None,
188
+ output_attentions: Optional[bool] = None,
189
+ output_hidden_states: Optional[bool] = None,
190
+ return_dict: Optional[bool] = None,
191
+ ):
192
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
193
+
194
+ if not return_dict:
195
+ return super().forward(
196
+ pixel_values=pixel_values,
197
+ output_attentions=output_attentions,
198
+ output_hidden_states=output_hidden_states,
199
+ return_dict=return_dict,
200
+ )
201
+
202
+ vision_outputs = self.vision_model(
203
+ pixel_values=pixel_values,
204
+ output_attentions=output_attentions,
205
+ output_hidden_states=output_hidden_states,
206
+ )
207
+ pooled_output = _get_output(vision_outputs, "pooler_output", 1)
208
+ last_hidden_state = _get_output(vision_outputs, "last_hidden_state", 0)
209
+ hidden_states = _get_output(vision_outputs, "hidden_states", 2)
210
+ attentions = _get_output(vision_outputs, "attentions", 3)
211
+ image_embeds = self.visual_projection(pooled_output)
212
+
213
+ image_var = _diag_cov(
214
+ pooled_output,
215
+ self.a_inv,
216
+ self.b_diag,
217
+ add_bias=self.visual_projection.bias is not None,
218
+ )
219
+ if image_var is None:
220
+ image_var = torch.zeros_like(image_embeds)
221
+ image_std = _std_from_var(image_var)
222
+
223
+ return BayesVLMVisionModelOutput(
224
+ image_embeds=image_embeds,
225
+ image_embeds_var=image_var,
226
+ image_embeds_std=image_std,
227
+ last_hidden_state=last_hidden_state,
228
+ hidden_states=hidden_states,
229
+ attentions=attentions,
230
+ )
231
+
232
+
233
+ class BayesVLMModel(CLIPModel):
234
+ def __init__(self, config):
235
+ super().__init__(config)
236
+ text_hidden = int(config.text_config.hidden_size)
237
+ vision_hidden = int(config.vision_config.hidden_size)
238
+ proj = int(config.projection_dim)
239
+ self.register_buffer("text_a_inv", torch.zeros(text_hidden, text_hidden))
240
+ self.register_buffer("text_b_diag", torch.zeros(proj))
241
+ self.register_buffer("image_a_inv", torch.zeros(vision_hidden, vision_hidden))
242
+ self.register_buffer("image_b_diag", torch.zeros(proj))
243
+
244
+ def set_covariances(
245
+ self,
246
+ image_a_inv: torch.Tensor,
247
+ image_b_inv: torch.Tensor,
248
+ text_a_inv: torch.Tensor,
249
+ text_b_inv: torch.Tensor,
250
+ ) -> None:
251
+ self.image_a_inv = image_a_inv
252
+ self.image_b_diag = torch.diagonal(image_b_inv)
253
+ self.text_a_inv = text_a_inv
254
+ self.text_b_diag = torch.diagonal(text_b_inv)
255
+
256
+ def _expected_logits_and_var(
257
+ self,
258
+ image_embeds: torch.Tensor,
259
+ text_embeds: torch.Tensor,
260
+ image_acts: torch.Tensor,
261
+ text_acts: torch.Tensor,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor | None]:
263
+ scale = self.logit_scale.exp()
264
+
265
+ if self.image_a_inv.numel() == 0 or self.text_a_inv.numel() == 0:
266
+ image_norm = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
267
+ text_norm = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
268
+ logits = image_norm @ text_norm.t()
269
+ logits = logits * scale
270
+ return logits, None
271
+
272
+ image_diag_cov = _diag_cov(
273
+ image_acts,
274
+ self.image_a_inv,
275
+ self.image_b_diag,
276
+ add_bias=self.visual_projection.bias is not None,
277
+ )
278
+ text_diag_cov = _diag_cov(
279
+ text_acts,
280
+ self.text_a_inv,
281
+ self.text_b_diag,
282
+ add_bias=self.text_projection.bias is not None,
283
+ )
284
+
285
+ norm_image = image_embeds**2 + image_diag_cov
286
+ norm_text = text_embeds**2 + text_diag_cov
287
+ expect_norm_image = norm_image.sum(dim=-1, keepdim=True)
288
+ expect_norm_text = norm_text.sum(dim=-1, keepdim=True)
289
+
290
+ expected_similarity = torch.matmul(
291
+ image_embeds / torch.sqrt(expect_norm_image),
292
+ (text_embeds / torch.sqrt(expect_norm_text)).t(),
293
+ )
294
+
295
+ term1 = torch.matmul(norm_image, text_diag_cov.t())
296
+ term2 = torch.matmul(image_diag_cov, (text_embeds**2).t())
297
+ variance_similarity = (term1 + term2) / (expect_norm_image * expect_norm_text.t())
298
+
299
+ logits_mean = expected_similarity * scale
300
+ logits_var = variance_similarity * (scale**2)
301
+ return logits_mean, logits_var
302
+
303
+ def get_text_features(
304
+ self,
305
+ input_ids: Optional[torch.LongTensor] = None,
306
+ attention_mask: Optional[torch.Tensor] = None,
307
+ position_ids: Optional[torch.LongTensor] = None,
308
+ output_attentions: Optional[bool] = None,
309
+ output_hidden_states: Optional[bool] = None,
310
+ return_dict: Optional[bool] = None,
311
+ return_std: bool = False,
312
+ ):
313
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
314
+ output_hidden_states = (
315
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
316
+ )
317
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
318
+
319
+ text_outputs = self.text_model(
320
+ input_ids=input_ids,
321
+ attention_mask=attention_mask,
322
+ position_ids=position_ids,
323
+ output_attentions=output_attentions,
324
+ output_hidden_states=output_hidden_states,
325
+ )
326
+ pooled_output = _get_output(text_outputs, "pooler_output", 1)
327
+ text_embeds = self.text_projection(pooled_output)
328
+
329
+ text_var = _diag_cov(
330
+ pooled_output,
331
+ self.text_a_inv,
332
+ self.text_b_diag,
333
+ add_bias=self.text_projection.bias is not None,
334
+ )
335
+ if text_var is None:
336
+ text_var = torch.zeros_like(text_embeds)
337
+ text_std = _std_from_var(text_var)
338
+
339
+ if not return_dict and not return_std:
340
+ return text_embeds
341
+
342
+ return BayesVLMEmbeddingOutput(mean=text_embeds, var=text_var, std=text_std)
343
+
344
+ def get_image_features(
345
+ self,
346
+ pixel_values: Optional[torch.FloatTensor] = None,
347
+ output_attentions: Optional[bool] = None,
348
+ output_hidden_states: Optional[bool] = None,
349
+ return_dict: Optional[bool] = None,
350
+ return_std: bool = False,
351
+ ):
352
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
353
+ output_hidden_states = (
354
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
355
+ )
356
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
357
+
358
+ vision_outputs = self.vision_model(
359
+ pixel_values=pixel_values,
360
+ output_attentions=output_attentions,
361
+ output_hidden_states=output_hidden_states,
362
+ )
363
+ pooled_output = _get_output(vision_outputs, "pooler_output", 1)
364
+ image_embeds = self.visual_projection(pooled_output)
365
+
366
+ image_var = _diag_cov(
367
+ pooled_output,
368
+ self.image_a_inv,
369
+ self.image_b_diag,
370
+ add_bias=self.visual_projection.bias is not None,
371
+ )
372
+ if image_var is None:
373
+ image_var = torch.zeros_like(image_embeds)
374
+ image_std = _std_from_var(image_var)
375
+
376
+ if not return_dict and not return_std:
377
+ return image_embeds
378
+
379
+ return BayesVLMEmbeddingOutput(mean=image_embeds, var=image_var, std=image_std)
380
+
381
+ def forward(
382
+ self,
383
+ input_ids: Optional[torch.LongTensor] = None,
384
+ pixel_values: Optional[torch.FloatTensor] = None,
385
+ attention_mask: Optional[torch.Tensor] = None,
386
+ position_ids: Optional[torch.LongTensor] = None,
387
+ return_loss: Optional[bool] = None,
388
+ output_attentions: Optional[bool] = None,
389
+ output_hidden_states: Optional[bool] = None,
390
+ return_dict: Optional[bool] = None,
391
+ ):
392
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
393
+ output_hidden_states = (
394
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
395
+ )
396
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
397
+
398
+ if not return_dict:
399
+ return super().forward(
400
+ input_ids=input_ids,
401
+ pixel_values=pixel_values,
402
+ attention_mask=attention_mask,
403
+ position_ids=position_ids,
404
+ return_loss=return_loss,
405
+ output_attentions=output_attentions,
406
+ output_hidden_states=output_hidden_states,
407
+ return_dict=return_dict,
408
+ )
409
+
410
+ text_outputs = self.text_model(
411
+ input_ids=input_ids,
412
+ attention_mask=attention_mask,
413
+ position_ids=position_ids,
414
+ output_attentions=output_attentions,
415
+ output_hidden_states=output_hidden_states,
416
+ )
417
+ vision_outputs = self.vision_model(
418
+ pixel_values=pixel_values,
419
+ output_attentions=output_attentions,
420
+ output_hidden_states=output_hidden_states,
421
+ )
422
+
423
+ text_pooled = _get_output(text_outputs, "pooler_output", 1)
424
+ image_pooled = _get_output(vision_outputs, "pooler_output", 1)
425
+
426
+ text_embeds = self.text_projection(text_pooled)
427
+ image_embeds = self.visual_projection(image_pooled)
428
+
429
+ text_var = _diag_cov(
430
+ text_pooled,
431
+ self.text_a_inv,
432
+ self.text_b_diag,
433
+ add_bias=self.text_projection.bias is not None,
434
+ )
435
+ image_var = _diag_cov(
436
+ image_pooled,
437
+ self.image_a_inv,
438
+ self.image_b_diag,
439
+ add_bias=self.visual_projection.bias is not None,
440
+ )
441
+ if text_var is None:
442
+ text_var = torch.zeros_like(text_embeds)
443
+ if image_var is None:
444
+ image_var = torch.zeros_like(image_embeds)
445
+
446
+ text_std = _std_from_var(text_var)
447
+ image_std = _std_from_var(image_var)
448
+
449
+ logits_mean, logits_var = self._expected_logits_and_var(
450
+ image_embeds,
451
+ text_embeds,
452
+ image_pooled,
453
+ text_pooled,
454
+ )
455
+
456
+ text_embeds, text_var = _normalize_mean_and_var(text_embeds, text_var)
457
+ image_embeds, image_var = _normalize_mean_and_var(image_embeds, image_var)
458
+ text_std = _std_from_var(text_var)
459
+ image_std = _std_from_var(image_var)
460
+
461
+ logits_per_image = logits_mean
462
+ logits_per_text = logits_mean.t() if logits_mean is not None else None
463
+
464
+ if logits_var is None and logits_mean is not None:
465
+ logits_var = torch.zeros_like(logits_mean)
466
+ logits_per_image_var = _as_optional_tensor(logits_var)
467
+ logits_per_text_var = logits_var.t() if logits_var is not None else None
468
+
469
+ logits_per_image_std = _std_from_var(logits_per_image_var)
470
+ logits_per_text_std = _std_from_var(logits_per_text_var)
471
+
472
+ loss = None
473
+ if return_loss and logits_per_image is not None and logits_per_text is not None:
474
+ labels = torch.arange(logits_per_image.shape[0], device=logits_per_image.device)
475
+ loss_i = torch.nn.functional.cross_entropy(logits_per_image, labels)
476
+ loss_t = torch.nn.functional.cross_entropy(logits_per_text, labels)
477
+ loss = (loss_i + loss_t) / 2
478
+
479
+ return BayesVLMOutput(
480
+ loss=loss,
481
+ logits_per_image=logits_per_image,
482
+ logits_per_text=logits_per_text,
483
+ logits_per_image_var=logits_per_image_var,
484
+ logits_per_text_var=logits_per_text_var,
485
+ logits_per_image_std=logits_per_image_std,
486
+ logits_per_text_std=logits_per_text_std,
487
+ text_embeds=text_embeds,
488
+ image_embeds=image_embeds,
489
+ text_embeds_var=text_var,
490
+ image_embeds_var=image_var,
491
+ text_embeds_std=text_std,
492
+ image_embeds_std=image_std,
493
+ text_model_output=text_outputs,
494
+ vision_model_output=vision_outputs,
495
+ )
text/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff0e694ffe9a4f158d729c7d81291ef05c8db29960e7b4baab382845fefe5245
3
+ size 255875046
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "49406": {
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49407": {
13
+ "content": "<|endoftext|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ }
20
+ },
21
+ "bos_token": "<|startoftext|>",
22
+ "clean_up_tokenization_spaces": true,
23
+ "do_lower_case": true,
24
+ "eos_token": "<|endoftext|>",
25
+ "errors": "replace",
26
+ "model_max_length": 77,
27
+ "pad_token": "<|endoftext|>",
28
+ "processor_class": "CLIPProcessor",
29
+ "tokenizer_class": "CLIPTokenizer",
30
+ "unk_token": "<|endoftext|>"
31
+ }
vision/config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BayesVLMVisionModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoModel": "modeling_bayesvlm_clip.BayesVLMVisionModel",
8
+ "AutoProcessor": "transformers.CLIPProcessor"
9
+ },
10
+ "dropout": 0.0,
11
+ "hidden_act": "gelu",
12
+ "hidden_size": 768,
13
+ "image_size": 224,
14
+ "initializer_factor": 1.0,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 3072,
17
+ "layer_norm_eps": 1e-05,
18
+ "model_type": "clip_vision_model",
19
+ "num_attention_heads": 12,
20
+ "num_channels": 3,
21
+ "num_hidden_layers": 12,
22
+ "patch_size": 32,
23
+ "projection_dim": 512,
24
+ "torch_dtype": "float32",
25
+ "transformers_version": "4.40.2"
26
+ }
vision/modeling_bayesvlm_clip.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ from transformers import CLIPModel, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
8
+ from transformers.modeling_outputs import ModelOutput
9
+
10
+
11
+ def _as_optional_tensor(tensor: torch.Tensor | None) -> torch.Tensor | None:
12
+ return tensor if tensor is not None else None
13
+
14
+
15
+ def _diag_cov(
16
+ activations: torch.Tensor,
17
+ a_inv: torch.Tensor,
18
+ b_diag: torch.Tensor,
19
+ add_bias: bool,
20
+ ) -> torch.Tensor | None:
21
+ if a_inv.numel() == 0 or b_diag.numel() == 0:
22
+ return None
23
+
24
+ if add_bias:
25
+ ones = torch.ones_like(activations[:, :1])
26
+ activations = torch.cat([activations, ones], dim=-1)
27
+
28
+ quad = torch.einsum("ij,jk,ik->i", activations, a_inv, activations)[:, None]
29
+ return quad * b_diag
30
+
31
+
32
+ def _std_from_var(var: torch.Tensor | None) -> torch.Tensor | None:
33
+ if var is None:
34
+ return None
35
+ return torch.sqrt(var)
36
+
37
+ def _get_output(outputs, name: str, index: int):
38
+ if hasattr(outputs, name):
39
+ return getattr(outputs, name)
40
+ if isinstance(outputs, (tuple, list)) and len(outputs) > index:
41
+ return outputs[index]
42
+ return None
43
+
44
+ def _normalize_mean_and_var(
45
+ mean: torch.Tensor,
46
+ var: torch.Tensor,
47
+ eps: float = 1e-6,
48
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
49
+ r2 = (mean**2).sum(dim=-1, keepdim=True).clamp_min(eps)
50
+ r = torch.sqrt(r2)
51
+ normalized = mean / r
52
+
53
+ # Delta-method approximation with diagonal covariance.
54
+ y2 = normalized**2
55
+ sum_y2v = (y2 * var).sum(dim=-1, keepdim=True)
56
+ norm_var = (var - 2 * y2 * var + y2 * sum_y2v) / r2
57
+ norm_var = norm_var.clamp_min(0)
58
+ return normalized, norm_var
59
+
60
+
61
+ @dataclass
62
+ class BayesVLMEmbeddingOutput(ModelOutput):
63
+ mean: torch.FloatTensor | None = None
64
+ var: torch.FloatTensor | None = None
65
+ std: torch.FloatTensor | None = None
66
+
67
+
68
+ @dataclass
69
+ class BayesVLMTextModelOutput(ModelOutput):
70
+ text_embeds: torch.FloatTensor | None = None
71
+ text_embeds_var: torch.FloatTensor | None = None
72
+ text_embeds_std: torch.FloatTensor | None = None
73
+ last_hidden_state: torch.FloatTensor | None = None
74
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
75
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
76
+
77
+
78
+ @dataclass
79
+ class BayesVLMVisionModelOutput(ModelOutput):
80
+ image_embeds: torch.FloatTensor | None = None
81
+ image_embeds_var: torch.FloatTensor | None = None
82
+ image_embeds_std: torch.FloatTensor | None = None
83
+ last_hidden_state: torch.FloatTensor | None = None
84
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
85
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
86
+
87
+
88
+ @dataclass
89
+ class BayesVLMOutput(ModelOutput):
90
+ loss: torch.FloatTensor | None = None
91
+ logits_per_image: torch.FloatTensor | None = None
92
+ logits_per_text: torch.FloatTensor | None = None
93
+ logits_per_image_var: torch.FloatTensor | None = None
94
+ logits_per_text_var: torch.FloatTensor | None = None
95
+ logits_per_image_std: torch.FloatTensor | None = None
96
+ logits_per_text_std: torch.FloatTensor | None = None
97
+ text_embeds: torch.FloatTensor | None = None
98
+ image_embeds: torch.FloatTensor | None = None
99
+ text_embeds_var: torch.FloatTensor | None = None
100
+ image_embeds_var: torch.FloatTensor | None = None
101
+ text_embeds_std: torch.FloatTensor | None = None
102
+ image_embeds_std: torch.FloatTensor | None = None
103
+ text_model_output: Optional[ModelOutput] = None
104
+ vision_model_output: Optional[ModelOutput] = None
105
+
106
+
107
+ class BayesVLMTextModel(CLIPTextModelWithProjection):
108
+ def __init__(self, config):
109
+ super().__init__(config)
110
+ hidden = int(config.hidden_size)
111
+ proj = int(config.projection_dim)
112
+ self.register_buffer("a_inv", torch.zeros(hidden, hidden))
113
+ self.register_buffer("b_diag", torch.zeros(proj))
114
+
115
+ def set_covariance(self, a_inv: torch.Tensor, b_inv: torch.Tensor) -> None:
116
+ self.a_inv = a_inv
117
+ self.b_diag = torch.diagonal(b_inv)
118
+
119
+ def forward(
120
+ self,
121
+ input_ids: Optional[torch.LongTensor] = None,
122
+ attention_mask: Optional[torch.Tensor] = None,
123
+ position_ids: Optional[torch.LongTensor] = None,
124
+ output_attentions: Optional[bool] = None,
125
+ output_hidden_states: Optional[bool] = None,
126
+ return_dict: Optional[bool] = None,
127
+ ):
128
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
129
+
130
+ if not return_dict:
131
+ return super().forward(
132
+ input_ids=input_ids,
133
+ attention_mask=attention_mask,
134
+ position_ids=position_ids,
135
+ output_attentions=output_attentions,
136
+ output_hidden_states=output_hidden_states,
137
+ return_dict=return_dict,
138
+ )
139
+
140
+ text_outputs = self.text_model(
141
+ input_ids=input_ids,
142
+ attention_mask=attention_mask,
143
+ position_ids=position_ids,
144
+ output_attentions=output_attentions,
145
+ output_hidden_states=output_hidden_states,
146
+ )
147
+ pooled_output = _get_output(text_outputs, "pooler_output", 1)
148
+ last_hidden_state = _get_output(text_outputs, "last_hidden_state", 0)
149
+ hidden_states = _get_output(text_outputs, "hidden_states", 2)
150
+ attentions = _get_output(text_outputs, "attentions", 3)
151
+ text_embeds = self.text_projection(pooled_output)
152
+
153
+ text_var = _diag_cov(
154
+ pooled_output,
155
+ self.a_inv,
156
+ self.b_diag,
157
+ add_bias=self.text_projection.bias is not None,
158
+ )
159
+ if text_var is None:
160
+ text_var = torch.zeros_like(text_embeds)
161
+ text_std = _std_from_var(text_var)
162
+
163
+ return BayesVLMTextModelOutput(
164
+ text_embeds=text_embeds,
165
+ text_embeds_var=text_var,
166
+ text_embeds_std=text_std,
167
+ last_hidden_state=last_hidden_state,
168
+ hidden_states=hidden_states,
169
+ attentions=attentions,
170
+ )
171
+
172
+
173
+ class BayesVLMVisionModel(CLIPVisionModelWithProjection):
174
+ def __init__(self, config):
175
+ super().__init__(config)
176
+ hidden = int(config.hidden_size)
177
+ proj = int(config.projection_dim)
178
+ self.register_buffer("a_inv", torch.zeros(hidden, hidden))
179
+ self.register_buffer("b_diag", torch.zeros(proj))
180
+
181
+ def set_covariance(self, a_inv: torch.Tensor, b_inv: torch.Tensor) -> None:
182
+ self.a_inv = a_inv
183
+ self.b_diag = torch.diagonal(b_inv)
184
+
185
+ def forward(
186
+ self,
187
+ pixel_values: Optional[torch.FloatTensor] = None,
188
+ output_attentions: Optional[bool] = None,
189
+ output_hidden_states: Optional[bool] = None,
190
+ return_dict: Optional[bool] = None,
191
+ ):
192
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
193
+
194
+ if not return_dict:
195
+ return super().forward(
196
+ pixel_values=pixel_values,
197
+ output_attentions=output_attentions,
198
+ output_hidden_states=output_hidden_states,
199
+ return_dict=return_dict,
200
+ )
201
+
202
+ vision_outputs = self.vision_model(
203
+ pixel_values=pixel_values,
204
+ output_attentions=output_attentions,
205
+ output_hidden_states=output_hidden_states,
206
+ )
207
+ pooled_output = _get_output(vision_outputs, "pooler_output", 1)
208
+ last_hidden_state = _get_output(vision_outputs, "last_hidden_state", 0)
209
+ hidden_states = _get_output(vision_outputs, "hidden_states", 2)
210
+ attentions = _get_output(vision_outputs, "attentions", 3)
211
+ image_embeds = self.visual_projection(pooled_output)
212
+
213
+ image_var = _diag_cov(
214
+ pooled_output,
215
+ self.a_inv,
216
+ self.b_diag,
217
+ add_bias=self.visual_projection.bias is not None,
218
+ )
219
+ if image_var is None:
220
+ image_var = torch.zeros_like(image_embeds)
221
+ image_std = _std_from_var(image_var)
222
+
223
+ return BayesVLMVisionModelOutput(
224
+ image_embeds=image_embeds,
225
+ image_embeds_var=image_var,
226
+ image_embeds_std=image_std,
227
+ last_hidden_state=last_hidden_state,
228
+ hidden_states=hidden_states,
229
+ attentions=attentions,
230
+ )
231
+
232
+
233
+ class BayesVLMModel(CLIPModel):
234
+ def __init__(self, config):
235
+ super().__init__(config)
236
+ text_hidden = int(config.text_config.hidden_size)
237
+ vision_hidden = int(config.vision_config.hidden_size)
238
+ proj = int(config.projection_dim)
239
+ self.register_buffer("text_a_inv", torch.zeros(text_hidden, text_hidden))
240
+ self.register_buffer("text_b_diag", torch.zeros(proj))
241
+ self.register_buffer("image_a_inv", torch.zeros(vision_hidden, vision_hidden))
242
+ self.register_buffer("image_b_diag", torch.zeros(proj))
243
+
244
+ def set_covariances(
245
+ self,
246
+ image_a_inv: torch.Tensor,
247
+ image_b_inv: torch.Tensor,
248
+ text_a_inv: torch.Tensor,
249
+ text_b_inv: torch.Tensor,
250
+ ) -> None:
251
+ self.image_a_inv = image_a_inv
252
+ self.image_b_diag = torch.diagonal(image_b_inv)
253
+ self.text_a_inv = text_a_inv
254
+ self.text_b_diag = torch.diagonal(text_b_inv)
255
+
256
+ def _expected_logits_and_var(
257
+ self,
258
+ image_embeds: torch.Tensor,
259
+ text_embeds: torch.Tensor,
260
+ image_acts: torch.Tensor,
261
+ text_acts: torch.Tensor,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor | None]:
263
+ scale = self.logit_scale.exp()
264
+
265
+ if self.image_a_inv.numel() == 0 or self.text_a_inv.numel() == 0:
266
+ image_norm = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
267
+ text_norm = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
268
+ logits = image_norm @ text_norm.t()
269
+ logits = logits * scale
270
+ return logits, None
271
+
272
+ image_diag_cov = _diag_cov(
273
+ image_acts,
274
+ self.image_a_inv,
275
+ self.image_b_diag,
276
+ add_bias=self.visual_projection.bias is not None,
277
+ )
278
+ text_diag_cov = _diag_cov(
279
+ text_acts,
280
+ self.text_a_inv,
281
+ self.text_b_diag,
282
+ add_bias=self.text_projection.bias is not None,
283
+ )
284
+
285
+ norm_image = image_embeds**2 + image_diag_cov
286
+ norm_text = text_embeds**2 + text_diag_cov
287
+ expect_norm_image = norm_image.sum(dim=-1, keepdim=True)
288
+ expect_norm_text = norm_text.sum(dim=-1, keepdim=True)
289
+
290
+ expected_similarity = torch.matmul(
291
+ image_embeds / torch.sqrt(expect_norm_image),
292
+ (text_embeds / torch.sqrt(expect_norm_text)).t(),
293
+ )
294
+
295
+ term1 = torch.matmul(norm_image, text_diag_cov.t())
296
+ term2 = torch.matmul(image_diag_cov, (text_embeds**2).t())
297
+ variance_similarity = (term1 + term2) / (expect_norm_image * expect_norm_text.t())
298
+
299
+ logits_mean = expected_similarity * scale
300
+ logits_var = variance_similarity * (scale**2)
301
+ return logits_mean, logits_var
302
+
303
+ def get_text_features(
304
+ self,
305
+ input_ids: Optional[torch.LongTensor] = None,
306
+ attention_mask: Optional[torch.Tensor] = None,
307
+ position_ids: Optional[torch.LongTensor] = None,
308
+ output_attentions: Optional[bool] = None,
309
+ output_hidden_states: Optional[bool] = None,
310
+ return_dict: Optional[bool] = None,
311
+ return_std: bool = False,
312
+ ):
313
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
314
+ output_hidden_states = (
315
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
316
+ )
317
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
318
+
319
+ text_outputs = self.text_model(
320
+ input_ids=input_ids,
321
+ attention_mask=attention_mask,
322
+ position_ids=position_ids,
323
+ output_attentions=output_attentions,
324
+ output_hidden_states=output_hidden_states,
325
+ )
326
+ pooled_output = _get_output(text_outputs, "pooler_output", 1)
327
+ text_embeds = self.text_projection(pooled_output)
328
+
329
+ text_var = _diag_cov(
330
+ pooled_output,
331
+ self.text_a_inv,
332
+ self.text_b_diag,
333
+ add_bias=self.text_projection.bias is not None,
334
+ )
335
+ if text_var is None:
336
+ text_var = torch.zeros_like(text_embeds)
337
+ text_std = _std_from_var(text_var)
338
+
339
+ if not return_dict and not return_std:
340
+ return text_embeds
341
+
342
+ return BayesVLMEmbeddingOutput(mean=text_embeds, var=text_var, std=text_std)
343
+
344
+ def get_image_features(
345
+ self,
346
+ pixel_values: Optional[torch.FloatTensor] = None,
347
+ output_attentions: Optional[bool] = None,
348
+ output_hidden_states: Optional[bool] = None,
349
+ return_dict: Optional[bool] = None,
350
+ return_std: bool = False,
351
+ ):
352
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
353
+ output_hidden_states = (
354
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
355
+ )
356
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
357
+
358
+ vision_outputs = self.vision_model(
359
+ pixel_values=pixel_values,
360
+ output_attentions=output_attentions,
361
+ output_hidden_states=output_hidden_states,
362
+ )
363
+ pooled_output = _get_output(vision_outputs, "pooler_output", 1)
364
+ image_embeds = self.visual_projection(pooled_output)
365
+
366
+ image_var = _diag_cov(
367
+ pooled_output,
368
+ self.image_a_inv,
369
+ self.image_b_diag,
370
+ add_bias=self.visual_projection.bias is not None,
371
+ )
372
+ if image_var is None:
373
+ image_var = torch.zeros_like(image_embeds)
374
+ image_std = _std_from_var(image_var)
375
+
376
+ if not return_dict and not return_std:
377
+ return image_embeds
378
+
379
+ return BayesVLMEmbeddingOutput(mean=image_embeds, var=image_var, std=image_std)
380
+
381
+ def forward(
382
+ self,
383
+ input_ids: Optional[torch.LongTensor] = None,
384
+ pixel_values: Optional[torch.FloatTensor] = None,
385
+ attention_mask: Optional[torch.Tensor] = None,
386
+ position_ids: Optional[torch.LongTensor] = None,
387
+ return_loss: Optional[bool] = None,
388
+ output_attentions: Optional[bool] = None,
389
+ output_hidden_states: Optional[bool] = None,
390
+ return_dict: Optional[bool] = None,
391
+ ):
392
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
393
+ output_hidden_states = (
394
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
395
+ )
396
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
397
+
398
+ if not return_dict:
399
+ return super().forward(
400
+ input_ids=input_ids,
401
+ pixel_values=pixel_values,
402
+ attention_mask=attention_mask,
403
+ position_ids=position_ids,
404
+ return_loss=return_loss,
405
+ output_attentions=output_attentions,
406
+ output_hidden_states=output_hidden_states,
407
+ return_dict=return_dict,
408
+ )
409
+
410
+ text_outputs = self.text_model(
411
+ input_ids=input_ids,
412
+ attention_mask=attention_mask,
413
+ position_ids=position_ids,
414
+ output_attentions=output_attentions,
415
+ output_hidden_states=output_hidden_states,
416
+ )
417
+ vision_outputs = self.vision_model(
418
+ pixel_values=pixel_values,
419
+ output_attentions=output_attentions,
420
+ output_hidden_states=output_hidden_states,
421
+ )
422
+
423
+ text_pooled = _get_output(text_outputs, "pooler_output", 1)
424
+ image_pooled = _get_output(vision_outputs, "pooler_output", 1)
425
+
426
+ text_embeds = self.text_projection(text_pooled)
427
+ image_embeds = self.visual_projection(image_pooled)
428
+
429
+ text_var = _diag_cov(
430
+ text_pooled,
431
+ self.text_a_inv,
432
+ self.text_b_diag,
433
+ add_bias=self.text_projection.bias is not None,
434
+ )
435
+ image_var = _diag_cov(
436
+ image_pooled,
437
+ self.image_a_inv,
438
+ self.image_b_diag,
439
+ add_bias=self.visual_projection.bias is not None,
440
+ )
441
+ if text_var is None:
442
+ text_var = torch.zeros_like(text_embeds)
443
+ if image_var is None:
444
+ image_var = torch.zeros_like(image_embeds)
445
+
446
+ text_std = _std_from_var(text_var)
447
+ image_std = _std_from_var(image_var)
448
+
449
+ logits_mean, logits_var = self._expected_logits_and_var(
450
+ image_embeds,
451
+ text_embeds,
452
+ image_pooled,
453
+ text_pooled,
454
+ )
455
+
456
+ text_embeds, text_var = _normalize_mean_and_var(text_embeds, text_var)
457
+ image_embeds, image_var = _normalize_mean_and_var(image_embeds, image_var)
458
+ text_std = _std_from_var(text_var)
459
+ image_std = _std_from_var(image_var)
460
+
461
+ logits_per_image = logits_mean
462
+ logits_per_text = logits_mean.t() if logits_mean is not None else None
463
+
464
+ if logits_var is None and logits_mean is not None:
465
+ logits_var = torch.zeros_like(logits_mean)
466
+ logits_per_image_var = _as_optional_tensor(logits_var)
467
+ logits_per_text_var = logits_var.t() if logits_var is not None else None
468
+
469
+ logits_per_image_std = _std_from_var(logits_per_image_var)
470
+ logits_per_text_std = _std_from_var(logits_per_text_var)
471
+
472
+ loss = None
473
+ if return_loss and logits_per_image is not None and logits_per_text is not None:
474
+ labels = torch.arange(logits_per_image.shape[0], device=logits_per_image.device)
475
+ loss_i = torch.nn.functional.cross_entropy(logits_per_image, labels)
476
+ loss_t = torch.nn.functional.cross_entropy(logits_per_text, labels)
477
+ loss = (loss_i + loss_t) / 2
478
+
479
+ return BayesVLMOutput(
480
+ loss=loss,
481
+ logits_per_image=logits_per_image,
482
+ logits_per_text=logits_per_text,
483
+ logits_per_image_var=logits_per_image_var,
484
+ logits_per_text_var=logits_per_text_var,
485
+ logits_per_image_std=logits_per_image_std,
486
+ logits_per_text_std=logits_per_text_std,
487
+ text_embeds=text_embeds,
488
+ image_embeds=image_embeds,
489
+ text_embeds_var=text_var,
490
+ image_embeds_var=image_var,
491
+ text_embeds_std=text_std,
492
+ image_embeds_std=image_std,
493
+ text_model_output=text_outputs,
494
+ vision_model_output=vision_outputs,
495
+ )
vision/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c6c33da51b2f2e691b8fb9a69da3525398a68942e523351fad39b19e449230f
3
+ size 354871602
vocab.json ADDED
The diff for this file is too large to render. See raw diff