""" Batch integration downstream example: - Extract embeddings with frozen GeneMamba - Evaluate simple batch mixing score proxy (silhouette by batch) Expected h5ad columns: - obs['batch'] """ import argparse import numpy as np import scanpy as sc import torch from sklearn.metrics import silhouette_score from sklearn.preprocessing import LabelEncoder from transformers import AutoModel def build_ranked_input_ids(adata, symbol2id, seq_len=2048, pad_id=1): gene_names = np.array(adata.var_names) X = adata.X out = np.full((adata.n_obs, seq_len), pad_id, dtype=np.int64) for i in range(adata.n_obs): row = X[i] if hasattr(row, "toarray"): expr = row.toarray().ravel() else: expr = np.asarray(row).ravel() nz = np.where(expr > 0)[0] if len(nz) == 0: continue genes = gene_names[nz] vals = expr[nz] order = np.argsort(-vals) ranked_genes = genes[order] ids = [symbol2id[g] for g in ranked_genes if g in symbol2id][:seq_len] out[i, : len(ids)] = ids return out def main(): parser = argparse.ArgumentParser() parser.add_argument("--model_path", required=True) parser.add_argument("--h5ad", required=True) parser.add_argument("--symbol2id_npy", required=True) parser.add_argument("--seq_len", type=int, default=2048) parser.add_argument("--batch_size", type=int, default=64) args = parser.parse_args() adata = sc.read_h5ad(args.h5ad) assert "batch" in adata.obs, "h5ad must include obs['batch']" symbol2id = np.load(args.symbol2id_npy, allow_pickle=True).item() input_ids = build_ranked_input_ids(adata, symbol2id, seq_len=args.seq_len) model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True) model.eval().cuda() embeds = [] with torch.no_grad(): for s in range(0, input_ids.shape[0], args.batch_size): batch = torch.tensor(input_ids[s : s + args.batch_size], dtype=torch.long, device="cuda") out = model(batch) embeds.append(out.pooled_embedding.detach().cpu().numpy()) embeds = np.concatenate(embeds, axis=0) batch_labels = LabelEncoder().fit_transform(adata.obs["batch"].values) score = silhouette_score(embeds, batch_labels, metric="euclidean") print("silhouette_by_batch:", score) print("(Closer to 0 typically indicates better batch mixing than very high positive values.)") if __name__ == "__main__": main()