File size: 4,940 Bytes
3147616 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | """
Inference utilities for Tahoe cell type classifier.
"""
import numpy as np
import scanpy as sc
import torch
import json
from pathlib import Path
from safetensors.torch import load_file
from tqdm.auto import tqdm
from src.tahoe_classifier.models.encoder import CellEncoder, CellEncoderConfig
from src.tahoe_classifier.models.classifier import CellTypeClassifier, CellTypeClassifierConfig
def load_model(model_path=".", device="cuda"):
"""Load trained model from directory."""
model_path = Path(model_path)
# Load config
with open(model_path / "config.json") as f:
config = json.load(f)
# Create encoder config
encoder_config = CellEncoderConfig(
vocab_size=60697,
d_model=config["d_model"],
n_layers=config["n_layers"],
n_heads=config["n_heads"],
expansion_ratio=config["expansion_ratio"],
)
# Load base model components (vocab, collator_cfg)
from src.tahoe_classifier.models.encoder import CellEncoder
_, vocab, collator_cfg = CellEncoder.from_pretrained_tahoe(
model_size=config["model_size"],
device="cpu"
)
# Create classifier
classifier_config = CellTypeClassifierConfig(
num_labels=config["num_labels"],
encoder_config=encoder_config.to_dict(),
classifier_dropout=config["classifier_dropout"],
)
from src.tahoe_classifier.models.encoder import CellEncoder as Encoder
encoder = Encoder(encoder_config)
model = CellTypeClassifier(classifier_config, encoder=encoder)
# Load merged weights
state_dict = load_file(model_path / "model.safetensors")
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model, vocab, collator_cfg
def prepare_data(adata, vocab, collator_cfg, gene_id_key="ensembl_id", max_length=2048):
"""Preprocess h5ad for inference."""
gene_col = gene_id_key if gene_id_key in adata.var.columns else "gene_id"
gene_names = adata.var[gene_col].tolist()
gene2idx = vocab.get_stoi()
gene_ids_map = np.array([gene2idx.get(g, -1) for g in gene_names], dtype=np.int64)
valid_mask = gene_ids_map >= 0
valid_indices = np.where(valid_mask)[0]
gene_ids = gene_ids_map[valid_mask]
n_cells = adata.n_obs
seq_len = min(len(gene_ids), max_length)
gene_ids_batch = np.zeros((n_cells, seq_len), dtype=np.int64)
expr_batch = np.zeros((n_cells, seq_len), dtype=np.float32)
for i in tqdm(range(n_cells), desc="Processing cells"):
x = adata.X[i, valid_indices]
x = x.toarray().flatten() if hasattr(x, "toarray") else np.array(x).flatten()
indices = np.argsort(-x)[:seq_len]
gene_ids_batch[i] = gene_ids[indices]
expr_batch[i] = x[indices]
# Binning
num_bins = collator_cfg.get("num_bins", 51)
expr_max = np.clip(np.max(expr_batch, axis=1, keepdims=True), 1e-6, None)
expr_batch = np.clip(np.floor(expr_batch / expr_max * (num_bins - 1)), 0, num_bins - 1)
return gene_ids_batch, expr_batch
def predict_cell_types(
model, vocab, collator_cfg,
h5ad_path, label_key=None, gene_id_key="ensembl_id",
batch_size=32, device="cuda"
):
"""Predict cell types for h5ad file."""
# Load config for label mapping
with open(Path(model.config._name_or_path if hasattr(model.config, '_name_or_path') else '.') / "config.json") as f:
config = json.load(f)
id_to_label = config["id2label"]
# Load data
adata = sc.read_h5ad(h5ad_path)
gene_ids, expr_values = prepare_data(adata, vocab, collator_cfg, gene_id_key)
# Run inference
all_preds = []
model.to(device)
with torch.no_grad():
for i in tqdm(range(0, len(gene_ids), batch_size), desc="Predicting"):
batch_genes = torch.tensor(gene_ids[i:i+batch_size], device=device)
batch_expr = torch.tensor(expr_values[i:i+batch_size], device=device)
outputs = model(gene_ids=batch_genes, expression_values=batch_expr)
preds = torch.argmax(outputs.logits, dim=-1)
all_preds.extend(preds.cpu().numpy())
# Decode predictions
predicted_labels = [id_to_label[str(p)] for p in all_preds]
# Add to adata
adata.obs["predicted_cell_type"] = predicted_labels
# Compute metrics if ground truth available
if label_key and label_key in adata.obs.columns:
from sklearn.metrics import accuracy_score, f1_score, classification_report
true_labels = adata.obs[label_key].values
accuracy = accuracy_score(true_labels, predicted_labels)
f1_macro = f1_score(true_labels, predicted_labels, average="macro")
print(f"\nAccuracy: {accuracy:.4f}")
print(f"F1 Macro: {f1_macro:.4f}")
print("\n" + classification_report(true_labels, predicted_labels))
return adata
|