GeneMamba / examples /1_extract_embeddings.py
mineself2016's picture
Upload GeneMamba model
54cd552 verified
"""
Phase 1: Extract Cell Embeddings
Demonstrates how to load GeneMamba and extract cell embeddings for downstream analysis.
Usage:
python examples/1_extract_embeddings.py
"""
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
def main():
print("=" * 80)
print("GeneMamba Phase 1: Extract Cell Embeddings")
print("=" * 80)
# ============================================================
# Step 1: Load pretrained model and tokenizer
# ============================================================
print("\n[Step 1] Loading model and tokenizer...")
# For this example, we use a local model path
# In practice, you would use: "username/GeneMamba-24l-512d"
model_name = "GeneMamba-24l-512d" # Change to HF Hub path when available
try:
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
local_files_only=True # Try local first
)
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
local_files_only=True
)
except Exception as e:
print(f"Note: Could not load from '{model_name}': {e}")
print("Using mock data for demonstration...")
# For demonstration without actual checkpoint
from configuration_genemamba import GeneMambaConfig
from modeling_genemamba import GeneMambaModel
config = GeneMambaConfig(
vocab_size=25426,
hidden_size=512,
num_hidden_layers=24,
embedding_pooling="mean",
)
model = GeneMambaModel(config)
tokenizer = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
print(f"βœ“ Model loaded on device: {device}")
print(f"βœ“ Model config: hidden_size={model.config.hidden_size}, "
f"num_layers={model.config.num_hidden_layers}")
# ============================================================
# Step 2: Prepare simulated single-cell data
# ============================================================
print("\n[Step 2] Preparing sample data...")
batch_size = 8
seq_len = 2048
vocab_size = 25426
# Simulate ranked gene sequences
# In practice, this would come from your scRNA-seq data
# Genes should be ranked by expression (highest first)
input_ids = torch.randint(2, vocab_size, (batch_size, seq_len)).to(device)
print(f"βœ“ Created sample input:")
print(f" - Batch size: {batch_size}")
print(f" - Sequence length: {seq_len}")
print(f" - Input shape: {input_ids.shape}")
# ============================================================
# Step 3: Inference - Extract embeddings
# ============================================================
print("\n[Step 3] Extracting cell embeddings...")
with torch.no_grad():
outputs = model(input_ids, output_hidden_states=False)
# Get the pooled embedding (cell representation)
cell_embeddings = outputs.pooled_embedding
print(f"βœ“ Extraction complete!")
print(f" - Cell embeddings shape: {cell_embeddings.shape}")
print(f" - Pooling method used: {outputs.embedding_pooling}")
print(f" - Embedding type: {cell_embeddings.dtype}")
# ============================================================
# Step 4: Example downstream analyses
# ============================================================
print("\n[Step 4] Example downstream uses...")
# Example 1: Clustering (KMeans)
from sklearn.cluster import KMeans
n_clusters = 3
kmeans = KMeans(n_clusters=n_clusters, n_init=10)
clusters = kmeans.fit_predict(cell_embeddings.cpu().numpy())
print(f"βœ“ Clustering: Assigned {len(np.unique(clusters))} clusters")
# Example 2: Dimensionality reduction (PCA)
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
embedding_2d = pca.fit_transform(cell_embeddings.cpu().numpy())
print(f"βœ“ PCA reduction: {cell_embeddings.shape} β†’ {embedding_2d.shape}")
# Example 3: Similarity search
# Find the most similar cell to the first cell
similarities = torch.nn.functional.cosine_similarity(
cell_embeddings[0:1],
cell_embeddings
)
most_similar_idx = torch.argmax(similarities).item()
print(f"βœ“ Similarity search: Most similar cell to cell 0 is cell {most_similar_idx} "
f"(similarity: {similarities[most_similar_idx]:.4f})")
# Example 4: Statistics
print("\n[Step 5] Embedding statistics:")
print(f" - Mean: {cell_embeddings.mean(dim=0).norm():.4f}")
print(f" - Std: {cell_embeddings.std(dim=0).mean():.4f}")
print(f" - Min: {cell_embeddings.min():.4f}")
print(f" - Max: {cell_embeddings.max():.4f}")
# ============================================================
# Step 6: Save embeddings (optional)
# ============================================================
print("\n[Step 6] Saving embeddings...")
np.save("cell_embeddings.npy", cell_embeddings.cpu().numpy())
print("βœ“ Embeddings saved to 'cell_embeddings.npy'")
print("\n" + "=" * 80)
print("Phase 1 Complete!")
print("=" * 80)
return model, cell_embeddings
if __name__ == "__main__":
model, embeddings = main()