""" Training entry point for ColiFormer. This script wraps finetune.py and loads configuration from YAML files. Usage: python scripts/train.py --config configs/train_ecoli_alm.yaml python scripts/train.py --config configs/train_ecoli_quick.yaml """ import argparse import os import sys from pathlib import Path # Add parent directory to path to import finetune sys.path.insert(0, str(Path(__file__).parent.parent)) def load_config(config_path: str) -> dict: """ Load configuration from YAML file. Args: config_path: Path to YAML config file Returns: Dictionary with configuration values """ # Lazy import so `python scripts/train.py --help` works without dependencies installed. import yaml if not os.path.exists(config_path): raise FileNotFoundError(f"Config file not found: {config_path}") with open(config_path, 'r') as f: config = yaml.safe_load(f) return config def config_to_args(config: dict) -> argparse.Namespace: """ Convert config dictionary to argparse.Namespace compatible with finetune.py. Args: config: Configuration dictionary from YAML Returns: argparse.Namespace with all required arguments """ # Extract nested config values data_config = config.get('data', {}) training_config = config.get('training', {}) checkpoint_config = config.get('checkpoint', {}) alm_config = config.get('alm', {}) gc_penalty_config = config.get('gc_penalty', {}) # Build args namespace args = argparse.Namespace() # Data paths args.dataset_dir = data_config.get('dataset_dir', 'data') # Checkpoint paths args.checkpoint_dir = checkpoint_config.get('checkpoint_dir', 'models/checkpoints') args.checkpoint_filename = checkpoint_config.get('checkpoint_filename', 'finetune.ckpt') # Training parameters args.batch_size = training_config.get('batch_size', 6) args.max_epochs = training_config.get('max_epochs', 15) args.num_workers = training_config.get('num_workers', 5) args.accumulate_grad_batches = training_config.get('accumulate_grad_batches', 1) args.num_gpus = training_config.get('num_gpus', 4) args.learning_rate = training_config.get('learning_rate', 5e-5) args.warmup_fraction = training_config.get('warmup_fraction', 0.1) args.save_every_n_steps = training_config.get('save_every_n_steps', 512) args.seed = training_config.get('seed', 123) args.log_every_n_steps = training_config.get('log_every_n_steps', 20) args.debug = training_config.get('debug', False) # GC penalty (legacy) args.gc_penalty_weight = gc_penalty_config.get('weight', 0.0) # ALM parameters args.use_lagrangian = alm_config.get('enabled', False) args.gc_target = alm_config.get('gc_target', 0.52) args.curriculum_epochs = alm_config.get('curriculum_epochs', 3) args.lagrangian_rho = alm_config.get('initial_penalty_factor', 20.0) # Use initial_penalty_factor as rho args.alm_tolerance = alm_config.get('tolerance', 1e-5) args.alm_dual_tolerance = alm_config.get('dual_tolerance', 1e-5) args.alm_penalty_update_factor = alm_config.get('penalty_update_factor', 10.0) args.alm_initial_penalty_factor = alm_config.get('initial_penalty_factor', 20.0) args.alm_tolerance_update_factor = alm_config.get('tolerance_update_factor', 0.1) args.alm_rel_penalty_increase_threshold = alm_config.get('rel_penalty_increase_threshold', 0.1) args.alm_max_penalty = alm_config.get('max_penalty', 1e6) args.alm_min_penalty = alm_config.get('min_penalty', 1e-6) return args def validate_config(config: dict): """ Validate configuration before training. Args: config: Configuration dictionary Raises: ValueError: If configuration is invalid """ data_config = config.get('data', {}) dataset_dir = data_config.get('dataset_dir', 'data') # Check dataset directory exists if not os.path.exists(dataset_dir): raise ValueError(f"Dataset directory not found: {dataset_dir}") # Check for expected data files finetune_set = os.path.join(dataset_dir, 'finetune_set.json') if not os.path.exists(finetune_set): raise ValueError( f"Training data not found: {finetune_set}\n" "Please run data preprocessing first:\n" " python scripts/preprocess_data.py" ) # Validate checkpoint directory can be created checkpoint_config = config.get('checkpoint', {}) checkpoint_dir = checkpoint_config.get('checkpoint_dir', 'models/checkpoints') os.makedirs(checkpoint_dir, exist_ok=True) def main(): """Main entry point for training.""" parser = argparse.ArgumentParser( description="Train ENCOT model with configuration file", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Train with main ALM configuration python scripts/train.py --config configs/train_ecoli_alm.yaml # Quick test training (CPU, 1 epoch) python scripts/train.py --config configs/train_ecoli_quick.yaml # Override config values from command line python scripts/train.py --config configs/train_ecoli_alm.yaml --num_gpus 2 --batch_size 4 """ ) parser.add_argument( "--config", type=str, required=True, help="Path to YAML configuration file" ) parser.add_argument( "--num_gpus", type=int, default=None, help="Override number of GPUs from config" ) parser.add_argument( "--batch_size", type=int, default=None, help="Override batch size from config" ) parser.add_argument( "--max_epochs", type=int, default=None, help="Override max epochs from config" ) args = parser.parse_args() try: # Lazy import so `--help` works even if training deps are missing. from finetune import main as finetune_main # Load configuration print(f"Loading configuration from {args.config}...") config = load_config(args.config) # Override with command-line arguments if provided if args.num_gpus is not None: config.setdefault('training', {})['num_gpus'] = args.num_gpus if args.batch_size is not None: config.setdefault('training', {})['batch_size'] = args.batch_size if args.max_epochs is not None: config.setdefault('training', {})['max_epochs'] = args.max_epochs # Validate configuration print("Validating configuration...") validate_config(config) # Convert config to args namespace train_args = config_to_args(config) # Print training summary print("\n" + "="*60) print("Training Configuration Summary") print("="*60) print(f"Dataset directory: {train_args.dataset_dir}") print(f"Checkpoint directory: {train_args.checkpoint_dir}") print(f"Checkpoint filename: {train_args.checkpoint_filename}") print(f"Batch size: {train_args.batch_size}") print(f"Max epochs: {train_args.max_epochs}") print(f"Learning rate: {train_args.learning_rate}") print(f"Number of GPUs: {train_args.num_gpus}") print(f"ALM enabled: {train_args.use_lagrangian}") if train_args.use_lagrangian: print(f"GC target: {train_args.gc_target}") print(f"Curriculum epochs: {train_args.curriculum_epochs}") print("="*60 + "\n") # Run training finetune_main(train_args) except Exception as e: print(f"Error: {e}", file=sys.stderr) sys.exit(1) if __name__ == "__main__": main()