|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, List |
|
|
|
|
|
import attrs |
|
|
|
|
|
from cosmos_predict1.diffusion.training.config.base.model import MultiviewModelConfig |
|
|
from cosmos_predict1.diffusion.training.config.text2world.registry import ( |
|
|
register_configs as register_configs_text2world, |
|
|
) |
|
|
from cosmos_predict1.diffusion.training.config.text2world_multiview.registry import ( |
|
|
register_configs as register_configs_text2world_multiview, |
|
|
) |
|
|
from cosmos_predict1.diffusion.training.config.video2world.registry import ( |
|
|
register_configs as register_configs_video2world, |
|
|
) |
|
|
from cosmos_predict1.diffusion.training.config.video2world_multiview.registry import ( |
|
|
register_configs as register_configs_video2world_multiview, |
|
|
) |
|
|
from cosmos_predict1.diffusion.training.models.model import DiffusionModel |
|
|
from cosmos_predict1.utils import config |
|
|
from cosmos_predict1.utils.config_helper import import_all_modules_from_package |
|
|
from cosmos_predict1.utils.lazy_config import PLACEHOLDER |
|
|
from cosmos_predict1.utils.lazy_config import LazyCall as L |
|
|
from cosmos_predict1.utils.lazy_config import LazyDict |
|
|
from cosmos_predict1.utils.trainer import Trainer |
|
|
|
|
|
|
|
|
@attrs.define(slots=False) |
|
|
class Config(config.Config): |
|
|
|
|
|
|
|
|
defaults: List[Any] = attrs.field( |
|
|
factory=lambda: [ |
|
|
"_self_", |
|
|
{"data_train": None}, |
|
|
{"data_val": None}, |
|
|
{"optimizer": "fusedadamw"}, |
|
|
{"scheduler": "lambdalinear"}, |
|
|
{"callbacks": None}, |
|
|
{"net": None}, |
|
|
{"conditioner": "add_fps_image_size_padding_mask"}, |
|
|
{"fsdp": None}, |
|
|
{"ema": "power"}, |
|
|
{"vae": "vae1"}, |
|
|
{"checkpoint": "pbss"}, |
|
|
{"ckpt_klass": "fsdp"}, |
|
|
|
|
|
{"experiment": None}, |
|
|
] |
|
|
) |
|
|
model_obj: LazyDict = L(DiffusionModel)( |
|
|
config=PLACEHOLDER, |
|
|
) |
|
|
|
|
|
|
|
|
def make_config(): |
|
|
c = Config( |
|
|
model=MultiviewModelConfig(), |
|
|
optimizer=None, |
|
|
scheduler=None, |
|
|
dataloader_train=None, |
|
|
dataloader_val=None, |
|
|
) |
|
|
|
|
|
|
|
|
c.job.project = "cosmos_predict1" |
|
|
c.job.group = "debug" |
|
|
c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}" |
|
|
|
|
|
c.trainer.type = Trainer |
|
|
|
|
|
c.trainer.max_iter = 400_000 |
|
|
c.trainer.logging_iter = 10 |
|
|
c.trainer.validation_iter = 100 |
|
|
c.trainer.run_validation = False |
|
|
c.trainer.callbacks = None |
|
|
|
|
|
c.checkpoint = None |
|
|
|
|
|
|
|
|
register_configs_text2world() |
|
|
register_configs_video2world() |
|
|
register_configs_text2world_multiview() |
|
|
register_configs_video2world_multiview() |
|
|
|
|
|
|
|
|
|
|
|
import_all_modules_from_package("cosmos_predict1.diffusion.training.config.text2world", reload=True) |
|
|
import_all_modules_from_package("cosmos_predict1.diffusion.training.config.video2world", reload=True) |
|
|
import_all_modules_from_package("cosmos_predict1.diffusion.training.config.text2world_multiview", reload=True) |
|
|
import_all_modules_from_package("cosmos_predict1.diffusion.training.config.video2world_multiview", reload=True) |
|
|
|
|
|
return c |
|
|
|