| import options as opt |
| import matplotlib.pyplot as plt |
| import torch.optim as optim |
| import numpy as np |
| import time |
|
|
| from dataset import GridDataset |
| from torch.utils.data import DataLoader |
|
|
|
|
| def dataset2dataloader( |
| dataset, num_workers=opt.num_workers, shuffle=True |
| ): |
| return DataLoader( |
| dataset, |
| batch_size=opt.batch_size, |
| shuffle=shuffle, |
| num_workers=num_workers, |
| drop_last=False |
| ) |
|
|
|
|
| dataset = GridDataset( |
| video_path=opt.video_path, |
| alignments_dir=opt.alignments_dir, |
| file_list=opt.train_list, |
| vid_pad=opt.vid_padding, |
| image_dir=opt.images_dir, |
| txt_pad=opt.txt_padding, |
| phase='train' |
| ) |
|
|
| loader = dataset2dataloader(dataset) |
|
|
|
|
| def fetch_samples(num_samples=10): |
| samples = [] |
| sample_no = 0 |
|
|
| for sample in loader: |
| sample_no += 1 |
| samples.append(sample) |
|
|
| if sample_no >= num_samples: |
| break |
|
|
| return samples |
|
|
|
|
| samples = fetch_samples() |
| print(samples[0]) |
| print('END') |