File size: 2,798 Bytes
5946936
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
import torch.nn.functional as F

def infoNCE_loss1(mol_features, ms_features, temperature=0.1, norm=True):
        # Normalize features
        if norm:
            mol_features = F.normalize(mol_features, p=2, dim=1)
            ms_features = F.normalize(ms_features, p=2, dim=1)
        
        # Compute similarity matrix
        logits = torch.mm(mol_features, ms_features.T) / temperature
        
        # Labels: positive pairs are on the diagonal
        batch_size = mol_features.size(0)
        labels = torch.arange(batch_size, device=mol_features.device)
        
        # Cross entropy loss
        loss_mol = F.cross_entropy(logits, labels)
        loss_trans = F.cross_entropy(logits.T, labels)
        loss = (loss_mol + loss_trans) / 2
        
        return loss

def infoNCE_loss2(mol_features, ms_features, temperature=0.1, alpha=0.75, norm=True):
    """

    使用更合适的temperature (0.07是CLIP中常用的值)

    添加更多的数值稳定性措施

    """
    if norm:
        mol_features = F.normalize(mol_features, p=2, dim=1)
        ms_features = F.normalize(ms_features, p=2, dim=1)
    
    batch_size = mol_features.size(0)
    
    # 计算相似度矩阵
    logits_ab = torch.matmul(mol_features, ms_features.T) / temperature
    logits_ba = torch.matmul(ms_features, mol_features.T) / temperature
    
    # 创建标签
    labels = torch.arange(batch_size, device=mol_features.device)
    
    # 计算损失
    loss_ab = F.cross_entropy(logits_ab, labels)
    loss_ba = F.cross_entropy(logits_ba, labels)
    
    return alpha * loss_ab + (1 - alpha) * loss_ba

# 在对比损失函数中增加困难负样本挖掘
def contrastive_loss_with_hard_negatives(features1, features2, margin=1.0, hard_negative_ratio=0.3):
    """

    改进的对比损失函数,包含困难负样本挖掘

    """
    batch_size = features1.shape[0]
    
    # 计算相似度矩阵
    similarity = torch.matmul(features1, features2.t())
    
    # 正样本对(对角线)
    positive_similarity = torch.diag(similarity)
    
    # 困难负样本挖掘:选择相似度较高的负样本
    mask = ~torch.eye(batch_size, dtype=torch.bool)
    negative_similarities = similarity[mask].view(batch_size, batch_size-1)
    
    # 选择最困难的前k个负样本
    k = int(batch_size * hard_negative_ratio)
    hard_negatives, _ = torch.topk(negative_similarities, k=k, dim=1)
    
    # 对比损失计算
    loss = 0
    for i in range(batch_size):
        pos_loss = 1 - positive_similarity[i]
        neg_loss = torch.mean(torch.clamp(hard_negatives[i] - margin, min=0))
        loss += pos_loss + neg_loss
    
    return loss / batch_size