File size: 2,139 Bytes
3b4941f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""the pivot model: perturbation encoder ψ_η + flow map X_θ, plus the
inverse-design jacobian used by reward guidance (Alg 4)."""
from __future__ import annotations

import numpy as np
import torch
import torch.nn as nn

from src.models.encoders import PerturbationEncoder, build_pert_tensors
from src.models.flow_map import FlowMap


class PIVOT(nn.Module):
    def __init__(self, d_state: int, n_genes: int, n_ops: int, n_perts: int,
                 d_pert: int = 64, hidden: int = 512, depth: int = 4,
                 rep_mode: str = "gene_op", gene_pathway=None, n_pathways: int = 0,
                 dropout: float = 0.0):
        super().__init__()
        self.encoder = PerturbationEncoder(
            n_genes, n_ops, n_perts, emb_dim=d_pert, mode=rep_mode,
            gene_pathway=gene_pathway, n_pathways=n_pathways,
        )
        self.flow = FlowMap(d_state, d_pert, hidden=hidden, depth=depth, dropout=dropout)
        self.d_state, self.d_pert = d_state, d_pert

    def encode(self, g, o, mask, pid):
        return self.encoder(g, o, mask, pid)

    def endpoint_from_e(self, c0, e):
        return self.flow.endpoint(c0, e)

    def endpoint_from_pert(self, c0, g, o, mask, pid):
        e = self.encode(g, o, mask, pid)
        if e.shape[0] == 1 and c0.shape[0] > 1:
            e = e.expand(c0.shape[0], -1)
        return self.flow.endpoint(c0, e)

    # --- inverse-design jacobian (Alg 4 building block) ---
    def endpoint_jacobian_e(self, c0_single, e_single):
        """∇_e X_{0,1}(c0, e) for a single (c0, e). returns (d_state, d_pert)."""
        e = e_single.detach().clone().requires_grad_(True)
        c0 = c0_single.unsqueeze(0)

        def f(ev):
            return self.flow.endpoint(c0, ev.unsqueeze(0)).squeeze(0)

        return torch.autograd.functional.jacobian(f, e, create_graph=False)


@torch.no_grad()
def candidate_embeddings(model: PIVOT, data, labels, device) -> np.ndarray:
    """ψ(u) for every candidate perturbation label -> (n_cand, d_pert)."""
    g, o, mask, pid = build_pert_tensors(data, labels, device=device)
    return model.encode(g, o, mask, pid).cpu().numpy()