|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
from hydra.core.config_store import ConfigStore |
|
|
from megatron.core import parallel_state |
|
|
from torch.utils.data import DataLoader, DistributedSampler |
|
|
|
|
|
from cosmos_predict1.diffusion.training.callbacks.iter_speed import IterSpeed |
|
|
from cosmos_predict1.diffusion.training.callbacks.low_precision import LowPrecisionCallback |
|
|
from cosmos_predict1.diffusion.training.datasets.dataset_3D import Dataset_3D |
|
|
from cosmos_predict1.diffusion.training.models.extend_model import FSDPExtendDiffusionModel |
|
|
from cosmos_predict1.diffusion.training.networks.general_dit_action import ActionConditionalVideoExtendGeneralDIT |
|
|
from cosmos_predict1.utils import log |
|
|
from cosmos_predict1.utils.callback import ProgressBarCallback |
|
|
from cosmos_predict1.utils.callbacks.grad_clip import GradClip |
|
|
from cosmos_predict1.utils.lazy_config import PLACEHOLDER |
|
|
from cosmos_predict1.utils.lazy_config import LazyCall as L |
|
|
from cosmos_predict1.utils.lazy_config import LazyDict |
|
|
|
|
|
cs = ConfigStore.instance() |
|
|
base_path = "datasets/bridge/" |
|
|
train_annotation_path = os.path.join(base_path, "annotation/train") |
|
|
val_annotation_path = os.path.join(base_path, "annotation/val") |
|
|
test_annotation_path = os.path.join(base_path, "annotation/test") |
|
|
|
|
|
|
|
|
def get_sampler(dataset): |
|
|
return DistributedSampler( |
|
|
dataset, |
|
|
num_replicas=parallel_state.get_data_parallel_world_size(), |
|
|
rank=parallel_state.get_data_parallel_rank(), |
|
|
shuffle=True, |
|
|
seed=0, |
|
|
) |
|
|
|
|
|
|
|
|
bridge_train_dataset = L(Dataset_3D)( |
|
|
train_annotation_path=train_annotation_path, |
|
|
val_annotation_path=val_annotation_path, |
|
|
test_annotation_path=test_annotation_path, |
|
|
video_path=base_path, |
|
|
sequence_interval=1, |
|
|
num_frames=2, |
|
|
cam_ids=[0], |
|
|
accumulate_action=False, |
|
|
video_size=[256, 320], |
|
|
val_start_frame_interval=1, |
|
|
mode="train", |
|
|
load_action=True, |
|
|
load_t5_embeddings=False, |
|
|
) |
|
|
|
|
|
bridge_val_dataset = L(Dataset_3D)( |
|
|
train_annotation_path=train_annotation_path, |
|
|
val_annotation_path=val_annotation_path, |
|
|
test_annotation_path=test_annotation_path, |
|
|
video_path=base_path, |
|
|
sequence_interval=1, |
|
|
num_frames=2, |
|
|
cam_ids=[0], |
|
|
accumulate_action=False, |
|
|
video_size=[256, 320], |
|
|
val_start_frame_interval=1, |
|
|
mode="val", |
|
|
load_action=True, |
|
|
load_t5_embeddings=False, |
|
|
) |
|
|
|
|
|
|
|
|
dataloader_train = L(DataLoader)( |
|
|
dataset=bridge_train_dataset, |
|
|
sampler=L(get_sampler)(dataset=bridge_train_dataset), |
|
|
batch_size=8, |
|
|
drop_last=True, |
|
|
pin_memory=True, |
|
|
num_workers=8, |
|
|
) |
|
|
dataloader_val = L(DataLoader)( |
|
|
dataset=bridge_val_dataset, |
|
|
sampler=L(get_sampler)(dataset=bridge_val_dataset), |
|
|
batch_size=1, |
|
|
drop_last=True, |
|
|
pin_memory=True, |
|
|
num_workers=8, |
|
|
) |
|
|
|
|
|
|
|
|
video2world_action_bridge_2frames = LazyDict( |
|
|
dict( |
|
|
defaults=[ |
|
|
{"override /net": "faditv2_7b"}, |
|
|
{"override /conditioner": "action_conditional_video_cond"}, |
|
|
{"override /ckpt_klass": "fsdp"}, |
|
|
{"override /checkpoint": "local"}, |
|
|
{"override /vae": "cosmos_diffusion_tokenizer_comp8x8x8"}, |
|
|
"_self_", |
|
|
], |
|
|
job=dict( |
|
|
project="posttraining", |
|
|
group="diffusion_video2world_action", |
|
|
name="video2world_action_bridge_2frames", |
|
|
), |
|
|
optimizer=dict( |
|
|
lr=4e-4, |
|
|
weight_decay=0.1, |
|
|
betas=[0.9, 0.99], |
|
|
eps=1e-10, |
|
|
), |
|
|
checkpoint=dict( |
|
|
save_iter=500, |
|
|
broadcast_via_filesystem=False, |
|
|
load_path="checkpoints/Cosmos-Predict1-7B-Video2World/model.pt", |
|
|
load_training_state=False, |
|
|
strict_resume=False, |
|
|
keys_not_to_resume=[], |
|
|
), |
|
|
trainer=dict( |
|
|
max_iter=2_000, |
|
|
distributed_parallelism="fsdp", |
|
|
logging_iter=200, |
|
|
callbacks=dict( |
|
|
grad_clip=L(GradClip)( |
|
|
model_key="model", |
|
|
fsdp_enabled=True, |
|
|
), |
|
|
low_prec=L(LowPrecisionCallback)(config=PLACEHOLDER, trainer=PLACEHOLDER, update_iter=1), |
|
|
iter_speed=L(IterSpeed)( |
|
|
every_n=10, |
|
|
hit_thres=0, |
|
|
), |
|
|
progress_bar=L(ProgressBarCallback)(), |
|
|
), |
|
|
), |
|
|
model_parallel=dict( |
|
|
sequence_parallel=False, |
|
|
tensor_model_parallel_size=1, |
|
|
context_parallel_size=1, |
|
|
), |
|
|
model=dict( |
|
|
|
|
|
latent_shape=[ |
|
|
16, |
|
|
2, |
|
|
32, |
|
|
40, |
|
|
], |
|
|
loss_reduce="mean", |
|
|
ema=dict( |
|
|
enabled=True, |
|
|
), |
|
|
fsdp_enabled=True, |
|
|
fsdp=dict( |
|
|
policy="block", |
|
|
checkpoint=False, |
|
|
min_num_params=1024, |
|
|
sharding_group_size=32, |
|
|
sharding_strategy="hybrid", |
|
|
), |
|
|
net=L(ActionConditionalVideoExtendGeneralDIT)( |
|
|
rope_h_extrapolation_ratio=1, |
|
|
rope_w_extrapolation_ratio=1, |
|
|
rope_t_extrapolation_ratio=2, |
|
|
), |
|
|
conditioner=dict( |
|
|
video_cond_bool=dict( |
|
|
condition_location="first_random_n", |
|
|
cfg_unconditional_type="zero_condition_region_condition_mask", |
|
|
first_random_n_num_condition_t_max=1, |
|
|
apply_corruption_to_condition_region="noise_with_sigma", |
|
|
condition_on_augment_sigma=False, |
|
|
) |
|
|
), |
|
|
|
|
|
vae=dict(pixel_chunk_duration=1), |
|
|
), |
|
|
model_obj=L(FSDPExtendDiffusionModel)( |
|
|
config=PLACEHOLDER, |
|
|
fsdp_checkpointer=PLACEHOLDER, |
|
|
), |
|
|
|
|
|
scheduler=dict( |
|
|
warm_up_steps=[2500], |
|
|
cycle_lengths=[10000000000000], |
|
|
f_start=[1.0e-6], |
|
|
f_max=[1.0], |
|
|
f_min=[1.0], |
|
|
), |
|
|
dataloader_train=dataloader_train, |
|
|
dataloader_val=dataloader_val, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
def register_experiments(cs): |
|
|
|
|
|
for _item in [ |
|
|
video2world_action_bridge_2frames, |
|
|
]: |
|
|
experiment_name = _item["job"]["name"] |
|
|
log.info(f"Registering experiment: {experiment_name}") |
|
|
cs.store( |
|
|
group="experiment", |
|
|
package="_global_", |
|
|
name=experiment_name, |
|
|
node=_item, |
|
|
) |
|
|
|