from typing import List, Optional import numpy as np import torch import torch.nn.functional as F # ---------- Centroid / Cosine ---------- def pairwise_centroid_distances(centroids: torch.Tensor) -> torch.Tensor: # centroids: [D, d] return torch.cdist(centroids, centroids, p=2) def pairwise_cosine_similarity_distances(centroids: torch.Tensor) -> torch.Tensor: C = F.normalize(centroids, dim=1) S = torch.matmul(C, C.t()) # cosine similarity return 1.0 - S # turn into distance # ---------- Sliced Wasserstein ---------- def _project(X: torch.Tensor, dirs: torch.Tensor) -> torch.Tensor: # X: [n, d], dirs: [P, d] -> [P, n] return (dirs @ X.t()) def _wasserstein_1d(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: # a,b: [n] sorted # W1 for equal-size empirical distributions = mean|sorted(a)-sorted(b)| return torch.mean(torch.abs(a - b)) def _random_unit_directions(d: int, P: int, device: torch.device) -> torch.Tensor: g = torch.randn(P, d, device=device) g = g / (g.norm(dim=1, keepdim=True) + 1e-12) return g def _label_indices(labels: torch.Tensor) -> List[torch.Tensor]: uniq = torch.unique(labels) return [torch.nonzero(labels == u, as_tuple=True)[0] for u in uniq] def sliced_wasserstein_distance_matrix( embs: torch.Tensor, # [D, n, d] num_projections: int = 64, labels_per_ds: Optional[List[torch.Tensor]] = None, label_aware: bool = True, label_weighting: str = "uniform", label_max_per_class: int = int(1e10), ) -> torch.Tensor: device = embs.device D, n, d = embs.shape dirs = _random_unit_directions(d, num_projections, device=device) # [P, d] # Precompute projections for speed # For label-aware mode, we will re-index per class later. proj_all = [] for i in range(D): Xi = embs[i] # [n, d] Pi = _project(Xi, dirs) # [P, n] proj_all.append(Pi) W = torch.zeros(D, D, device=device) for i in range(D): for j in range(i, D): if not label_aware or labels_per_ds is None: # Aggregate across all samples Pi, Pj = proj_all[i], proj_all[j] # [P, n] # Sort per-projection Pi_sorted, _ = torch.sort(Pi, dim=1) Pj_sorted, _ = torch.sort(Pj, dim=1) # W1 averaged over projections w = torch.mean(torch.abs(Pi_sorted - Pj_sorted)).item() else: # Label-aware: average W1 per label overlap yi = labels_per_ds[i] if i < len(labels_per_ds) else None yj = labels_per_ds[j] if j < len(labels_per_ds) else None if yi is None or yj is None or yi.numel() == 0 or yj.numel() == 0: # fallback Pi, Pj = proj_all[i], proj_all[j] Pi_sorted, _ = torch.sort(Pi, dim=1) Pj_sorted, _ = torch.sort(Pj, dim=1) w = torch.mean(torch.abs(Pi_sorted - Pj_sorted)).item() else: inds_i = _label_indices(yi) inds_j = _label_indices(yj) # Map label value to indices list in j uniq_j = torch.unique(yj).tolist() class_map_j = {int(v): torch.nonzero(yj == v, as_tuple=True)[0] for v in uniq_j} ws = [] weights = [] for idxs_i in inds_i: if idxs_i.numel() == 0: continue labval = int(yi[idxs_i[0]].item()) idxs_j = class_map_j.get(labval, None) if idxs_j is None or idxs_j.numel() == 0: continue # Optionally cap per-class count ni = min(idxs_i.numel(), label_max_per_class) nj = min(idxs_j.numel(), label_max_per_class) idxs_i_use = idxs_i[:ni] idxs_j_use = idxs_j[:nj] Pi = proj_all[i][:, idxs_i_use] # [P, ni] Pj = proj_all[j][:, idxs_j_use] # [P, nj] # Pad/trim to same length by interpolation or subsampling (simple: subsample to min) m = min(Pi.shape[1], Pj.shape[1]) if m == 0: continue Pi = Pi[:, :m] Pj = Pj[:, :m] Pi_sorted, _ = torch.sort(Pi, dim=1) Pj_sorted, _ = torch.sort(Pj, dim=1) w_ij = torch.mean(torch.abs(Pi_sorted - Pj_sorted)) # scalar tensor ws.append(w_ij) if label_weighting == "support": weights.append(float(m)) else: weights.append(1.0) if len(ws) == 0: # fallback to non-label-aware Pi, Pj = proj_all[i], proj_all[j] Pi_sorted, _ = torch.sort(Pi, dim=1) Pj_sorted, _ = torch.sort(Pj, dim=1) w = torch.mean(torch.abs(Pi_sorted - Pj_sorted)).item() else: ws_t = torch.stack(ws) w = float((ws_t * torch.tensor(weights, device=ws_t.device)).sum() / (torch.tensor(weights, device=ws_t.device).sum())) W[i, j] = W[j, i] = w return W