"""Optimizer registry and factory for instantiating optimizers.""" from typing import Dict, Type, Callable, Any import torch.optim as optim from taoTrain.core.base import BaseModel from taoTrain.config import TrainingConfig, OptimizerEnum # Global registry for optimizers _OPTIMIZER_REGISTRY: Dict[str, Callable] = {} def register_optimizer(name: str): """ Decorator to register a custom optimizer factory function. Args: name: Name of the optimizer (e.g., 'adamw', 'adam', 'sgd') """ def decorator(fn: Callable) -> Callable: if name in _OPTIMIZER_REGISTRY: raise ValueError(f"Optimizer '{name}' is already registered") _OPTIMIZER_REGISTRY[name] = fn return fn return decorator def get_registered_optimizers() -> Dict[str, Callable]: """Get all registered optimizer factory functions.""" return _OPTIMIZER_REGISTRY.copy() def get_optimizer( model: BaseModel, config: TrainingConfig, ) -> optim.Optimizer: """ Create an optimizer instance from config. Args: model: Model to optimize config: TrainingConfig with optimizer configuration Returns: Optimizer instance Raises: ValueError: If optimizer type is not registered """ # Handle both enum and string values optimizer_type = config.optimizer.optimizer_type if isinstance(optimizer_type, str): optimizer_name = optimizer_type else: optimizer_name = optimizer_type.value if optimizer_name not in _OPTIMIZER_REGISTRY: raise ValueError( f"Unknown optimizer: {optimizer_name}. " f"Available: {list(_OPTIMIZER_REGISTRY.keys())}" ) factory_fn = _OPTIMIZER_REGISTRY[optimizer_name] return factory_fn(model, config) def register_builtin_optimizers(): """Register all built-in optimizers.""" # Import here to trigger decorator registration (avoid circular imports) from . import adamw # noqa: F401 from . import adam # noqa: F401 from . import sgd # noqa: F401 from . import hybrid_muon_adamw # noqa: F401 # Auto-register built-in optimizers when module is imported register_builtin_optimizers()