| """ |
| Batch integration downstream example: |
| - Extract embeddings with frozen GeneMamba |
| - Evaluate simple batch mixing score proxy (silhouette by batch) |
| |
| Expected h5ad columns: |
| - obs['batch'] |
| """ |
|
|
| import argparse |
| import numpy as np |
| import scanpy as sc |
| import torch |
| from sklearn.metrics import silhouette_score |
| from sklearn.preprocessing import LabelEncoder |
| from transformers import AutoModel |
|
|
|
|
| def build_ranked_input_ids(adata, symbol2id, seq_len=2048, pad_id=1): |
| gene_names = np.array(adata.var_names) |
| X = adata.X |
| out = np.full((adata.n_obs, seq_len), pad_id, dtype=np.int64) |
|
|
| for i in range(adata.n_obs): |
| row = X[i] |
| if hasattr(row, "toarray"): |
| expr = row.toarray().ravel() |
| else: |
| expr = np.asarray(row).ravel() |
|
|
| nz = np.where(expr > 0)[0] |
| if len(nz) == 0: |
| continue |
|
|
| genes = gene_names[nz] |
| vals = expr[nz] |
| order = np.argsort(-vals) |
| ranked_genes = genes[order] |
| ids = [symbol2id[g] for g in ranked_genes if g in symbol2id][:seq_len] |
| out[i, : len(ids)] = ids |
|
|
| return out |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_path", required=True) |
| parser.add_argument("--h5ad", required=True) |
| parser.add_argument("--symbol2id_npy", required=True) |
| parser.add_argument("--seq_len", type=int, default=2048) |
| parser.add_argument("--batch_size", type=int, default=64) |
| args = parser.parse_args() |
|
|
| adata = sc.read_h5ad(args.h5ad) |
| assert "batch" in adata.obs, "h5ad must include obs['batch']" |
|
|
| symbol2id = np.load(args.symbol2id_npy, allow_pickle=True).item() |
| input_ids = build_ranked_input_ids(adata, symbol2id, seq_len=args.seq_len) |
|
|
| model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True) |
| model.eval().cuda() |
|
|
| embeds = [] |
| with torch.no_grad(): |
| for s in range(0, input_ids.shape[0], args.batch_size): |
| batch = torch.tensor(input_ids[s : s + args.batch_size], dtype=torch.long, device="cuda") |
| out = model(batch) |
| embeds.append(out.pooled_embedding.detach().cpu().numpy()) |
| embeds = np.concatenate(embeds, axis=0) |
|
|
| batch_labels = LabelEncoder().fit_transform(adata.obs["batch"].values) |
| score = silhouette_score(embeds, batch_labels, metric="euclidean") |
|
|
| print("silhouette_by_batch:", score) |
| print("(Closer to 0 typically indicates better batch mixing than very high positive values.)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|