| import os |
| import glob |
| import random |
| from typing import Optional, Callable |
| from pathlib import Path |
|
|
| import torch |
| from torch.utils.data import DataLoader |
| from torch.utils.data.distributed import DistributedSampler |
|
|
|
|
| class PretextDataLoader: |
| """Efficient dataloader with prefetching for Codsworth training.""" |
| |
| def __init__( |
| self, |
| dataset, |
| batch_size: int = 32, |
| shuffle: bool = True, |
| num_workers: int = 4, |
| pin_memory: bool = True, |
| persistent_workers: bool = True, |
| prefetch_factor: int = 2, |
| drop_last: bool = True, |
| global_rank: int = 0, |
| world_size: int = 1, |
| ): |
| self.dataset = dataset |
| self.batch_size = batch_size |
| self.num_workers = num_workers |
| self.world_size = world_size |
| self.global_rank = global_rank |
| |
| sampler = None |
| if world_size > 1: |
| sampler = DistributedSampler( |
| dataset, |
| num_replicas=world_size, |
| rank=global_rank, |
| shuffle=shuffle, |
| ) |
| shuffle = False |
| |
| self.dataloader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| sampler=sampler, |
| num_workers=num_workers, |
| pin_memory=pin_memory, |
| persistent_workers=persistent_workers if num_workers > 0 else False, |
| prefetch_factor=prefetch_factor if num_workers > 0 else None, |
| drop_last=drop_last, |
| ) |
| |
| def __iter__(self): |
| return iter(self.dataloader) |
| |
| def __len__(self) -> int: |
| return len(self.dataloader) |
|
|
|
|
| class DataCollator: |
| """Collator for batching training data.""" |
| |
| def __init__( |
| self, |
| pad_token_id: int = 0, |
| label_pad_token_id: int = -100, |
| ): |
| self.pad_token_id = pad_token_id |
| self.label_pad_token_id = label_pad_token_id |
| |
| def __call__(self, batch: list[dict]) -> dict: |
| input_ids = torch.stack([item["input_ids"] for item in batch]) |
| labels = torch.stack([item["labels"] for item in batch]) |
| |
| labels = labels.masked_fill(labels == self.pad_token_id, self.label_pad_token_id) |
| |
| return { |
| "input_ids": input_ids, |
| "labels": labels, |
| } |
|
|
|
|
| def create_distributed_dataloader( |
| dataset, |
| batch_size: int = 32, |
| shuffle: bool = True, |
| num_workers: int = 4, |
| pin_memory: bool = True, |
| persistent_workers: bool = True, |
| prefetch_factor: int = 2, |
| drop_last: bool = True, |
| seed: int = 42, |
| ) -> PretextDataLoader: |
| rank = int(os.environ.get("RANK", 0)) |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) |
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| |
| random.seed(seed + rank) |
| torch.manual_seed(seed + rank) |
| |
| return PretextDataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| num_workers=num_workers, |
| pin_memory=pin_memory, |
| persistent_workers=persistent_workers, |
| prefetch_factor=prefetch_factor, |
| drop_last=drop_last, |
| global_rank=rank, |
| world_size=world_size, |
| ) |
|
|
|
|
| def optimize_dataloader( |
| dataloader: DataLoader, |
| batch_size: int = 32, |
| gradient_accumulation_steps: int = 1, |
| ) -> dict: |
| total_batch_size = batch_size * gradient_accumulation_steps |
| |
| optimal_batch_size = get_optimal_batch_size(total_batch_size) |
| |
| return { |
| "batch_size": optimal_batch_size, |
| "gradient_accumulation_steps": total_batch_size // optimal_batch_size, |
| } |
|
|
|
|
| def get_optimal_batch_size(target_batch_size: int) -> int: |
| if not torch.cuda.is_available(): |
| return target_batch_size |
| |
| free_memory = torch.cuda.get_device_properties(0).total_memory |
| estimated_model_memory = 350_000_000 * 4 |
| |
| available_memory = free_memory - estimated_model_memory |
| |
| bytes_per_param = 4 |
| max_params_in_memory = available_memory // bytes_per_param |
| |
| if max_params_in_memory > 350_000_000: |
| return target_batch_size |
| elif max_params_in_memory > 175_000_000: |
| return max(1, target_batch_size // 2) |
| else: |
| return max(1, target_batch_size // 4) |