Safetensors
custom_code
nobottle commited on
Commit
4333430
·
1 Parent(s): d76ef05

add modeling files

Browse files
align_transformers.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import PreTrainedModel
4
+ from transformers.models.dinov2.modeling_dinov2 import Dinov2Encoder
5
+
6
+ from .configuration import AlignTransformerConfig
7
+
8
+
9
+ def build_align_transformer(config):
10
+ if config.model_type == "align_transformer":
11
+ model = AlignTransformer(config)
12
+ else:
13
+ raise NotImplementedError()
14
+
15
+ return model
16
+
17
+
18
+ class AlignTransformer(PreTrainedModel):
19
+ def __init__(self, config: AlignTransformerConfig):
20
+ super().__init__(config)
21
+
22
+ self.projector = None
23
+
24
+ if config.num_hidden_layers:
25
+ self.transformer_layers = Dinov2Encoder(config)
26
+ else:
27
+ self.transformer_layers = None
28
+
29
+ if config.use_layer_norm:
30
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
31
+ else:
32
+ self.layer_norm = None
33
+
34
+ def forward(self, vision_tokens):
35
+
36
+ if self.projector is not None:
37
+
38
+ cls_token = vision_tokens[:, :1]
39
+ patch_tokens = vision_tokens[:, 1:]
40
+
41
+ patch_tokens = self.projector(patch_tokens)["last_hidden_state"]
42
+ vision_tokens = torch.cat([cls_token, patch_tokens], dim=1)
43
+
44
+ if self.transformer_layers is not None:
45
+ vision_tokens = self.transformer_layers(vision_tokens)["last_hidden_state"]
46
+
47
+ if self.layer_norm is not None:
48
+ vision_tokens = self.layer_norm(vision_tokens)
49
+
50
+ return vision_tokens
common_layers.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from transformers.modeling_utils import PreTrainedModel
3
+
4
+
5
+ class BasePreTrainedModel(PreTrainedModel):
6
+ """
7
+ An abstract class to handle weights initialization and
8
+ a simple interface for downloading and loading pretrained models.
9
+ """
10
+
11
+ supports_gradient_checkpointing = True
12
+
13
+ def _init_weights(self, module):
14
+ """Initialize the weights"""
15
+ if (
16
+ isinstance(module, nn.Conv2d) # noqa: SIM101
17
+ or isinstance(module, nn.Embedding)
18
+ or isinstance(module, nn.Linear)
19
+ ):
20
+ module.weight.data.normal_(mean=0.0, std=0.02)
21
+ if hasattr(module, "bias") and module.bias is not None:
22
+ module.bias.data.zero_()
23
+
24
+ elif isinstance(module, nn.LayerNorm):
25
+ module.bias.data.zero_()
26
+ module.weight.data.fill_(1.0)
27
+ elif isinstance(module, nn.Parameter):
28
+ raise ValueError()
configuration.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig
2
+ from transformers.configuration_utils import PretrainedConfig
3
+ from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
4
+
5
+
6
+ class VisionConfig(PretrainedConfig):
7
+ def __init__(
8
+ self,
9
+ **kwargs,
10
+ ):
11
+ super().__init__(**kwargs)
12
+
13
+ @staticmethod
14
+ def from_exp_config(vision_config: dict):
15
+
16
+ model_type = vision_config["model_type"]
17
+
18
+ if model_type in [
19
+ "siglip_vision_model",
20
+ "clip_vision_model",
21
+ "dinov2",
22
+ "sam",
23
+ "raddino",
24
+ ]:
25
+ config = AutoConfig.from_pretrained(
26
+ vision_config["pretrained_name_or_path"]
27
+ )
28
+ config = config.to_dict()
29
+ vision_config.update(config)
30
+ elif model_type == "xrayclip":
31
+ config = AutoConfig.from_pretrained(
32
+ vision_config["pretrained_name_or_path"]
33
+ )
34
+ config = config.to_dict()
35
+ config["model_type"] = "xrayclip"
36
+ vision_config.update(config)
37
+ elif model_type == "biomedclip":
38
+ pass
39
+ elif model_type == "m3ae":
40
+ pass
41
+
42
+ else:
43
+ raise NotImplementedError()
44
+
45
+ vision_config = VisionConfig(**vision_config)
46
+
47
+ return vision_config
48
+
49
+
50
+ class TextConfig(PretrainedConfig):
51
+ def __init__(
52
+ self,
53
+ model_type,
54
+ **kwargs,
55
+ ):
56
+ super().__init__(**kwargs)
57
+ self.model_type = model_type
58
+
59
+ @staticmethod
60
+ def from_exp_config(
61
+ text_config: dict,
62
+ ):
63
+ model_type = text_config["model_type"]
64
+
65
+ if model_type in [
66
+ "siglip_text_model",
67
+ "clip_text_model",
68
+ "mpnet",
69
+ "biomedclip",
70
+ "bioclinicalmpbert",
71
+ ]:
72
+ text_config = TextConfig(**text_config)
73
+ else:
74
+ raise NotImplementedError()
75
+
76
+ return text_config
77
+
78
+
79
+ class AlignTransformerConfig(PretrainedConfig):
80
+ def __init__(
81
+ self,
82
+ model_type: str = "align_transformer",
83
+ projector_config=None,
84
+ **kwargs,
85
+ ):
86
+ super().__init__(**kwargs)
87
+ self.model_type = model_type
88
+ self.projector_config = projector_config
89
+
90
+ @staticmethod
91
+ def from_exp_config(
92
+ align_transformer_config: dict,
93
+ ):
94
+ projector_config = align_transformer_config.pop("projector_config", None)
95
+
96
+ config = Dinov2Config(**align_transformer_config)
97
+ config = config.to_dict()
98
+
99
+ align_transformer_config = AlignTransformerConfig(
100
+ **(config | align_transformer_config),
101
+ projector_config=projector_config,
102
+ )
103
+
104
+ return align_transformer_config
105
+
106
+
107
+ class CxrAlignConfig(PretrainedConfig):
108
+ is_composition = True
109
+
110
+ def __init__(
111
+ self,
112
+ vision_config: dict,
113
+ text_config: dict,
114
+ align_transformer_config: dict,
115
+ **kwargs,
116
+ ):
117
+ super().__init__(**kwargs)
118
+
119
+ # Vision config
120
+ self.vision_config = VisionConfig.from_exp_config(vision_config)
121
+
122
+ # text config
123
+ self.text_config = TextConfig.from_exp_config(text_config)
124
+
125
+ self.align_transformer_config = AlignTransformerConfig.from_exp_config(
126
+ align_transformer_config
127
+ )
128
+
129
+ self.kwargs = kwargs
losses.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.distributed as dist
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class KeyPhraseAlignmentLoss(nn.Module):
11
+
12
+ def __init__(
13
+ self,
14
+ hidden_dim=768,
15
+ use_vision_cls_token=True,
16
+ attn_temperature=None,
17
+ loss_temperature=0.07,
18
+ text_features_l2_norm=False,
19
+ mpnce_row_sum=False,
20
+ mpnce_col_sum=False,
21
+ sim_op="cos",
22
+ use_layer_norm=True,
23
+ **kwargs,
24
+ ):
25
+ super().__init__()
26
+
27
+ self.hidden_dim = hidden_dim
28
+ self.layer_norm = nn.LayerNorm(hidden_dim) if use_layer_norm else None
29
+
30
+ self.use_vision_cls_token = use_vision_cls_token
31
+ self.loss_temperature = nn.Parameter(
32
+ torch.FloatTensor([np.log(loss_temperature)])
33
+ )
34
+ if attn_temperature is not None:
35
+ self.attn_temperature = nn.Parameter(
36
+ torch.FloatTensor([np.log(attn_temperature)])
37
+ )
38
+ else:
39
+ self.attn_temperature = None
40
+ self.text_features_l2_norm = text_features_l2_norm
41
+ self.sim_op = sim_op
42
+
43
+ self.similarity_logit = SimilarityLogit(sim_op)
44
+
45
+ self.mpnce_row_sum = mpnce_row_sum
46
+ self.mpnce_col_sum = mpnce_col_sum
47
+
48
+ def forward(
49
+ self,
50
+ key_phrases,
51
+ vision_tokens,
52
+ forward_text_model,
53
+ ddp_gather=True,
54
+ need_attn_weights=False,
55
+ compute_loss=True,
56
+ **kwargs,
57
+ ):
58
+ outputs = {}
59
+
60
+ text_features, group_map = self.compute_text_features(
61
+ key_phrases, forward_text_model, ddp_gather
62
+ )
63
+
64
+ if ddp_gather and dist.is_initialized():
65
+ vision_tokens = torch.cat(dist.nn.all_gather(vision_tokens), dim=0)
66
+
67
+ if self.layer_norm is not None:
68
+ vision_tokens = self.layer_norm(vision_tokens)
69
+
70
+ vision_patch_tokens = vision_tokens[:, 1:]
71
+
72
+ # text to image cross-attention
73
+ if not self.use_vision_cls_token:
74
+ vision_attn_tokens = vision_patch_tokens
75
+ else:
76
+ vision_attn_tokens = vision_tokens
77
+
78
+ t2i_logits, t2i_attn_weights_list = self.compute_t2i_logits(
79
+ text_features, vision_attn_tokens, need_attn_weights
80
+ )
81
+ outputs["t2i_logits"] = t2i_logits
82
+ outputs["t2i_attn_weights"] = t2i_attn_weights_list
83
+
84
+ if compute_loss:
85
+ losses = {}
86
+ loss = 0
87
+
88
+ # compute t2i loss
89
+ t2i_loss = multi_positive_nce_loss(
90
+ t2i_logits,
91
+ group_map,
92
+ temperature=self.loss_temperature.exp(),
93
+ row_sum=self.mpnce_row_sum,
94
+ col_sum=self.mpnce_col_sum,
95
+ )
96
+ loss += t2i_loss
97
+ losses["t2i_loss"] = t2i_loss
98
+
99
+ losses["loss"] = loss
100
+ outputs["losses"] = losses
101
+ return outputs
102
+
103
+ def compute_text_features(self, key_phrases, forward_text_model, ddp_gather=True):
104
+
105
+ key_text_features_list = list()
106
+ group_list = list()
107
+
108
+ B_local = len(key_phrases)
109
+ # Calculate offset by getting the rank of the current process when using DDP
110
+ local_rank = dist.get_rank() if (ddp_gather and dist.is_initialized()) else 0
111
+
112
+ for i, kp in enumerate(key_phrases):
113
+ feats = forward_text_model(kp)
114
+
115
+ # (N_i, D)
116
+ if self.text_features_l2_norm:
117
+ feat = feats["text_features"]
118
+ else:
119
+ feat = feats["text_features_wo_l2_norm"]
120
+
121
+ if feat.shape[-1] == 2 * self.hidden_dim:
122
+ feat = feat[:, self.hidden_dim :]
123
+
124
+ key_text_features_list.append(feat)
125
+
126
+ # Add local_rank * B_local offset to local index i
127
+ global_index = i + local_rank * B_local
128
+ group_list.extend([global_index] * feat.size(0))
129
+
130
+ text_features = torch.cat(key_text_features_list, dim=0)
131
+ group_map = torch.tensor(group_list, device=text_features.device)
132
+
133
+ if ddp_gather and dist.is_initialized():
134
+ # Gather text_features and image_features and group_map
135
+ text_features = pad_and_gather(text_features)
136
+
137
+ group_map = pad_and_gather(group_map)
138
+ group_map = group_map.long()
139
+
140
+ if self.layer_norm is not None:
141
+ text_features = self.layer_norm(text_features)
142
+
143
+ return text_features, group_map
144
+
145
+ def compute_t2i_logits(
146
+ self, text_features, vision_attn_tokens, need_attn_weights, repeat=True
147
+ ):
148
+
149
+ t2i_logits, t2i_attn_weights_list = self.similarity_logit(
150
+ text_features,
151
+ vision_attn_tokens,
152
+ need_attn_weights,
153
+ repeat=repeat,
154
+ temperature=(
155
+ self.attn_temperature.exp()
156
+ if self.attn_temperature is not None
157
+ else self.loss_temperature.exp()
158
+ ),
159
+ )
160
+
161
+ return t2i_logits, t2i_attn_weights_list
162
+
163
+
164
+ class SimilarityLogit(nn.Module):
165
+ def __init__(self, sim_op="dot", **kwargs):
166
+ super().__init__()
167
+ self.sim_op = sim_op
168
+
169
+ def forward(
170
+ self,
171
+ queries: torch.Tensor,
172
+ local_tokens: torch.Tensor,
173
+ need_attn_weights: bool = False,
174
+ repeat: bool = True,
175
+ **kwargs,
176
+ ):
177
+ if repeat:
178
+ query_attn_features = queries.unsqueeze(0).expand(
179
+ local_tokens.shape[0], queries.shape[0], queries.shape[1]
180
+ )
181
+ else:
182
+ assert queries.dim() == 3
183
+ query_attn_features = queries
184
+
185
+ if self.sim_op == "cos":
186
+ temperature = kwargs.get("temperature")
187
+ assert temperature is not None
188
+ denominator = temperature
189
+ query_attn_features = F.normalize(query_attn_features, p=2, dim=-1)
190
+ local_tokens = F.normalize(local_tokens, p=2, dim=-1)
191
+ elif self.sim_op == "dot":
192
+ denominator = math.sqrt(local_tokens.size(-1))
193
+ else:
194
+ raise NotImplementedError
195
+
196
+ scores = (
197
+ torch.bmm(query_attn_features, local_tokens.permute(0, 2, 1)) / denominator
198
+ )
199
+ attn_weights = F.softmax(scores, dim=-1)
200
+
201
+ aggregated = torch.matmul(attn_weights, local_tokens)
202
+
203
+ query_attn_features = F.normalize(query_attn_features, p=2, dim=-1)
204
+ aggregated = F.normalize(aggregated, p=2, dim=-1)
205
+
206
+ logits = torch.matmul(
207
+ query_attn_features.unsqueeze(2), aggregated.unsqueeze(-1)
208
+ ).squeeze()
209
+
210
+ logits = logits.T
211
+
212
+ if need_attn_weights:
213
+ attn_scores = [scores]
214
+ else:
215
+ attn_scores = None
216
+
217
+ return logits, attn_scores
218
+
219
+
220
+ def multi_positive_nce_loss(
221
+ logits: torch.Tensor,
222
+ group_map: torch.Tensor,
223
+ temperature: float = 1.0,
224
+ eps: float = 1e-8,
225
+ row_sum: bool = False,
226
+ col_sum: bool = False,
227
+ ):
228
+ """
229
+ Args:
230
+ logits: tensor of shape (N_total, B_global), each row is a logit between a key phrase and each candidate image.
231
+ group_map: tensor of shape (N_total,), source image index of each key phrase.
232
+ temperature: scaling factor.
233
+
234
+ For each key phrase row i, the positive is the candidate image index == group_map[i],
235
+ and the rest are treated as negatives.
236
+
237
+ For each column j, each positive for image j is considered independently.
238
+
239
+ Returns:
240
+ loss: scalar tensor.
241
+ """
242
+ scaled_logits = torch.exp(logits / temperature) # (N_total, B_global)
243
+
244
+ pos_logits = scaled_logits[
245
+ torch.arange(scaled_logits.size(0)), group_map
246
+ ] # (N_total,)
247
+
248
+ row_loss = get_row_loss(
249
+ scaled_logits,
250
+ pos_logits,
251
+ group_map,
252
+ eps,
253
+ row_sum,
254
+ )
255
+
256
+ neg_mask = torch.ones_like(scaled_logits)
257
+ neg_mask[torch.arange(scaled_logits.size(0)), group_map] = 0 # (N_total, B_global)
258
+
259
+ column_loss = get_col_loss(
260
+ scaled_logits,
261
+ pos_logits,
262
+ neg_mask,
263
+ group_map,
264
+ eps,
265
+ col_sum,
266
+ )
267
+
268
+ loss = (row_loss.mean() + column_loss.mean()) / 2
269
+
270
+ return loss
271
+
272
+
273
+ def get_row_loss(
274
+ logits: torch.Tensor,
275
+ pos_logits: torch.Tensor,
276
+ group_map: torch.Tensor,
277
+ eps: float = 1e-8,
278
+ row_sum: bool = False,
279
+ ):
280
+ if row_sum:
281
+ # Create a tensor to hold the summed values
282
+ row_sum_logits = torch.zeros(
283
+ logits.shape[-1], device=logits.device
284
+ ) # (B_global)
285
+ row_pos_sum_logits = torch.zeros(
286
+ logits.shape[-1], device=logits.device
287
+ ) # (B_global)
288
+
289
+ # Use scatter_add to sum values based on group_map
290
+ row_sum_logits.scatter_add_(0, group_map, logits.sum(dim=1)) # (B_global)
291
+ row_pos_sum_logits.scatter_add_(0, group_map, pos_logits) # (B_global)
292
+ p_row = row_pos_sum_logits / (row_sum_logits + eps) # (B_global)
293
+ else:
294
+ row_sum_logits = logits.sum(dim=1) # (N_total)
295
+ p_row = pos_logits / (row_sum_logits + eps) # (N_total)
296
+
297
+ return -torch.log(p_row + eps)
298
+
299
+
300
+ def get_col_loss(
301
+ logits: torch.Tensor,
302
+ pos_logits: torch.Tensor,
303
+ neg_mask: torch.Tensor,
304
+ group_map: torch.Tensor,
305
+ eps: float = 1e-8,
306
+ col_sum: bool = False,
307
+ ):
308
+ if col_sum:
309
+ # MIL-NCE loss
310
+ column_sum_logits = logits.sum(dim=0) # (B_global,)
311
+ pos_mask = torch.ones_like(logits) - neg_mask # (N_total, B_global)
312
+ column_pos_logits = (logits * pos_mask).sum(dim=0) # (B_global,)
313
+ p_column = column_pos_logits / (column_sum_logits + eps) # (B_global,)
314
+ else:
315
+ # MP-NCE loss (UniCLIP)
316
+ neg_logits = logits * neg_mask # (N_total, B_global)
317
+ sum_neg_logits = neg_logits.sum(dim=0) # (B_global,)
318
+ sum_neg_logits = sum_neg_logits[group_map] # (N_total)
319
+ p_column = pos_logits / (pos_logits + sum_neg_logits + eps) # (N_total)
320
+
321
+ return -torch.log(p_column + eps)
322
+
323
+
324
+ def pad_and_gather(tensor):
325
+ # Determine the size of the tensor
326
+ local_size = torch.tensor(tensor.size(), device=tensor.device)
327
+
328
+ # Gather all sizes
329
+ all_sizes = [torch.zeros_like(local_size) for _ in range(dist.get_world_size())]
330
+ dist.all_gather(all_sizes, local_size)
331
+
332
+ # Determine the maximum size
333
+ max_size = torch.stack(all_sizes).max(dim=0)[0]
334
+
335
+ # Pad the tensor to the maximum size
336
+ padded_tensor = torch.zeros(max_size.tolist(), device=tensor.device)
337
+ padded_tensor[: local_size[0]] = tensor
338
+
339
+ # Gather all padded tensors
340
+ gathered_tensors = dist.nn.all_gather(padded_tensor)
341
+
342
+ # Trim the gathered tensors to their original sizes
343
+ gathered_tensors = [g[: s[0]] for g, s in zip(gathered_tensors, all_sizes)]
344
+
345
+ gathered_tensors = torch.cat(gathered_tensors, dim=0)
346
+
347
+ return gathered_tensors
radzero_modeling.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import AutoTokenizer, BertModel
6
+ from transformers.models.clip.modeling_clip import CLIPTextModel
7
+ from transformers.models.mpnet.modeling_mpnet import MPNetModel
8
+ from transformers.trainer import logger
9
+
10
+ from .align_transformers import build_align_transformer
11
+ from .common_layers import BasePreTrainedModel
12
+ from .configuration import CxrAlignConfig
13
+ from .losses import KeyPhraseAlignmentLoss
14
+ from .text_encoders import aggregate_tokens, build_text_encoder
15
+ from .vision_encoders import MRM, Dinov2Model, build_vision_encoder
16
+
17
+
18
+ class CxrAlignModel(BasePreTrainedModel):
19
+
20
+ config_class = CxrAlignConfig
21
+
22
+ def build_vision_model(self, config: CxrAlignConfig):
23
+ vision_config = config.vision_config
24
+ vision_config.pretrained_dir = config.pretrained_dir
25
+ vision_model = build_vision_encoder(vision_config)
26
+ return vision_model
27
+
28
+ def build_text_model(self, config: CxrAlignConfig):
29
+ text_config = config.text_config
30
+ text_model = build_text_encoder(text_config)
31
+
32
+ if text_config.model_type == "bioclinicalmpbert":
33
+ self.tokenizer = AutoTokenizer.from_pretrained(
34
+ text_config.pretrained_tokenizer_name_or_path
35
+ )
36
+ self.idxtoword = {v: k for k, v in self.tokenizer.get_vocab().items()}
37
+
38
+ return text_model
39
+
40
+ def build_align_transformer_model(self, config: CxrAlignConfig):
41
+ align_transformer_config = config.align_transformer_config
42
+ align_transformer = build_align_transformer(align_transformer_config)
43
+
44
+ return align_transformer
45
+
46
+ def __init__(self, config: CxrAlignConfig):
47
+ super().__init__(config)
48
+
49
+ logger.info("Build vision model ...")
50
+ self.vision_model = self.build_vision_model(config)
51
+
52
+ logger.info("Build text model ...")
53
+ self.text_model = self.build_text_model(config)
54
+
55
+ if (
56
+ isinstance(self.text_model, CLIPTextModel)
57
+ or isinstance(self.text_model, MPNetModel)
58
+ or isinstance(self.text_model, BertModel)
59
+ ):
60
+ text_dim = self.text_model.config.hidden_size
61
+
62
+ self.hidden_size = config.align_transformer_config.hidden_size
63
+
64
+ if config.text_config.use_text_projection:
65
+ self.text_projector = nn.Linear(text_dim, 2 * self.hidden_size)
66
+ else:
67
+ self.text_projector = None
68
+
69
+ logger.info("Build align transformer model ...")
70
+ self.align_transformer = self.build_align_transformer_model(config)
71
+
72
+ logger.info("Build loss functions ...")
73
+ loss_cfg = config.kwargs["loss"]
74
+ self.loss_ratio = dict()
75
+ self.loss_fns = nn.ModuleDict()
76
+ for loss_type, ratio in zip(loss_cfg["apply"], loss_cfg["ratio"]):
77
+ logger.info(f"Build {loss_type} loss function ...")
78
+ if loss_cfg[loss_type] is None:
79
+ loss_cfg[loss_type] = dict()
80
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
81
+ loss_cfg[loss_type]["rank"] = torch.distributed.get_rank()
82
+ loss_cfg[loss_type]["world_size"] = torch.distributed.get_world_size()
83
+ self.loss_fns[loss_type] = eval(loss_type)(**loss_cfg[loss_type])
84
+ self.loss_ratio[loss_type] = ratio
85
+
86
+ self.compute_logits_type = config.kwargs.get("compute_logits_type")
87
+ self.use_negative_logits = config.kwargs.get("use_negative_logits")
88
+
89
+ self.module_to_update = config.kwargs.get("module_to_update")
90
+
91
+ self.post_init()
92
+
93
+ def forward_vision_model(self, pixel_values):
94
+
95
+ if isinstance(self.vision_model, Dinov2Model):
96
+ vision_tokens = self.vision_model(pixel_values)["last_hidden_state"]
97
+ elif isinstance(self.vision_model, MRM):
98
+ img_emb_g, img_emb_l = self.vision_model(pixel_values)
99
+ img_emb_g = img_emb_g.unsqueeze(1)
100
+ img_emb_l = img_emb_l.view(img_emb_l.size(0), img_emb_l.size(1), -1)
101
+ img_emb_l = img_emb_l.permute(0, 2, 1)
102
+
103
+ vision_tokens = torch.cat([img_emb_g, img_emb_l], dim=1)
104
+ else:
105
+ raise NotImplementedError
106
+
107
+ vision_tokens = self.align_transformer(vision_tokens)
108
+
109
+ cls_token = vision_tokens[:, 0]
110
+ patch_tokens = vision_tokens[:, 1:]
111
+ image_features = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
112
+ image_features = F.normalize(image_features, p=2, dim=1)
113
+
114
+ outputs = {}
115
+ outputs["vision_tokens"] = vision_tokens
116
+ outputs["image_cls_token"] = cls_token
117
+ outputs["image_patch_tokens"] = patch_tokens
118
+ outputs["image_features"] = image_features
119
+
120
+ return outputs
121
+
122
+ def forward_text_model(self, encoded_input):
123
+ text_outputs = {}
124
+
125
+ if isinstance(self.text_model, MPNetModel):
126
+ model_output = self.text_model(
127
+ input_ids=encoded_input["input_ids"],
128
+ attention_mask=encoded_input["attention_mask"],
129
+ )
130
+
131
+ token_embeddings = model_output[
132
+ 0
133
+ ] # First element of model_output contains all token embeddings
134
+
135
+ # text embedding projection
136
+ if self.text_projector is not None:
137
+ token_embeddings = self.text_projector(token_embeddings)
138
+
139
+ # token_embeddings = self.text_projector(token_embeddings)
140
+ if self.config.text_config.use_cls_token:
141
+ text_features = token_embeddings[:, 0, :]
142
+
143
+ else:
144
+ # mean pooling
145
+ input_mask_expanded = (
146
+ encoded_input["attention_mask"]
147
+ .unsqueeze(-1)
148
+ .expand(token_embeddings.size())
149
+ .float()
150
+ )
151
+ text_features = torch.sum(
152
+ token_embeddings * input_mask_expanded, 1
153
+ ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
154
+
155
+ elif isinstance(self.text_model, BertModel):
156
+ # BioClinicalMPBERT
157
+
158
+ model_output = self.text_model(
159
+ input_ids=encoded_input["input_ids"],
160
+ attention_mask=encoded_input["attention_mask"],
161
+ token_type_ids=encoded_input.get("token_type_ids", None),
162
+ )
163
+
164
+ if self.config.text_config.use_cls_token:
165
+ text_features = model_output.last_hidden_state[:, 0, :]
166
+
167
+ elif self.config.text_config.use_aggregate_tokens:
168
+
169
+ all_embeddings = model_output[2]
170
+ embeddings = torch.stack(
171
+ all_embeddings[-self.config.text_config.last_n_layers :]
172
+ )
173
+ embeddings = embeddings.permute(1, 0, 2, 3)
174
+
175
+ embeddings, sents = aggregate_tokens(
176
+ embeddings, encoded_input["input_ids"], self.idxtoword
177
+ )
178
+ sent_embeddings = embeddings.mean(axis=2)
179
+
180
+ if self.config.text_config.aggregate_method == "sum":
181
+ word_embeddings = embeddings.sum(axis=1)
182
+ sent_embeddings = sent_embeddings.sum(axis=1)
183
+ elif self.config.text_config.aggregate_method == "mean":
184
+ word_embeddings = embeddings.mean(axis=1)
185
+ sent_embeddings = sent_embeddings.mean(axis=1)
186
+
187
+ word_embeddings = word_embeddings.permute(0, 2, 1)
188
+
189
+ text_features = sent_embeddings
190
+ text_outputs["word_embeddings"] = word_embeddings
191
+
192
+ else:
193
+ text_features = model_output.last_hidden_state
194
+ mask = encoded_input["attention_mask"].unsqueeze(-1).float()
195
+ text_features = torch.sum(text_features * mask, dim=1) / torch.clamp(
196
+ mask.sum(dim=1), min=1e-9
197
+ )
198
+
199
+ if self.text_projector is not None:
200
+ text_features = self.text_projector(text_features)
201
+
202
+ else:
203
+ raise NotImplementedError
204
+
205
+ text_outputs["text_features_wo_l2_norm"] = text_features
206
+ text_outputs["text_features"] = F.normalize(text_features, p=2, dim=1)
207
+
208
+ return text_outputs
209
+
210
+ def forward(
211
+ self,
212
+ pixel_values,
213
+ encoded_key_phrases=None,
214
+ return_loss=True,
215
+ **kwargs,
216
+ ):
217
+ vision_outputs = self.forward_vision_model(pixel_values)
218
+
219
+ outputs = {}
220
+ outputs.update(vision_outputs)
221
+
222
+ # Trainer's self.can_return_loss is True if 'return_loss' is in model's forward function
223
+ if return_loss:
224
+ loss = 0
225
+ losses = {}
226
+
227
+ for loss_type, loss_fn in self.loss_fns.items():
228
+ if isinstance(loss_fn, KeyPhraseAlignmentLoss):
229
+ loss_outputs = loss_fn(
230
+ encoded_key_phrases,
231
+ outputs["vision_tokens"],
232
+ self.forward_text_model,
233
+ )
234
+ key_phrase_alignment_losses = loss_outputs["losses"]
235
+ losses["key_phrase_alignment_loss"] = (
236
+ key_phrase_alignment_losses.pop("loss")
237
+ )
238
+ for loss_name, loss_value in key_phrase_alignment_losses.items():
239
+ losses[loss_name] = loss_value
240
+ loop_loss = losses["key_phrase_alignment_loss"]
241
+ else:
242
+ raise NotImplementedError
243
+
244
+ loss += loop_loss * self.loss_ratio[loss_type]
245
+
246
+ losses["loss"] = loss
247
+
248
+ outputs["losses"] = losses
249
+
250
+ return outputs
251
+
252
+ def compute_logits(
253
+ self,
254
+ pixel_values,
255
+ encoded_key_phrases,
256
+ **kwargs,
257
+ ):
258
+ vision_outputs = self.forward_vision_model(pixel_values)
259
+
260
+ outputs = {}
261
+
262
+ if self.compute_logits_type == "key_phrase_alignment":
263
+
264
+ splited_key_phrases = [
265
+ {
266
+ "input_ids": encoded_key_phrases[0]["input_ids"][i : i + 1],
267
+ "attention_mask": encoded_key_phrases[0]["attention_mask"][
268
+ i : i + 1
269
+ ],
270
+ }
271
+ for i in range(encoded_key_phrases[0]["input_ids"].size(0))
272
+ ]
273
+
274
+ loss_outputs = self.loss_fns["KeyPhraseAlignmentLoss"](
275
+ splited_key_phrases,
276
+ vision_outputs["vision_tokens"],
277
+ self.forward_text_model,
278
+ ddp_gather=False,
279
+ need_attn_weights=True,
280
+ compute_loss=False,
281
+ )
282
+ outputs.update(loss_outputs)
283
+
284
+ # mean attention weights from all layers
285
+ outputs["similarity_scores"] = torch.mean(
286
+ torch.stack(loss_outputs["t2i_attn_weights"]), dim=0
287
+ )
288
+
289
+ # remove attention score for cls token
290
+ if self.loss_fns["KeyPhraseAlignmentLoss"].use_vision_cls_token:
291
+ outputs["similarity_scores"] = outputs["similarity_scores"][:, :, 1:]
292
+
293
+ # compute logits
294
+ logits = loss_outputs["t2i_logits"]
295
+ logits = logits.T
296
+
297
+ logits = (
298
+ logits / self.loss_fns["KeyPhraseAlignmentLoss"].loss_temperature.exp()
299
+ )
300
+
301
+ outputs["logits"] = logits
302
+ return outputs
text_encoders.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import open_clip
2
+ import torch
3
+ from transformers import AutoModel
4
+ from transformers.models.clip.modeling_clip import CLIPTextModel
5
+ from transformers.models.siglip.modeling_siglip import SiglipTextModel
6
+
7
+
8
+ def build_text_encoder(config):
9
+ if config.model_type == "mpnet":
10
+ model = AutoModel.from_pretrained(config.pretrained_name_or_path)
11
+ else:
12
+ raise NotImplementedError()
13
+
14
+ return model
15
+
16
+
17
+ # Mean Pooling - Take attention mask into account for correct averaging
18
+ def mean_pooling(model_output, attention_mask):
19
+ token_embeddings = model_output[
20
+ 0
21
+ ] # First element of model_output contains all token embeddings
22
+ input_mask_expanded = (
23
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
24
+ )
25
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
26
+ input_mask_expanded.sum(1), min=1e-9
27
+ )
vision_encoders.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Dinov2Model
2
+
3
+
4
+ def build_vision_encoder(config):
5
+ if config.model_type == "dinov2":
6
+ model = Dinov2Model.from_pretrained(config.pretrained_name_or_path)
7
+
8
+ else:
9
+ raise NotImplementedError()
10
+
11
+ return model