| """the pivot model: perturbation encoder ψ_η + flow map X_θ, plus the |
| inverse-design jacobian used by reward guidance (Alg 4).""" |
| from __future__ import annotations |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| from src.models.encoders import PerturbationEncoder, build_pert_tensors |
| from src.models.flow_map import FlowMap |
|
|
|
|
| class PIVOT(nn.Module): |
| def __init__(self, d_state: int, n_genes: int, n_ops: int, n_perts: int, |
| d_pert: int = 64, hidden: int = 512, depth: int = 4, |
| rep_mode: str = "gene_op", gene_pathway=None, n_pathways: int = 0, |
| dropout: float = 0.0): |
| super().__init__() |
| self.encoder = PerturbationEncoder( |
| n_genes, n_ops, n_perts, emb_dim=d_pert, mode=rep_mode, |
| gene_pathway=gene_pathway, n_pathways=n_pathways, |
| ) |
| self.flow = FlowMap(d_state, d_pert, hidden=hidden, depth=depth, dropout=dropout) |
| self.d_state, self.d_pert = d_state, d_pert |
|
|
| def encode(self, g, o, mask, pid): |
| return self.encoder(g, o, mask, pid) |
|
|
| def endpoint_from_e(self, c0, e): |
| return self.flow.endpoint(c0, e) |
|
|
| def endpoint_from_pert(self, c0, g, o, mask, pid): |
| e = self.encode(g, o, mask, pid) |
| if e.shape[0] == 1 and c0.shape[0] > 1: |
| e = e.expand(c0.shape[0], -1) |
| return self.flow.endpoint(c0, e) |
|
|
| |
| def endpoint_jacobian_e(self, c0_single, e_single): |
| """∇_e X_{0,1}(c0, e) for a single (c0, e). returns (d_state, d_pert).""" |
| e = e_single.detach().clone().requires_grad_(True) |
| c0 = c0_single.unsqueeze(0) |
|
|
| def f(ev): |
| return self.flow.endpoint(c0, ev.unsqueeze(0)).squeeze(0) |
|
|
| return torch.autograd.functional.jacobian(f, e, create_graph=False) |
|
|
|
|
| @torch.no_grad() |
| def candidate_embeddings(model: PIVOT, data, labels, device) -> np.ndarray: |
| """ψ(u) for every candidate perturbation label -> (n_cand, d_pert).""" |
| g, o, mask, pid = build_pert_tensors(data, labels, device=device) |
| return model.encode(g, o, mask, pid).cpu().numpy() |
|
|