|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""optimizer config options: |
|
|
|
|
|
fused_adam - FusedAdamConfig |
|
|
adamw - AdamWConfig |
|
|
""" |
|
|
|
|
|
import torch |
|
|
|
|
|
from cosmos_predict1.utils import fused_adam |
|
|
from cosmos_predict1.utils.lazy_config import PLACEHOLDER |
|
|
from cosmos_predict1.utils.lazy_config import LazyCall as L |
|
|
from cosmos_predict1.utils.lazy_config import LazyDict |
|
|
from cosmos_predict1.utils.scheduler import WarmupCosineLR, WarmupLambdaLR |
|
|
|
|
|
FusedAdamConfig: LazyDict = L(fused_adam.FusedAdam)( |
|
|
capturable=True, |
|
|
master_weights=True, |
|
|
adam_w_mode=True, |
|
|
params=PLACEHOLDER, |
|
|
lr=1e-4, |
|
|
betas=(0.5, 0.999), |
|
|
eps=1e-8, |
|
|
weight_decay=0.01, |
|
|
) |
|
|
|
|
|
AdamWConfig: LazyDict = L(torch.optim.AdamW)( |
|
|
params=PLACEHOLDER, |
|
|
lr=1e-4, |
|
|
betas=(0.5, 0.999), |
|
|
eps=1e-8, |
|
|
weight_decay=0.01, |
|
|
) |
|
|
|
|
|
WarmupLRConfig: LazyDict = L(WarmupLambdaLR)(optimizer=PLACEHOLDER, warmup=5000) |
|
|
|
|
|
FusedAdamDiscConfig: LazyDict = L(fused_adam.FusedAdam)( |
|
|
capturable=True, |
|
|
master_weights=True, |
|
|
adam_w_mode=True, |
|
|
params=PLACEHOLDER, |
|
|
lr=4e-4, |
|
|
betas=(0.5, 0.999), |
|
|
eps=1e-8, |
|
|
weight_decay=0.01, |
|
|
) |
|
|
|
|
|
WarmupLRDiscConfig: LazyDict = L(WarmupLambdaLR)(optimizer=PLACEHOLDER, warmup=5000) |
|
|
|
|
|
WarmupCosineLRConfig: LazyDict = L(WarmupCosineLR)( |
|
|
optimizer=PLACEHOLDER, warmup_iters=5000, lr_decay_iters=1000000, min_lr=1e-8 |
|
|
) |
|
|
|