import torch import torch.nn as nn import torch.nn.functional as F def infoNCE_loss1(mol_features, ms_features, temperature=0.1, norm=True): # Normalize features if norm: mol_features = F.normalize(mol_features, p=2, dim=1) ms_features = F.normalize(ms_features, p=2, dim=1) # Compute similarity matrix logits = torch.mm(mol_features, ms_features.T) / temperature # Labels: positive pairs are on the diagonal batch_size = mol_features.size(0) labels = torch.arange(batch_size, device=mol_features.device) # Cross entropy loss loss_mol = F.cross_entropy(logits, labels) loss_trans = F.cross_entropy(logits.T, labels) loss = (loss_mol + loss_trans) / 2 return loss def infoNCE_loss2(mol_features, ms_features, temperature=0.1, alpha=0.75, norm=True): """ 使用更合适的temperature (0.07是CLIP中常用的值) 添加更多的数值稳定性措施 """ if norm: mol_features = F.normalize(mol_features, p=2, dim=1) ms_features = F.normalize(ms_features, p=2, dim=1) batch_size = mol_features.size(0) # 计算相似度矩阵 logits_ab = torch.matmul(mol_features, ms_features.T) / temperature logits_ba = torch.matmul(ms_features, mol_features.T) / temperature # 创建标签 labels = torch.arange(batch_size, device=mol_features.device) # 计算损失 loss_ab = F.cross_entropy(logits_ab, labels) loss_ba = F.cross_entropy(logits_ba, labels) return alpha * loss_ab + (1 - alpha) * loss_ba # 在对比损失函数中增加困难负样本挖掘 def contrastive_loss_with_hard_negatives(features1, features2, margin=1.0, hard_negative_ratio=0.3): """ 改进的对比损失函数,包含困难负样本挖掘 """ batch_size = features1.shape[0] # 计算相似度矩阵 similarity = torch.matmul(features1, features2.t()) # 正样本对(对角线) positive_similarity = torch.diag(similarity) # 困难负样本挖掘:选择相似度较高的负样本 mask = ~torch.eye(batch_size, dtype=torch.bool) negative_similarities = similarity[mask].view(batch_size, batch_size-1) # 选择最困难的前k个负样本 k = int(batch_size * hard_negative_ratio) hard_negatives, _ = torch.topk(negative_similarities, k=k, dim=1) # 对比损失计算 loss = 0 for i in range(batch_size): pos_loss = 1 - positive_similarity[i] neg_loss = torch.mean(torch.clamp(hard_negatives[i] - margin, min=0)) loss += pos_loss + neg_loss return loss / batch_size