|
|
""" |
|
|
Inference utilities for Tahoe cell type classifier. |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
import scanpy as sc |
|
|
import torch |
|
|
import json |
|
|
from pathlib import Path |
|
|
from safetensors.torch import load_file |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
from src.tahoe_classifier.models.encoder import CellEncoder, CellEncoderConfig |
|
|
from src.tahoe_classifier.models.classifier import CellTypeClassifier, CellTypeClassifierConfig |
|
|
|
|
|
|
|
|
def load_model(model_path=".", device="cuda"): |
|
|
"""Load trained model from directory.""" |
|
|
model_path = Path(model_path) |
|
|
|
|
|
|
|
|
with open(model_path / "config.json") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
encoder_config = CellEncoderConfig( |
|
|
vocab_size=60697, |
|
|
d_model=config["d_model"], |
|
|
n_layers=config["n_layers"], |
|
|
n_heads=config["n_heads"], |
|
|
expansion_ratio=config["expansion_ratio"], |
|
|
) |
|
|
|
|
|
|
|
|
from src.tahoe_classifier.models.encoder import CellEncoder |
|
|
_, vocab, collator_cfg = CellEncoder.from_pretrained_tahoe( |
|
|
model_size=config["model_size"], |
|
|
device="cpu" |
|
|
) |
|
|
|
|
|
|
|
|
classifier_config = CellTypeClassifierConfig( |
|
|
num_labels=config["num_labels"], |
|
|
encoder_config=encoder_config.to_dict(), |
|
|
classifier_dropout=config["classifier_dropout"], |
|
|
) |
|
|
|
|
|
from src.tahoe_classifier.models.encoder import CellEncoder as Encoder |
|
|
encoder = Encoder(encoder_config) |
|
|
model = CellTypeClassifier(classifier_config, encoder=encoder) |
|
|
|
|
|
|
|
|
state_dict = load_file(model_path / "model.safetensors") |
|
|
model.load_state_dict(state_dict) |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
return model, vocab, collator_cfg |
|
|
|
|
|
|
|
|
def prepare_data(adata, vocab, collator_cfg, gene_id_key="ensembl_id", max_length=2048): |
|
|
"""Preprocess h5ad for inference.""" |
|
|
gene_col = gene_id_key if gene_id_key in adata.var.columns else "gene_id" |
|
|
gene_names = adata.var[gene_col].tolist() |
|
|
|
|
|
gene2idx = vocab.get_stoi() |
|
|
gene_ids_map = np.array([gene2idx.get(g, -1) for g in gene_names], dtype=np.int64) |
|
|
valid_mask = gene_ids_map >= 0 |
|
|
valid_indices = np.where(valid_mask)[0] |
|
|
gene_ids = gene_ids_map[valid_mask] |
|
|
|
|
|
n_cells = adata.n_obs |
|
|
seq_len = min(len(gene_ids), max_length) |
|
|
|
|
|
gene_ids_batch = np.zeros((n_cells, seq_len), dtype=np.int64) |
|
|
expr_batch = np.zeros((n_cells, seq_len), dtype=np.float32) |
|
|
|
|
|
for i in tqdm(range(n_cells), desc="Processing cells"): |
|
|
x = adata.X[i, valid_indices] |
|
|
x = x.toarray().flatten() if hasattr(x, "toarray") else np.array(x).flatten() |
|
|
indices = np.argsort(-x)[:seq_len] |
|
|
gene_ids_batch[i] = gene_ids[indices] |
|
|
expr_batch[i] = x[indices] |
|
|
|
|
|
|
|
|
num_bins = collator_cfg.get("num_bins", 51) |
|
|
expr_max = np.clip(np.max(expr_batch, axis=1, keepdims=True), 1e-6, None) |
|
|
expr_batch = np.clip(np.floor(expr_batch / expr_max * (num_bins - 1)), 0, num_bins - 1) |
|
|
|
|
|
return gene_ids_batch, expr_batch |
|
|
|
|
|
|
|
|
def predict_cell_types( |
|
|
model, vocab, collator_cfg, |
|
|
h5ad_path, label_key=None, gene_id_key="ensembl_id", |
|
|
batch_size=32, device="cuda" |
|
|
): |
|
|
"""Predict cell types for h5ad file.""" |
|
|
|
|
|
|
|
|
with open(Path(model.config._name_or_path if hasattr(model.config, '_name_or_path') else '.') / "config.json") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
id_to_label = config["id2label"] |
|
|
|
|
|
|
|
|
adata = sc.read_h5ad(h5ad_path) |
|
|
gene_ids, expr_values = prepare_data(adata, vocab, collator_cfg, gene_id_key) |
|
|
|
|
|
|
|
|
all_preds = [] |
|
|
model.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in tqdm(range(0, len(gene_ids), batch_size), desc="Predicting"): |
|
|
batch_genes = torch.tensor(gene_ids[i:i+batch_size], device=device) |
|
|
batch_expr = torch.tensor(expr_values[i:i+batch_size], device=device) |
|
|
|
|
|
outputs = model(gene_ids=batch_genes, expression_values=batch_expr) |
|
|
preds = torch.argmax(outputs.logits, dim=-1) |
|
|
all_preds.extend(preds.cpu().numpy()) |
|
|
|
|
|
|
|
|
predicted_labels = [id_to_label[str(p)] for p in all_preds] |
|
|
|
|
|
|
|
|
adata.obs["predicted_cell_type"] = predicted_labels |
|
|
|
|
|
|
|
|
if label_key and label_key in adata.obs.columns: |
|
|
from sklearn.metrics import accuracy_score, f1_score, classification_report |
|
|
|
|
|
true_labels = adata.obs[label_key].values |
|
|
accuracy = accuracy_score(true_labels, predicted_labels) |
|
|
f1_macro = f1_score(true_labels, predicted_labels, average="macro") |
|
|
|
|
|
print(f"\nAccuracy: {accuracy:.4f}") |
|
|
print(f"F1 Macro: {f1_macro:.4f}") |
|
|
print("\n" + classification_report(true_labels, predicted_labels)) |
|
|
|
|
|
return adata |
|
|
|