| """ |
| Phase 3: Train from Scratch (Next-Token Objective) |
| Demonstrates how to initialize and train GeneMamba with next-token prediction. |
| If a checkpoint exists, training resumes from checkpoint automatically. |
| |
| Usage: |
| python examples/3_pretrain_from_scratch.py |
| """ |
|
|
| import torch |
| import numpy as np |
| from torch.utils.data import Dataset |
| from pathlib import Path |
| from transformers import ( |
| AutoTokenizer, |
| AutoConfig, |
| AutoModelForMaskedLM, |
| Trainer, |
| TrainingArguments, |
| ) |
| from transformers.trainer_utils import get_last_checkpoint |
|
|
|
|
| class PretrainingDataset(Dataset): |
| """Dataset for pretraining.""" |
| |
| def __init__(self, input_ids_list, max_length=2048): |
| self.input_ids_list = input_ids_list |
| self.max_length = max_length |
| |
| def __len__(self): |
| return len(self.input_ids_list) |
| |
| def __getitem__(self, idx): |
| input_ids = self.input_ids_list[idx] |
| |
| |
| if len(input_ids) >= self.max_length: |
| input_ids = input_ids[:self.max_length] |
| else: |
| input_ids = np.pad( |
| input_ids, |
| (0, self.max_length - len(input_ids)), |
| constant_values=1 |
| ) |
| |
| return { |
| "input_ids": torch.tensor(input_ids, dtype=torch.long), |
| } |
|
|
|
|
| class NextTokenTrainer(Trainer): |
| """Use next-token prediction loss: logits[:, :-1] vs labels[:, 1:].""" |
|
|
| def compute_loss(self, model, inputs, return_outputs=False): |
| input_ids = inputs["input_ids"] |
| outputs = model(input_ids=input_ids) |
| logits = outputs.logits |
|
|
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = input_ids[:, 1:].contiguous().to(shift_logits.device) |
|
|
| loss_fct = torch.nn.CrossEntropyLoss() |
| loss = loss_fct( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ) |
|
|
| return (loss, outputs) if return_outputs else loss |
|
|
|
|
| class NextTokenDataCollator: |
| """Simple collator for pre-tokenized input_ids (no MLM masking).""" |
|
|
| def __call__(self, batch): |
| input_ids = torch.stack([item["input_ids"] for item in batch]) |
| return {"input_ids": input_ids} |
|
|
|
|
| def create_mock_pretraining_data(n_sequences=5000, seq_len=2048): |
| """Create mock pretraining data.""" |
| |
| print("Creating mock pretraining dataset for from-scratch training...") |
| |
| sequences = [] |
| for _ in range(n_sequences): |
| seq = np.random.randint(2, 25426, seq_len) |
| sequences.append(seq) |
| |
| print(f"✓ Created {n_sequences} sequences") |
| |
| return sequences |
|
|
|
|
| def main(): |
| print("=" * 80) |
| print("GeneMamba Phase 3: Train from Scratch (Next-Token)") |
| print("=" * 80) |
|
|
| model_id = "mineself2016/GeneMamba" |
| output_dir = "./from_scratch_pretrain" |
| checkpoint_dir = Path(output_dir) / "checkpoint-last" |
| |
| |
| |
| |
| print("\n[Step 1] Loading tokenizer...") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
|
|
| print("✓ Tokenizer loaded:") |
| print(f" - vocab_size: {tokenizer.vocab_size}") |
| print(f" - [UNK] token/id: {tokenizer.unk_token}/{tokenizer.unk_token_id}") |
| print(f" - [PAD] token/id: {tokenizer.pad_token}/{tokenizer.pad_token_id}") |
| print(f" - [CLS] token/id: {tokenizer.cls_token}/{tokenizer.cls_token_id}") |
| print(f" - [MASK] token/id: {tokenizer.mask_token}/{tokenizer.mask_token_id}") |
| |
| |
| |
| |
| print("\n[Step 2] Building model (resume if checkpoint exists)...") |
|
|
| model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) |
| model_config.vocab_size = 25426 |
| model_config.hidden_size = 256 |
| model_config.num_hidden_layers = 12 |
| model_config.intermediate_size = 1024 |
| model_config.max_position_embeddings = 2048 |
| model_config.mamba_mode = "mean" |
|
|
| resume_from_checkpoint = None |
| if checkpoint_dir.exists(): |
| resume_from_checkpoint = str(checkpoint_dir) |
| else: |
| resume_from_checkpoint = get_last_checkpoint(output_dir) |
|
|
| if resume_from_checkpoint is not None: |
| model = AutoModelForMaskedLM.from_pretrained( |
| resume_from_checkpoint, |
| trust_remote_code=True, |
| local_files_only=True, |
| ) |
| print(f"✓ Found checkpoint, resume from: {resume_from_checkpoint}") |
| else: |
| model = AutoModelForMaskedLM.from_config(model_config, trust_remote_code=True) |
| print("✓ No checkpoint found, start from scratch") |
| |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| |
| print(f"✓ Model initialized:") |
| print(f" - Total parameters: {total_params / 1e6:.2f}M") |
| print(f" - Trainable parameters: {trainable_params / 1e6:.2f}M") |
| |
| |
| |
| |
| print("\n[Step 3] Preparing training data...") |
| |
| sequences = create_mock_pretraining_data(n_sequences=5000, seq_len=2048) |
| |
| |
| train_size = int(0.8 * len(sequences)) |
| train_sequences = sequences[:train_size] |
| eval_sequences = sequences[train_size:] |
| |
| train_dataset = PretrainingDataset(train_sequences) |
| eval_dataset = PretrainingDataset(eval_sequences) |
| |
| print(f"✓ Datasets created:") |
| print(f" - Train: {len(train_dataset)}") |
| print(f" - Eval: {len(eval_dataset)}") |
| |
| |
| |
| |
| print("\n[Step 4] Setting up data collator...") |
|
|
| data_collator = NextTokenDataCollator() |
| print(f"✓ Data collator ready") |
| |
| |
| |
| |
| print("\n[Step 5] Setting up training...") |
| |
| training_args = TrainingArguments( |
| output_dir=output_dir, |
| num_train_epochs=5, |
| per_device_train_batch_size=16, |
| per_device_eval_batch_size=16, |
| learning_rate=5e-4, |
| weight_decay=0.01, |
| warmup_steps=500, |
| logging_steps=50, |
| eval_strategy="epoch", |
| save_strategy="epoch", |
| load_best_model_at_end=True, |
| metric_for_best_model="eval_loss", |
| report_to="none", |
| seed=42, |
| optim="adamw_torch", |
| gradient_accumulation_steps=1, |
| max_grad_norm=1.0, |
| ) |
| |
| print(f"✓ Training config:") |
| print(f" - Output: {output_dir}") |
| print(f" - Epochs: {training_args.num_train_epochs}") |
| print(f" - Batch size: {training_args.per_device_train_batch_size}") |
| print(f" - Learning rate: {training_args.learning_rate}") |
| |
| |
| |
| |
| print("\n[Step 6] Starting training...") |
| print("(This may take a while. In practice, use more GPUs/data for real pretraining)") |
| |
| trainer = NextTokenTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| data_collator=data_collator, |
| ) |
|
|
| train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint) |
| |
| print(f"✓ Training complete!") |
| print(f" - Final training loss: {train_result.training_loss:.4f}") |
| |
| |
| |
| |
| print("\n[Step 7] Evaluating...") |
| |
| eval_results = trainer.evaluate() |
| |
| print(f"✓ Evaluation Results:") |
| for metric, value in eval_results.items(): |
| if isinstance(value, (int, float)): |
| print(f" - {metric}: {value:.4f}") |
| |
| |
| |
| |
| print("\n[Step 8] Saving model...") |
| |
| save_dir = "./my_genemamba_from_scratch" |
| model.save_pretrained(save_dir) |
| model_config.save_pretrained(save_dir) |
| |
| print(f"✓ Model and config saved to '{save_dir}'") |
| print(f" Files created:") |
| print(f" - config.json") |
| print(f" - model.safetensors (or pytorch_model.bin)") |
| |
| |
| |
| |
| print("\n[Step 9] Reloading model from checkpoint...") |
| |
| loaded_model = AutoModelForMaskedLM.from_pretrained( |
| save_dir, |
| trust_remote_code=True, |
| ) |
| |
| loaded_model.eval() |
| |
| |
| with torch.no_grad(): |
| sample_input = torch.randint(2, 25426, (2, 2048)) |
| outputs = loaded_model(sample_input) |
| logits = outputs.logits |
| |
| print(f"✓ Model reloaded and tested!") |
| print(f" - Input shape: {sample_input.shape}") |
| print(f" - Logits shape: {logits.shape}") |
| |
| |
| |
| |
| print("\n[Step 10] Model ready for conversion/deployment!") |
| print(f"✓ You can now:") |
| print(f" 1. Push to Hugging Face Hub:") |
| print(f" model.push_to_hub('your-username/GeneMamba-custom')") |
| print(f" 2. Use with downstream tasks:") |
| print(f" AutoModelForSequenceClassification.from_pretrained('{save_dir}', num_labels=N)") |
| print(f" 3. Extract embeddings:") |
| print(f" AutoModel.from_pretrained('{save_dir}')") |
| |
| print("\n" + "=" * 80) |
| print("Phase 3 Complete! Model trained from scratch and ready to use.") |
| print("=" * 80) |
| |
| return model, trainer, model_config |
|
|
|
|
| if __name__ == "__main__": |
| model, trainer, model_config = main() |
|
|