Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| from PIL import Image | |
| import pickle | |
| import imageio | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| import random | |
| from datasets import register | |
| import math | |
| import torch.distributed as dist | |
| from torch.utils.data import BatchSampler | |
| from torch.utils.data._utils.collate import default_collate | |
| class ImageFolder(Dataset): | |
| def __init__(self, path, split_file=None, split_key=None, first_k=None, size=None, | |
| repeat=1, cache='none', mask=False): | |
| self.repeat = repeat | |
| self.cache = cache | |
| self.path = path | |
| self.Train = False | |
| self.split_key = split_key | |
| self.size = size | |
| self.mask = mask | |
| if self.mask: | |
| self.img_transform = transforms.Compose([ | |
| transforms.Resize((self.size, self.size), interpolation=Image.NEAREST), | |
| transforms.ToTensor(), | |
| ]) | |
| else: | |
| self.img_transform = transforms.Compose([ | |
| transforms.Resize((self.size, self.size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| if split_file is None: | |
| filenames = sorted(os.listdir(path)) | |
| else: | |
| with open(split_file, 'r') as f: | |
| filenames = json.load(f)[split_key] | |
| if first_k is not None: | |
| filenames = filenames[:first_k] | |
| self.files = [] | |
| for filename in filenames: | |
| file = os.path.join(path, filename) | |
| self.append_file(file) | |
| def append_file(self, file): | |
| if self.cache == 'none': | |
| self.files.append(file) | |
| elif self.cache == 'in_memory': | |
| self.files.append(self.img_process(file)) | |
| def __len__(self): | |
| return len(self.files) * self.repeat | |
| def __getitem__(self, idx): | |
| x = self.files[idx % len(self.files)] | |
| if self.cache == 'none': | |
| return self.img_process(x) | |
| elif self.cache == 'in_memory': | |
| return x | |
| def img_process(self, file): | |
| if self.mask: | |
| # return Image.open(file).convert('L') | |
| return file | |
| else: | |
| return Image.open(file).convert('RGB') | |
| class PairedImageFolders(Dataset): | |
| def __init__(self, root_path_1, root_path_2, **kwargs): | |
| self.dataset_1 = ImageFolder(root_path_1, **kwargs) | |
| self.dataset_2 = ImageFolder(root_path_2, **kwargs, mask=True) | |
| def __len__(self): | |
| return len(self.dataset_1) | |
| def __getitem__(self, idx): | |
| return self.dataset_1[idx], self.dataset_2[idx] | |
| class ImageFolder_multi_task(Dataset): | |
| def __init__(self, path, split_file=None, split_key=None, first_k=None, size=None, | |
| repeat=1, cache='none', mask=False): | |
| self.repeat = repeat | |
| self.cache = cache | |
| self.path = path | |
| self.Train = False | |
| self.split_key = split_key | |
| self.size = size | |
| self.mask = mask | |
| if self.mask: | |
| self.img_transform = transforms.Compose([ | |
| transforms.Resize((self.size, self.size), interpolation=Image.NEAREST), | |
| transforms.ToTensor(), | |
| ]) | |
| else: | |
| self.img_transform = transforms.Compose([ | |
| transforms.Resize((self.size, self.size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| if split_file is None: | |
| filenames = sorted(os.listdir(path)) | |
| else: | |
| with open(split_file, 'r') as f: | |
| filenames = json.load(f)[split_key] | |
| if first_k is not None: | |
| filenames = filenames[:first_k] | |
| self.files = [] | |
| for filename in filenames: | |
| file = os.path.join(path, filename) | |
| self.append_file(file) | |
| def append_file(self, file): | |
| if self.cache == 'none': | |
| self.files.append(file) | |
| elif self.cache == 'in_memory': | |
| self.files.append(self.img_process(file)) | |
| def __len__(self): | |
| return len(self.files) * self.repeat | |
| def __getitem__(self, idx): | |
| x = self.files[idx % len(self.files)] | |
| if self.cache == 'none': | |
| return self.img_process(x) | |
| elif self.cache == 'in_memory': | |
| return x | |
| def img_process(self, file): | |
| if self.mask: | |
| # return Image.open(file).convert('L') | |
| return file | |
| else: | |
| return Image.open(file).convert('RGB') | |
| class PairedImageFolders_multi_task(Dataset): | |
| def __init__(self, root_path_1, root_path_2, model=None, **kwargs): | |
| self.dataset_1 = ImageFolder_multi_task(root_path_1, **kwargs) | |
| self.dataset_2 = ImageFolder_multi_task(root_path_2, **kwargs, mask=True) | |
| def __len__(self): | |
| return len(self.dataset_1) | |
| def __getitem__(self, idx): | |
| return self.dataset_1[idx], self.dataset_2[idx] | |
| # class MultiTaskDataset(Dataset): | |
| # """ | |
| # useage example: | |
| # train_datasets = [SemData_Single(), SemData_Single()] | |
| # multi_task_train_dataset = MultiTaskDataset(train_datasets) | |
| # multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True) | |
| # multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler) | |
| # for i, (task_id, input, target) in enumerate(multi_task_train_data): | |
| # pre = model(input) | |
| # """ | |
| # def __init__(self, datasets_image, datasets_gt): | |
| # self._datasets = datasets_image | |
| # task_id_2_image_set_dic = {} | |
| # for i, dataset in enumerate(datasets_image): | |
| # task_id = i | |
| # assert task_id not in task_id_2_image_set_dic, "Duplicate task_id %s" % task_id | |
| # task_id_2_image_set_dic[task_id] = dataset | |
| # self.datasets_1 = task_id_2_image_set_dic | |
| # | |
| # task_id_2_gt_set_dic = {} | |
| # for i, dataset in enumerate(datasets_gt): | |
| # task_id = i | |
| # assert task_id not in task_id_2_gt_set_dic, "Duplicate task_id %s" % task_id | |
| # task_id_2_gt_set_dic[task_id] = dataset | |
| # self.dataset_2 = task_id_2_gt_set_dic | |
| # | |
| # | |
| # def __len__(self): | |
| # return sum(len(dataset) for dataset in self._datasets) | |
| # | |
| # def __getitem__(self, idx): | |
| # task_id, sample_id = idx | |
| # # return self._task_id_2_data_set_dic[task_id][sample_id] | |
| # return self.dataset_1[task_id][sample_id], self.dataset_2[task_id][sample_id] | |
| class MultiTaskDataset(Dataset): | |
| """ | |
| useage example: | |
| train_datasets = [SemData_Single(), SemData_Single()] | |
| multi_task_train_dataset = MultiTaskDataset(train_datasets) | |
| multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True) | |
| multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler) | |
| for i, (task_id, input, target) in enumerate(multi_task_train_data): | |
| pre = model(input) | |
| """ | |
| def __init__(self, datasets): | |
| self._datasets = datasets | |
| task_id_2_data_set_dic = {} | |
| for i, dataset in enumerate(datasets): | |
| task_id = i | |
| assert task_id not in task_id_2_data_set_dic, "Duplicate task_id %s" % task_id | |
| task_id_2_data_set_dic[task_id] = dataset | |
| self._task_id_2_data_set_dic = task_id_2_data_set_dic | |
| def __len__(self): | |
| return sum(len(dataset) for dataset in self._datasets) | |
| def __getitem__(self, idx): | |
| task_id, sample_id = idx | |
| # print('----', idx, task_id, sample_id) | |
| return self._task_id_2_data_set_dic[task_id][sample_id] | |
| def collate_fn(batch): | |
| # print(len(batch)) | |
| # print('*'*10) | |
| # print(batch[0][0]) | |
| # print('#'*10) | |
| # print(batch[0][1]) | |
| # batch = list(filter(lambda x: x[0][0] is not None, batch)) | |
| # if len(batch) == 0: return torch.Tensor() | |
| print('******------',batch) | |
| if not isinstance(batch[0][0], tuple): | |
| return default_collate(batch) | |
| else: | |
| batch_num = len(batch) | |
| ret = [] | |
| for item_idx in range(len(batch[0][0])): | |
| if batch[0][0][item_idx] is None: | |
| ret.append(None) | |
| else: | |
| ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)])) | |
| ret.append(default_collate([batch[i][1] for i in range(batch_num)])) | |
| return ret | |
| class DistrubutedMultiTaskBatchSampler(BatchSampler): | |
| """ | |
| datasets: class the class of the Dataset | |
| batch_size: int | |
| mix_opt: int mix_opt ==0 shuffle all_task; mix_opt ==1 shuffle extra_task | |
| extra_task_ratio(float, optional): the rate between task one and extra task | |
| drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, | |
| if the dataset size is not divisible by the batch size. If ``False`` and | |
| the size of dataset is not divisible by the batch size, then the last batch | |
| will be smaller. (default: ``True``) | |
| """ | |
| def __init__(self, datasets, batch_size, num_replicas, rank, mix_opt=0, extra_task_ratio=0, drop_last=True, | |
| shuffle=True): | |
| if num_replicas is None: | |
| if not dist.is_available(): | |
| raise RuntimeError("Requires distributed package to be available") | |
| num_replicas = dist.get_world_size() | |
| if rank is None: | |
| if not dist.is_available(): | |
| raise RuntimeError("Requires distributed package to be available") | |
| rank = dist.get_rank() | |
| if rank >= num_replicas or rank < 0: | |
| raise ValueError( | |
| "Invalid rank {}, rank should be in the interval" | |
| " [0, {}]".format(rank, num_replicas - 1)) | |
| self.num_replicas = num_replicas | |
| self.rank = rank | |
| self.epoch = 0 | |
| assert mix_opt in [0, 1], 'mix_opt must equal 0 or 1' | |
| assert extra_task_ratio >= 0, 'extra_task_ratio must greater than 0' | |
| # self._datasets = datasets | |
| self._batch_size = batch_size | |
| self._mix_opt = mix_opt | |
| self._extra_task_ratio = extra_task_ratio | |
| self._drop_last = drop_last | |
| train_data_list = [] | |
| self.shuffle = shuffle | |
| for dataset in datasets: | |
| print(len(dataset)) | |
| train_data_list.append(self._get_index_batches(len(dataset), batch_size, self._drop_last)) | |
| ######### 一个列表里存n个dataset的数据,数据也以列表形式存在,一个dataset的列表里面把数据划分成了不同的batch的index | |
| self._train_data_list = train_data_list | |
| self.total_len = sum(len(train_data) for train_data in self._train_data_list) | |
| ######### DDP ###################### | |
| if self._drop_last and self.total_len % self.num_replicas != 0: # type: ignore[arg-type] | |
| self.num_samples = math.ceil( | |
| (self.total_len - self.num_replicas) / self.num_replicas # type: ignore[arg-type] | |
| ) | |
| else: | |
| self.num_samples = math.ceil(self.total_len / self.num_replicas) # type: ignore[arg-type] | |
| self.total_size = self.num_samples * self.num_replicas | |
| self.epoch = 0 | |
| self.seed = 0 | |
| def set_epoch(self, epoch): | |
| # print('&&&&****') | |
| self.epoch = epoch | |
| def _get_index_batches(dataset_len, batch_size, drop_last): | |
| # index_batches = [list(range(i, min(i+batch_size, dataset_len))) for i in range(0, dataset_len, batch_size)] | |
| index = list(range(dataset_len)) | |
| if drop_last and dataset_len % batch_size: | |
| del index[-(dataset_len % batch_size):] | |
| index_batches = [index[i:i + batch_size] for i in range(0, len(index), batch_size)] | |
| return index_batches | |
| def __len__(self): | |
| # return sum(len(train_data) for train_data in self._train_data_list) | |
| return self.num_samples | |
| def __iter__(self): | |
| all_iters = [iter(item) for item in self._train_data_list] | |
| all_indices = self._gen_task_indices(self._train_data_list, self._mix_opt, self._extra_task_ratio) | |
| ######### DDP ###################### | |
| random.shuffle(all_indices) | |
| all_indices = all_indices[self.rank:self.total_size:self.num_replicas] | |
| assert len(all_indices) == self.num_samples | |
| for local_task_idx in all_indices: | |
| # task_id = self._datasets[local_task_idx].get_task_id() | |
| batch = next(all_iters[local_task_idx]) | |
| # batch = batch[self.rank:len(batch):self.num_replicas] | |
| # print(local_task_idx) | |
| yield [(local_task_idx, sample_id) for sample_id in batch] | |
| # yield iter(batch) | |
| def _gen_task_indices(train_data_list, mix_opt, extra_task_ratio): | |
| ########## accoding to the number of models ########### | |
| all_indices = [] | |
| for i in range(len(train_data_list)): | |
| all_indices += [i] * len(train_data_list[i]) | |
| # print(all_indices) | |
| return all_indices | |
| # def set_epoch(self, epoch) | |
| # r""" | |
| # Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas | |
| # use a different random ordering for each epoch. Otherwise, the next iteration of this | |
| # sampler will yield the same ordering. | |
| # Args: | |
| # epoch (int): Epoch number. | |
| # """ | |
| # self.epoch = epoch | |