""" loaders.py Entry point for accessing dataset. Creates PyTorch DataLoaders for the COCO captioning subset. """ from torch.utils.data import DataLoader from data.coco_dataset import CocoCaptionDataset from data.collate import CocoCollator def get_coco_dataloaders( batch_size=16, image_size=224, num_workers=8, tokenizer_name="t5-small", max_caption_length=64, data_dir="data/processed", normalize=True ): # Build datasets train_ds = CocoCaptionDataset( split="train", image_size=image_size, tokenizer_name=tokenizer_name, max_caption_length=max_caption_length, data_dir=data_dir, random_caption=True, normalize=normalize, ) val_ds = CocoCaptionDataset( split="val", image_size=image_size, tokenizer_name=tokenizer_name, max_caption_length=max_caption_length, data_dir=data_dir, random_caption=False, normalize=normalize ) test_ds = CocoCaptionDataset( split="test", image_size=image_size, tokenizer_name=tokenizer_name, max_caption_length=max_caption_length, data_dir=data_dir, random_caption=False, normalize=normalize ) # Collator (padding for captions) pad_token_id = train_ds.tokenizer.pad_token_id collator = CocoCollator(pad_token_id=pad_token_id) # DataLoaders train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collator, pin_memory=True, ) val_loader = DataLoader( val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collator, pin_memory=True, ) test_loader = DataLoader( test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collator, pin_memory=True, ) return train_loader, val_loader, test_loader