StarMist0012's picture
Add files using upload-large-folder tool
3270dae verified
"""Optimizer registry and factory for instantiating optimizers."""
from typing import Dict, Type, Callable, Any
import torch.optim as optim
from taoTrain.core.base import BaseModel
from taoTrain.config import TrainingConfig, OptimizerEnum
# Global registry for optimizers
_OPTIMIZER_REGISTRY: Dict[str, Callable] = {}
def register_optimizer(name: str):
"""
Decorator to register a custom optimizer factory function.
Args:
name: Name of the optimizer (e.g., 'adamw', 'adam', 'sgd')
"""
def decorator(fn: Callable) -> Callable:
if name in _OPTIMIZER_REGISTRY:
raise ValueError(f"Optimizer '{name}' is already registered")
_OPTIMIZER_REGISTRY[name] = fn
return fn
return decorator
def get_registered_optimizers() -> Dict[str, Callable]:
"""Get all registered optimizer factory functions."""
return _OPTIMIZER_REGISTRY.copy()
def get_optimizer(
model: BaseModel,
config: TrainingConfig,
) -> optim.Optimizer:
"""
Create an optimizer instance from config.
Args:
model: Model to optimize
config: TrainingConfig with optimizer configuration
Returns:
Optimizer instance
Raises:
ValueError: If optimizer type is not registered
"""
# Handle both enum and string values
optimizer_type = config.optimizer.optimizer_type
if isinstance(optimizer_type, str):
optimizer_name = optimizer_type
else:
optimizer_name = optimizer_type.value
if optimizer_name not in _OPTIMIZER_REGISTRY:
raise ValueError(
f"Unknown optimizer: {optimizer_name}. "
f"Available: {list(_OPTIMIZER_REGISTRY.keys())}"
)
factory_fn = _OPTIMIZER_REGISTRY[optimizer_name]
return factory_fn(model, config)
def register_builtin_optimizers():
"""Register all built-in optimizers."""
# Import here to trigger decorator registration (avoid circular imports)
from . import adamw # noqa: F401
from . import adam # noqa: F401
from . import sgd # noqa: F401
from . import hybrid_muon_adamw # noqa: F401
# Auto-register built-in optimizers when module is imported
register_builtin_optimizers()