# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from lightning.pytorch.callbacks.progress import TQDMProgressBar from lightning.pytorch.callbacks.progress.tqdm_progress import _update_n class MegatronProgressBar(TQDMProgressBar): """ Add MegatronProgressBar to remove 's/it' and display progress per step instead of per microbatch for megatron models. """ def init_train_tqdm(self): """ Override bar_format to not have 's/it'. """ self.bar = super().init_train_tqdm() self.bar.bar_format = "{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}{postfix}]" return self.bar def on_train_epoch_start(self, trainer, *_): if trainer.max_steps > 0: # and (trainer.ckpt_path is not None): # while resuming from a ckpt use trainer.max_steps as the total for progress bar as trainer.num_training_batches # is truncated to max_steps - step being resumed at num_training_batches = trainer.max_steps else: num_training_batches = trainer.num_training_batches self.train_progress_bar.reset(num_training_batches) self.train_progress_bar.initial = 0 self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}") def on_train_batch_end(self, trainer, pl_module, *_, **__): """ Override parent class on_train_batch_end to update progress bar per global batch instead of per microbatch. """ n = trainer.strategy.current_epoch_step if self._should_update(n, self.train_progress_bar.total): _update_n(self.train_progress_bar, n) self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module), refresh=False) def calculate_data_parallel_groups() -> int: from nemo.utils import AppState app_state = AppState() pipeline_model_parallel_size = app_state.pipeline_model_parallel_size tensor_model_parallel_size = app_state.tensor_model_parallel_size world_size = app_state.world_size data_parallel_group_len = world_size // (pipeline_model_parallel_size * tensor_model_parallel_size) return world_size // data_parallel_group_len