"""reward definitions for desired-state nomination. all rewards take predicted endpoints chat (B, d) and return a per-row reward (B,) that is differentiable in chat (so the flow-map jacobian can pull the gradient back into perturbation-embedding space, Alg 4). higher = better. definitions: target r = -dist(endpoint, c*)^2 (point target) centroid r = -dist(endpoint, centroid)^2 (target-distribution centroid) nn_target r = -min_j dist(endpoint, c*_j)^2 (nearest target sample) mmd r = -mmd^2(pred batch, target) (distributional, batch-level) wasserstein r = -sinkhorn(pred batch, target) (distributional, batch-level) classifier r = log p(y* | endpoint) (target-state classifier) combined r = alpha_dist*r_target + alpha_clf*r_classifier """ from __future__ import annotations import numpy as np import torch import torch.nn as nn class TargetStateClassifier(nn.Module): """small mlp classifying whether an embedding belongs to the target state. trained per target: positives = target-perturbation cells, negatives = control + random other-perturbation cells. used for r_clf and the target-clf metric.""" def __init__(self, d: int, hidden: int = 128): super().__init__() self.net = nn.Sequential( nn.Linear(d, hidden), nn.ReLU(), nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1), ) def forward(self, x): return self.net(x).squeeze(-1) def fit(self, pos: np.ndarray, neg: np.ndarray, device, epochs: int = 200, lr: float = 1e-3): X = torch.as_tensor(np.concatenate([pos, neg]), dtype=torch.float32, device=device) y = torch.as_tensor(np.concatenate([np.ones(len(pos)), np.zeros(len(neg))]), dtype=torch.float32, device=device) opt = torch.optim.Adam(self.parameters(), lr=lr) lossf = nn.BCEWithLogitsLoss() self.train() for _ in range(epochs): opt.zero_grad() loss = lossf(self(X), y) loss.backward() opt.step() self.eval() return self def _rbf_mmd2(x, y, gamma=None): """biased rbf mmd^2 between point sets x (n,d), y (m,d).""" if gamma is None: with torch.no_grad(): d2 = torch.cdist(x[: min(200, len(x))], y[: min(200, len(y))]).pow(2) med = d2.median().clamp(min=1e-6) gamma = 1.0 / med def k(a, b): return torch.exp(-gamma * torch.cdist(a, b).pow(2)) return k(x, x).mean() + k(y, y).mean() - 2 * k(x, y).mean() def _sinkhorn(x, y, eps=0.1, iters=50): """entropic ot (sinkhorn) cost between empirical x (n,d), y (m,d).""" C = torch.cdist(x, y).pow(2) n, m = C.shape a = torch.full((n,), 1.0 / n, device=x.device) b = torch.full((m,), 1.0 / m, device=x.device) K = torch.exp(-C / eps) u = torch.ones_like(a) for _ in range(iters): v = b / (K.t() @ u + 1e-8) u = a / (K @ v + 1e-8) P = torch.diag(u) @ K @ torch.diag(v) return (P * C).sum() class Reward: """configurable reward.""" def __init__(self, kind: str = "centroid", target_c=None, target_sample=None, classifier: TargetStateClassifier | None = None, alpha_dist: float = 1.0, alpha_clf: float = 1.0, device="cpu", control_ref=None): self.kind = kind self.device = device self.alpha_dist = alpha_dist self.alpha_clf = alpha_clf self.classifier = classifier self.target_c = (torch.as_tensor(target_c, dtype=torch.float32, device=device) if target_c is not None else None) self.target_sample = (torch.as_tensor(target_sample, dtype=torch.float32, device=device) if target_sample is not None else None) # control reference (mean control embedding) for direction-aware rewards self.control_ref = (torch.as_tensor(control_ref, dtype=torch.float32, device=device) if control_ref is not None else None) def __call__(self, chat: torch.Tensor) -> torch.Tensor: k = self.kind if k in ("target", "centroid"): return -((chat - self.target_c) ** 2).sum(-1) if k == "cosine": # direction-aware: cosine between predicted effect and target effect cref = self.control_ref if self.control_ref is not None else chat.mean(0, keepdim=True) pe = chat - cref te = (self.target_c - cref).view(1, -1) pe = pe / (pe.norm(dim=-1, keepdim=True) + 1e-8) te = te / (te.norm(dim=-1, keepdim=True) + 1e-8) return (pe * te).sum(-1) if k == "nn_target": d2 = torch.cdist(chat, self.target_sample).pow(2) return -d2.min(dim=1).values if k == "mmd": return -_rbf_mmd2(chat, self.target_sample).expand(chat.shape[0]) if k == "wasserstein": return -_sinkhorn(chat, self.target_sample).expand(chat.shape[0]) if k == "classifier": logit = self.classifier(chat) return torch.nn.functional.logsigmoid(logit) if k == "combined": rt = -((chat - self.target_c) ** 2).sum(-1) rc = torch.nn.functional.logsigmoid(self.classifier(chat)) return self.alpha_dist * rt + self.alpha_clf * rc raise ValueError(k)