Spaces:
Running
Running
| from typing import Dict, List, Optional, Type, Union | |
| from ..datasets.ab_dataset import ABDataset | |
| # from benchmark.data.visualize import visualize_classes_in_object_detection | |
| # from benchmark.scenario.val_domain_shift import get_val_domain_shift_transform | |
| from ..dataset import get_dataset | |
| import copy | |
| from torchvision.transforms import Compose | |
| from ..datasets.registery import static_dataset_registery | |
| from ..build.scenario import Scenario as DAScenario | |
| from copy import deepcopy | |
| from utils.common.log import logger | |
| import random | |
| from .scenario import _ABDatasetMetaInfo, Scenario | |
| def _check(source_datasets_meta_info: List[_ABDatasetMetaInfo], target_datasets_meta_info: List[_ABDatasetMetaInfo]): | |
| # requirements for simplity | |
| # 1. no same class in source datasets | |
| source_datasets_class = [i.classes for i in source_datasets_meta_info] | |
| for ci1, c1 in enumerate(source_datasets_class): | |
| for ci2, c2 in enumerate(source_datasets_class): | |
| if ci1 == ci2: | |
| continue | |
| c1_name = source_datasets_meta_info[ci1].name | |
| c2_name = source_datasets_meta_info[ci2].name | |
| intersection = set(c1).intersection(set(c2)) | |
| assert len(intersection) == 0, f'{c1_name} has intersection with {c2_name}: {intersection}' | |
| def build_cl_scenario( | |
| da_scenario: DAScenario, | |
| target_datasets_name: List[str], | |
| num_classes_per_task: int, | |
| max_num_tasks: int, | |
| data_dirs, | |
| sanity_check=False | |
| ): | |
| config = deepcopy(locals()) | |
| source_datasets_idx_map = {} | |
| source_class_idx_max = 0 | |
| for sd in da_scenario.config['source_datasets_name']: | |
| da_scenario_idx_map = None | |
| for k, v in da_scenario.all_datasets_idx_map.items(): | |
| if k.startswith(sd): | |
| da_scenario_idx_map = v | |
| break | |
| source_datasets_idx_map[sd] = da_scenario_idx_map | |
| source_class_idx_max = max(source_class_idx_max, max(list(da_scenario_idx_map.values()))) | |
| target_class_idx_start = source_class_idx_max + 1 | |
| target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:], None, None) for d in target_datasets_name] | |
| task_datasets_seq = [] | |
| num_tasks_per_dataset = {} | |
| for td_info_i, td_info in enumerate(target_datasets_meta_info): | |
| if td_info_i >= 1: | |
| for _td_info_i, _td_info in enumerate(target_datasets_meta_info[0: td_info_i]): | |
| if _td_info.name == td_info.name: | |
| # print(111) | |
| # class_idx_offset = sum([len(t.classes) for t in target_datasets_meta_info[0: td_info_i]]) | |
| print(len(task_datasets_seq)) | |
| task_index_offset = sum([v if __i < _td_info_i else 0 for __i, v in enumerate(num_tasks_per_dataset.values())]) | |
| task_datasets_seq += task_datasets_seq[task_index_offset: task_index_offset + num_tasks_per_dataset[_td_info_i]] | |
| print(len(task_datasets_seq)) | |
| break | |
| continue | |
| td_classes = td_info.classes | |
| num_tasks_per_dataset[td_info_i] = 0 | |
| for ci in range(0, len(td_classes), num_classes_per_task): | |
| task_i = ci // num_classes_per_task | |
| task_datasets_seq += [_ABDatasetMetaInfo( | |
| f'{td_info.name}|task-{task_i}|ci-{ci}-{ci + num_classes_per_task - 1}', | |
| td_classes[ci: ci + num_classes_per_task], | |
| td_info.task_type, | |
| td_info.object_type, | |
| td_info.class_aliases, | |
| td_info.shift_type, | |
| td_classes[:ci] + td_classes[ci + num_classes_per_task: ], | |
| {cii: cii + target_class_idx_start for cii in range(ci, ci + num_classes_per_task)} | |
| )] | |
| num_tasks_per_dataset[td_info_i] += 1 | |
| if ci + num_classes_per_task < len(td_classes) - 1: | |
| task_datasets_seq += [_ABDatasetMetaInfo( | |
| f'{td_info.name}-task-{task_i + 1}|ci-{ci}-{ci + num_classes_per_task - 1}', | |
| td_classes[ci: len(td_classes)], | |
| td_info.task_type, | |
| td_info.object_type, | |
| td_info.class_aliases, | |
| td_info.shift_type, | |
| td_classes[:ci], | |
| {cii: cii + target_class_idx_start for cii in range(ci, len(td_classes))} | |
| )] | |
| num_tasks_per_dataset[td_info_i] += 1 | |
| target_class_idx_start += len(td_classes) | |
| if len(task_datasets_seq) < max_num_tasks: | |
| print(len(task_datasets_seq), max_num_tasks) | |
| raise RuntimeError() | |
| task_datasets_seq = task_datasets_seq[0: max_num_tasks] | |
| target_class_idx_start = max([max(list(td.idx_map.values())) + 1 for td in task_datasets_seq]) | |
| scenario = Scenario(config, task_datasets_seq, target_class_idx_start, source_class_idx_max + 1, data_dirs) | |
| if sanity_check: | |
| selected_tasks_index = [] | |
| for task_index, _ in enumerate(scenario.target_tasks_order): | |
| cur_datasets = scenario.get_cur_task_train_datasets() | |
| if len(cur_datasets) < 300: | |
| # empty_tasks_index += [task_index] | |
| # while True: | |
| # replaced_task_index = random.randint(0, task_index - 1) # ensure no random | |
| replaced_task_index = task_index // 2 | |
| assert replaced_task_index != task_index | |
| while replaced_task_index in selected_tasks_index: | |
| replaced_task_index += 1 | |
| task_datasets_seq[task_index] = deepcopy(task_datasets_seq[replaced_task_index]) | |
| selected_tasks_index += [replaced_task_index] | |
| logger.warning(f'replace {task_index}-th task with {replaced_task_index}-th task') | |
| # print(task_index, [t.name for t in task_datasets_seq]) | |
| scenario.next_task() | |
| # print([t.name for t in task_datasets_seq]) | |
| if len(selected_tasks_index) > 0: | |
| target_class_idx_start = max([max(list(td.idx_map.values())) + 1 for td in task_datasets_seq]) | |
| scenario = Scenario(config, task_datasets_seq, target_class_idx_start, source_class_idx_max + 1, data_dirs) | |
| for task_index, _ in enumerate(scenario.target_tasks_order): | |
| cur_datasets = scenario.get_cur_task_train_datasets() | |
| logger.info(f'task {task_index}, len {len(cur_datasets)}') | |
| assert len(cur_datasets) > 0 | |
| scenario.next_task() | |
| scenario = Scenario(config, task_datasets_seq, target_class_idx_start, source_class_idx_max + 1, data_dirs) | |
| return scenario |