""" Phase 3: Train from Scratch (Next-Token Objective) Demonstrates how to initialize and train GeneMamba with next-token prediction. If a checkpoint exists, training resumes from checkpoint automatically. Usage: python examples/3_pretrain_from_scratch.py """ import torch import numpy as np from torch.utils.data import Dataset from pathlib import Path from transformers import ( AutoTokenizer, AutoConfig, AutoModelForMaskedLM, Trainer, TrainingArguments, ) from transformers.trainer_utils import get_last_checkpoint class PretrainingDataset(Dataset): """Dataset for pretraining.""" def __init__(self, input_ids_list, max_length=2048): self.input_ids_list = input_ids_list self.max_length = max_length def __len__(self): return len(self.input_ids_list) def __getitem__(self, idx): input_ids = self.input_ids_list[idx] # Pad or truncate if len(input_ids) >= self.max_length: input_ids = input_ids[:self.max_length] else: input_ids = np.pad( input_ids, (0, self.max_length - len(input_ids)), constant_values=1 ) return { "input_ids": torch.tensor(input_ids, dtype=torch.long), } class NextTokenTrainer(Trainer): """Use next-token prediction loss: logits[:, :-1] vs labels[:, 1:].""" def compute_loss(self, model, inputs, return_outputs=False): input_ids = inputs["input_ids"] outputs = model(input_ids=input_ids) logits = outputs.logits shift_logits = logits[:, :-1, :].contiguous() shift_labels = input_ids[:, 1:].contiguous().to(shift_logits.device) loss_fct = torch.nn.CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ) return (loss, outputs) if return_outputs else loss class NextTokenDataCollator: """Simple collator for pre-tokenized input_ids (no MLM masking).""" def __call__(self, batch): input_ids = torch.stack([item["input_ids"] for item in batch]) return {"input_ids": input_ids} def create_mock_pretraining_data(n_sequences=5000, seq_len=2048): """Create mock pretraining data.""" print("Creating mock pretraining dataset for from-scratch training...") sequences = [] for _ in range(n_sequences): seq = np.random.randint(2, 25426, seq_len) sequences.append(seq) print(f"✓ Created {n_sequences} sequences") return sequences def main(): print("=" * 80) print("GeneMamba Phase 3: Train from Scratch (Next-Token)") print("=" * 80) model_id = "mineself2016/GeneMamba" output_dir = "./from_scratch_pretrain" checkpoint_dir = Path(output_dir) / "checkpoint-last" # ============================================================ # Step 1: Load tokenizer spec # ============================================================ print("\n[Step 1] Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) print("✓ Tokenizer loaded:") print(f" - vocab_size: {tokenizer.vocab_size}") print(f" - [UNK] token/id: {tokenizer.unk_token}/{tokenizer.unk_token_id}") print(f" - [PAD] token/id: {tokenizer.pad_token}/{tokenizer.pad_token_id}") print(f" - [CLS] token/id: {tokenizer.cls_token}/{tokenizer.cls_token_id}") print(f" - [MASK] token/id: {tokenizer.mask_token}/{tokenizer.mask_token_id}") # ============================================================ # Step 2: Build config and initialize/resume model # ============================================================ print("\n[Step 2] Building model (resume if checkpoint exists)...") model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) model_config.vocab_size = 25426 model_config.hidden_size = 256 model_config.num_hidden_layers = 12 model_config.intermediate_size = 1024 model_config.max_position_embeddings = 2048 model_config.mamba_mode = "mean" resume_from_checkpoint = None if checkpoint_dir.exists(): resume_from_checkpoint = str(checkpoint_dir) else: resume_from_checkpoint = get_last_checkpoint(output_dir) if resume_from_checkpoint is not None: model = AutoModelForMaskedLM.from_pretrained( resume_from_checkpoint, trust_remote_code=True, local_files_only=True, ) print(f"✓ Found checkpoint, resume from: {resume_from_checkpoint}") else: model = AutoModelForMaskedLM.from_config(model_config, trust_remote_code=True) print("✓ No checkpoint found, start from scratch") # Count parameters total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"✓ Model initialized:") print(f" - Total parameters: {total_params / 1e6:.2f}M") print(f" - Trainable parameters: {trainable_params / 1e6:.2f}M") # ============================================================ # Step 3: Prepare data # ============================================================ print("\n[Step 3] Preparing training data...") sequences = create_mock_pretraining_data(n_sequences=5000, seq_len=2048) # Split train_size = int(0.8 * len(sequences)) train_sequences = sequences[:train_size] eval_sequences = sequences[train_size:] train_dataset = PretrainingDataset(train_sequences) eval_dataset = PretrainingDataset(eval_sequences) print(f"✓ Datasets created:") print(f" - Train: {len(train_dataset)}") print(f" - Eval: {len(eval_dataset)}") # ============================================================ # Step 4: Data collator for next-token training # ============================================================ print("\n[Step 4] Setting up data collator...") data_collator = NextTokenDataCollator() print(f"✓ Data collator ready") # ============================================================ # Step 5: Training arguments # ============================================================ print("\n[Step 5] Setting up training...") training_args = TrainingArguments( output_dir=output_dir, num_train_epochs=5, per_device_train_batch_size=16, per_device_eval_batch_size=16, learning_rate=5e-4, weight_decay=0.01, warmup_steps=500, logging_steps=50, eval_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="eval_loss", report_to="none", seed=42, optim="adamw_torch", gradient_accumulation_steps=1, max_grad_norm=1.0, ) print(f"✓ Training config:") print(f" - Output: {output_dir}") print(f" - Epochs: {training_args.num_train_epochs}") print(f" - Batch size: {training_args.per_device_train_batch_size}") print(f" - Learning rate: {training_args.learning_rate}") # ============================================================ # Step 6: Train # ============================================================ print("\n[Step 6] Starting training...") print("(This may take a while. In practice, use more GPUs/data for real pretraining)") trainer = NextTokenTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator, ) train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint) print(f"✓ Training complete!") print(f" - Final training loss: {train_result.training_loss:.4f}") # ============================================================ # Step 7: Evaluate # ============================================================ print("\n[Step 7] Evaluating...") eval_results = trainer.evaluate() print(f"✓ Evaluation Results:") for metric, value in eval_results.items(): if isinstance(value, (int, float)): print(f" - {metric}: {value:.4f}") # ============================================================ # Step 8: Save model and config # ============================================================ print("\n[Step 8] Saving model...") save_dir = "./my_genemamba_from_scratch" model.save_pretrained(save_dir) model_config.save_pretrained(save_dir) print(f"✓ Model and config saved to '{save_dir}'") print(f" Files created:") print(f" - config.json") print(f" - model.safetensors (or pytorch_model.bin)") # ============================================================ # Step 9: Reload and verify # ============================================================ print("\n[Step 9] Reloading model from checkpoint...") loaded_model = AutoModelForMaskedLM.from_pretrained( save_dir, trust_remote_code=True, ) loaded_model.eval() # Test inference with torch.no_grad(): sample_input = torch.randint(2, 25426, (2, 2048)) outputs = loaded_model(sample_input) logits = outputs.logits print(f"✓ Model reloaded and tested!") print(f" - Input shape: {sample_input.shape}") print(f" - Logits shape: {logits.shape}") # ============================================================ # Step 10: Optional - Convert to different format # ============================================================ print("\n[Step 10] Model ready for conversion/deployment!") print(f"✓ You can now:") print(f" 1. Push to Hugging Face Hub:") print(f" model.push_to_hub('your-username/GeneMamba-custom')") print(f" 2. Use with downstream tasks:") print(f" AutoModelForSequenceClassification.from_pretrained('{save_dir}', num_labels=N)") print(f" 3. Extract embeddings:") print(f" AutoModel.from_pretrained('{save_dir}')") print("\n" + "=" * 80) print("Phase 3 Complete! Model trained from scratch and ready to use.") print("=" * 80) return model, trainer, model_config if __name__ == "__main__": model, trainer, model_config = main()