Spaces:
Sleeping
Sleeping
| """ | |
| Copied from https://github.com/talshaharabany/AutoSAM | |
| """ | |
| import os | |
| from PIL import Image | |
| import torch.utils.data as data | |
| import torchvision.transforms as transforms | |
| import numpy as np | |
| import random | |
| import torch | |
| from dataloaders.PolypTransforms import get_polyp_transform | |
| import cv2 | |
| KVASIR = "Kvasir" | |
| CLINIC_DB = "CVC-ClinicDB" | |
| COLON_DB = "CVC-ColonDB" | |
| ETIS_DB = "ETIS-LaribPolypDB" | |
| CVC300 = "CVC-300" | |
| DATASETS = (KVASIR, CLINIC_DB, COLON_DB, ETIS_DB) | |
| EXCLUDE_DS = (CVC300, ) | |
| def create_suppport_set_for_polyps(n_support=10): | |
| """ | |
| create a text file contating n_support_images for each dataset | |
| """ | |
| root_dir = "/disk4/Lev/Projects/Self-supervised-Fewshot-Medical-Image-Segmentation/data/PolypDataset/TrainDataset" | |
| supp_images = [] | |
| supp_masks = [] | |
| image_dir = os.path.join(root_dir, "images") | |
| mask_dir = os.path.join(root_dir, "masks") | |
| # randonly sample n_support images and masks | |
| image_paths = sorted([os.path.join(image_dir, f) for f in os.listdir( | |
| image_dir) if f.endswith('.jpg') or f.endswith('.png')]) | |
| mask_paths = sorted([os.path.join(mask_dir, f) for f in os.listdir( | |
| mask_dir) if f.endswith('.png')]) | |
| while len(supp_images) < n_support: | |
| index = random.randint(0, len(image_paths) - 1) | |
| # check that the index is not already in the support set | |
| if image_paths[index] in supp_images: | |
| continue | |
| supp_images.append(image_paths[index]) | |
| supp_masks.append(mask_paths[index]) | |
| with open(os.path.join(root_dir, "support.txt"), 'w') as file: | |
| for image_path, mask_path in zip(supp_images, supp_masks): | |
| file.write(f"{image_path} {mask_path}\n") | |
| def create_train_val_test_split_for_polyps(): | |
| root_dir = "/disk4/Lev/Projects/Self-supervised-Fewshot-Medical-Image-Segmentation/data/PolypDataset/" | |
| # for each subdir in root_dir, create a split file | |
| num_train_images_per_dataset = { | |
| "CVC-ClinicDB": 548, "Kvasir": 900, "CVC-300": 0, "CVC-ColonDB": 0} | |
| num_test_images_per_dataset = { | |
| "CVC-ClinicDB": 64, "Kvasir": 100, "CVC-300": 60, "CVC-ColonDB": 380} | |
| for subdir in os.listdir(root_dir): | |
| subdir_path = os.path.join(root_dir, subdir) | |
| if os.path.isdir(subdir_path): | |
| split_file = os.path.join(subdir_path, "split.txt") | |
| image_dir = os.path.join(subdir_path, "images") | |
| create_train_val_test_split( | |
| image_dir, split_file, train_number=num_train_images_per_dataset[subdir], test_number=num_test_images_per_dataset[subdir]) | |
| def create_train_val_test_split(root, split_file, train_number=100, test_number=20): | |
| """ | |
| Create a train, val, test split file for a dataset | |
| root: root directory of dataset | |
| split_file: name of split file to create | |
| train_ratio: ratio of train set | |
| val_ratio: ratio of val set | |
| test_ratio: ratio of test set | |
| """ | |
| # Get all files in root directory | |
| files = os.listdir(root) | |
| # Filter out non-image files, remove suffix | |
| files = [f.split('.')[0] | |
| for f in files if f.endswith('.jpg') or f.endswith('.png')] | |
| # Shuffle files | |
| random.shuffle(files) | |
| # Calculate number of files for each split | |
| num_files = len(files) | |
| num_train = train_number | |
| num_test = test_number | |
| num_val = num_files - num_train - num_test | |
| print(f"num_train: {num_train}, num_val: {num_val}, num_test: {num_test}") | |
| # Create splits | |
| train = files[:num_train] | |
| val = files[num_train:num_train + num_val] | |
| test = files[num_train + num_val:] | |
| # Write splits to file | |
| with open(split_file, 'w') as file: | |
| file.write("train\n") | |
| for f in train: | |
| file.write(f + "\n") | |
| file.write("val\n") | |
| for f in val: | |
| file.write(f + "\n") | |
| file.write("test\n") | |
| for f in test: | |
| file.write(f + "\n") | |
| class PolypDataset(data.Dataset): | |
| """ | |
| dataloader for polyp segmentation tasks | |
| """ | |
| def __init__(self, root, image_root=None, gt_root=None, trainsize=352, augmentations=None, train=True, sam_trans=None, datasets=DATASETS, image_size=(1024, 1024), ds_mean=None, ds_std=None): | |
| self.trainsize = trainsize | |
| self.augmentations = augmentations | |
| self.datasets = datasets | |
| if isinstance(image_size, int): | |
| image_size = (image_size, image_size) | |
| self.image_size = image_size | |
| if image_root is not None and gt_root is not None: | |
| self.images = [ | |
| os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] | |
| self.gts = [ | |
| os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.endswith('.png')] | |
| # also look in subdirectories | |
| for subdir in os.listdir(image_root): | |
| # if not dir, continue | |
| if not os.path.isdir(os.path.join(image_root, subdir)): | |
| continue | |
| subdir_image_root = os.path.join(image_root, subdir) | |
| subdir_gt_root = os.path.join(gt_root, subdir) | |
| self.images.extend([os.path.join(subdir_image_root, f) for f in os.listdir( | |
| subdir_image_root) if f.endswith('.jpg') or f.endswith('.png')]) | |
| self.gts.extend([os.path.join(subdir_gt_root, f) for f in os.listdir( | |
| subdir_gt_root) if f.endswith('.png')]) | |
| else: | |
| self.images, self.gts = self.get_image_gt_pairs( | |
| root, split="train" if train else "test", datasets=self.datasets) | |
| self.images = sorted(self.images) | |
| self.gts = sorted(self.gts) | |
| if not 'VPS' in root: | |
| self.filter_files_and_get_ds_mean_and_std() | |
| if ds_mean is not None and ds_std is not None: | |
| self.mean, self.std = ds_mean, ds_std | |
| self.size = len(self.images) | |
| self.train = train | |
| self.sam_trans = sam_trans | |
| if self.sam_trans is not None: | |
| # sam trans takes care of norm | |
| self.mean, self.std = 0 , 1 | |
| def get_image_gt_pairs(self, dir_root: str, split="train", datasets: tuple = DATASETS): | |
| """ | |
| for each folder in dir root, get all image-gt pairs. Assumes each subdir has a split.txt file | |
| dir_root: root directory of all subdirectories, each subdirectory contains images and masks folders | |
| split: train, val, or test | |
| """ | |
| image_paths = [] | |
| gt_paths = [] | |
| for folder in os.listdir(dir_root): | |
| if folder not in datasets: | |
| continue | |
| split_file = os.path.join(dir_root, folder, "split.txt") | |
| if os.path.isfile(split_file): | |
| image_root = os.path.join(dir_root, folder, "images") | |
| gt_root = os.path.join(dir_root, folder, "masks") | |
| image_paths_tmp, gt_paths_tmp = self.get_image_gt_pairs_from_text_file( | |
| image_root, gt_root, split_file, split=split) | |
| image_paths.extend(image_paths_tmp) | |
| gt_paths.extend(gt_paths_tmp) | |
| else: | |
| print( | |
| f"No split.txt file found in {os.path.join(dir_root, folder)}") | |
| return image_paths, gt_paths | |
| def get_image_gt_pairs_from_text_file(self, image_root: str, gt_root: str, text_file: str, split: str = "train"): | |
| """ | |
| image_root: root directory of images | |
| gt_root: root directory of ground truth | |
| text_file: text file containing train, val, test split with the following format: | |
| train: | |
| image1 | |
| image2 | |
| ... | |
| val: | |
| image1 | |
| image2 | |
| ... | |
| test: | |
| image1 | |
| image2 | |
| ... | |
| split: train, val, or test | |
| """ | |
| # Initialize a dictionary to hold file names for each split | |
| splits = {"train": [], "val": [], "test": []} | |
| current_split = None | |
| # Read the text file and categorize file names under each split | |
| with open(text_file, 'r') as file: | |
| for line in file: | |
| line = line.strip() | |
| if line in splits: | |
| current_split = line | |
| elif line and current_split: | |
| splits[current_split].append(line) | |
| # Get the file names for the requested split | |
| file_names = splits[split] | |
| # Create image-ground truth pairs | |
| image_paths = [] | |
| gt_paths = [] | |
| for name in file_names: | |
| image_path = os.path.join(image_root, name + '.png') | |
| gt_path = os.path.join(gt_root, name + '.png') | |
| image_paths.append(image_path) | |
| gt_paths.append(gt_path) | |
| return image_paths, gt_paths | |
| def get_support_from_dirs(self, support_image_dir, support_mask_dir, n_support=1): | |
| support_images = [] | |
| support_labels = [] | |
| # get all images and masks | |
| support_image_paths = sorted([os.path.join(support_image_dir, f) for f in os.listdir( | |
| support_image_dir) if f.endswith('.jpg') or f.endswith('.png')]) | |
| support_mask_paths = sorted([os.path.join(support_mask_dir, f) for f in os.listdir( | |
| support_mask_dir) if f.endswith('.png')]) | |
| # sample n_support images and masks | |
| for i in range(n_support): | |
| index = random.randint(0, len(support_image_paths) - 1) | |
| support_img = self.cv2_loader( | |
| support_image_paths[index], is_mask=False) | |
| support_mask = self.cv2_loader( | |
| support_mask_paths[index], is_mask=True) | |
| support_images.append(support_img) | |
| support_labels.append(support_mask) | |
| if self.augmentations: | |
| support_images = [self.augmentations( | |
| img, mask)[0] for img, mask in zip(support_images, support_labels)] | |
| support_labels = [self.augmentations( | |
| img, mask)[1] for img, mask in zip(support_images, support_labels)] | |
| support_images = [(support_image - self.mean) / self.std if support_image.max() == 255 and support_image.min() == 0 else support_image for support_image in support_images] | |
| if self.sam_trans is not None: | |
| support_images = [self.sam_trans.preprocess( | |
| img).squeeze(0) for img in support_images] | |
| support_labels = [self.sam_trans.preprocess( | |
| mask) for mask in support_labels] | |
| else: | |
| image_size = self.image_size | |
| support_images = [torch.nn.functional.interpolate(img.unsqueeze( | |
| 0), size=image_size, mode='bilinear', align_corners=False).squeeze(0) for img in support_images] | |
| support_labels = [torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze( | |
| 0), size=image_size, mode='nearest').squeeze(0).squeeze(0) for mask in support_labels] | |
| return torch.stack(support_images), torch.stack(support_labels) | |
| def get_support_from_text_file(self, text_file, n_support=1): | |
| """ | |
| each row in the file has 2 paths divided by space, the first is the image path and the second is the mask path | |
| """ | |
| support_images = [] | |
| support_labels = [] | |
| with open(text_file, 'r') as file: | |
| for line in file: | |
| image_path, mask_path = line.strip().split() | |
| support_images.append(image_path) | |
| support_labels.append(mask_path) | |
| # indices = random.choices(range(len(support_images)), k=n_support) | |
| if n_support > len(support_images): | |
| raise ValueError(f"n_support ({n_support}) is larger than the number of images in the text file ({len(support_images)})") | |
| n_support_images = support_images[:n_support] | |
| n_support_labels = support_labels[:n_support] | |
| return n_support_images, n_support_labels | |
| def get_support(self, n_support=1, support_image_dir=None, support_mask_dir=None, text_file=None): | |
| """ | |
| Get support set from specified directories, text file or from the dataset itself | |
| """ | |
| if support_image_dir is not None and support_mask_dir: | |
| return self.get_support_from_dirs(support_image_dir, support_mask_dir, n_support=n_support) | |
| elif text_file is not None: | |
| support_image_paths, support_gt_paths = self.get_support_from_text_file(text_file, n_support=n_support) | |
| else: | |
| # randomly sample n_support images and masks from the dataset | |
| indices = random.choices(range(self.size), k=n_support) | |
| # indices = list(range(n_support)) | |
| print(f"support indices:{indices}") | |
| support_image_paths = [self.images[index] for index in indices] | |
| support_gt_paths = [self.gts[index] for index in indices] | |
| support_images = [] | |
| support_gts = [] | |
| for image_path, gt_path in zip(support_image_paths, support_gt_paths): | |
| support_img = self.cv2_loader(image_path, is_mask=False) | |
| support_mask = self.cv2_loader(gt_path, is_mask=True) | |
| out = self.process_image_gt(support_img, support_mask) | |
| support_images.append(out['image'].unsqueeze(0)) | |
| support_gts.append(out['label'].unsqueeze(0)) | |
| if len(support_images) >= n_support: | |
| break | |
| return support_images, support_gts, out['case'] | |
| # return torch.stack(support_images), torch.stack(support_gts), out['case'] | |
| def process_image_gt(self, image, gt, dataset=""): | |
| """ | |
| image and gt are expected to be output from self.cv2_loader | |
| """ | |
| original_size = tuple(image.shape[-2:]) | |
| if self.augmentations: | |
| image, mask = self.augmentations(image, gt) | |
| if self.sam_trans: | |
| image, mask = self.sam_trans.apply_image_torch( | |
| image.unsqueeze(0)), self.sam_trans.apply_image_torch(mask) | |
| elif image.max() <= 255 and image.min() >= 0: | |
| image = (image - self.mean) / self.std | |
| mask[mask > 0.5] = 1 | |
| mask[mask <= 0.5] = 0 | |
| # image_size = tuple(img.shape[-2:]) | |
| image_size = self.image_size | |
| if self.sam_trans is None: | |
| image = torch.nn.functional.interpolate(image.unsqueeze( | |
| 0), size=image_size, mode='bilinear', align_corners=False).squeeze(0) | |
| mask = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze( | |
| 0), size=image_size, mode='nearest').squeeze(0).squeeze(0) | |
| # img = (img - img.min()) / (img.max() - img.min()) # TODO uncomment this if results get worse | |
| return {'image': self.sam_trans.preprocess(image).squeeze(0) if self.sam_trans else image, | |
| 'label': self.sam_trans.preprocess(mask) if self.sam_trans else mask, | |
| 'original_size': torch.Tensor(original_size), | |
| 'image_size': torch.Tensor(image_size), | |
| 'case': dataset} # case to be compatible with polyp video dataset | |
| def get_dataset_name_from_path(self, path): | |
| for dataset in self.datasets: | |
| if dataset in path: | |
| return dataset | |
| return "" | |
| def __getitem__(self, index): | |
| image = self.cv2_loader(self.images[index], is_mask=False) | |
| gt = self.cv2_loader(self.gts[index], is_mask=True) | |
| dataset = self.get_dataset_name_from_path(self.images[index]) | |
| return self.process_image_gt(image, gt, dataset) | |
| def filter_files_and_get_ds_mean_and_std(self): | |
| assert len(self.images) == len(self.gts) | |
| images = [] | |
| gts = [] | |
| ds_mean = 0 | |
| ds_std = 0 | |
| for img_path, gt_path in zip(self.images, self.gts): | |
| if any([ex_ds in img_path for ex_ds in EXCLUDE_DS]): | |
| continue | |
| img = Image.open(img_path) | |
| gt = Image.open(gt_path) | |
| if img.size == gt.size: | |
| images.append(img_path) | |
| gts.append(gt_path) | |
| ds_mean += np.array(img).mean() | |
| ds_std += np.array(img).std() | |
| self.images = images | |
| self.gts = gts | |
| self.mean = ds_mean / len(self.images) | |
| self.std = ds_std / len(self.images) | |
| def rgb_loader(self, path): | |
| with open(path, 'rb') as f: | |
| img = Image.open(f) | |
| return img.convert('RGB') | |
| def binary_loader(self, path): | |
| # with open(path, 'rb') as f: | |
| # img = Image.open(f) | |
| # return img.convert('1') | |
| img = cv2.imread(path, 0) | |
| return img | |
| def cv2_loader(self, path, is_mask): | |
| if is_mask: | |
| img = cv2.imread(path, 0) | |
| img[img > 0] = 1 | |
| else: | |
| img = cv2.cvtColor(cv2.imread( | |
| path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) | |
| return img | |
| def resize(self, img, gt): | |
| assert img.size == gt.size | |
| w, h = img.size | |
| if h < self.trainsize or w < self.trainsize: | |
| h = max(h, self.trainsize) | |
| w = max(w, self.trainsize) | |
| return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST) | |
| else: | |
| return img, gt | |
| def __len__(self): | |
| # return 32 | |
| return self.size | |
| class SuperpixPolypDataset(PolypDataset): | |
| def __init__(self, root, image_root=None, gt_root=None, trainsize=352, augmentations=None, train=True, sam_trans=None, datasets=DATASETS, image_size=(1024, 1024), ds_mean=None, ds_std=None): | |
| self.trainsize = trainsize | |
| self.augmentations = augmentations | |
| self.datasets = datasets | |
| self.image_size = image_size | |
| # print(self.augmentations) | |
| if image_root is not None and gt_root is not None: | |
| self.images = [ | |
| os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] | |
| self.gts = [ | |
| os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.endswith('.png') and 'superpix' in f] | |
| # also look in subdirectories | |
| for subdir in os.listdir(image_root): | |
| # if not dir, continue | |
| if not os.path.isdir(os.path.join(image_root, subdir)): | |
| continue | |
| subdir_image_root = os.path.join(image_root, subdir) | |
| subdir_gt_root = os.path.join(gt_root, subdir) | |
| self.images.extend([os.path.join(subdir_image_root, f) for f in os.listdir( | |
| subdir_image_root) if f.endswith('.jpg') or f.endswith('.png')]) | |
| self.gts.extend([os.path.join(subdir_gt_root, f) for f in os.listdir( | |
| subdir_gt_root) if f.endswith('.png')]) | |
| else: | |
| self.images, self.gts = self.get_image_gt_pairs( | |
| root, split="train" if train else "test", datasets=self.datasets) | |
| self.images = sorted(self.images) | |
| self.gts = sorted(self.gts) | |
| if not 'VPS' in root: | |
| self.filter_files_and_get_ds_mean_and_std() | |
| if ds_mean is not None and ds_std is not None: | |
| self.mean, self.std = ds_mean, ds_std | |
| self.size = len(self.images) | |
| self.train = train | |
| self.sam_trans = sam_trans | |
| if self.sam_trans is not None: | |
| # sam trans takes care of norm | |
| self.mean, self.std = 0 , 1 | |
| def __getitem__(self, index): | |
| image = self.cv2_loader(self.images[index], is_mask=False) | |
| gt = self.cv2_loader(self.gts[index], is_mask=False) | |
| gt = gt[:, :, 0] | |
| fgpath = os.path.basename(self.gts[index]).split('.png')[0].split('superpix-MIDDLE_') | |
| fgpath = os.path.join(os.path.dirname(self.gts[index]), 'fgmask_' + fgpath[1] + '.png') | |
| fg = self.cv2_loader(fgpath, is_mask=True) | |
| dataset = self.get_dataset_name_from_path(self.images[index]) | |
| # randomly choose a superpixels from the gt | |
| gt[1-fg] = 0 | |
| sp_id = random.choice(np.unique(gt)[1:]) | |
| sp = (gt == sp_id).astype(np.uint8) | |
| out = self.process_image_gt(image, gt, dataset) | |
| support_image, support_sp, dataset = out["image"], out["label"], out["case"] | |
| out = self.process_image_gt(image, sp, dataset) | |
| query_image, query_sp, dataset = out["image"], out["label"], out["case"] | |
| # TODO tile the masks to have 3 channels? | |
| support_bg_mask = 1 - support_sp | |
| support_masks = {"fg_mask": support_sp, "bg_mask": support_bg_mask} | |
| batch = {"support_images" : [[support_image]], | |
| "support_mask" : [[support_masks]], | |
| "query_images" : [query_image], | |
| "query_labels" : [query_sp], | |
| "scan_id" : [dataset] | |
| } | |
| return batch | |
| def get_superpix_polyp_dataset(image_size:tuple=(1024,1024), sam_trans=None): | |
| transform_train, transform_test = get_polyp_transform() | |
| image_root = './data/PolypDataset/TrainDataset/images/' | |
| gt_root = './data/PolypDataset/TrainDataset/superpixels/' | |
| ds_train = SuperpixPolypDataset(root=image_root, image_root=image_root, gt_root=gt_root, | |
| augmentations=transform_train, | |
| sam_trans=sam_trans, | |
| image_size=image_size) | |
| return ds_train | |
| def get_polyp_dataset(image_size, sam_trans=None): | |
| transform_train, transform_test = get_polyp_transform() | |
| image_root = './data/PolypDataset/TrainDataset/images/' | |
| gt_root = './data/PolypDataset/TrainDataset/masks/' | |
| ds_train = PolypDataset(root=image_root, image_root=image_root, gt_root=gt_root, | |
| augmentations=transform_test, sam_trans=sam_trans, train=True, image_size=image_size) | |
| image_root = './data/PolypDataset/TestDataset/test/images/' | |
| gt_root = './data/PolypDataset/TestDataset/test/masks/' | |
| ds_test = PolypDataset(root=image_root, image_root=image_root, gt_root=gt_root, train=False, | |
| augmentations=transform_test, sam_trans=sam_trans, image_size=image_size) | |
| return ds_train, ds_test | |
| def get_tests_polyp_dataset(sam_trans): | |
| transform_train, transform_test = get_polyp_transform() | |
| image_root = './data/polyp/TestDataset/Kvasir/images/' | |
| gt_root = './data/polyp/TestDataset/Kvasir/masks/' | |
| ds_Kvasir = PolypDataset( | |
| image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans) | |
| image_root = './data/polyp/TestDataset/CVC-ClinicDB/images/' | |
| gt_root = './data/polyp/TestDataset/CVC-ClinicDB/masks/' | |
| ds_ClinicDB = PolypDataset( | |
| image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans) | |
| image_root = './data/polyp/TestDataset/CVC-ColonDB/images/' | |
| gt_root = './data/polyp/TestDataset/CVC-ColonDB/masks/' | |
| ds_ColonDB = PolypDataset( | |
| image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans) | |
| image_root = './data/polyp/TestDataset/ETIS-LaribPolypDB/images/' | |
| gt_root = './data/polyp/TestDataset/ETIS-LaribPolypDB/masks/' | |
| ds_ETIS = PolypDataset( | |
| image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans) | |
| return ds_Kvasir, ds_ClinicDB, ds_ColonDB, ds_ETIS | |
| if __name__ == '__main__': | |
| # create_train_val_test_split_for_polyps() | |
| create_suppport_set_for_polyps() | |