Spaces:
Running
Running
| from typing import List, Optional | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| # ---------- Centroid / Cosine ---------- | |
| def pairwise_centroid_distances(centroids: torch.Tensor) -> torch.Tensor: | |
| # centroids: [D, d] | |
| return torch.cdist(centroids, centroids, p=2) | |
| def pairwise_cosine_similarity_distances(centroids: torch.Tensor) -> torch.Tensor: | |
| C = F.normalize(centroids, dim=1) | |
| S = torch.matmul(C, C.t()) # cosine similarity | |
| return 1.0 - S # turn into distance | |
| # ---------- Sliced Wasserstein ---------- | |
| def _project(X: torch.Tensor, dirs: torch.Tensor) -> torch.Tensor: | |
| # X: [n, d], dirs: [P, d] -> [P, n] | |
| return (dirs @ X.t()) | |
| def _wasserstein_1d(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | |
| # a,b: [n] sorted | |
| # W1 for equal-size empirical distributions = mean|sorted(a)-sorted(b)| | |
| return torch.mean(torch.abs(a - b)) | |
| def _random_unit_directions(d: int, P: int, device: torch.device) -> torch.Tensor: | |
| g = torch.randn(P, d, device=device) | |
| g = g / (g.norm(dim=1, keepdim=True) + 1e-12) | |
| return g | |
| def _label_indices(labels: torch.Tensor) -> List[torch.Tensor]: | |
| uniq = torch.unique(labels) | |
| return [torch.nonzero(labels == u, as_tuple=True)[0] for u in uniq] | |
| def sliced_wasserstein_distance_matrix( | |
| embs: torch.Tensor, # [D, n, d] | |
| num_projections: int = 64, | |
| labels_per_ds: Optional[List[torch.Tensor]] = None, | |
| label_aware: bool = True, | |
| label_weighting: str = "uniform", | |
| label_max_per_class: int = int(1e10), | |
| ) -> torch.Tensor: | |
| device = embs.device | |
| D, n, d = embs.shape | |
| dirs = _random_unit_directions(d, num_projections, device=device) # [P, d] | |
| # Precompute projections for speed | |
| # For label-aware mode, we will re-index per class later. | |
| proj_all = [] | |
| for i in range(D): | |
| Xi = embs[i] # [n, d] | |
| Pi = _project(Xi, dirs) # [P, n] | |
| proj_all.append(Pi) | |
| W = torch.zeros(D, D, device=device) | |
| for i in range(D): | |
| for j in range(i, D): | |
| if not label_aware or labels_per_ds is None: | |
| # Aggregate across all samples | |
| Pi, Pj = proj_all[i], proj_all[j] # [P, n] | |
| # Sort per-projection | |
| Pi_sorted, _ = torch.sort(Pi, dim=1) | |
| Pj_sorted, _ = torch.sort(Pj, dim=1) | |
| # W1 averaged over projections | |
| w = torch.mean(torch.abs(Pi_sorted - Pj_sorted)).item() | |
| else: | |
| # Label-aware: average W1 per label overlap | |
| yi = labels_per_ds[i] if i < len(labels_per_ds) else None | |
| yj = labels_per_ds[j] if j < len(labels_per_ds) else None | |
| if yi is None or yj is None or yi.numel() == 0 or yj.numel() == 0: | |
| # fallback | |
| Pi, Pj = proj_all[i], proj_all[j] | |
| Pi_sorted, _ = torch.sort(Pi, dim=1) | |
| Pj_sorted, _ = torch.sort(Pj, dim=1) | |
| w = torch.mean(torch.abs(Pi_sorted - Pj_sorted)).item() | |
| else: | |
| inds_i = _label_indices(yi) | |
| inds_j = _label_indices(yj) | |
| # Map label value to indices list in j | |
| uniq_j = torch.unique(yj).tolist() | |
| class_map_j = {int(v): torch.nonzero(yj == v, as_tuple=True)[0] for v in uniq_j} | |
| ws = [] | |
| weights = [] | |
| for idxs_i in inds_i: | |
| if idxs_i.numel() == 0: | |
| continue | |
| labval = int(yi[idxs_i[0]].item()) | |
| idxs_j = class_map_j.get(labval, None) | |
| if idxs_j is None or idxs_j.numel() == 0: | |
| continue | |
| # Optionally cap per-class count | |
| ni = min(idxs_i.numel(), label_max_per_class) | |
| nj = min(idxs_j.numel(), label_max_per_class) | |
| idxs_i_use = idxs_i[:ni] | |
| idxs_j_use = idxs_j[:nj] | |
| Pi = proj_all[i][:, idxs_i_use] # [P, ni] | |
| Pj = proj_all[j][:, idxs_j_use] # [P, nj] | |
| # Pad/trim to same length by interpolation or subsampling (simple: subsample to min) | |
| m = min(Pi.shape[1], Pj.shape[1]) | |
| if m == 0: | |
| continue | |
| Pi = Pi[:, :m] | |
| Pj = Pj[:, :m] | |
| Pi_sorted, _ = torch.sort(Pi, dim=1) | |
| Pj_sorted, _ = torch.sort(Pj, dim=1) | |
| w_ij = torch.mean(torch.abs(Pi_sorted - Pj_sorted)) # scalar tensor | |
| ws.append(w_ij) | |
| if label_weighting == "support": | |
| weights.append(float(m)) | |
| else: | |
| weights.append(1.0) | |
| if len(ws) == 0: | |
| # fallback to non-label-aware | |
| Pi, Pj = proj_all[i], proj_all[j] | |
| Pi_sorted, _ = torch.sort(Pi, dim=1) | |
| Pj_sorted, _ = torch.sort(Pj, dim=1) | |
| w = torch.mean(torch.abs(Pi_sorted - Pj_sorted)).item() | |
| else: | |
| ws_t = torch.stack(ws) | |
| w = float((ws_t * torch.tensor(weights, device=ws_t.device)).sum() / (torch.tensor(weights, device=ws_t.device).sum())) | |
| W[i, j] = W[j, i] = w | |
| return W |