|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
from typing import Any, Dict, Optional, Type, TypeVar, Union |
|
|
|
|
|
import attrs |
|
|
import torch |
|
|
from megatron.core import ModelParallelConfig |
|
|
|
|
|
from cosmos_predict1.utils import callback |
|
|
from cosmos_predict1.utils.lazy_config import LazyCall as L |
|
|
from cosmos_predict1.utils.lazy_config import LazyDict |
|
|
from cosmos_predict1.utils.misc import Color |
|
|
|
|
|
T = TypeVar("T") |
|
|
|
|
|
|
|
|
def _is_attrs_instance(obj: object) -> bool: |
|
|
""" |
|
|
Helper function to check if an object is an instance of an attrs-defined class. |
|
|
|
|
|
Args: |
|
|
obj: The object to check. |
|
|
|
|
|
Returns: |
|
|
bool: True if the object is an instance of an attrs-defined class, False otherwise. |
|
|
""" |
|
|
return hasattr(obj, "__attrs_attrs__") |
|
|
|
|
|
|
|
|
def make_freezable(cls: T) -> T: |
|
|
""" |
|
|
A decorator that adds the capability to freeze instances of an attrs-defined class. |
|
|
|
|
|
NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need |
|
|
to hack on a "_is_frozen" attribute. |
|
|
|
|
|
This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime. |
|
|
Once an instance is frozen, its attributes cannot be changed. It also recursively freezes |
|
|
any attrs-defined objects that are attributes of the class. |
|
|
|
|
|
Usage: |
|
|
@make_freezable |
|
|
@attrs.define(slots=False) |
|
|
class MyClass: |
|
|
attribute1: int |
|
|
attribute2: str |
|
|
|
|
|
obj = MyClass(1, 'a') |
|
|
obj.freeze() # Freeze the instance |
|
|
obj.attribute1 = 2 # Raises AttributeError |
|
|
|
|
|
Args: |
|
|
cls: The class to be decorated. |
|
|
|
|
|
Returns: |
|
|
The decorated class with added freezing capability. |
|
|
""" |
|
|
|
|
|
if not hasattr(cls, "__dict__"): |
|
|
raise TypeError( |
|
|
"make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped " |
|
|
"class was defined with `@attrs.define(slots=False)`" |
|
|
) |
|
|
|
|
|
original_setattr = cls.__setattr__ |
|
|
|
|
|
def setattr_override(self, key, value) -> None: |
|
|
""" |
|
|
Override __setattr__ to allow modifications during initialization |
|
|
and prevent modifications once the instance is frozen. |
|
|
""" |
|
|
if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen": |
|
|
raise AttributeError("Cannot modify frozen instance") |
|
|
original_setattr(self, key, value) |
|
|
|
|
|
cls.__setattr__ = setattr_override |
|
|
|
|
|
def freeze(self: object) -> None: |
|
|
""" |
|
|
Freeze the instance and all its attrs-defined attributes. |
|
|
""" |
|
|
for _, value in attrs.asdict(self, recurse=False).items(): |
|
|
if _is_attrs_instance(value) and hasattr(value, "freeze"): |
|
|
value.freeze() |
|
|
self._is_frozen = True |
|
|
|
|
|
cls.freeze = freeze |
|
|
|
|
|
return cls |
|
|
|
|
|
|
|
|
def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str: |
|
|
""" |
|
|
Recursively pretty prints attrs objects with color. |
|
|
""" |
|
|
|
|
|
assert attrs.has(obj.__class__) |
|
|
|
|
|
lines: list[str] = [] |
|
|
for attribute in attrs.fields(obj.__class__): |
|
|
value = getattr(obj, attribute.name) |
|
|
if attrs.has(value.__class__): |
|
|
if use_color: |
|
|
lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":") |
|
|
else: |
|
|
lines.append(" " * indent + "* " + attribute.name + ":") |
|
|
lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color)) |
|
|
else: |
|
|
if use_color: |
|
|
lines.append( |
|
|
" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value) |
|
|
) |
|
|
else: |
|
|
lines.append(" " * indent + "* " + attribute.name + ": " + str(value)) |
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
def pretty_print_overrides(overrides: Optional[list[str]] = None, use_color: bool = False) -> str: |
|
|
""" |
|
|
Pretty prints overrides. |
|
|
""" |
|
|
|
|
|
lines: list[str] = [] |
|
|
lines.append(Color.cyan("* ") + Color.green("overrides") + ": ") |
|
|
for override in overrides: |
|
|
if override == "--": |
|
|
continue |
|
|
attribute_name, attribute_value = override.split("=") |
|
|
if use_color: |
|
|
lines.append(" " + Color.cyan("* ") + Color.green(attribute_name) + ": " + Color.yellow(attribute_value)) |
|
|
else: |
|
|
lines.append(" " + "* " + attribute_name + ": " + str(attribute_value)) |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
@make_freezable |
|
|
@attrs.define(slots=False) |
|
|
class JobConfig: |
|
|
|
|
|
project: str = "" |
|
|
|
|
|
group: str = "" |
|
|
|
|
|
name: str = "" |
|
|
|
|
|
@property |
|
|
def path(self) -> str: |
|
|
return f"{self.project}/{self.group}/{self.name}" |
|
|
|
|
|
@property |
|
|
def path_local(self) -> str: |
|
|
local_root = os.environ.get("OUTPUT_ROOT", "checkpoints") |
|
|
return f"{local_root}/{self.path}" |
|
|
|
|
|
|
|
|
@make_freezable |
|
|
@attrs.define(slots=False) |
|
|
class EMAConfig: |
|
|
|
|
|
enabled: bool = False |
|
|
|
|
|
beta: float = 0.9999 |
|
|
|
|
|
torch_compile_buffer_renaming: bool = False |
|
|
|
|
|
|
|
|
@make_freezable |
|
|
@attrs.define(slots=False) |
|
|
class DDPConfig: |
|
|
|
|
|
find_unused_parameters: bool = False |
|
|
|
|
|
static_graph: bool = True |
|
|
|
|
|
broadcast_buffers: bool = True |
|
|
|
|
|
|
|
|
@make_freezable |
|
|
@attrs.define(slots=False) |
|
|
class CuDNNConfig: |
|
|
|
|
|
deterministic: bool = False |
|
|
|
|
|
benchmark: bool = True |
|
|
|
|
|
|
|
|
@make_freezable |
|
|
@attrs.define(slots=False) |
|
|
class JITConfig: |
|
|
|
|
|
enabled: bool = False |
|
|
|
|
|
input_shape: Union[list[int], None] = None |
|
|
|
|
|
device: str = "cuda" |
|
|
|
|
|
dtype: str = "bfloat16" |
|
|
|
|
|
strict: bool = True |
|
|
|
|
|
|
|
|
@make_freezable |
|
|
@attrs.define(slots=False) |
|
|
class CheckpointConfig: |
|
|
|
|
|
type: Optional[Dict] = None |
|
|
|
|
|
dcp_async_mode_enabled: bool = False |
|
|
|
|
|
save_iter: int = 999999999 |
|
|
|
|
|
load_path: str = "" |
|
|
|
|
|
load_training_state: bool = False |
|
|
|
|
|
only_load_scheduler_state: bool = False |
|
|
|
|
|
strict_resume: bool = True |
|
|
|
|
|
verbose: bool = True |
|
|
|
|
|
jit: JITConfig = attrs.field(factory=JITConfig) |
|
|
|
|
|
keys_not_to_resume: list[str] = [] |
|
|
|
|
|
broadcast_via_filesystem: bool = False |
|
|
load_ema_to_reg: bool = False |
|
|
async_saving: bool = True |
|
|
|
|
|
|
|
|
@make_freezable |
|
|
@attrs.define(slots=False) |
|
|
class TrainerConfig: |
|
|
from cosmos_predict1.utils.trainer import Trainer |
|
|
|
|
|
type: Type[Trainer] = Trainer |
|
|
|
|
|
|
|
|
callbacks: LazyDict = LazyDict( |
|
|
dict( |
|
|
ema=L(callback.EMAModelCallback)(), |
|
|
progress_bar=L(callback.ProgressBarCallback)(), |
|
|
) |
|
|
) |
|
|
|
|
|
distributed_parallelism: str = "ddp" |
|
|
|
|
|
ddp: DDPConfig = attrs.field(factory=DDPConfig) |
|
|
|
|
|
cudnn: CuDNNConfig = attrs.field(factory=CuDNNConfig) |
|
|
|
|
|
seed: int = 0 |
|
|
|
|
|
grad_scaler_args: dict = attrs.field(factory=lambda: dict(enabled=False)) |
|
|
|
|
|
max_iter: int = 999999999 |
|
|
|
|
|
max_val_iter: int | None = None |
|
|
|
|
|
logging_iter: int = 100 |
|
|
|
|
|
run_validation: bool = True |
|
|
|
|
|
validation_iter: int = 999999999 |
|
|
|
|
|
timeout_period: int = 999999999 |
|
|
|
|
|
memory_format: torch.memory_format = torch.preserve_format |
|
|
|
|
|
grad_accum_iter: int = 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@make_freezable |
|
|
@attrs.define(slots=False) |
|
|
class Config: |
|
|
"""Config for a job. |
|
|
|
|
|
See /README.md/Configuration System for more info. |
|
|
""" |
|
|
|
|
|
|
|
|
model: LazyDict |
|
|
|
|
|
optimizer: LazyDict = LazyDict(dict(dummy=None)) |
|
|
|
|
|
scheduler: LazyDict = LazyDict(dict(dummy=None)) |
|
|
|
|
|
dataloader_train: LazyDict = LazyDict(dict(dummy=None)) |
|
|
|
|
|
dataloader_val: LazyDict = LazyDict(dict(dummy=None)) |
|
|
|
|
|
|
|
|
job: JobConfig = attrs.field(factory=JobConfig) |
|
|
|
|
|
|
|
|
trainer: TrainerConfig = attrs.field(factory=TrainerConfig) |
|
|
|
|
|
|
|
|
model_parallel: ModelParallelConfig = attrs.field(factory=ModelParallelConfig) |
|
|
|
|
|
|
|
|
checkpoint: CheckpointConfig = attrs.field(factory=CheckpointConfig) |
|
|
|
|
|
def pretty_print(self, use_color: bool = False) -> str: |
|
|
return _pretty_print_attrs_instance(self, 0, use_color) |
|
|
|
|
|
|
|
|
job: JobConfig = attrs.field(factory=JobConfig) |
|
|
|
|
|
def to_dict(self) -> dict[str, Any]: |
|
|
return attrs.asdict(self) |
|
|
|
|
|
def validate(self) -> None: |
|
|
"""Validate that the config has all required fields.""" |
|
|
assert self.job.project != "", "Project name is required." |
|
|
assert self.job.group != "", "Group name is required." |
|
|
assert self.job.name != "", "Job name is required." |
|
|
|