Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Any | |
| import torch | |
| from cosmos_predict1.utils.lazy_config import LazyDict, instantiate | |
| class Model(torch.nn.Module): | |
| """The base model class. It is inherited from torch.nn.Module. | |
| All models should inherit Model. It should include the implementions for all the | |
| computation graphs. All inheriting child classes should implement the following methods: | |
| - training_step(): The training step of the model, including the loss computation. | |
| - validation_step(): The validation step of the model, including the loss computation. | |
| - forward(): The computation graph for model inference. | |
| The following methods have default implementations in Model: | |
| - init_optimizer_scheduler(): Creates the optimizer and scheduler for the model. | |
| """ | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.on_model_init_start(set_barrier=False) | |
| def init_optimizer_scheduler( | |
| self, optimizer_config: LazyDict, scheduler_config: LazyDict | |
| ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: | |
| """Creates the optimizer and scheduler for the model. | |
| Args: | |
| config_model (ModelConfig): The config object for the model. | |
| Returns: | |
| optimizer (torch.optim.Optimizer): The model optimizer. | |
| scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. | |
| """ | |
| optimizer_config.params = self.parameters() | |
| optimizer = instantiate(optimizer_config) | |
| scheduler_config.optimizer = optimizer | |
| scheduler = instantiate(scheduler_config) | |
| return optimizer, scheduler | |
| def training_step( | |
| self, data_batch: dict[str, torch.Tensor], iteration: int | |
| ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: | |
| """The training step of the model, including the loss computation. | |
| Args: | |
| data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). | |
| iteration (int): Current iteration number. | |
| Returns: | |
| output_batch (dict[str, torch.Tensor]): Auxiliary model output from the training batch. | |
| loss (torch.Tensor): The total loss for backprop (weighted sum of various losses). | |
| """ | |
| raise NotImplementedError | |
| def validation_step( | |
| self, data_batch: dict[str, torch.Tensor], iteration: int | |
| ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: | |
| """The validation step of the model, including the loss computation. | |
| Args: | |
| data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). | |
| iteration (int): Current iteration number. | |
| Returns: | |
| output_batch (dict[str, torch.Tensor]): Auxiliary model output from the validation batch. | |
| loss (torch.Tensor): The total loss (weighted sum of various losses). | |
| """ | |
| raise NotImplementedError | |
| def forward(self, *args: Any, **kwargs: Any) -> Any: | |
| """The computation graph for model inference. | |
| Args: | |
| *args: Whatever you decide to pass into the forward method. | |
| **kwargs: Keyword arguments are also possible. | |
| Return: | |
| Your model's output. | |
| """ | |
| raise NotImplementedError | |
| def on_model_init_start(self, set_barrier=False) -> None: | |
| return | |
| def on_model_init_end(self, set_barrier=False) -> None: | |
| return | |
| def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: | |
| """The model preparation before the training is launched | |
| Args: | |
| memory_format (torch.memory_format): Memory format of the model. | |
| """ | |
| pass | |
| def on_before_zero_grad( | |
| self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int | |
| ) -> None: | |
| """Hook before zero_grad() is called. | |
| Args: | |
| optimizer (torch.optim.Optimizer): The model optimizer. | |
| scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. | |
| iteration (int): Current iteration number. | |
| """ | |
| pass | |
| def on_after_backward(self, iteration: int = 0) -> None: | |
| """Hook after loss.backward() is called. | |
| This method is called immediately after the backward pass, allowing for custom operations | |
| or modifications to be performed on the gradients before the optimizer step. | |
| Args: | |
| iteration (int): Current iteration number. | |
| """ | |
| pass | |