|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import importlib |
|
|
import os |
|
|
|
|
|
from loguru import logger as logging |
|
|
|
|
|
from cosmos_predict1.utils.config import Config, pretty_print_overrides |
|
|
from cosmos_predict1.utils.config_helper import get_config_module, override |
|
|
from cosmos_predict1.utils.lazy_config import instantiate |
|
|
from cosmos_predict1.utils.lazy_config.lazy import LazyConfig |
|
|
|
|
|
|
|
|
@logging.catch(reraise=True) |
|
|
def launch(config: Config, args: argparse.Namespace) -> None: |
|
|
|
|
|
config.validate() |
|
|
|
|
|
config.freeze() |
|
|
trainer = config.trainer.type(config) |
|
|
|
|
|
model = instantiate(config.model) |
|
|
model.on_model_init_end() |
|
|
dataloader_train = instantiate(config.dataloader_train) |
|
|
dataloader_val = instantiate(config.dataloader_val) |
|
|
|
|
|
trainer.train( |
|
|
model, |
|
|
dataloader_train, |
|
|
dataloader_val, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Training") |
|
|
parser.add_argument("--config", help="Path to the config file", required=True) |
|
|
parser.add_argument( |
|
|
"opts", |
|
|
help=""" |
|
|
Modify config options at the end of the command. For Yacs configs, use |
|
|
space-separated "PATH.KEY VALUE" pairs. |
|
|
For python-based LazyConfig, use "path.key=value". |
|
|
""".strip(), |
|
|
default=None, |
|
|
nargs=argparse.REMAINDER, |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dryrun", |
|
|
action="store_true", |
|
|
help="Do a dry run without training. Useful for debugging the config.", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
config_module = get_config_module(args.config) |
|
|
config = importlib.import_module(config_module).make_config() |
|
|
config = override(config, args.opts) |
|
|
if args.dryrun: |
|
|
logging.info( |
|
|
"Config:\n" + config.pretty_print(use_color=True) + "\n" + pretty_print_overrides(args.opts, use_color=True) |
|
|
) |
|
|
os.makedirs(config.job.path_local, exist_ok=True) |
|
|
LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") |
|
|
print(f"{config.job.path_local}/config.yaml") |
|
|
else: |
|
|
|
|
|
launch(config, args) |
|
|
|