|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from cosmos_predict1.autoregressive.callbacks.video_sampling_teacher_forcing import VideoSamplingTeacherForcing |
|
|
from cosmos_predict1.callbacks.grad_clip import GradClip |
|
|
from cosmos_predict1.utils.callback import ProgressBarCallback |
|
|
from cosmos_predict1.utils.lazy_config import LazyCall as L |
|
|
|
|
|
BASIC_CALLBACKS = dict( |
|
|
progress_bar=L(ProgressBarCallback)(), |
|
|
grad_clip=L(GradClip)(clip_norm=1.0, fsdp_enabled="${model.model_config.fsdp_enabled}", model_key="model"), |
|
|
) |
|
|
|
|
|
VIDEO_TEACHER_FORCING_CALLBACK = dict( |
|
|
vid_sampling_tf=L(VideoSamplingTeacherForcing)( |
|
|
every_n=500, |
|
|
video_latent_shape="${model.model_config.video_latent_shape}", |
|
|
num_frames_to_display=4, |
|
|
save_folder="video_sampling_teacher_forcing", |
|
|
) |
|
|
) |
|
|
|