File size: 2,798 Bytes
5946936 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import torch
import torch.nn as nn
import torch.nn.functional as F
def infoNCE_loss1(mol_features, ms_features, temperature=0.1, norm=True):
# Normalize features
if norm:
mol_features = F.normalize(mol_features, p=2, dim=1)
ms_features = F.normalize(ms_features, p=2, dim=1)
# Compute similarity matrix
logits = torch.mm(mol_features, ms_features.T) / temperature
# Labels: positive pairs are on the diagonal
batch_size = mol_features.size(0)
labels = torch.arange(batch_size, device=mol_features.device)
# Cross entropy loss
loss_mol = F.cross_entropy(logits, labels)
loss_trans = F.cross_entropy(logits.T, labels)
loss = (loss_mol + loss_trans) / 2
return loss
def infoNCE_loss2(mol_features, ms_features, temperature=0.1, alpha=0.75, norm=True):
"""
使用更合适的temperature (0.07是CLIP中常用的值)
添加更多的数值稳定性措施
"""
if norm:
mol_features = F.normalize(mol_features, p=2, dim=1)
ms_features = F.normalize(ms_features, p=2, dim=1)
batch_size = mol_features.size(0)
# 计算相似度矩阵
logits_ab = torch.matmul(mol_features, ms_features.T) / temperature
logits_ba = torch.matmul(ms_features, mol_features.T) / temperature
# 创建标签
labels = torch.arange(batch_size, device=mol_features.device)
# 计算损失
loss_ab = F.cross_entropy(logits_ab, labels)
loss_ba = F.cross_entropy(logits_ba, labels)
return alpha * loss_ab + (1 - alpha) * loss_ba
# 在对比损失函数中增加困难负样本挖掘
def contrastive_loss_with_hard_negatives(features1, features2, margin=1.0, hard_negative_ratio=0.3):
"""
改进的对比损失函数,包含困难负样本挖掘
"""
batch_size = features1.shape[0]
# 计算相似度矩阵
similarity = torch.matmul(features1, features2.t())
# 正样本对(对角线)
positive_similarity = torch.diag(similarity)
# 困难负样本挖掘:选择相似度较高的负样本
mask = ~torch.eye(batch_size, dtype=torch.bool)
negative_similarities = similarity[mask].view(batch_size, batch_size-1)
# 选择最困难的前k个负样本
k = int(batch_size * hard_negative_ratio)
hard_negatives, _ = torch.topk(negative_similarities, k=k, dim=1)
# 对比损失计算
loss = 0
for i in range(batch_size):
pos_loss = 1 - positive_similarity[i]
neg_loss = torch.mean(torch.clamp(hard_negatives[i] - margin, min=0))
loss += pos_loss + neg_loss
return loss / batch_size |