| from __future__ import print_function |
|
|
| import torch |
| import torchvision.datasets as datasets |
| from torch.utils.data import Dataset |
| from PIL import Image |
| from .tsv_io import TSVFile |
| import numpy as np |
| import base64 |
| import io |
|
|
|
|
| class TSVDataset(Dataset): |
| """ TSV dataset for ImageNet 1K training |
| """ |
| def __init__(self, tsv_file, transform=None, target_transform=None): |
| self.tsv = TSVFile(tsv_file) |
| self.transform = transform |
| self.target_transform = target_transform |
|
|
| def __getitem__(self, index): |
| """ |
| Args: |
| index (int): Index |
| Returns: |
| tuple: (image, target) where target is class_index of the target class. |
| """ |
| row = self.tsv.seek(index) |
| image_data = base64.b64decode(row[-1]) |
| image = Image.open(io.BytesIO(image_data)) |
| image = image.convert('RGB') |
| target = int(row[1]) |
|
|
| if self.transform is not None: |
| img = self.transform(image) |
| else: |
| img = image |
| if self.target_transform is not None: |
| target = self.target_transform(target) |
|
|
| return img, target |
|
|
| def __len__(self): |
| return self.tsv.num_rows() |
|
|