import os import glob import random from typing import Optional, Callable from pathlib import Path import torch from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler class PretextDataLoader: """Efficient dataloader with prefetching for Codsworth training.""" def __init__( self, dataset, batch_size: int = 32, shuffle: bool = True, num_workers: int = 4, pin_memory: bool = True, persistent_workers: bool = True, prefetch_factor: int = 2, drop_last: bool = True, global_rank: int = 0, world_size: int = 1, ): self.dataset = dataset self.batch_size = batch_size self.num_workers = num_workers self.world_size = world_size self.global_rank = global_rank sampler = None if world_size > 1: sampler = DistributedSampler( dataset, num_replicas=world_size, rank=global_rank, shuffle=shuffle, ) shuffle = False self.dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers if num_workers > 0 else False, prefetch_factor=prefetch_factor if num_workers > 0 else None, drop_last=drop_last, ) def __iter__(self): return iter(self.dataloader) def __len__(self) -> int: return len(self.dataloader) class DataCollator: """Collator for batching training data.""" def __init__( self, pad_token_id: int = 0, label_pad_token_id: int = -100, ): self.pad_token_id = pad_token_id self.label_pad_token_id = label_pad_token_id def __call__(self, batch: list[dict]) -> dict: input_ids = torch.stack([item["input_ids"] for item in batch]) labels = torch.stack([item["labels"] for item in batch]) labels = labels.masked_fill(labels == self.pad_token_id, self.label_pad_token_id) return { "input_ids": input_ids, "labels": labels, } def create_distributed_dataloader( dataset, batch_size: int = 32, shuffle: bool = True, num_workers: int = 4, pin_memory: bool = True, persistent_workers: bool = True, prefetch_factor: int = 2, drop_last: bool = True, seed: int = 42, ) -> PretextDataLoader: rank = int(os.environ.get("RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) random.seed(seed + rank) torch.manual_seed(seed + rank) return PretextDataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=persistent_workers, prefetch_factor=prefetch_factor, drop_last=drop_last, global_rank=rank, world_size=world_size, ) def optimize_dataloader( dataloader: DataLoader, batch_size: int = 32, gradient_accumulation_steps: int = 1, ) -> dict: total_batch_size = batch_size * gradient_accumulation_steps optimal_batch_size = get_optimal_batch_size(total_batch_size) return { "batch_size": optimal_batch_size, "gradient_accumulation_steps": total_batch_size // optimal_batch_size, } def get_optimal_batch_size(target_batch_size: int) -> int: if not torch.cuda.is_available(): return target_batch_size free_memory = torch.cuda.get_device_properties(0).total_memory estimated_model_memory = 350_000_000 * 4 available_memory = free_memory - estimated_model_memory bytes_per_param = 4 max_params_in_memory = available_memory // bytes_per_param if max_params_in_memory > 350_000_000: return target_batch_size elif max_params_in_memory > 175_000_000: return max(1, target_batch_size // 2) else: return max(1, target_batch_size // 4)