VisionExtract / dataset.py
Biswajeet1's picture
Update dataset.py
bd3980f verified
import os
import numpy as np
import torch
import cv2
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
class CocoSegmentationDataset(Dataset):
def __init__(self, coco, image_folder,
category_name=None,
transform=None):
self.coco = coco
self.image_folder = image_folder
self.transform = transform
if category_name:
self.cat_ids = self.coco.getCatIds(catNms=[category_name])
self.img_ids = self.coco.getImgIds(catIds=self.cat_ids)
else:
# Use all categories and all images if no specific category is provided
self.cat_ids = self.coco.getCatIds()
self.img_ids = self.coco.getImgIds()
def __len__(self):
return len(self.img_ids)
def __getitem__(self, index):
img_id = self.img_ids[index]
img_info = self.coco.loadImgs(img_id)[0]
img_path = os.path.join(self.image_folder, img_info['file_name'])
# Load image with OpenCV (BGR to RGB)
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Fetch annotations for the image. If self.cat_ids is everything, it gets all annotations.
ann_ids = self.coco.getAnnIds(
imgIds=img_info['id'],
catIds=self.cat_ids,
iscrowd=None
)
anns = self.coco.loadAnns(ann_ids)
mask = np.zeros((img_info['height'], img_info['width']))
for ann in anns:
mask = np.maximum(mask, self.coco.annToMask(ann))
if self.transform:
augmented = self.transform(image=image, mask=mask)
image = augmented['image']
mask = augmented['mask']
if not isinstance(mask, torch.Tensor):
mask = torch.from_numpy(mask).float()
if mask.ndim == 2:
mask = mask.unsqueeze(0)
return image, mask
def get_train_transforms(image_size=256):
return A.Compose([
A.LongestMaxSize(max_size=image_size),
A.PadIfNeeded(min_height=image_size, min_width=image_size, border_mode=cv2.BORDER_CONSTANT, value=(123.675, 116.28, 103.53), mask_value=0),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.3),
A.RandomBrightnessContrast(p=0.4),
A.Affine(
scale=(0.9, 1.1),
rotate=(-15, 15),
translate_percent=(0.05, 0.05),
p=0.5
),
A.GaussianBlur(p=0.2),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
def get_val_transforms(image_size=256):
return A.Compose([
A.LongestMaxSize(max_size=image_size),
A.PadIfNeeded(min_height=image_size, min_width=image_size, border_mode=cv2.BORDER_CONSTANT, value=(123.675, 116.28, 103.53), mask_value=0),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])