Safetensors
custom_code
File size: 8,635 Bytes
4333430
 
 
6205663
4333430
 
 
 
 
 
7dbbacb
4333430
6205663
 
4333430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel
from transformers.models.clip.modeling_clip import CLIPTextModel
from transformers.models.mpnet.modeling_mpnet import MPNetModel
from transformers.trainer import logger

from .align_transformers import build_align_transformer
from .common_layers import BasePreTrainedModel
from .configuration_radzero import CxrAlignConfig
from .losses import KeyPhraseAlignmentLoss
from .text_encoders import build_text_encoder
from .vision_encoders import Dinov2Model, build_vision_encoder


class CxrAlignModel(BasePreTrainedModel):

    config_class = CxrAlignConfig

    def build_vision_model(self, config: CxrAlignConfig):
        vision_config = config.vision_config
        vision_config.pretrained_dir = config.pretrained_dir
        vision_model = build_vision_encoder(vision_config)
        return vision_model

    def build_text_model(self, config: CxrAlignConfig):
        text_config = config.text_config
        text_model = build_text_encoder(text_config)
        return text_model

    def build_align_transformer_model(self, config: CxrAlignConfig):
        align_transformer_config = config.align_transformer_config
        align_transformer = build_align_transformer(align_transformer_config)

        return align_transformer

    def __init__(self, config: CxrAlignConfig):
        super().__init__(config)

        logger.info("Build vision model ...")
        self.vision_model = self.build_vision_model(config)

        logger.info("Build text model ...")
        self.text_model = self.build_text_model(config)

        if (
            isinstance(self.text_model, CLIPTextModel)
            or isinstance(self.text_model, MPNetModel)
            or isinstance(self.text_model, BertModel)
        ):
            text_dim = self.text_model.config.hidden_size

        self.hidden_size = config.align_transformer_config.hidden_size

        if config.text_config.use_text_projection:
            self.text_projector = nn.Linear(text_dim, 2 * self.hidden_size)
        else:
            self.text_projector = None

        logger.info("Build align transformer model ...")
        self.align_transformer = self.build_align_transformer_model(config)

        logger.info("Build loss functions ...")
        loss_cfg = config.kwargs["loss"]
        self.loss_ratio = dict()
        self.loss_fns = nn.ModuleDict()
        for loss_type, ratio in zip(loss_cfg["apply"], loss_cfg["ratio"]):
            logger.info(f"Build {loss_type} loss function ...")
            if loss_cfg[loss_type] is None:
                loss_cfg[loss_type] = dict()
            if torch.distributed.is_available() and torch.distributed.is_initialized():
                loss_cfg[loss_type]["rank"] = torch.distributed.get_rank()
                loss_cfg[loss_type]["world_size"] = torch.distributed.get_world_size()
            self.loss_fns[loss_type] = eval(loss_type)(**loss_cfg[loss_type])
            self.loss_ratio[loss_type] = ratio

        self.compute_logits_type = config.kwargs.get("compute_logits_type")
        self.use_negative_logits = config.kwargs.get("use_negative_logits")

        self.module_to_update = config.kwargs.get("module_to_update")

    def forward_vision_model(self, pixel_values):

        if isinstance(self.vision_model, Dinov2Model):
            vision_tokens = self.vision_model(pixel_values)["last_hidden_state"]

        else:
            raise NotImplementedError

        vision_tokens = self.align_transformer(vision_tokens)

        cls_token = vision_tokens[:, 0]
        patch_tokens = vision_tokens[:, 1:]
        image_features = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
        image_features = F.normalize(image_features, p=2, dim=1)

        outputs = {}
        outputs["vision_tokens"] = vision_tokens
        outputs["image_cls_token"] = cls_token
        outputs["image_patch_tokens"] = patch_tokens
        outputs["image_features"] = image_features

        return outputs

    def forward_text_model(self, encoded_input):
        text_outputs = {}

        if isinstance(self.text_model, MPNetModel):
            model_output = self.text_model(
                input_ids=encoded_input["input_ids"],
                attention_mask=encoded_input["attention_mask"],
            )

            token_embeddings = model_output[
                0
            ]  # First element of model_output contains all token embeddings

            # text embedding projection
            if self.text_projector is not None:
                token_embeddings = self.text_projector(token_embeddings)

            # token_embeddings = self.text_projector(token_embeddings)
            if self.config.text_config.use_cls_token:
                text_features = token_embeddings[:, 0, :]

            else:
                # mean pooling
                input_mask_expanded = (
                    encoded_input["attention_mask"]
                    .unsqueeze(-1)
                    .expand(token_embeddings.size())
                    .float()
                )
                text_features = torch.sum(
                    token_embeddings * input_mask_expanded, 1
                ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

        else:
            raise NotImplementedError

        text_outputs["text_features_wo_l2_norm"] = text_features
        text_outputs["text_features"] = F.normalize(text_features, p=2, dim=1)

        return text_outputs

    def forward(
        self,
        pixel_values,
        encoded_key_phrases=None,
        return_loss=True,
        **kwargs,
    ):
        vision_outputs = self.forward_vision_model(pixel_values)

        outputs = {}
        outputs.update(vision_outputs)

        # Trainer's self.can_return_loss is True if 'return_loss' is in model's forward function
        if return_loss:
            loss = 0
            losses = {}

            for loss_type, loss_fn in self.loss_fns.items():
                if isinstance(loss_fn, KeyPhraseAlignmentLoss):
                    loss_outputs = loss_fn(
                        encoded_key_phrases,
                        outputs["vision_tokens"],
                        self.forward_text_model,
                    )
                    key_phrase_alignment_losses = loss_outputs["losses"]
                    losses["key_phrase_alignment_loss"] = (
                        key_phrase_alignment_losses.pop("loss")
                    )
                    for loss_name, loss_value in key_phrase_alignment_losses.items():
                        losses[loss_name] = loss_value
                    loop_loss = losses["key_phrase_alignment_loss"]
                else:
                    raise NotImplementedError

                loss += loop_loss * self.loss_ratio[loss_type]

            losses["loss"] = loss

            outputs["losses"] = losses

        return outputs

    def compute_logits(
        self,
        pixel_values,
        encoded_key_phrases,
        **kwargs,
    ):
        vision_outputs = self.forward_vision_model(pixel_values)

        outputs = {}

        if self.compute_logits_type == "key_phrase_alignment":

            splited_key_phrases = [
                {
                    "input_ids": encoded_key_phrases[0]["input_ids"][i : i + 1],
                    "attention_mask": encoded_key_phrases[0]["attention_mask"][
                        i : i + 1
                    ],
                }
                for i in range(encoded_key_phrases[0]["input_ids"].size(0))
            ]

            loss_outputs = self.loss_fns["KeyPhraseAlignmentLoss"](
                splited_key_phrases,
                vision_outputs["vision_tokens"],
                self.forward_text_model,
                ddp_gather=False,
                need_attn_weights=True,
                compute_loss=False,
            )
            outputs.update(loss_outputs)

            # mean attention weights from all layers
            outputs["similarity_scores"] = torch.mean(
                torch.stack(loss_outputs["t2i_attn_weights"]), dim=0
            )

            # remove attention score for cls token
            if self.loss_fns["KeyPhraseAlignmentLoss"].use_vision_cls_token:
                outputs["similarity_scores"] = outputs["similarity_scores"][:, :, 1:]

            # compute logits
            logits = loss_outputs["t2i_logits"]
            logits = logits.T

            logits = (
                logits / self.loss_fns["KeyPhraseAlignmentLoss"].loss_temperature.exp()
            )

        outputs["logits"] = logits
        return outputs