PIVOT / src /experiments /predictors.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
1.58 kB
"""uniform predictor interface so pivot and every baseline run through the same
forward/inverse code. a predictor maps (perturbation label, control-cell embeddings
c0) -> predicted post-perturbation population ĉ1 in embedding space.
"""
from __future__ import annotations
import numpy as np
import torch
from src.evaluation import inference as inf
class PivotPredictor:
name = "PIVOT"
def __init__(self, model, data, device):
self.model, self.data, self.device = model, data, device
model.eval()
def population(self, label: str, c0: np.ndarray) -> np.ndarray:
c0t = torch.as_tensor(c0, dtype=torch.float32, device=self.device)
e = inf.encode_label(self.model, self.data, label, self.device)
return inf.forward_predict(self.model, c0t, e).cpu().numpy()
def population_from_genes(self, genes: list[str], c0: np.ndarray) -> np.ndarray:
c0t = torch.as_tensor(c0, dtype=torch.float32, device=self.device)
e = inf.encode_gene_set(self.model, self.data, genes, self.device)
return inf.forward_predict(self.model, c0t, e).cpu().numpy()
class BaselinePredictor:
def __init__(self, baseline):
self.bl = baseline
self.name = baseline.name
def population(self, label: str, c0: np.ndarray) -> np.ndarray:
return self.bl.predict_endpoint(label, c0)
def population_from_genes(self, genes: list[str], c0: np.ndarray) -> np.ndarray:
label = self.bl.data.sep.join(genes) if genes else self.bl.data.control_label
return self.bl.predict_endpoint(label, c0)