tahoe-lung-classifier / inference.py
Yuto2007's picture
Upload folder using huggingface_hub
3147616 verified
"""
Inference utilities for Tahoe cell type classifier.
"""
import numpy as np
import scanpy as sc
import torch
import json
from pathlib import Path
from safetensors.torch import load_file
from tqdm.auto import tqdm
from src.tahoe_classifier.models.encoder import CellEncoder, CellEncoderConfig
from src.tahoe_classifier.models.classifier import CellTypeClassifier, CellTypeClassifierConfig
def load_model(model_path=".", device="cuda"):
"""Load trained model from directory."""
model_path = Path(model_path)
# Load config
with open(model_path / "config.json") as f:
config = json.load(f)
# Create encoder config
encoder_config = CellEncoderConfig(
vocab_size=60697,
d_model=config["d_model"],
n_layers=config["n_layers"],
n_heads=config["n_heads"],
expansion_ratio=config["expansion_ratio"],
)
# Load base model components (vocab, collator_cfg)
from src.tahoe_classifier.models.encoder import CellEncoder
_, vocab, collator_cfg = CellEncoder.from_pretrained_tahoe(
model_size=config["model_size"],
device="cpu"
)
# Create classifier
classifier_config = CellTypeClassifierConfig(
num_labels=config["num_labels"],
encoder_config=encoder_config.to_dict(),
classifier_dropout=config["classifier_dropout"],
)
from src.tahoe_classifier.models.encoder import CellEncoder as Encoder
encoder = Encoder(encoder_config)
model = CellTypeClassifier(classifier_config, encoder=encoder)
# Load merged weights
state_dict = load_file(model_path / "model.safetensors")
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model, vocab, collator_cfg
def prepare_data(adata, vocab, collator_cfg, gene_id_key="ensembl_id", max_length=2048):
"""Preprocess h5ad for inference."""
gene_col = gene_id_key if gene_id_key in adata.var.columns else "gene_id"
gene_names = adata.var[gene_col].tolist()
gene2idx = vocab.get_stoi()
gene_ids_map = np.array([gene2idx.get(g, -1) for g in gene_names], dtype=np.int64)
valid_mask = gene_ids_map >= 0
valid_indices = np.where(valid_mask)[0]
gene_ids = gene_ids_map[valid_mask]
n_cells = adata.n_obs
seq_len = min(len(gene_ids), max_length)
gene_ids_batch = np.zeros((n_cells, seq_len), dtype=np.int64)
expr_batch = np.zeros((n_cells, seq_len), dtype=np.float32)
for i in tqdm(range(n_cells), desc="Processing cells"):
x = adata.X[i, valid_indices]
x = x.toarray().flatten() if hasattr(x, "toarray") else np.array(x).flatten()
indices = np.argsort(-x)[:seq_len]
gene_ids_batch[i] = gene_ids[indices]
expr_batch[i] = x[indices]
# Binning
num_bins = collator_cfg.get("num_bins", 51)
expr_max = np.clip(np.max(expr_batch, axis=1, keepdims=True), 1e-6, None)
expr_batch = np.clip(np.floor(expr_batch / expr_max * (num_bins - 1)), 0, num_bins - 1)
return gene_ids_batch, expr_batch
def predict_cell_types(
model, vocab, collator_cfg,
h5ad_path, label_key=None, gene_id_key="ensembl_id",
batch_size=32, device="cuda"
):
"""Predict cell types for h5ad file."""
# Load config for label mapping
with open(Path(model.config._name_or_path if hasattr(model.config, '_name_or_path') else '.') / "config.json") as f:
config = json.load(f)
id_to_label = config["id2label"]
# Load data
adata = sc.read_h5ad(h5ad_path)
gene_ids, expr_values = prepare_data(adata, vocab, collator_cfg, gene_id_key)
# Run inference
all_preds = []
model.to(device)
with torch.no_grad():
for i in tqdm(range(0, len(gene_ids), batch_size), desc="Predicting"):
batch_genes = torch.tensor(gene_ids[i:i+batch_size], device=device)
batch_expr = torch.tensor(expr_values[i:i+batch_size], device=device)
outputs = model(gene_ids=batch_genes, expression_values=batch_expr)
preds = torch.argmax(outputs.logits, dim=-1)
all_preds.extend(preds.cpu().numpy())
# Decode predictions
predicted_labels = [id_to_label[str(p)] for p in all_preds]
# Add to adata
adata.obs["predicted_cell_type"] = predicted_labels
# Compute metrics if ground truth available
if label_key and label_key in adata.obs.columns:
from sklearn.metrics import accuracy_score, f1_score, classification_report
true_labels = adata.obs[label_key].values
accuracy = accuracy_score(true_labels, predicted_labels)
f1_macro = f1_score(true_labels, predicted_labels, average="macro")
print(f"\nAccuracy: {accuracy:.4f}")
print(f"F1 Macro: {f1_macro:.4f}")
print("\n" + classification_report(true_labels, predicted_labels))
return adata