mineself2016 commited on
Commit
2a939f5
·
verified ·
1 Parent(s): b8d89c6

Use next-token training and checkpoint resume in train-from-scratch

Browse files
Files changed (1) hide show
  1. README.md +52 -98
README.md CHANGED
@@ -20,8 +20,7 @@ A Hugging Face compatible implementation of GeneMamba, a foundational state-spac
20
  - [Quick Start](#quick-start)
21
  - [Phase 1: Extract Cell Embeddings](#phase-1-extract-cell-embeddings)
22
  - [Phase 2: Downstream Tasks](#phase-2-downstream-tasks)
23
- - [Phase 3: Continue Pretraining](#phase-3-continue-pretraining)
24
- - [Phase 4: Train from Scratch](#phase-4-train-from-scratch)
25
  - [Model Variants](#model-variants)
26
  - [Architecture](#architecture)
27
  - [Datasets](#datasets)
@@ -37,14 +36,14 @@ GeneMamba is a **state-space model (SSM)** based on **Mamba architecture** optim
37
 
38
  - **Takes ranked gene sequences** as input (genes sorted by expression level)
39
  - **Outputs cell embeddings** suitable for clustering, classification, and batch integration
40
- - **Supports multiple downstream tasks** including cell type annotation and masked LM pretraining
41
  - **Is compatible with Hugging Face Transformers** for easy integration into existing pipelines
42
 
43
  ### Key Features
44
 
45
  ✅ **Efficient Sequence Processing**: SSM-based architecture with linear complexity
46
  ✅ **Cell Representation Learning**: Direct cell embedding without intermediate steps
47
- ✅ **Multi-task Support**: Classification, masked LM, and embeddings in one model
48
  ✅ **Hugging Face Integration**: Standard `from_pretrained()` and `save_pretrained()` interface
49
  ✅ **Production Ready**: Pretrained checkpoints available on Hugging Face Hub
50
 
@@ -204,117 +203,74 @@ The model also supports:
204
 
205
  ---
206
 
207
- ### Phase 3: Continue Pretraining
208
 
209
- Fine-tune the model on your own single-cell data using **masked LM objective**:
 
210
 
211
  ```python
212
  import torch
 
213
  from transformers import AutoTokenizer, AutoModelForMaskedLM, Trainer, TrainingArguments
214
- from torch.utils.data import Dataset
215
 
216
- # Load model for masked LM
217
  tokenizer = AutoTokenizer.from_pretrained(
218
  "mineself2016/GeneMamba",
219
- trust_remote_code=True
220
- )
221
- model = AutoModelForMaskedLM.from_pretrained(
222
- "mineself2016/GeneMamba",
223
- trust_remote_code=True
224
  )
225
 
226
- print("vocab_size:", tokenizer.vocab_size) # 25426
227
- print("unk_token/id:", tokenizer.unk_token, tokenizer.unk_token_id) # [UNK], 0
228
- print("pad_token/id:", tokenizer.pad_token, tokenizer.pad_token_id) # [PAD], 1
229
- print("cls_token/id:", tokenizer.cls_token, tokenizer.cls_token_id) # None, None
230
- print("mask_token/id:", tokenizer.mask_token, tokenizer.mask_token_id) # None, None
231
 
232
- # Important:
233
- # GeneMamba tokenizer defines only [UNK]=0 and [PAD]=1 as special tokens.
234
- # There is no built-in [CLS]/[SEP]/[MASK].
235
 
236
- # Your pretraining dataset (with input_ids only, no labels needed)
237
- class PretrainDataset(Dataset):
238
- def __init__(self, input_ids_list):
239
- self.input_ids_list = input_ids_list
240
-
241
- def __len__(self):
242
- return len(self.input_ids_list)
243
-
244
- def __getitem__(self, idx):
245
- return {"input_ids": self.input_ids_list[idx]}
246
-
247
- # Custom MLM collator (replace masked positions with [UNK], id=0)
248
- class GeneMambaMLMCollator:
249
- def __init__(self, pad_token_id=1, unk_token_id=0, mlm_probability=0.15):
250
- self.pad_token_id = pad_token_id
251
- self.unk_token_id = unk_token_id
252
- self.mlm_probability = mlm_probability
253
-
254
- def __call__(self, features):
255
- input_ids = torch.stack([f["input_ids"] for f in features])
256
- labels = input_ids.clone()
257
-
258
- prob = torch.full(labels.shape, self.mlm_probability)
259
- mask_positions = torch.bernoulli(prob).bool()
260
- mask_positions &= input_ids.ne(self.pad_token_id)
261
-
262
- labels[~mask_positions] = -100
263
- input_ids[mask_positions] = self.unk_token_id
264
- return {"input_ids": input_ids, "labels": labels}
265
-
266
- data_collator = GeneMambaMLMCollator(
267
- pad_token_id=tokenizer.pad_token_id,
268
- unk_token_id=tokenizer.unk_token_id,
269
- mlm_probability=0.15,
270
  )
271
 
272
- # Train
273
- trainer = Trainer(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  model=model,
275
  args=TrainingArguments(
276
- output_dir="./pretrain_results",
277
  num_train_epochs=3,
278
  per_device_train_batch_size=32,
279
  learning_rate=2e-5,
280
  ),
281
  train_dataset=train_dataset,
282
- data_collator=data_collator,
283
  )
284
 
285
- trainer.train()
286
- ```
287
-
288
- ---
289
-
290
- ### Phase 4: Train from Scratch
291
-
292
- Initialize and train a new GeneMamba model from scratch:
293
-
294
- ```python
295
- import torch
296
- from transformers import AutoConfig, PreTrainedModel
297
- from transformers.utils.hub import register_and_push_to_hub_with_git_history
298
-
299
- # Create config
300
- config = AutoConfig.from_pretrained(
301
- "mineself2016/GeneMamba",
302
- trust_remote_code=True
303
- )
304
-
305
- # Modify hyperparameters if needed
306
- config.hidden_size = 512
307
- config.num_hidden_layers = 24
308
- config.vocab_size = 25426
309
-
310
- # Import and instantiate model
311
- from modeling_genemamba import GeneMambaForMaskedLM
312
-
313
- model = GeneMambaForMaskedLM(config)
314
-
315
- print(f"Total parameters: {model.num_parameters() / 1e9:.2f}B")
316
-
317
- # Now proceed with training as in Phase 3
318
  ```
319
 
320
  ---
@@ -380,7 +336,7 @@ model = AutoModelForSequenceClassification.from_pretrained(
380
  "mineself2016/GeneMamba", num_labels=10, trust_remote_code=True
381
  )
382
 
383
- # Masked LM
384
  from transformers import AutoModelForMaskedLM
385
  model = AutoModelForMaskedLM.from_pretrained("mineself2016/GeneMamba", trust_remote_code=True)
386
  ```
@@ -476,11 +432,9 @@ input_ids = tokenizer(gene_ids, return_tensors="pt", padding=True)["input_ids"]
476
 
477
  See the `examples/` directory for complete scripts:
478
 
479
- - `00_preprocess_to_input_ids.py` - h5ad to ranked gene token IDs
480
- - `01_extract_embeddings.py` - Extract cell embeddings
481
- - `10_finetune_classification.py` - Cell type annotation
482
- - `20_continue_pretraining_reference.py` - Domain adaptation
483
- - `21_pretrain_from_scratch_reference.py` - Training from scratch
484
 
485
  Run any example:
486
 
 
20
  - [Quick Start](#quick-start)
21
  - [Phase 1: Extract Cell Embeddings](#phase-1-extract-cell-embeddings)
22
  - [Phase 2: Downstream Tasks](#phase-2-downstream-tasks)
23
+ - [Phase 3: Train from Scratch](#phase-3-train-from-scratch)
 
24
  - [Model Variants](#model-variants)
25
  - [Architecture](#architecture)
26
  - [Datasets](#datasets)
 
36
 
37
  - **Takes ranked gene sequences** as input (genes sorted by expression level)
38
  - **Outputs cell embeddings** suitable for clustering, classification, and batch integration
39
+ - **Supports multiple downstream tasks** including cell type annotation and next-token pretraining
40
  - **Is compatible with Hugging Face Transformers** for easy integration into existing pipelines
41
 
42
  ### Key Features
43
 
44
  ✅ **Efficient Sequence Processing**: SSM-based architecture with linear complexity
45
  ✅ **Cell Representation Learning**: Direct cell embedding without intermediate steps
46
+ ✅ **Multi-task Support**: Classification, next-token pretraining, and embeddings in one model
47
  ✅ **Hugging Face Integration**: Standard `from_pretrained()` and `save_pretrained()` interface
48
  ✅ **Production Ready**: Pretrained checkpoints available on Hugging Face Hub
49
 
 
203
 
204
  ---
205
 
206
+ ### Phase 3: Train from Scratch
207
 
208
+ Train a new GeneMamba model with **next-token prediction**.
209
+ If a checkpoint exists, resume automatically; otherwise start from scratch.
210
 
211
  ```python
212
  import torch
213
+ from pathlib import Path
214
  from transformers import AutoTokenizer, AutoModelForMaskedLM, Trainer, TrainingArguments
 
215
 
 
216
  tokenizer = AutoTokenizer.from_pretrained(
217
  "mineself2016/GeneMamba",
218
+ trust_remote_code=True,
 
 
 
 
219
  )
220
 
221
+ print("vocab_size:", tokenizer.vocab_size) # 25426
222
+ print("unk/pad:", tokenizer.unk_token_id, tokenizer.pad_token_id) # 0, 1
223
+ print("cls/mask:", tokenizer.cls_token_id, tokenizer.mask_token_id) # None, None
 
 
224
 
225
+ # Build model config
226
+ from configuration_genemamba import GeneMambaConfig
227
+ from modeling_genemamba import GeneMambaForMaskedLM
228
 
229
+ config = GeneMambaConfig(
230
+ vocab_size=25426,
231
+ hidden_size=512,
232
+ num_hidden_layers=24,
233
+ max_position_embeddings=2048,
234
+ mamba_mode="mean",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  )
236
 
237
+ # Resume if checkpoint exists
238
+ checkpoint_dir = Path("./from_scratch_pretrain/checkpoint-last")
239
+ if checkpoint_dir.exists():
240
+ model = AutoModelForMaskedLM.from_pretrained(
241
+ str(checkpoint_dir),
242
+ trust_remote_code=True,
243
+ local_files_only=True,
244
+ )
245
+ resume_from_checkpoint = str(checkpoint_dir)
246
+ else:
247
+ model = GeneMambaForMaskedLM(config)
248
+ resume_from_checkpoint = None
249
+
250
+ class NextTokenTrainer(Trainer):
251
+ def compute_loss(self, model, inputs, return_outputs=False):
252
+ input_ids = inputs["input_ids"]
253
+ logits = model(input_ids=input_ids).logits
254
+ shift_logits = logits[:, :-1, :].contiguous()
255
+ shift_labels = input_ids[:, 1:].contiguous().to(shift_logits.device)
256
+ loss = torch.nn.functional.cross_entropy(
257
+ shift_logits.view(-1, shift_logits.size(-1)),
258
+ shift_labels.view(-1),
259
+ )
260
+ return loss
261
+
262
+ trainer = NextTokenTrainer(
263
  model=model,
264
  args=TrainingArguments(
265
+ output_dir="./from_scratch_pretrain",
266
  num_train_epochs=3,
267
  per_device_train_batch_size=32,
268
  learning_rate=2e-5,
269
  ),
270
  train_dataset=train_dataset,
 
271
  )
272
 
273
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  ```
275
 
276
  ---
 
336
  "mineself2016/GeneMamba", num_labels=10, trust_remote_code=True
337
  )
338
 
339
+ # Language modeling head (used with next-token objective)
340
  from transformers import AutoModelForMaskedLM
341
  model = AutoModelForMaskedLM.from_pretrained("mineself2016/GeneMamba", trust_remote_code=True)
342
  ```
 
432
 
433
  See the `examples/` directory for complete scripts:
434
 
435
+ - `1_extract_embeddings.py` - Extract cell embeddings
436
+ - `2_finetune_classification.py` - Cell type annotation
437
+ - `4_pretrain_from_scratch.py` - Train from scratch (next-token + optional resume)
 
 
438
 
439
  Run any example:
440