#train_olmoe_adapter.py #!/usr/bin/env python """ Training script for OlmoE model with adapters on the mlfoundations/dclm-baseline-1.0 dataset. This script demonstrates parameter-efficient fine-tuning using adapters. """ import os import math import logging import argparse from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple, Any, Union import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, IterableDataset from torch.optim import AdamW from torch.optim.lr_scheduler import LambdaLR from datasets import load_dataset from transformers import ( OlmoConfig, OlmoForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed, get_scheduler, ) from tqdm import tqdm from accelerate import Accelerator, DistributedType from accelerate.utils import find_batch_size from modeling_olmoe import ( OlmoEWithAdaptersForCausalLM, OlmoEForCausalLM, ) # Set up logging logger = logging.getLogger(__name__) logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) @dataclass class ModelArguments: """Arguments for model configuration.""" model_name_or_path: str = field( default="allenai/OLMo-7B-Instruct", metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} ) adapter_size: int = field( default=64, metadata={"help": "Size of the adapter layers"} ) freeze_base_model: bool = field( default=True, metadata={"help": "Whether to freeze all parameters except the adapters"} ) checkpoint_dir: Optional[str] = field( default=None, metadata={"help": "Path to save model checkpoints"} ) @dataclass class DataArguments: """Arguments for dataset configuration.""" dataset_name: str = field( default="mlfoundations/dclm-baseline-1.0", metadata={"help": "Dataset name or path for training"} ) streaming: bool = field( default=True, metadata={"help": "Whether to stream the dataset"} ) streaming_buffer_size: int = field( default=8192, metadata={"help": "Buffer size for streaming dataset"} ) max_seq_length: int = field( default=1024, metadata={"help": "Maximum sequence length for training"} ) preprocessing_num_workers: Optional[int] = field( default=None, metadata={"help": "Number of workers for preprocessing"} ) text_column_name: str = field( default="text", metadata={"help": "Column name for text data"} ) class StreamingTextDataset(IterableDataset): """Dataset for streaming text data.""" def __init__( self, dataset_name: str, tokenizer, max_seq_length: int, streaming: bool = True, text_column_name: str = "text", buffer_size: int = 8192, split: str = "train", ): self.tokenizer = tokenizer self.max_seq_length = max_seq_length self.text_column_name = text_column_name # Load dataset in streaming mode self.dataset = load_dataset( dataset_name, split=split, streaming=streaming, ) if streaming: # Buffer for streaming self.dataset = self.dataset.shuffle(buffer_size=buffer_size) def __iter__(self): buffer = [] current_length = 0 for example in self.dataset: text = example[self.text_column_name] if not text or len(text.strip()) == 0: continue tokenized = self.tokenizer( text, truncation=False, return_attention_mask=False, return_token_type_ids=False, add_special_tokens=False, ) ids = tokenized["input_ids"] buffer.extend(ids) # Yield complete sequences and update buffer while len(buffer) >= self.max_seq_length: yield { "input_ids": torch.tensor(buffer[:self.max_seq_length], dtype=torch.long), "labels": torch.tensor(buffer[:self.max_seq_length], dtype=torch.long), } buffer = buffer[self.max_seq_length:] def create_optimizer_and_scheduler( model: nn.Module, args: TrainingArguments, num_training_steps: int ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: """Create optimizer and learning rate scheduler.""" # Get only trainable parameters if using adapters with frozen base model if hasattr(model, "get_trainable_parameters"): optimizer_params = model.get_trainable_parameters() logger.info(f"Training with {len(optimizer_params)} trainable parameters") else: # No parameter filtering - get all parameters that require grad optimizer_params = [p for p in model.parameters() if p.requires_grad] logger.info(f"Training with {len(optimizer_params)} parameters") # Create optimizer optimizer = AdamW( optimizer_params, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_epsilon, weight_decay=args.weight_decay, ) # Create scheduler scheduler = get_scheduler( name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=num_training_steps, ) return optimizer, scheduler def train( model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, ): """Main training function.""" # Set up accelerator accelerator = Accelerator( gradient_accumulation_steps=training_args.gradient_accumulation_steps, mixed_precision=training_args.fp16 and "fp16" or training_args.bf16 and "bf16" or "no", ) # Log information about the training setup logger.info(accelerator.state) if accelerator.is_local_main_process: logger.info(f"Model arguments: {model_args}") logger.info(f"Data arguments: {data_args}") logger.info(f"Training arguments: {training_args}") # Set seed for reproducibility set_seed(training_args.seed) # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) # Ensure the tokenizer has padding token and EOS token set if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load model config and update with adapter size config = OlmoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) config.adapter_size = model_args.adapter_size # Load model with adapters logger.info(f"Loading OlmoE model with adapters from {model_args.model_name_or_path}") base_model = OlmoForCausalLM.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) # Create adapter model from base model weights model = OlmoEWithAdaptersForCausalLM(config) # Copy weights from base model to adapter model # This is needed because we're using a custom model class model.load_state_dict(base_model.state_dict(), strict=False) # Freeze base model parameters if requested if model_args.freeze_base_model: logger.info("Freezing base model parameters") model.freeze_base_model() # Set up streaming dataset logger.info(f"Loading dataset: {data_args.dataset_name}") train_dataset = StreamingTextDataset( dataset_name=data_args.dataset_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length, streaming=data_args.streaming, buffer_size=data_args.streaming_buffer_size, text_column_name=data_args.text_column_name, ) # Data collator to handle batching data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False, ) # Create data loader train_dataloader = DataLoader( train_dataset, batch_size=training_args.per_device_train_batch_size, collate_fn=data_collator, num_workers=data_args.preprocessing_num_workers or 0, ) # Estimate number of update steps # For streaming datasets, we'll use a fixed number of steps num_update_steps_per_epoch = training_args.max_steps num_training_steps = training_args.max_steps # Create optimizer and scheduler optimizer, lr_scheduler = create_optimizer_and_scheduler( model=model, args=training_args, num_training_steps=num_training_steps, ) # Prepare for distributed training with accelerator model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler ) # Get total batch size for logging total_batch_size = ( training_args.per_device_train_batch_size * accelerator.num_processes * training_args.gradient_accumulation_steps ) logger.info(f"Total batch size (with parallel & accumulation): {total_batch_size}") # Log estimated number of steps logger.info(f"Number of training steps: {num_training_steps}") logger.info(f"Number of warmup steps: {training_args.warmup_steps}") # Keep track of training progress progress_bar = tqdm( range(num_training_steps), disable=not accelerator.is_local_main_process, desc="Training", ) completed_steps = 0 starting_epoch = 0 global_step = 0 # Training loop logger.info("Starting training...") model.train() for step, batch in enumerate(train_dataloader): # Skip steps for resuming if starting_epoch > 0 and step < starting_epoch * num_update_steps_per_epoch: progress_bar.update(1) continue with accelerator.accumulate(model): # Forward pass outputs = model(**batch) loss = outputs.loss # Backward pass accelerator.backward(loss) # Update weights optimizer.step() lr_scheduler.step() optimizer.zero_grad() # Update progress bar progress_bar.update(1) completed_steps += 1 global_step += 1 # Log metrics if global_step % training_args.logging_steps == 0: # Gather loss from all processes loss_value = accelerator.gather(loss).mean().item() logger.info(f"Step {global_step}: loss = {loss_value:.4f}, lr = {lr_scheduler.get_last_lr()[0]:.8f}") # Log to tensorboard if available if hasattr(accelerator.trackers[0], "store"): accelerator.trackers[0].store({ "loss": loss_value, "learning_rate": lr_scheduler.get_last_lr()[0], "step": global_step, }) # Save checkpoint if training_args.save_steps > 0 and global_step % training_args.save_steps == 0: if model_args.checkpoint_dir is not None: output_dir = os.path.join(model_args.checkpoint_dir, f"checkpoint-{global_step}") accelerator.save_state(output_dir) logger.info(f"Saved checkpoint to {output_dir}") # Save the model separately if accelerator.is_main_process: unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.save_pretrained( output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, ) tokenizer.save_pretrained(output_dir) # Check if we've reached max steps if completed_steps >= num_training_steps: break # Save final model if model_args.checkpoint_dir is not None: output_dir = os.path.join(model_args.checkpoint_dir, "final-model") accelerator.save_state(output_dir) # Save the model separately if accelerator.is_main_process: unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.save_pretrained( output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, ) tokenizer.save_pretrained(output_dir) logger.info(f"Saved final model to {output_dir}") logger.info("Training complete!") def main(): """Main entry point.""" parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() # Set up output directory if model_args.checkpoint_dir is None: model_args.checkpoint_dir = training_args.output_dir os.makedirs(model_args.checkpoint_dir, exist_ok=True) # Run training train(model_args, data_args, training_args) if __name__ == "__main__": main()