File size: 5,265 Bytes
3b4941f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""inference procedures, algorithms 2-6.

  Alg 2  forward_predict       predict endpoint, aggregate population
  Alg 3  endpoint_ranking      score = reward of predicted endpoint, return topk
  Alg 4  reward_guidance       jacobian pull-back ascent in embedding space
  Alg 5  project_and_rerank    project e* to admissible candidates, rerank k-nn
  Alg 6  greedy_combinatorial  sequentially add the best gene (mean-pool embed)
"""
from __future__ import annotations

import numpy as np
import torch

from src.models.encoders import build_pert_tensors
from src.models.pivot import PIVOT, candidate_embeddings


def encode_label(model: PIVOT, data, label: str, device) -> torch.Tensor:
    g, o, mask, pid = build_pert_tensors(data, [label], device=device)
    return model.encode(g, o, mask, pid)  # (1, m)


def encode_gene_set(model: PIVOT, data, genes: list[str], device) -> torch.Tensor:
    """encode an arbitrary gene set via the dataset operation (mean-pool)."""
    label = data.sep.join(genes) if genes else data.control_label
    return encode_label(model, data, label, device)


@torch.no_grad()
def forward_predict(model: PIVOT, c0: torch.Tensor, e_u: torch.Tensor) -> torch.Tensor:
    """Alg 2. c0: (N,d) control embeddings; e_u: (1,m) or (N,m). returns endpoint (N,d)."""
    if e_u.shape[0] == 1 and c0.shape[0] > 1:
        e_u = e_u.expand(c0.shape[0], -1)
    return model.endpoint_from_e(c0, e_u)


@torch.no_grad()
def endpoint_ranking(model: PIVOT, data, candidate_labels, c0: torch.Tensor, reward,
                     topk: int | None = None, device="cpu", chunk: int = 256):
    """Alg 3. score each candidate by reward of its predicted endpoint population.

    batched: all candidates x control cells go through the flow map in one
    (chunked) forward pass, then per-row rewards are averaged per candidate. works
    for any per-row reward (point/classifier). returns list of (label, score) desc.
    """
    E = torch.as_tensor(candidate_embeddings(model, data, candidate_labels, device), device=device)
    C, N = E.shape[0], c0.shape[0]
    scores = np.empty(C, dtype=np.float64)
    # chunk over candidates to bound memory (c*n rows per chunk)
    cands_per_chunk = max(1, chunk * 1024 // max(N, 1))
    for s in range(0, C, cands_per_chunk):
        e = E[s: s + cands_per_chunk]
        c = e.shape[0]
        c0_rep = c0.unsqueeze(0).expand(c, N, -1).reshape(c * N, -1)
        e_rep = e.unsqueeze(1).expand(c, N, -1).reshape(c * N, -1)
        chat = model.endpoint_from_e(c0_rep, e_rep)
        r = reward(chat).view(c, N).mean(1)
        scores[s: s + c] = r.detach().cpu().numpy()
    order = np.argsort(-scores)
    ranked = [(candidate_labels[i], float(scores[i])) for i in order]
    return ranked if topk is None else ranked[:topk]


def reward_guidance(model: PIVOT, c0: torch.Tensor, reward, e_init: torch.Tensor,
                    steps: int = 25, step_size: float = 0.5, eps: float = 1e-8,
                    normalize: bool = True) -> torch.Tensor:
    """Alg 4. gradient ascent on reward in embedding space via the flow-map jacobian.

    e_{l+1} = e_l + gamma * g_e/(norm(g_e)+eps), g_e from autograd."""
    e = e_init.detach().clone()
    for _ in range(steps):
        e.requires_grad_(True)
        chat = model.endpoint_from_e(c0, e.expand(c0.shape[0], -1))
        r = reward(chat).mean()
        (g_e,) = torch.autograd.grad(r, e)
        with torch.no_grad():
            step = g_e / (g_e.norm() + eps) if normalize else g_e
            e = e + step_size * step
        e = e.detach()
    return e


@torch.no_grad()
def project_and_rerank(model: PIVOT, data, candidate_labels, e_star: torch.Tensor,
                       c0: torch.Tensor, reward, k_nearest: int = 10, topk: int = 5,
                       device="cpu"):
    """Alg 5. project e* to the k nearest admissible candidates, rerank by endpoint reward."""
    E = torch.as_tensor(candidate_embeddings(model, data, candidate_labels, device), device=device)
    d = torch.cdist(e_star.view(1, -1), E).squeeze(0)
    knn = torch.argsort(d)[:k_nearest].cpu().numpy()
    sub = [candidate_labels[i] for i in knn]
    ranked = endpoint_ranking(model, data, sub, c0, reward, device=device)
    return ranked[:topk]


@torch.no_grad()
def greedy_combinatorial(model: PIVOT, data, gene_pool, c0: torch.Tensor, reward,
                         max_size: int = 2, min_gain: float = 1e-4, device="cpu"):
    """Alg 6. greedily add the gene that most improves the reward of the mean-pooled endpoint."""
    chosen: list[str] = []
    best_score = -np.inf
    history = []
    while len(chosen) < max_size:
        remaining = [g for g in gene_pool if g not in chosen]
        if not remaining:
            break
        # score all candidate gene-sets {chosen + g} in one batched pass
        labels = [data.sep.join(chosen + [g]) for g in remaining]
        ranked = endpoint_ranking(model, data, labels, c0, reward, device=device)
        best_label, best_local = ranked[0]
        best_gene = remaining[labels.index(best_label)]
        if best_local - best_score < min_gain:
            break
        chosen.append(best_gene)
        best_score = best_local
        history.append((list(chosen), best_score))
    return chosen, best_score, history