Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import logging
from typing import List, Literal, Optional
import lightning.pytorch as pl
from torch.utils.data import DataLoader
from nemo.lightning.megatron_parallel import MegatronStep
class DataSampler:
def connect(self, trainer: pl.Trainer):
self.trainer = trainer
def setup(self, global_rank: int) -> None:
raise NotImplementedError()
def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0) -> DataLoader:
raise NotImplementedError()
class MegatronDataSampler(DataSampler):
def __init__(
self,
seq_len: int,
micro_batch_size: int = 4,
global_batch_size: int = 8,
rampup_batch_size: Optional[List[int]] = None,
dataloader_type: Literal["single", "cyclic", "batch"] = "single",
init_consumed_samples: int = 0,
init_global_step: int = 0,
output_log: bool = True,
decoder_seq_len: Optional[int] = None,
):
self.seq_len = seq_len
self.decoder_seq_len = decoder_seq_len
self.output_log = output_log
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.rampup_batch_size = rampup_batch_size
self.dataloader_type = dataloader_type
self.init_consumed_samples = init_consumed_samples
self.prev_consumed_samples = self.init_consumed_samples
self.if_first_step = 0
self.prev_global_batch_size = None
self.init_global_step = init_global_step
def setup(self, global_rank: int) -> None:
from nemo.lightning.data import setup_microbatch_calculator
setup_microbatch_calculator(global_rank, self.micro_batch_size, self.global_batch_size, self.rampup_batch_size)
def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0) -> DataLoader:
from megatron.core import parallel_state
from nemo.lightning.data import add_megatron_sampler
mode = getattr(dataloader, 'mode', 'train')
data_parallel_rank = parallel_state.get_data_parallel_rank()
data_parallel_size = parallel_state.get_data_parallel_world_size()
return add_megatron_sampler(
dataloader,
micro_batch_size=self.micro_batch_size,
global_batch_size=self.global_batch_size,
rampup_batch_size=self.rampup_batch_size,
consumed_samples=self.init_consumed_samples if mode == 'train' else 0,
dataloader_type=self.dataloader_type,
drop_last=mode not in ["test", "predict"], # don't drop the incomplete batch in test and predict methods
dataloader_mode=mode, # dataloader wrapped with nemo.lightning.data.WrappedDataLoader has mode attribute
rank=data_parallel_rank,
world_size=data_parallel_size,
)
def compute_consumed_samples(self, steps_since_resume=0) -> int:
from nemo.lightning.pytorch.strategies import MegatronStrategy
from nemo.utils import AppState
if not hasattr(self, "trainer") or not isinstance(self.trainer.strategy, MegatronStrategy):
return 0
app_state = AppState()
if self.rampup_batch_size is not None:
consumed_samples = self.prev_consumed_samples + self.if_first_step * self.current_global_batch_size
else:
consumed_samples = (
self.init_consumed_samples
+ steps_since_resume * app_state.data_parallel_size * self.micro_batch_size * self.num_microbatches
)
return int(consumed_samples)
# Megatron callbacks
def on_megatron_step_start(self, step: MegatronStep) -> MegatronStep:
return dataclasses.replace(
step,
seq_length=self.seq_len,
micro_batch_size=self.micro_batch_size,
num_microbatches=self.num_microbatches,
decoder_seq_length=self.decoder_seq_len,
)
def on_megatron_microbatches_start(self, step: MegatronStep) -> None:
if not step.trainer:
return
# do validation and save the checkpoint when gbs is changed
if (
self.rampup_batch_size is not None
and self.prev_global_batch_size != self.current_global_batch_size
and self.prev_global_batch_size
):
step.trainer.should_stop = True
def on_megatron_step_end(self, step: MegatronStep) -> None:
trainer = step.trainer
pl_module = step.pl_module
try:
from megatron.core.num_microbatches_calculator import update_num_microbatches
except (ImportError, ModuleNotFoundError):
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
from apex.transformer.pipeline_parallel.utils import update_num_microbatches
self.prev_global_batch_size = self.current_global_batch_size
if step.step_i:
consumed_samples = self.compute_consumed_samples(step.step_i + 1 - self.init_global_step)
if self.output_log and trainer and getattr(trainer, "training", False):
# You may need to turn off logging, for example when doing trainer.predict(model, data)
pl_module.log(
'consumed_samples',
consumed_samples,
prog_bar=True,
batch_size=1,
)
self.prev_consumed_samples = consumed_samples
update_num_microbatches(
consumed_samples=consumed_samples,
consistency_check=False,
)
if self.output_log and trainer:
# You may need to turn off logging, for example when doing trainer.predict(model, data)
pl_module.log(
"global_batch_size",
self.current_global_batch_size,
prog_bar=True,
batch_size=1,
)
self.if_first_step = 1
@property
def num_microbatches(self) -> int:
try:
from megatron.core.num_microbatches_calculator import get_num_microbatches
except (ImportError, ModuleNotFoundError):
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
return get_num_microbatches()
@property
def current_global_batch_size(self) -> int:
try:
from megatron.core.num_microbatches_calculator import get_current_global_batch_size
except (ImportError, ModuleNotFoundError):
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
from apex.transformer.pipeline_parallel.utils import get_current_global_batch_size
if get_current_global_batch_size():
current_global_batch_size = get_current_global_batch_size()
else:
current_global_batch_size = 1
return current_global_batch_size