Safetensors
custom_code
File size: 10,706 Bytes
4333430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
import math

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F


class KeyPhraseAlignmentLoss(nn.Module):

    def __init__(
        self,
        hidden_dim=768,
        use_vision_cls_token=True,
        attn_temperature=None,
        loss_temperature=0.07,
        text_features_l2_norm=False,
        mpnce_row_sum=False,
        mpnce_col_sum=False,
        sim_op="cos",
        use_layer_norm=True,
        **kwargs,
    ):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.layer_norm = nn.LayerNorm(hidden_dim) if use_layer_norm else None

        self.use_vision_cls_token = use_vision_cls_token
        self.loss_temperature = nn.Parameter(
            torch.FloatTensor([np.log(loss_temperature)])
        )
        if attn_temperature is not None:
            self.attn_temperature = nn.Parameter(
                torch.FloatTensor([np.log(attn_temperature)])
            )
        else:
            self.attn_temperature = None
        self.text_features_l2_norm = text_features_l2_norm
        self.sim_op = sim_op

        self.similarity_logit = SimilarityLogit(sim_op)

        self.mpnce_row_sum = mpnce_row_sum
        self.mpnce_col_sum = mpnce_col_sum

    def forward(
        self,
        key_phrases,
        vision_tokens,
        forward_text_model,
        ddp_gather=True,
        need_attn_weights=False,
        compute_loss=True,
        **kwargs,
    ):
        outputs = {}

        text_features, group_map = self.compute_text_features(
            key_phrases, forward_text_model, ddp_gather
        )

        if ddp_gather and dist.is_initialized():
            vision_tokens = torch.cat(dist.nn.all_gather(vision_tokens), dim=0)

        if self.layer_norm is not None:
            vision_tokens = self.layer_norm(vision_tokens)

        vision_patch_tokens = vision_tokens[:, 1:]

        # text to image cross-attention
        if not self.use_vision_cls_token:
            vision_attn_tokens = vision_patch_tokens
        else:
            vision_attn_tokens = vision_tokens

        t2i_logits, t2i_attn_weights_list = self.compute_t2i_logits(
            text_features, vision_attn_tokens, need_attn_weights
        )
        outputs["t2i_logits"] = t2i_logits
        outputs["t2i_attn_weights"] = t2i_attn_weights_list

        if compute_loss:
            losses = {}
            loss = 0

            # compute t2i loss
            t2i_loss = multi_positive_nce_loss(
                t2i_logits,
                group_map,
                temperature=self.loss_temperature.exp(),
                row_sum=self.mpnce_row_sum,
                col_sum=self.mpnce_col_sum,
            )
            loss += t2i_loss
            losses["t2i_loss"] = t2i_loss

            losses["loss"] = loss
            outputs["losses"] = losses
        return outputs

    def compute_text_features(self, key_phrases, forward_text_model, ddp_gather=True):

        key_text_features_list = list()
        group_list = list()

        B_local = len(key_phrases)
        # Calculate offset by getting the rank of the current process when using DDP
        local_rank = dist.get_rank() if (ddp_gather and dist.is_initialized()) else 0

        for i, kp in enumerate(key_phrases):
            feats = forward_text_model(kp)

            # (N_i, D)
            if self.text_features_l2_norm:
                feat = feats["text_features"]
            else:
                feat = feats["text_features_wo_l2_norm"]

            if feat.shape[-1] == 2 * self.hidden_dim:
                feat = feat[:, self.hidden_dim :]

            key_text_features_list.append(feat)

            # Add local_rank * B_local offset to local index i
            global_index = i + local_rank * B_local
            group_list.extend([global_index] * feat.size(0))

        text_features = torch.cat(key_text_features_list, dim=0)
        group_map = torch.tensor(group_list, device=text_features.device)

        if ddp_gather and dist.is_initialized():
            # Gather text_features and image_features and group_map
            text_features = pad_and_gather(text_features)

            group_map = pad_and_gather(group_map)
            group_map = group_map.long()

        if self.layer_norm is not None:
            text_features = self.layer_norm(text_features)

        return text_features, group_map

    def compute_t2i_logits(
        self, text_features, vision_attn_tokens, need_attn_weights, repeat=True
    ):

        t2i_logits, t2i_attn_weights_list = self.similarity_logit(
            text_features,
            vision_attn_tokens,
            need_attn_weights,
            repeat=repeat,
            temperature=(
                self.attn_temperature.exp()
                if self.attn_temperature is not None
                else self.loss_temperature.exp()
            ),
        )

        return t2i_logits, t2i_attn_weights_list


class SimilarityLogit(nn.Module):
    def __init__(self, sim_op="dot", **kwargs):
        super().__init__()
        self.sim_op = sim_op

    def forward(
        self,
        queries: torch.Tensor,
        local_tokens: torch.Tensor,
        need_attn_weights: bool = False,
        repeat: bool = True,
        **kwargs,
    ):
        if repeat:
            query_attn_features = queries.unsqueeze(0).expand(
                local_tokens.shape[0], queries.shape[0], queries.shape[1]
            )
        else:
            assert queries.dim() == 3
            query_attn_features = queries

        if self.sim_op == "cos":
            temperature = kwargs.get("temperature")
            assert temperature is not None
            denominator = temperature
            query_attn_features = F.normalize(query_attn_features, p=2, dim=-1)
            local_tokens = F.normalize(local_tokens, p=2, dim=-1)
        elif self.sim_op == "dot":
            denominator = math.sqrt(local_tokens.size(-1))
        else:
            raise NotImplementedError

        scores = (
            torch.bmm(query_attn_features, local_tokens.permute(0, 2, 1)) / denominator
        )
        attn_weights = F.softmax(scores, dim=-1)

        aggregated = torch.matmul(attn_weights, local_tokens)

        query_attn_features = F.normalize(query_attn_features, p=2, dim=-1)
        aggregated = F.normalize(aggregated, p=2, dim=-1)

        logits = torch.matmul(
            query_attn_features.unsqueeze(2), aggregated.unsqueeze(-1)
        ).squeeze()

        logits = logits.T

        if need_attn_weights:
            attn_scores = [scores]
        else:
            attn_scores = None

        return logits, attn_scores


def multi_positive_nce_loss(
    logits: torch.Tensor,
    group_map: torch.Tensor,
    temperature: float = 1.0,
    eps: float = 1e-8,
    row_sum: bool = False,
    col_sum: bool = False,
):
    """
    Args:
        logits: tensor of shape (N_total, B_global), each row is a logit between a key phrase and each candidate image.
        group_map: tensor of shape (N_total,), source image index of each key phrase.
        temperature: scaling factor.

    For each key phrase row i, the positive is the candidate image index == group_map[i],
    and the rest are treated as negatives.

    For each column j, each positive for image j is considered independently.

    Returns:
        loss: scalar tensor.
    """
    scaled_logits = torch.exp(logits / temperature)  # (N_total, B_global)

    pos_logits = scaled_logits[
        torch.arange(scaled_logits.size(0)), group_map
    ]  # (N_total,)

    row_loss = get_row_loss(
        scaled_logits,
        pos_logits,
        group_map,
        eps,
        row_sum,
    )

    neg_mask = torch.ones_like(scaled_logits)
    neg_mask[torch.arange(scaled_logits.size(0)), group_map] = 0  # (N_total, B_global)

    column_loss = get_col_loss(
        scaled_logits,
        pos_logits,
        neg_mask,
        group_map,
        eps,
        col_sum,
    )

    loss = (row_loss.mean() + column_loss.mean()) / 2

    return loss


def get_row_loss(
    logits: torch.Tensor,
    pos_logits: torch.Tensor,
    group_map: torch.Tensor,
    eps: float = 1e-8,
    row_sum: bool = False,
):
    if row_sum:
        # Create a tensor to hold the summed values
        row_sum_logits = torch.zeros(
            logits.shape[-1], device=logits.device
        )  # (B_global)
        row_pos_sum_logits = torch.zeros(
            logits.shape[-1], device=logits.device
        )  # (B_global)

        # Use scatter_add to sum values based on group_map
        row_sum_logits.scatter_add_(0, group_map, logits.sum(dim=1))  # (B_global)
        row_pos_sum_logits.scatter_add_(0, group_map, pos_logits)  # (B_global)
        p_row = row_pos_sum_logits / (row_sum_logits + eps)  # (B_global)
    else:
        row_sum_logits = logits.sum(dim=1)  # (N_total)
        p_row = pos_logits / (row_sum_logits + eps)  # (N_total)

    return -torch.log(p_row + eps)


def get_col_loss(
    logits: torch.Tensor,
    pos_logits: torch.Tensor,
    neg_mask: torch.Tensor,
    group_map: torch.Tensor,
    eps: float = 1e-8,
    col_sum: bool = False,
):
    if col_sum:
        # MIL-NCE loss
        column_sum_logits = logits.sum(dim=0)  # (B_global,)
        pos_mask = torch.ones_like(logits) - neg_mask  # (N_total, B_global)
        column_pos_logits = (logits * pos_mask).sum(dim=0)  # (B_global,)
        p_column = column_pos_logits / (column_sum_logits + eps)  # (B_global,)
    else:
        # MP-NCE loss (UniCLIP)
        neg_logits = logits * neg_mask  # (N_total, B_global)
        sum_neg_logits = neg_logits.sum(dim=0)  # (B_global,)
        sum_neg_logits = sum_neg_logits[group_map]  # (N_total)
        p_column = pos_logits / (pos_logits + sum_neg_logits + eps)  # (N_total)

    return -torch.log(p_column + eps)


def pad_and_gather(tensor):
    # Determine the size of the tensor
    local_size = torch.tensor(tensor.size(), device=tensor.device)

    # Gather all sizes
    all_sizes = [torch.zeros_like(local_size) for _ in range(dist.get_world_size())]
    dist.all_gather(all_sizes, local_size)

    # Determine the maximum size
    max_size = torch.stack(all_sizes).max(dim=0)[0]

    # Pad the tensor to the maximum size
    padded_tensor = torch.zeros(max_size.tolist(), device=tensor.device)
    padded_tensor[: local_size[0]] = tensor

    # Gather all padded tensors
    gathered_tensors = dist.nn.all_gather(padded_tensor)

    # Trim the gathered tensors to their original sizes
    gathered_tensors = [g[: s[0]] for g, s in zip(gathered_tensors, all_sizes)]

    gathered_tensors = torch.cat(gathered_tensors, dim=0)

    return gathered_tensors