"""perturbation encoder ψ_η plus a builder that turns perturbation labels into padded (gene_id, op_id, mask) tensors. ψ separates gene identity from operation and is permutation-invariant over the gene set (mean pooling), so one encoder handles single-gene and combinatorial perturbations: e_u = (1/M) Σ_j ( W_g[g_j] + W_o[o_j] ) rep modes: gene_op (default), op_only, gene_only, random_id, gene_pathway_op. """ from __future__ import annotations import numpy as np import torch import torch.nn as nn REP_MODES = ("gene_op", "op_only", "gene_only", "random_id", "gene_pathway_op") class PerturbationEncoder(nn.Module): def __init__( self, n_genes: int, n_ops: int, n_perts: int, emb_dim: int = 64, mode: str = "gene_op", gene_pathway: np.ndarray | None = None, # (n_genes,) cluster id, for gene_pathway_op mode n_pathways: int = 0, ): super().__init__() assert mode in REP_MODES, mode self.mode = mode self.emb_dim = emb_dim self.W_g = nn.Embedding(n_genes, emb_dim) self.W_o = nn.Embedding(n_ops, emb_dim) nn.init.normal_(self.W_g.weight, std=0.02) nn.init.normal_(self.W_o.weight, std=0.02) if mode == "random_id": self.W_id = nn.Embedding(n_perts + 1, emb_dim) # +1 for control/unknown nn.init.normal_(self.W_id.weight, std=0.02) if mode == "gene_pathway_op": assert gene_pathway is not None and n_pathways > 0 self.register_buffer("gene_pathway", torch.as_tensor(gene_pathway, dtype=torch.long)) self.W_path = nn.Embedding(n_pathways, emb_dim) nn.init.normal_(self.W_path.weight, std=0.02) def forward(self, gene_ids: torch.Tensor, op_ids: torch.Tensor, mask: torch.Tensor, pert_ids: torch.Tensor | None = None) -> torch.Tensor: """gene_ids/op_ids/mask: (B, Lmax) padded; mask 1 for valid gene slots. returns e_u: (B, emb_dim). control rows (mask all zero) map to the zero vector. """ if self.mode == "random_id": return self.W_id(pert_ids) m = mask.unsqueeze(-1).float() # (b, l, 1) if self.mode == "op_only": per_gene = self.W_o(op_ids) elif self.mode == "gene_only": per_gene = self.W_g(gene_ids) elif self.mode == "gene_pathway_op": path_ids = self.gene_pathway[gene_ids.clamp(min=0)] per_gene = self.W_g(gene_ids) + self.W_o(op_ids) + self.W_path(path_ids) else: # gene_op per_gene = self.W_g(gene_ids) + self.W_o(op_ids) per_gene = per_gene * m denom = m.sum(dim=1).clamp(min=1.0) # (b,1); control -> 1 to avoid /0 (sum is 0 anyway) return per_gene.sum(dim=1) / denom def build_pert_tensors(data, labels, device="cpu"): """convert a list of perturbation labels into padded (gene_ids, op_ids, mask, pert_ids). pert_ids index into ``data.perturbations`` (len = control/unknown sentinel) for random_id mode. """ pert_index = {p: i for i, p in enumerate(data.perturbations)} L = max((len(data.parse(p)) for p in labels), default=1) L = max(L, 1) B = len(labels) g = np.zeros((B, L), dtype=np.int64) o = np.zeros((B, L), dtype=np.int64) mask = np.zeros((B, L), dtype=np.float32) pid = np.full(B, len(data.perturbations), dtype=np.int64) # sentinel for i, p in enumerate(labels): gids, oids = data.pert_gene_op_ids(p) n = len(gids) if n: g[i, :n] = gids o[i, :n] = oids mask[i, :n] = 1.0 if p in pert_index: pid[i] = pert_index[p] t = lambda a: torch.as_tensor(a, device=device) return t(g), t(o), t(mask), t(pid)