"""inference procedures, algorithms 2-6. Alg 2 forward_predict predict endpoint, aggregate population Alg 3 endpoint_ranking score = reward of predicted endpoint, return topk Alg 4 reward_guidance jacobian pull-back ascent in embedding space Alg 5 project_and_rerank project e* to admissible candidates, rerank k-nn Alg 6 greedy_combinatorial sequentially add the best gene (mean-pool embed) """ from __future__ import annotations import numpy as np import torch from src.models.encoders import build_pert_tensors from src.models.pivot import PIVOT, candidate_embeddings def encode_label(model: PIVOT, data, label: str, device) -> torch.Tensor: g, o, mask, pid = build_pert_tensors(data, [label], device=device) return model.encode(g, o, mask, pid) # (1, m) def encode_gene_set(model: PIVOT, data, genes: list[str], device) -> torch.Tensor: """encode an arbitrary gene set via the dataset operation (mean-pool).""" label = data.sep.join(genes) if genes else data.control_label return encode_label(model, data, label, device) @torch.no_grad() def forward_predict(model: PIVOT, c0: torch.Tensor, e_u: torch.Tensor) -> torch.Tensor: """Alg 2. c0: (N,d) control embeddings; e_u: (1,m) or (N,m). returns endpoint (N,d).""" if e_u.shape[0] == 1 and c0.shape[0] > 1: e_u = e_u.expand(c0.shape[0], -1) return model.endpoint_from_e(c0, e_u) @torch.no_grad() def endpoint_ranking(model: PIVOT, data, candidate_labels, c0: torch.Tensor, reward, topk: int | None = None, device="cpu", chunk: int = 256): """Alg 3. score each candidate by reward of its predicted endpoint population. batched: all candidates x control cells go through the flow map in one (chunked) forward pass, then per-row rewards are averaged per candidate. works for any per-row reward (point/classifier). returns list of (label, score) desc. """ E = torch.as_tensor(candidate_embeddings(model, data, candidate_labels, device), device=device) C, N = E.shape[0], c0.shape[0] scores = np.empty(C, dtype=np.float64) # chunk over candidates to bound memory (c*n rows per chunk) cands_per_chunk = max(1, chunk * 1024 // max(N, 1)) for s in range(0, C, cands_per_chunk): e = E[s: s + cands_per_chunk] c = e.shape[0] c0_rep = c0.unsqueeze(0).expand(c, N, -1).reshape(c * N, -1) e_rep = e.unsqueeze(1).expand(c, N, -1).reshape(c * N, -1) chat = model.endpoint_from_e(c0_rep, e_rep) r = reward(chat).view(c, N).mean(1) scores[s: s + c] = r.detach().cpu().numpy() order = np.argsort(-scores) ranked = [(candidate_labels[i], float(scores[i])) for i in order] return ranked if topk is None else ranked[:topk] def reward_guidance(model: PIVOT, c0: torch.Tensor, reward, e_init: torch.Tensor, steps: int = 25, step_size: float = 0.5, eps: float = 1e-8, normalize: bool = True) -> torch.Tensor: """Alg 4. gradient ascent on reward in embedding space via the flow-map jacobian. e_{l+1} = e_l + gamma * g_e/(norm(g_e)+eps), g_e from autograd.""" e = e_init.detach().clone() for _ in range(steps): e.requires_grad_(True) chat = model.endpoint_from_e(c0, e.expand(c0.shape[0], -1)) r = reward(chat).mean() (g_e,) = torch.autograd.grad(r, e) with torch.no_grad(): step = g_e / (g_e.norm() + eps) if normalize else g_e e = e + step_size * step e = e.detach() return e @torch.no_grad() def project_and_rerank(model: PIVOT, data, candidate_labels, e_star: torch.Tensor, c0: torch.Tensor, reward, k_nearest: int = 10, topk: int = 5, device="cpu"): """Alg 5. project e* to the k nearest admissible candidates, rerank by endpoint reward.""" E = torch.as_tensor(candidate_embeddings(model, data, candidate_labels, device), device=device) d = torch.cdist(e_star.view(1, -1), E).squeeze(0) knn = torch.argsort(d)[:k_nearest].cpu().numpy() sub = [candidate_labels[i] for i in knn] ranked = endpoint_ranking(model, data, sub, c0, reward, device=device) return ranked[:topk] @torch.no_grad() def greedy_combinatorial(model: PIVOT, data, gene_pool, c0: torch.Tensor, reward, max_size: int = 2, min_gain: float = 1e-4, device="cpu"): """Alg 6. greedily add the gene that most improves the reward of the mean-pooled endpoint.""" chosen: list[str] = [] best_score = -np.inf history = [] while len(chosen) < max_size: remaining = [g for g in gene_pool if g not in chosen] if not remaining: break # score all candidate gene-sets {chosen + g} in one batched pass labels = [data.sep.join(chosen + [g]) for g in remaining] ranked = endpoint_ranking(model, data, labels, c0, reward, device=device) best_label, best_local = ranked[0] best_gene = remaining[labels.index(best_label)] if best_local - best_score < min_gain: break chosen.append(best_gene) best_score = best_local history.append((list(chosen), best_score)) return chosen, best_score, history