|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""callbacks config options: |
|
|
|
|
|
BASIC_CALLBACKS: always recommended to use |
|
|
""" |
|
|
|
|
|
from cosmos_predict1.tokenizer.training.callbacks import ( |
|
|
AdaptCkptStateDict, |
|
|
ExpandLossMask, |
|
|
GradClipCallback, |
|
|
TorchCompile, |
|
|
) |
|
|
from cosmos_predict1.utils.callback import EMAModelCallback, LowPrecisionCallback, ProgressBarCallback |
|
|
from cosmos_predict1.utils.lazy_config import PLACEHOLDER |
|
|
from cosmos_predict1.utils.lazy_config import LazyCall as L |
|
|
|
|
|
BASIC_CALLBACKS = dict( |
|
|
low_precision=L(LowPrecisionCallback)(update_iter=1, config=PLACEHOLDER, trainer=PLACEHOLDER), |
|
|
grad_clip=L(GradClipCallback)(grad_clip_norm=1, verbose=False, config=PLACEHOLDER, trainer=PLACEHOLDER), |
|
|
ema=L(EMAModelCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER), |
|
|
progress_bar=L(ProgressBarCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER), |
|
|
expand_loss_mask=L(ExpandLossMask)(kernel_size=51, config=PLACEHOLDER, trainer=PLACEHOLDER), |
|
|
adapt_ckpt_state_dict=L(AdaptCkptStateDict)(config=PLACEHOLDER, trainer=PLACEHOLDER), |
|
|
torch_compile=L(TorchCompile)( |
|
|
compile_after_iterations=8, |
|
|
compile_network=False, |
|
|
compile_loss=False, |
|
|
compile_loss_keys=["flow", "perceptual"], |
|
|
), |
|
|
) |
|
|
|