Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import torch | |
| import cv2 | |
| from torch.utils.data import Dataset | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| class CocoSegmentationDataset(Dataset): | |
| def __init__(self, coco, image_folder, | |
| category_name=None, | |
| transform=None): | |
| self.coco = coco | |
| self.image_folder = image_folder | |
| self.transform = transform | |
| if category_name: | |
| self.cat_ids = self.coco.getCatIds(catNms=[category_name]) | |
| self.img_ids = self.coco.getImgIds(catIds=self.cat_ids) | |
| else: | |
| # Use all categories and all images if no specific category is provided | |
| self.cat_ids = self.coco.getCatIds() | |
| self.img_ids = self.coco.getImgIds() | |
| def __len__(self): | |
| return len(self.img_ids) | |
| def __getitem__(self, index): | |
| img_id = self.img_ids[index] | |
| img_info = self.coco.loadImgs(img_id)[0] | |
| img_path = os.path.join(self.image_folder, img_info['file_name']) | |
| # Load image with OpenCV (BGR to RGB) | |
| image = cv2.imread(img_path) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # Fetch annotations for the image. If self.cat_ids is everything, it gets all annotations. | |
| ann_ids = self.coco.getAnnIds( | |
| imgIds=img_info['id'], | |
| catIds=self.cat_ids, | |
| iscrowd=None | |
| ) | |
| anns = self.coco.loadAnns(ann_ids) | |
| mask = np.zeros((img_info['height'], img_info['width'])) | |
| for ann in anns: | |
| mask = np.maximum(mask, self.coco.annToMask(ann)) | |
| if self.transform: | |
| augmented = self.transform(image=image, mask=mask) | |
| image = augmented['image'] | |
| mask = augmented['mask'] | |
| if not isinstance(mask, torch.Tensor): | |
| mask = torch.from_numpy(mask).float() | |
| if mask.ndim == 2: | |
| mask = mask.unsqueeze(0) | |
| return image, mask | |
| def get_train_transforms(image_size=256): | |
| return A.Compose([ | |
| A.LongestMaxSize(max_size=image_size), | |
| A.PadIfNeeded(min_height=image_size, min_width=image_size, border_mode=cv2.BORDER_CONSTANT, value=(123.675, 116.28, 103.53), mask_value=0), | |
| A.HorizontalFlip(p=0.5), | |
| A.VerticalFlip(p=0.3), | |
| A.RandomBrightnessContrast(p=0.4), | |
| A.Affine( | |
| scale=(0.9, 1.1), | |
| rotate=(-15, 15), | |
| translate_percent=(0.05, 0.05), | |
| p=0.5 | |
| ), | |
| A.GaussianBlur(p=0.2), | |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ToTensorV2(), | |
| ]) | |
| def get_val_transforms(image_size=256): | |
| return A.Compose([ | |
| A.LongestMaxSize(max_size=image_size), | |
| A.PadIfNeeded(min_height=image_size, min_width=image_size, border_mode=cv2.BORDER_CONSTANT, value=(123.675, 116.28, 103.53), mask_value=0), | |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ToTensorV2(), | |
| ]) | |