File size: 9,925 Bytes
28693e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd66851
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple, Union, Literal, List
import math
import copy

class CLIPLoss(nn.Module):
    """CLIP风格的对比学习损失"""
    def __init__(self, temperature: float = 0.07, max_temperature: float = 100.0):
        super().__init__()
        self.temperature = temperature
        self.max_temperature = max_temperature
        # 初始化 logit_scale
        self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / temperature))

    def forward(
        self,
        image_features: torch.Tensor,
        text_features: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            image_features: [B, D]
            text_features: [B, D]
        """
        # 归一化
        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)
        
        # 限制 logit_scale 防止数值不稳定
        logit_scale = self.logit_scale.exp().clamp(max=self.max_temperature)
        logits_per_image = logit_scale * image_features @ text_features.T
        logits_per_text = logits_per_image.T
        
        # 标签: 对角线为正样本
        batch_size = image_features.shape[0]
        labels = torch.arange(batch_size, device=image_features.device)
        
        # 双向交叉熵
        loss_i2t = F.cross_entropy(logits_per_image, labels)
        loss_t2i = F.cross_entropy(logits_per_text, labels)
        
        total_loss = (loss_i2t + loss_t2i) / 2
        
        return total_loss, loss_i2t, loss_t2i

class SigLIPLoss(nn.Module):
    def __init__(self, init_temperature: float = 1.0, init_bias: float = -10.0):
        super().__init__()
        self.t_prime = nn.Parameter(torch.tensor(math.log(init_temperature))) 
        self.b = nn.Parameter(torch.tensor(init_bias))

    def forward(
        self,
        image_features: torch.Tensor,
        text_features: torch.Tensor
    ) -> torch.Tensor:
        # 归一化
        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)
        
        batch_size = image_features.shape[0]
        
        # Logits = exp(t) * (x @ yT) + b
        logits = image_features @ text_features.T * self.t_prime.exp() + self.b
        
        # 构造标签: 对角线为1,其余为-1
        labels = -torch.ones(batch_size, batch_size, device=image_features.device)
        labels += 2 * torch.eye(batch_size, device=image_features.device)

        loss = -F.logsigmoid(labels * logits).sum() / batch_size
        
        return loss

class InfoNCELoss(nn.Module):
    def __init__(self, temperature: float = 0.07):
        super().__init__()
        self.temperature = temperature

    def forward(
        self,
        query: torch.Tensor,
        positive_key: torch.Tensor,
        negative_keys: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Args:
            query: [B, D]
            positive_key: [B, D]
            negative_keys: [B, N, D] or None.
        """
        query = F.normalize(query, dim=-1)
        positive_key = F.normalize(positive_key, dim=-1)
        
        if negative_keys is not None:

            pos_sim = (query * positive_key).sum(dim=-1) / self.temperature
            
            negative_keys = F.normalize(negative_keys, dim=-1)
            # neg_sim: [B, N]
            neg_sim = (query.unsqueeze(1) * negative_keys).sum(dim=-1) / self.temperature
            
            # [B, 1 + N]
            logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)
            # 正样本在索引0
            labels = torch.zeros(query.shape[0], dtype=torch.long, device=query.device)
        else:
            logits = query @ positive_key.T / self.temperature
            labels = torch.arange(query.shape[0], dtype=torch.long, device=query.device)
        
        loss = F.cross_entropy(logits, labels)
        return loss

class ProjectionHead(nn.Module):
    def __init__(
        self, 
        input_dim: int, 
        embed_dim: int, 
        pooling_type: Literal['cls', 'mean', 'max', 'none'] = 'mean',
        exclude_first_token: bool = False
    ):
        super().__init__()
        self.pooling_type = pooling_type
        self.exclude_first_token = exclude_first_token
        
        self.net = nn.Sequential(
            nn.Linear(input_dim, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, embed_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 适配 3D 张量 [B, Seq, D] -> [B, D]
        if x.dim() == 3:
            if self.pooling_type == 'cls':
                x = x[:, 0, :]
            
            elif self.pooling_type == 'mean':
                if self.exclude_first_token and x.shape[1] > 1:
                    x = x[:, 1:, :].mean(dim=1)
                else:
                    x = x.mean(dim=1)
            
            elif self.pooling_type == 'max':
                if self.exclude_first_token and x.shape[1] > 1:
                    x = x[:, 1:, :].max(dim=1)[0]
                else:
                    x = x.max(dim=1)[0]
                    
            elif self.pooling_type == 'none':
                pass
        
        return self.net(x)

class MultiModalContrastiveLoss(nn.Module):
    def __init__(
        self,
        embed_dim: int = 512,
        input_dims: Union[int, Dict[str, int]] = 2048,
        temperature: float = 0.07,
        loss_type: str = 'clip',
        modality_config: Optional[Dict[str, str]] = None
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.loss_type = loss_type
        
        if loss_type == 'clip':
            self.loss_fn = CLIPLoss(temperature)
        elif loss_type == 'siglip':
            self.loss_fn = SigLIPLoss()
        else:
            self.loss_fn = InfoNCELoss(temperature)
        
        self.projectors = nn.ModuleDict()
        
        if modality_config is None:
            modality_config = {
                'text': 'cls',   
                'image': 'cls',  
                'audio': 'mean', 
                'video': 'mean'  
            }
        
        self.modality_config = modality_config
        
        # 初始化投影头
        for mod_name, pool_type in modality_config.items():
            dim = 0
            if isinstance(input_dims, dict):
                dim = input_dims.get(mod_name)
                # 如果字典里没给这个模态的维度,跳过初始化,避免 crash
                if dim is None:
                    continue
            else:
                dim = input_dims
            
            exclude_first = False
            if mod_name in ['image', 'text'] and pool_type in ['mean', 'max']:
                exclude_first = True

            self.projectors[mod_name] = ProjectionHead(
                input_dim=dim,
                embed_dim=embed_dim,
                pooling_type=pool_type,
                exclude_first_token=exclude_first
            )

    def forward(
        self,
        features: Dict[str, torch.Tensor],
        modality_pairs: Optional[List[Tuple[str, str]]] = None
    ) -> Dict[str, torch.Tensor]:
        
        # 自动生成对比对:将所有非Text模态与Text对比
        if modality_pairs is None:
            if 'text' in features:
                modality_pairs = [
                    (mod, 'text') for mod in features.keys() if mod != 'text'
                ]
            else:
                return {}
        
        losses = {}
        
        for mod_a, mod_b in modality_pairs:
            if mod_a not in features or mod_b not in features:
                continue
            
            if mod_a not in self.projectors or mod_b not in self.projectors:
                # 记录警告或跳过
                continue

            feat_a = self.projectors[mod_a](features[mod_a])
            feat_b = self.projectors[mod_b](features[mod_b])
            
            # 计算损失
            loss_key = f'{mod_a}_{mod_b}_loss'
            
            if self.loss_type == 'clip':
                loss, _, _ = self.loss_fn(feat_a, feat_b)
            else:
                loss = self.loss_fn(feat_a, feat_b)
            
            losses[loss_key] = loss
        
        return losses

class MomentumEncoder(nn.Module):
    def __init__(self, encoder: nn.Module, momentum: float = 0.999):
        super().__init__()
        self.encoder = encoder
        self.momentum_encoder = self._build_momentum_encoder(encoder)
        self.momentum = momentum

    def _build_momentum_encoder(self, encoder: nn.Module) -> nn.Module:
        """构建动量编码器"""
        momentum_encoder = copy.deepcopy(encoder)
        
        # 冻结动量编码器参数
        for param in momentum_encoder.parameters():
            param.requires_grad = False
        
        return momentum_encoder

    @torch.no_grad()
    def _update_momentum_encoder(self):
        for param_q, param_k in zip(
            self.encoder.parameters(),
            self.momentum_encoder.parameters()
        ):
            # EMA Update: k = m * k + (1 - m) * q
            param_k.data.mul_(self.momentum).add_(param_q.data, alpha=1.0 - self.momentum)
        
        for buffer_q, buffer_k in zip(
            self.encoder.buffers(),
            self.momentum_encoder.buffers()
        ):
            buffer_k.data.copy_(buffer_q.data)

    def forward(self, x: torch.Tensor, use_momentum: bool = False) -> torch.Tensor:
        if use_momentum:
            with torch.no_grad():
                self._update_momentum_encoder()
                # 动量编码器始终处于 eval 模式
                self.momentum_encoder.eval()
                return self.momentum_encoder(x)
        else:
            return self.encoder(x)