"""the pivot model: perturbation encoder ψ_η + flow map X_θ, plus the inverse-design jacobian used by reward guidance (Alg 4).""" from __future__ import annotations import numpy as np import torch import torch.nn as nn from src.models.encoders import PerturbationEncoder, build_pert_tensors from src.models.flow_map import FlowMap class PIVOT(nn.Module): def __init__(self, d_state: int, n_genes: int, n_ops: int, n_perts: int, d_pert: int = 64, hidden: int = 512, depth: int = 4, rep_mode: str = "gene_op", gene_pathway=None, n_pathways: int = 0, dropout: float = 0.0): super().__init__() self.encoder = PerturbationEncoder( n_genes, n_ops, n_perts, emb_dim=d_pert, mode=rep_mode, gene_pathway=gene_pathway, n_pathways=n_pathways, ) self.flow = FlowMap(d_state, d_pert, hidden=hidden, depth=depth, dropout=dropout) self.d_state, self.d_pert = d_state, d_pert def encode(self, g, o, mask, pid): return self.encoder(g, o, mask, pid) def endpoint_from_e(self, c0, e): return self.flow.endpoint(c0, e) def endpoint_from_pert(self, c0, g, o, mask, pid): e = self.encode(g, o, mask, pid) if e.shape[0] == 1 and c0.shape[0] > 1: e = e.expand(c0.shape[0], -1) return self.flow.endpoint(c0, e) # --- inverse-design jacobian (Alg 4 building block) --- def endpoint_jacobian_e(self, c0_single, e_single): """∇_e X_{0,1}(c0, e) for a single (c0, e). returns (d_state, d_pert).""" e = e_single.detach().clone().requires_grad_(True) c0 = c0_single.unsqueeze(0) def f(ev): return self.flow.endpoint(c0, ev.unsqueeze(0)).squeeze(0) return torch.autograd.functional.jacobian(f, e, create_graph=False) @torch.no_grad() def candidate_embeddings(model: PIVOT, data, labels, device) -> np.ndarray: """ψ(u) for every candidate perturbation label -> (n_cand, d_pert).""" g, o, mask, pid = build_pert_tensors(data, labels, device=device) return model.encode(g, o, mask, pid).cpu().numpy()