dataset-distancing-lab / distances_common.py
wi-lab's picture
Create distances_common.py
cea1e88
from typing import List, Optional
import numpy as np
import torch
import torch.nn.functional as F
# ---------- Centroid / Cosine ----------
def pairwise_centroid_distances(centroids: torch.Tensor) -> torch.Tensor:
# centroids: [D, d]
return torch.cdist(centroids, centroids, p=2)
def pairwise_cosine_similarity_distances(centroids: torch.Tensor) -> torch.Tensor:
C = F.normalize(centroids, dim=1)
S = torch.matmul(C, C.t()) # cosine similarity
return 1.0 - S # turn into distance
# ---------- Sliced Wasserstein ----------
def _project(X: torch.Tensor, dirs: torch.Tensor) -> torch.Tensor:
# X: [n, d], dirs: [P, d] -> [P, n]
return (dirs @ X.t())
def _wasserstein_1d(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
# a,b: [n] sorted
# W1 for equal-size empirical distributions = mean|sorted(a)-sorted(b)|
return torch.mean(torch.abs(a - b))
def _random_unit_directions(d: int, P: int, device: torch.device) -> torch.Tensor:
g = torch.randn(P, d, device=device)
g = g / (g.norm(dim=1, keepdim=True) + 1e-12)
return g
def _label_indices(labels: torch.Tensor) -> List[torch.Tensor]:
uniq = torch.unique(labels)
return [torch.nonzero(labels == u, as_tuple=True)[0] for u in uniq]
def sliced_wasserstein_distance_matrix(
embs: torch.Tensor, # [D, n, d]
num_projections: int = 64,
labels_per_ds: Optional[List[torch.Tensor]] = None,
label_aware: bool = True,
label_weighting: str = "uniform",
label_max_per_class: int = int(1e10),
) -> torch.Tensor:
device = embs.device
D, n, d = embs.shape
dirs = _random_unit_directions(d, num_projections, device=device) # [P, d]
# Precompute projections for speed
# For label-aware mode, we will re-index per class later.
proj_all = []
for i in range(D):
Xi = embs[i] # [n, d]
Pi = _project(Xi, dirs) # [P, n]
proj_all.append(Pi)
W = torch.zeros(D, D, device=device)
for i in range(D):
for j in range(i, D):
if not label_aware or labels_per_ds is None:
# Aggregate across all samples
Pi, Pj = proj_all[i], proj_all[j] # [P, n]
# Sort per-projection
Pi_sorted, _ = torch.sort(Pi, dim=1)
Pj_sorted, _ = torch.sort(Pj, dim=1)
# W1 averaged over projections
w = torch.mean(torch.abs(Pi_sorted - Pj_sorted)).item()
else:
# Label-aware: average W1 per label overlap
yi = labels_per_ds[i] if i < len(labels_per_ds) else None
yj = labels_per_ds[j] if j < len(labels_per_ds) else None
if yi is None or yj is None or yi.numel() == 0 or yj.numel() == 0:
# fallback
Pi, Pj = proj_all[i], proj_all[j]
Pi_sorted, _ = torch.sort(Pi, dim=1)
Pj_sorted, _ = torch.sort(Pj, dim=1)
w = torch.mean(torch.abs(Pi_sorted - Pj_sorted)).item()
else:
inds_i = _label_indices(yi)
inds_j = _label_indices(yj)
# Map label value to indices list in j
uniq_j = torch.unique(yj).tolist()
class_map_j = {int(v): torch.nonzero(yj == v, as_tuple=True)[0] for v in uniq_j}
ws = []
weights = []
for idxs_i in inds_i:
if idxs_i.numel() == 0:
continue
labval = int(yi[idxs_i[0]].item())
idxs_j = class_map_j.get(labval, None)
if idxs_j is None or idxs_j.numel() == 0:
continue
# Optionally cap per-class count
ni = min(idxs_i.numel(), label_max_per_class)
nj = min(idxs_j.numel(), label_max_per_class)
idxs_i_use = idxs_i[:ni]
idxs_j_use = idxs_j[:nj]
Pi = proj_all[i][:, idxs_i_use] # [P, ni]
Pj = proj_all[j][:, idxs_j_use] # [P, nj]
# Pad/trim to same length by interpolation or subsampling (simple: subsample to min)
m = min(Pi.shape[1], Pj.shape[1])
if m == 0:
continue
Pi = Pi[:, :m]
Pj = Pj[:, :m]
Pi_sorted, _ = torch.sort(Pi, dim=1)
Pj_sorted, _ = torch.sort(Pj, dim=1)
w_ij = torch.mean(torch.abs(Pi_sorted - Pj_sorted)) # scalar tensor
ws.append(w_ij)
if label_weighting == "support":
weights.append(float(m))
else:
weights.append(1.0)
if len(ws) == 0:
# fallback to non-label-aware
Pi, Pj = proj_all[i], proj_all[j]
Pi_sorted, _ = torch.sort(Pi, dim=1)
Pj_sorted, _ = torch.sort(Pj, dim=1)
w = torch.mean(torch.abs(Pi_sorted - Pj_sorted)).item()
else:
ws_t = torch.stack(ws)
w = float((ws_t * torch.tensor(weights, device=ws_t.device)).sum() / (torch.tensor(weights, device=ws_t.device).sum()))
W[i, j] = W[j, i] = w
return W