|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def infoNCE_loss1(mol_features, ms_features, temperature=0.1, norm=True):
|
|
|
|
|
|
if norm:
|
|
|
mol_features = F.normalize(mol_features, p=2, dim=1)
|
|
|
ms_features = F.normalize(ms_features, p=2, dim=1)
|
|
|
|
|
|
|
|
|
logits = torch.mm(mol_features, ms_features.T) / temperature
|
|
|
|
|
|
|
|
|
batch_size = mol_features.size(0)
|
|
|
labels = torch.arange(batch_size, device=mol_features.device)
|
|
|
|
|
|
|
|
|
loss_mol = F.cross_entropy(logits, labels)
|
|
|
loss_trans = F.cross_entropy(logits.T, labels)
|
|
|
loss = (loss_mol + loss_trans) / 2
|
|
|
|
|
|
return loss
|
|
|
|
|
|
def infoNCE_loss2(mol_features, ms_features, temperature=0.1, alpha=0.75, norm=True):
|
|
|
"""
|
|
|
使用更合适的temperature (0.07是CLIP中常用的值)
|
|
|
添加更多的数值稳定性措施
|
|
|
"""
|
|
|
if norm:
|
|
|
mol_features = F.normalize(mol_features, p=2, dim=1)
|
|
|
ms_features = F.normalize(ms_features, p=2, dim=1)
|
|
|
|
|
|
batch_size = mol_features.size(0)
|
|
|
|
|
|
|
|
|
logits_ab = torch.matmul(mol_features, ms_features.T) / temperature
|
|
|
logits_ba = torch.matmul(ms_features, mol_features.T) / temperature
|
|
|
|
|
|
|
|
|
labels = torch.arange(batch_size, device=mol_features.device)
|
|
|
|
|
|
|
|
|
loss_ab = F.cross_entropy(logits_ab, labels)
|
|
|
loss_ba = F.cross_entropy(logits_ba, labels)
|
|
|
|
|
|
return alpha * loss_ab + (1 - alpha) * loss_ba
|
|
|
|
|
|
|
|
|
def contrastive_loss_with_hard_negatives(features1, features2, margin=1.0, hard_negative_ratio=0.3):
|
|
|
"""
|
|
|
改进的对比损失函数,包含困难负样本挖掘
|
|
|
"""
|
|
|
batch_size = features1.shape[0]
|
|
|
|
|
|
|
|
|
similarity = torch.matmul(features1, features2.t())
|
|
|
|
|
|
|
|
|
positive_similarity = torch.diag(similarity)
|
|
|
|
|
|
|
|
|
mask = ~torch.eye(batch_size, dtype=torch.bool)
|
|
|
negative_similarities = similarity[mask].view(batch_size, batch_size-1)
|
|
|
|
|
|
|
|
|
k = int(batch_size * hard_negative_ratio)
|
|
|
hard_negatives, _ = torch.topk(negative_similarities, k=k, dim=1)
|
|
|
|
|
|
|
|
|
loss = 0
|
|
|
for i in range(batch_size):
|
|
|
pos_loss = 1 - positive_similarity[i]
|
|
|
neg_loss = torch.mean(torch.clamp(hard_negatives[i] - margin, min=0))
|
|
|
loss += pos_loss + neg_loss
|
|
|
|
|
|
return loss / batch_size |