PIVOT / src /models /pivot.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
2.14 kB
"""the pivot model: perturbation encoder ψ_η + flow map X_θ, plus the
inverse-design jacobian used by reward guidance (Alg 4)."""
from __future__ import annotations
import numpy as np
import torch
import torch.nn as nn
from src.models.encoders import PerturbationEncoder, build_pert_tensors
from src.models.flow_map import FlowMap
class PIVOT(nn.Module):
def __init__(self, d_state: int, n_genes: int, n_ops: int, n_perts: int,
d_pert: int = 64, hidden: int = 512, depth: int = 4,
rep_mode: str = "gene_op", gene_pathway=None, n_pathways: int = 0,
dropout: float = 0.0):
super().__init__()
self.encoder = PerturbationEncoder(
n_genes, n_ops, n_perts, emb_dim=d_pert, mode=rep_mode,
gene_pathway=gene_pathway, n_pathways=n_pathways,
)
self.flow = FlowMap(d_state, d_pert, hidden=hidden, depth=depth, dropout=dropout)
self.d_state, self.d_pert = d_state, d_pert
def encode(self, g, o, mask, pid):
return self.encoder(g, o, mask, pid)
def endpoint_from_e(self, c0, e):
return self.flow.endpoint(c0, e)
def endpoint_from_pert(self, c0, g, o, mask, pid):
e = self.encode(g, o, mask, pid)
if e.shape[0] == 1 and c0.shape[0] > 1:
e = e.expand(c0.shape[0], -1)
return self.flow.endpoint(c0, e)
# --- inverse-design jacobian (Alg 4 building block) ---
def endpoint_jacobian_e(self, c0_single, e_single):
"""∇_e X_{0,1}(c0, e) for a single (c0, e). returns (d_state, d_pert)."""
e = e_single.detach().clone().requires_grad_(True)
c0 = c0_single.unsqueeze(0)
def f(ev):
return self.flow.endpoint(c0, ev.unsqueeze(0)).squeeze(0)
return torch.autograd.functional.jacobian(f, e, create_graph=False)
@torch.no_grad()
def candidate_embeddings(model: PIVOT, data, labels, device) -> np.ndarray:
"""ψ(u) for every candidate perturbation label -> (n_cand, d_pert)."""
g, o, mask, pid = build_pert_tensors(data, labels, device=device)
return model.encode(g, o, mask, pid).cpu().numpy()