Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Tuple, List, Optional | |
| import numpy as np | |
| class SemiHardTripletMiner: | |
| """Semi-hard negative mining for triplet loss training.""" | |
| def __init__(self, margin: float = 0.2): | |
| self.margin = margin | |
| def mine_triplets( | |
| self, | |
| embeddings: torch.Tensor, | |
| labels: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Mine semi-hard triplets from embeddings. | |
| Args: | |
| embeddings: (N, D) tensor of normalized embeddings | |
| labels: (N,) tensor of labels | |
| Returns: | |
| anchors, positives, negatives: (K, D) tensors where K is number of valid triplets | |
| """ | |
| # Compute pairwise distances | |
| dist_matrix = self._compute_distance_matrix(embeddings) | |
| # Find valid triplets | |
| anchors, positives, negatives = self._find_semi_hard_triplets( | |
| dist_matrix, labels | |
| ) | |
| if len(anchors) == 0: | |
| # Fallback to random triplets if no semi-hard ones found | |
| return self._random_triplets(embeddings, labels) | |
| return embeddings[anchors], embeddings[positives], embeddings[negatives] | |
| def _compute_distance_matrix(self, embeddings: torch.Tensor) -> torch.Tensor: | |
| """Compute pairwise cosine distances between embeddings.""" | |
| # Normalize embeddings to unit length | |
| embeddings = F.normalize(embeddings, p=2, dim=1) | |
| # Compute cosine similarity matrix | |
| similarity_matrix = torch.mm(embeddings, embeddings.t()) | |
| # Convert to distance matrix (1 - similarity) | |
| distance_matrix = 1 - similarity_matrix | |
| return distance_matrix | |
| def _find_semi_hard_triplets( | |
| self, | |
| dist_matrix: torch.Tensor, | |
| labels: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Find semi-hard negative triplets.""" | |
| anchors = [] | |
| positives = [] | |
| negatives = [] | |
| n = len(labels) | |
| for i in range(n): | |
| anchor_label = labels[i] | |
| # Find positive samples (same label) | |
| positive_mask = (labels == anchor_label) & (torch.arange(n, device=labels.device) != i) | |
| positive_indices = torch.where(positive_mask)[0] | |
| if len(positive_indices) == 0: | |
| continue | |
| # Find negative samples (different label) | |
| negative_mask = labels != anchor_label | |
| negative_indices = torch.where(negative_mask)[0] | |
| if len(negative_indices) == 0: | |
| continue | |
| # For each positive, find semi-hard negative | |
| for pos_idx in positive_indices: | |
| pos_dist = dist_matrix[i, pos_idx] | |
| # Find negatives that are harder than positive but not too hard | |
| # Semi-hard: pos_dist < neg_dist < pos_dist + margin | |
| neg_dists = dist_matrix[i, negative_indices] | |
| semi_hard_mask = (neg_dists > pos_dist) & (neg_dists < pos_dist + self.margin) | |
| semi_hard_indices = torch.where(semi_hard_mask)[0] | |
| if len(semi_hard_indices) > 0: | |
| # Choose the hardest semi-hard negative | |
| hardest_idx = semi_hard_indices[torch.argmax(neg_dists[semi_hard_indices])] | |
| neg_idx = negative_indices[hardest_idx] | |
| anchors.append(i) | |
| positives.append(pos_idx) | |
| negatives.append(neg_idx) | |
| if len(anchors) == 0: | |
| return torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long) | |
| return torch.tensor(anchors), torch.tensor(positives), torch.tensor(negatives) | |
| def _random_triplets( | |
| self, | |
| embeddings: torch.Tensor, | |
| labels: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Generate random triplets as fallback.""" | |
| anchors = [] | |
| positives = [] | |
| negatives = [] | |
| n = len(labels) | |
| max_triplets = min(1000, n // 3) # Limit number of random triplets | |
| for _ in range(max_triplets): | |
| # Random anchor | |
| anchor_idx = torch.randint(0, n, (1,)).item() | |
| anchor_label = labels[anchor_idx] | |
| # Random positive (same label) | |
| positive_mask = (labels == anchor_label) & (torch.arange(n, device=labels.device) != anchor_idx) | |
| positive_indices = torch.where(positive_mask)[0] | |
| if len(positive_indices) == 0: | |
| continue | |
| pos_idx = positive_indices[torch.randint(0, len(positive_indices), (1,))].item() | |
| # Random negative (different label) | |
| negative_mask = labels != anchor_label | |
| negative_indices = torch.where(negative_mask)[0] | |
| if len(negative_indices) == 0: | |
| continue | |
| neg_idx = negative_indices[torch.randint(0, len(negative_indices), (1,))].item() | |
| anchors.append(anchor_idx) | |
| positives.append(pos_idx) | |
| negatives.append(neg_idx) | |
| if len(anchors) == 0: | |
| # Last resort: duplicate first sample | |
| return embeddings[:1], embeddings[:1], embeddings[:1] | |
| return torch.tensor(anchors), torch.tensor(positives), torch.tensor(negatives) | |
| class OnlineTripletMiner: | |
| """Online triplet mining for batch training.""" | |
| def __init__(self, margin: float = 0.2, mining_strategy: str = "semi_hard"): | |
| self.margin = margin | |
| self.mining_strategy = mining_strategy | |
| self.semi_hard_miner = SemiHardTripletMiner(margin) | |
| def mine_batch_triplets( | |
| self, | |
| embeddings: torch.Tensor, | |
| labels: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Mine triplets from a batch of embeddings. | |
| Args: | |
| embeddings: (B, D) tensor of normalized embeddings | |
| labels: (B,) tensor of labels | |
| Returns: | |
| anchors, positives, negatives: (K, D) tensors | |
| """ | |
| if self.mining_strategy == "semi_hard": | |
| return self.semi_hard_miner.mine_triplets(embeddings, labels) | |
| elif self.mining_strategy == "hardest": | |
| return self._hardest_triplets(embeddings, labels) | |
| elif self.mining_strategy == "random": | |
| return self._random_batch_triplets(embeddings, labels) | |
| else: | |
| raise ValueError(f"Unknown mining strategy: {self.mining_strategy}") | |
| def _hardest_triplets( | |
| self, | |
| embeddings: torch.Tensor, | |
| labels: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Find hardest negative triplets.""" | |
| dist_matrix = self._compute_distance_matrix(embeddings) | |
| anchors = [] | |
| positives = [] | |
| negatives = [] | |
| n = len(labels) | |
| for i in range(n): | |
| anchor_label = labels[i] | |
| # Find positive samples | |
| positive_mask = (labels == anchor_label) & (torch.arange(n, device=labels.device) != i) | |
| positive_indices = torch.where(positive_mask)[0] | |
| if len(positive_indices) == 0: | |
| continue | |
| # Find negative samples | |
| negative_mask = labels != anchor_label | |
| negative_indices = torch.where(negative_mask)[0] | |
| if len(negative_indices) == 0: | |
| continue | |
| # For each positive, find hardest negative | |
| for pos_idx in positive_indices: | |
| pos_dist = dist_matrix[i, pos_idx] | |
| # Find hardest negative (closest to anchor) | |
| neg_dists = dist_matrix[i, negative_indices] | |
| hardest_idx = torch.argmin(neg_dists) | |
| neg_idx = negative_indices[hardest_idx] | |
| # Only include if negative is closer than positive + margin | |
| if neg_dists[hardest_idx] < pos_dist + self.margin: | |
| anchors.append(i) | |
| positives.append(pos_idx) | |
| negatives.append(neg_idx) | |
| if len(anchors) == 0: | |
| return self._random_batch_triplets(embeddings, labels) | |
| return torch.tensor(anchors), torch.tensor(positives), torch.tensor(negatives) | |
| def _random_batch_triplets( | |
| self, | |
| embeddings: torch.Tensor, | |
| labels: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Generate random triplets from batch.""" | |
| return self.semi_hard_miner._random_triplets(embeddings, labels) | |
| def _compute_distance_matrix(self, embeddings: torch.Tensor) -> torch.Tensor: | |
| """Compute pairwise cosine distances.""" | |
| embeddings = F.normalize(embeddings, p=2, dim=1) | |
| similarity_matrix = torch.mm(embeddings, embeddings.t()) | |
| distance_matrix = 1 - similarity_matrix | |
| return distance_matrix | |
| def create_triplet_miner( | |
| strategy: str = "semi_hard", | |
| margin: float = 0.2 | |
| ) -> OnlineTripletMiner: | |
| """Factory function to create a triplet miner.""" | |
| return OnlineTripletMiner(margin=margin, mining_strategy=strategy) | |
| # Example usage | |
| if __name__ == "__main__": | |
| # Test with dummy data | |
| batch_size = 32 | |
| embed_dim = 128 | |
| num_classes = 8 | |
| # Generate dummy embeddings and labels | |
| embeddings = torch.randn(batch_size, embed_dim) | |
| labels = torch.randint(0, num_classes, (batch_size,)) | |
| # Create miner | |
| miner = create_triplet_miner(strategy="semi_hard", margin=0.2) | |
| # Mine triplets | |
| anchors, positives, negatives = miner.mine_batch_triplets(embeddings, labels) | |
| print(f"Generated {len(anchors)} triplets from batch of {batch_size}") | |
| print(f"Anchor indices: {anchors[:5]}") | |
| print(f"Positive indices: {positives[:5]}") | |
| print(f"Negative indices: {negatives[:5]}") | |