Spaces:
Runtime error
Runtime error
| """ | |
| file - dataset.py | |
| Customized dataset class to loop through the AVA dataset and apply needed image augmentations for training. | |
| Copyright (C) Yunxiao Shi 2017 - 2021 | |
| NIMA is released under the MIT license. See LICENSE for the fill license text. | |
| """ | |
| import os | |
| import pandas as pd | |
| from PIL import Image | |
| import torch | |
| from torch.utils import data | |
| import torchvision.transforms as transforms | |
| class AVADataset(data.Dataset): | |
| """AVA dataset | |
| Args: | |
| csv_file: a 11-column csv_file, column one contains the names of image files, column 2-11 contains the empiricial distributions of ratings | |
| root_dir: directory to the images | |
| transform: preprocessing and augmentation of the training images | |
| """ | |
| def __init__(self, csv_file, root_dir, transform=None): | |
| self.annotations = pd.read_csv(csv_file) | |
| self.root_dir = root_dir | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.annotations) | |
| def __getitem__(self, idx): | |
| img_name = os.path.join(self.root_dir, str(self.annotations.iloc[idx, 0]) + '.jpg') | |
| image = Image.open(img_name).convert('RGB') | |
| annotations = self.annotations.iloc[idx, 1:].to_numpy() | |
| annotations = annotations.astype('float').reshape(-1, 1) | |
| sample = {'img_id': img_name, 'image': image, 'annotations': annotations} | |
| if self.transform: | |
| sample['image'] = self.transform(sample['image']) | |
| return sample | |
| if __name__ == '__main__': | |
| # sanity check | |
| root = './data/images' | |
| csv_file = './data/train_labels.csv' | |
| train_transform = transforms.Compose([ | |
| transforms.Scale(256), | |
| transforms.RandomCrop(224), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) | |
| dset = AVADataset(csv_file=csv_file, root_dir=root, transform=train_transform) | |
| train_loader = data.DataLoader(dset, batch_size=4, shuffle=True, num_workers=4) | |
| for i, data in enumerate(train_loader): | |
| images = data['image'] | |
| print(images.size()) | |
| labels = data['annotations'] | |
| print(labels.size()) | |