recomendation / utils /triplet_mining.py
Ali Mohsin
final prod
25bdf34
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, List, Optional
import numpy as np
class SemiHardTripletMiner:
"""Semi-hard negative mining for triplet loss training."""
def __init__(self, margin: float = 0.2):
self.margin = margin
def mine_triplets(
self,
embeddings: torch.Tensor,
labels: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Mine semi-hard triplets from embeddings.
Args:
embeddings: (N, D) tensor of normalized embeddings
labels: (N,) tensor of labels
Returns:
anchors, positives, negatives: (K, D) tensors where K is number of valid triplets
"""
# Compute pairwise distances
dist_matrix = self._compute_distance_matrix(embeddings)
# Find valid triplets
anchors, positives, negatives = self._find_semi_hard_triplets(
dist_matrix, labels
)
if len(anchors) == 0:
# Fallback to random triplets if no semi-hard ones found
return self._random_triplets(embeddings, labels)
return embeddings[anchors], embeddings[positives], embeddings[negatives]
def _compute_distance_matrix(self, embeddings: torch.Tensor) -> torch.Tensor:
"""Compute pairwise cosine distances between embeddings."""
# Normalize embeddings to unit length
embeddings = F.normalize(embeddings, p=2, dim=1)
# Compute cosine similarity matrix
similarity_matrix = torch.mm(embeddings, embeddings.t())
# Convert to distance matrix (1 - similarity)
distance_matrix = 1 - similarity_matrix
return distance_matrix
def _find_semi_hard_triplets(
self,
dist_matrix: torch.Tensor,
labels: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Find semi-hard negative triplets."""
anchors = []
positives = []
negatives = []
n = len(labels)
for i in range(n):
anchor_label = labels[i]
# Find positive samples (same label)
positive_mask = (labels == anchor_label) & (torch.arange(n, device=labels.device) != i)
positive_indices = torch.where(positive_mask)[0]
if len(positive_indices) == 0:
continue
# Find negative samples (different label)
negative_mask = labels != anchor_label
negative_indices = torch.where(negative_mask)[0]
if len(negative_indices) == 0:
continue
# For each positive, find semi-hard negative
for pos_idx in positive_indices:
pos_dist = dist_matrix[i, pos_idx]
# Find negatives that are harder than positive but not too hard
# Semi-hard: pos_dist < neg_dist < pos_dist + margin
neg_dists = dist_matrix[i, negative_indices]
semi_hard_mask = (neg_dists > pos_dist) & (neg_dists < pos_dist + self.margin)
semi_hard_indices = torch.where(semi_hard_mask)[0]
if len(semi_hard_indices) > 0:
# Choose the hardest semi-hard negative
hardest_idx = semi_hard_indices[torch.argmax(neg_dists[semi_hard_indices])]
neg_idx = negative_indices[hardest_idx]
anchors.append(i)
positives.append(pos_idx)
negatives.append(neg_idx)
if len(anchors) == 0:
return torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)
return torch.tensor(anchors), torch.tensor(positives), torch.tensor(negatives)
def _random_triplets(
self,
embeddings: torch.Tensor,
labels: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate random triplets as fallback."""
anchors = []
positives = []
negatives = []
n = len(labels)
max_triplets = min(1000, n // 3) # Limit number of random triplets
for _ in range(max_triplets):
# Random anchor
anchor_idx = torch.randint(0, n, (1,)).item()
anchor_label = labels[anchor_idx]
# Random positive (same label)
positive_mask = (labels == anchor_label) & (torch.arange(n, device=labels.device) != anchor_idx)
positive_indices = torch.where(positive_mask)[0]
if len(positive_indices) == 0:
continue
pos_idx = positive_indices[torch.randint(0, len(positive_indices), (1,))].item()
# Random negative (different label)
negative_mask = labels != anchor_label
negative_indices = torch.where(negative_mask)[0]
if len(negative_indices) == 0:
continue
neg_idx = negative_indices[torch.randint(0, len(negative_indices), (1,))].item()
anchors.append(anchor_idx)
positives.append(pos_idx)
negatives.append(neg_idx)
if len(anchors) == 0:
# Last resort: duplicate first sample
return embeddings[:1], embeddings[:1], embeddings[:1]
return torch.tensor(anchors), torch.tensor(positives), torch.tensor(negatives)
class OnlineTripletMiner:
"""Online triplet mining for batch training."""
def __init__(self, margin: float = 0.2, mining_strategy: str = "semi_hard"):
self.margin = margin
self.mining_strategy = mining_strategy
self.semi_hard_miner = SemiHardTripletMiner(margin)
def mine_batch_triplets(
self,
embeddings: torch.Tensor,
labels: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Mine triplets from a batch of embeddings.
Args:
embeddings: (B, D) tensor of normalized embeddings
labels: (B,) tensor of labels
Returns:
anchors, positives, negatives: (K, D) tensors
"""
if self.mining_strategy == "semi_hard":
return self.semi_hard_miner.mine_triplets(embeddings, labels)
elif self.mining_strategy == "hardest":
return self._hardest_triplets(embeddings, labels)
elif self.mining_strategy == "random":
return self._random_batch_triplets(embeddings, labels)
else:
raise ValueError(f"Unknown mining strategy: {self.mining_strategy}")
def _hardest_triplets(
self,
embeddings: torch.Tensor,
labels: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Find hardest negative triplets."""
dist_matrix = self._compute_distance_matrix(embeddings)
anchors = []
positives = []
negatives = []
n = len(labels)
for i in range(n):
anchor_label = labels[i]
# Find positive samples
positive_mask = (labels == anchor_label) & (torch.arange(n, device=labels.device) != i)
positive_indices = torch.where(positive_mask)[0]
if len(positive_indices) == 0:
continue
# Find negative samples
negative_mask = labels != anchor_label
negative_indices = torch.where(negative_mask)[0]
if len(negative_indices) == 0:
continue
# For each positive, find hardest negative
for pos_idx in positive_indices:
pos_dist = dist_matrix[i, pos_idx]
# Find hardest negative (closest to anchor)
neg_dists = dist_matrix[i, negative_indices]
hardest_idx = torch.argmin(neg_dists)
neg_idx = negative_indices[hardest_idx]
# Only include if negative is closer than positive + margin
if neg_dists[hardest_idx] < pos_dist + self.margin:
anchors.append(i)
positives.append(pos_idx)
negatives.append(neg_idx)
if len(anchors) == 0:
return self._random_batch_triplets(embeddings, labels)
return torch.tensor(anchors), torch.tensor(positives), torch.tensor(negatives)
def _random_batch_triplets(
self,
embeddings: torch.Tensor,
labels: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate random triplets from batch."""
return self.semi_hard_miner._random_triplets(embeddings, labels)
def _compute_distance_matrix(self, embeddings: torch.Tensor) -> torch.Tensor:
"""Compute pairwise cosine distances."""
embeddings = F.normalize(embeddings, p=2, dim=1)
similarity_matrix = torch.mm(embeddings, embeddings.t())
distance_matrix = 1 - similarity_matrix
return distance_matrix
def create_triplet_miner(
strategy: str = "semi_hard",
margin: float = 0.2
) -> OnlineTripletMiner:
"""Factory function to create a triplet miner."""
return OnlineTripletMiner(margin=margin, mining_strategy=strategy)
# Example usage
if __name__ == "__main__":
# Test with dummy data
batch_size = 32
embed_dim = 128
num_classes = 8
# Generate dummy embeddings and labels
embeddings = torch.randn(batch_size, embed_dim)
labels = torch.randint(0, num_classes, (batch_size,))
# Create miner
miner = create_triplet_miner(strategy="semi_hard", margin=0.2)
# Mine triplets
anchors, positives, negatives = miner.mine_batch_triplets(embeddings, labels)
print(f"Generated {len(anchors)} triplets from batch of {batch_size}")
print(f"Anchor indices: {anchors[:5]}")
print(f"Positive indices: {positives[:5]}")
print(f"Negative indices: {negatives[:5]}")