PIVOT / src /evaluation /rewards.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
5.45 kB
"""reward definitions for desired-state nomination.
all rewards take predicted endpoints chat (B, d) and return a per-row reward
(B,) that is differentiable in chat (so the flow-map jacobian can pull the
gradient back into perturbation-embedding space, Alg 4). higher = better.
definitions:
target r = -dist(endpoint, c*)^2 (point target)
centroid r = -dist(endpoint, centroid)^2 (target-distribution centroid)
nn_target r = -min_j dist(endpoint, c*_j)^2 (nearest target sample)
mmd r = -mmd^2(pred batch, target) (distributional, batch-level)
wasserstein r = -sinkhorn(pred batch, target) (distributional, batch-level)
classifier r = log p(y* | endpoint) (target-state classifier)
combined r = alpha_dist*r_target + alpha_clf*r_classifier
"""
from __future__ import annotations
import numpy as np
import torch
import torch.nn as nn
class TargetStateClassifier(nn.Module):
"""small mlp classifying whether an embedding belongs to the target state.
trained per target: positives = target-perturbation cells, negatives =
control + random other-perturbation cells. used for r_clf and the target-clf metric."""
def __init__(self, d: int, hidden: int = 128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d, hidden), nn.ReLU(), nn.Linear(hidden, hidden), nn.ReLU(),
nn.Linear(hidden, 1),
)
def forward(self, x):
return self.net(x).squeeze(-1)
def fit(self, pos: np.ndarray, neg: np.ndarray, device, epochs: int = 200, lr: float = 1e-3):
X = torch.as_tensor(np.concatenate([pos, neg]), dtype=torch.float32, device=device)
y = torch.as_tensor(np.concatenate([np.ones(len(pos)), np.zeros(len(neg))]),
dtype=torch.float32, device=device)
opt = torch.optim.Adam(self.parameters(), lr=lr)
lossf = nn.BCEWithLogitsLoss()
self.train()
for _ in range(epochs):
opt.zero_grad()
loss = lossf(self(X), y)
loss.backward()
opt.step()
self.eval()
return self
def _rbf_mmd2(x, y, gamma=None):
"""biased rbf mmd^2 between point sets x (n,d), y (m,d)."""
if gamma is None:
with torch.no_grad():
d2 = torch.cdist(x[: min(200, len(x))], y[: min(200, len(y))]).pow(2)
med = d2.median().clamp(min=1e-6)
gamma = 1.0 / med
def k(a, b):
return torch.exp(-gamma * torch.cdist(a, b).pow(2))
return k(x, x).mean() + k(y, y).mean() - 2 * k(x, y).mean()
def _sinkhorn(x, y, eps=0.1, iters=50):
"""entropic ot (sinkhorn) cost between empirical x (n,d), y (m,d)."""
C = torch.cdist(x, y).pow(2)
n, m = C.shape
a = torch.full((n,), 1.0 / n, device=x.device)
b = torch.full((m,), 1.0 / m, device=x.device)
K = torch.exp(-C / eps)
u = torch.ones_like(a)
for _ in range(iters):
v = b / (K.t() @ u + 1e-8)
u = a / (K @ v + 1e-8)
P = torch.diag(u) @ K @ torch.diag(v)
return (P * C).sum()
class Reward:
"""configurable reward."""
def __init__(self, kind: str = "centroid", target_c=None, target_sample=None,
classifier: TargetStateClassifier | None = None,
alpha_dist: float = 1.0, alpha_clf: float = 1.0, device="cpu",
control_ref=None):
self.kind = kind
self.device = device
self.alpha_dist = alpha_dist
self.alpha_clf = alpha_clf
self.classifier = classifier
self.target_c = (torch.as_tensor(target_c, dtype=torch.float32, device=device)
if target_c is not None else None)
self.target_sample = (torch.as_tensor(target_sample, dtype=torch.float32, device=device)
if target_sample is not None else None)
# control reference (mean control embedding) for direction-aware rewards
self.control_ref = (torch.as_tensor(control_ref, dtype=torch.float32, device=device)
if control_ref is not None else None)
def __call__(self, chat: torch.Tensor) -> torch.Tensor:
k = self.kind
if k in ("target", "centroid"):
return -((chat - self.target_c) ** 2).sum(-1)
if k == "cosine":
# direction-aware: cosine between predicted effect and target effect
cref = self.control_ref if self.control_ref is not None else chat.mean(0, keepdim=True)
pe = chat - cref
te = (self.target_c - cref).view(1, -1)
pe = pe / (pe.norm(dim=-1, keepdim=True) + 1e-8)
te = te / (te.norm(dim=-1, keepdim=True) + 1e-8)
return (pe * te).sum(-1)
if k == "nn_target":
d2 = torch.cdist(chat, self.target_sample).pow(2)
return -d2.min(dim=1).values
if k == "mmd":
return -_rbf_mmd2(chat, self.target_sample).expand(chat.shape[0])
if k == "wasserstein":
return -_sinkhorn(chat, self.target_sample).expand(chat.shape[0])
if k == "classifier":
logit = self.classifier(chat)
return torch.nn.functional.logsigmoid(logit)
if k == "combined":
rt = -((chat - self.target_c) ** 2).sum(-1)
rc = torch.nn.functional.logsigmoid(self.classifier(chat))
return self.alpha_dist * rt + self.alpha_clf * rc
raise ValueError(k)