File size: 8,186 Bytes
b386992 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Literal
from lightning.pytorch.utilities.types import EVAL_DATALOADERS
from megatron.core import parallel_state
from megatron.energon import DefaultTaskEncoder, WorkerConfig, get_savable_loader, get_train_dataset
from nemo.collections.multimodal.data.energon.base import EnergonMultiModalDataModule
class DiffusionDataModule(EnergonMultiModalDataModule):
"""
A PyTorch Lightning DataModule for handling multimodal datasets with images and text.
This data module is designed to work with multimodal datasets that involve both images and text.
It provides a seamless interface to load training and validation data, manage batching, and handle
the state of the data pipeline across training epochs. The module integrates with the Megatron-Energon
framework for efficient data handling in large-scale distributed training.
Attributes:
path (str): Path to the energon dataset.
tokenizer (Tokenizer): The tokenizer used for processing text.
image_processor (ImageProcessor): The image processor used for preprocessing images.
seq_length (int): The maximum sequence length for tokenized text.
micro_batch_size (int): The batch size for training and validation.
num_workers (int): Number of workers for data loading.
pin_memory (bool): Whether to pin memory in the DataLoader.
multimodal_sample_config (MultiModalSampleConfig): Configuration object for multimodal samples.
task_encoder (MultiModalTaskEncoder): Encoder responsible for encoding and batching samples.
init_global_step (int): The initial global step for the trainer, used for resuming training.
data_sampler (SequentialMegatronSampler): Sampler responsible for generating sequential samples.
train_dataloader_object (Optional): The DataLoader object for training data.
val_dataloader_object (Optional): The DataLoader object for validation data.
"""
def __init__(
self,
path: str,
seq_length: int = 2048,
micro_batch_size: int = 1,
global_batch_size: int = 8,
num_workers: int = 1,
pin_memory: bool = True,
task_encoder: DefaultTaskEncoder = None,
use_train_split_for_val: bool = False,
virtual_epoch_length: int = 1_000_000_000, # a hack to avoid energon end of epoch warning
packing_buffer_size: int | None = None,
max_samples_per_sequence: int | None = None,
) -> None:
"""
Initialize the EnergonMultiModalDataModule.
Parameters:
path (str): Path to the dataset.
tokenizer (Tokenizer): The tokenizer used for processing text.
image_processor (ImageProcessor): The image processor used for preprocessing images.
seq_length (int, optional): The maximum sequence length for tokenized text. Defaults to 2048.
micro_batch_size (int, optional): The batch size for training and validation. Defaults to 1.
num_workers (int, optional): Number of workers for data loading. Defaults to 1.
pin_memory (bool, optional): Whether to pin memory in the DataLoader. Defaults to True.
"""
super().__init__(
path=path,
tokenizer=None,
image_processor=None,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
task_encoder=task_encoder,
)
self.use_train_split_for_val = use_train_split_for_val
self.virtual_epoch_length = virtual_epoch_length
self.num_workers_val = 1
self.packing_buffer_size = packing_buffer_size
self.max_samples_per_sequence = max_samples_per_sequence
def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val'):
"""
Provide the dataset for training or validation.
This method retrieves the dataset for the specified split (either 'train' or 'val') and configures
it according to the worker configuration.
Parameters:
worker_config: Configuration for the data loader workers.
split (Literal['train', 'val'], optional): The data split to retrieve ('train' or 'val'). Defaults to 'val'.
Returns:
Dataset: The dataset configured for the specified split.
"""
if split not in {'train', 'val'}:
raise ValueError("Invalid value for split. Allowed values are 'train' or 'val'.")
if self.use_train_split_for_val:
split = 'train'
_dataset = get_train_dataset(
self.path,
batch_size=self.micro_batch_size,
task_encoder=self.task_encoder,
worker_config=worker_config,
max_samples_per_sequence=self.max_samples_per_sequence,
shuffle_buffer_size=None,
split_part=split,
virtual_epoch_length=self.virtual_epoch_length,
packing_buffer_size=self.packing_buffer_size,
)
return _dataset
def val_dataloader(self) -> EVAL_DATALOADERS:
"""
Initialize and return the validation DataLoader.
This method initializes the DataLoader for the validation dataset. It ensures that the parallel state
is initialized correctly for distributed training and returns a configured DataLoader object.
Returns:
EVAL_DATALOADERS: The DataLoader for the validation dataset.
"""
if self.use_train_split_for_val:
return self.train_dataloader()
if self.val_dataloader_object:
return self.val_dataloader_object
if not parallel_state.is_initialized():
message = (
"Muiltimodal val data loader parallel state is not initialized "
f"using default worker config with no_workers {self.num_workers}"
)
logging.info(message)
worker_config = WorkerConfig.default_worker_config(self.num_workers_val)
else:
rank = parallel_state.get_data_parallel_rank()
world_size = parallel_state.get_data_parallel_world_size()
data_parallel_group = parallel_state.get_data_parallel_group()
logging.info(f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group}")
worker_config = WorkerConfig(
rank=rank,
world_size=world_size,
num_workers=self.num_workers_val,
data_parallel_group=data_parallel_group,
worker_debug_path=None,
worker_log_level=0,
)
val_dataset = self.datasets_provider(worker_config, split='val')
energon_loader = get_savable_loader(val_dataset, worker_config=worker_config)
self.val_dataloader_object = energon_loader
return self.val_dataloader_object
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
Load the state of the data module from a checkpoint.
This method is called when loading a checkpoint. It restores the state of the data module,
including the state of the dataloader and the number of consumed samples.
Parameters:
state_dict (Dict[str, Any]): The state dictionary containing the saved state of the data module.
"""
try:
super().load_state_dict(state_dict)
except Exception as e:
logging.warning(f"datamodule.load_state_dict failed {e}")
|