import torch import torch.nn as nn import torch.nn.functional as F from typing import Tuple, List, Optional import numpy as np class SemiHardTripletMiner: """Semi-hard negative mining for triplet loss training.""" def __init__(self, margin: float = 0.2): self.margin = margin def mine_triplets( self, embeddings: torch.Tensor, labels: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Mine semi-hard triplets from embeddings. Args: embeddings: (N, D) tensor of normalized embeddings labels: (N,) tensor of labels Returns: anchors, positives, negatives: (K, D) tensors where K is number of valid triplets """ # Compute pairwise distances dist_matrix = self._compute_distance_matrix(embeddings) # Find valid triplets anchors, positives, negatives = self._find_semi_hard_triplets( dist_matrix, labels ) if len(anchors) == 0: # Fallback to random triplets if no semi-hard ones found return self._random_triplets(embeddings, labels) return embeddings[anchors], embeddings[positives], embeddings[negatives] def _compute_distance_matrix(self, embeddings: torch.Tensor) -> torch.Tensor: """Compute pairwise cosine distances between embeddings.""" # Normalize embeddings to unit length embeddings = F.normalize(embeddings, p=2, dim=1) # Compute cosine similarity matrix similarity_matrix = torch.mm(embeddings, embeddings.t()) # Convert to distance matrix (1 - similarity) distance_matrix = 1 - similarity_matrix return distance_matrix def _find_semi_hard_triplets( self, dist_matrix: torch.Tensor, labels: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Find semi-hard negative triplets.""" anchors = [] positives = [] negatives = [] n = len(labels) for i in range(n): anchor_label = labels[i] # Find positive samples (same label) positive_mask = (labels == anchor_label) & (torch.arange(n, device=labels.device) != i) positive_indices = torch.where(positive_mask)[0] if len(positive_indices) == 0: continue # Find negative samples (different label) negative_mask = labels != anchor_label negative_indices = torch.where(negative_mask)[0] if len(negative_indices) == 0: continue # For each positive, find semi-hard negative for pos_idx in positive_indices: pos_dist = dist_matrix[i, pos_idx] # Find negatives that are harder than positive but not too hard # Semi-hard: pos_dist < neg_dist < pos_dist + margin neg_dists = dist_matrix[i, negative_indices] semi_hard_mask = (neg_dists > pos_dist) & (neg_dists < pos_dist + self.margin) semi_hard_indices = torch.where(semi_hard_mask)[0] if len(semi_hard_indices) > 0: # Choose the hardest semi-hard negative hardest_idx = semi_hard_indices[torch.argmax(neg_dists[semi_hard_indices])] neg_idx = negative_indices[hardest_idx] anchors.append(i) positives.append(pos_idx) negatives.append(neg_idx) if len(anchors) == 0: return torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long) return torch.tensor(anchors), torch.tensor(positives), torch.tensor(negatives) def _random_triplets( self, embeddings: torch.Tensor, labels: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Generate random triplets as fallback.""" anchors = [] positives = [] negatives = [] n = len(labels) max_triplets = min(1000, n // 3) # Limit number of random triplets for _ in range(max_triplets): # Random anchor anchor_idx = torch.randint(0, n, (1,)).item() anchor_label = labels[anchor_idx] # Random positive (same label) positive_mask = (labels == anchor_label) & (torch.arange(n, device=labels.device) != anchor_idx) positive_indices = torch.where(positive_mask)[0] if len(positive_indices) == 0: continue pos_idx = positive_indices[torch.randint(0, len(positive_indices), (1,))].item() # Random negative (different label) negative_mask = labels != anchor_label negative_indices = torch.where(negative_mask)[0] if len(negative_indices) == 0: continue neg_idx = negative_indices[torch.randint(0, len(negative_indices), (1,))].item() anchors.append(anchor_idx) positives.append(pos_idx) negatives.append(neg_idx) if len(anchors) == 0: # Last resort: duplicate first sample return embeddings[:1], embeddings[:1], embeddings[:1] return torch.tensor(anchors), torch.tensor(positives), torch.tensor(negatives) class OnlineTripletMiner: """Online triplet mining for batch training.""" def __init__(self, margin: float = 0.2, mining_strategy: str = "semi_hard"): self.margin = margin self.mining_strategy = mining_strategy self.semi_hard_miner = SemiHardTripletMiner(margin) def mine_batch_triplets( self, embeddings: torch.Tensor, labels: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Mine triplets from a batch of embeddings. Args: embeddings: (B, D) tensor of normalized embeddings labels: (B,) tensor of labels Returns: anchors, positives, negatives: (K, D) tensors """ if self.mining_strategy == "semi_hard": return self.semi_hard_miner.mine_triplets(embeddings, labels) elif self.mining_strategy == "hardest": return self._hardest_triplets(embeddings, labels) elif self.mining_strategy == "random": return self._random_batch_triplets(embeddings, labels) else: raise ValueError(f"Unknown mining strategy: {self.mining_strategy}") def _hardest_triplets( self, embeddings: torch.Tensor, labels: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Find hardest negative triplets.""" dist_matrix = self._compute_distance_matrix(embeddings) anchors = [] positives = [] negatives = [] n = len(labels) for i in range(n): anchor_label = labels[i] # Find positive samples positive_mask = (labels == anchor_label) & (torch.arange(n, device=labels.device) != i) positive_indices = torch.where(positive_mask)[0] if len(positive_indices) == 0: continue # Find negative samples negative_mask = labels != anchor_label negative_indices = torch.where(negative_mask)[0] if len(negative_indices) == 0: continue # For each positive, find hardest negative for pos_idx in positive_indices: pos_dist = dist_matrix[i, pos_idx] # Find hardest negative (closest to anchor) neg_dists = dist_matrix[i, negative_indices] hardest_idx = torch.argmin(neg_dists) neg_idx = negative_indices[hardest_idx] # Only include if negative is closer than positive + margin if neg_dists[hardest_idx] < pos_dist + self.margin: anchors.append(i) positives.append(pos_idx) negatives.append(neg_idx) if len(anchors) == 0: return self._random_batch_triplets(embeddings, labels) return torch.tensor(anchors), torch.tensor(positives), torch.tensor(negatives) def _random_batch_triplets( self, embeddings: torch.Tensor, labels: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Generate random triplets from batch.""" return self.semi_hard_miner._random_triplets(embeddings, labels) def _compute_distance_matrix(self, embeddings: torch.Tensor) -> torch.Tensor: """Compute pairwise cosine distances.""" embeddings = F.normalize(embeddings, p=2, dim=1) similarity_matrix = torch.mm(embeddings, embeddings.t()) distance_matrix = 1 - similarity_matrix return distance_matrix def create_triplet_miner( strategy: str = "semi_hard", margin: float = 0.2 ) -> OnlineTripletMiner: """Factory function to create a triplet miner.""" return OnlineTripletMiner(margin=margin, mining_strategy=strategy) # Example usage if __name__ == "__main__": # Test with dummy data batch_size = 32 embed_dim = 128 num_classes = 8 # Generate dummy embeddings and labels embeddings = torch.randn(batch_size, embed_dim) labels = torch.randint(0, num_classes, (batch_size,)) # Create miner miner = create_triplet_miner(strategy="semi_hard", margin=0.2) # Mine triplets anchors, positives, negatives = miner.mine_batch_triplets(embeddings, labels) print(f"Generated {len(anchors)} triplets from batch of {batch_size}") print(f"Anchor indices: {anchors[:5]}") print(f"Positive indices: {positives[:5]}") print(f"Negative indices: {negatives[:5]}")