GeneMamba2-24l-512d / examples /downstream /20_continue_pretraining_reference.py
mineself2016's picture
Normalize example naming order
d3fa071 verified
"""
Phase 3: Continue Pretraining
Demonstrates how to continue pretraining GeneMamba on your own data using masked LM objective.
Usage:
python examples/3_continue_pretraining.py
"""
import torch
import numpy as np
from torch.utils.data import Dataset
from transformers import (
AutoModelForMaskedLM,
AutoTokenizer,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling,
)
class PretrainingDataset(Dataset):
"""
Dataset for pretraining/continued pretraining.
Loads sequences and their lengths.
"""
def __init__(self, input_ids_list, max_length=2048):
self.input_ids_list = input_ids_list
self.max_length = max_length
def __len__(self):
return len(self.input_ids_list)
def __getitem__(self, idx):
input_ids = self.input_ids_list[idx]
# Pad or truncate to max_length
if len(input_ids) >= self.max_length:
input_ids = input_ids[:self.max_length]
else:
input_ids = np.pad(
input_ids,
(0, self.max_length - len(input_ids)),
constant_values=1 # Pad token ID
)
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
}
def create_mock_pretraining_data(n_sequences=5000, seq_len=2048):
"""Create mock single-cell sequences for pretraining."""
print("Creating mock pretraining dataset...")
# Create ranked gene sequences
# In practice, these would come from your scRNA-seq data
sequences = []
for _ in range(n_sequences):
# Random ranked sequence
seq = np.random.randint(2, 25426, seq_len)
sequences.append(seq)
print(f"✓ Created {n_sequences} sequences of length {seq_len}")
return sequences
def main():
print("=" * 80)
print("GeneMamba Phase 3: Continue Pretraining")
print("=" * 80)
# ============================================================
# Step 1: Load pretrained model for masked LM
# ============================================================
print("\n[Step 1] Loading model for masked LM...")
try:
model = AutoModelForMaskedLM.from_pretrained(
"GeneMamba-24l-512d",
trust_remote_code=True,
local_files_only=True,
)
tokenizer = AutoTokenizer.from_pretrained(
"GeneMamba-24l-512d",
trust_remote_code=True,
local_files_only=True,
)
except Exception as e:
print(f"Note: Could not load from hub ({e})")
print("Using local initialization...")
# Initialize locally
from configuration_genemamba import GeneMambaConfig
from modeling_genemamba import GeneMambaForMaskedLM
config = GeneMambaConfig(
vocab_size=25426,
hidden_size=512,
num_hidden_layers=24,
)
model = GeneMambaForMaskedLM(config)
tokenizer = None
print(f"✓ Model loaded")
print(f" - Architecture: {model.config.num_hidden_layers} layers, "
f"hidden_size={model.config.hidden_size}")
# ============================================================
# Step 2: Prepare pretraining data
# ============================================================
print("\n[Step 2] Preparing pretraining dataset...")
sequences = create_mock_pretraining_data(n_sequences=5000, seq_len=2048)
# Split train/eval
train_size = int(0.9 * len(sequences))
train_sequences = sequences[:train_size]
eval_sequences = sequences[train_size:]
train_dataset = PretrainingDataset(train_sequences)
eval_dataset = PretrainingDataset(eval_sequences)
print(f"✓ Datasets created:")
print(f" - Training: {len(train_dataset)} samples")
print(f" - Evaluation: {len(eval_dataset)} samples")
# ============================================================
# Step 3: Set up data collator for MLM
# ============================================================
print("\n[Step 3] Setting up data collator...")
if tokenizer is not None:
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=True,
mlm_probability=0.15, # Mask 15% of tokens
)
else:
# Custom collator if no tokenizer available
class CustomDataCollator:
def __call__(self, batch):
input_ids = torch.stack([item["input_ids"] for item in batch])
# Create masked labels (for MLM loss)
labels = input_ids.clone()
mask = torch.rand(input_ids.shape) < 0.15
# Set input to [MASK] token (id=0)
input_ids[mask] = 0
# Set labels to -100 where not masked (loss ignores these)
labels[~mask] = -100
return {"input_ids": input_ids, "labels": labels}
data_collator = CustomDataCollator()
print(f"✓ Data collator ready (MLM probability: 0.15)")
# ============================================================
# Step 4: Set up training arguments
# ============================================================
print("\n[Step 4] Setting up training...")
output_dir = "./pretrain_results"
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=2,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=2e-5,
weight_decay=0.01,
warmup_steps=500,
logging_steps=100,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
report_to="none", # Disable W&B
seed=42,
)
print(f"✓ Training config:")
print(f" - Output dir: {output_dir}")
print(f" - Epochs: {training_args.num_train_epochs}")
print(f" - Batch size: {training_args.per_device_train_batch_size}")
print(f" - Learning rate: {training_args.learning_rate}")
print(f" - MLM masking: 15%")
# ============================================================
# Step 5: Train
# ============================================================
print("\n[Step 5] Starting continued pretraining...")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)
train_result = trainer.train()
print(f"✓ Training complete!")
print(f" - Final training loss: {train_result.training_loss:.4f}")
# ============================================================
# Step 6: Evaluate
# ============================================================
print("\n[Step 6] Evaluating on held-out set...")
eval_results = trainer.evaluate()
print(f"✓ Evaluation Results:")
for metric, value in eval_results.items():
if isinstance(value, (int, float)):
print(f" - {metric}: {value:.4f}")
# ============================================================
# Step 7: Save model
# ============================================================
print("\n[Step 7] Saving continued pretrained model...")
save_dir = "./genemamba_continued_pretrain"
model.save_pretrained(save_dir)
if tokenizer is not None:
tokenizer.save_pretrained(save_dir)
print(f"✓ Model saved to '{save_dir}'")
# ============================================================
# Step 8: Test model inference
# ============================================================
print("\n[Step 8] Testing inference on masked input...")
model.eval()
# Create sample input with masked tokens
sample_input = torch.randint(2, 25426, (1, 2048))
sample_input[0, :10] = 0 # Mask first 10 tokens
with torch.no_grad():
outputs = model(sample_input)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
print(f"✓ Sample predictions generated")
print(f" - Input shape: {sample_input.shape}")
print(f" - Output logits shape: {logits.shape}")
print(f" - Top predicted genes (tokens): {predictions[0, :10].tolist()}")
print("\n" + "=" * 80)
print("Phase 3 Complete! Model ready for downstream tasks or further training.")
print("=" * 80)
return model, trainer
if __name__ == "__main__":
model, trainer = main()