""" Phase 1: Extract Cell Embeddings Demonstrates how to load GeneMamba and extract cell embeddings for downstream analysis. Usage: python examples/1_extract_embeddings.py """ import torch import numpy as np from transformers import AutoTokenizer, AutoModel def main(): print("=" * 80) print("GeneMamba Phase 1: Extract Cell Embeddings") print("=" * 80) # ============================================================ # Step 1: Load pretrained model and tokenizer # ============================================================ print("\n[Step 1] Loading model and tokenizer...") # For this example, we use a local model path # In practice, you would use: "username/GeneMamba-24l-512d" model_name = "GeneMamba-24l-512d" # Change to HF Hub path when available try: tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, local_files_only=True # Try local first ) model = AutoModel.from_pretrained( model_name, trust_remote_code=True, local_files_only=True ) except Exception as e: print(f"Note: Could not load from '{model_name}': {e}") print("Using mock data for demonstration...") # For demonstration without actual checkpoint from configuration_genemamba import GeneMambaConfig from modeling_genemamba import GeneMambaModel config = GeneMambaConfig( vocab_size=25426, hidden_size=512, num_hidden_layers=24, embedding_pooling="mean", ) model = GeneMambaModel(config) tokenizer = None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() print(f"✓ Model loaded on device: {device}") print(f"✓ Model config: hidden_size={model.config.hidden_size}, " f"num_layers={model.config.num_hidden_layers}") # ============================================================ # Step 2: Prepare simulated single-cell data # ============================================================ print("\n[Step 2] Preparing sample data...") batch_size = 8 seq_len = 2048 vocab_size = 25426 # Simulate ranked gene sequences # In practice, this would come from your scRNA-seq data # Genes should be ranked by expression (highest first) input_ids = torch.randint(2, vocab_size, (batch_size, seq_len)).to(device) print(f"✓ Created sample input:") print(f" - Batch size: {batch_size}") print(f" - Sequence length: {seq_len}") print(f" - Input shape: {input_ids.shape}") # ============================================================ # Step 3: Inference - Extract embeddings # ============================================================ print("\n[Step 3] Extracting cell embeddings...") with torch.no_grad(): outputs = model(input_ids, output_hidden_states=False) # Get the pooled embedding (cell representation) cell_embeddings = outputs.pooled_embedding print(f"✓ Extraction complete!") print(f" - Cell embeddings shape: {cell_embeddings.shape}") print(f" - Pooling method used: {outputs.embedding_pooling}") print(f" - Embedding type: {cell_embeddings.dtype}") # ============================================================ # Step 4: Example downstream analyses # ============================================================ print("\n[Step 4] Example downstream uses...") # Example 1: Clustering (KMeans) from sklearn.cluster import KMeans n_clusters = 3 kmeans = KMeans(n_clusters=n_clusters, n_init=10) clusters = kmeans.fit_predict(cell_embeddings.cpu().numpy()) print(f"✓ Clustering: Assigned {len(np.unique(clusters))} clusters") # Example 2: Dimensionality reduction (PCA) from sklearn.decomposition import PCA pca = PCA(n_components=2) embedding_2d = pca.fit_transform(cell_embeddings.cpu().numpy()) print(f"✓ PCA reduction: {cell_embeddings.shape} → {embedding_2d.shape}") # Example 3: Similarity search # Find the most similar cell to the first cell similarities = torch.nn.functional.cosine_similarity( cell_embeddings[0:1], cell_embeddings ) most_similar_idx = torch.argmax(similarities).item() print(f"✓ Similarity search: Most similar cell to cell 0 is cell {most_similar_idx} " f"(similarity: {similarities[most_similar_idx]:.4f})") # Example 4: Statistics print("\n[Step 5] Embedding statistics:") print(f" - Mean: {cell_embeddings.mean(dim=0).norm():.4f}") print(f" - Std: {cell_embeddings.std(dim=0).mean():.4f}") print(f" - Min: {cell_embeddings.min():.4f}") print(f" - Max: {cell_embeddings.max():.4f}") # ============================================================ # Step 6: Save embeddings (optional) # ============================================================ print("\n[Step 6] Saving embeddings...") np.save("cell_embeddings.npy", cell_embeddings.cpu().numpy()) print("✓ Embeddings saved to 'cell_embeddings.npy'") print("\n" + "=" * 80) print("Phase 1 Complete!") print("=" * 80) return model, cell_embeddings if __name__ == "__main__": model, embeddings = main()