File size: 8,296 Bytes
1faccd4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Optional
from omegaconf import MISSING
from verl.base_config import BaseConfig
__all__ = [
"OptimizerConfig",
"FSDPOptimizerConfig",
"McoreOptimizerConfig",
"build_optimizer",
"VeOmniOptimizerConfig",
"TorchtitanOptimizerConfig",
]
@dataclass
class OptimizerConfig(BaseConfig):
"""Base optimizer configuration.
Args:
lr (float): learning rate. Must be specified.
lr_warmup_steps_ratio (float): Warmup steps ratio; total steps will be injected at runtime.
total_training_steps (int): Total training steps (must be overridden at runtime).
weight_decay (float): Weight decay factor.
lr_warmup_steps (Optional[int]): Number of warmup steps; None delegates to lr_warmup_steps_ratio.
"""
_mutable_fields = {"clip_grad", "total_training_steps", "lr_warmup_steps"}
lr: float = 1e-3
lr_warmup_steps_ratio: float = 0.0
total_training_steps: int = -1
weight_decay: float = 0.01
lr_warmup_steps: Optional[int] = -1
betas: tuple[float, float] = (0.9, 0.999)
clip_grad: float = 1.0
# deprecate grad_clip
grad_clip: Optional[float] = None
def __post_init__(self):
assert self.lr != MISSING
if self.grad_clip is not None:
warnings.warn("`grad_clip` is deprecated, use `clip_grad` instead.", DeprecationWarning, stacklevel=2)
self.clip_grad = self.grad_clip
@dataclass
class VeOmniOptimizerConfig(OptimizerConfig):
"""VeOmni optimizer configuration extending base OptimizerConfig.
Args:
optimizer (str): Optimizer name; default is "adamw".
lr (float): Learning rate.
lr_min (float): Minimum learning rate.
lr_start (float): Starting learning rate for warmup.
lr_decay_ratio (float): LR decay ratio.
lr_scheduler_type (str): LR scheduler type: "constant" or "cosine".
"""
_mutable_fields = OptimizerConfig._mutable_fields.copy()
optimizer: str = "adamw"
lr_min: float = 0.0
lr_start: float = 0.0
lr_decay_ratio: float = 1.0
lr_scheduler_type: str = "constant"
override_optimizer_config: Optional[dict] = None
@dataclass
class FSDPOptimizerConfig(OptimizerConfig):
"""FSDP optimizer configuration extending base OptimizerConfig.
Args:
optimizer (str): Optimizer class name (e.g., "AdamW", "AdamW8bit", "_AdamW").
optimizer_impl (str): Module path to import optimizer from (e.g., "torch.optim", "torchao.optim",
"bitsandbytes.optim").
lr (float): Learning rate.
min_lr_ratio (Optional[float]): Minimum LR ratio for cosine schedule.
lr_scheduler_type (str): LR scheduler type: "constant" or "cosine".
num_cycles (float): Number of cosine cycles in LR schedule.
zero_indexed_step (bool): Whether the LR schedule uses 0-indexed steps. If True (default),
step counting starts at 0. If False, step counting starts at 1.
"""
_mutable_fields = OptimizerConfig._mutable_fields.copy()
_mutable_fields.add("lr_scheduler_type")
optimizer: str = "AdamW"
optimizer_impl: str = "torch.optim"
min_lr_ratio: Optional[float] = None
# deprecate warmup_style
warmup_style: Optional[str] = None
lr_scheduler_type: str = "constant"
num_cycles: float = 0.5
override_optimizer_config: Optional[dict] = None
zero_indexed_step: bool = True
def __post_init__(self):
if self.warmup_style is not None:
assert self.warmup_style in ["constant", "cosine"]
warnings.warn(
"`warmup_style` is deprecated, use `lr_scheduler_type` instead.", DeprecationWarning, stacklevel=2
)
self.lr_scheduler_type = self.warmup_style
assert self.lr_scheduler_type in ["constant", "cosine"]
return super().__post_init__()
@dataclass
class McoreOptimizerConfig(OptimizerConfig):
"""Mcore optimizer configuration extending base OptimizerConfig.
Args:
optimizer (str): Optimizer name; default is "adam".
lr (float): Learning rate.
clip_grad (float): Gradient clipping norm.
lr_warmup_init (float): Initial learning rate for warmup; defaults to 0.0.
lr_decay_steps (Optional[int]): Number of decay steps.
lr_decay_style (str): LR decay style: "constant", "linear", "cosine", or "inverse_square_root".
min_lr (float): Minimum learning rate.
weight_decay_incr_style (str): Weight decay increment style: "constant" or "cosine".
lr_wsd_decay_style (str): Weight-standard-deviation decay style: "constant", "exponential", or "cosine".
lr_wsd_decay_steps (Optional[int]): Number of steps for weight-standard-deviation decay.
use_checkpoint_opt_param_scheduler (bool): Whether to use checkpoint optimizer parameter scheduler.
"""
optimizer: str = "adam"
lr_warmup_init: float = 0.0
lr_decay_steps: Optional[int] = None
lr_decay_style: str = "linear"
min_lr: float = 0.0
weight_decay_incr_style: str = "constant"
lr_wsd_decay_style: str = "exponential"
lr_wsd_decay_steps: Optional[int] = None
use_checkpoint_opt_param_scheduler: bool = False
override_optimizer_config: Optional[dict] = None
@dataclass
class TorchtitanOptimizerConfig(OptimizerConfig):
"""Torchtitan optimizer configuration extending base OptimizerConfig.
Args:
name (str): Optimizer name; default is "AdamW".
eps (float): Epsilon value for AdamW optimizer, default 1e-8.
decay_type (str): Weight decay type: "linear", "sqrt", or "cosine".
min_lr_factor (float): Minimum learning rate factor.
"""
name: str = "AdamW"
eps: float = 1e-8
decay_type: str = "linear"
min_lr_factor: float = 0.0
def build_optimizer(parameters, config: FSDPOptimizerConfig):
"""Build an optimizer based on the configuration.
Dynamically imports and instantiates an optimizer class from the specified module.
Args:
parameters: Model parameters to optimize
config: FSDPOptimizerConfig with optimizer settings
Returns:
Optimizer instance
Examples:
# PyTorch AdamW
config.optimizer_impl = "torch.optim"
config.optimizer = "AdamW"
# TorchAO AdamW with bf16 stochastic rounding
config.optimizer_impl = "torchao.optim"
config.optimizer = "_AdamW"
config.override_optimizer_config = {"bf16_stochastic_round": True}
# BitsAndBytes AdamW 8bit
config.optimizer_impl = "bitsandbytes.optim"
config.optimizer = "AdamW8bit"
"""
import importlib
optimizer_args = {
"lr": config.lr,
"weight_decay": config.weight_decay,
}
optimizer_name_lower = config.optimizer.lower()
if "adam" in optimizer_name_lower or "ademamix" in optimizer_name_lower:
optimizer_args["betas"] = config.betas
if config.override_optimizer_config is not None:
optimizer_args.update(config.override_optimizer_config)
try:
module = importlib.import_module(config.optimizer_impl)
optimizer_cls = getattr(module, config.optimizer)
except ImportError as e:
raise ImportError(
f"Failed to import module '{config.optimizer_impl}'. Make sure the package is installed. Error: {e}"
) from e
except AttributeError as e:
raise AttributeError(
f"Optimizer '{config.optimizer}' not found in module '{config.optimizer_impl}'. "
f"Available optimizers: {dir(module)}"
) from e
return optimizer_cls(parameters, **optimizer_args)
|