| """ |
| Phase 1: Extract Cell Embeddings |
| Demonstrates how to load GeneMamba and extract cell embeddings for downstream analysis. |
| |
| Usage: |
| python examples/1_extract_embeddings.py |
| """ |
|
|
| import torch |
| import numpy as np |
| from transformers import AutoTokenizer, AutoModel |
|
|
|
|
| def main(): |
| print("=" * 80) |
| print("GeneMamba Phase 1: Extract Cell Embeddings") |
| print("=" * 80) |
| |
| |
| |
| |
| print("\n[Step 1] Loading model and tokenizer...") |
| |
| |
| |
| model_name = "GeneMamba-24l-512d" |
| |
| try: |
| tokenizer = AutoTokenizer.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| local_files_only=True |
| ) |
| model = AutoModel.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| local_files_only=True |
| ) |
| except Exception as e: |
| print(f"Note: Could not load from '{model_name}': {e}") |
| print("Using mock data for demonstration...") |
| |
| |
| from configuration_genemamba import GeneMambaConfig |
| from modeling_genemamba import GeneMambaModel |
| |
| config = GeneMambaConfig( |
| vocab_size=25426, |
| hidden_size=512, |
| num_hidden_layers=24, |
| embedding_pooling="mean", |
| ) |
| model = GeneMambaModel(config) |
| tokenizer = None |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = model.to(device) |
| model.eval() |
| |
| print(f"β Model loaded on device: {device}") |
| print(f"β Model config: hidden_size={model.config.hidden_size}, " |
| f"num_layers={model.config.num_hidden_layers}") |
| |
| |
| |
| |
| print("\n[Step 2] Preparing sample data...") |
| |
| batch_size = 8 |
| seq_len = 2048 |
| vocab_size = 25426 |
| |
| |
| |
| |
| input_ids = torch.randint(2, vocab_size, (batch_size, seq_len)).to(device) |
| |
| print(f"β Created sample input:") |
| print(f" - Batch size: {batch_size}") |
| print(f" - Sequence length: {seq_len}") |
| print(f" - Input shape: {input_ids.shape}") |
| |
| |
| |
| |
| print("\n[Step 3] Extracting cell embeddings...") |
| |
| with torch.no_grad(): |
| outputs = model(input_ids, output_hidden_states=False) |
| |
| |
| cell_embeddings = outputs.pooled_embedding |
| |
| print(f"β Extraction complete!") |
| print(f" - Cell embeddings shape: {cell_embeddings.shape}") |
| print(f" - Pooling method used: {outputs.embedding_pooling}") |
| print(f" - Embedding type: {cell_embeddings.dtype}") |
| |
| |
| |
| |
| print("\n[Step 4] Example downstream uses...") |
| |
| |
| from sklearn.cluster import KMeans |
| n_clusters = 3 |
| kmeans = KMeans(n_clusters=n_clusters, n_init=10) |
| clusters = kmeans.fit_predict(cell_embeddings.cpu().numpy()) |
| print(f"β Clustering: Assigned {len(np.unique(clusters))} clusters") |
| |
| |
| from sklearn.decomposition import PCA |
| pca = PCA(n_components=2) |
| embedding_2d = pca.fit_transform(cell_embeddings.cpu().numpy()) |
| print(f"β PCA reduction: {cell_embeddings.shape} β {embedding_2d.shape}") |
| |
| |
| |
| similarities = torch.nn.functional.cosine_similarity( |
| cell_embeddings[0:1], |
| cell_embeddings |
| ) |
| most_similar_idx = torch.argmax(similarities).item() |
| print(f"β Similarity search: Most similar cell to cell 0 is cell {most_similar_idx} " |
| f"(similarity: {similarities[most_similar_idx]:.4f})") |
| |
| |
| print("\n[Step 5] Embedding statistics:") |
| print(f" - Mean: {cell_embeddings.mean(dim=0).norm():.4f}") |
| print(f" - Std: {cell_embeddings.std(dim=0).mean():.4f}") |
| print(f" - Min: {cell_embeddings.min():.4f}") |
| print(f" - Max: {cell_embeddings.max():.4f}") |
| |
| |
| |
| |
| print("\n[Step 6] Saving embeddings...") |
| |
| np.save("cell_embeddings.npy", cell_embeddings.cpu().numpy()) |
| print("β Embeddings saved to 'cell_embeddings.npy'") |
| |
| print("\n" + "=" * 80) |
| print("Phase 1 Complete!") |
| print("=" * 80) |
| |
| return model, cell_embeddings |
|
|
|
|
| if __name__ == "__main__": |
| model, embeddings = main() |
|
|