| """reward definitions for desired-state nomination. |
| |
| all rewards take predicted endpoints chat (B, d) and return a per-row reward |
| (B,) that is differentiable in chat (so the flow-map jacobian can pull the |
| gradient back into perturbation-embedding space, Alg 4). higher = better. |
| |
| definitions: |
| target r = -dist(endpoint, c*)^2 (point target) |
| centroid r = -dist(endpoint, centroid)^2 (target-distribution centroid) |
| nn_target r = -min_j dist(endpoint, c*_j)^2 (nearest target sample) |
| mmd r = -mmd^2(pred batch, target) (distributional, batch-level) |
| wasserstein r = -sinkhorn(pred batch, target) (distributional, batch-level) |
| classifier r = log p(y* | endpoint) (target-state classifier) |
| combined r = alpha_dist*r_target + alpha_clf*r_classifier |
| """ |
| from __future__ import annotations |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
|
|
| class TargetStateClassifier(nn.Module): |
| """small mlp classifying whether an embedding belongs to the target state. |
| |
| trained per target: positives = target-perturbation cells, negatives = |
| control + random other-perturbation cells. used for r_clf and the target-clf metric.""" |
|
|
| def __init__(self, d: int, hidden: int = 128): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(d, hidden), nn.ReLU(), nn.Linear(hidden, hidden), nn.ReLU(), |
| nn.Linear(hidden, 1), |
| ) |
|
|
| def forward(self, x): |
| return self.net(x).squeeze(-1) |
|
|
| def fit(self, pos: np.ndarray, neg: np.ndarray, device, epochs: int = 200, lr: float = 1e-3): |
| X = torch.as_tensor(np.concatenate([pos, neg]), dtype=torch.float32, device=device) |
| y = torch.as_tensor(np.concatenate([np.ones(len(pos)), np.zeros(len(neg))]), |
| dtype=torch.float32, device=device) |
| opt = torch.optim.Adam(self.parameters(), lr=lr) |
| lossf = nn.BCEWithLogitsLoss() |
| self.train() |
| for _ in range(epochs): |
| opt.zero_grad() |
| loss = lossf(self(X), y) |
| loss.backward() |
| opt.step() |
| self.eval() |
| return self |
|
|
|
|
| def _rbf_mmd2(x, y, gamma=None): |
| """biased rbf mmd^2 between point sets x (n,d), y (m,d).""" |
| if gamma is None: |
| with torch.no_grad(): |
| d2 = torch.cdist(x[: min(200, len(x))], y[: min(200, len(y))]).pow(2) |
| med = d2.median().clamp(min=1e-6) |
| gamma = 1.0 / med |
| def k(a, b): |
| return torch.exp(-gamma * torch.cdist(a, b).pow(2)) |
| return k(x, x).mean() + k(y, y).mean() - 2 * k(x, y).mean() |
|
|
|
|
| def _sinkhorn(x, y, eps=0.1, iters=50): |
| """entropic ot (sinkhorn) cost between empirical x (n,d), y (m,d).""" |
| C = torch.cdist(x, y).pow(2) |
| n, m = C.shape |
| a = torch.full((n,), 1.0 / n, device=x.device) |
| b = torch.full((m,), 1.0 / m, device=x.device) |
| K = torch.exp(-C / eps) |
| u = torch.ones_like(a) |
| for _ in range(iters): |
| v = b / (K.t() @ u + 1e-8) |
| u = a / (K @ v + 1e-8) |
| P = torch.diag(u) @ K @ torch.diag(v) |
| return (P * C).sum() |
|
|
|
|
| class Reward: |
| """configurable reward.""" |
|
|
| def __init__(self, kind: str = "centroid", target_c=None, target_sample=None, |
| classifier: TargetStateClassifier | None = None, |
| alpha_dist: float = 1.0, alpha_clf: float = 1.0, device="cpu", |
| control_ref=None): |
| self.kind = kind |
| self.device = device |
| self.alpha_dist = alpha_dist |
| self.alpha_clf = alpha_clf |
| self.classifier = classifier |
| self.target_c = (torch.as_tensor(target_c, dtype=torch.float32, device=device) |
| if target_c is not None else None) |
| self.target_sample = (torch.as_tensor(target_sample, dtype=torch.float32, device=device) |
| if target_sample is not None else None) |
| |
| self.control_ref = (torch.as_tensor(control_ref, dtype=torch.float32, device=device) |
| if control_ref is not None else None) |
|
|
| def __call__(self, chat: torch.Tensor) -> torch.Tensor: |
| k = self.kind |
| if k in ("target", "centroid"): |
| return -((chat - self.target_c) ** 2).sum(-1) |
| if k == "cosine": |
| |
| cref = self.control_ref if self.control_ref is not None else chat.mean(0, keepdim=True) |
| pe = chat - cref |
| te = (self.target_c - cref).view(1, -1) |
| pe = pe / (pe.norm(dim=-1, keepdim=True) + 1e-8) |
| te = te / (te.norm(dim=-1, keepdim=True) + 1e-8) |
| return (pe * te).sum(-1) |
| if k == "nn_target": |
| d2 = torch.cdist(chat, self.target_sample).pow(2) |
| return -d2.min(dim=1).values |
| if k == "mmd": |
| return -_rbf_mmd2(chat, self.target_sample).expand(chat.shape[0]) |
| if k == "wasserstein": |
| return -_sinkhorn(chat, self.target_sample).expand(chat.shape[0]) |
| if k == "classifier": |
| logit = self.classifier(chat) |
| return torch.nn.functional.logsigmoid(logit) |
| if k == "combined": |
| rt = -((chat - self.target_c) ** 2).sum(-1) |
| rc = torch.nn.functional.logsigmoid(self.classifier(chat)) |
| return self.alpha_dist * rt + self.alpha_clf * rc |
| raise ValueError(k) |
|
|