| """perturbation encoder ψ_η plus a builder that turns perturbation labels into |
| padded (gene_id, op_id, mask) tensors. |
| |
| ψ separates gene identity from operation and is permutation-invariant over the |
| gene set (mean pooling), so one encoder handles single-gene and combinatorial |
| perturbations: |
| |
| e_u = (1/M) Σ_j ( W_g[g_j] + W_o[o_j] ) |
| |
| rep modes: gene_op (default), op_only, gene_only, random_id, gene_pathway_op. |
| """ |
| from __future__ import annotations |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| REP_MODES = ("gene_op", "op_only", "gene_only", "random_id", "gene_pathway_op") |
|
|
|
|
| class PerturbationEncoder(nn.Module): |
| def __init__( |
| self, |
| n_genes: int, |
| n_ops: int, |
| n_perts: int, |
| emb_dim: int = 64, |
| mode: str = "gene_op", |
| gene_pathway: np.ndarray | None = None, |
| n_pathways: int = 0, |
| ): |
| super().__init__() |
| assert mode in REP_MODES, mode |
| self.mode = mode |
| self.emb_dim = emb_dim |
| self.W_g = nn.Embedding(n_genes, emb_dim) |
| self.W_o = nn.Embedding(n_ops, emb_dim) |
| nn.init.normal_(self.W_g.weight, std=0.02) |
| nn.init.normal_(self.W_o.weight, std=0.02) |
| if mode == "random_id": |
| self.W_id = nn.Embedding(n_perts + 1, emb_dim) |
| nn.init.normal_(self.W_id.weight, std=0.02) |
| if mode == "gene_pathway_op": |
| assert gene_pathway is not None and n_pathways > 0 |
| self.register_buffer("gene_pathway", torch.as_tensor(gene_pathway, dtype=torch.long)) |
| self.W_path = nn.Embedding(n_pathways, emb_dim) |
| nn.init.normal_(self.W_path.weight, std=0.02) |
|
|
| def forward(self, gene_ids: torch.Tensor, op_ids: torch.Tensor, mask: torch.Tensor, |
| pert_ids: torch.Tensor | None = None) -> torch.Tensor: |
| """gene_ids/op_ids/mask: (B, Lmax) padded; mask 1 for valid gene slots. |
| |
| returns e_u: (B, emb_dim). control rows (mask all zero) map to the zero vector. |
| """ |
| if self.mode == "random_id": |
| return self.W_id(pert_ids) |
|
|
| m = mask.unsqueeze(-1).float() |
| if self.mode == "op_only": |
| per_gene = self.W_o(op_ids) |
| elif self.mode == "gene_only": |
| per_gene = self.W_g(gene_ids) |
| elif self.mode == "gene_pathway_op": |
| path_ids = self.gene_pathway[gene_ids.clamp(min=0)] |
| per_gene = self.W_g(gene_ids) + self.W_o(op_ids) + self.W_path(path_ids) |
| else: |
| per_gene = self.W_g(gene_ids) + self.W_o(op_ids) |
| per_gene = per_gene * m |
| denom = m.sum(dim=1).clamp(min=1.0) |
| return per_gene.sum(dim=1) / denom |
|
|
|
|
| def build_pert_tensors(data, labels, device="cpu"): |
| """convert a list of perturbation labels into padded (gene_ids, op_ids, mask, pert_ids). |
| |
| pert_ids index into ``data.perturbations`` (len = control/unknown sentinel) for random_id mode. |
| """ |
| pert_index = {p: i for i, p in enumerate(data.perturbations)} |
| L = max((len(data.parse(p)) for p in labels), default=1) |
| L = max(L, 1) |
| B = len(labels) |
| g = np.zeros((B, L), dtype=np.int64) |
| o = np.zeros((B, L), dtype=np.int64) |
| mask = np.zeros((B, L), dtype=np.float32) |
| pid = np.full(B, len(data.perturbations), dtype=np.int64) |
| for i, p in enumerate(labels): |
| gids, oids = data.pert_gene_op_ids(p) |
| n = len(gids) |
| if n: |
| g[i, :n] = gids |
| o[i, :n] = oids |
| mask[i, :n] = 1.0 |
| if p in pert_index: |
| pid[i] = pert_index[p] |
| t = lambda a: torch.as_tensor(a, device=device) |
| return t(g), t(o), t(mask), t(pid) |
|
|