File size: 7,702 Bytes
b386992 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import logging
from typing import List, Literal, Optional
import lightning.pytorch as pl
from torch.utils.data import DataLoader
from nemo.lightning.megatron_parallel import MegatronStep
class DataSampler:
def connect(self, trainer: pl.Trainer):
self.trainer = trainer
def setup(self, global_rank: int) -> None:
raise NotImplementedError()
def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0) -> DataLoader:
raise NotImplementedError()
class MegatronDataSampler(DataSampler):
def __init__(
self,
seq_len: int,
micro_batch_size: int = 4,
global_batch_size: int = 8,
rampup_batch_size: Optional[List[int]] = None,
dataloader_type: Literal["single", "cyclic", "batch"] = "single",
init_consumed_samples: int = 0,
init_global_step: int = 0,
output_log: bool = True,
decoder_seq_len: Optional[int] = None,
):
self.seq_len = seq_len
self.decoder_seq_len = decoder_seq_len
self.output_log = output_log
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.rampup_batch_size = rampup_batch_size
self.dataloader_type = dataloader_type
self.init_consumed_samples = init_consumed_samples
self.prev_consumed_samples = self.init_consumed_samples
self.if_first_step = 0
self.prev_global_batch_size = None
self.init_global_step = init_global_step
def setup(self, global_rank: int) -> None:
from nemo.lightning.data import setup_microbatch_calculator
setup_microbatch_calculator(global_rank, self.micro_batch_size, self.global_batch_size, self.rampup_batch_size)
def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0) -> DataLoader:
from megatron.core import parallel_state
from nemo.lightning.data import add_megatron_sampler
mode = getattr(dataloader, 'mode', 'train')
data_parallel_rank = parallel_state.get_data_parallel_rank()
data_parallel_size = parallel_state.get_data_parallel_world_size()
return add_megatron_sampler(
dataloader,
micro_batch_size=self.micro_batch_size,
global_batch_size=self.global_batch_size,
rampup_batch_size=self.rampup_batch_size,
consumed_samples=self.init_consumed_samples if mode == 'train' else 0,
dataloader_type=self.dataloader_type,
drop_last=mode not in ["test", "predict"], # don't drop the incomplete batch in test and predict methods
dataloader_mode=mode, # dataloader wrapped with nemo.lightning.data.WrappedDataLoader has mode attribute
rank=data_parallel_rank,
world_size=data_parallel_size,
)
def compute_consumed_samples(self, steps_since_resume=0) -> int:
from nemo.lightning.pytorch.strategies import MegatronStrategy
from nemo.utils import AppState
if not hasattr(self, "trainer") or not isinstance(self.trainer.strategy, MegatronStrategy):
return 0
app_state = AppState()
if self.rampup_batch_size is not None:
consumed_samples = self.prev_consumed_samples + self.if_first_step * self.current_global_batch_size
else:
consumed_samples = (
self.init_consumed_samples
+ steps_since_resume * app_state.data_parallel_size * self.micro_batch_size * self.num_microbatches
)
return int(consumed_samples)
# Megatron callbacks
def on_megatron_step_start(self, step: MegatronStep) -> MegatronStep:
return dataclasses.replace(
step,
seq_length=self.seq_len,
micro_batch_size=self.micro_batch_size,
num_microbatches=self.num_microbatches,
decoder_seq_length=self.decoder_seq_len,
)
def on_megatron_microbatches_start(self, step: MegatronStep) -> None:
if not step.trainer:
return
# do validation and save the checkpoint when gbs is changed
if (
self.rampup_batch_size is not None
and self.prev_global_batch_size != self.current_global_batch_size
and self.prev_global_batch_size
):
step.trainer.should_stop = True
def on_megatron_step_end(self, step: MegatronStep) -> None:
trainer = step.trainer
pl_module = step.pl_module
try:
from megatron.core.num_microbatches_calculator import update_num_microbatches
except (ImportError, ModuleNotFoundError):
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
from apex.transformer.pipeline_parallel.utils import update_num_microbatches
self.prev_global_batch_size = self.current_global_batch_size
if step.step_i:
consumed_samples = self.compute_consumed_samples(step.step_i + 1 - self.init_global_step)
if self.output_log and trainer and getattr(trainer, "training", False):
# You may need to turn off logging, for example when doing trainer.predict(model, data)
pl_module.log(
'consumed_samples',
consumed_samples,
prog_bar=True,
batch_size=1,
)
self.prev_consumed_samples = consumed_samples
update_num_microbatches(
consumed_samples=consumed_samples,
consistency_check=False,
)
if self.output_log and trainer:
# You may need to turn off logging, for example when doing trainer.predict(model, data)
pl_module.log(
"global_batch_size",
self.current_global_batch_size,
prog_bar=True,
batch_size=1,
)
self.if_first_step = 1
@property
def num_microbatches(self) -> int:
try:
from megatron.core.num_microbatches_calculator import get_num_microbatches
except (ImportError, ModuleNotFoundError):
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
return get_num_microbatches()
@property
def current_global_batch_size(self) -> int:
try:
from megatron.core.num_microbatches_calculator import get_current_global_batch_size
except (ImportError, ModuleNotFoundError):
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
from apex.transformer.pipeline_parallel.utils import get_current_global_batch_size
if get_current_global_batch_size():
current_global_batch_size = get_current_global_batch_size()
else:
current_global_batch_size = 1
return current_global_batch_size
|