GeneMamba / examples /3_pretrain_from_scratch.py
mineself2016's picture
Update examples/3_pretrain_from_scratch.py
dac10b0 verified
"""
Phase 3: Train from Scratch (Next-Token Objective)
Demonstrates how to initialize and train GeneMamba with next-token prediction.
If a checkpoint exists, training resumes from checkpoint automatically.
Usage:
python examples/3_pretrain_from_scratch.py
"""
import torch
import numpy as np
from torch.utils.data import Dataset
from pathlib import Path
from transformers import (
AutoTokenizer,
AutoConfig,
AutoModelForMaskedLM,
Trainer,
TrainingArguments,
)
from transformers.trainer_utils import get_last_checkpoint
class PretrainingDataset(Dataset):
"""Dataset for pretraining."""
def __init__(self, input_ids_list, max_length=2048):
self.input_ids_list = input_ids_list
self.max_length = max_length
def __len__(self):
return len(self.input_ids_list)
def __getitem__(self, idx):
input_ids = self.input_ids_list[idx]
# Pad or truncate
if len(input_ids) >= self.max_length:
input_ids = input_ids[:self.max_length]
else:
input_ids = np.pad(
input_ids,
(0, self.max_length - len(input_ids)),
constant_values=1
)
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
}
class NextTokenTrainer(Trainer):
"""Use next-token prediction loss: logits[:, :-1] vs labels[:, 1:]."""
def compute_loss(self, model, inputs, return_outputs=False):
input_ids = inputs["input_ids"]
outputs = model(input_ids=input_ids)
logits = outputs.logits
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = input_ids[:, 1:].contiguous().to(shift_logits.device)
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
return (loss, outputs) if return_outputs else loss
class NextTokenDataCollator:
"""Simple collator for pre-tokenized input_ids (no MLM masking)."""
def __call__(self, batch):
input_ids = torch.stack([item["input_ids"] for item in batch])
return {"input_ids": input_ids}
def create_mock_pretraining_data(n_sequences=5000, seq_len=2048):
"""Create mock pretraining data."""
print("Creating mock pretraining dataset for from-scratch training...")
sequences = []
for _ in range(n_sequences):
seq = np.random.randint(2, 25426, seq_len)
sequences.append(seq)
print(f"✓ Created {n_sequences} sequences")
return sequences
def main():
print("=" * 80)
print("GeneMamba Phase 3: Train from Scratch (Next-Token)")
print("=" * 80)
model_id = "mineself2016/GeneMamba"
output_dir = "./from_scratch_pretrain"
checkpoint_dir = Path(output_dir) / "checkpoint-last"
# ============================================================
# Step 1: Load tokenizer spec
# ============================================================
print("\n[Step 1] Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
print("✓ Tokenizer loaded:")
print(f" - vocab_size: {tokenizer.vocab_size}")
print(f" - [UNK] token/id: {tokenizer.unk_token}/{tokenizer.unk_token_id}")
print(f" - [PAD] token/id: {tokenizer.pad_token}/{tokenizer.pad_token_id}")
print(f" - [CLS] token/id: {tokenizer.cls_token}/{tokenizer.cls_token_id}")
print(f" - [MASK] token/id: {tokenizer.mask_token}/{tokenizer.mask_token_id}")
# ============================================================
# Step 2: Build config and initialize/resume model
# ============================================================
print("\n[Step 2] Building model (resume if checkpoint exists)...")
model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
model_config.vocab_size = 25426
model_config.hidden_size = 256
model_config.num_hidden_layers = 12
model_config.intermediate_size = 1024
model_config.max_position_embeddings = 2048
model_config.mamba_mode = "mean"
resume_from_checkpoint = None
if checkpoint_dir.exists():
resume_from_checkpoint = str(checkpoint_dir)
else:
resume_from_checkpoint = get_last_checkpoint(output_dir)
if resume_from_checkpoint is not None:
model = AutoModelForMaskedLM.from_pretrained(
resume_from_checkpoint,
trust_remote_code=True,
local_files_only=True,
)
print(f"✓ Found checkpoint, resume from: {resume_from_checkpoint}")
else:
model = AutoModelForMaskedLM.from_config(model_config, trust_remote_code=True)
print("✓ No checkpoint found, start from scratch")
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"✓ Model initialized:")
print(f" - Total parameters: {total_params / 1e6:.2f}M")
print(f" - Trainable parameters: {trainable_params / 1e6:.2f}M")
# ============================================================
# Step 3: Prepare data
# ============================================================
print("\n[Step 3] Preparing training data...")
sequences = create_mock_pretraining_data(n_sequences=5000, seq_len=2048)
# Split
train_size = int(0.8 * len(sequences))
train_sequences = sequences[:train_size]
eval_sequences = sequences[train_size:]
train_dataset = PretrainingDataset(train_sequences)
eval_dataset = PretrainingDataset(eval_sequences)
print(f"✓ Datasets created:")
print(f" - Train: {len(train_dataset)}")
print(f" - Eval: {len(eval_dataset)}")
# ============================================================
# Step 4: Data collator for next-token training
# ============================================================
print("\n[Step 4] Setting up data collator...")
data_collator = NextTokenDataCollator()
print(f"✓ Data collator ready")
# ============================================================
# Step 5: Training arguments
# ============================================================
print("\n[Step 5] Setting up training...")
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=5e-4,
weight_decay=0.01,
warmup_steps=500,
logging_steps=50,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
report_to="none",
seed=42,
optim="adamw_torch",
gradient_accumulation_steps=1,
max_grad_norm=1.0,
)
print(f"✓ Training config:")
print(f" - Output: {output_dir}")
print(f" - Epochs: {training_args.num_train_epochs}")
print(f" - Batch size: {training_args.per_device_train_batch_size}")
print(f" - Learning rate: {training_args.learning_rate}")
# ============================================================
# Step 6: Train
# ============================================================
print("\n[Step 6] Starting training...")
print("(This may take a while. In practice, use more GPUs/data for real pretraining)")
trainer = NextTokenTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)
train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
print(f"✓ Training complete!")
print(f" - Final training loss: {train_result.training_loss:.4f}")
# ============================================================
# Step 7: Evaluate
# ============================================================
print("\n[Step 7] Evaluating...")
eval_results = trainer.evaluate()
print(f"✓ Evaluation Results:")
for metric, value in eval_results.items():
if isinstance(value, (int, float)):
print(f" - {metric}: {value:.4f}")
# ============================================================
# Step 8: Save model and config
# ============================================================
print("\n[Step 8] Saving model...")
save_dir = "./my_genemamba_from_scratch"
model.save_pretrained(save_dir)
model_config.save_pretrained(save_dir)
print(f"✓ Model and config saved to '{save_dir}'")
print(f" Files created:")
print(f" - config.json")
print(f" - model.safetensors (or pytorch_model.bin)")
# ============================================================
# Step 9: Reload and verify
# ============================================================
print("\n[Step 9] Reloading model from checkpoint...")
loaded_model = AutoModelForMaskedLM.from_pretrained(
save_dir,
trust_remote_code=True,
)
loaded_model.eval()
# Test inference
with torch.no_grad():
sample_input = torch.randint(2, 25426, (2, 2048))
outputs = loaded_model(sample_input)
logits = outputs.logits
print(f"✓ Model reloaded and tested!")
print(f" - Input shape: {sample_input.shape}")
print(f" - Logits shape: {logits.shape}")
# ============================================================
# Step 10: Optional - Convert to different format
# ============================================================
print("\n[Step 10] Model ready for conversion/deployment!")
print(f"✓ You can now:")
print(f" 1. Push to Hugging Face Hub:")
print(f" model.push_to_hub('your-username/GeneMamba-custom')")
print(f" 2. Use with downstream tasks:")
print(f" AutoModelForSequenceClassification.from_pretrained('{save_dir}', num_labels=N)")
print(f" 3. Extract embeddings:")
print(f" AutoModel.from_pretrained('{save_dir}')")
print("\n" + "=" * 80)
print("Phase 3 Complete! Model trained from scratch and ready to use.")
print("=" * 80)
return model, trainer, model_config
if __name__ == "__main__":
model, trainer, model_config = main()