GeneMamba2-24l-512d / examples /downstream /12_batch_integration_eval.py
mineself2016's picture
Normalize example naming order
d3fa071 verified
"""
Batch integration downstream example:
- Extract embeddings with frozen GeneMamba
- Evaluate simple batch mixing score proxy (silhouette by batch)
Expected h5ad columns:
- obs['batch']
"""
import argparse
import numpy as np
import scanpy as sc
import torch
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import LabelEncoder
from transformers import AutoModel
def build_ranked_input_ids(adata, symbol2id, seq_len=2048, pad_id=1):
gene_names = np.array(adata.var_names)
X = adata.X
out = np.full((adata.n_obs, seq_len), pad_id, dtype=np.int64)
for i in range(adata.n_obs):
row = X[i]
if hasattr(row, "toarray"):
expr = row.toarray().ravel()
else:
expr = np.asarray(row).ravel()
nz = np.where(expr > 0)[0]
if len(nz) == 0:
continue
genes = gene_names[nz]
vals = expr[nz]
order = np.argsort(-vals)
ranked_genes = genes[order]
ids = [symbol2id[g] for g in ranked_genes if g in symbol2id][:seq_len]
out[i, : len(ids)] = ids
return out
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", required=True)
parser.add_argument("--h5ad", required=True)
parser.add_argument("--symbol2id_npy", required=True)
parser.add_argument("--seq_len", type=int, default=2048)
parser.add_argument("--batch_size", type=int, default=64)
args = parser.parse_args()
adata = sc.read_h5ad(args.h5ad)
assert "batch" in adata.obs, "h5ad must include obs['batch']"
symbol2id = np.load(args.symbol2id_npy, allow_pickle=True).item()
input_ids = build_ranked_input_ids(adata, symbol2id, seq_len=args.seq_len)
model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True)
model.eval().cuda()
embeds = []
with torch.no_grad():
for s in range(0, input_ids.shape[0], args.batch_size):
batch = torch.tensor(input_ids[s : s + args.batch_size], dtype=torch.long, device="cuda")
out = model(batch)
embeds.append(out.pooled_embedding.detach().cpu().numpy())
embeds = np.concatenate(embeds, axis=0)
batch_labels = LabelEncoder().fit_transform(adata.obs["batch"].values)
score = silhouette_score(embeds, batch_labels, metric="euclidean")
print("silhouette_by_batch:", score)
print("(Closer to 0 typically indicates better batch mixing than very high positive values.)")
if __name__ == "__main__":
main()