|
|
""" |
|
|
loaders.py |
|
|
|
|
|
Entry point for accessing dataset. Creates PyTorch DataLoaders for the COCO captioning subset. |
|
|
|
|
|
""" |
|
|
|
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from data.coco_dataset import CocoCaptionDataset |
|
|
from data.collate import CocoCollator |
|
|
|
|
|
|
|
|
def get_coco_dataloaders( |
|
|
batch_size=16, |
|
|
image_size=224, |
|
|
num_workers=8, |
|
|
tokenizer_name="t5-small", |
|
|
max_caption_length=64, |
|
|
data_dir="data/processed", |
|
|
normalize=True |
|
|
): |
|
|
|
|
|
|
|
|
train_ds = CocoCaptionDataset( |
|
|
split="train", |
|
|
image_size=image_size, |
|
|
tokenizer_name=tokenizer_name, |
|
|
max_caption_length=max_caption_length, |
|
|
data_dir=data_dir, |
|
|
random_caption=True, |
|
|
normalize=normalize, |
|
|
) |
|
|
|
|
|
val_ds = CocoCaptionDataset( |
|
|
split="val", |
|
|
image_size=image_size, |
|
|
tokenizer_name=tokenizer_name, |
|
|
max_caption_length=max_caption_length, |
|
|
data_dir=data_dir, |
|
|
random_caption=False, |
|
|
normalize=normalize |
|
|
) |
|
|
|
|
|
test_ds = CocoCaptionDataset( |
|
|
split="test", |
|
|
image_size=image_size, |
|
|
tokenizer_name=tokenizer_name, |
|
|
max_caption_length=max_caption_length, |
|
|
data_dir=data_dir, |
|
|
random_caption=False, |
|
|
normalize=normalize |
|
|
) |
|
|
|
|
|
|
|
|
pad_token_id = train_ds.tokenizer.pad_token_id |
|
|
collator = CocoCollator(pad_token_id=pad_token_id) |
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_ds, |
|
|
batch_size=batch_size, |
|
|
shuffle=True, |
|
|
num_workers=num_workers, |
|
|
collate_fn=collator, |
|
|
pin_memory=True, |
|
|
) |
|
|
|
|
|
val_loader = DataLoader( |
|
|
val_ds, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=num_workers, |
|
|
collate_fn=collator, |
|
|
pin_memory=True, |
|
|
) |
|
|
|
|
|
test_loader = DataLoader( |
|
|
test_ds, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=num_workers, |
|
|
collate_fn=collator, |
|
|
pin_memory=True, |
|
|
) |
|
|
|
|
|
return train_loader, val_loader, test_loader |
|
|
|