File size: 1,212 Bytes
5106c86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms


class COCODataset(Dataset):
    def __init__(self, root, transform=None):
        """
        初始化 COCO 数据集
        :param root: 存储数据集的路径
        :param train: 是否使用训练集
        :param transform: 应用于图像的转换
        """
        super().__init__()
        self.root = root
        self.transform = transform
        self.dataset = datasets.ImageFolder(root, transform=transform)
        # self.dataset = datasets.CocoDetection(root=os.path.join(root, 'coco'),
        #                                       annFile=os.path.join(root,
        #                                                            'annotations/instances_train2017.json' if train else 'annotations/instances_val2017.json'),
        #                                       transform=self.transform)

    def __getitem__(self, index):
        """
        获取一个数据点和其标签
        """
        image, _ = self.dataset[index]
        return image

    def __len__(self):
        """
        返回数据集中的数据点数
        """
        return len(self.dataset)