StarMist0012's picture
Add files using upload-large-folder tool
3270dae verified
"""SGD optimizer factory."""
import torch.optim as optim
from taoTrain.core.base import BaseModel
from taoTrain.config import TrainingConfig
from .registry import register_optimizer
def _separate_parameters(model: BaseModel) -> tuple[list, list]:
"""
Separate model parameters into decay and no-decay groups.
Args:
model: Model instance
Returns:
Tuple of (decay_params, no_decay_params)
"""
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# Apply weight decay to all params except biases and layer norms
if 'bias' in name or 'norm' in name:
no_decay_params.append(param)
else:
decay_params.append(param)
return decay_params, no_decay_params
@register_optimizer("sgd")
def create_sgd(model: BaseModel, config: TrainingConfig) -> optim.SGD:
"""
Create SGD optimizer with weight decay applied selectively.
Args:
model: Model instance
config: TrainingConfig
Returns:
SGD optimizer instance
"""
optimizer_config = config.optimizer
# Separate parameters for weight decay
decay_params, no_decay_params = _separate_parameters(model)
param_groups = [
{"params": decay_params, "weight_decay": optimizer_config.weight_decay},
{"params": no_decay_params, "weight_decay": 0.0},
]
optimizer = optim.SGD(
param_groups,
lr=optimizer_config.learning_rate,
momentum=optimizer_config.betas[0], # Use first beta as momentum
)
return optimizer