File size: 4,480 Bytes
7efee70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import yaml
import string
import secrets
import os

import torch
import wandb
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
from torchdyn.core import NeuralODE

import torch

@torch.no_grad()
def gather_local_starts(x0s, X0_pool, N, k=64):
    # for each anchor b, take its k-NN from pool, then sample N distinct
    B, G = x0s.shape
    d2 = torch.cdist(x0s, X0_pool).pow(2)             # (B, M0)
    knn_idx = d2.topk(k=min(k, X0_pool.size(0)), largest=False).indices  # (B,k)
    x0_clusters = []
    for b in range(B):
        choices = knn_idx[b]
        pick = choices[torch.randperm(choices.numel(), device=choices.device)[:N]]
        x0_clusters.append(X0_pool[pick])             # (N,G)
    return torch.stack(x0_clusters, dim=0)            # (B,N,G)

@torch.no_grad()
def make_aligned_clusters(ot_sampler, x0s, x1s, N, replace=True, k_local=128):
    
    device, dtype = x0s.device, x0s.dtype
    
    B, G = x0s.shape
    M = x1s.shape[0]
    # Use gather_local_starts to get N distinct cells for each source
    x0_clusters = gather_local_starts(x0s, x0s, N, k=k_local).to(device=device, dtype=dtype)
    x1_clusters = torch.empty((B, N, G), device=device, dtype=dtype)
    idx1 = torch.empty((B, N), device=device, dtype=torch.long)

    # Try to get a full coupling once (preferred: row-stochastic matrix P of shape (B, M))
    P = None
    if hasattr(ot_sampler, "coupling"):
        P = ot_sampler.coupling(x0s, x1s)  # expected (B, M) torch tensor
    elif hasattr(ot_sampler, "plan"):
        P = ot_sampler.plan(x0s, x1s)      # same expectation
    # If your ot_sampler only supports sampling, we’ll fall back row-by-row below.

    for b in range(B):
        x0_b = x0s[b:b+1]                  # (1, G)

        if P is not None:
            # --- Sample N targets from the row distribution P[b] ---
            probs = P[b].clamp_min(0)
            probs = probs / probs.sum().clamp_min(1e-12)
            if replace:
                j = torch.multinomial(probs, num_samples=N, replacement=True)   # (N,)
            else:
                k = min(N, (probs > 0).sum().item())
                j = torch.multinomial(probs, num_samples=k, replacement=False)
                if k < N:  # pad by repeating the last choice to keep shape
                    j = torch.cat([j, j[-1:].expand(N-k)], dim=0)
            x1_match = x1s[j]              # (N, G)
        else:
            # --- Row-wise fallback using sampler’s own sampling API ---
            # Try to ask for N pairs at once
            got = False
            if hasattr(ot_sampler, "sample_plan"):
                try:
                    # many samplers support an argument like n_pairs / k / n
                    x0_rep, x1_match = ot_sampler.sample_plan(
                        x0_b, x1s, replace=replace, n_pairs=N
                    )
                    # x0_rep: (N, G) or (1, N, G) -> squeeze if needed
                    x1_match = x1_match.view(N, G)
                    got = True
                except TypeError:
                    pass
            if not got:
                # last resort: call sample_plan N times
                xs, ys, js = [], [], []
                for _ in range(N):
                    x0_rep, x1_one = ot_sampler.sample_plan(x0_b, x1s, replace=replace)
                    # infer index by nearest neighbor for bookkeeping (optional)
                    j_hat = torch.cdist(x1_one.view(1, -1), x1s).argmin()
                    xs.append(x0_rep.view(1, G))
                    ys.append(x1_one.view(1, G))
                    js.append(j_hat.view(1))
                x1_match = torch.cat(ys, dim=0)
                j = torch.cat(js, dim=0)

        # Fill clusters (source row replicated N times)
        #x0_clusters[b] = x0_b.expand(N, G)
        x1_clusters[b] = x1_match
        idx1[b] = j

    return x0_clusters, x1_clusters, idx1


def load_config(path):
    with open(path, "r") as file:
        config = yaml.safe_load(file)
    return config


def merge_config(args, config_updates):
    for key, value in config_updates.items():
        if not hasattr(args, key):
            raise ValueError(
                f"Unknown configuration parameter '{key}' found in the config file."
            )
        setattr(args, key, value)
    return args


def generate_group_string(length=16):
    alphabet = string.ascii_letters + string.digits
    return "".join(secrets.choice(alphabet) for _ in range(length))