| """uniform predictor interface so pivot and every baseline run through the same |
| forward/inverse code. a predictor maps (perturbation label, control-cell embeddings |
| c0) -> predicted post-perturbation population ĉ1 in embedding space. |
| """ |
| from __future__ import annotations |
|
|
| import numpy as np |
| import torch |
|
|
| from src.evaluation import inference as inf |
|
|
|
|
| class PivotPredictor: |
| name = "PIVOT" |
|
|
| def __init__(self, model, data, device): |
| self.model, self.data, self.device = model, data, device |
| model.eval() |
|
|
| def population(self, label: str, c0: np.ndarray) -> np.ndarray: |
| c0t = torch.as_tensor(c0, dtype=torch.float32, device=self.device) |
| e = inf.encode_label(self.model, self.data, label, self.device) |
| return inf.forward_predict(self.model, c0t, e).cpu().numpy() |
|
|
| def population_from_genes(self, genes: list[str], c0: np.ndarray) -> np.ndarray: |
| c0t = torch.as_tensor(c0, dtype=torch.float32, device=self.device) |
| e = inf.encode_gene_set(self.model, self.data, genes, self.device) |
| return inf.forward_predict(self.model, c0t, e).cpu().numpy() |
|
|
|
|
| class BaselinePredictor: |
| def __init__(self, baseline): |
| self.bl = baseline |
| self.name = baseline.name |
|
|
| def population(self, label: str, c0: np.ndarray) -> np.ndarray: |
| return self.bl.predict_endpoint(label, c0) |
|
|
| def population_from_genes(self, genes: list[str], c0: np.ndarray) -> np.ndarray: |
| label = self.bl.data.sep.join(genes) if genes else self.bl.data.control_label |
| return self.bl.predict_endpoint(label, c0) |
|
|