"""uniform predictor interface so pivot and every baseline run through the same forward/inverse code. a predictor maps (perturbation label, control-cell embeddings c0) -> predicted post-perturbation population ĉ1 in embedding space. """ from __future__ import annotations import numpy as np import torch from src.evaluation import inference as inf class PivotPredictor: name = "PIVOT" def __init__(self, model, data, device): self.model, self.data, self.device = model, data, device model.eval() def population(self, label: str, c0: np.ndarray) -> np.ndarray: c0t = torch.as_tensor(c0, dtype=torch.float32, device=self.device) e = inf.encode_label(self.model, self.data, label, self.device) return inf.forward_predict(self.model, c0t, e).cpu().numpy() def population_from_genes(self, genes: list[str], c0: np.ndarray) -> np.ndarray: c0t = torch.as_tensor(c0, dtype=torch.float32, device=self.device) e = inf.encode_gene_set(self.model, self.data, genes, self.device) return inf.forward_predict(self.model, c0t, e).cpu().numpy() class BaselinePredictor: def __init__(self, baseline): self.bl = baseline self.name = baseline.name def population(self, label: str, c0: np.ndarray) -> np.ndarray: return self.bl.predict_endpoint(label, c0) def population_from_genes(self, genes: list[str], c0: np.ndarray) -> np.ndarray: label = self.bl.data.sep.join(genes) if genes else self.bl.data.control_label return self.bl.predict_endpoint(label, c0)