''' Author: Chris Xiao yl.xiao@mail.utoronto.ca Date: 2023-09-16 17:41:29 LastEditors: Chris Xiao yl.xiao@mail.utoronto.ca LastEditTime: 2023-12-17 18:22:42 FilePath: /EndoSAM/endoSAM/dataset.py Description: EndoVisDataset class I Love IU Copyright (c) 2023 by Chris Xiao yl.xiao@mail.utoronto.ca, All Rights Reserved. ''' from torch.utils.data import Dataset import os import glob import numpy as np import cv2 from utils import ResizeLongestSide, preprocess import torch modes = ['train', 'val', 'test'] class EndoVisDataset(Dataset): def __init__(self, root, ann_format= 'png', img_format = 'jpg', mode='train', encoder_size=1024): super(EndoVisDataset, self).__init__() """Define the customized EndoVis dataset Args: data_root_dir (str, optional): root dir containing all data. Defaults to "../data". mode (str, optional): either in "train", "val" or "test" mode. Defaults to "train". vit_mode (str, optional): "h", "l", "b" for huge, large, and base versions of SAM. Defaults to "h". """ self.root = root self.mode = mode self.ann_format = ann_format self.img_format = img_format self.encoder_size = encoder_size self.ann_path = os.path.join(self.root, 'ann') self.img_path = os.path.join(self.root, 'img') if self.mode in modes: self.img_mode_path = os.path.join(self.img_path, self.mode) self.ann_mode_path = os.path.join(self.ann_path, self.mode) else: raise ValueError('Invalid mode: {}'.format(self.mode)) self.imgs = glob.glob(os.path.join(self.img_mode_path, '*.{}'.format(self.img_format))) self.anns = glob.glob(os.path.join(self.ann_mode_path, '*.{}'.format(self.ann_format))) self.transform = ResizeLongestSide(self.encoder_size) def __len__(self): if self.mode in modes: assert len(self.imgs) == len(self.anns) return len(self.imgs) else: raise ValueError('Invalid mode: {}'.format(self.mode)) def __getitem__(self, index) -> tuple: img_bgr = cv2.imread(self.imgs[index]) img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) name = os.path.basename(self.imgs[index]).split('.')[0] input_image = self.transform.apply_image(img_rgb) input_image_torch = torch.as_tensor(input_image).permute(2, 0, 1).contiguous() img = preprocess(input_image_torch, self.encoder_size) ann_path = os.path.join(self.ann_mode_path, f"{name}.{self.ann_format}") ann = cv2.imread(ann_path, cv2.IMREAD_GRAYSCALE) ann = np.array(ann) ann[ann != 0] = 1 return img, ann, name, img_bgr