import os import torch import numpy as np from PIL import Image from pycocotools.coco import COCO import torchvision.transforms as T class LegoDataset(torch.utils.data.Dataset): def __init__(self, root, annFile, transforms=None): self.root = root self.coco = COCO(annFile) self.ids = list(self.coco.imgs.keys()) self.transforms = transforms or T.Compose([T.ToTensor()]) def __getitem__(self, index): img_id = self.ids[index] img_info = self.coco.loadImgs(img_id)[0] path = img_info["file_name"] img = Image.open(os.path.join(self.root, path)).convert("RGB") ann_ids = self.coco.getAnnIds(imgIds=img_id) annotations = self.coco.loadAnns(ann_ids) boxes = [] labels = [] masks = [] # Dummy masks for ann in annotations: xmin, ymin, width, height = ann["bbox"] boxes.append([xmin, ymin, xmin + width, ymin + height]) labels.append(1) # 'lego' is the only class, labeled as 1 # Dummy mask for Mask R-CNN, filled with zeros dummy_mask = np.zeros( (img_info["height"], img_info["width"]), dtype=np.uint8 ) masks.append(dummy_mask) boxes = torch.as_tensor(boxes, dtype=torch.float32) labels = torch.as_tensor(labels, dtype=torch.int64) masks = torch.as_tensor(np.array(masks), dtype=torch.uint8) target = { "boxes": boxes, "labels": labels, "masks": masks, "image_id": torch.tensor([img_id]), } if self.transforms: img = self.transforms(img) return img, target def __len__(self): return len(self.ids)