Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from pycocotools.coco import COCO | |
| import torchvision.transforms as T | |
| class LegoDataset(torch.utils.data.Dataset): | |
| def __init__(self, root, annFile, transforms=None): | |
| self.root = root | |
| self.coco = COCO(annFile) | |
| self.ids = list(self.coco.imgs.keys()) | |
| self.transforms = transforms or T.Compose([T.ToTensor()]) | |
| def __getitem__(self, index): | |
| img_id = self.ids[index] | |
| img_info = self.coco.loadImgs(img_id)[0] | |
| path = img_info["file_name"] | |
| img = Image.open(os.path.join(self.root, path)).convert("RGB") | |
| ann_ids = self.coco.getAnnIds(imgIds=img_id) | |
| annotations = self.coco.loadAnns(ann_ids) | |
| boxes = [] | |
| labels = [] | |
| masks = [] # Dummy masks | |
| for ann in annotations: | |
| xmin, ymin, width, height = ann["bbox"] | |
| boxes.append([xmin, ymin, xmin + width, ymin + height]) | |
| labels.append(1) # 'lego' is the only class, labeled as 1 | |
| # Dummy mask for Mask R-CNN, filled with zeros | |
| dummy_mask = np.zeros( | |
| (img_info["height"], img_info["width"]), dtype=np.uint8 | |
| ) | |
| masks.append(dummy_mask) | |
| boxes = torch.as_tensor(boxes, dtype=torch.float32) | |
| labels = torch.as_tensor(labels, dtype=torch.int64) | |
| masks = torch.as_tensor(np.array(masks), dtype=torch.uint8) | |
| target = { | |
| "boxes": boxes, | |
| "labels": labels, | |
| "masks": masks, | |
| "image_id": torch.tensor([img_id]), | |
| } | |
| if self.transforms: | |
| img = self.transforms(img) | |
| return img, target | |
| def __len__(self): | |
| return len(self.ids) | |