Spaces:
Sleeping
Sleeping
Create distances_common.py
Browse files- distances_common.py +127 -0
distances_common.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
# ---------- Centroid / Cosine ----------
|
| 7 |
+
|
| 8 |
+
def pairwise_centroid_distances(centroids: torch.Tensor) -> torch.Tensor:
|
| 9 |
+
# centroids: [D, d]
|
| 10 |
+
return torch.cdist(centroids, centroids, p=2)
|
| 11 |
+
|
| 12 |
+
def pairwise_cosine_similarity_distances(centroids: torch.Tensor) -> torch.Tensor:
|
| 13 |
+
C = F.normalize(centroids, dim=1)
|
| 14 |
+
S = torch.matmul(C, C.t()) # cosine similarity
|
| 15 |
+
return 1.0 - S # turn into distance
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ---------- Sliced Wasserstein ----------
|
| 19 |
+
|
| 20 |
+
def _project(X: torch.Tensor, dirs: torch.Tensor) -> torch.Tensor:
|
| 21 |
+
# X: [n, d], dirs: [P, d] -> [P, n]
|
| 22 |
+
return (dirs @ X.t())
|
| 23 |
+
|
| 24 |
+
def _wasserstein_1d(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
| 25 |
+
# a,b: [n] sorted
|
| 26 |
+
# W1 for equal-size empirical distributions = mean|sorted(a)-sorted(b)|
|
| 27 |
+
return torch.mean(torch.abs(a - b))
|
| 28 |
+
|
| 29 |
+
def _random_unit_directions(d: int, P: int, device: torch.device) -> torch.Tensor:
|
| 30 |
+
g = torch.randn(P, d, device=device)
|
| 31 |
+
g = g / (g.norm(dim=1, keepdim=True) + 1e-12)
|
| 32 |
+
return g
|
| 33 |
+
|
| 34 |
+
def _label_indices(labels: torch.Tensor) -> List[torch.Tensor]:
|
| 35 |
+
uniq = torch.unique(labels)
|
| 36 |
+
return [torch.nonzero(labels == u, as_tuple=True)[0] for u in uniq]
|
| 37 |
+
|
| 38 |
+
def sliced_wasserstein_distance_matrix(
|
| 39 |
+
embs: torch.Tensor, # [D, n, d]
|
| 40 |
+
num_projections: int = 64,
|
| 41 |
+
labels_per_ds: Optional[List[torch.Tensor]] = None,
|
| 42 |
+
label_aware: bool = True,
|
| 43 |
+
label_weighting: str = "uniform",
|
| 44 |
+
label_max_per_class: int = int(1e10),
|
| 45 |
+
) -> torch.Tensor:
|
| 46 |
+
device = embs.device
|
| 47 |
+
D, n, d = embs.shape
|
| 48 |
+
dirs = _random_unit_directions(d, num_projections, device=device) # [P, d]
|
| 49 |
+
|
| 50 |
+
# Precompute projections for speed
|
| 51 |
+
# For label-aware mode, we will re-index per class later.
|
| 52 |
+
proj_all = []
|
| 53 |
+
for i in range(D):
|
| 54 |
+
Xi = embs[i] # [n, d]
|
| 55 |
+
Pi = _project(Xi, dirs) # [P, n]
|
| 56 |
+
proj_all.append(Pi)
|
| 57 |
+
|
| 58 |
+
W = torch.zeros(D, D, device=device)
|
| 59 |
+
for i in range(D):
|
| 60 |
+
for j in range(i, D):
|
| 61 |
+
if not label_aware or labels_per_ds is None:
|
| 62 |
+
# Aggregate across all samples
|
| 63 |
+
Pi, Pj = proj_all[i], proj_all[j] # [P, n]
|
| 64 |
+
# Sort per-projection
|
| 65 |
+
Pi_sorted, _ = torch.sort(Pi, dim=1)
|
| 66 |
+
Pj_sorted, _ = torch.sort(Pj, dim=1)
|
| 67 |
+
# W1 averaged over projections
|
| 68 |
+
w = torch.mean(torch.abs(Pi_sorted - Pj_sorted)).item()
|
| 69 |
+
else:
|
| 70 |
+
# Label-aware: average W1 per label overlap
|
| 71 |
+
yi = labels_per_ds[i] if i < len(labels_per_ds) else None
|
| 72 |
+
yj = labels_per_ds[j] if j < len(labels_per_ds) else None
|
| 73 |
+
if yi is None or yj is None or yi.numel() == 0 or yj.numel() == 0:
|
| 74 |
+
# fallback
|
| 75 |
+
Pi, Pj = proj_all[i], proj_all[j]
|
| 76 |
+
Pi_sorted, _ = torch.sort(Pi, dim=1)
|
| 77 |
+
Pj_sorted, _ = torch.sort(Pj, dim=1)
|
| 78 |
+
w = torch.mean(torch.abs(Pi_sorted - Pj_sorted)).item()
|
| 79 |
+
else:
|
| 80 |
+
inds_i = _label_indices(yi)
|
| 81 |
+
inds_j = _label_indices(yj)
|
| 82 |
+
# Map label value to indices list in j
|
| 83 |
+
uniq_j = torch.unique(yj).tolist()
|
| 84 |
+
class_map_j = {int(v): torch.nonzero(yj == v, as_tuple=True)[0] for v in uniq_j}
|
| 85 |
+
ws = []
|
| 86 |
+
weights = []
|
| 87 |
+
for idxs_i in inds_i:
|
| 88 |
+
if idxs_i.numel() == 0:
|
| 89 |
+
continue
|
| 90 |
+
labval = int(yi[idxs_i[0]].item())
|
| 91 |
+
idxs_j = class_map_j.get(labval, None)
|
| 92 |
+
if idxs_j is None or idxs_j.numel() == 0:
|
| 93 |
+
continue
|
| 94 |
+
# Optionally cap per-class count
|
| 95 |
+
ni = min(idxs_i.numel(), label_max_per_class)
|
| 96 |
+
nj = min(idxs_j.numel(), label_max_per_class)
|
| 97 |
+
idxs_i_use = idxs_i[:ni]
|
| 98 |
+
idxs_j_use = idxs_j[:nj]
|
| 99 |
+
Pi = proj_all[i][:, idxs_i_use] # [P, ni]
|
| 100 |
+
Pj = proj_all[j][:, idxs_j_use] # [P, nj]
|
| 101 |
+
# Pad/trim to same length by interpolation or subsampling (simple: subsample to min)
|
| 102 |
+
m = min(Pi.shape[1], Pj.shape[1])
|
| 103 |
+
if m == 0:
|
| 104 |
+
continue
|
| 105 |
+
Pi = Pi[:, :m]
|
| 106 |
+
Pj = Pj[:, :m]
|
| 107 |
+
Pi_sorted, _ = torch.sort(Pi, dim=1)
|
| 108 |
+
Pj_sorted, _ = torch.sort(Pj, dim=1)
|
| 109 |
+
w_ij = torch.mean(torch.abs(Pi_sorted - Pj_sorted)) # scalar tensor
|
| 110 |
+
ws.append(w_ij)
|
| 111 |
+
if label_weighting == "support":
|
| 112 |
+
weights.append(float(m))
|
| 113 |
+
else:
|
| 114 |
+
weights.append(1.0)
|
| 115 |
+
if len(ws) == 0:
|
| 116 |
+
# fallback to non-label-aware
|
| 117 |
+
Pi, Pj = proj_all[i], proj_all[j]
|
| 118 |
+
Pi_sorted, _ = torch.sort(Pi, dim=1)
|
| 119 |
+
Pj_sorted, _ = torch.sort(Pj, dim=1)
|
| 120 |
+
w = torch.mean(torch.abs(Pi_sorted - Pj_sorted)).item()
|
| 121 |
+
else:
|
| 122 |
+
ws_t = torch.stack(ws)
|
| 123 |
+
w = float((ws_t * torch.tensor(weights, device=ws_t.device)).sum() / (torch.tensor(weights, device=ws_t.device).sum()))
|
| 124 |
+
|
| 125 |
+
W[i, j] = W[j, i] = w
|
| 126 |
+
|
| 127 |
+
return W
|