""" Inference utilities for Tahoe cell type classifier. """ import numpy as np import scanpy as sc import torch import json from pathlib import Path from safetensors.torch import load_file from tqdm.auto import tqdm from src.tahoe_classifier.models.encoder import CellEncoder, CellEncoderConfig from src.tahoe_classifier.models.classifier import CellTypeClassifier, CellTypeClassifierConfig def load_model(model_path=".", device="cuda"): """Load trained model from directory.""" model_path = Path(model_path) # Load config with open(model_path / "config.json") as f: config = json.load(f) # Create encoder config encoder_config = CellEncoderConfig( vocab_size=60697, d_model=config["d_model"], n_layers=config["n_layers"], n_heads=config["n_heads"], expansion_ratio=config["expansion_ratio"], ) # Load base model components (vocab, collator_cfg) from src.tahoe_classifier.models.encoder import CellEncoder _, vocab, collator_cfg = CellEncoder.from_pretrained_tahoe( model_size=config["model_size"], device="cpu" ) # Create classifier classifier_config = CellTypeClassifierConfig( num_labels=config["num_labels"], encoder_config=encoder_config.to_dict(), classifier_dropout=config["classifier_dropout"], ) from src.tahoe_classifier.models.encoder import CellEncoder as Encoder encoder = Encoder(encoder_config) model = CellTypeClassifier(classifier_config, encoder=encoder) # Load merged weights state_dict = load_file(model_path / "model.safetensors") model.load_state_dict(state_dict) model.to(device) model.eval() return model, vocab, collator_cfg def prepare_data(adata, vocab, collator_cfg, gene_id_key="ensembl_id", max_length=2048): """Preprocess h5ad for inference.""" gene_col = gene_id_key if gene_id_key in adata.var.columns else "gene_id" gene_names = adata.var[gene_col].tolist() gene2idx = vocab.get_stoi() gene_ids_map = np.array([gene2idx.get(g, -1) for g in gene_names], dtype=np.int64) valid_mask = gene_ids_map >= 0 valid_indices = np.where(valid_mask)[0] gene_ids = gene_ids_map[valid_mask] n_cells = adata.n_obs seq_len = min(len(gene_ids), max_length) gene_ids_batch = np.zeros((n_cells, seq_len), dtype=np.int64) expr_batch = np.zeros((n_cells, seq_len), dtype=np.float32) for i in tqdm(range(n_cells), desc="Processing cells"): x = adata.X[i, valid_indices] x = x.toarray().flatten() if hasattr(x, "toarray") else np.array(x).flatten() indices = np.argsort(-x)[:seq_len] gene_ids_batch[i] = gene_ids[indices] expr_batch[i] = x[indices] # Binning num_bins = collator_cfg.get("num_bins", 51) expr_max = np.clip(np.max(expr_batch, axis=1, keepdims=True), 1e-6, None) expr_batch = np.clip(np.floor(expr_batch / expr_max * (num_bins - 1)), 0, num_bins - 1) return gene_ids_batch, expr_batch def predict_cell_types( model, vocab, collator_cfg, h5ad_path, label_key=None, gene_id_key="ensembl_id", batch_size=32, device="cuda" ): """Predict cell types for h5ad file.""" # Load config for label mapping with open(Path(model.config._name_or_path if hasattr(model.config, '_name_or_path') else '.') / "config.json") as f: config = json.load(f) id_to_label = config["id2label"] # Load data adata = sc.read_h5ad(h5ad_path) gene_ids, expr_values = prepare_data(adata, vocab, collator_cfg, gene_id_key) # Run inference all_preds = [] model.to(device) with torch.no_grad(): for i in tqdm(range(0, len(gene_ids), batch_size), desc="Predicting"): batch_genes = torch.tensor(gene_ids[i:i+batch_size], device=device) batch_expr = torch.tensor(expr_values[i:i+batch_size], device=device) outputs = model(gene_ids=batch_genes, expression_values=batch_expr) preds = torch.argmax(outputs.logits, dim=-1) all_preds.extend(preds.cpu().numpy()) # Decode predictions predicted_labels = [id_to_label[str(p)] for p in all_preds] # Add to adata adata.obs["predicted_cell_type"] = predicted_labels # Compute metrics if ground truth available if label_key and label_key in adata.obs.columns: from sklearn.metrics import accuracy_score, f1_score, classification_report true_labels = adata.obs[label_key].values accuracy = accuracy_score(true_labels, predicted_labels) f1_macro = f1_score(true_labels, predicted_labels, average="macro") print(f"\nAccuracy: {accuracy:.4f}") print(f"F1 Macro: {f1_macro:.4f}") print("\n" + classification_report(true_labels, predicted_labels)) return adata