wi-lab commited on
Commit
cea1e88
·
1 Parent(s): 3c27f51

Create distances_common.py

Browse files
Files changed (1) hide show
  1. distances_common.py +127 -0
distances_common.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ # ---------- Centroid / Cosine ----------
7
+
8
+ def pairwise_centroid_distances(centroids: torch.Tensor) -> torch.Tensor:
9
+ # centroids: [D, d]
10
+ return torch.cdist(centroids, centroids, p=2)
11
+
12
+ def pairwise_cosine_similarity_distances(centroids: torch.Tensor) -> torch.Tensor:
13
+ C = F.normalize(centroids, dim=1)
14
+ S = torch.matmul(C, C.t()) # cosine similarity
15
+ return 1.0 - S # turn into distance
16
+
17
+
18
+ # ---------- Sliced Wasserstein ----------
19
+
20
+ def _project(X: torch.Tensor, dirs: torch.Tensor) -> torch.Tensor:
21
+ # X: [n, d], dirs: [P, d] -> [P, n]
22
+ return (dirs @ X.t())
23
+
24
+ def _wasserstein_1d(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
25
+ # a,b: [n] sorted
26
+ # W1 for equal-size empirical distributions = mean|sorted(a)-sorted(b)|
27
+ return torch.mean(torch.abs(a - b))
28
+
29
+ def _random_unit_directions(d: int, P: int, device: torch.device) -> torch.Tensor:
30
+ g = torch.randn(P, d, device=device)
31
+ g = g / (g.norm(dim=1, keepdim=True) + 1e-12)
32
+ return g
33
+
34
+ def _label_indices(labels: torch.Tensor) -> List[torch.Tensor]:
35
+ uniq = torch.unique(labels)
36
+ return [torch.nonzero(labels == u, as_tuple=True)[0] for u in uniq]
37
+
38
+ def sliced_wasserstein_distance_matrix(
39
+ embs: torch.Tensor, # [D, n, d]
40
+ num_projections: int = 64,
41
+ labels_per_ds: Optional[List[torch.Tensor]] = None,
42
+ label_aware: bool = True,
43
+ label_weighting: str = "uniform",
44
+ label_max_per_class: int = int(1e10),
45
+ ) -> torch.Tensor:
46
+ device = embs.device
47
+ D, n, d = embs.shape
48
+ dirs = _random_unit_directions(d, num_projections, device=device) # [P, d]
49
+
50
+ # Precompute projections for speed
51
+ # For label-aware mode, we will re-index per class later.
52
+ proj_all = []
53
+ for i in range(D):
54
+ Xi = embs[i] # [n, d]
55
+ Pi = _project(Xi, dirs) # [P, n]
56
+ proj_all.append(Pi)
57
+
58
+ W = torch.zeros(D, D, device=device)
59
+ for i in range(D):
60
+ for j in range(i, D):
61
+ if not label_aware or labels_per_ds is None:
62
+ # Aggregate across all samples
63
+ Pi, Pj = proj_all[i], proj_all[j] # [P, n]
64
+ # Sort per-projection
65
+ Pi_sorted, _ = torch.sort(Pi, dim=1)
66
+ Pj_sorted, _ = torch.sort(Pj, dim=1)
67
+ # W1 averaged over projections
68
+ w = torch.mean(torch.abs(Pi_sorted - Pj_sorted)).item()
69
+ else:
70
+ # Label-aware: average W1 per label overlap
71
+ yi = labels_per_ds[i] if i < len(labels_per_ds) else None
72
+ yj = labels_per_ds[j] if j < len(labels_per_ds) else None
73
+ if yi is None or yj is None or yi.numel() == 0 or yj.numel() == 0:
74
+ # fallback
75
+ Pi, Pj = proj_all[i], proj_all[j]
76
+ Pi_sorted, _ = torch.sort(Pi, dim=1)
77
+ Pj_sorted, _ = torch.sort(Pj, dim=1)
78
+ w = torch.mean(torch.abs(Pi_sorted - Pj_sorted)).item()
79
+ else:
80
+ inds_i = _label_indices(yi)
81
+ inds_j = _label_indices(yj)
82
+ # Map label value to indices list in j
83
+ uniq_j = torch.unique(yj).tolist()
84
+ class_map_j = {int(v): torch.nonzero(yj == v, as_tuple=True)[0] for v in uniq_j}
85
+ ws = []
86
+ weights = []
87
+ for idxs_i in inds_i:
88
+ if idxs_i.numel() == 0:
89
+ continue
90
+ labval = int(yi[idxs_i[0]].item())
91
+ idxs_j = class_map_j.get(labval, None)
92
+ if idxs_j is None or idxs_j.numel() == 0:
93
+ continue
94
+ # Optionally cap per-class count
95
+ ni = min(idxs_i.numel(), label_max_per_class)
96
+ nj = min(idxs_j.numel(), label_max_per_class)
97
+ idxs_i_use = idxs_i[:ni]
98
+ idxs_j_use = idxs_j[:nj]
99
+ Pi = proj_all[i][:, idxs_i_use] # [P, ni]
100
+ Pj = proj_all[j][:, idxs_j_use] # [P, nj]
101
+ # Pad/trim to same length by interpolation or subsampling (simple: subsample to min)
102
+ m = min(Pi.shape[1], Pj.shape[1])
103
+ if m == 0:
104
+ continue
105
+ Pi = Pi[:, :m]
106
+ Pj = Pj[:, :m]
107
+ Pi_sorted, _ = torch.sort(Pi, dim=1)
108
+ Pj_sorted, _ = torch.sort(Pj, dim=1)
109
+ w_ij = torch.mean(torch.abs(Pi_sorted - Pj_sorted)) # scalar tensor
110
+ ws.append(w_ij)
111
+ if label_weighting == "support":
112
+ weights.append(float(m))
113
+ else:
114
+ weights.append(1.0)
115
+ if len(ws) == 0:
116
+ # fallback to non-label-aware
117
+ Pi, Pj = proj_all[i], proj_all[j]
118
+ Pi_sorted, _ = torch.sort(Pi, dim=1)
119
+ Pj_sorted, _ = torch.sort(Pj, dim=1)
120
+ w = torch.mean(torch.abs(Pi_sorted - Pj_sorted)).item()
121
+ else:
122
+ ws_t = torch.stack(ws)
123
+ w = float((ws_t * torch.tensor(weights, device=ws_t.device)).sum() / (torch.tensor(weights, device=ws_t.device).sum()))
124
+
125
+ W[i, j] = W[j, i] = w
126
+
127
+ return W