StyleFusion / datasets.py
escapist413's picture
Upload existing project files
5106c86
import os
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
class COCODataset(Dataset):
def __init__(self, root, transform=None):
"""
初始化 COCO 数据集
:param root: 存储数据集的路径
:param train: 是否使用训练集
:param transform: 应用于图像的转换
"""
super().__init__()
self.root = root
self.transform = transform
self.dataset = datasets.ImageFolder(root, transform=transform)
# self.dataset = datasets.CocoDetection(root=os.path.join(root, 'coco'),
# annFile=os.path.join(root,
# 'annotations/instances_train2017.json' if train else 'annotations/instances_val2017.json'),
# transform=self.transform)
def __getitem__(self, index):
"""
获取一个数据点和其标签
"""
image, _ = self.dataset[index]
return image
def __len__(self):
"""
返回数据集中的数据点数
"""
return len(self.dataset)