add cosmos-1-diffusion-text2world
Browse files- .gitignore +3 -1
- config_helper.py +2 -0
- cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-text2world.py → cosmos1diffusiontext2world.py +0 -0
- cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-video2world.py → cosmos1diffusionvideo2world.py +0 -0
- df_config_config.py +3 -0
- inference_utils.py +9 -2
.gitignore
CHANGED
|
@@ -16,7 +16,9 @@
|
|
| 16 |
# Misc
|
| 17 |
outputs/
|
| 18 |
checkpoints/*
|
| 19 |
-
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# Data types
|
| 22 |
*.jit
|
|
|
|
| 16 |
# Misc
|
| 17 |
outputs/
|
| 18 |
checkpoints/*
|
| 19 |
+
checkpoints/README.md
|
| 20 |
+
checkpoints
|
| 21 |
+
.gitignore
|
| 22 |
|
| 23 |
# Data types
|
| 24 |
*.jit
|
config_helper.py
CHANGED
|
@@ -29,6 +29,7 @@ from omegaconf import DictConfig, OmegaConf
|
|
| 29 |
|
| 30 |
from .log import log
|
| 31 |
from .config import Config
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
def is_attrs_or_dataclass(obj) -> bool:
|
|
@@ -163,6 +164,7 @@ def import_all_modules_from_package(package_path: str, reload: bool = False, ski
|
|
| 163 |
reload (bool): Flag to determine whether to reload modules if they're already imported.
|
| 164 |
skip_underscore (bool): If True, skips importing modules that start with an underscore.
|
| 165 |
"""
|
|
|
|
| 166 |
log.debug(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}")
|
| 167 |
package = importlib.import_module(package_path)
|
| 168 |
package_directory = package.__path__
|
|
|
|
| 29 |
|
| 30 |
from .log import log
|
| 31 |
from .config import Config
|
| 32 |
+
from .inference import *
|
| 33 |
|
| 34 |
|
| 35 |
def is_attrs_or_dataclass(obj) -> bool:
|
|
|
|
| 164 |
reload (bool): Flag to determine whether to reload modules if they're already imported.
|
| 165 |
skip_underscore (bool): If True, skips importing modules that start with an underscore.
|
| 166 |
"""
|
| 167 |
+
return # TODO: we do not use this
|
| 168 |
log.debug(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}")
|
| 169 |
package = importlib.import_module(package_path)
|
| 170 |
package_directory = package.__path__
|
cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-text2world.py → cosmos1diffusiontext2world.py
RENAMED
|
File without changes
|
cosmos1/models/diffusion/config/inference/cosmos-1-diffusion-video2world.py → cosmos1diffusionvideo2world.py
RENAMED
|
File without changes
|
df_config_config.py
CHANGED
|
@@ -22,6 +22,9 @@ from .df_config_registry import register_configs
|
|
| 22 |
from .config import Config as ori_Config
|
| 23 |
from .config_helper import import_all_modules_from_package
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
@attrs.define(slots=False)
|
| 27 |
class Config(ori_Config):
|
|
|
|
| 22 |
from .config import Config as ori_Config
|
| 23 |
from .config_helper import import_all_modules_from_package
|
| 24 |
|
| 25 |
+
# I added importing here
|
| 26 |
+
from .cosmos1diffusiontext2world import LazyDict
|
| 27 |
+
from .cosmos1diffusionvideo2world import LazyDict
|
| 28 |
|
| 29 |
@attrs.define(slots=False)
|
| 30 |
class Config(ori_Config):
|
inference_utils.py
CHANGED
|
@@ -29,6 +29,8 @@ from .model_v2w import DiffusionV2WModel
|
|
| 29 |
from .config_helper import get_config_module, override
|
| 30 |
from .utils_io import load_from_fileobj
|
| 31 |
from .misc import misc
|
|
|
|
|
|
|
| 32 |
|
| 33 |
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
|
| 34 |
if TORCH_VERSION >= (1, 11):
|
|
@@ -272,8 +274,13 @@ def load_model_by_config(
|
|
| 272 |
config_file="projects/cosmos_video/config/config.py",
|
| 273 |
model_class=DiffusionT2WModel,
|
| 274 |
):
|
| 275 |
-
|
| 276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
config = override(config, ["--", f"experiment={config_job_name}"])
|
| 279 |
|
|
|
|
| 29 |
from .config_helper import get_config_module, override
|
| 30 |
from .utils_io import load_from_fileobj
|
| 31 |
from .misc import misc
|
| 32 |
+
from .df_config_config import make_config
|
| 33 |
+
from .log import log
|
| 34 |
|
| 35 |
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
|
| 36 |
if TORCH_VERSION >= (1, 11):
|
|
|
|
| 274 |
config_file="projects/cosmos_video/config/config.py",
|
| 275 |
model_class=DiffusionT2WModel,
|
| 276 |
):
|
| 277 |
+
# TODO: We need to modify this for huggingface because the config file path is different
|
| 278 |
+
# config_module = get_config_module(config_file)
|
| 279 |
+
# config = importlib.import_module(config_module).make_config()
|
| 280 |
+
if model_class in (DiffusionT2WModel, DiffusionV2WModel):
|
| 281 |
+
config = make_config()
|
| 282 |
+
else:
|
| 283 |
+
raise NotImplementedError("TODO: didn't implement autoregression")
|
| 284 |
|
| 285 |
config = override(config, ["--", f"experiment={config_job_name}"])
|
| 286 |
|