PIVOT / src /models /encoders.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
3.79 kB
"""perturbation encoder ψ_η plus a builder that turns perturbation labels into
padded (gene_id, op_id, mask) tensors.
ψ separates gene identity from operation and is permutation-invariant over the
gene set (mean pooling), so one encoder handles single-gene and combinatorial
perturbations:
e_u = (1/M) Σ_j ( W_g[g_j] + W_o[o_j] )
rep modes: gene_op (default), op_only, gene_only, random_id, gene_pathway_op.
"""
from __future__ import annotations
import numpy as np
import torch
import torch.nn as nn
REP_MODES = ("gene_op", "op_only", "gene_only", "random_id", "gene_pathway_op")
class PerturbationEncoder(nn.Module):
def __init__(
self,
n_genes: int,
n_ops: int,
n_perts: int,
emb_dim: int = 64,
mode: str = "gene_op",
gene_pathway: np.ndarray | None = None, # (n_genes,) cluster id, for gene_pathway_op mode
n_pathways: int = 0,
):
super().__init__()
assert mode in REP_MODES, mode
self.mode = mode
self.emb_dim = emb_dim
self.W_g = nn.Embedding(n_genes, emb_dim)
self.W_o = nn.Embedding(n_ops, emb_dim)
nn.init.normal_(self.W_g.weight, std=0.02)
nn.init.normal_(self.W_o.weight, std=0.02)
if mode == "random_id":
self.W_id = nn.Embedding(n_perts + 1, emb_dim) # +1 for control/unknown
nn.init.normal_(self.W_id.weight, std=0.02)
if mode == "gene_pathway_op":
assert gene_pathway is not None and n_pathways > 0
self.register_buffer("gene_pathway", torch.as_tensor(gene_pathway, dtype=torch.long))
self.W_path = nn.Embedding(n_pathways, emb_dim)
nn.init.normal_(self.W_path.weight, std=0.02)
def forward(self, gene_ids: torch.Tensor, op_ids: torch.Tensor, mask: torch.Tensor,
pert_ids: torch.Tensor | None = None) -> torch.Tensor:
"""gene_ids/op_ids/mask: (B, Lmax) padded; mask 1 for valid gene slots.
returns e_u: (B, emb_dim). control rows (mask all zero) map to the zero vector.
"""
if self.mode == "random_id":
return self.W_id(pert_ids)
m = mask.unsqueeze(-1).float() # (b, l, 1)
if self.mode == "op_only":
per_gene = self.W_o(op_ids)
elif self.mode == "gene_only":
per_gene = self.W_g(gene_ids)
elif self.mode == "gene_pathway_op":
path_ids = self.gene_pathway[gene_ids.clamp(min=0)]
per_gene = self.W_g(gene_ids) + self.W_o(op_ids) + self.W_path(path_ids)
else: # gene_op
per_gene = self.W_g(gene_ids) + self.W_o(op_ids)
per_gene = per_gene * m
denom = m.sum(dim=1).clamp(min=1.0) # (b,1); control -> 1 to avoid /0 (sum is 0 anyway)
return per_gene.sum(dim=1) / denom
def build_pert_tensors(data, labels, device="cpu"):
"""convert a list of perturbation labels into padded (gene_ids, op_ids, mask, pert_ids).
pert_ids index into ``data.perturbations`` (len = control/unknown sentinel) for random_id mode.
"""
pert_index = {p: i for i, p in enumerate(data.perturbations)}
L = max((len(data.parse(p)) for p in labels), default=1)
L = max(L, 1)
B = len(labels)
g = np.zeros((B, L), dtype=np.int64)
o = np.zeros((B, L), dtype=np.int64)
mask = np.zeros((B, L), dtype=np.float32)
pid = np.full(B, len(data.perturbations), dtype=np.int64) # sentinel
for i, p in enumerate(labels):
gids, oids = data.pert_gene_op_ids(p)
n = len(gids)
if n:
g[i, :n] = gids
o[i, :n] = oids
mask[i, :n] = 1.0
if p in pert_index:
pid[i] = pert_index[p]
t = lambda a: torch.as_tensor(a, device=device)
return t(g), t(o), t(mask), t(pid)