Add preprocess docs, examples, tokenizer mapping assets
Browse files- README.md +90 -0
- config.json +22 -0
- configuration_genemamba.py +97 -0
- examples/0_preprocess_to_input_ids.py +75 -0
- examples/1_extract_embeddings.py +150 -0
- examples/__pycache__/0_preprocess_to_input_ids.cpython-39.pyc +0 -0
- model.safetensors +3 -0
- modeling_genemamba.py +395 -0
- modeling_outputs.py +81 -0
- special_tokens_map.json +4 -0
- tokenizer.json +0 -0
- tokenizer_assets/gene_tokenizer.json +0 -0
- tokenizer_assets/id2symbol.pkl +3 -0
- tokenizer_assets/symbol2id.pkl +3 -0
- tokenizer_config.json +8 -0
README.md
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags:
|
| 4 |
+
- genomics
|
| 5 |
+
- single-cell
|
| 6 |
+
- mamba
|
| 7 |
+
- biology
|
| 8 |
+
pipeline_tag: feature-extraction
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# GeneMamba2-24l-512d
|
| 12 |
+
|
| 13 |
+
This repository contains a GeneMamba checkpoint plus full usage assets:
|
| 14 |
+
- model weights (`model.safetensors`)
|
| 15 |
+
- custom modeling/config files for `trust_remote_code=True`
|
| 16 |
+
- preprocessing example from `h5ad` to `input_ids`
|
| 17 |
+
- tokenizer assets and id mapping files
|
| 18 |
+
|
| 19 |
+
## 1) Input format (very important)
|
| 20 |
+
|
| 21 |
+
GeneMamba input is **ranked gene token IDs** per cell:
|
| 22 |
+
1. Start from one cell expression vector
|
| 23 |
+
2. Keep genes with expression > 0
|
| 24 |
+
3. Sort genes by expression descending
|
| 25 |
+
4. Convert each gene ID (Ensembl, e.g. `ENSG00000000003`) to token ID
|
| 26 |
+
5. Use resulting list as `input_ids`
|
| 27 |
+
|
| 28 |
+
Each sample is one list of integers:
|
| 29 |
+
|
| 30 |
+
```python
|
| 31 |
+
{"input_ids": [145, 2088, 531, 91, ...]}
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
For batch input, shape is typically `(batch_size, seq_len)` after padding/truncation.
|
| 35 |
+
|
| 36 |
+
## 2) Where tokenizer and id mapping come from
|
| 37 |
+
|
| 38 |
+
- Main tokenizer used for model inference: `tokenizer.json`
|
| 39 |
+
- Original full tokenizer table: `tokenizer_assets/gene_tokenizer.json`
|
| 40 |
+
- Gene symbol -> token id mapping: `tokenizer_assets/symbol2id.pkl`
|
| 41 |
+
- Token id -> gene symbol mapping: `tokenizer_assets/id2symbol.pkl`
|
| 42 |
+
|
| 43 |
+
Special tokens:
|
| 44 |
+
- `[UNK]` = 0
|
| 45 |
+
- `[PAD]` = 1
|
| 46 |
+
|
| 47 |
+
## 3) Preprocess your data
|
| 48 |
+
|
| 49 |
+
See script:
|
| 50 |
+
- `examples/0_preprocess_to_input_ids.py`
|
| 51 |
+
|
| 52 |
+
Example:
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
python examples/0_preprocess_to_input_ids.py \
|
| 56 |
+
--h5ad /path/to/your_data.h5ad \
|
| 57 |
+
--tokenizer_json tokenizer.json \
|
| 58 |
+
--output_arrow ./my_data/sorted_gene_token_ids.arrow
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
This output Arrow file has one column: `input_ids`.
|
| 62 |
+
|
| 63 |
+
## 4) Load model and extract embedding
|
| 64 |
+
|
| 65 |
+
```python
|
| 66 |
+
from transformers import AutoModel, AutoTokenizer
|
| 67 |
+
|
| 68 |
+
model = AutoModel.from_pretrained(
|
| 69 |
+
"mineself2016/GeneMamba2-24l-512d",
|
| 70 |
+
trust_remote_code=True
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 74 |
+
"mineself2016/GeneMamba2-24l-512d",
|
| 75 |
+
trust_remote_code=True
|
| 76 |
+
)
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
More complete example:
|
| 80 |
+
- `examples/1_extract_embeddings.py`
|
| 81 |
+
|
| 82 |
+
## 5) Source of preprocessing logic
|
| 83 |
+
|
| 84 |
+
The preprocessing/tokenization pipeline is aligned with assets from:
|
| 85 |
+
- `/project/zhiwei/cq5/PythonWorkSpace/gene_mamba`
|
| 86 |
+
|
| 87 |
+
Key references used:
|
| 88 |
+
- tokenizer: `gene_tokenizer.json`
|
| 89 |
+
- mappings: `symbol2id.pkl`, `id2symbol.pkl`
|
| 90 |
+
- dataset build logic (Arrow + `input_ids`): `utils.py` (`build_dataset`)
|
config.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "genemamba",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"GeneMambaModel"
|
| 5 |
+
],
|
| 6 |
+
"vocab_size": 25426,
|
| 7 |
+
"max_position_embeddings": 2048,
|
| 8 |
+
"hidden_size": 512,
|
| 9 |
+
"num_hidden_layers": 24,
|
| 10 |
+
"intermediate_size": 2048,
|
| 11 |
+
"hidden_dropout_prob": 0.1,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"mamba_mode": "gate",
|
| 14 |
+
"embedding_pooling": "mean",
|
| 15 |
+
"num_labels": 2,
|
| 16 |
+
"pad_token_id": 1,
|
| 17 |
+
"eos_token_id": 2,
|
| 18 |
+
"bos_token_id": 0,
|
| 19 |
+
"use_cache": true,
|
| 20 |
+
"torch_dtype": "float32",
|
| 21 |
+
"transformers_version": "4.40.2"
|
| 22 |
+
}
|
configuration_genemamba.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration for GeneMamba model.
|
| 3 |
+
Defines all hyperparameters and settings for the GeneMamba architecture.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from transformers import PretrainedConfig
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GeneMambaConfig(PretrainedConfig):
|
| 11 |
+
"""
|
| 12 |
+
Configuration class for GeneMamba model.
|
| 13 |
+
|
| 14 |
+
This class stores the configuration of a GeneMamba model, inheriting from PretrainedConfig.
|
| 15 |
+
It can be used to instantiate models from pretrained checkpoints or customize model initialization.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
vocab_size (int, optional, defaults to 25426):
|
| 19 |
+
Vocabulary size of the model. Number of gene tokens (Ensembl Gene IDs).
|
| 20 |
+
|
| 21 |
+
hidden_size (int, optional, defaults to 512):
|
| 22 |
+
Dimensionality of the hidden/embedding layers (d_model in Mamba).
|
| 23 |
+
|
| 24 |
+
num_hidden_layers (int, optional, defaults to 24):
|
| 25 |
+
Number of Mamba layers (mamba_layer).
|
| 26 |
+
|
| 27 |
+
intermediate_size (int, optional, defaults to 2048):
|
| 28 |
+
Dimensionality of intermediate representations in MLP.
|
| 29 |
+
|
| 30 |
+
max_position_embeddings (int, optional, defaults to 2048):
|
| 31 |
+
Maximum sequence length (seq_len).
|
| 32 |
+
|
| 33 |
+
hidden_dropout_prob (float, optional, defaults to 0.1):
|
| 34 |
+
Dropout probability for hidden states.
|
| 35 |
+
|
| 36 |
+
initializer_range (float, optional, defaults to 0.02):
|
| 37 |
+
Standard deviation of truncated normal initializer.
|
| 38 |
+
|
| 39 |
+
mamba_mode (str, optional, defaults to "gate"):
|
| 40 |
+
Aggregation mode for bidirectional Mamba layers.
|
| 41 |
+
Options: "mean", "sum", "concat", "gate".
|
| 42 |
+
|
| 43 |
+
embedding_pooling (str, optional, defaults to "mean"):
|
| 44 |
+
Method for pooling to get cell embedding.
|
| 45 |
+
Options: "CLS", "mean", "weighted".
|
| 46 |
+
|
| 47 |
+
num_labels (int, optional, defaults to 2):
|
| 48 |
+
Number of labels for sequence classification tasks.
|
| 49 |
+
|
| 50 |
+
pad_token_id (int, optional, defaults to 1):
|
| 51 |
+
Token ID for padding.
|
| 52 |
+
|
| 53 |
+
bos_token_id (int, optional, defaults to None):
|
| 54 |
+
Token ID for beginning of sequence.
|
| 55 |
+
|
| 56 |
+
eos_token_id (int, optional, defaults to None):
|
| 57 |
+
Token ID for end of sequence.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
model_type = "genemamba"
|
| 61 |
+
attribute_map = {
|
| 62 |
+
"hidden_size": "hidden_size",
|
| 63 |
+
"num_hidden_layers": "num_hidden_layers",
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
vocab_size: int = 25426,
|
| 69 |
+
hidden_size: int = 512,
|
| 70 |
+
num_hidden_layers: int = 24,
|
| 71 |
+
intermediate_size: int = 2048,
|
| 72 |
+
max_position_embeddings: int = 2048,
|
| 73 |
+
hidden_dropout_prob: float = 0.1,
|
| 74 |
+
initializer_range: float = 0.02,
|
| 75 |
+
mamba_mode: str = "gate",
|
| 76 |
+
embedding_pooling: str = "mean",
|
| 77 |
+
num_labels: int = 2,
|
| 78 |
+
pad_token_id: int = 1,
|
| 79 |
+
bos_token_id: Optional[int] = None,
|
| 80 |
+
eos_token_id: Optional[int] = None,
|
| 81 |
+
**kwargs
|
| 82 |
+
):
|
| 83 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
| 84 |
+
|
| 85 |
+
self.vocab_size = vocab_size
|
| 86 |
+
self.hidden_size = hidden_size
|
| 87 |
+
self.num_hidden_layers = num_hidden_layers
|
| 88 |
+
self.intermediate_size = intermediate_size
|
| 89 |
+
self.max_position_embeddings = max_position_embeddings
|
| 90 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 91 |
+
self.initializer_range = initializer_range
|
| 92 |
+
self.mamba_mode = mamba_mode
|
| 93 |
+
self.embedding_pooling = embedding_pooling
|
| 94 |
+
self.num_labels = num_labels
|
| 95 |
+
self.pad_token_id = pad_token_id
|
| 96 |
+
self.bos_token_id = bos_token_id
|
| 97 |
+
self.eos_token_id = eos_token_id
|
examples/0_preprocess_to_input_ids.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import scanpy as sc
|
| 8 |
+
import pyarrow as pa
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_vocab(tokenizer_json_path: str):
|
| 12 |
+
with open(tokenizer_json_path, "r") as f:
|
| 13 |
+
tokenizer = json.load(f)
|
| 14 |
+
vocab = tokenizer["model"]["vocab"]
|
| 15 |
+
pad_id = vocab.get("[PAD]", 1)
|
| 16 |
+
unk_id = vocab.get("[UNK]", 0)
|
| 17 |
+
return vocab, pad_id, unk_id
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def ranked_gene_ids_for_cell(expr_values, gene_names, vocab):
|
| 21 |
+
nonzero_idx = np.where(expr_values > 0)[0]
|
| 22 |
+
if len(nonzero_idx) == 0:
|
| 23 |
+
return []
|
| 24 |
+
|
| 25 |
+
genes = np.array(gene_names)[nonzero_idx]
|
| 26 |
+
values = expr_values[nonzero_idx]
|
| 27 |
+
|
| 28 |
+
order = np.argsort(-values)
|
| 29 |
+
ranked_genes = genes[order]
|
| 30 |
+
|
| 31 |
+
token_ids = [vocab[g] for g in ranked_genes if g in vocab]
|
| 32 |
+
return token_ids
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def main():
|
| 36 |
+
parser = argparse.ArgumentParser(description="Convert h5ad to GeneMamba input_ids (Arrow)")
|
| 37 |
+
parser.add_argument("--h5ad", required=True, help="Input h5ad file")
|
| 38 |
+
parser.add_argument("--tokenizer_json", required=True, help="Path to tokenizer.json or gene_tokenizer.json")
|
| 39 |
+
parser.add_argument("--output_arrow", required=True, help="Output arrow file path")
|
| 40 |
+
parser.add_argument("--max_cells", type=int, default=None, help="Optional: process first N cells only")
|
| 41 |
+
args = parser.parse_args()
|
| 42 |
+
|
| 43 |
+
adata = sc.read_h5ad(args.h5ad)
|
| 44 |
+
vocab, _, _ = load_vocab(args.tokenizer_json)
|
| 45 |
+
|
| 46 |
+
gene_names = list(adata.var_names)
|
| 47 |
+
n_cells = adata.n_obs if args.max_cells is None else min(args.max_cells, adata.n_obs)
|
| 48 |
+
|
| 49 |
+
rows = []
|
| 50 |
+
X = adata.X
|
| 51 |
+
|
| 52 |
+
for i in range(n_cells):
|
| 53 |
+
row = X[i]
|
| 54 |
+
if hasattr(row, "toarray"):
|
| 55 |
+
expr = row.toarray().ravel()
|
| 56 |
+
else:
|
| 57 |
+
expr = np.asarray(row).ravel()
|
| 58 |
+
|
| 59 |
+
token_ids = ranked_gene_ids_for_cell(expr, gene_names, vocab)
|
| 60 |
+
rows.append(token_ids)
|
| 61 |
+
|
| 62 |
+
df = pd.DataFrame({"input_ids": rows})
|
| 63 |
+
table = pa.Table.from_pandas(df)
|
| 64 |
+
|
| 65 |
+
output_path = Path(args.output_arrow)
|
| 66 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 67 |
+
with pa.OSFile(str(output_path), "wb") as sink:
|
| 68 |
+
with pa.ipc.new_stream(sink, table.schema) as writer:
|
| 69 |
+
writer.write_table(table)
|
| 70 |
+
|
| 71 |
+
print(f"Saved {len(rows)} cells to {output_path}")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
main()
|
examples/1_extract_embeddings.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Phase 1: Extract Cell Embeddings
|
| 3 |
+
Demonstrates how to load GeneMamba and extract cell embeddings for downstream analysis.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python examples/1_extract_embeddings.py
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
from transformers import AutoTokenizer, AutoModel
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def main():
|
| 15 |
+
print("=" * 80)
|
| 16 |
+
print("GeneMamba Phase 1: Extract Cell Embeddings")
|
| 17 |
+
print("=" * 80)
|
| 18 |
+
|
| 19 |
+
# ============================================================
|
| 20 |
+
# Step 1: Load pretrained model and tokenizer
|
| 21 |
+
# ============================================================
|
| 22 |
+
print("\n[Step 1] Loading model and tokenizer...")
|
| 23 |
+
|
| 24 |
+
# For this example, we use a local model path
|
| 25 |
+
# In practice, you would use: "username/GeneMamba-24l-512d"
|
| 26 |
+
model_name = "GeneMamba-24l-512d" # Change to HF Hub path when available
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 30 |
+
model_name,
|
| 31 |
+
trust_remote_code=True,
|
| 32 |
+
local_files_only=True # Try local first
|
| 33 |
+
)
|
| 34 |
+
model = AutoModel.from_pretrained(
|
| 35 |
+
model_name,
|
| 36 |
+
trust_remote_code=True,
|
| 37 |
+
local_files_only=True
|
| 38 |
+
)
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f"Note: Could not load from '{model_name}': {e}")
|
| 41 |
+
print("Using mock data for demonstration...")
|
| 42 |
+
|
| 43 |
+
# For demonstration without actual checkpoint
|
| 44 |
+
from configuration_genemamba import GeneMambaConfig
|
| 45 |
+
from modeling_genemamba import GeneMambaModel
|
| 46 |
+
|
| 47 |
+
config = GeneMambaConfig(
|
| 48 |
+
vocab_size=25426,
|
| 49 |
+
hidden_size=512,
|
| 50 |
+
num_hidden_layers=24,
|
| 51 |
+
embedding_pooling="mean",
|
| 52 |
+
)
|
| 53 |
+
model = GeneMambaModel(config)
|
| 54 |
+
tokenizer = None
|
| 55 |
+
|
| 56 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 57 |
+
model = model.to(device)
|
| 58 |
+
model.eval()
|
| 59 |
+
|
| 60 |
+
print(f"✓ Model loaded on device: {device}")
|
| 61 |
+
print(f"✓ Model config: hidden_size={model.config.hidden_size}, "
|
| 62 |
+
f"num_layers={model.config.num_hidden_layers}")
|
| 63 |
+
|
| 64 |
+
# ============================================================
|
| 65 |
+
# Step 2: Prepare simulated single-cell data
|
| 66 |
+
# ============================================================
|
| 67 |
+
print("\n[Step 2] Preparing sample data...")
|
| 68 |
+
|
| 69 |
+
batch_size = 8
|
| 70 |
+
seq_len = 2048
|
| 71 |
+
vocab_size = 25426
|
| 72 |
+
|
| 73 |
+
# Simulate ranked gene sequences
|
| 74 |
+
# In practice, this would come from your scRNA-seq data
|
| 75 |
+
# Genes should be ranked by expression (highest first)
|
| 76 |
+
input_ids = torch.randint(2, vocab_size, (batch_size, seq_len)).to(device)
|
| 77 |
+
|
| 78 |
+
print(f"✓ Created sample input:")
|
| 79 |
+
print(f" - Batch size: {batch_size}")
|
| 80 |
+
print(f" - Sequence length: {seq_len}")
|
| 81 |
+
print(f" - Input shape: {input_ids.shape}")
|
| 82 |
+
|
| 83 |
+
# ============================================================
|
| 84 |
+
# Step 3: Inference - Extract embeddings
|
| 85 |
+
# ============================================================
|
| 86 |
+
print("\n[Step 3] Extracting cell embeddings...")
|
| 87 |
+
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
outputs = model(input_ids, output_hidden_states=False)
|
| 90 |
+
|
| 91 |
+
# Get the pooled embedding (cell representation)
|
| 92 |
+
cell_embeddings = outputs.pooled_embedding
|
| 93 |
+
|
| 94 |
+
print(f"✓ Extraction complete!")
|
| 95 |
+
print(f" - Cell embeddings shape: {cell_embeddings.shape}")
|
| 96 |
+
print(f" - Pooling method used: {outputs.embedding_pooling}")
|
| 97 |
+
print(f" - Embedding type: {cell_embeddings.dtype}")
|
| 98 |
+
|
| 99 |
+
# ============================================================
|
| 100 |
+
# Step 4: Example downstream analyses
|
| 101 |
+
# ============================================================
|
| 102 |
+
print("\n[Step 4] Example downstream uses...")
|
| 103 |
+
|
| 104 |
+
# Example 1: Clustering (KMeans)
|
| 105 |
+
from sklearn.cluster import KMeans
|
| 106 |
+
n_clusters = 3
|
| 107 |
+
kmeans = KMeans(n_clusters=n_clusters, n_init=10)
|
| 108 |
+
clusters = kmeans.fit_predict(cell_embeddings.cpu().numpy())
|
| 109 |
+
print(f"✓ Clustering: Assigned {len(np.unique(clusters))} clusters")
|
| 110 |
+
|
| 111 |
+
# Example 2: Dimensionality reduction (PCA)
|
| 112 |
+
from sklearn.decomposition import PCA
|
| 113 |
+
pca = PCA(n_components=2)
|
| 114 |
+
embedding_2d = pca.fit_transform(cell_embeddings.cpu().numpy())
|
| 115 |
+
print(f"✓ PCA reduction: {cell_embeddings.shape} → {embedding_2d.shape}")
|
| 116 |
+
|
| 117 |
+
# Example 3: Similarity search
|
| 118 |
+
# Find the most similar cell to the first cell
|
| 119 |
+
similarities = torch.nn.functional.cosine_similarity(
|
| 120 |
+
cell_embeddings[0:1],
|
| 121 |
+
cell_embeddings
|
| 122 |
+
)
|
| 123 |
+
most_similar_idx = torch.argmax(similarities).item()
|
| 124 |
+
print(f"✓ Similarity search: Most similar cell to cell 0 is cell {most_similar_idx} "
|
| 125 |
+
f"(similarity: {similarities[most_similar_idx]:.4f})")
|
| 126 |
+
|
| 127 |
+
# Example 4: Statistics
|
| 128 |
+
print("\n[Step 5] Embedding statistics:")
|
| 129 |
+
print(f" - Mean: {cell_embeddings.mean(dim=0).norm():.4f}")
|
| 130 |
+
print(f" - Std: {cell_embeddings.std(dim=0).mean():.4f}")
|
| 131 |
+
print(f" - Min: {cell_embeddings.min():.4f}")
|
| 132 |
+
print(f" - Max: {cell_embeddings.max():.4f}")
|
| 133 |
+
|
| 134 |
+
# ============================================================
|
| 135 |
+
# Step 6: Save embeddings (optional)
|
| 136 |
+
# ============================================================
|
| 137 |
+
print("\n[Step 6] Saving embeddings...")
|
| 138 |
+
|
| 139 |
+
np.save("cell_embeddings.npy", cell_embeddings.cpu().numpy())
|
| 140 |
+
print("✓ Embeddings saved to 'cell_embeddings.npy'")
|
| 141 |
+
|
| 142 |
+
print("\n" + "=" * 80)
|
| 143 |
+
print("Phase 1 Complete!")
|
| 144 |
+
print("=" * 80)
|
| 145 |
+
|
| 146 |
+
return model, cell_embeddings
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
model, embeddings = main()
|
examples/__pycache__/0_preprocess_to_input_ids.cpython-39.pyc
ADDED
|
Binary file (2.64 kB). View file
|
|
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ccb1fcb0ee4b3ea2013099b9b187455e160d3b66b76c606715231b70b13c2784
|
| 3 |
+
size 262998656
|
modeling_genemamba.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch implementation of GeneMamba model for Hugging Face Transformers.
|
| 3 |
+
Includes backbone model and task-specific heads for various downstream tasks.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch.nn.init import normal_, constant_
|
| 14 |
+
|
| 15 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 16 |
+
from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput
|
| 17 |
+
from transformers.models.auto import register_model_for_auto_class
|
| 18 |
+
|
| 19 |
+
from mamba_ssm import Mamba
|
| 20 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm
|
| 21 |
+
|
| 22 |
+
from .configuration_genemamba import GeneMambaConfig
|
| 23 |
+
from .modeling_outputs import GeneMambaModelOutput, GeneMambaSequenceClassifierOutput, GeneMambaMaskedLMOutput
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ===========================
|
| 29 |
+
# Core Architecture Components
|
| 30 |
+
# ===========================
|
| 31 |
+
|
| 32 |
+
class EncoderLayer(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Single Mamba encoder layer with residual connection.
|
| 35 |
+
Applies a Mamba2 or Mamba layer followed by addition with input.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
hidden_size (int): Dimension of hidden states.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, hidden_size: int):
|
| 42 |
+
super(EncoderLayer, self).__init__()
|
| 43 |
+
self.mamba = Mamba(d_model=hidden_size, d_state=64, d_conv=4, expand=2)
|
| 44 |
+
|
| 45 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
torch.Tensor: Output after Mamba layer and residual connection.
|
| 52 |
+
"""
|
| 53 |
+
output = self.mamba(X) + X
|
| 54 |
+
return output
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class MambaMixer(nn.Module):
|
| 58 |
+
"""
|
| 59 |
+
Stack of Mamba encoder layers with bidirectional processing and aggregation.
|
| 60 |
+
Processes sequences in both forward and reverse directions, then aggregates.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
mode (str): Aggregation mode. Options: "mean", "sum", "concat", "gate".
|
| 64 |
+
hidden_size (int): Dimension of hidden states.
|
| 65 |
+
num_hidden_layers (int): Number of Mamba layers.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
mode: str = "gate",
|
| 71 |
+
hidden_size: int = 512,
|
| 72 |
+
num_hidden_layers: int = 24
|
| 73 |
+
):
|
| 74 |
+
super(MambaMixer, self).__init__()
|
| 75 |
+
self.mode = mode
|
| 76 |
+
self.hidden_size = hidden_size
|
| 77 |
+
|
| 78 |
+
# Create Mamba layers
|
| 79 |
+
self.layers = nn.ModuleList(
|
| 80 |
+
[EncoderLayer(hidden_size) for _ in range(num_hidden_layers)]
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Aggregation modules for certain modes
|
| 84 |
+
if mode in ["concat", "gate"]:
|
| 85 |
+
self.aggr = nn.Linear(hidden_size * 2, hidden_size)
|
| 86 |
+
|
| 87 |
+
def flip_sequence(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 88 |
+
"""
|
| 89 |
+
Reverse a sequence based on actual length (ignoring padding).
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
|
| 93 |
+
mask (torch.Tensor, optional): Padding mask of shape (batch_size, seq_len).
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
torch.Tensor: Reversed tensor.
|
| 97 |
+
"""
|
| 98 |
+
batch_size, seq_length, embedding_dim = X.size()
|
| 99 |
+
|
| 100 |
+
if mask is None:
|
| 101 |
+
# Simple flip
|
| 102 |
+
return X.flip([1])
|
| 103 |
+
|
| 104 |
+
# Flip based on actual sequence length (marked by mask)
|
| 105 |
+
lengths = (~mask).sum(dim=1)
|
| 106 |
+
pos_tensor = torch.arange(seq_length, device=X.device).unsqueeze(0).expand(batch_size, -1)
|
| 107 |
+
flip_mask = pos_tensor < lengths.unsqueeze(1)
|
| 108 |
+
reversed_positions = torch.where(
|
| 109 |
+
flip_mask,
|
| 110 |
+
lengths.unsqueeze(1) - 1 - pos_tensor,
|
| 111 |
+
pos_tensor
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
X_reverse = torch.gather(X, 1, reversed_positions.unsqueeze(-1).expand(-1, -1, embedding_dim))
|
| 115 |
+
return X_reverse
|
| 116 |
+
|
| 117 |
+
def forward(
|
| 118 |
+
self,
|
| 119 |
+
X: torch.Tensor,
|
| 120 |
+
padding_mask: Optional[torch.Tensor] = None
|
| 121 |
+
) -> torch.Tensor:
|
| 122 |
+
"""
|
| 123 |
+
Process sequence through bidirectional Mamba layers.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
|
| 127 |
+
padding_mask (torch.Tensor, optional): Padding mask.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
torch.Tensor: Output after processing all layers and aggregation.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
for layer in self.layers:
|
| 134 |
+
# Flip sequence for reverse processing
|
| 135 |
+
X_flip = self.flip_sequence(X, padding_mask)
|
| 136 |
+
|
| 137 |
+
# Forward and reverse passes
|
| 138 |
+
X_f = layer(X)
|
| 139 |
+
X_b = layer(X_flip)
|
| 140 |
+
|
| 141 |
+
# Flip back the reverse output
|
| 142 |
+
X_b = self.flip_sequence(X_b, padding_mask)
|
| 143 |
+
|
| 144 |
+
# Aggregate forward and reverse
|
| 145 |
+
if self.mode == "mean":
|
| 146 |
+
X = (X_f + X_b) / 2
|
| 147 |
+
elif self.mode == "sum":
|
| 148 |
+
X = X_f + X_b
|
| 149 |
+
elif self.mode == "concat":
|
| 150 |
+
X = torch.cat([X_f, X_b], dim=-1)
|
| 151 |
+
X = self.aggr(X)
|
| 152 |
+
elif self.mode == "gate":
|
| 153 |
+
z = torch.sigmoid(self.aggr(torch.cat([X_f, X_b], dim=-1)))
|
| 154 |
+
X = z * X_f + (1 - z) * X_b
|
| 155 |
+
else:
|
| 156 |
+
raise ValueError(f"Invalid aggregation mode: {self.mode}")
|
| 157 |
+
|
| 158 |
+
return X
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# ===========================
|
| 162 |
+
# Base Model Classes
|
| 163 |
+
# ===========================
|
| 164 |
+
|
| 165 |
+
class GeneMambaPreTrainedModel(PreTrainedModel):
|
| 166 |
+
"""
|
| 167 |
+
Base class for all GeneMamba models.
|
| 168 |
+
Handles weight initialization and provides standard model interfaces.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
config_class = GeneMambaConfig
|
| 172 |
+
base_model_prefix = "genemamba"
|
| 173 |
+
supports_gradient_checkpointing = True
|
| 174 |
+
|
| 175 |
+
def _init_weights(self, module):
|
| 176 |
+
"""Initialize module weights."""
|
| 177 |
+
if isinstance(module, nn.Linear):
|
| 178 |
+
normal_(module.weight, std=self.config.initializer_range)
|
| 179 |
+
if module.bias is not None:
|
| 180 |
+
constant_(module.bias, 0.0)
|
| 181 |
+
elif isinstance(module, nn.Embedding):
|
| 182 |
+
normal_(module.weight, std=self.config.initializer_range)
|
| 183 |
+
if module.padding_idx is not None:
|
| 184 |
+
module.weight.data[module.padding_idx].zero_()
|
| 185 |
+
elif isinstance(module, nn.LayerNorm):
|
| 186 |
+
constant_(module.bias, 0.0)
|
| 187 |
+
constant_(module.weight, 1.0)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class GeneMambaModel(GeneMambaPreTrainedModel):
|
| 191 |
+
"""
|
| 192 |
+
GeneMamba backbone model - outputs cell embeddings and hidden states.
|
| 193 |
+
This is the core model used by task-specific heads.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
config (GeneMambaConfig): Model configuration class.
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
def __init__(self, config: GeneMambaConfig):
|
| 200 |
+
super().__init__(config)
|
| 201 |
+
self.config = config
|
| 202 |
+
|
| 203 |
+
# Embedding layer
|
| 204 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 205 |
+
|
| 206 |
+
# Mamba layers with bidirectional aggregation
|
| 207 |
+
self.mamba_mixer = MambaMixer(
|
| 208 |
+
mode=config.mamba_mode,
|
| 209 |
+
hidden_size=config.hidden_size,
|
| 210 |
+
num_hidden_layers=config.num_hidden_layers
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Final layer normalization
|
| 214 |
+
self.norm = RMSNorm(config.hidden_size)
|
| 215 |
+
|
| 216 |
+
self.apply(self._init_weights)
|
| 217 |
+
|
| 218 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 219 |
+
"""Return embedding layer."""
|
| 220 |
+
return self.embeddings
|
| 221 |
+
|
| 222 |
+
def set_input_embeddings(self, value: nn.Embedding):
|
| 223 |
+
"""Set embedding layer."""
|
| 224 |
+
self.embeddings = value
|
| 225 |
+
|
| 226 |
+
def forward(
|
| 227 |
+
self,
|
| 228 |
+
input_ids: torch.Tensor,
|
| 229 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 230 |
+
output_hidden_states: bool = False,
|
| 231 |
+
) -> GeneMambaModelOutput:
|
| 232 |
+
"""
|
| 233 |
+
Args:
|
| 234 |
+
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
|
| 235 |
+
attention_mask (torch.Tensor, optional): Attention mask of shape (batch_size, seq_len).
|
| 236 |
+
output_hidden_states (bool): Whether to output hidden states from all layers.
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
GeneMambaModelOutput: Contains last_hidden_state, pooled_embedding, etc.
|
| 240 |
+
"""
|
| 241 |
+
# Get embeddings
|
| 242 |
+
hidden_states = self.embeddings(input_ids)
|
| 243 |
+
|
| 244 |
+
# Pass through Mamba layers
|
| 245 |
+
hidden_states = self.mamba_mixer(hidden_states, attention_mask)
|
| 246 |
+
|
| 247 |
+
# Apply final normalization
|
| 248 |
+
hidden_states = self.norm(hidden_states)
|
| 249 |
+
|
| 250 |
+
# Compute pooled embedding (cell representation)
|
| 251 |
+
if self.config.embedding_pooling == "CLS":
|
| 252 |
+
# Use first token (CLS)
|
| 253 |
+
pooled_embedding = hidden_states[:, 0, :]
|
| 254 |
+
elif self.config.embedding_pooling == "mean":
|
| 255 |
+
# Mean pooling over sequence
|
| 256 |
+
if attention_mask is not None:
|
| 257 |
+
mask = attention_mask.unsqueeze(-1).expand(hidden_states.shape).float()
|
| 258 |
+
pooled_embedding = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1)
|
| 259 |
+
else:
|
| 260 |
+
pooled_embedding = hidden_states.mean(dim=1)
|
| 261 |
+
else:
|
| 262 |
+
raise ValueError(f"Unsupported embedding_pooling: {self.config.embedding_pooling}")
|
| 263 |
+
|
| 264 |
+
return GeneMambaModelOutput(
|
| 265 |
+
last_hidden_state=hidden_states,
|
| 266 |
+
pooled_embedding=pooled_embedding,
|
| 267 |
+
hidden_states=hidden_states if output_hidden_states else None,
|
| 268 |
+
embedding_pooling=self.config.embedding_pooling,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# ===========================
|
| 273 |
+
# Task-Specific Models
|
| 274 |
+
# ===========================
|
| 275 |
+
|
| 276 |
+
@register_model_for_auto_class("AutoModel")
|
| 277 |
+
class GeneMambaForMaskedLM(GeneMambaPreTrainedModel):
|
| 278 |
+
"""
|
| 279 |
+
GeneMamba model for masked language modeling (MLM).
|
| 280 |
+
Suitable for pretraining and domain adaptation.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
config (GeneMambaConfig): Model configuration class.
|
| 284 |
+
"""
|
| 285 |
+
|
| 286 |
+
def __init__(self, config: GeneMambaConfig):
|
| 287 |
+
super().__init__(config)
|
| 288 |
+
self.genemamba = GeneMambaModel(config)
|
| 289 |
+
|
| 290 |
+
# Language modeling head
|
| 291 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
| 292 |
+
|
| 293 |
+
self.apply(self._init_weights)
|
| 294 |
+
|
| 295 |
+
def forward(
|
| 296 |
+
self,
|
| 297 |
+
input_ids: torch.Tensor,
|
| 298 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 299 |
+
labels: Optional[torch.Tensor] = None,
|
| 300 |
+
output_hidden_states: bool = False,
|
| 301 |
+
) -> GeneMambaMaskedLMOutput:
|
| 302 |
+
"""
|
| 303 |
+
Args:
|
| 304 |
+
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
|
| 305 |
+
attention_mask (torch.Tensor, optional): Attention mask.
|
| 306 |
+
labels (torch.Tensor, optional): Target token ids for MLM loss.
|
| 307 |
+
output_hidden_states (bool): Whether to output hidden states.
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
GeneMambaMaskedLMOutput: Contains logits and optional loss.
|
| 311 |
+
"""
|
| 312 |
+
outputs = self.genemamba(
|
| 313 |
+
input_ids=input_ids,
|
| 314 |
+
attention_mask=attention_mask,
|
| 315 |
+
output_hidden_states=output_hidden_states,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
logits = self.lm_head(outputs.last_hidden_state)
|
| 319 |
+
|
| 320 |
+
loss = None
|
| 321 |
+
if labels is not None:
|
| 322 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 323 |
+
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
| 324 |
+
|
| 325 |
+
return GeneMambaMaskedLMOutput(
|
| 326 |
+
loss=loss,
|
| 327 |
+
logits=logits,
|
| 328 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@register_model_for_auto_class("AutoModelForSequenceClassification")
|
| 333 |
+
class GeneMambaForSequenceClassification(GeneMambaPreTrainedModel):
|
| 334 |
+
"""
|
| 335 |
+
GeneMamba model for sequence classification tasks.
|
| 336 |
+
Ideal for cell type annotation, tissue classification, etc.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
config (GeneMambaConfig): Model configuration class.
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
def __init__(self, config: GeneMambaConfig):
|
| 343 |
+
super().__init__(config)
|
| 344 |
+
self.num_labels = config.num_labels
|
| 345 |
+
self.config = config
|
| 346 |
+
|
| 347 |
+
self.genemamba = GeneMambaModel(config)
|
| 348 |
+
|
| 349 |
+
# Classification head
|
| 350 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 351 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 352 |
+
|
| 353 |
+
self.apply(self._init_weights)
|
| 354 |
+
|
| 355 |
+
def forward(
|
| 356 |
+
self,
|
| 357 |
+
input_ids: torch.Tensor,
|
| 358 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 359 |
+
labels: Optional[torch.Tensor] = None,
|
| 360 |
+
output_hidden_states: bool = False,
|
| 361 |
+
) -> GeneMambaSequenceClassifierOutput:
|
| 362 |
+
"""
|
| 363 |
+
Args:
|
| 364 |
+
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
|
| 365 |
+
attention_mask (torch.Tensor, optional): Attention mask.
|
| 366 |
+
labels (torch.Tensor, optional): Class labels for classification loss.
|
| 367 |
+
output_hidden_states (bool): Whether to output hidden states.
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
GeneMambaSequenceClassifierOutput: Contains logits, optional loss, and embedding.
|
| 371 |
+
"""
|
| 372 |
+
outputs = self.genemamba(
|
| 373 |
+
input_ids=input_ids,
|
| 374 |
+
attention_mask=attention_mask,
|
| 375 |
+
output_hidden_states=output_hidden_states,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
pooled_embedding = outputs.pooled_embedding
|
| 379 |
+
logits = self.classifier(self.dropout(pooled_embedding))
|
| 380 |
+
|
| 381 |
+
loss = None
|
| 382 |
+
if labels is not None:
|
| 383 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 384 |
+
loss = loss_fct(logits, labels)
|
| 385 |
+
|
| 386 |
+
return GeneMambaSequenceClassifierOutput(
|
| 387 |
+
loss=loss,
|
| 388 |
+
logits=logits,
|
| 389 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 390 |
+
pooled_embedding=pooled_embedding,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
# Register tokenizer class
|
| 395 |
+
register_model_for_auto_class("AutoModelForMaskedLM")(GeneMambaForMaskedLM)
|
modeling_outputs.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom ModelOutput classes for GeneMamba.
|
| 3 |
+
Defines the output structure for different GeneMamba tasks.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional, Tuple
|
| 8 |
+
import torch
|
| 9 |
+
from transformers.utils import ModelOutput
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class GeneMambaModelOutput(ModelOutput):
|
| 14 |
+
"""
|
| 15 |
+
Base output class for GeneMamba models.
|
| 16 |
+
|
| 17 |
+
Attributes:
|
| 18 |
+
last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)):
|
| 19 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 20 |
+
|
| 21 |
+
hidden_states (tuple(torch.FloatTensor), optional):
|
| 22 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 23 |
+
|
| 24 |
+
pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size)):
|
| 25 |
+
Cell/sequence-level embedding (pooled representation) used for downstream tasks.
|
| 26 |
+
This is the recommended embedding to use for classification, clustering, etc.
|
| 27 |
+
|
| 28 |
+
embedding_pooling (str):
|
| 29 |
+
The pooling method used to generate pooled_embedding.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
last_hidden_state: torch.FloatTensor = None
|
| 33 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 34 |
+
pooled_embedding: torch.FloatTensor = None
|
| 35 |
+
embedding_pooling: str = "mean"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class GeneMambaSequenceClassifierOutput(ModelOutput):
|
| 40 |
+
"""
|
| 41 |
+
Output class for GeneMamba sequence classification models.
|
| 42 |
+
|
| 43 |
+
Attributes:
|
| 44 |
+
loss (torch.FloatTensor of shape (), optional):
|
| 45 |
+
Classification loss (if labels were provided).
|
| 46 |
+
|
| 47 |
+
logits (torch.FloatTensor of shape (batch_size, num_labels)):
|
| 48 |
+
Classification scores (before softmax).
|
| 49 |
+
|
| 50 |
+
hidden_states (tuple(torch.FloatTensor), optional):
|
| 51 |
+
Hidden-states of the model at the output of each layer.
|
| 52 |
+
|
| 53 |
+
pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size), optional):
|
| 54 |
+
Cell embedding before classification head.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
loss: Optional[torch.FloatTensor] = None
|
| 58 |
+
logits: torch.FloatTensor = None
|
| 59 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 60 |
+
pooled_embedding: Optional[torch.FloatTensor] = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class GeneMambaMaskedLMOutput(ModelOutput):
|
| 65 |
+
"""
|
| 66 |
+
Output class for GeneMamba masked language modeling.
|
| 67 |
+
|
| 68 |
+
Attributes:
|
| 69 |
+
loss (torch.FloatTensor of shape (), optional):
|
| 70 |
+
MLM loss (if labels were provided).
|
| 71 |
+
|
| 72 |
+
logits (torch.FloatTensor of shape (batch_size, sequence_length, vocab_size)):
|
| 73 |
+
Prediction scores of the language modeling head.
|
| 74 |
+
|
| 75 |
+
hidden_states (tuple(torch.FloatTensor), optional):
|
| 76 |
+
Hidden-states of the model at the output of each layer.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
loss: Optional[torch.FloatTensor] = None
|
| 80 |
+
logits: torch.FloatTensor = None
|
| 81 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"pad_token": "[PAD]",
|
| 3 |
+
"unk_token": "[UNK]"
|
| 4 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_assets/gene_tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_assets/id2symbol.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3d5090ff562a77a03b19c37f6a010d639b8d64b1624db2e9a7c3291f9d389293
|
| 3 |
+
size 634589
|
tokenizer_assets/symbol2id.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ecc7193b1f549e513903ba37410788632252a2dda4d07876a1d91730d8697dbe
|
| 3 |
+
size 526232
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {},
|
| 3 |
+
"clean_up_tokenization_spaces": true,
|
| 4 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 5 |
+
"pad_token": "[PAD]",
|
| 6 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
| 7 |
+
"unk_token": "[UNK]"
|
| 8 |
+
}
|