File size: 5,497 Bytes
54cd552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
Phase 1: Extract Cell Embeddings
Demonstrates how to load GeneMamba and extract cell embeddings for downstream analysis.

Usage:
    python examples/1_extract_embeddings.py
"""

import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel


def main():
    print("=" * 80)
    print("GeneMamba Phase 1: Extract Cell Embeddings")
    print("=" * 80)
    
    # ============================================================
    # Step 1: Load pretrained model and tokenizer
    # ============================================================
    print("\n[Step 1] Loading model and tokenizer...")
    
    # For this example, we use a local model path
    # In practice, you would use: "username/GeneMamba-24l-512d"
    model_name = "GeneMamba-24l-512d"  # Change to HF Hub path when available
    
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True,
            local_files_only=True  # Try local first
        )
        model = AutoModel.from_pretrained(
            model_name,
            trust_remote_code=True,
            local_files_only=True
        )
    except Exception as e:
        print(f"Note: Could not load from '{model_name}': {e}")
        print("Using mock data for demonstration...")
        
        # For demonstration without actual checkpoint
        from configuration_genemamba import GeneMambaConfig
        from modeling_genemamba import GeneMambaModel
        
        config = GeneMambaConfig(
            vocab_size=25426,
            hidden_size=512,
            num_hidden_layers=24,
            embedding_pooling="mean",
        )
        model = GeneMambaModel(config)
        tokenizer = None
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    
    print(f"βœ“ Model loaded on device: {device}")
    print(f"βœ“ Model config: hidden_size={model.config.hidden_size}, "
          f"num_layers={model.config.num_hidden_layers}")
    
    # ============================================================
    # Step 2: Prepare simulated single-cell data
    # ============================================================
    print("\n[Step 2] Preparing sample data...")
    
    batch_size = 8
    seq_len = 2048
    vocab_size = 25426
    
    # Simulate ranked gene sequences
    # In practice, this would come from your scRNA-seq data
    # Genes should be ranked by expression (highest first)
    input_ids = torch.randint(2, vocab_size, (batch_size, seq_len)).to(device)
    
    print(f"βœ“ Created sample input:")
    print(f"  - Batch size: {batch_size}")
    print(f"  - Sequence length: {seq_len}")
    print(f"  - Input shape: {input_ids.shape}")
    
    # ============================================================
    # Step 3: Inference - Extract embeddings
    # ============================================================
    print("\n[Step 3] Extracting cell embeddings...")
    
    with torch.no_grad():
        outputs = model(input_ids, output_hidden_states=False)
    
    # Get the pooled embedding (cell representation)
    cell_embeddings = outputs.pooled_embedding
    
    print(f"βœ“ Extraction complete!")
    print(f"  - Cell embeddings shape: {cell_embeddings.shape}")
    print(f"  - Pooling method used: {outputs.embedding_pooling}")
    print(f"  - Embedding type: {cell_embeddings.dtype}")
    
    # ============================================================
    # Step 4: Example downstream analyses
    # ============================================================
    print("\n[Step 4] Example downstream uses...")
    
    # Example 1: Clustering (KMeans)
    from sklearn.cluster import KMeans
    n_clusters = 3
    kmeans = KMeans(n_clusters=n_clusters, n_init=10)
    clusters = kmeans.fit_predict(cell_embeddings.cpu().numpy())
    print(f"βœ“ Clustering: Assigned {len(np.unique(clusters))} clusters")
    
    # Example 2: Dimensionality reduction (PCA)
    from sklearn.decomposition import PCA
    pca = PCA(n_components=2)
    embedding_2d = pca.fit_transform(cell_embeddings.cpu().numpy())
    print(f"βœ“ PCA reduction: {cell_embeddings.shape} β†’ {embedding_2d.shape}")
    
    # Example 3: Similarity search
    # Find the most similar cell to the first cell
    similarities = torch.nn.functional.cosine_similarity(
        cell_embeddings[0:1],
        cell_embeddings
    )
    most_similar_idx = torch.argmax(similarities).item()
    print(f"βœ“ Similarity search: Most similar cell to cell 0 is cell {most_similar_idx} "
          f"(similarity: {similarities[most_similar_idx]:.4f})")
    
    # Example 4: Statistics
    print("\n[Step 5] Embedding statistics:")
    print(f"  - Mean: {cell_embeddings.mean(dim=0).norm():.4f}")
    print(f"  - Std: {cell_embeddings.std(dim=0).mean():.4f}")
    print(f"  - Min: {cell_embeddings.min():.4f}")
    print(f"  - Max: {cell_embeddings.max():.4f}")
    
    # ============================================================
    # Step 6: Save embeddings (optional)
    # ============================================================
    print("\n[Step 6] Saving embeddings...")
    
    np.save("cell_embeddings.npy", cell_embeddings.cpu().numpy())
    print("βœ“ Embeddings saved to 'cell_embeddings.npy'")
    
    print("\n" + "=" * 80)
    print("Phase 1 Complete!")
    print("=" * 80)
    
    return model, cell_embeddings


if __name__ == "__main__":
    model, embeddings = main()