szxllm commited on
Commit
28693e2
·
verified ·
1 Parent(s): 4bfa065

Update contrastive_learning.py

Browse files
Files changed (1) hide show
  1. contrastive_learning.py +289 -338
contrastive_learning.py CHANGED
@@ -1,339 +1,290 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from typing import Dict, Optional, Tuple, Union, Literal, List
5
- import math
6
- import copy
7
-
8
- class CLIPLoss(nn.Module):
9
- """CLIP风格的对比学习损失"""
10
- def __init__(self, temperature: float = 0.07, max_temperature: float = 100.0):
11
- super().__init__()
12
- self.temperature = temperature
13
- self.max_temperature = max_temperature
14
- # 初始化 logit_scale
15
- self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / temperature))
16
-
17
- def forward(
18
- self,
19
- image_features: torch.Tensor,
20
- text_features: torch.Tensor
21
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
22
- """
23
- Args:
24
- image_features: [B, D]
25
- text_features: [B, D]
26
- """
27
- # 归一化
28
- image_features = F.normalize(image_features, dim=-1)
29
- text_features = F.normalize(text_features, dim=-1)
30
-
31
- # 限制 logit_scale 防止数值不稳定
32
- logit_scale = self.logit_scale.exp().clamp(max=self.max_temperature)
33
-
34
- # 计算相似度矩阵 [B, B]
35
- # 注意:在 DDP 环境下,这里计算的是局部 Batch 的 Loss。
36
- # 完整的 DDP 实现需要 gather 所有 GPU 的 features。
37
- logits_per_image = logit_scale * image_features @ text_features.T
38
- logits_per_text = logits_per_image.T
39
-
40
- # 标签: 对角线为正样本
41
- batch_size = image_features.shape[0]
42
- labels = torch.arange(batch_size, device=image_features.device)
43
-
44
- # 双向交叉熵
45
- loss_i2t = F.cross_entropy(logits_per_image, labels)
46
- loss_t2i = F.cross_entropy(logits_per_text, labels)
47
-
48
- total_loss = (loss_i2t + loss_t2i) / 2
49
-
50
- return total_loss, loss_i2t, loss_t2i
51
-
52
- class SigLIPLoss(nn.Module):
53
- """
54
- SigLIP损失 - 包含可学习的 Bias 和 Scale
55
- Paper: Sigmoid Loss for Language Image Pre-Training
56
- """
57
- def __init__(self, init_temperature: float = 1.0, init_bias: float = -10.0):
58
- super().__init__()
59
- self.t_prime = nn.Parameter(torch.tensor(math.log(init_temperature)))
60
- self.b = nn.Parameter(torch.tensor(init_bias))
61
-
62
- def forward(
63
- self,
64
- image_features: torch.Tensor,
65
- text_features: torch.Tensor
66
- ) -> torch.Tensor:
67
- """
68
- 注意:SigLIP 的标准实现不需要 Gather 全局负样本即可收敛,
69
- 但这里实现的是 dense pair loss。对于超大 Batch (如 8k+)
70
- 构造 [B, B] 的 labels 矩阵会导致显存爆炸,生产环境建议使用 custom kernel block chunking。
71
- """
72
- # 归一化
73
- image_features = F.normalize(image_features, dim=-1)
74
- text_features = F.normalize(text_features, dim=-1)
75
-
76
- batch_size = image_features.shape[0]
77
-
78
- # Logits = exp(t) * (x @ yT) + b
79
- logits = image_features @ text_features.T * self.t_prime.exp() + self.b
80
-
81
- # 构造标签: 对角线为1,其余为-1
82
- labels = -torch.ones(batch_size, batch_size, device=image_features.device)
83
- labels += 2 * torch.eye(batch_size, device=image_features.device)
84
-
85
- # Sigmoid Loss: -log(sigmoid(label * logits))
86
- # label=1: -log(sigmoid(z))
87
- # 当 label=-1: -log(sigmoid(-z)) = -log(1 - sigmoid(z))
88
- # 这就是标准的 Binary Cross Entropy (Summed)
89
-
90
- # SigLIP 论文中通常建议除以 batch_size (或正样本数量) 进行归一化
91
- loss = -F.logsigmoid(labels * logits).sum() / batch_size
92
-
93
- return loss
94
-
95
- class InfoNCELoss(nn.Module):
96
- """InfoNCE损失 - 支持显式负样本或 Batch 内负样本"""
97
- def __init__(self, temperature: float = 0.07):
98
- super().__init__()
99
- self.temperature = temperature
100
-
101
- def forward(
102
- self,
103
- query: torch.Tensor,
104
- positive_key: torch.Tensor,
105
- negative_keys: Optional[torch.Tensor] = None
106
- ) -> torch.Tensor:
107
- """
108
- Args:
109
- query: [B, D]
110
- positive_key: [B, D]
111
- negative_keys: [B, N, D] or None.
112
- """
113
- query = F.normalize(query, dim=-1)
114
- positive_key = F.normalize(positive_key, dim=-1)
115
-
116
- if negative_keys is not None:
117
- # 显式负样本
118
- # pos_sim: [B]
119
- pos_sim = (query * positive_key).sum(dim=-1) / self.temperature
120
-
121
- negative_keys = F.normalize(negative_keys, dim=-1)
122
- # neg_sim: [B, N]
123
- neg_sim = (query.unsqueeze(1) * negative_keys).sum(dim=-1) / self.temperature
124
-
125
- # [B, 1 + N]
126
- logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)
127
- # 正样本在索引0
128
- labels = torch.zeros(query.shape[0], dtype=torch.long, device=query.device)
129
- else:
130
- # Batch内负样本 (类似于 CLIP 的单向 Loss)
131
- logits = query @ positive_key.T / self.temperature
132
- labels = torch.arange(query.shape[0], dtype=torch.long, device=query.device)
133
-
134
- loss = F.cross_entropy(logits, labels)
135
- return loss
136
-
137
- class ProjectionHead(nn.Module):
138
- """
139
- 投影头:处理特征维度变换和形状适配
140
- 针对 Transformer 输出 (Sequence) 提供了更精细的 Pooling 控制。
141
- """
142
- def __init__(
143
- self,
144
- input_dim: int,
145
- embed_dim: int,
146
- pooling_type: Literal['cls', 'mean', 'max', 'none'] = 'mean',
147
- exclude_first_token: bool = False
148
- ):
149
- super().__init__()
150
- self.pooling_type = pooling_type
151
- self.exclude_first_token = exclude_first_token
152
-
153
- self.net = nn.Sequential(
154
- nn.Linear(input_dim, embed_dim),
155
- nn.GELU(),
156
- nn.Linear(embed_dim, embed_dim)
157
- )
158
-
159
- def forward(self, x: torch.Tensor) -> torch.Tensor:
160
- # 适配 3D 张量 [B, Seq, D] -> [B, D]
161
- if x.dim() == 3:
162
- if self.pooling_type == 'cls':
163
- # 假设索引0是CLS token (Standard ViT / BERT)
164
- x = x[:, 0, :]
165
-
166
- elif self.pooling_type == 'mean':
167
- if self.exclude_first_token and x.shape[1] > 1:
168
- # 对于 ViT,如果使用 mean pooling,通常需要排除 CLS token
169
- x = x[:, 1:, :].mean(dim=1)
170
- else:
171
- x = x.mean(dim=1)
172
-
173
- elif self.pooling_type == 'max':
174
- if self.exclude_first_token and x.shape[1] > 1:
175
- x = x[:, 1:, :].max(dim=1)[0]
176
- else:
177
- x = x.max(dim=1)[0]
178
-
179
- elif self.pooling_type == 'none':
180
- # 保留序列维度,适用于 Dense Prediction 或细粒度对比
181
- # 此时输出为 [B, Seq, embed_dim]
182
- pass
183
-
184
- return self.net(x)
185
-
186
- class MultiModalContrastiveLoss(nn.Module):
187
- """多模态对比学习损失 - 支持动态模态和异构维度"""
188
- def __init__(
189
- self,
190
- embed_dim: int = 512,
191
- input_dims: Union[int, Dict[str, int]] = 2048,
192
- temperature: float = 0.07,
193
- loss_type: str = 'clip',
194
- modality_config: Optional[Dict[str, str]] = None
195
- ):
196
- super().__init__()
197
- self.embed_dim = embed_dim
198
- self.loss_type = loss_type
199
-
200
- if loss_type == 'clip':
201
- self.loss_fn = CLIPLoss(temperature)
202
- elif loss_type == 'siglip':
203
- self.loss_fn = SigLIPLoss()
204
- else:
205
- self.loss_fn = InfoNCELoss(temperature)
206
-
207
- self.projectors = nn.ModuleDict()
208
-
209
- if modality_config is None:
210
- # 默认常用模态配置
211
- # 注意:ImprovedVisionTransformer 输出带 CLS,所以图像推荐用 'cls' 或带排除的 'mean'
212
- modality_config = {
213
- 'text': 'cls',
214
- 'image': 'cls',
215
- 'audio': 'mean', # AudioEncoder 的双流输出已经是 2D,但如果是纯 Transformer 输出则是 3D
216
- 'video': 'mean' # VideoEncoder 输出通常是 [B, T, D]
217
- }
218
-
219
- self.modality_config = modality_config
220
-
221
- # 初始化投影头
222
- for mod_name, pool_type in modality_config.items():
223
- dim = 0
224
- if isinstance(input_dims, dict):
225
- dim = input_dims.get(mod_name)
226
- # 如果字典里没给这个模态的维度,跳过初始化,避免 crash
227
- if dim is None:
228
- continue
229
- else:
230
- dim = input_dims
231
-
232
- # 特殊处理:如果是 'mean' 或 'max' 且是 image/text,可能需要排除 CLS
233
- # 这里做一个启发式判断,用户也可以手动修改
234
- exclude_first = False
235
- if mod_name in ['image', 'text'] and pool_type in ['mean', 'max']:
236
- exclude_first = True
237
-
238
- self.projectors[mod_name] = ProjectionHead(
239
- input_dim=dim,
240
- embed_dim=embed_dim,
241
- pooling_type=pool_type,
242
- exclude_first_token=exclude_first
243
- )
244
-
245
- def forward(
246
- self,
247
- features: Dict[str, torch.Tensor],
248
- modality_pairs: Optional[List[Tuple[str, str]]] = None
249
- ) -> Dict[str, torch.Tensor]:
250
-
251
- # 自动生成对比对:将所有非Text模态与Text对比
252
- if modality_pairs is None:
253
- if 'text' in features:
254
- modality_pairs = [
255
- (mod, 'text') for mod in features.keys() if mod != 'text'
256
- ]
257
- else:
258
- return {}
259
-
260
- losses = {}
261
-
262
- for mod_a, mod_b in modality_pairs:
263
- if mod_a not in features or mod_b not in features:
264
- continue
265
-
266
- if mod_a not in self.projectors or mod_b not in self.projectors:
267
- # 记录警告或跳过
268
- continue
269
-
270
- feat_a = self.projectors[mod_a](features[mod_a])
271
- feat_b = self.projectors[mod_b](features[mod_b])
272
-
273
- # 计算损失
274
- loss_key = f'{mod_a}_{mod_b}_loss'
275
-
276
- if self.loss_type == 'clip':
277
- loss, _, _ = self.loss_fn(feat_a, feat_b)
278
- else:
279
- loss = self.loss_fn(feat_a, feat_b)
280
-
281
- losses[loss_key] = loss
282
-
283
- return losses
284
-
285
- class MomentumEncoder(nn.Module):
286
- """
287
- 动量编码器 - 用于MoCo风格的对比学习
288
- 支持参数和 Buffer (如 BatchNorm stats) 的动量更新
289
- """
290
- def __init__(self, encoder: nn.Module, momentum: float = 0.999):
291
- super().__init__()
292
- self.encoder = encoder
293
- self.momentum_encoder = self._build_momentum_encoder(encoder)
294
- self.momentum = momentum
295
-
296
- def _build_momentum_encoder(self, encoder: nn.Module) -> nn.Module:
297
- """构建动量编码器"""
298
- momentum_encoder = copy.deepcopy(encoder)
299
-
300
- # 冻结动量编码器参数
301
- for param in momentum_encoder.parameters():
302
- param.requires_grad = False
303
-
304
- return momentum_encoder
305
-
306
- @torch.no_grad()
307
- def _update_momentum_encoder(self):
308
- """更新动量编码器 (In-place update)"""
309
- # 更新参数
310
- for param_q, param_k in zip(
311
- self.encoder.parameters(),
312
- self.momentum_encoder.parameters()
313
- ):
314
- # EMA Update: k = m * k + (1 - m) * q
315
- param_k.data.mul_(self.momentum).add_(param_q.data, alpha=1.0 - self.momentum)
316
-
317
- # 更新 Buffers (如 BatchNorm running mean/var)
318
- # 简单的策略是直接覆盖,或者同样使用 EMA。通常直接覆盖即可,
319
- # 因为 Key Encoder 处于 Eval 模式,不追踪 batch stats。
320
- for buffer_q, buffer_k in zip(
321
- self.encoder.buffers(),
322
- self.momentum_encoder.buffers()
323
- ):
324
- buffer_k.data.copy_(buffer_q.data)
325
-
326
- def forward(self, x: torch.Tensor, use_momentum: bool = False) -> torch.Tensor:
327
- """
328
- Args:
329
- x: 输入数据
330
- use_momentum: 如果为 True,使用动量编码器 (通常用于生成 Key/Target)
331
- """
332
- if use_momentum:
333
- with torch.no_grad():
334
- self._update_momentum_encoder()
335
- # 动量编码器始终处于 eval 模式
336
- self.momentum_encoder.eval()
337
- return self.momentum_encoder(x)
338
- else:
339
  return self.encoder(x)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Dict, Optional, Tuple, Union, Literal, List
5
+ import math
6
+ import copy
7
+
8
+ class CLIPLoss(nn.Module):
9
+ """CLIP风格的对比学习损失"""
10
+ def __init__(self, temperature: float = 0.07, max_temperature: float = 100.0):
11
+ super().__init__()
12
+ self.temperature = temperature
13
+ self.max_temperature = max_temperature
14
+ # 初始化 logit_scale
15
+ self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / temperature))
16
+
17
+ def forward(
18
+ self,
19
+ image_features: torch.Tensor,
20
+ text_features: torch.Tensor
21
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
22
+ """
23
+ Args:
24
+ image_features: [B, D]
25
+ text_features: [B, D]
26
+ """
27
+ # 归一化
28
+ image_features = F.normalize(image_features, dim=-1)
29
+ text_features = F.normalize(text_features, dim=-1)
30
+
31
+ # 限制 logit_scale 防止数值不稳定
32
+ logit_scale = self.logit_scale.exp().clamp(max=self.max_temperature)
33
+ logits_per_image = logit_scale * image_features @ text_features.T
34
+ logits_per_text = logits_per_image.T
35
+
36
+ # 标签: 对角线为正样本
37
+ batch_size = image_features.shape[0]
38
+ labels = torch.arange(batch_size, device=image_features.device)
39
+
40
+ # 双向交叉熵
41
+ loss_i2t = F.cross_entropy(logits_per_image, labels)
42
+ loss_t2i = F.cross_entropy(logits_per_text, labels)
43
+
44
+ total_loss = (loss_i2t + loss_t2i) / 2
45
+
46
+ return total_loss, loss_i2t, loss_t2i
47
+
48
+ class SigLIPLoss(nn.Module):
49
+ def __init__(self, init_temperature: float = 1.0, init_bias: float = -10.0):
50
+ super().__init__()
51
+ self.t_prime = nn.Parameter(torch.tensor(math.log(init_temperature)))
52
+ self.b = nn.Parameter(torch.tensor(init_bias))
53
+
54
+ def forward(
55
+ self,
56
+ image_features: torch.Tensor,
57
+ text_features: torch.Tensor
58
+ ) -> torch.Tensor:
59
+ # 归一化
60
+ image_features = F.normalize(image_features, dim=-1)
61
+ text_features = F.normalize(text_features, dim=-1)
62
+
63
+ batch_size = image_features.shape[0]
64
+
65
+ # Logits = exp(t) * (x @ yT) + b
66
+ logits = image_features @ text_features.T * self.t_prime.exp() + self.b
67
+
68
+ # 构造标签: 对角线为1,其余为-1
69
+ labels = -torch.ones(batch_size, batch_size, device=image_features.device)
70
+ labels += 2 * torch.eye(batch_size, device=image_features.device)
71
+
72
+ loss = -F.logsigmoid(labels * logits).sum() / batch_size
73
+
74
+ return loss
75
+
76
+ class InfoNCELoss(nn.Module):
77
+ def __init__(self, temperature: float = 0.07):
78
+ super().__init__()
79
+ self.temperature = temperature
80
+
81
+ def forward(
82
+ self,
83
+ query: torch.Tensor,
84
+ positive_key: torch.Tensor,
85
+ negative_keys: Optional[torch.Tensor] = None
86
+ ) -> torch.Tensor:
87
+ """
88
+ Args:
89
+ query: [B, D]
90
+ positive_key: [B, D]
91
+ negative_keys: [B, N, D] or None.
92
+ """
93
+ query = F.normalize(query, dim=-1)
94
+ positive_key = F.normalize(positive_key, dim=-1)
95
+
96
+ if negative_keys is not None:
97
+
98
+ pos_sim = (query * positive_key).sum(dim=-1) / self.temperature
99
+
100
+ negative_keys = F.normalize(negative_keys, dim=-1)
101
+ # neg_sim: [B, N]
102
+ neg_sim = (query.unsqueeze(1) * negative_keys).sum(dim=-1) / self.temperature
103
+
104
+ # [B, 1 + N]
105
+ logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)
106
+ # 正样本在索引0
107
+ labels = torch.zeros(query.shape[0], dtype=torch.long, device=query.device)
108
+ else:
109
+ logits = query @ positive_key.T / self.temperature
110
+ labels = torch.arange(query.shape[0], dtype=torch.long, device=query.device)
111
+
112
+ loss = F.cross_entropy(logits, labels)
113
+ return loss
114
+
115
+ class ProjectionHead(nn.Module):
116
+ def __init__(
117
+ self,
118
+ input_dim: int,
119
+ embed_dim: int,
120
+ pooling_type: Literal['cls', 'mean', 'max', 'none'] = 'mean',
121
+ exclude_first_token: bool = False
122
+ ):
123
+ super().__init__()
124
+ self.pooling_type = pooling_type
125
+ self.exclude_first_token = exclude_first_token
126
+
127
+ self.net = nn.Sequential(
128
+ nn.Linear(input_dim, embed_dim),
129
+ nn.GELU(),
130
+ nn.Linear(embed_dim, embed_dim)
131
+ )
132
+
133
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
134
+ # 适配 3D 张量 [B, Seq, D] -> [B, D]
135
+ if x.dim() == 3:
136
+ if self.pooling_type == 'cls':
137
+ x = x[:, 0, :]
138
+
139
+ elif self.pooling_type == 'mean':
140
+ if self.exclude_first_token and x.shape[1] > 1:
141
+ x = x[:, 1:, :].mean(dim=1)
142
+ else:
143
+ x = x.mean(dim=1)
144
+
145
+ elif self.pooling_type == 'max':
146
+ if self.exclude_first_token and x.shape[1] > 1:
147
+ x = x[:, 1:, :].max(dim=1)[0]
148
+ else:
149
+ x = x.max(dim=1)[0]
150
+
151
+ elif self.pooling_type == 'none':
152
+ pass
153
+
154
+ return self.net(x)
155
+
156
+ class MultiModalContrastiveLoss(nn.Module):
157
+ def __init__(
158
+ self,
159
+ embed_dim: int = 512,
160
+ input_dims: Union[int, Dict[str, int]] = 2048,
161
+ temperature: float = 0.07,
162
+ loss_type: str = 'clip',
163
+ modality_config: Optional[Dict[str, str]] = None
164
+ ):
165
+ super().__init__()
166
+ self.embed_dim = embed_dim
167
+ self.loss_type = loss_type
168
+
169
+ if loss_type == 'clip':
170
+ self.loss_fn = CLIPLoss(temperature)
171
+ elif loss_type == 'siglip':
172
+ self.loss_fn = SigLIPLoss()
173
+ else:
174
+ self.loss_fn = InfoNCELoss(temperature)
175
+
176
+ self.projectors = nn.ModuleDict()
177
+
178
+ if modality_config is None:
179
+ modality_config = {
180
+ 'text': 'cls',
181
+ 'image': 'cls',
182
+ 'audio': 'mean',
183
+ 'video': 'mean'
184
+ }
185
+
186
+ self.modality_config = modality_config
187
+
188
+ # 初始化投影头
189
+ for mod_name, pool_type in modality_config.items():
190
+ dim = 0
191
+ if isinstance(input_dims, dict):
192
+ dim = input_dims.get(mod_name)
193
+ # 如果字典里没给这个模态的维度,跳过初始化,避免 crash
194
+ if dim is None:
195
+ continue
196
+ else:
197
+ dim = input_dims
198
+
199
+ exclude_first = False
200
+ if mod_name in ['image', 'text'] and pool_type in ['mean', 'max']:
201
+ exclude_first = True
202
+
203
+ self.projectors[mod_name] = ProjectionHead(
204
+ input_dim=dim,
205
+ embed_dim=embed_dim,
206
+ pooling_type=pool_type,
207
+ exclude_first_token=exclude_first
208
+ )
209
+
210
+ def forward(
211
+ self,
212
+ features: Dict[str, torch.Tensor],
213
+ modality_pairs: Optional[List[Tuple[str, str]]] = None
214
+ ) -> Dict[str, torch.Tensor]:
215
+
216
+ # 自动生成对比对:将所有非Text模态与Text对比
217
+ if modality_pairs is None:
218
+ if 'text' in features:
219
+ modality_pairs = [
220
+ (mod, 'text') for mod in features.keys() if mod != 'text'
221
+ ]
222
+ else:
223
+ return {}
224
+
225
+ losses = {}
226
+
227
+ for mod_a, mod_b in modality_pairs:
228
+ if mod_a not in features or mod_b not in features:
229
+ continue
230
+
231
+ if mod_a not in self.projectors or mod_b not in self.projectors:
232
+ # 记录警告或跳过
233
+ continue
234
+
235
+ feat_a = self.projectors[mod_a](features[mod_a])
236
+ feat_b = self.projectors[mod_b](features[mod_b])
237
+
238
+ # 计算损失
239
+ loss_key = f'{mod_a}_{mod_b}_loss'
240
+
241
+ if self.loss_type == 'clip':
242
+ loss, _, _ = self.loss_fn(feat_a, feat_b)
243
+ else:
244
+ loss = self.loss_fn(feat_a, feat_b)
245
+
246
+ losses[loss_key] = loss
247
+
248
+ return losses
249
+
250
+ class MomentumEncoder(nn.Module):
251
+ def __init__(self, encoder: nn.Module, momentum: float = 0.999):
252
+ super().__init__()
253
+ self.encoder = encoder
254
+ self.momentum_encoder = self._build_momentum_encoder(encoder)
255
+ self.momentum = momentum
256
+
257
+ def _build_momentum_encoder(self, encoder: nn.Module) -> nn.Module:
258
+ """构建动量编码器"""
259
+ momentum_encoder = copy.deepcopy(encoder)
260
+
261
+ # 冻结动量编码器参数
262
+ for param in momentum_encoder.parameters():
263
+ param.requires_grad = False
264
+
265
+ return momentum_encoder
266
+
267
+ @torch.no_grad()
268
+ def _update_momentum_encoder(self):
269
+ for param_q, param_k in zip(
270
+ self.encoder.parameters(),
271
+ self.momentum_encoder.parameters()
272
+ ):
273
+ # EMA Update: k = m * k + (1 - m) * q
274
+ param_k.data.mul_(self.momentum).add_(param_q.data, alpha=1.0 - self.momentum)
275
+
276
+ for buffer_q, buffer_k in zip(
277
+ self.encoder.buffers(),
278
+ self.momentum_encoder.buffers()
279
+ ):
280
+ buffer_k.data.copy_(buffer_q.data)
281
+
282
+ def forward(self, x: torch.Tensor, use_momentum: bool = False) -> torch.Tensor:
283
+ if use_momentum:
284
+ with torch.no_grad():
285
+ self._update_momentum_encoder()
286
+ # 动量编码器始终处于 eval 模式
287
+ self.momentum_encoder.eval()
288
+ return self.momentum_encoder(x)
289
+ else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  return self.encoder(x)