File size: 5,497 Bytes
54cd552 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """
Phase 1: Extract Cell Embeddings
Demonstrates how to load GeneMamba and extract cell embeddings for downstream analysis.
Usage:
python examples/1_extract_embeddings.py
"""
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
def main():
print("=" * 80)
print("GeneMamba Phase 1: Extract Cell Embeddings")
print("=" * 80)
# ============================================================
# Step 1: Load pretrained model and tokenizer
# ============================================================
print("\n[Step 1] Loading model and tokenizer...")
# For this example, we use a local model path
# In practice, you would use: "username/GeneMamba-24l-512d"
model_name = "GeneMamba-24l-512d" # Change to HF Hub path when available
try:
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
local_files_only=True # Try local first
)
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
local_files_only=True
)
except Exception as e:
print(f"Note: Could not load from '{model_name}': {e}")
print("Using mock data for demonstration...")
# For demonstration without actual checkpoint
from configuration_genemamba import GeneMambaConfig
from modeling_genemamba import GeneMambaModel
config = GeneMambaConfig(
vocab_size=25426,
hidden_size=512,
num_hidden_layers=24,
embedding_pooling="mean",
)
model = GeneMambaModel(config)
tokenizer = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
print(f"β Model loaded on device: {device}")
print(f"β Model config: hidden_size={model.config.hidden_size}, "
f"num_layers={model.config.num_hidden_layers}")
# ============================================================
# Step 2: Prepare simulated single-cell data
# ============================================================
print("\n[Step 2] Preparing sample data...")
batch_size = 8
seq_len = 2048
vocab_size = 25426
# Simulate ranked gene sequences
# In practice, this would come from your scRNA-seq data
# Genes should be ranked by expression (highest first)
input_ids = torch.randint(2, vocab_size, (batch_size, seq_len)).to(device)
print(f"β Created sample input:")
print(f" - Batch size: {batch_size}")
print(f" - Sequence length: {seq_len}")
print(f" - Input shape: {input_ids.shape}")
# ============================================================
# Step 3: Inference - Extract embeddings
# ============================================================
print("\n[Step 3] Extracting cell embeddings...")
with torch.no_grad():
outputs = model(input_ids, output_hidden_states=False)
# Get the pooled embedding (cell representation)
cell_embeddings = outputs.pooled_embedding
print(f"β Extraction complete!")
print(f" - Cell embeddings shape: {cell_embeddings.shape}")
print(f" - Pooling method used: {outputs.embedding_pooling}")
print(f" - Embedding type: {cell_embeddings.dtype}")
# ============================================================
# Step 4: Example downstream analyses
# ============================================================
print("\n[Step 4] Example downstream uses...")
# Example 1: Clustering (KMeans)
from sklearn.cluster import KMeans
n_clusters = 3
kmeans = KMeans(n_clusters=n_clusters, n_init=10)
clusters = kmeans.fit_predict(cell_embeddings.cpu().numpy())
print(f"β Clustering: Assigned {len(np.unique(clusters))} clusters")
# Example 2: Dimensionality reduction (PCA)
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
embedding_2d = pca.fit_transform(cell_embeddings.cpu().numpy())
print(f"β PCA reduction: {cell_embeddings.shape} β {embedding_2d.shape}")
# Example 3: Similarity search
# Find the most similar cell to the first cell
similarities = torch.nn.functional.cosine_similarity(
cell_embeddings[0:1],
cell_embeddings
)
most_similar_idx = torch.argmax(similarities).item()
print(f"β Similarity search: Most similar cell to cell 0 is cell {most_similar_idx} "
f"(similarity: {similarities[most_similar_idx]:.4f})")
# Example 4: Statistics
print("\n[Step 5] Embedding statistics:")
print(f" - Mean: {cell_embeddings.mean(dim=0).norm():.4f}")
print(f" - Std: {cell_embeddings.std(dim=0).mean():.4f}")
print(f" - Min: {cell_embeddings.min():.4f}")
print(f" - Max: {cell_embeddings.max():.4f}")
# ============================================================
# Step 6: Save embeddings (optional)
# ============================================================
print("\n[Step 6] Saving embeddings...")
np.save("cell_embeddings.npy", cell_embeddings.cpu().numpy())
print("β Embeddings saved to 'cell_embeddings.npy'")
print("\n" + "=" * 80)
print("Phase 1 Complete!")
print("=" * 80)
return model, cell_embeddings
if __name__ == "__main__":
model, embeddings = main()
|