coco-demo / data /loaders.py
evanec's picture
Upload 6 files
41bb8f7 verified
"""
loaders.py
Entry point for accessing dataset. Creates PyTorch DataLoaders for the COCO captioning subset.
"""
from torch.utils.data import DataLoader
from data.coco_dataset import CocoCaptionDataset
from data.collate import CocoCollator
def get_coco_dataloaders(
batch_size=16,
image_size=224,
num_workers=8,
tokenizer_name="t5-small",
max_caption_length=64,
data_dir="data/processed",
normalize=True
):
# Build datasets
train_ds = CocoCaptionDataset(
split="train",
image_size=image_size,
tokenizer_name=tokenizer_name,
max_caption_length=max_caption_length,
data_dir=data_dir,
random_caption=True,
normalize=normalize,
)
val_ds = CocoCaptionDataset(
split="val",
image_size=image_size,
tokenizer_name=tokenizer_name,
max_caption_length=max_caption_length,
data_dir=data_dir,
random_caption=False,
normalize=normalize
)
test_ds = CocoCaptionDataset(
split="test",
image_size=image_size,
tokenizer_name=tokenizer_name,
max_caption_length=max_caption_length,
data_dir=data_dir,
random_caption=False,
normalize=normalize
)
# Collator (padding for captions)
pad_token_id = train_ds.tokenizer.pad_token_id
collator = CocoCollator(pad_token_id=pad_token_id)
# DataLoaders
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
collate_fn=collator,
pin_memory=True,
)
val_loader = DataLoader(
val_ds,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collator,
pin_memory=True,
)
test_loader = DataLoader(
test_ds,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collator,
pin_memory=True,
)
return train_loader, val_loader, test_loader