codsworth-3.8m / codsworth /train /dataloader.py
Jaqshanahan's picture
Initial upload of Codsworth model
b84d85a verified
import os
import glob
import random
from typing import Optional, Callable
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
class PretextDataLoader:
"""Efficient dataloader with prefetching for Codsworth training."""
def __init__(
self,
dataset,
batch_size: int = 32,
shuffle: bool = True,
num_workers: int = 4,
pin_memory: bool = True,
persistent_workers: bool = True,
prefetch_factor: int = 2,
drop_last: bool = True,
global_rank: int = 0,
world_size: int = 1,
):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.world_size = world_size
self.global_rank = global_rank
sampler = None
if world_size > 1:
sampler = DistributedSampler(
dataset,
num_replicas=world_size,
rank=global_rank,
shuffle=shuffle,
)
shuffle = False
self.dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers if num_workers > 0 else False,
prefetch_factor=prefetch_factor if num_workers > 0 else None,
drop_last=drop_last,
)
def __iter__(self):
return iter(self.dataloader)
def __len__(self) -> int:
return len(self.dataloader)
class DataCollator:
"""Collator for batching training data."""
def __init__(
self,
pad_token_id: int = 0,
label_pad_token_id: int = -100,
):
self.pad_token_id = pad_token_id
self.label_pad_token_id = label_pad_token_id
def __call__(self, batch: list[dict]) -> dict:
input_ids = torch.stack([item["input_ids"] for item in batch])
labels = torch.stack([item["labels"] for item in batch])
labels = labels.masked_fill(labels == self.pad_token_id, self.label_pad_token_id)
return {
"input_ids": input_ids,
"labels": labels,
}
def create_distributed_dataloader(
dataset,
batch_size: int = 32,
shuffle: bool = True,
num_workers: int = 4,
pin_memory: bool = True,
persistent_workers: bool = True,
prefetch_factor: int = 2,
drop_last: bool = True,
seed: int = 42,
) -> PretextDataLoader:
rank = int(os.environ.get("RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
random.seed(seed + rank)
torch.manual_seed(seed + rank)
return PretextDataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
drop_last=drop_last,
global_rank=rank,
world_size=world_size,
)
def optimize_dataloader(
dataloader: DataLoader,
batch_size: int = 32,
gradient_accumulation_steps: int = 1,
) -> dict:
total_batch_size = batch_size * gradient_accumulation_steps
optimal_batch_size = get_optimal_batch_size(total_batch_size)
return {
"batch_size": optimal_batch_size,
"gradient_accumulation_steps": total_batch_size // optimal_batch_size,
}
def get_optimal_batch_size(target_batch_size: int) -> int:
if not torch.cuda.is_available():
return target_batch_size
free_memory = torch.cuda.get_device_properties(0).total_memory
estimated_model_memory = 350_000_000 * 4
available_memory = free_memory - estimated_model_memory
bytes_per_param = 4
max_params_in_memory = available_memory // bytes_per_param
if max_params_in_memory > 350_000_000:
return target_batch_size
elif max_params_in_memory > 175_000_000:
return max(1, target_batch_size // 2)
else:
return max(1, target_batch_size // 4)