| """inference procedures, algorithms 2-6. |
| |
| Alg 2 forward_predict predict endpoint, aggregate population |
| Alg 3 endpoint_ranking score = reward of predicted endpoint, return topk |
| Alg 4 reward_guidance jacobian pull-back ascent in embedding space |
| Alg 5 project_and_rerank project e* to admissible candidates, rerank k-nn |
| Alg 6 greedy_combinatorial sequentially add the best gene (mean-pool embed) |
| """ |
| from __future__ import annotations |
|
|
| import numpy as np |
| import torch |
|
|
| from src.models.encoders import build_pert_tensors |
| from src.models.pivot import PIVOT, candidate_embeddings |
|
|
|
|
| def encode_label(model: PIVOT, data, label: str, device) -> torch.Tensor: |
| g, o, mask, pid = build_pert_tensors(data, [label], device=device) |
| return model.encode(g, o, mask, pid) |
|
|
|
|
| def encode_gene_set(model: PIVOT, data, genes: list[str], device) -> torch.Tensor: |
| """encode an arbitrary gene set via the dataset operation (mean-pool).""" |
| label = data.sep.join(genes) if genes else data.control_label |
| return encode_label(model, data, label, device) |
|
|
|
|
| @torch.no_grad() |
| def forward_predict(model: PIVOT, c0: torch.Tensor, e_u: torch.Tensor) -> torch.Tensor: |
| """Alg 2. c0: (N,d) control embeddings; e_u: (1,m) or (N,m). returns endpoint (N,d).""" |
| if e_u.shape[0] == 1 and c0.shape[0] > 1: |
| e_u = e_u.expand(c0.shape[0], -1) |
| return model.endpoint_from_e(c0, e_u) |
|
|
|
|
| @torch.no_grad() |
| def endpoint_ranking(model: PIVOT, data, candidate_labels, c0: torch.Tensor, reward, |
| topk: int | None = None, device="cpu", chunk: int = 256): |
| """Alg 3. score each candidate by reward of its predicted endpoint population. |
| |
| batched: all candidates x control cells go through the flow map in one |
| (chunked) forward pass, then per-row rewards are averaged per candidate. works |
| for any per-row reward (point/classifier). returns list of (label, score) desc. |
| """ |
| E = torch.as_tensor(candidate_embeddings(model, data, candidate_labels, device), device=device) |
| C, N = E.shape[0], c0.shape[0] |
| scores = np.empty(C, dtype=np.float64) |
| |
| cands_per_chunk = max(1, chunk * 1024 // max(N, 1)) |
| for s in range(0, C, cands_per_chunk): |
| e = E[s: s + cands_per_chunk] |
| c = e.shape[0] |
| c0_rep = c0.unsqueeze(0).expand(c, N, -1).reshape(c * N, -1) |
| e_rep = e.unsqueeze(1).expand(c, N, -1).reshape(c * N, -1) |
| chat = model.endpoint_from_e(c0_rep, e_rep) |
| r = reward(chat).view(c, N).mean(1) |
| scores[s: s + c] = r.detach().cpu().numpy() |
| order = np.argsort(-scores) |
| ranked = [(candidate_labels[i], float(scores[i])) for i in order] |
| return ranked if topk is None else ranked[:topk] |
|
|
|
|
| def reward_guidance(model: PIVOT, c0: torch.Tensor, reward, e_init: torch.Tensor, |
| steps: int = 25, step_size: float = 0.5, eps: float = 1e-8, |
| normalize: bool = True) -> torch.Tensor: |
| """Alg 4. gradient ascent on reward in embedding space via the flow-map jacobian. |
| |
| e_{l+1} = e_l + gamma * g_e/(norm(g_e)+eps), g_e from autograd.""" |
| e = e_init.detach().clone() |
| for _ in range(steps): |
| e.requires_grad_(True) |
| chat = model.endpoint_from_e(c0, e.expand(c0.shape[0], -1)) |
| r = reward(chat).mean() |
| (g_e,) = torch.autograd.grad(r, e) |
| with torch.no_grad(): |
| step = g_e / (g_e.norm() + eps) if normalize else g_e |
| e = e + step_size * step |
| e = e.detach() |
| return e |
|
|
|
|
| @torch.no_grad() |
| def project_and_rerank(model: PIVOT, data, candidate_labels, e_star: torch.Tensor, |
| c0: torch.Tensor, reward, k_nearest: int = 10, topk: int = 5, |
| device="cpu"): |
| """Alg 5. project e* to the k nearest admissible candidates, rerank by endpoint reward.""" |
| E = torch.as_tensor(candidate_embeddings(model, data, candidate_labels, device), device=device) |
| d = torch.cdist(e_star.view(1, -1), E).squeeze(0) |
| knn = torch.argsort(d)[:k_nearest].cpu().numpy() |
| sub = [candidate_labels[i] for i in knn] |
| ranked = endpoint_ranking(model, data, sub, c0, reward, device=device) |
| return ranked[:topk] |
|
|
|
|
| @torch.no_grad() |
| def greedy_combinatorial(model: PIVOT, data, gene_pool, c0: torch.Tensor, reward, |
| max_size: int = 2, min_gain: float = 1e-4, device="cpu"): |
| """Alg 6. greedily add the gene that most improves the reward of the mean-pooled endpoint.""" |
| chosen: list[str] = [] |
| best_score = -np.inf |
| history = [] |
| while len(chosen) < max_size: |
| remaining = [g for g in gene_pool if g not in chosen] |
| if not remaining: |
| break |
| |
| labels = [data.sep.join(chosen + [g]) for g in remaining] |
| ranked = endpoint_ranking(model, data, labels, c0, reward, device=device) |
| best_label, best_local = ranked[0] |
| best_gene = remaining[labels.index(best_label)] |
| if best_local - best_score < min_gain: |
| break |
| chosen.append(best_gene) |
| best_score = best_local |
| history.append((list(chosen), best_score)) |
| return chosen, best_score, history |
|
|