File size: 2,518 Bytes
d3fa071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
Batch integration downstream example:
- Extract embeddings with frozen GeneMamba
- Evaluate simple batch mixing score proxy (silhouette by batch)

Expected h5ad columns:
- obs['batch']
"""

import argparse
import numpy as np
import scanpy as sc
import torch
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import LabelEncoder
from transformers import AutoModel


def build_ranked_input_ids(adata, symbol2id, seq_len=2048, pad_id=1):
    gene_names = np.array(adata.var_names)
    X = adata.X
    out = np.full((adata.n_obs, seq_len), pad_id, dtype=np.int64)

    for i in range(adata.n_obs):
        row = X[i]
        if hasattr(row, "toarray"):
            expr = row.toarray().ravel()
        else:
            expr = np.asarray(row).ravel()

        nz = np.where(expr > 0)[0]
        if len(nz) == 0:
            continue

        genes = gene_names[nz]
        vals = expr[nz]
        order = np.argsort(-vals)
        ranked_genes = genes[order]
        ids = [symbol2id[g] for g in ranked_genes if g in symbol2id][:seq_len]
        out[i, : len(ids)] = ids

    return out


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", required=True)
    parser.add_argument("--h5ad", required=True)
    parser.add_argument("--symbol2id_npy", required=True)
    parser.add_argument("--seq_len", type=int, default=2048)
    parser.add_argument("--batch_size", type=int, default=64)
    args = parser.parse_args()

    adata = sc.read_h5ad(args.h5ad)
    assert "batch" in adata.obs, "h5ad must include obs['batch']"

    symbol2id = np.load(args.symbol2id_npy, allow_pickle=True).item()
    input_ids = build_ranked_input_ids(adata, symbol2id, seq_len=args.seq_len)

    model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True)
    model.eval().cuda()

    embeds = []
    with torch.no_grad():
        for s in range(0, input_ids.shape[0], args.batch_size):
            batch = torch.tensor(input_ids[s : s + args.batch_size], dtype=torch.long, device="cuda")
            out = model(batch)
            embeds.append(out.pooled_embedding.detach().cpu().numpy())
    embeds = np.concatenate(embeds, axis=0)

    batch_labels = LabelEncoder().fit_transform(adata.obs["batch"].values)
    score = silhouette_score(embeds, batch_labels, metric="euclidean")

    print("silhouette_by_batch:", score)
    print("(Closer to 0 typically indicates better batch mixing than very high positive values.)")


if __name__ == "__main__":
    main()