| """ |
| Phase 3: Continue Pretraining |
| Demonstrates how to continue pretraining GeneMamba on your own data using masked LM objective. |
| |
| Usage: |
| python examples/3_continue_pretraining.py |
| """ |
|
|
| import torch |
| import numpy as np |
| from torch.utils.data import Dataset |
| from transformers import ( |
| AutoModelForMaskedLM, |
| AutoTokenizer, |
| Trainer, |
| TrainingArguments, |
| DataCollatorForLanguageModeling, |
| ) |
|
|
|
|
| class PretrainingDataset(Dataset): |
| """ |
| Dataset for pretraining/continued pretraining. |
| Loads sequences and their lengths. |
| """ |
| |
| def __init__(self, input_ids_list, max_length=2048): |
| self.input_ids_list = input_ids_list |
| self.max_length = max_length |
| |
| def __len__(self): |
| return len(self.input_ids_list) |
| |
| def __getitem__(self, idx): |
| input_ids = self.input_ids_list[idx] |
| |
| |
| if len(input_ids) >= self.max_length: |
| input_ids = input_ids[:self.max_length] |
| else: |
| input_ids = np.pad( |
| input_ids, |
| (0, self.max_length - len(input_ids)), |
| constant_values=1 |
| ) |
| |
| return { |
| "input_ids": torch.tensor(input_ids, dtype=torch.long), |
| } |
|
|
|
|
| def create_mock_pretraining_data(n_sequences=5000, seq_len=2048): |
| """Create mock single-cell sequences for pretraining.""" |
| |
| print("Creating mock pretraining dataset...") |
| |
| |
| |
| sequences = [] |
| for _ in range(n_sequences): |
| |
| seq = np.random.randint(2, 25426, seq_len) |
| sequences.append(seq) |
| |
| print(f"✓ Created {n_sequences} sequences of length {seq_len}") |
| |
| return sequences |
|
|
|
|
| def main(): |
| print("=" * 80) |
| print("GeneMamba Phase 3: Continue Pretraining") |
| print("=" * 80) |
| |
| |
| |
| |
| print("\n[Step 1] Loading model for masked LM...") |
| |
| try: |
| model = AutoModelForMaskedLM.from_pretrained( |
| "GeneMamba-24l-512d", |
| trust_remote_code=True, |
| local_files_only=True, |
| ) |
| tokenizer = AutoTokenizer.from_pretrained( |
| "GeneMamba-24l-512d", |
| trust_remote_code=True, |
| local_files_only=True, |
| ) |
| except Exception as e: |
| print(f"Note: Could not load from hub ({e})") |
| print("Using local initialization...") |
| |
| |
| from configuration_genemamba import GeneMambaConfig |
| from modeling_genemamba import GeneMambaForMaskedLM |
| |
| config = GeneMambaConfig( |
| vocab_size=25426, |
| hidden_size=512, |
| num_hidden_layers=24, |
| ) |
| model = GeneMambaForMaskedLM(config) |
| tokenizer = None |
| |
| print(f"✓ Model loaded") |
| print(f" - Architecture: {model.config.num_hidden_layers} layers, " |
| f"hidden_size={model.config.hidden_size}") |
| |
| |
| |
| |
| print("\n[Step 2] Preparing pretraining dataset...") |
| |
| sequences = create_mock_pretraining_data(n_sequences=5000, seq_len=2048) |
| |
| |
| train_size = int(0.9 * len(sequences)) |
| train_sequences = sequences[:train_size] |
| eval_sequences = sequences[train_size:] |
| |
| train_dataset = PretrainingDataset(train_sequences) |
| eval_dataset = PretrainingDataset(eval_sequences) |
| |
| print(f"✓ Datasets created:") |
| print(f" - Training: {len(train_dataset)} samples") |
| print(f" - Evaluation: {len(eval_dataset)} samples") |
| |
| |
| |
| |
| print("\n[Step 3] Setting up data collator...") |
| |
| if tokenizer is not None: |
| data_collator = DataCollatorForLanguageModeling( |
| tokenizer=tokenizer, |
| mlm=True, |
| mlm_probability=0.15, |
| ) |
| else: |
| |
| class CustomDataCollator: |
| def __call__(self, batch): |
| input_ids = torch.stack([item["input_ids"] for item in batch]) |
| |
| |
| labels = input_ids.clone() |
| mask = torch.rand(input_ids.shape) < 0.15 |
| |
| |
| input_ids[mask] = 0 |
| |
| |
| labels[~mask] = -100 |
| |
| return {"input_ids": input_ids, "labels": labels} |
| |
| data_collator = CustomDataCollator() |
| |
| print(f"✓ Data collator ready (MLM probability: 0.15)") |
| |
| |
| |
| |
| print("\n[Step 4] Setting up training...") |
| |
| output_dir = "./pretrain_results" |
| |
| training_args = TrainingArguments( |
| output_dir=output_dir, |
| num_train_epochs=2, |
| per_device_train_batch_size=16, |
| per_device_eval_batch_size=16, |
| learning_rate=2e-5, |
| weight_decay=0.01, |
| warmup_steps=500, |
| logging_steps=100, |
| eval_strategy="epoch", |
| save_strategy="epoch", |
| load_best_model_at_end=True, |
| metric_for_best_model="eval_loss", |
| report_to="none", |
| seed=42, |
| ) |
| |
| print(f"✓ Training config:") |
| print(f" - Output dir: {output_dir}") |
| print(f" - Epochs: {training_args.num_train_epochs}") |
| print(f" - Batch size: {training_args.per_device_train_batch_size}") |
| print(f" - Learning rate: {training_args.learning_rate}") |
| print(f" - MLM masking: 15%") |
| |
| |
| |
| |
| print("\n[Step 5] Starting continued pretraining...") |
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| data_collator=data_collator, |
| ) |
| |
| train_result = trainer.train() |
| |
| print(f"✓ Training complete!") |
| print(f" - Final training loss: {train_result.training_loss:.4f}") |
| |
| |
| |
| |
| print("\n[Step 6] Evaluating on held-out set...") |
| |
| eval_results = trainer.evaluate() |
| |
| print(f"✓ Evaluation Results:") |
| for metric, value in eval_results.items(): |
| if isinstance(value, (int, float)): |
| print(f" - {metric}: {value:.4f}") |
| |
| |
| |
| |
| print("\n[Step 7] Saving continued pretrained model...") |
| |
| save_dir = "./genemamba_continued_pretrain" |
| model.save_pretrained(save_dir) |
| if tokenizer is not None: |
| tokenizer.save_pretrained(save_dir) |
| |
| print(f"✓ Model saved to '{save_dir}'") |
| |
| |
| |
| |
| print("\n[Step 8] Testing inference on masked input...") |
| |
| model.eval() |
| |
| |
| sample_input = torch.randint(2, 25426, (1, 2048)) |
| sample_input[0, :10] = 0 |
| |
| with torch.no_grad(): |
| outputs = model(sample_input) |
| logits = outputs.logits |
| predictions = torch.argmax(logits, dim=-1) |
| |
| print(f"✓ Sample predictions generated") |
| print(f" - Input shape: {sample_input.shape}") |
| print(f" - Output logits shape: {logits.shape}") |
| print(f" - Top predicted genes (tokens): {predictions[0, :10].tolist()}") |
| |
| print("\n" + "=" * 80) |
| print("Phase 3 Complete! Model ready for downstream tasks or further training.") |
| print("=" * 80) |
| |
| return model, trainer |
|
|
|
|
| if __name__ == "__main__": |
| model, trainer = main() |
|
|