chatspatial-engine / phoenix_engine.py
arka2696
fix: Vendor Phoenix flow_simple.py, remove phoenix package dependency
3ae0177
"""
Phoenix Engine — Self-contained for HuggingFace Spaces deployment.
Flow-matching model for predicting 377-gene expression from H&E patches.
Reference:
Tran, Gindra et al., "Pan-cancer virtual spatial transcriptomics
from routine histology with Phoenix", bioRxiv (2026).
"""
import logging
from pathlib import Path
from typing import Union
import numpy as np
import torch
from PIL import Image
logger = logging.getLogger(__name__)
IMAGE_MEAN = (0.707223, 0.578729, 0.703617)
IMAGE_STD = (0.211883, 0.230117, 0.177517)
FLOW_CONFIG = dict(
d_genes=1,
d_image=1536,
d_model=512,
d_cross=512,
n_heads=8,
n_layers=8,
qkv_bias=False,
ffn_bias=False,
ffn_mult=4,
attn_drop=0.0,
proj_drop=0.0,
n_classes=0,
cls_drop=0.1,
checkpoint=False,
)
VISION_MODEL_NAME = "vit_giant_patch14_reg4_dinov2"
MARKER_ANNOTATIONS = {
"CD8A": "cytotoxic T-cells",
"CD8B": "cytotoxic T-cells",
"CD3D": "pan T-cells",
"CD3E": "pan T-cells",
"CD4": "helper T-cells",
"MS4A1": "B-cells (CD20)",
"CD19": "B-cells",
"CD68": "macrophages",
"CD163": "M2 macrophages",
"PTPRC": "pan-immune (CD45)",
"FOXP3": "regulatory T-cells",
"EPCAM": "epithelial/tumor",
"KRT18": "epithelial",
"KRT7": "epithelial",
"MKI67": "proliferation",
"PCNA": "proliferation",
"COL1A1": "fibroblasts/collagen",
"VIM": "mesenchymal",
"ACTA2": "smooth muscle/CAFs",
"FAP": "cancer-associated fibroblasts",
"VEGFA": "angiogenesis",
"PDCD1": "immune checkpoint (PD-1)",
"CD274": "immune checkpoint (PD-L1)",
"CTLA4": "immune checkpoint",
"HLA-A": "antigen presentation",
}
def classify_expression(value: float) -> str:
if value <= 0.1:
return "absent"
if value <= 0.5:
return "very_low"
if value <= 1.0:
return "low"
if value <= 2.0:
return "moderate"
if value <= 4.0:
return "high"
return "very_high"
def _build_image_transform():
from torchvision.transforms import v2, InterpolationMode
return v2.Compose([
v2.Resize((224, 224), InterpolationMode.BICUBIC),
v2.CenterCrop((224, 224)),
v2.ToTensor(),
v2.Normalize(IMAGE_MEAN, IMAGE_STD),
])
class PhoenixEngine:
"""
Phoenix flow-matching model for spatial gene expression prediction.
Parameters
----------
model_dir : str | Path
Directory containing flow_model.pth, stats_table.npz, xenium_human_multi.npy
device : str
CUDA device string
num_samples : int
Number of ODE samples to average (higher = more stable, slower)
"""
def __init__(self, model_dir: Union[str, Path], device: str = "cuda:0", num_samples: int = 5):
self.model_dir = Path(model_dir)
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.num_samples = num_samples
self._loaded = False
self._weight_path = self.model_dir / "flow_model.pth"
self._stats_path = self.model_dir / "stats_table.npz"
self._panel_path = self.model_dir / "xenium_human_multi.npy"
missing = [p for p in [self._weight_path, self._stats_path, self._panel_path] if not p.exists()]
if missing:
raise FileNotFoundError(f"Missing Phoenix model files: {[p.name for p in missing]}")
def _load(self):
if self._loaded:
return
logger.info("Loading Phoenix engine...")
self.gene_list = list(np.load(str(self._panel_path), allow_pickle=True))
self.n_genes = len(self.gene_list)
stats = np.load(str(self._stats_path))
self.stats_mean = stats["mean"]
self.stats_std = stats["std"]
import timm
self.vision_model = timm.create_model(
VISION_MODEL_NAME,
pretrained=True,
img_size=224,
num_classes=0,
global_pool="token",
init_values=1e-5,
dynamic_img_size=False,
)
self.vision_model = self.vision_model.eval().to(self.device)
for p in self.vision_model.parameters():
p.requires_grad = False
from phoenix_vendor import FlowTransformerModel, FlowTransformerConfig
cfg = FlowTransformerConfig(**FLOW_CONFIG)
self.flow_model = FlowTransformerModel(cfg, vision_model=None)
state_dict = torch.load(str(self._weight_path), map_location=self.device, weights_only=False)
flow_keys = {k: v for k, v in state_dict.items() if not k.startswith("vision_model.")}
self.flow_model.load_state_dict(flow_keys, strict=False)
self.flow_model = self.flow_model.eval().to(self.device)
self.transform = _build_image_transform()
from zuko.utils import odeint
self._odeint = odeint
self._loaded = True
logger.info("Phoenix engine ready")
@torch.no_grad()
def _run_flow(self, image_tensor: torch.Tensor) -> np.ndarray:
feats = self.vision_model.forward_features(image_tensor)
all_preds = []
for _ in range(self.num_samples):
x_0 = torch.randn(1, self.n_genes, 1, device=self.device)
def velocity_fn(t_scalar: float, x: torch.Tensor) -> torch.Tensor:
t_vec = torch.full((x.shape[0],), t_scalar, device=self.device)
return self.flow_model(x, t_vec, feats)
phi = self.flow_model.parameters()
x_1 = self._odeint(velocity_fn, x_0, 0.0, 1.0, phi=phi, atol=1e-1, rtol=1e-1)
pred = x_1.squeeze().cpu().numpy()
all_preds.append(pred)
pred_mean = np.mean(all_preds, axis=0)
pred_mean = np.clip(pred_mean, 0, None)
expression = pred_mean * self.stats_std + self.stats_mean
return expression
def predict(self, image: Union[str, Path, Image.Image]) -> dict:
"""Predict gene expression from H&E patch. Returns {gene: value}."""
self._load()
if isinstance(image, (str, Path)):
img = Image.open(image).convert("RGB")
else:
img = image.convert("RGB")
tensor = self.transform(img).unsqueeze(0).to(self.device)
expression = self._run_flow(tensor)
return {gene: float(val) for gene, val in zip(self.gene_list, expression)}
def predict_formatted(self, image: Union[str, Path, Image.Image], top_n: int = 30) -> dict:
"""Predict expression and return structured report."""
raw = self.predict(image)
marker_results = {}
for gene, desc in MARKER_ANNOTATIONS.items():
if gene in raw:
marker_results[gene] = {
"value": round(raw[gene], 3),
"tier": classify_expression(raw[gene]),
"annotation": desc,
}
sorted_genes = sorted(raw.items(), key=lambda x: -x[1])
top_genes = sorted_genes[:top_n]
immune_genes = ["CD8A", "CD8B", "CD3D", "CD3E", "CD4", "PTPRC", "CD68", "CD163", "CD19", "MS4A1", "FOXP3"]
tumor_genes = ["EPCAM", "KRT18", "KRT7", "MKI67", "PCNA"]
stroma_genes = ["COL1A1", "VIM", "ACTA2", "FAP"]
immune_score = sum(raw.get(g, 0) for g in immune_genes)
tumor_score = sum(raw.get(g, 0) for g in tumor_genes)
stroma_score = sum(raw.get(g, 0) for g in stroma_genes)
scores = {"immune": immune_score, "tumor": tumor_score, "stroma": stroma_score}
dominant = max(scores, key=scores.get)
lines = [
f"**PHOENIX EXPRESSION PROFILE** ({len(raw)} genes, flow-matching model)",
f"**Dominant tissue type:** {dominant.upper()} "
f"(immune={immune_score:.1f}, tumor={tumor_score:.1f}, stroma={stroma_score:.1f})",
"",
"**Top expressed genes:**",
]
for gene, val in top_genes[:15]:
annot = MARKER_ANNOTATIONS.get(gene, "")
tier = classify_expression(val)
label = f" ({annot})" if annot else ""
lines.append(f"- {gene}{label}: {val:.3f} [{tier}]")
if marker_results:
lines.append("")
lines.append("**Key marker genes:**")
for gene, info in sorted(marker_results.items(), key=lambda x: -x[1]["value"]):
lines.append(f"- {gene} ({info['annotation']}): {info['value']} [{info['tier']}]")
return {
"expression": raw,
"top_genes": top_genes,
"marker_results": marker_results,
"cell_type_scores": scores,
"dominant_type": dominant,
"summary_text": "\n".join(lines),
"n_genes": len(raw),
}