Delete contrastive_loss.py
Browse files- contrastive_loss.py +0 -140
contrastive_loss.py
DELETED
|
@@ -1,140 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class ContrastiveLoss(nn.Module):
|
| 7 |
-
def __init__(self, temperature=.25, distance_metric='cosine'):
|
| 8 |
-
super(ContrastiveLoss, self).__init__()
|
| 9 |
-
self.temperature = temperature
|
| 10 |
-
self.distance_metric = distance_metric
|
| 11 |
-
|
| 12 |
-
def compute_similarity(self, embeddings):
|
| 13 |
-
if self.distance_metric == 'cosine':
|
| 14 |
-
embeddings = F.normalize(embeddings, p=2, dim=-1) # [B, 2T, D]
|
| 15 |
-
sim = torch.matmul(embeddings, embeddings.transpose(-1, -2)) # [B, 2T, 2T]
|
| 16 |
-
else:
|
| 17 |
-
raise ValueError(f"Unsupported distance metric: {self.distance_metric}")
|
| 18 |
-
return sim / self.temperature
|
| 19 |
-
|
| 20 |
-
def pairwise_and_no_diag(self, m):
|
| 21 |
-
m_i = m.unsqueeze(2) # [B, T, 1]
|
| 22 |
-
m_j = m.unsqueeze(1) # [B, 1, T]
|
| 23 |
-
out = m_i & m_j # [B, T, T]
|
| 24 |
-
diag = torch.eye(m.size(1), dtype=torch.bool, device=m.device).unsqueeze(0)
|
| 25 |
-
return out & ~diag
|
| 26 |
-
|
| 27 |
-
def forward(self, embeddings, pos_indicator_mask):
|
| 28 |
-
"""
|
| 29 |
-
Args:
|
| 30 |
-
embeddings: [B, 2T, D]
|
| 31 |
-
pos_indicator_mask: [B, 2T] - boolean, positions that belong to each speaker group
|
| 32 |
-
Returns:
|
| 33 |
-
Scalar contrastive loss
|
| 34 |
-
"""
|
| 35 |
-
B, two_T, D = embeddings.shape
|
| 36 |
-
T = two_T // 2
|
| 37 |
-
sim = self.compute_similarity(embeddings) # [B, 2T, 2T]
|
| 38 |
-
|
| 39 |
-
# Split input mask
|
| 40 |
-
m1 = pos_indicator_mask[:, :T] # [B, T]
|
| 41 |
-
m2 = pos_indicator_mask[:, T:] # [B, T]
|
| 42 |
-
|
| 43 |
-
# Positive mask (same speaker pairs, diagonal excluded)
|
| 44 |
-
pos_block1 = self.pairwise_and_no_diag(m1) # [B, T, T]
|
| 45 |
-
pos_block2 = self.pairwise_and_no_diag(m2) # [B, T, T]
|
| 46 |
-
pos_mask = torch.cat([
|
| 47 |
-
torch.cat([pos_block1, torch.zeros_like(pos_block1)], dim=2), # [B, T, 2T]
|
| 48 |
-
torch.cat([torch.zeros_like(pos_block2), pos_block2], dim=2) # [B, T, 2T]
|
| 49 |
-
], dim=1) # [B, 2T, 2T]
|
| 50 |
-
|
| 51 |
-
# Negative mask (cross-speaker pairs where both are active)
|
| 52 |
-
cross = m1.unsqueeze(2) & m2.unsqueeze(1) # [B, T, T]
|
| 53 |
-
neg_mask = torch.cat([
|
| 54 |
-
torch.cat([torch.zeros_like(cross), cross], dim=2), # [B, T, 2T]
|
| 55 |
-
torch.cat([cross.transpose(1, 2), torch.zeros_like(cross)], dim=2) # [B, T, 2T]
|
| 56 |
-
], dim=1) # [B, 2T, 2T]
|
| 57 |
-
|
| 58 |
-
# Identity mask (exclude [i, i] self-pairs)
|
| 59 |
-
identity_mask = torch.eye(two_T, dtype=torch.bool, device=embeddings.device).unsqueeze(0) # [1, 2T, 2T]
|
| 60 |
-
pos_mask &= ~identity_mask
|
| 61 |
-
neg_mask &= ~identity_mask
|
| 62 |
-
|
| 63 |
-
# Fully vectorized InfoNCE computation
|
| 64 |
-
if pos_mask.any():
|
| 65 |
-
# Compute exp(similarities) for numerical stability
|
| 66 |
-
exp_sim = torch.exp(sim) # [B, 2T, 2T]
|
| 67 |
-
|
| 68 |
-
# Get positive similarities
|
| 69 |
-
pos_sim = sim[pos_mask] # [num_pos_pairs]
|
| 70 |
-
pos_exp = torch.exp(pos_sim) # [num_pos_pairs]
|
| 71 |
-
|
| 72 |
-
# For each position, sum the exponentials of its negatives
|
| 73 |
-
neg_exp_avg = 10 * torch.mean(exp_sim * neg_mask.float(), dim=2) # [B, 2T]
|
| 74 |
-
|
| 75 |
-
# Get the negative sums corresponding to each positive pair
|
| 76 |
-
pos_indices = torch.nonzero(pos_mask, as_tuple=False) # [num_pos_pairs, 3]
|
| 77 |
-
batch_idx = pos_indices[:, 0] # [num_pos_pairs]
|
| 78 |
-
row_idx = pos_indices[:, 1] # [num_pos_pairs]
|
| 79 |
-
|
| 80 |
-
# Get negative sums for each positive pair's anchor
|
| 81 |
-
neg_avgs_for_pos = neg_exp_avg[batch_idx, row_idx] # [num_pos_pairs]
|
| 82 |
-
|
| 83 |
-
# Compute denominators: exp(pos) + sum(exp(neg)) for each positive pair
|
| 84 |
-
denominators = pos_exp + neg_avgs_for_pos # [num_pos_pairs]
|
| 85 |
-
|
| 86 |
-
# InfoNCE loss: -log(exp(pos) / denominator)
|
| 87 |
-
loss = -torch.log(pos_exp / denominators)
|
| 88 |
-
total_loss = loss.mean()
|
| 89 |
-
# logits = sim
|
| 90 |
-
# logits = logits.masked_fill(~(pos_mask | neg_mask), float('-inf')) # Mask out irrelevant pairs
|
| 91 |
-
#
|
| 92 |
-
# log_probs = F.log_softmax(logits, dim=-1) # [B, 2T, 2T]
|
| 93 |
-
# pos_log_probs = log_probs[pos_mask]
|
| 94 |
-
# total_loss = -pos_log_probs.mean()
|
| 95 |
-
else:
|
| 96 |
-
# No positive pairs found, return zero loss
|
| 97 |
-
total_loss = torch.tensor(0.0, device=embeddings.device, requires_grad=True)
|
| 98 |
-
return total_loss
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
# Example usage and testing
|
| 103 |
-
def create_example_data():
|
| 104 |
-
"""Create example data for testing."""
|
| 105 |
-
B, T, D = 2, 3, 64
|
| 106 |
-
|
| 107 |
-
# Create random embeddings
|
| 108 |
-
embeddings = torch.randn(B, T, D)
|
| 109 |
-
|
| 110 |
-
# Create example positive and negative masks
|
| 111 |
-
pos_mask = torch.zeros(B, T, B, T, dtype=torch.bool)
|
| 112 |
-
neg_mask = torch.zeros(B, T, B, T, dtype=torch.bool)
|
| 113 |
-
|
| 114 |
-
# Example: make adjacent time steps positive pairs
|
| 115 |
-
for b in range(B):
|
| 116 |
-
for t in range(T - 1):
|
| 117 |
-
pos_mask[b, t, b, t + 1] = True
|
| 118 |
-
pos_mask[b, t + 1, b, t] = True
|
| 119 |
-
|
| 120 |
-
# Example: make cross-batch pairs negative
|
| 121 |
-
for b1 in range(B):
|
| 122 |
-
for b2 in range(B):
|
| 123 |
-
if b1 != b2:
|
| 124 |
-
neg_mask[b1, :, b2, :] = True
|
| 125 |
-
|
| 126 |
-
pair_masks = torch.stack([pos_mask, neg_mask])
|
| 127 |
-
|
| 128 |
-
return embeddings, pair_masks
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
if __name__ == "__main__":
|
| 132 |
-
# Test the implementation
|
| 133 |
-
embeddings, pair_masks = create_example_data()
|
| 134 |
-
|
| 135 |
-
# Initialize loss function
|
| 136 |
-
contrastive_loss = ContrastiveLoss(temperature=0.07, distance_metric='cosine')
|
| 137 |
-
|
| 138 |
-
# Compute loss
|
| 139 |
-
loss = contrastive_loss(embeddings, pair_masks)
|
| 140 |
-
print(f"Contrastive Loss: {loss.item():.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|