# train.py # runs train_olmoe_adapter.py with parameters when called # #!/usr/bin/env python """ Run script for fine-tuning OlmoE with adapters on specific text domains. Handles argument parsing and configuration. """ import argparse import os import sys from dataclasses import dataclass, field from typing import Optional from transformers import ( HfArgumentParser, TrainingArguments, ) @dataclass class ScriptArguments: """ Arguments for the run script that aren't covered by TrainingArguments. """ model_path: str = field( default="allenai/OLMo-7B-Instruct", metadata={"help": "Path to the model to fine-tune"} ) output_dir: str = field( default="./output_olmoe_adapter", metadata={"help": "Directory to save the model and logs"} ) adapter_size: int = field( default=64, metadata={"help": "Size of the adapter layers"} ) dataset_name: str = field( default="mlfoundations/dclm-baseline-1.0", metadata={"help": "Name of the dataset to use"} ) max_steps: int = field( default=10000, metadata={"help": "Maximum number of training steps"} ) learning_rate: float = field( default=5e-5, metadata={"help": "Learning rate for fine-tuning"} ) per_device_batch_size: int = field( default=8, metadata={"help": "Batch size per device"} ) gradient_accumulation_steps: int = field( default=1, metadata={"help": "Number of steps to accumulate gradients"} ) # use_8bit: bool = field( # default=False, # metadata={"help": "Whether to use 8-bit precision"} # ) # use_4bit: bool = field( # default=False, # metadata={"help": "Whether to use 4-bit precision"} # ) def main(): # Parse command-line arguments parser = HfArgumentParser(ScriptArguments) args = parser.parse_args_into_dataclasses()[0] # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Prepare command for training cmd = [ "python", "train_olmoe_adapter.py", # Model arguments f"--model_name_or_path={args.model_path}", f"--adapter_size={args.adapter_size}", "--freeze_base_model=True", # Always freeze the base model f"--checkpoint_dir={args.output_dir}", # Data arguments f"--dataset_name={args.dataset_name}", "--streaming=True", # Always stream for large datasets "--streaming_buffer_size=8192", "--max_seq_length=1024", # Training arguments f"--output_dir={args.output_dir}", f"--per_device_train_batch_size={args.per_device_batch_size}", f"--gradient_accumulation_steps={args.gradient_accumulation_steps}", f"--learning_rate={args.learning_rate}", f"--max_steps={args.max_steps}", "--warmup_steps=500", "--logging_steps=10", "--save_steps=1000", "--save_total_limit=2", "--dataloader_num_workers=4", "--seed=42", ] # Add precision flags if needed # if args.use_8bit: # cmd.append("--load_in_8bit") # if args.use_4bit: # cmd.append("--load_in_4bit") # Print the command for logging cmd_str = " ".join(cmd) print(f"Running command: {cmd_str}") # Execute the training script os.environ["PYTHONPATH"] = os.getcwd() ret = os.system(cmd_str) if ret != 0: print(f"Training failed with exit code {ret}") sys.exit(ret) print("Training completed successfully!") if __name__ == "__main__": main()