"""SGD optimizer factory.""" import torch.optim as optim from taoTrain.core.base import BaseModel from taoTrain.config import TrainingConfig from .registry import register_optimizer def _separate_parameters(model: BaseModel) -> tuple[list, list]: """ Separate model parameters into decay and no-decay groups. Args: model: Model instance Returns: Tuple of (decay_params, no_decay_params) """ decay_params = [] no_decay_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # Apply weight decay to all params except biases and layer norms if 'bias' in name or 'norm' in name: no_decay_params.append(param) else: decay_params.append(param) return decay_params, no_decay_params @register_optimizer("sgd") def create_sgd(model: BaseModel, config: TrainingConfig) -> optim.SGD: """ Create SGD optimizer with weight decay applied selectively. Args: model: Model instance config: TrainingConfig Returns: SGD optimizer instance """ optimizer_config = config.optimizer # Separate parameters for weight decay decay_params, no_decay_params = _separate_parameters(model) param_groups = [ {"params": decay_params, "weight_decay": optimizer_config.weight_decay}, {"params": no_decay_params, "weight_decay": 0.0}, ] optimizer = optim.SGD( param_groups, lr=optimizer_config.learning_rate, momentum=optimizer_config.betas[0], # Use first beta as momentum ) return optimizer