|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from hydra.core.config_store import ConfigStore |
|
|
from megatron.core import parallel_state |
|
|
from torch.utils.data import DataLoader, DistributedSampler |
|
|
|
|
|
from cosmos_predict1.diffusion.training.callbacks.iter_speed import IterSpeed |
|
|
from cosmos_predict1.diffusion.training.callbacks.low_precision import LowPrecisionCallback |
|
|
from cosmos_predict1.diffusion.training.datasets.dataset_multiview import Dataset |
|
|
from cosmos_predict1.diffusion.training.models.extend_model_multiview import FSDPExtendDiffusionModel |
|
|
from cosmos_predict1.diffusion.training.networks.general_dit_lvg_multiview import VideoExtendMultiviewGeneralDIT |
|
|
from cosmos_predict1.utils import log |
|
|
from cosmos_predict1.utils.callbacks.grad_clip import GradClip |
|
|
from cosmos_predict1.utils.lazy_config import PLACEHOLDER |
|
|
from cosmos_predict1.utils.lazy_config import LazyCall as L |
|
|
from cosmos_predict1.utils.lazy_config import LazyDict |
|
|
|
|
|
|
|
|
def get_sampler(dataset): |
|
|
return DistributedSampler( |
|
|
dataset, |
|
|
num_replicas=parallel_state.get_data_parallel_world_size(), |
|
|
rank=parallel_state.get_data_parallel_rank(), |
|
|
shuffle=True, |
|
|
seed=0, |
|
|
) |
|
|
|
|
|
|
|
|
cs = ConfigStore.instance() |
|
|
|
|
|
num_frames = 57 |
|
|
num_views = 5 |
|
|
view_keys = ["pinhole_front_left", "pinhole_front", "pinhole_front_right", "pinhole_side_left", "pinhole_side_right"] |
|
|
example_multiview_dataset_waymo = L(Dataset)( |
|
|
dataset_dir="datasets/waymo", |
|
|
sequence_interval=1, |
|
|
num_frames=num_frames, |
|
|
view_keys=view_keys, |
|
|
video_size=(480, 848), |
|
|
) |
|
|
|
|
|
|
|
|
video2world_multiview_7b_example_waymo = LazyDict( |
|
|
dict( |
|
|
defaults=[ |
|
|
{"override /net": "faditv2_7b"}, |
|
|
{"override /conditioner": "video_cond"}, |
|
|
{"override /ckpt_klass": "fsdp"}, |
|
|
{"override /checkpoint": "local"}, |
|
|
{"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, |
|
|
"_self_", |
|
|
], |
|
|
job=dict( |
|
|
project="posttraining", |
|
|
group="diffusion_video2world", |
|
|
name="video2world_multiview_7b_example_waymo", |
|
|
), |
|
|
optimizer=dict( |
|
|
lr=2 ** (-14.3), |
|
|
weight_decay=0.1, |
|
|
betas=[0.9, 0.99], |
|
|
eps=1e-10, |
|
|
), |
|
|
checkpoint=dict( |
|
|
save_iter=200, |
|
|
|
|
|
broadcast_via_filesystem=False, |
|
|
load_path="checkpoints/Cosmos-Predict1-7B-Video2World-Sample-AV-Multiview/model.pt", |
|
|
load_training_state=False, |
|
|
strict_resume=False, |
|
|
keys_not_to_resume=[], |
|
|
), |
|
|
trainer=dict( |
|
|
max_iter=2000, |
|
|
distributed_parallelism="fsdp", |
|
|
logging_iter=200, |
|
|
callbacks=dict( |
|
|
grad_clip=L(GradClip)( |
|
|
model_key="model", |
|
|
fsdp_enabled=True, |
|
|
), |
|
|
low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), |
|
|
iter_speed=L(IterSpeed)( |
|
|
every_n=200, |
|
|
hit_thres=5, |
|
|
), |
|
|
), |
|
|
), |
|
|
model_parallel=dict( |
|
|
sequence_parallel=False, |
|
|
tensor_model_parallel_size=1, |
|
|
context_parallel_size=1, |
|
|
), |
|
|
model=dict( |
|
|
n_views=num_views, |
|
|
|
|
|
latent_shape=[ |
|
|
16, |
|
|
16, |
|
|
88, |
|
|
160, |
|
|
], |
|
|
loss_reduce="mean", |
|
|
ema=dict( |
|
|
enabled=True, |
|
|
), |
|
|
fsdp_enabled=True, |
|
|
fsdp=dict( |
|
|
policy="block", |
|
|
checkpoint=True, |
|
|
min_num_params=1024, |
|
|
sharding_group_size=32, |
|
|
sharding_strategy="hybrid", |
|
|
), |
|
|
net=L(VideoExtendMultiviewGeneralDIT)( |
|
|
rope_h_extrapolation_ratio=1, |
|
|
rope_w_extrapolation_ratio=1, |
|
|
rope_t_extrapolation_ratio=2, |
|
|
n_views=num_views, |
|
|
), |
|
|
conditioner=dict( |
|
|
video_cond_bool=dict( |
|
|
condition_location="first_random_n", |
|
|
cfg_unconditional_type="zero_condition_region_condition_mask", |
|
|
apply_corruption_to_condition_region="noise_with_sigma", |
|
|
condition_on_augment_sigma=False, |
|
|
dropout_rate=0.0, |
|
|
first_random_n_num_condition_t_max=2, |
|
|
normalize_condition_latent=False, |
|
|
|
|
|
augment_sigma_sample_p_mean=-3.0, |
|
|
augment_sigma_sample_p_std=2.0, |
|
|
augment_sigma_sample_multiplier=1.0, |
|
|
) |
|
|
), |
|
|
vae=dict(pixel_chunk_duration=num_frames), |
|
|
), |
|
|
model_obj=L(FSDPExtendDiffusionModel)( |
|
|
config=PLACEHOLDER, |
|
|
fsdp_checkpointer=PLACEHOLDER, |
|
|
), |
|
|
|
|
|
scheduler=dict( |
|
|
warm_up_steps=[2500], |
|
|
cycle_lengths=[10000000000000], |
|
|
f_start=[1.0e-6], |
|
|
f_max=[1.0], |
|
|
f_min=[1.0], |
|
|
), |
|
|
dataloader_train=L(DataLoader)( |
|
|
dataset=example_multiview_dataset_waymo, |
|
|
sampler=L(get_sampler)(dataset=example_multiview_dataset_waymo), |
|
|
batch_size=1, |
|
|
drop_last=True, |
|
|
pin_memory=True, |
|
|
num_workers=8, |
|
|
), |
|
|
dataloader_val=L(DataLoader)( |
|
|
dataset=example_multiview_dataset_waymo, |
|
|
sampler=L(get_sampler)(dataset=example_multiview_dataset_waymo), |
|
|
batch_size=1, |
|
|
drop_last=True, |
|
|
pin_memory=True, |
|
|
num_workers=8, |
|
|
), |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def register_experiments(cs): |
|
|
|
|
|
for _item in [ |
|
|
video2world_multiview_7b_example_waymo, |
|
|
]: |
|
|
experiment_name = _item["job"]["name"] |
|
|
log.info(f"Registering experiment: {experiment_name}") |
|
|
cs.store( |
|
|
group="experiment", |
|
|
package="_global_", |
|
|
name=experiment_name, |
|
|
node=_item, |
|
|
) |
|
|
|