import os import numpy as np import torch import cv2 from torch.utils.data import Dataset import albumentations as A from albumentations.pytorch import ToTensorV2 class CocoSegmentationDataset(Dataset): def __init__(self, coco, image_folder, category_name=None, transform=None): self.coco = coco self.image_folder = image_folder self.transform = transform if category_name: self.cat_ids = self.coco.getCatIds(catNms=[category_name]) self.img_ids = self.coco.getImgIds(catIds=self.cat_ids) else: # Use all categories and all images if no specific category is provided self.cat_ids = self.coco.getCatIds() self.img_ids = self.coco.getImgIds() def __len__(self): return len(self.img_ids) def __getitem__(self, index): img_id = self.img_ids[index] img_info = self.coco.loadImgs(img_id)[0] img_path = os.path.join(self.image_folder, img_info['file_name']) # Load image with OpenCV (BGR to RGB) image = cv2.imread(img_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Fetch annotations for the image. If self.cat_ids is everything, it gets all annotations. ann_ids = self.coco.getAnnIds( imgIds=img_info['id'], catIds=self.cat_ids, iscrowd=None ) anns = self.coco.loadAnns(ann_ids) mask = np.zeros((img_info['height'], img_info['width'])) for ann in anns: mask = np.maximum(mask, self.coco.annToMask(ann)) if self.transform: augmented = self.transform(image=image, mask=mask) image = augmented['image'] mask = augmented['mask'] if not isinstance(mask, torch.Tensor): mask = torch.from_numpy(mask).float() if mask.ndim == 2: mask = mask.unsqueeze(0) return image, mask def get_train_transforms(image_size=256): return A.Compose([ A.LongestMaxSize(max_size=image_size), A.PadIfNeeded(min_height=image_size, min_width=image_size, border_mode=cv2.BORDER_CONSTANT, value=(123.675, 116.28, 103.53), mask_value=0), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.3), A.RandomBrightnessContrast(p=0.4), A.Affine( scale=(0.9, 1.1), rotate=(-15, 15), translate_percent=(0.05, 0.05), p=0.5 ), A.GaussianBlur(p=0.2), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2(), ]) def get_val_transforms(image_size=256): return A.Compose([ A.LongestMaxSize(max_size=image_size), A.PadIfNeeded(min_height=image_size, min_width=image_size, border_mode=cv2.BORDER_CONSTANT, value=(123.675, 116.28, 103.53), mask_value=0), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2(), ])