""" Phoenix Engine — Self-contained for HuggingFace Spaces deployment. Flow-matching model for predicting 377-gene expression from H&E patches. Reference: Tran, Gindra et al., "Pan-cancer virtual spatial transcriptomics from routine histology with Phoenix", bioRxiv (2026). """ import logging from pathlib import Path from typing import Union import numpy as np import torch from PIL import Image logger = logging.getLogger(__name__) IMAGE_MEAN = (0.707223, 0.578729, 0.703617) IMAGE_STD = (0.211883, 0.230117, 0.177517) FLOW_CONFIG = dict( d_genes=1, d_image=1536, d_model=512, d_cross=512, n_heads=8, n_layers=8, qkv_bias=False, ffn_bias=False, ffn_mult=4, attn_drop=0.0, proj_drop=0.0, n_classes=0, cls_drop=0.1, checkpoint=False, ) VISION_MODEL_NAME = "vit_giant_patch14_reg4_dinov2" MARKER_ANNOTATIONS = { "CD8A": "cytotoxic T-cells", "CD8B": "cytotoxic T-cells", "CD3D": "pan T-cells", "CD3E": "pan T-cells", "CD4": "helper T-cells", "MS4A1": "B-cells (CD20)", "CD19": "B-cells", "CD68": "macrophages", "CD163": "M2 macrophages", "PTPRC": "pan-immune (CD45)", "FOXP3": "regulatory T-cells", "EPCAM": "epithelial/tumor", "KRT18": "epithelial", "KRT7": "epithelial", "MKI67": "proliferation", "PCNA": "proliferation", "COL1A1": "fibroblasts/collagen", "VIM": "mesenchymal", "ACTA2": "smooth muscle/CAFs", "FAP": "cancer-associated fibroblasts", "VEGFA": "angiogenesis", "PDCD1": "immune checkpoint (PD-1)", "CD274": "immune checkpoint (PD-L1)", "CTLA4": "immune checkpoint", "HLA-A": "antigen presentation", } def classify_expression(value: float) -> str: if value <= 0.1: return "absent" if value <= 0.5: return "very_low" if value <= 1.0: return "low" if value <= 2.0: return "moderate" if value <= 4.0: return "high" return "very_high" def _build_image_transform(): from torchvision.transforms import v2, InterpolationMode return v2.Compose([ v2.Resize((224, 224), InterpolationMode.BICUBIC), v2.CenterCrop((224, 224)), v2.ToTensor(), v2.Normalize(IMAGE_MEAN, IMAGE_STD), ]) class PhoenixEngine: """ Phoenix flow-matching model for spatial gene expression prediction. Parameters ---------- model_dir : str | Path Directory containing flow_model.pth, stats_table.npz, xenium_human_multi.npy device : str CUDA device string num_samples : int Number of ODE samples to average (higher = more stable, slower) """ def __init__(self, model_dir: Union[str, Path], device: str = "cuda:0", num_samples: int = 5): self.model_dir = Path(model_dir) self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.num_samples = num_samples self._loaded = False self._weight_path = self.model_dir / "flow_model.pth" self._stats_path = self.model_dir / "stats_table.npz" self._panel_path = self.model_dir / "xenium_human_multi.npy" missing = [p for p in [self._weight_path, self._stats_path, self._panel_path] if not p.exists()] if missing: raise FileNotFoundError(f"Missing Phoenix model files: {[p.name for p in missing]}") def _load(self): if self._loaded: return logger.info("Loading Phoenix engine...") self.gene_list = list(np.load(str(self._panel_path), allow_pickle=True)) self.n_genes = len(self.gene_list) stats = np.load(str(self._stats_path)) self.stats_mean = stats["mean"] self.stats_std = stats["std"] import timm self.vision_model = timm.create_model( VISION_MODEL_NAME, pretrained=True, img_size=224, num_classes=0, global_pool="token", init_values=1e-5, dynamic_img_size=False, ) self.vision_model = self.vision_model.eval().to(self.device) for p in self.vision_model.parameters(): p.requires_grad = False from phoenix_vendor import FlowTransformerModel, FlowTransformerConfig cfg = FlowTransformerConfig(**FLOW_CONFIG) self.flow_model = FlowTransformerModel(cfg, vision_model=None) state_dict = torch.load(str(self._weight_path), map_location=self.device, weights_only=False) flow_keys = {k: v for k, v in state_dict.items() if not k.startswith("vision_model.")} self.flow_model.load_state_dict(flow_keys, strict=False) self.flow_model = self.flow_model.eval().to(self.device) self.transform = _build_image_transform() from zuko.utils import odeint self._odeint = odeint self._loaded = True logger.info("Phoenix engine ready") @torch.no_grad() def _run_flow(self, image_tensor: torch.Tensor) -> np.ndarray: feats = self.vision_model.forward_features(image_tensor) all_preds = [] for _ in range(self.num_samples): x_0 = torch.randn(1, self.n_genes, 1, device=self.device) def velocity_fn(t_scalar: float, x: torch.Tensor) -> torch.Tensor: t_vec = torch.full((x.shape[0],), t_scalar, device=self.device) return self.flow_model(x, t_vec, feats) phi = self.flow_model.parameters() x_1 = self._odeint(velocity_fn, x_0, 0.0, 1.0, phi=phi, atol=1e-1, rtol=1e-1) pred = x_1.squeeze().cpu().numpy() all_preds.append(pred) pred_mean = np.mean(all_preds, axis=0) pred_mean = np.clip(pred_mean, 0, None) expression = pred_mean * self.stats_std + self.stats_mean return expression def predict(self, image: Union[str, Path, Image.Image]) -> dict: """Predict gene expression from H&E patch. Returns {gene: value}.""" self._load() if isinstance(image, (str, Path)): img = Image.open(image).convert("RGB") else: img = image.convert("RGB") tensor = self.transform(img).unsqueeze(0).to(self.device) expression = self._run_flow(tensor) return {gene: float(val) for gene, val in zip(self.gene_list, expression)} def predict_formatted(self, image: Union[str, Path, Image.Image], top_n: int = 30) -> dict: """Predict expression and return structured report.""" raw = self.predict(image) marker_results = {} for gene, desc in MARKER_ANNOTATIONS.items(): if gene in raw: marker_results[gene] = { "value": round(raw[gene], 3), "tier": classify_expression(raw[gene]), "annotation": desc, } sorted_genes = sorted(raw.items(), key=lambda x: -x[1]) top_genes = sorted_genes[:top_n] immune_genes = ["CD8A", "CD8B", "CD3D", "CD3E", "CD4", "PTPRC", "CD68", "CD163", "CD19", "MS4A1", "FOXP3"] tumor_genes = ["EPCAM", "KRT18", "KRT7", "MKI67", "PCNA"] stroma_genes = ["COL1A1", "VIM", "ACTA2", "FAP"] immune_score = sum(raw.get(g, 0) for g in immune_genes) tumor_score = sum(raw.get(g, 0) for g in tumor_genes) stroma_score = sum(raw.get(g, 0) for g in stroma_genes) scores = {"immune": immune_score, "tumor": tumor_score, "stroma": stroma_score} dominant = max(scores, key=scores.get) lines = [ f"**PHOENIX EXPRESSION PROFILE** ({len(raw)} genes, flow-matching model)", f"**Dominant tissue type:** {dominant.upper()} " f"(immune={immune_score:.1f}, tumor={tumor_score:.1f}, stroma={stroma_score:.1f})", "", "**Top expressed genes:**", ] for gene, val in top_genes[:15]: annot = MARKER_ANNOTATIONS.get(gene, "") tier = classify_expression(val) label = f" ({annot})" if annot else "" lines.append(f"- {gene}{label}: {val:.3f} [{tier}]") if marker_results: lines.append("") lines.append("**Key marker genes:**") for gene, info in sorted(marker_results.items(), key=lambda x: -x[1]["value"]): lines.append(f"- {gene} ({info['annotation']}): {info['value']} [{info['tier']}]") return { "expression": raw, "top_genes": top_genes, "marker_results": marker_results, "cell_type_scores": scores, "dominant_type": dominant, "summary_text": "\n".join(lines), "n_genes": len(raw), }