| ''' | |
| Author: Chris Xiao yl.xiao@mail.utoronto.ca | |
| Date: 2023-09-16 17:41:29 | |
| LastEditors: Chris Xiao yl.xiao@mail.utoronto.ca | |
| LastEditTime: 2023-12-17 18:22:42 | |
| FilePath: /EndoSAM/endoSAM/dataset.py | |
| Description: EndoVisDataset class | |
| I Love IU | |
| Copyright (c) 2023 by Chris Xiao yl.xiao@mail.utoronto.ca, All Rights Reserved. | |
| ''' | |
| from torch.utils.data import Dataset | |
| import os | |
| import glob | |
| import numpy as np | |
| import cv2 | |
| from utils import ResizeLongestSide, preprocess | |
| import torch | |
| modes = ['train', 'val', 'test'] | |
| class EndoVisDataset(Dataset): | |
| def __init__(self, root, | |
| ann_format= 'png', | |
| img_format = 'jpg', | |
| mode='train', | |
| encoder_size=1024): | |
| super(EndoVisDataset, self).__init__() | |
| """Define the customized EndoVis dataset | |
| Args: | |
| data_root_dir (str, optional): root dir containing all data. Defaults to "../data". | |
| mode (str, optional): either in "train", "val" or "test" mode. Defaults to "train". | |
| vit_mode (str, optional): "h", "l", "b" for huge, large, and base versions of SAM. Defaults to "h". | |
| """ | |
| self.root = root | |
| self.mode = mode | |
| self.ann_format = ann_format | |
| self.img_format = img_format | |
| self.encoder_size = encoder_size | |
| self.ann_path = os.path.join(self.root, 'ann') | |
| self.img_path = os.path.join(self.root, 'img') | |
| if self.mode in modes: | |
| self.img_mode_path = os.path.join(self.img_path, self.mode) | |
| self.ann_mode_path = os.path.join(self.ann_path, self.mode) | |
| else: | |
| raise ValueError('Invalid mode: {}'.format(self.mode)) | |
| self.imgs = glob.glob(os.path.join(self.img_mode_path, '*.{}'.format(self.img_format))) | |
| self.anns = glob.glob(os.path.join(self.ann_mode_path, '*.{}'.format(self.ann_format))) | |
| self.transform = ResizeLongestSide(self.encoder_size) | |
| def __len__(self): | |
| if self.mode in modes: | |
| assert len(self.imgs) == len(self.anns) | |
| return len(self.imgs) | |
| else: | |
| raise ValueError('Invalid mode: {}'.format(self.mode)) | |
| def __getitem__(self, index) -> tuple: | |
| img_bgr = cv2.imread(self.imgs[index]) | |
| img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) | |
| name = os.path.basename(self.imgs[index]).split('.')[0] | |
| input_image = self.transform.apply_image(img_rgb) | |
| input_image_torch = torch.as_tensor(input_image).permute(2, 0, 1).contiguous() | |
| img = preprocess(input_image_torch, self.encoder_size) | |
| ann_path = os.path.join(self.ann_mode_path, f"{name}.{self.ann_format}") | |
| ann = cv2.imread(ann_path, cv2.IMREAD_GRAYSCALE) | |
| ann = np.array(ann) | |
| ann[ann != 0] = 1 | |
| return img, ann, name, img_bgr |