File size: 4,940 Bytes
3147616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Inference utilities for Tahoe cell type classifier.
"""

import numpy as np
import scanpy as sc
import torch
import json
from pathlib import Path
from safetensors.torch import load_file
from tqdm.auto import tqdm

from src.tahoe_classifier.models.encoder import CellEncoder, CellEncoderConfig
from src.tahoe_classifier.models.classifier import CellTypeClassifier, CellTypeClassifierConfig


def load_model(model_path=".", device="cuda"):
    """Load trained model from directory."""
    model_path = Path(model_path)
    
    # Load config
    with open(model_path / "config.json") as f:
        config = json.load(f)
    
    # Create encoder config
    encoder_config = CellEncoderConfig(
        vocab_size=60697,
        d_model=config["d_model"],
        n_layers=config["n_layers"],
        n_heads=config["n_heads"],
        expansion_ratio=config["expansion_ratio"],
    )
    
    # Load base model components (vocab, collator_cfg)
    from src.tahoe_classifier.models.encoder import CellEncoder
    _, vocab, collator_cfg = CellEncoder.from_pretrained_tahoe(
        model_size=config["model_size"],
        device="cpu"
    )
    
    # Create classifier
    classifier_config = CellTypeClassifierConfig(
        num_labels=config["num_labels"],
        encoder_config=encoder_config.to_dict(),
        classifier_dropout=config["classifier_dropout"],
    )
    
    from src.tahoe_classifier.models.encoder import CellEncoder as Encoder
    encoder = Encoder(encoder_config)
    model = CellTypeClassifier(classifier_config, encoder=encoder)
    
    # Load merged weights
    state_dict = load_file(model_path / "model.safetensors")
    model.load_state_dict(state_dict)
    
    model.to(device)
    model.eval()
    
    return model, vocab, collator_cfg


def prepare_data(adata, vocab, collator_cfg, gene_id_key="ensembl_id", max_length=2048):
    """Preprocess h5ad for inference."""
    gene_col = gene_id_key if gene_id_key in adata.var.columns else "gene_id"
    gene_names = adata.var[gene_col].tolist()
    
    gene2idx = vocab.get_stoi()
    gene_ids_map = np.array([gene2idx.get(g, -1) for g in gene_names], dtype=np.int64)
    valid_mask = gene_ids_map >= 0
    valid_indices = np.where(valid_mask)[0]
    gene_ids = gene_ids_map[valid_mask]
    
    n_cells = adata.n_obs
    seq_len = min(len(gene_ids), max_length)
    
    gene_ids_batch = np.zeros((n_cells, seq_len), dtype=np.int64)
    expr_batch = np.zeros((n_cells, seq_len), dtype=np.float32)
    
    for i in tqdm(range(n_cells), desc="Processing cells"):
        x = adata.X[i, valid_indices]
        x = x.toarray().flatten() if hasattr(x, "toarray") else np.array(x).flatten()
        indices = np.argsort(-x)[:seq_len]
        gene_ids_batch[i] = gene_ids[indices]
        expr_batch[i] = x[indices]
    
    # Binning
    num_bins = collator_cfg.get("num_bins", 51)
    expr_max = np.clip(np.max(expr_batch, axis=1, keepdims=True), 1e-6, None)
    expr_batch = np.clip(np.floor(expr_batch / expr_max * (num_bins - 1)), 0, num_bins - 1)
    
    return gene_ids_batch, expr_batch


def predict_cell_types(
    model, vocab, collator_cfg,
    h5ad_path, label_key=None, gene_id_key="ensembl_id",
    batch_size=32, device="cuda"
):
    """Predict cell types for h5ad file."""
    
    # Load config for label mapping
    with open(Path(model.config._name_or_path if hasattr(model.config, '_name_or_path') else '.') / "config.json") as f:
        config = json.load(f)
    
    id_to_label = config["id2label"]
    
    # Load data
    adata = sc.read_h5ad(h5ad_path)
    gene_ids, expr_values = prepare_data(adata, vocab, collator_cfg, gene_id_key)
    
    # Run inference
    all_preds = []
    model.to(device)
    
    with torch.no_grad():
        for i in tqdm(range(0, len(gene_ids), batch_size), desc="Predicting"):
            batch_genes = torch.tensor(gene_ids[i:i+batch_size], device=device)
            batch_expr = torch.tensor(expr_values[i:i+batch_size], device=device)
            
            outputs = model(gene_ids=batch_genes, expression_values=batch_expr)
            preds = torch.argmax(outputs.logits, dim=-1)
            all_preds.extend(preds.cpu().numpy())
    
    # Decode predictions
    predicted_labels = [id_to_label[str(p)] for p in all_preds]
    
    # Add to adata
    adata.obs["predicted_cell_type"] = predicted_labels
    
    # Compute metrics if ground truth available
    if label_key and label_key in adata.obs.columns:
        from sklearn.metrics import accuracy_score, f1_score, classification_report
        
        true_labels = adata.obs[label_key].values
        accuracy = accuracy_score(true_labels, predicted_labels)
        f1_macro = f1_score(true_labels, predicted_labels, average="macro")
        
        print(f"\nAccuracy: {accuracy:.4f}")
        print(f"F1 Macro: {f1_macro:.4f}")
        print("\n" + classification_report(true_labels, predicted_labels))
    
    return adata