Spaces:
Runtime error
Runtime error
File size: 1,212 Bytes
5106c86 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | import os
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
class COCODataset(Dataset):
def __init__(self, root, transform=None):
"""
初始化 COCO 数据集
:param root: 存储数据集的路径
:param train: 是否使用训练集
:param transform: 应用于图像的转换
"""
super().__init__()
self.root = root
self.transform = transform
self.dataset = datasets.ImageFolder(root, transform=transform)
# self.dataset = datasets.CocoDetection(root=os.path.join(root, 'coco'),
# annFile=os.path.join(root,
# 'annotations/instances_train2017.json' if train else 'annotations/instances_val2017.json'),
# transform=self.transform)
def __getitem__(self, index):
"""
获取一个数据点和其标签
"""
image, _ = self.dataset[index]
return image
def __len__(self):
"""
返回数据集中的数据点数
"""
return len(self.dataset)
|