| from causvid.data import TextDataset | |
| import torch | |
| dataset = TextDataset("sample_dataset/captions_coco14_test.txt") | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, batch_size=32, shuffle=False, drop_last=True) | |
| for batch in dataloader: | |
| print( | |
| f"batch element type {type(batch[0])} batch length {len(batch)} batch first element {batch[0]}") | |
| break | |