PIVOT / src /evaluation /inference.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
5.27 kB
"""inference procedures, algorithms 2-6.
Alg 2 forward_predict predict endpoint, aggregate population
Alg 3 endpoint_ranking score = reward of predicted endpoint, return topk
Alg 4 reward_guidance jacobian pull-back ascent in embedding space
Alg 5 project_and_rerank project e* to admissible candidates, rerank k-nn
Alg 6 greedy_combinatorial sequentially add the best gene (mean-pool embed)
"""
from __future__ import annotations
import numpy as np
import torch
from src.models.encoders import build_pert_tensors
from src.models.pivot import PIVOT, candidate_embeddings
def encode_label(model: PIVOT, data, label: str, device) -> torch.Tensor:
g, o, mask, pid = build_pert_tensors(data, [label], device=device)
return model.encode(g, o, mask, pid) # (1, m)
def encode_gene_set(model: PIVOT, data, genes: list[str], device) -> torch.Tensor:
"""encode an arbitrary gene set via the dataset operation (mean-pool)."""
label = data.sep.join(genes) if genes else data.control_label
return encode_label(model, data, label, device)
@torch.no_grad()
def forward_predict(model: PIVOT, c0: torch.Tensor, e_u: torch.Tensor) -> torch.Tensor:
"""Alg 2. c0: (N,d) control embeddings; e_u: (1,m) or (N,m). returns endpoint (N,d)."""
if e_u.shape[0] == 1 and c0.shape[0] > 1:
e_u = e_u.expand(c0.shape[0], -1)
return model.endpoint_from_e(c0, e_u)
@torch.no_grad()
def endpoint_ranking(model: PIVOT, data, candidate_labels, c0: torch.Tensor, reward,
topk: int | None = None, device="cpu", chunk: int = 256):
"""Alg 3. score each candidate by reward of its predicted endpoint population.
batched: all candidates x control cells go through the flow map in one
(chunked) forward pass, then per-row rewards are averaged per candidate. works
for any per-row reward (point/classifier). returns list of (label, score) desc.
"""
E = torch.as_tensor(candidate_embeddings(model, data, candidate_labels, device), device=device)
C, N = E.shape[0], c0.shape[0]
scores = np.empty(C, dtype=np.float64)
# chunk over candidates to bound memory (c*n rows per chunk)
cands_per_chunk = max(1, chunk * 1024 // max(N, 1))
for s in range(0, C, cands_per_chunk):
e = E[s: s + cands_per_chunk]
c = e.shape[0]
c0_rep = c0.unsqueeze(0).expand(c, N, -1).reshape(c * N, -1)
e_rep = e.unsqueeze(1).expand(c, N, -1).reshape(c * N, -1)
chat = model.endpoint_from_e(c0_rep, e_rep)
r = reward(chat).view(c, N).mean(1)
scores[s: s + c] = r.detach().cpu().numpy()
order = np.argsort(-scores)
ranked = [(candidate_labels[i], float(scores[i])) for i in order]
return ranked if topk is None else ranked[:topk]
def reward_guidance(model: PIVOT, c0: torch.Tensor, reward, e_init: torch.Tensor,
steps: int = 25, step_size: float = 0.5, eps: float = 1e-8,
normalize: bool = True) -> torch.Tensor:
"""Alg 4. gradient ascent on reward in embedding space via the flow-map jacobian.
e_{l+1} = e_l + gamma * g_e/(norm(g_e)+eps), g_e from autograd."""
e = e_init.detach().clone()
for _ in range(steps):
e.requires_grad_(True)
chat = model.endpoint_from_e(c0, e.expand(c0.shape[0], -1))
r = reward(chat).mean()
(g_e,) = torch.autograd.grad(r, e)
with torch.no_grad():
step = g_e / (g_e.norm() + eps) if normalize else g_e
e = e + step_size * step
e = e.detach()
return e
@torch.no_grad()
def project_and_rerank(model: PIVOT, data, candidate_labels, e_star: torch.Tensor,
c0: torch.Tensor, reward, k_nearest: int = 10, topk: int = 5,
device="cpu"):
"""Alg 5. project e* to the k nearest admissible candidates, rerank by endpoint reward."""
E = torch.as_tensor(candidate_embeddings(model, data, candidate_labels, device), device=device)
d = torch.cdist(e_star.view(1, -1), E).squeeze(0)
knn = torch.argsort(d)[:k_nearest].cpu().numpy()
sub = [candidate_labels[i] for i in knn]
ranked = endpoint_ranking(model, data, sub, c0, reward, device=device)
return ranked[:topk]
@torch.no_grad()
def greedy_combinatorial(model: PIVOT, data, gene_pool, c0: torch.Tensor, reward,
max_size: int = 2, min_gain: float = 1e-4, device="cpu"):
"""Alg 6. greedily add the gene that most improves the reward of the mean-pooled endpoint."""
chosen: list[str] = []
best_score = -np.inf
history = []
while len(chosen) < max_size:
remaining = [g for g in gene_pool if g not in chosen]
if not remaining:
break
# score all candidate gene-sets {chosen + g} in one batched pass
labels = [data.sep.join(chosen + [g]) for g in remaining]
ranked = endpoint_ranking(model, data, labels, c0, reward, device=device)
best_label, best_local = ranked[0]
best_gene = remaining[labels.index(best_label)]
if best_local - best_score < min_gain:
break
chosen.append(best_gene)
best_score = best_local
history.append((list(chosen), best_score))
return chosen, best_score, history