Lakoc commited on
Commit
5eb8d6a
·
verified ·
1 Parent(s): e4d2de3

Delete contrastive_loss.py

Browse files
Files changed (1) hide show
  1. contrastive_loss.py +0 -140
contrastive_loss.py DELETED
@@ -1,140 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
-
6
- class ContrastiveLoss(nn.Module):
7
- def __init__(self, temperature=.25, distance_metric='cosine'):
8
- super(ContrastiveLoss, self).__init__()
9
- self.temperature = temperature
10
- self.distance_metric = distance_metric
11
-
12
- def compute_similarity(self, embeddings):
13
- if self.distance_metric == 'cosine':
14
- embeddings = F.normalize(embeddings, p=2, dim=-1) # [B, 2T, D]
15
- sim = torch.matmul(embeddings, embeddings.transpose(-1, -2)) # [B, 2T, 2T]
16
- else:
17
- raise ValueError(f"Unsupported distance metric: {self.distance_metric}")
18
- return sim / self.temperature
19
-
20
- def pairwise_and_no_diag(self, m):
21
- m_i = m.unsqueeze(2) # [B, T, 1]
22
- m_j = m.unsqueeze(1) # [B, 1, T]
23
- out = m_i & m_j # [B, T, T]
24
- diag = torch.eye(m.size(1), dtype=torch.bool, device=m.device).unsqueeze(0)
25
- return out & ~diag
26
-
27
- def forward(self, embeddings, pos_indicator_mask):
28
- """
29
- Args:
30
- embeddings: [B, 2T, D]
31
- pos_indicator_mask: [B, 2T] - boolean, positions that belong to each speaker group
32
- Returns:
33
- Scalar contrastive loss
34
- """
35
- B, two_T, D = embeddings.shape
36
- T = two_T // 2
37
- sim = self.compute_similarity(embeddings) # [B, 2T, 2T]
38
-
39
- # Split input mask
40
- m1 = pos_indicator_mask[:, :T] # [B, T]
41
- m2 = pos_indicator_mask[:, T:] # [B, T]
42
-
43
- # Positive mask (same speaker pairs, diagonal excluded)
44
- pos_block1 = self.pairwise_and_no_diag(m1) # [B, T, T]
45
- pos_block2 = self.pairwise_and_no_diag(m2) # [B, T, T]
46
- pos_mask = torch.cat([
47
- torch.cat([pos_block1, torch.zeros_like(pos_block1)], dim=2), # [B, T, 2T]
48
- torch.cat([torch.zeros_like(pos_block2), pos_block2], dim=2) # [B, T, 2T]
49
- ], dim=1) # [B, 2T, 2T]
50
-
51
- # Negative mask (cross-speaker pairs where both are active)
52
- cross = m1.unsqueeze(2) & m2.unsqueeze(1) # [B, T, T]
53
- neg_mask = torch.cat([
54
- torch.cat([torch.zeros_like(cross), cross], dim=2), # [B, T, 2T]
55
- torch.cat([cross.transpose(1, 2), torch.zeros_like(cross)], dim=2) # [B, T, 2T]
56
- ], dim=1) # [B, 2T, 2T]
57
-
58
- # Identity mask (exclude [i, i] self-pairs)
59
- identity_mask = torch.eye(two_T, dtype=torch.bool, device=embeddings.device).unsqueeze(0) # [1, 2T, 2T]
60
- pos_mask &= ~identity_mask
61
- neg_mask &= ~identity_mask
62
-
63
- # Fully vectorized InfoNCE computation
64
- if pos_mask.any():
65
- # Compute exp(similarities) for numerical stability
66
- exp_sim = torch.exp(sim) # [B, 2T, 2T]
67
-
68
- # Get positive similarities
69
- pos_sim = sim[pos_mask] # [num_pos_pairs]
70
- pos_exp = torch.exp(pos_sim) # [num_pos_pairs]
71
-
72
- # For each position, sum the exponentials of its negatives
73
- neg_exp_avg = 10 * torch.mean(exp_sim * neg_mask.float(), dim=2) # [B, 2T]
74
-
75
- # Get the negative sums corresponding to each positive pair
76
- pos_indices = torch.nonzero(pos_mask, as_tuple=False) # [num_pos_pairs, 3]
77
- batch_idx = pos_indices[:, 0] # [num_pos_pairs]
78
- row_idx = pos_indices[:, 1] # [num_pos_pairs]
79
-
80
- # Get negative sums for each positive pair's anchor
81
- neg_avgs_for_pos = neg_exp_avg[batch_idx, row_idx] # [num_pos_pairs]
82
-
83
- # Compute denominators: exp(pos) + sum(exp(neg)) for each positive pair
84
- denominators = pos_exp + neg_avgs_for_pos # [num_pos_pairs]
85
-
86
- # InfoNCE loss: -log(exp(pos) / denominator)
87
- loss = -torch.log(pos_exp / denominators)
88
- total_loss = loss.mean()
89
- # logits = sim
90
- # logits = logits.masked_fill(~(pos_mask | neg_mask), float('-inf')) # Mask out irrelevant pairs
91
- #
92
- # log_probs = F.log_softmax(logits, dim=-1) # [B, 2T, 2T]
93
- # pos_log_probs = log_probs[pos_mask]
94
- # total_loss = -pos_log_probs.mean()
95
- else:
96
- # No positive pairs found, return zero loss
97
- total_loss = torch.tensor(0.0, device=embeddings.device, requires_grad=True)
98
- return total_loss
99
-
100
-
101
-
102
- # Example usage and testing
103
- def create_example_data():
104
- """Create example data for testing."""
105
- B, T, D = 2, 3, 64
106
-
107
- # Create random embeddings
108
- embeddings = torch.randn(B, T, D)
109
-
110
- # Create example positive and negative masks
111
- pos_mask = torch.zeros(B, T, B, T, dtype=torch.bool)
112
- neg_mask = torch.zeros(B, T, B, T, dtype=torch.bool)
113
-
114
- # Example: make adjacent time steps positive pairs
115
- for b in range(B):
116
- for t in range(T - 1):
117
- pos_mask[b, t, b, t + 1] = True
118
- pos_mask[b, t + 1, b, t] = True
119
-
120
- # Example: make cross-batch pairs negative
121
- for b1 in range(B):
122
- for b2 in range(B):
123
- if b1 != b2:
124
- neg_mask[b1, :, b2, :] = True
125
-
126
- pair_masks = torch.stack([pos_mask, neg_mask])
127
-
128
- return embeddings, pair_masks
129
-
130
-
131
- if __name__ == "__main__":
132
- # Test the implementation
133
- embeddings, pair_masks = create_example_data()
134
-
135
- # Initialize loss function
136
- contrastive_loss = ContrastiveLoss(temperature=0.07, distance_metric='cosine')
137
-
138
- # Compute loss
139
- loss = contrastive_loss(embeddings, pair_masks)
140
- print(f"Contrastive Loss: {loss.item():.4f}")