Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import hydra | |
| import torch | |
| from torch import nn | |
| from cosmos_predict1.utils import log | |
| from cosmos_predict1.utils.fused_adam import FusedAdam | |
| def get_regular_param_group(net: nn.Module): | |
| """ | |
| seperate the parameters of the network into two groups: decay and no_decay. | |
| based on nano_gpt codebase. | |
| """ | |
| param_dict = {pn: p for pn, p in net.named_parameters()} | |
| param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} | |
| decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] | |
| nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] | |
| return decay_params, nodecay_params | |
| def get_base_optimizer( | |
| model: nn.Module, | |
| lr: float, | |
| weight_decay: float, | |
| optim_type: str = "adamw", | |
| sharding: bool = False, | |
| **kwargs, | |
| ) -> torch.optim.Optimizer: | |
| net_decay_param, net_nodecay_param = get_regular_param_group(model) | |
| num_decay_params = sum(p.numel() for p in net_decay_param) | |
| num_nodecay_params = sum(p.numel() for p in net_nodecay_param) | |
| net_param_total = num_decay_params + num_nodecay_params | |
| log.critical(f"total num parameters : {net_param_total:,}") | |
| param_group = [ | |
| { | |
| "params": net_decay_param + net_nodecay_param, | |
| "lr": lr, | |
| "weight_decay": weight_decay, | |
| }, | |
| ] | |
| if optim_type == "adamw": | |
| opt_cls = torch.optim.AdamW | |
| elif optim_type == "fusedadam": | |
| opt_cls = FusedAdam | |
| else: | |
| raise ValueError(f"Unknown optimizer type: {optim_type}") | |
| return opt_cls(param_group, **kwargs) | |
| def get_base_scheduler( | |
| optimizer: torch.optim.Optimizer, | |
| model: nn.Module, | |
| scheduler_config: dict, | |
| ): | |
| net_scheduler = hydra.utils.instantiate(scheduler_config) | |
| net_scheduler.model = model | |
| return torch.optim.lr_scheduler.LambdaLR( | |
| optimizer, | |
| lr_lambda=[ | |
| net_scheduler.schedule, | |
| ], | |
| ) | |