Spaces:
Sleeping
Sleeping
| """ | |
| Phoenix Engine — Self-contained for HuggingFace Spaces deployment. | |
| Flow-matching model for predicting 377-gene expression from H&E patches. | |
| Reference: | |
| Tran, Gindra et al., "Pan-cancer virtual spatial transcriptomics | |
| from routine histology with Phoenix", bioRxiv (2026). | |
| """ | |
| import logging | |
| from pathlib import Path | |
| from typing import Union | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| logger = logging.getLogger(__name__) | |
| IMAGE_MEAN = (0.707223, 0.578729, 0.703617) | |
| IMAGE_STD = (0.211883, 0.230117, 0.177517) | |
| FLOW_CONFIG = dict( | |
| d_genes=1, | |
| d_image=1536, | |
| d_model=512, | |
| d_cross=512, | |
| n_heads=8, | |
| n_layers=8, | |
| qkv_bias=False, | |
| ffn_bias=False, | |
| ffn_mult=4, | |
| attn_drop=0.0, | |
| proj_drop=0.0, | |
| n_classes=0, | |
| cls_drop=0.1, | |
| checkpoint=False, | |
| ) | |
| VISION_MODEL_NAME = "vit_giant_patch14_reg4_dinov2" | |
| MARKER_ANNOTATIONS = { | |
| "CD8A": "cytotoxic T-cells", | |
| "CD8B": "cytotoxic T-cells", | |
| "CD3D": "pan T-cells", | |
| "CD3E": "pan T-cells", | |
| "CD4": "helper T-cells", | |
| "MS4A1": "B-cells (CD20)", | |
| "CD19": "B-cells", | |
| "CD68": "macrophages", | |
| "CD163": "M2 macrophages", | |
| "PTPRC": "pan-immune (CD45)", | |
| "FOXP3": "regulatory T-cells", | |
| "EPCAM": "epithelial/tumor", | |
| "KRT18": "epithelial", | |
| "KRT7": "epithelial", | |
| "MKI67": "proliferation", | |
| "PCNA": "proliferation", | |
| "COL1A1": "fibroblasts/collagen", | |
| "VIM": "mesenchymal", | |
| "ACTA2": "smooth muscle/CAFs", | |
| "FAP": "cancer-associated fibroblasts", | |
| "VEGFA": "angiogenesis", | |
| "PDCD1": "immune checkpoint (PD-1)", | |
| "CD274": "immune checkpoint (PD-L1)", | |
| "CTLA4": "immune checkpoint", | |
| "HLA-A": "antigen presentation", | |
| } | |
| def classify_expression(value: float) -> str: | |
| if value <= 0.1: | |
| return "absent" | |
| if value <= 0.5: | |
| return "very_low" | |
| if value <= 1.0: | |
| return "low" | |
| if value <= 2.0: | |
| return "moderate" | |
| if value <= 4.0: | |
| return "high" | |
| return "very_high" | |
| def _build_image_transform(): | |
| from torchvision.transforms import v2, InterpolationMode | |
| return v2.Compose([ | |
| v2.Resize((224, 224), InterpolationMode.BICUBIC), | |
| v2.CenterCrop((224, 224)), | |
| v2.ToTensor(), | |
| v2.Normalize(IMAGE_MEAN, IMAGE_STD), | |
| ]) | |
| class PhoenixEngine: | |
| """ | |
| Phoenix flow-matching model for spatial gene expression prediction. | |
| Parameters | |
| ---------- | |
| model_dir : str | Path | |
| Directory containing flow_model.pth, stats_table.npz, xenium_human_multi.npy | |
| device : str | |
| CUDA device string | |
| num_samples : int | |
| Number of ODE samples to average (higher = more stable, slower) | |
| """ | |
| def __init__(self, model_dir: Union[str, Path], device: str = "cuda:0", num_samples: int = 5): | |
| self.model_dir = Path(model_dir) | |
| self.device = torch.device(device if torch.cuda.is_available() else "cpu") | |
| self.num_samples = num_samples | |
| self._loaded = False | |
| self._weight_path = self.model_dir / "flow_model.pth" | |
| self._stats_path = self.model_dir / "stats_table.npz" | |
| self._panel_path = self.model_dir / "xenium_human_multi.npy" | |
| missing = [p for p in [self._weight_path, self._stats_path, self._panel_path] if not p.exists()] | |
| if missing: | |
| raise FileNotFoundError(f"Missing Phoenix model files: {[p.name for p in missing]}") | |
| def _load(self): | |
| if self._loaded: | |
| return | |
| logger.info("Loading Phoenix engine...") | |
| self.gene_list = list(np.load(str(self._panel_path), allow_pickle=True)) | |
| self.n_genes = len(self.gene_list) | |
| stats = np.load(str(self._stats_path)) | |
| self.stats_mean = stats["mean"] | |
| self.stats_std = stats["std"] | |
| import timm | |
| self.vision_model = timm.create_model( | |
| VISION_MODEL_NAME, | |
| pretrained=True, | |
| img_size=224, | |
| num_classes=0, | |
| global_pool="token", | |
| init_values=1e-5, | |
| dynamic_img_size=False, | |
| ) | |
| self.vision_model = self.vision_model.eval().to(self.device) | |
| for p in self.vision_model.parameters(): | |
| p.requires_grad = False | |
| from phoenix_vendor import FlowTransformerModel, FlowTransformerConfig | |
| cfg = FlowTransformerConfig(**FLOW_CONFIG) | |
| self.flow_model = FlowTransformerModel(cfg, vision_model=None) | |
| state_dict = torch.load(str(self._weight_path), map_location=self.device, weights_only=False) | |
| flow_keys = {k: v for k, v in state_dict.items() if not k.startswith("vision_model.")} | |
| self.flow_model.load_state_dict(flow_keys, strict=False) | |
| self.flow_model = self.flow_model.eval().to(self.device) | |
| self.transform = _build_image_transform() | |
| from zuko.utils import odeint | |
| self._odeint = odeint | |
| self._loaded = True | |
| logger.info("Phoenix engine ready") | |
| def _run_flow(self, image_tensor: torch.Tensor) -> np.ndarray: | |
| feats = self.vision_model.forward_features(image_tensor) | |
| all_preds = [] | |
| for _ in range(self.num_samples): | |
| x_0 = torch.randn(1, self.n_genes, 1, device=self.device) | |
| def velocity_fn(t_scalar: float, x: torch.Tensor) -> torch.Tensor: | |
| t_vec = torch.full((x.shape[0],), t_scalar, device=self.device) | |
| return self.flow_model(x, t_vec, feats) | |
| phi = self.flow_model.parameters() | |
| x_1 = self._odeint(velocity_fn, x_0, 0.0, 1.0, phi=phi, atol=1e-1, rtol=1e-1) | |
| pred = x_1.squeeze().cpu().numpy() | |
| all_preds.append(pred) | |
| pred_mean = np.mean(all_preds, axis=0) | |
| pred_mean = np.clip(pred_mean, 0, None) | |
| expression = pred_mean * self.stats_std + self.stats_mean | |
| return expression | |
| def predict(self, image: Union[str, Path, Image.Image]) -> dict: | |
| """Predict gene expression from H&E patch. Returns {gene: value}.""" | |
| self._load() | |
| if isinstance(image, (str, Path)): | |
| img = Image.open(image).convert("RGB") | |
| else: | |
| img = image.convert("RGB") | |
| tensor = self.transform(img).unsqueeze(0).to(self.device) | |
| expression = self._run_flow(tensor) | |
| return {gene: float(val) for gene, val in zip(self.gene_list, expression)} | |
| def predict_formatted(self, image: Union[str, Path, Image.Image], top_n: int = 30) -> dict: | |
| """Predict expression and return structured report.""" | |
| raw = self.predict(image) | |
| marker_results = {} | |
| for gene, desc in MARKER_ANNOTATIONS.items(): | |
| if gene in raw: | |
| marker_results[gene] = { | |
| "value": round(raw[gene], 3), | |
| "tier": classify_expression(raw[gene]), | |
| "annotation": desc, | |
| } | |
| sorted_genes = sorted(raw.items(), key=lambda x: -x[1]) | |
| top_genes = sorted_genes[:top_n] | |
| immune_genes = ["CD8A", "CD8B", "CD3D", "CD3E", "CD4", "PTPRC", "CD68", "CD163", "CD19", "MS4A1", "FOXP3"] | |
| tumor_genes = ["EPCAM", "KRT18", "KRT7", "MKI67", "PCNA"] | |
| stroma_genes = ["COL1A1", "VIM", "ACTA2", "FAP"] | |
| immune_score = sum(raw.get(g, 0) for g in immune_genes) | |
| tumor_score = sum(raw.get(g, 0) for g in tumor_genes) | |
| stroma_score = sum(raw.get(g, 0) for g in stroma_genes) | |
| scores = {"immune": immune_score, "tumor": tumor_score, "stroma": stroma_score} | |
| dominant = max(scores, key=scores.get) | |
| lines = [ | |
| f"**PHOENIX EXPRESSION PROFILE** ({len(raw)} genes, flow-matching model)", | |
| f"**Dominant tissue type:** {dominant.upper()} " | |
| f"(immune={immune_score:.1f}, tumor={tumor_score:.1f}, stroma={stroma_score:.1f})", | |
| "", | |
| "**Top expressed genes:**", | |
| ] | |
| for gene, val in top_genes[:15]: | |
| annot = MARKER_ANNOTATIONS.get(gene, "") | |
| tier = classify_expression(val) | |
| label = f" ({annot})" if annot else "" | |
| lines.append(f"- {gene}{label}: {val:.3f} [{tier}]") | |
| if marker_results: | |
| lines.append("") | |
| lines.append("**Key marker genes:**") | |
| for gene, info in sorted(marker_results.items(), key=lambda x: -x[1]["value"]): | |
| lines.append(f"- {gene} ({info['annotation']}): {info['value']} [{info['tier']}]") | |
| return { | |
| "expression": raw, | |
| "top_genes": top_genes, | |
| "marker_results": marker_results, | |
| "cell_type_scores": scores, | |
| "dominant_type": dominant, | |
| "summary_text": "\n".join(lines), | |
| "n_genes": len(raw), | |
| } | |