mineself2016 commited on
Commit
e3faf24
·
verified ·
1 Parent(s): 0bcc4ad

Add preprocess docs, examples, tokenizer mapping assets

Browse files
README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - genomics
5
+ - single-cell
6
+ - mamba
7
+ - biology
8
+ pipeline_tag: feature-extraction
9
+ ---
10
+
11
+ # GeneMamba2-24l-512d
12
+
13
+ This repository contains a GeneMamba checkpoint plus full usage assets:
14
+ - model weights (`model.safetensors`)
15
+ - custom modeling/config files for `trust_remote_code=True`
16
+ - preprocessing example from `h5ad` to `input_ids`
17
+ - tokenizer assets and id mapping files
18
+
19
+ ## 1) Input format (very important)
20
+
21
+ GeneMamba input is **ranked gene token IDs** per cell:
22
+ 1. Start from one cell expression vector
23
+ 2. Keep genes with expression > 0
24
+ 3. Sort genes by expression descending
25
+ 4. Convert each gene ID (Ensembl, e.g. `ENSG00000000003`) to token ID
26
+ 5. Use resulting list as `input_ids`
27
+
28
+ Each sample is one list of integers:
29
+
30
+ ```python
31
+ {"input_ids": [145, 2088, 531, 91, ...]}
32
+ ```
33
+
34
+ For batch input, shape is typically `(batch_size, seq_len)` after padding/truncation.
35
+
36
+ ## 2) Where tokenizer and id mapping come from
37
+
38
+ - Main tokenizer used for model inference: `tokenizer.json`
39
+ - Original full tokenizer table: `tokenizer_assets/gene_tokenizer.json`
40
+ - Gene symbol -> token id mapping: `tokenizer_assets/symbol2id.pkl`
41
+ - Token id -> gene symbol mapping: `tokenizer_assets/id2symbol.pkl`
42
+
43
+ Special tokens:
44
+ - `[UNK]` = 0
45
+ - `[PAD]` = 1
46
+
47
+ ## 3) Preprocess your data
48
+
49
+ See script:
50
+ - `examples/0_preprocess_to_input_ids.py`
51
+
52
+ Example:
53
+
54
+ ```bash
55
+ python examples/0_preprocess_to_input_ids.py \
56
+ --h5ad /path/to/your_data.h5ad \
57
+ --tokenizer_json tokenizer.json \
58
+ --output_arrow ./my_data/sorted_gene_token_ids.arrow
59
+ ```
60
+
61
+ This output Arrow file has one column: `input_ids`.
62
+
63
+ ## 4) Load model and extract embedding
64
+
65
+ ```python
66
+ from transformers import AutoModel, AutoTokenizer
67
+
68
+ model = AutoModel.from_pretrained(
69
+ "mineself2016/GeneMamba2-24l-512d",
70
+ trust_remote_code=True
71
+ )
72
+
73
+ tokenizer = AutoTokenizer.from_pretrained(
74
+ "mineself2016/GeneMamba2-24l-512d",
75
+ trust_remote_code=True
76
+ )
77
+ ```
78
+
79
+ More complete example:
80
+ - `examples/1_extract_embeddings.py`
81
+
82
+ ## 5) Source of preprocessing logic
83
+
84
+ The preprocessing/tokenization pipeline is aligned with assets from:
85
+ - `/project/zhiwei/cq5/PythonWorkSpace/gene_mamba`
86
+
87
+ Key references used:
88
+ - tokenizer: `gene_tokenizer.json`
89
+ - mappings: `symbol2id.pkl`, `id2symbol.pkl`
90
+ - dataset build logic (Arrow + `input_ids`): `utils.py` (`build_dataset`)
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "genemamba",
3
+ "architectures": [
4
+ "GeneMambaModel"
5
+ ],
6
+ "vocab_size": 25426,
7
+ "max_position_embeddings": 2048,
8
+ "hidden_size": 512,
9
+ "num_hidden_layers": 24,
10
+ "intermediate_size": 2048,
11
+ "hidden_dropout_prob": 0.1,
12
+ "initializer_range": 0.02,
13
+ "mamba_mode": "gate",
14
+ "embedding_pooling": "mean",
15
+ "num_labels": 2,
16
+ "pad_token_id": 1,
17
+ "eos_token_id": 2,
18
+ "bos_token_id": 0,
19
+ "use_cache": true,
20
+ "torch_dtype": "float32",
21
+ "transformers_version": "4.40.2"
22
+ }
configuration_genemamba.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration for GeneMamba model.
3
+ Defines all hyperparameters and settings for the GeneMamba architecture.
4
+ """
5
+
6
+ from transformers import PretrainedConfig
7
+ from typing import Optional
8
+
9
+
10
+ class GeneMambaConfig(PretrainedConfig):
11
+ """
12
+ Configuration class for GeneMamba model.
13
+
14
+ This class stores the configuration of a GeneMamba model, inheriting from PretrainedConfig.
15
+ It can be used to instantiate models from pretrained checkpoints or customize model initialization.
16
+
17
+ Args:
18
+ vocab_size (int, optional, defaults to 25426):
19
+ Vocabulary size of the model. Number of gene tokens (Ensembl Gene IDs).
20
+
21
+ hidden_size (int, optional, defaults to 512):
22
+ Dimensionality of the hidden/embedding layers (d_model in Mamba).
23
+
24
+ num_hidden_layers (int, optional, defaults to 24):
25
+ Number of Mamba layers (mamba_layer).
26
+
27
+ intermediate_size (int, optional, defaults to 2048):
28
+ Dimensionality of intermediate representations in MLP.
29
+
30
+ max_position_embeddings (int, optional, defaults to 2048):
31
+ Maximum sequence length (seq_len).
32
+
33
+ hidden_dropout_prob (float, optional, defaults to 0.1):
34
+ Dropout probability for hidden states.
35
+
36
+ initializer_range (float, optional, defaults to 0.02):
37
+ Standard deviation of truncated normal initializer.
38
+
39
+ mamba_mode (str, optional, defaults to "gate"):
40
+ Aggregation mode for bidirectional Mamba layers.
41
+ Options: "mean", "sum", "concat", "gate".
42
+
43
+ embedding_pooling (str, optional, defaults to "mean"):
44
+ Method for pooling to get cell embedding.
45
+ Options: "CLS", "mean", "weighted".
46
+
47
+ num_labels (int, optional, defaults to 2):
48
+ Number of labels for sequence classification tasks.
49
+
50
+ pad_token_id (int, optional, defaults to 1):
51
+ Token ID for padding.
52
+
53
+ bos_token_id (int, optional, defaults to None):
54
+ Token ID for beginning of sequence.
55
+
56
+ eos_token_id (int, optional, defaults to None):
57
+ Token ID for end of sequence.
58
+ """
59
+
60
+ model_type = "genemamba"
61
+ attribute_map = {
62
+ "hidden_size": "hidden_size",
63
+ "num_hidden_layers": "num_hidden_layers",
64
+ }
65
+
66
+ def __init__(
67
+ self,
68
+ vocab_size: int = 25426,
69
+ hidden_size: int = 512,
70
+ num_hidden_layers: int = 24,
71
+ intermediate_size: int = 2048,
72
+ max_position_embeddings: int = 2048,
73
+ hidden_dropout_prob: float = 0.1,
74
+ initializer_range: float = 0.02,
75
+ mamba_mode: str = "gate",
76
+ embedding_pooling: str = "mean",
77
+ num_labels: int = 2,
78
+ pad_token_id: int = 1,
79
+ bos_token_id: Optional[int] = None,
80
+ eos_token_id: Optional[int] = None,
81
+ **kwargs
82
+ ):
83
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
84
+
85
+ self.vocab_size = vocab_size
86
+ self.hidden_size = hidden_size
87
+ self.num_hidden_layers = num_hidden_layers
88
+ self.intermediate_size = intermediate_size
89
+ self.max_position_embeddings = max_position_embeddings
90
+ self.hidden_dropout_prob = hidden_dropout_prob
91
+ self.initializer_range = initializer_range
92
+ self.mamba_mode = mamba_mode
93
+ self.embedding_pooling = embedding_pooling
94
+ self.num_labels = num_labels
95
+ self.pad_token_id = pad_token_id
96
+ self.bos_token_id = bos_token_id
97
+ self.eos_token_id = eos_token_id
examples/0_preprocess_to_input_ids.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import scanpy as sc
8
+ import pyarrow as pa
9
+
10
+
11
+ def load_vocab(tokenizer_json_path: str):
12
+ with open(tokenizer_json_path, "r") as f:
13
+ tokenizer = json.load(f)
14
+ vocab = tokenizer["model"]["vocab"]
15
+ pad_id = vocab.get("[PAD]", 1)
16
+ unk_id = vocab.get("[UNK]", 0)
17
+ return vocab, pad_id, unk_id
18
+
19
+
20
+ def ranked_gene_ids_for_cell(expr_values, gene_names, vocab):
21
+ nonzero_idx = np.where(expr_values > 0)[0]
22
+ if len(nonzero_idx) == 0:
23
+ return []
24
+
25
+ genes = np.array(gene_names)[nonzero_idx]
26
+ values = expr_values[nonzero_idx]
27
+
28
+ order = np.argsort(-values)
29
+ ranked_genes = genes[order]
30
+
31
+ token_ids = [vocab[g] for g in ranked_genes if g in vocab]
32
+ return token_ids
33
+
34
+
35
+ def main():
36
+ parser = argparse.ArgumentParser(description="Convert h5ad to GeneMamba input_ids (Arrow)")
37
+ parser.add_argument("--h5ad", required=True, help="Input h5ad file")
38
+ parser.add_argument("--tokenizer_json", required=True, help="Path to tokenizer.json or gene_tokenizer.json")
39
+ parser.add_argument("--output_arrow", required=True, help="Output arrow file path")
40
+ parser.add_argument("--max_cells", type=int, default=None, help="Optional: process first N cells only")
41
+ args = parser.parse_args()
42
+
43
+ adata = sc.read_h5ad(args.h5ad)
44
+ vocab, _, _ = load_vocab(args.tokenizer_json)
45
+
46
+ gene_names = list(adata.var_names)
47
+ n_cells = adata.n_obs if args.max_cells is None else min(args.max_cells, adata.n_obs)
48
+
49
+ rows = []
50
+ X = adata.X
51
+
52
+ for i in range(n_cells):
53
+ row = X[i]
54
+ if hasattr(row, "toarray"):
55
+ expr = row.toarray().ravel()
56
+ else:
57
+ expr = np.asarray(row).ravel()
58
+
59
+ token_ids = ranked_gene_ids_for_cell(expr, gene_names, vocab)
60
+ rows.append(token_ids)
61
+
62
+ df = pd.DataFrame({"input_ids": rows})
63
+ table = pa.Table.from_pandas(df)
64
+
65
+ output_path = Path(args.output_arrow)
66
+ output_path.parent.mkdir(parents=True, exist_ok=True)
67
+ with pa.OSFile(str(output_path), "wb") as sink:
68
+ with pa.ipc.new_stream(sink, table.schema) as writer:
69
+ writer.write_table(table)
70
+
71
+ print(f"Saved {len(rows)} cells to {output_path}")
72
+
73
+
74
+ if __name__ == "__main__":
75
+ main()
examples/1_extract_embeddings.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 1: Extract Cell Embeddings
3
+ Demonstrates how to load GeneMamba and extract cell embeddings for downstream analysis.
4
+
5
+ Usage:
6
+ python examples/1_extract_embeddings.py
7
+ """
8
+
9
+ import torch
10
+ import numpy as np
11
+ from transformers import AutoTokenizer, AutoModel
12
+
13
+
14
+ def main():
15
+ print("=" * 80)
16
+ print("GeneMamba Phase 1: Extract Cell Embeddings")
17
+ print("=" * 80)
18
+
19
+ # ============================================================
20
+ # Step 1: Load pretrained model and tokenizer
21
+ # ============================================================
22
+ print("\n[Step 1] Loading model and tokenizer...")
23
+
24
+ # For this example, we use a local model path
25
+ # In practice, you would use: "username/GeneMamba-24l-512d"
26
+ model_name = "GeneMamba-24l-512d" # Change to HF Hub path when available
27
+
28
+ try:
29
+ tokenizer = AutoTokenizer.from_pretrained(
30
+ model_name,
31
+ trust_remote_code=True,
32
+ local_files_only=True # Try local first
33
+ )
34
+ model = AutoModel.from_pretrained(
35
+ model_name,
36
+ trust_remote_code=True,
37
+ local_files_only=True
38
+ )
39
+ except Exception as e:
40
+ print(f"Note: Could not load from '{model_name}': {e}")
41
+ print("Using mock data for demonstration...")
42
+
43
+ # For demonstration without actual checkpoint
44
+ from configuration_genemamba import GeneMambaConfig
45
+ from modeling_genemamba import GeneMambaModel
46
+
47
+ config = GeneMambaConfig(
48
+ vocab_size=25426,
49
+ hidden_size=512,
50
+ num_hidden_layers=24,
51
+ embedding_pooling="mean",
52
+ )
53
+ model = GeneMambaModel(config)
54
+ tokenizer = None
55
+
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ model = model.to(device)
58
+ model.eval()
59
+
60
+ print(f"✓ Model loaded on device: {device}")
61
+ print(f"✓ Model config: hidden_size={model.config.hidden_size}, "
62
+ f"num_layers={model.config.num_hidden_layers}")
63
+
64
+ # ============================================================
65
+ # Step 2: Prepare simulated single-cell data
66
+ # ============================================================
67
+ print("\n[Step 2] Preparing sample data...")
68
+
69
+ batch_size = 8
70
+ seq_len = 2048
71
+ vocab_size = 25426
72
+
73
+ # Simulate ranked gene sequences
74
+ # In practice, this would come from your scRNA-seq data
75
+ # Genes should be ranked by expression (highest first)
76
+ input_ids = torch.randint(2, vocab_size, (batch_size, seq_len)).to(device)
77
+
78
+ print(f"✓ Created sample input:")
79
+ print(f" - Batch size: {batch_size}")
80
+ print(f" - Sequence length: {seq_len}")
81
+ print(f" - Input shape: {input_ids.shape}")
82
+
83
+ # ============================================================
84
+ # Step 3: Inference - Extract embeddings
85
+ # ============================================================
86
+ print("\n[Step 3] Extracting cell embeddings...")
87
+
88
+ with torch.no_grad():
89
+ outputs = model(input_ids, output_hidden_states=False)
90
+
91
+ # Get the pooled embedding (cell representation)
92
+ cell_embeddings = outputs.pooled_embedding
93
+
94
+ print(f"✓ Extraction complete!")
95
+ print(f" - Cell embeddings shape: {cell_embeddings.shape}")
96
+ print(f" - Pooling method used: {outputs.embedding_pooling}")
97
+ print(f" - Embedding type: {cell_embeddings.dtype}")
98
+
99
+ # ============================================================
100
+ # Step 4: Example downstream analyses
101
+ # ============================================================
102
+ print("\n[Step 4] Example downstream uses...")
103
+
104
+ # Example 1: Clustering (KMeans)
105
+ from sklearn.cluster import KMeans
106
+ n_clusters = 3
107
+ kmeans = KMeans(n_clusters=n_clusters, n_init=10)
108
+ clusters = kmeans.fit_predict(cell_embeddings.cpu().numpy())
109
+ print(f"✓ Clustering: Assigned {len(np.unique(clusters))} clusters")
110
+
111
+ # Example 2: Dimensionality reduction (PCA)
112
+ from sklearn.decomposition import PCA
113
+ pca = PCA(n_components=2)
114
+ embedding_2d = pca.fit_transform(cell_embeddings.cpu().numpy())
115
+ print(f"✓ PCA reduction: {cell_embeddings.shape} → {embedding_2d.shape}")
116
+
117
+ # Example 3: Similarity search
118
+ # Find the most similar cell to the first cell
119
+ similarities = torch.nn.functional.cosine_similarity(
120
+ cell_embeddings[0:1],
121
+ cell_embeddings
122
+ )
123
+ most_similar_idx = torch.argmax(similarities).item()
124
+ print(f"✓ Similarity search: Most similar cell to cell 0 is cell {most_similar_idx} "
125
+ f"(similarity: {similarities[most_similar_idx]:.4f})")
126
+
127
+ # Example 4: Statistics
128
+ print("\n[Step 5] Embedding statistics:")
129
+ print(f" - Mean: {cell_embeddings.mean(dim=0).norm():.4f}")
130
+ print(f" - Std: {cell_embeddings.std(dim=0).mean():.4f}")
131
+ print(f" - Min: {cell_embeddings.min():.4f}")
132
+ print(f" - Max: {cell_embeddings.max():.4f}")
133
+
134
+ # ============================================================
135
+ # Step 6: Save embeddings (optional)
136
+ # ============================================================
137
+ print("\n[Step 6] Saving embeddings...")
138
+
139
+ np.save("cell_embeddings.npy", cell_embeddings.cpu().numpy())
140
+ print("✓ Embeddings saved to 'cell_embeddings.npy'")
141
+
142
+ print("\n" + "=" * 80)
143
+ print("Phase 1 Complete!")
144
+ print("=" * 80)
145
+
146
+ return model, cell_embeddings
147
+
148
+
149
+ if __name__ == "__main__":
150
+ model, embeddings = main()
examples/__pycache__/0_preprocess_to_input_ids.cpython-39.pyc ADDED
Binary file (2.64 kB). View file
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccb1fcb0ee4b3ea2013099b9b187455e160d3b66b76c606715231b70b13c2784
3
+ size 262998656
modeling_genemamba.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch implementation of GeneMamba model for Hugging Face Transformers.
3
+ Includes backbone model and task-specific heads for various downstream tasks.
4
+ """
5
+
6
+ import math
7
+ import logging
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn.init import normal_, constant_
14
+
15
+ from transformers import PreTrainedModel, PretrainedConfig
16
+ from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput
17
+ from transformers.models.auto import register_model_for_auto_class
18
+
19
+ from mamba_ssm import Mamba
20
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm
21
+
22
+ from .configuration_genemamba import GeneMambaConfig
23
+ from .modeling_outputs import GeneMambaModelOutput, GeneMambaSequenceClassifierOutput, GeneMambaMaskedLMOutput
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ # ===========================
29
+ # Core Architecture Components
30
+ # ===========================
31
+
32
+ class EncoderLayer(nn.Module):
33
+ """
34
+ Single Mamba encoder layer with residual connection.
35
+ Applies a Mamba2 or Mamba layer followed by addition with input.
36
+
37
+ Args:
38
+ hidden_size (int): Dimension of hidden states.
39
+ """
40
+
41
+ def __init__(self, hidden_size: int):
42
+ super(EncoderLayer, self).__init__()
43
+ self.mamba = Mamba(d_model=hidden_size, d_state=64, d_conv=4, expand=2)
44
+
45
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
46
+ """
47
+ Args:
48
+ X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
49
+
50
+ Returns:
51
+ torch.Tensor: Output after Mamba layer and residual connection.
52
+ """
53
+ output = self.mamba(X) + X
54
+ return output
55
+
56
+
57
+ class MambaMixer(nn.Module):
58
+ """
59
+ Stack of Mamba encoder layers with bidirectional processing and aggregation.
60
+ Processes sequences in both forward and reverse directions, then aggregates.
61
+
62
+ Args:
63
+ mode (str): Aggregation mode. Options: "mean", "sum", "concat", "gate".
64
+ hidden_size (int): Dimension of hidden states.
65
+ num_hidden_layers (int): Number of Mamba layers.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ mode: str = "gate",
71
+ hidden_size: int = 512,
72
+ num_hidden_layers: int = 24
73
+ ):
74
+ super(MambaMixer, self).__init__()
75
+ self.mode = mode
76
+ self.hidden_size = hidden_size
77
+
78
+ # Create Mamba layers
79
+ self.layers = nn.ModuleList(
80
+ [EncoderLayer(hidden_size) for _ in range(num_hidden_layers)]
81
+ )
82
+
83
+ # Aggregation modules for certain modes
84
+ if mode in ["concat", "gate"]:
85
+ self.aggr = nn.Linear(hidden_size * 2, hidden_size)
86
+
87
+ def flip_sequence(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
88
+ """
89
+ Reverse a sequence based on actual length (ignoring padding).
90
+
91
+ Args:
92
+ X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
93
+ mask (torch.Tensor, optional): Padding mask of shape (batch_size, seq_len).
94
+
95
+ Returns:
96
+ torch.Tensor: Reversed tensor.
97
+ """
98
+ batch_size, seq_length, embedding_dim = X.size()
99
+
100
+ if mask is None:
101
+ # Simple flip
102
+ return X.flip([1])
103
+
104
+ # Flip based on actual sequence length (marked by mask)
105
+ lengths = (~mask).sum(dim=1)
106
+ pos_tensor = torch.arange(seq_length, device=X.device).unsqueeze(0).expand(batch_size, -1)
107
+ flip_mask = pos_tensor < lengths.unsqueeze(1)
108
+ reversed_positions = torch.where(
109
+ flip_mask,
110
+ lengths.unsqueeze(1) - 1 - pos_tensor,
111
+ pos_tensor
112
+ )
113
+
114
+ X_reverse = torch.gather(X, 1, reversed_positions.unsqueeze(-1).expand(-1, -1, embedding_dim))
115
+ return X_reverse
116
+
117
+ def forward(
118
+ self,
119
+ X: torch.Tensor,
120
+ padding_mask: Optional[torch.Tensor] = None
121
+ ) -> torch.Tensor:
122
+ """
123
+ Process sequence through bidirectional Mamba layers.
124
+
125
+ Args:
126
+ X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
127
+ padding_mask (torch.Tensor, optional): Padding mask.
128
+
129
+ Returns:
130
+ torch.Tensor: Output after processing all layers and aggregation.
131
+ """
132
+
133
+ for layer in self.layers:
134
+ # Flip sequence for reverse processing
135
+ X_flip = self.flip_sequence(X, padding_mask)
136
+
137
+ # Forward and reverse passes
138
+ X_f = layer(X)
139
+ X_b = layer(X_flip)
140
+
141
+ # Flip back the reverse output
142
+ X_b = self.flip_sequence(X_b, padding_mask)
143
+
144
+ # Aggregate forward and reverse
145
+ if self.mode == "mean":
146
+ X = (X_f + X_b) / 2
147
+ elif self.mode == "sum":
148
+ X = X_f + X_b
149
+ elif self.mode == "concat":
150
+ X = torch.cat([X_f, X_b], dim=-1)
151
+ X = self.aggr(X)
152
+ elif self.mode == "gate":
153
+ z = torch.sigmoid(self.aggr(torch.cat([X_f, X_b], dim=-1)))
154
+ X = z * X_f + (1 - z) * X_b
155
+ else:
156
+ raise ValueError(f"Invalid aggregation mode: {self.mode}")
157
+
158
+ return X
159
+
160
+
161
+ # ===========================
162
+ # Base Model Classes
163
+ # ===========================
164
+
165
+ class GeneMambaPreTrainedModel(PreTrainedModel):
166
+ """
167
+ Base class for all GeneMamba models.
168
+ Handles weight initialization and provides standard model interfaces.
169
+ """
170
+
171
+ config_class = GeneMambaConfig
172
+ base_model_prefix = "genemamba"
173
+ supports_gradient_checkpointing = True
174
+
175
+ def _init_weights(self, module):
176
+ """Initialize module weights."""
177
+ if isinstance(module, nn.Linear):
178
+ normal_(module.weight, std=self.config.initializer_range)
179
+ if module.bias is not None:
180
+ constant_(module.bias, 0.0)
181
+ elif isinstance(module, nn.Embedding):
182
+ normal_(module.weight, std=self.config.initializer_range)
183
+ if module.padding_idx is not None:
184
+ module.weight.data[module.padding_idx].zero_()
185
+ elif isinstance(module, nn.LayerNorm):
186
+ constant_(module.bias, 0.0)
187
+ constant_(module.weight, 1.0)
188
+
189
+
190
+ class GeneMambaModel(GeneMambaPreTrainedModel):
191
+ """
192
+ GeneMamba backbone model - outputs cell embeddings and hidden states.
193
+ This is the core model used by task-specific heads.
194
+
195
+ Args:
196
+ config (GeneMambaConfig): Model configuration class.
197
+ """
198
+
199
+ def __init__(self, config: GeneMambaConfig):
200
+ super().__init__(config)
201
+ self.config = config
202
+
203
+ # Embedding layer
204
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
205
+
206
+ # Mamba layers with bidirectional aggregation
207
+ self.mamba_mixer = MambaMixer(
208
+ mode=config.mamba_mode,
209
+ hidden_size=config.hidden_size,
210
+ num_hidden_layers=config.num_hidden_layers
211
+ )
212
+
213
+ # Final layer normalization
214
+ self.norm = RMSNorm(config.hidden_size)
215
+
216
+ self.apply(self._init_weights)
217
+
218
+ def get_input_embeddings(self) -> nn.Embedding:
219
+ """Return embedding layer."""
220
+ return self.embeddings
221
+
222
+ def set_input_embeddings(self, value: nn.Embedding):
223
+ """Set embedding layer."""
224
+ self.embeddings = value
225
+
226
+ def forward(
227
+ self,
228
+ input_ids: torch.Tensor,
229
+ attention_mask: Optional[torch.Tensor] = None,
230
+ output_hidden_states: bool = False,
231
+ ) -> GeneMambaModelOutput:
232
+ """
233
+ Args:
234
+ input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
235
+ attention_mask (torch.Tensor, optional): Attention mask of shape (batch_size, seq_len).
236
+ output_hidden_states (bool): Whether to output hidden states from all layers.
237
+
238
+ Returns:
239
+ GeneMambaModelOutput: Contains last_hidden_state, pooled_embedding, etc.
240
+ """
241
+ # Get embeddings
242
+ hidden_states = self.embeddings(input_ids)
243
+
244
+ # Pass through Mamba layers
245
+ hidden_states = self.mamba_mixer(hidden_states, attention_mask)
246
+
247
+ # Apply final normalization
248
+ hidden_states = self.norm(hidden_states)
249
+
250
+ # Compute pooled embedding (cell representation)
251
+ if self.config.embedding_pooling == "CLS":
252
+ # Use first token (CLS)
253
+ pooled_embedding = hidden_states[:, 0, :]
254
+ elif self.config.embedding_pooling == "mean":
255
+ # Mean pooling over sequence
256
+ if attention_mask is not None:
257
+ mask = attention_mask.unsqueeze(-1).expand(hidden_states.shape).float()
258
+ pooled_embedding = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1)
259
+ else:
260
+ pooled_embedding = hidden_states.mean(dim=1)
261
+ else:
262
+ raise ValueError(f"Unsupported embedding_pooling: {self.config.embedding_pooling}")
263
+
264
+ return GeneMambaModelOutput(
265
+ last_hidden_state=hidden_states,
266
+ pooled_embedding=pooled_embedding,
267
+ hidden_states=hidden_states if output_hidden_states else None,
268
+ embedding_pooling=self.config.embedding_pooling,
269
+ )
270
+
271
+
272
+ # ===========================
273
+ # Task-Specific Models
274
+ # ===========================
275
+
276
+ @register_model_for_auto_class("AutoModel")
277
+ class GeneMambaForMaskedLM(GeneMambaPreTrainedModel):
278
+ """
279
+ GeneMamba model for masked language modeling (MLM).
280
+ Suitable for pretraining and domain adaptation.
281
+
282
+ Args:
283
+ config (GeneMambaConfig): Model configuration class.
284
+ """
285
+
286
+ def __init__(self, config: GeneMambaConfig):
287
+ super().__init__(config)
288
+ self.genemamba = GeneMambaModel(config)
289
+
290
+ # Language modeling head
291
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
292
+
293
+ self.apply(self._init_weights)
294
+
295
+ def forward(
296
+ self,
297
+ input_ids: torch.Tensor,
298
+ attention_mask: Optional[torch.Tensor] = None,
299
+ labels: Optional[torch.Tensor] = None,
300
+ output_hidden_states: bool = False,
301
+ ) -> GeneMambaMaskedLMOutput:
302
+ """
303
+ Args:
304
+ input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
305
+ attention_mask (torch.Tensor, optional): Attention mask.
306
+ labels (torch.Tensor, optional): Target token ids for MLM loss.
307
+ output_hidden_states (bool): Whether to output hidden states.
308
+
309
+ Returns:
310
+ GeneMambaMaskedLMOutput: Contains logits and optional loss.
311
+ """
312
+ outputs = self.genemamba(
313
+ input_ids=input_ids,
314
+ attention_mask=attention_mask,
315
+ output_hidden_states=output_hidden_states,
316
+ )
317
+
318
+ logits = self.lm_head(outputs.last_hidden_state)
319
+
320
+ loss = None
321
+ if labels is not None:
322
+ loss_fct = nn.CrossEntropyLoss()
323
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
324
+
325
+ return GeneMambaMaskedLMOutput(
326
+ loss=loss,
327
+ logits=logits,
328
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
329
+ )
330
+
331
+
332
+ @register_model_for_auto_class("AutoModelForSequenceClassification")
333
+ class GeneMambaForSequenceClassification(GeneMambaPreTrainedModel):
334
+ """
335
+ GeneMamba model for sequence classification tasks.
336
+ Ideal for cell type annotation, tissue classification, etc.
337
+
338
+ Args:
339
+ config (GeneMambaConfig): Model configuration class.
340
+ """
341
+
342
+ def __init__(self, config: GeneMambaConfig):
343
+ super().__init__(config)
344
+ self.num_labels = config.num_labels
345
+ self.config = config
346
+
347
+ self.genemamba = GeneMambaModel(config)
348
+
349
+ # Classification head
350
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
351
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
352
+
353
+ self.apply(self._init_weights)
354
+
355
+ def forward(
356
+ self,
357
+ input_ids: torch.Tensor,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ labels: Optional[torch.Tensor] = None,
360
+ output_hidden_states: bool = False,
361
+ ) -> GeneMambaSequenceClassifierOutput:
362
+ """
363
+ Args:
364
+ input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
365
+ attention_mask (torch.Tensor, optional): Attention mask.
366
+ labels (torch.Tensor, optional): Class labels for classification loss.
367
+ output_hidden_states (bool): Whether to output hidden states.
368
+
369
+ Returns:
370
+ GeneMambaSequenceClassifierOutput: Contains logits, optional loss, and embedding.
371
+ """
372
+ outputs = self.genemamba(
373
+ input_ids=input_ids,
374
+ attention_mask=attention_mask,
375
+ output_hidden_states=output_hidden_states,
376
+ )
377
+
378
+ pooled_embedding = outputs.pooled_embedding
379
+ logits = self.classifier(self.dropout(pooled_embedding))
380
+
381
+ loss = None
382
+ if labels is not None:
383
+ loss_fct = nn.CrossEntropyLoss()
384
+ loss = loss_fct(logits, labels)
385
+
386
+ return GeneMambaSequenceClassifierOutput(
387
+ loss=loss,
388
+ logits=logits,
389
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
390
+ pooled_embedding=pooled_embedding,
391
+ )
392
+
393
+
394
+ # Register tokenizer class
395
+ register_model_for_auto_class("AutoModelForMaskedLM")(GeneMambaForMaskedLM)
modeling_outputs.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom ModelOutput classes for GeneMamba.
3
+ Defines the output structure for different GeneMamba tasks.
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple
8
+ import torch
9
+ from transformers.utils import ModelOutput
10
+
11
+
12
+ @dataclass
13
+ class GeneMambaModelOutput(ModelOutput):
14
+ """
15
+ Base output class for GeneMamba models.
16
+
17
+ Attributes:
18
+ last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)):
19
+ Sequence of hidden-states at the output of the last layer of the model.
20
+
21
+ hidden_states (tuple(torch.FloatTensor), optional):
22
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
23
+
24
+ pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size)):
25
+ Cell/sequence-level embedding (pooled representation) used for downstream tasks.
26
+ This is the recommended embedding to use for classification, clustering, etc.
27
+
28
+ embedding_pooling (str):
29
+ The pooling method used to generate pooled_embedding.
30
+ """
31
+
32
+ last_hidden_state: torch.FloatTensor = None
33
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
34
+ pooled_embedding: torch.FloatTensor = None
35
+ embedding_pooling: str = "mean"
36
+
37
+
38
+ @dataclass
39
+ class GeneMambaSequenceClassifierOutput(ModelOutput):
40
+ """
41
+ Output class for GeneMamba sequence classification models.
42
+
43
+ Attributes:
44
+ loss (torch.FloatTensor of shape (), optional):
45
+ Classification loss (if labels were provided).
46
+
47
+ logits (torch.FloatTensor of shape (batch_size, num_labels)):
48
+ Classification scores (before softmax).
49
+
50
+ hidden_states (tuple(torch.FloatTensor), optional):
51
+ Hidden-states of the model at the output of each layer.
52
+
53
+ pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size), optional):
54
+ Cell embedding before classification head.
55
+ """
56
+
57
+ loss: Optional[torch.FloatTensor] = None
58
+ logits: torch.FloatTensor = None
59
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
60
+ pooled_embedding: Optional[torch.FloatTensor] = None
61
+
62
+
63
+ @dataclass
64
+ class GeneMambaMaskedLMOutput(ModelOutput):
65
+ """
66
+ Output class for GeneMamba masked language modeling.
67
+
68
+ Attributes:
69
+ loss (torch.FloatTensor of shape (), optional):
70
+ MLM loss (if labels were provided).
71
+
72
+ logits (torch.FloatTensor of shape (batch_size, sequence_length, vocab_size)):
73
+ Prediction scores of the language modeling head.
74
+
75
+ hidden_states (tuple(torch.FloatTensor), optional):
76
+ Hidden-states of the model at the output of each layer.
77
+ """
78
+
79
+ loss: Optional[torch.FloatTensor] = None
80
+ logits: torch.FloatTensor = None
81
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
special_tokens_map.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "pad_token": "[PAD]",
3
+ "unk_token": "[UNK]"
4
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_assets/gene_tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_assets/id2symbol.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d5090ff562a77a03b19c37f6a010d639b8d64b1624db2e9a7c3291f9d389293
3
+ size 634589
tokenizer_assets/symbol2id.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecc7193b1f549e513903ba37410788632252a2dda4d07876a1d91730d8697dbe
3
+ size 526232
tokenizer_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {},
3
+ "clean_up_tokenization_spaces": true,
4
+ "model_max_length": 1000000000000000019884624838656,
5
+ "pad_token": "[PAD]",
6
+ "tokenizer_class": "PreTrainedTokenizerFast",
7
+ "unk_token": "[UNK]"
8
+ }