Spaces:
Sleeping
Sleeping
| """ | |
| Customized dataset. Extended from vanilla PANet script by Wang et al. | |
| """ | |
| import os | |
| import random | |
| import torch | |
| import numpy as np | |
| from dataloaders.common import ReloadPairedDataset, ValidationDataset | |
| from dataloaders.ManualAnnoDatasetv2 import ManualAnnoDataset | |
| def attrib_basic(_sample, class_id): | |
| """ | |
| Add basic attribute | |
| Args: | |
| _sample: data sample | |
| class_id: class label asscociated with the data | |
| (sometimes indicting from which subset the data are drawn) | |
| """ | |
| return {'class_id': class_id} | |
| def getMaskOnly(label, class_id, class_ids): | |
| """ | |
| Generate FG/BG mask from the segmentation mask | |
| Args: | |
| label: | |
| semantic mask | |
| scribble: | |
| scribble mask | |
| class_id: | |
| semantic class of interest | |
| class_ids: | |
| all class id in this episode | |
| """ | |
| # Dense Mask | |
| fg_mask = torch.where(label == class_id, | |
| torch.ones_like(label), torch.zeros_like(label)) | |
| bg_mask = torch.where(label != class_id, | |
| torch.ones_like(label), torch.zeros_like(label)) | |
| for class_id in class_ids: | |
| bg_mask[label == class_id] = 0 | |
| return {'fg_mask': fg_mask, | |
| 'bg_mask': bg_mask} | |
| def getMasks(*args, **kwargs): | |
| raise NotImplementedError | |
| def fewshot_pairing(paired_sample, n_ways, n_shots, cnt_query, coco=False, mask_only = True): | |
| """ | |
| Postprocess paired sample for fewshot settings | |
| For now only 1-way is tested but we leave multi-way possible (inherited from original PANet) | |
| Args: | |
| paired_sample: | |
| data sample from a PairedDataset | |
| n_ways: | |
| n-way few-shot learning | |
| n_shots: | |
| n-shot few-shot learning | |
| cnt_query: | |
| number of query images for each class in the support set | |
| coco: | |
| MS COCO dataset. This is from the original PANet dataset but lets keep it for further extension | |
| mask_only: | |
| only give masks and no scribbles/ instances. Suitable for medical images (for now) | |
| """ | |
| if not mask_only: | |
| raise NotImplementedError | |
| ###### Compose the support and query image list ###### | |
| cumsum_idx = np.cumsum([0,] + [n_shots + x for x in cnt_query]) # seperation for supports and queries | |
| # support class ids | |
| class_ids = [paired_sample[cumsum_idx[i]]['basic_class_id'] for i in range(n_ways)] # class ids for each image (support and query) | |
| # support images | |
| support_images = [[paired_sample[cumsum_idx[i] + j]['image'] for j in range(n_shots)] | |
| for i in range(n_ways)] # fetch support images for each class | |
| # support image labels | |
| if coco: | |
| support_labels = [[paired_sample[cumsum_idx[i] + j]['label'][class_ids[i]] | |
| for j in range(n_shots)] for i in range(n_ways)] | |
| else: | |
| support_labels = [[paired_sample[cumsum_idx[i] + j]['label'] for j in range(n_shots)] | |
| for i in range(n_ways)] | |
| if not mask_only: | |
| support_scribbles = [[paired_sample[cumsum_idx[i] + j]['scribble'] for j in range(n_shots)] | |
| for i in range(n_ways)] | |
| support_insts = [[paired_sample[cumsum_idx[i] + j]['inst'] for j in range(n_shots)] | |
| for i in range(n_ways)] | |
| else: | |
| support_insts = [] | |
| # query images, masks and class indices | |
| query_images = [paired_sample[cumsum_idx[i+1] - j - 1]['image'] for i in range(n_ways) | |
| for j in range(cnt_query[i])] | |
| if coco: | |
| query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'][class_ids[i]] | |
| for i in range(n_ways) for j in range(cnt_query[i])] | |
| else: | |
| query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'] for i in range(n_ways) | |
| for j in range(cnt_query[i])] | |
| query_cls_idx = [sorted([0,] + [class_ids.index(x) + 1 | |
| for x in set(np.unique(query_label)) & set(class_ids)]) | |
| for query_label in query_labels] | |
| ###### Generate support image masks ###### | |
| if not mask_only: | |
| support_mask = [[getMasks(support_labels[way][shot], support_scribbles[way][shot], | |
| class_ids[way], class_ids) | |
| for shot in range(n_shots)] for way in range(n_ways)] | |
| else: | |
| support_mask = [[getMaskOnly(support_labels[way][shot], | |
| class_ids[way], class_ids) | |
| for shot in range(n_shots)] for way in range(n_ways)] | |
| ###### Generate query label (class indices in one episode, i.e. the ground truth)###### | |
| query_labels_tmp = [torch.zeros_like(x) for x in query_labels] | |
| for i, query_label_tmp in enumerate(query_labels_tmp): | |
| query_label_tmp[query_labels[i] == 255] = 255 | |
| for j in range(n_ways): | |
| query_label_tmp[query_labels[i] == class_ids[j]] = j + 1 | |
| ###### Generate query mask for each semantic class (including BG) ###### | |
| # BG class | |
| query_masks = [[torch.where(query_label == 0, | |
| torch.ones_like(query_label), | |
| torch.zeros_like(query_label))[None, ...],] | |
| for query_label in query_labels] | |
| # Other classes in query image | |
| for i, query_label in enumerate(query_labels): | |
| for idx in query_cls_idx[i][1:]: | |
| mask = torch.where(query_label == class_ids[idx - 1], | |
| torch.ones_like(query_label), | |
| torch.zeros_like(query_label))[None, ...] | |
| query_masks[i].append(mask) | |
| return {'class_ids': class_ids, | |
| 'support_images': support_images, | |
| 'support_mask': support_mask, | |
| 'support_inst': support_insts, # leave these interfaces | |
| 'support_scribbles': support_scribbles, | |
| 'query_images': query_images, | |
| 'query_labels': query_labels_tmp, | |
| 'query_masks': query_masks, | |
| 'query_cls_idx': query_cls_idx, | |
| } | |
| def med_fewshot(dataset_name, base_dir, idx_split, mode, scan_per_load, | |
| transforms, act_labels, n_ways, n_shots, max_iters_per_load, min_fg = '', n_queries=1, fix_parent_len = None, exclude_list = [], **kwargs): | |
| """ | |
| Dataset wrapper | |
| Args: | |
| dataset_name: | |
| indicates what dataset to use | |
| base_dir: | |
| dataset directory | |
| mode: | |
| which mode to use | |
| choose from ('train', 'val', 'trainval', 'trainaug') | |
| idx_split: | |
| index of split | |
| scan_per_load: | |
| number of scans to load into memory as the dataset is large | |
| use that together with reload_buffer | |
| transforms: | |
| transformations to be performed on images/masks | |
| act_labels: | |
| active labels involved in training process. Should be a subset of all labels | |
| n_ways: | |
| n-way few-shot learning, should be no more than # of object class labels | |
| n_shots: | |
| n-shot few-shot learning | |
| max_iters_per_load: | |
| number of pairs per load (epoch size) | |
| n_queries: | |
| number of query images | |
| fix_parent_len: | |
| fixed length of the parent dataset | |
| """ | |
| med_set = ManualAnnoDataset | |
| mydataset = med_set(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = mode,\ | |
| scan_per_load = scan_per_load, transforms=transforms, min_fg = min_fg, fix_length = fix_parent_len,\ | |
| exclude_list = exclude_list, **kwargs) | |
| mydataset.add_attrib('basic', attrib_basic, {}) | |
| # Create sub-datasets and add class_id attribute. Here the class file is internally loaded and reloaded inside | |
| subsets = mydataset.subsets([{'basic': {'class_id': ii}} | |
| for ii, _ in enumerate(mydataset.label_name)]) | |
| # Choose the classes of queries | |
| cnt_query = np.bincount(random.choices(population=range(n_ways), k=n_queries), minlength=n_ways) | |
| # Number of queries for each way | |
| # Set the number of images for each class | |
| n_elements = [n_shots + x for x in cnt_query] # <n_shot> supports + <cnt_quert>[i] queries | |
| # Create paired dataset. We do not include background. | |
| paired_data = ReloadPairedDataset([subsets[ii] for ii in act_labels], n_elements=n_elements, curr_max_iters=max_iters_per_load, | |
| pair_based_transforms=[ | |
| (fewshot_pairing, {'n_ways': n_ways, 'n_shots': n_shots, | |
| 'cnt_query': cnt_query, 'mask_only': True})]) | |
| return paired_data, mydataset | |
| def update_loader_dset(loader, parent_set): | |
| """ | |
| Update data loader and the parent dataset behind | |
| Args: | |
| loader: actual dataloader | |
| parent_set: parent dataset which actually stores the data | |
| """ | |
| parent_set.reload_buffer() | |
| loader.dataset.update_index() | |
| print(f'###### Loader and dataset have been updated ######' ) | |
| def med_fewshot_val(dataset_name, base_dir, idx_split, scan_per_load, act_labels, npart, fix_length = None, nsup = 1, transforms=None, mode='val', **kwargs): | |
| """ | |
| validation set for med images | |
| Args: | |
| dataset_name: | |
| indicates what dataset to use | |
| base_dir: | |
| SABS dataset directory | |
| mode: (original split) | |
| which split to use | |
| choose from ('train', 'val', 'trainval', 'trainaug') | |
| idx_split: | |
| index of split | |
| scan_per_batch: | |
| number of scans to load into memory as the dataset is large | |
| use that together with reload_buffer | |
| act_labels: | |
| actual labels involved in training process. Should be a subset of all labels | |
| npart: number of chunks for splitting a 3d volume | |
| nsup: number of support scans, equivalent to nshot | |
| """ | |
| mydataset = ManualAnnoDataset(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = mode, scan_per_load = scan_per_load, transforms=transforms, min_fg = 1, fix_length = fix_length, nsup = nsup, **kwargs) | |
| mydataset.add_attrib('basic', attrib_basic, {}) | |
| valset = ValidationDataset(mydataset, test_classes = act_labels, npart = npart) | |
| return valset, mydataset |