Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from torch.utils.data import Dataset | |
| from torchvision import datasets, transforms | |
| class COCODataset(Dataset): | |
| def __init__(self, root, transform=None): | |
| """ | |
| 初始化 COCO 数据集 | |
| :param root: 存储数据集的路径 | |
| :param train: 是否使用训练集 | |
| :param transform: 应用于图像的转换 | |
| """ | |
| super().__init__() | |
| self.root = root | |
| self.transform = transform | |
| self.dataset = datasets.ImageFolder(root, transform=transform) | |
| # self.dataset = datasets.CocoDetection(root=os.path.join(root, 'coco'), | |
| # annFile=os.path.join(root, | |
| # 'annotations/instances_train2017.json' if train else 'annotations/instances_val2017.json'), | |
| # transform=self.transform) | |
| def __getitem__(self, index): | |
| """ | |
| 获取一个数据点和其标签 | |
| """ | |
| image, _ = self.dataset[index] | |
| return image | |
| def __len__(self): | |
| """ | |
| 返回数据集中的数据点数 | |
| """ | |
| return len(self.dataset) | |