Spaces:
Running
Running
| import enum | |
| from functools import reduce | |
| from typing import Dict, List, Tuple | |
| import numpy as np | |
| import copy | |
| from utils.common.log import logger | |
| from ..datasets.ab_dataset import ABDataset | |
| from ..dataloader import FastDataLoader, InfiniteDataLoader, build_dataloader | |
| from data import get_dataset, MergedDataset, Scenario as DAScenario | |
| class _ABDatasetMetaInfo: | |
| def __init__(self, name, classes, task_type, object_type, class_aliases, shift_type, ignore_classes, idx_map): | |
| self.name = name | |
| self.classes = classes | |
| self.class_aliases = class_aliases | |
| self.shift_type = shift_type | |
| self.task_type = task_type | |
| self.object_type = object_type | |
| self.ignore_classes = ignore_classes | |
| self.idx_map = idx_map | |
| def __repr__(self) -> str: | |
| return f'({self.name}, {self.classes}, {self.idx_map})' | |
| class Scenario: | |
| def __init__(self, config, target_datasets_info: List[_ABDatasetMetaInfo], num_classes: int, num_source_classes: int, data_dirs): | |
| self.config = config | |
| self.target_datasets_info = target_datasets_info | |
| self.num_classes = num_classes | |
| self.cur_task_index = 0 | |
| self.num_source_classes = num_source_classes | |
| self.cur_class_offset = num_source_classes | |
| self.data_dirs = data_dirs | |
| self.target_tasks_order = [i.name for i in self.target_datasets_info] | |
| self.num_tasks_to_be_learn = sum([len(i.classes) for i in target_datasets_info]) | |
| logger.info(f'[scenario build] # classes: {num_classes}, # tasks to be learnt: {len(target_datasets_info)}, ' | |
| f'# classes per task: {config["num_classes_per_task"]}') | |
| def to_json(self): | |
| config = copy.deepcopy(self.config) | |
| config['da_scenario'] = config['da_scenario'].to_json() | |
| target_datasets_info = [str(i) for i in self.target_datasets_info] | |
| return dict( | |
| config=config, target_datasets_info=target_datasets_info, | |
| num_classes=self.num_classes | |
| ) | |
| def __str__(self): | |
| return f'Scenario({self.to_json()})' | |
| def get_cur_class_offset(self): | |
| return self.cur_class_offset | |
| def get_cur_num_class(self): | |
| return len(self.target_datasets_info[self.cur_task_index].classes) | |
| def get_nc_per_task(self): | |
| return len(self.target_datasets_info[0].classes) | |
| def next_task(self): | |
| self.cur_class_offset += len(self.target_datasets_info[self.cur_task_index].classes) | |
| self.cur_task_index += 1 | |
| print(f'now, cur task: {self.cur_task_index}, cur_class_offset: {self.cur_class_offset}') | |
| def get_cur_task_datasets(self): | |
| dataset_info = self.target_datasets_info[self.cur_task_index] | |
| dataset_name = dataset_info.name.split('|')[0] | |
| # print() | |
| # source_datasets_info = [] | |
| res ={ **{split: get_dataset(dataset_name=dataset_name, | |
| root_dir=self.data_dirs[dataset_name], | |
| split=split, | |
| transform=None, | |
| ignore_classes=dataset_info.ignore_classes, | |
| idx_map=dataset_info.idx_map) for split in ['train']}, | |
| **{split: MergedDataset([get_dataset(dataset_name=dataset_name, | |
| root_dir=self.data_dirs[dataset_name], | |
| split=split, | |
| transform=None, | |
| ignore_classes=di.ignore_classes, | |
| idx_map=di.idx_map) for di in self.target_datasets_info[0: self.cur_task_index + 1]]) | |
| for split in ['val', 'test']} | |
| } | |
| # if len(res['train']) < 200 or len(res['val']) < 200 or len(res['test']) < 200: | |
| # return None | |
| if len(res['train']) < 1000: | |
| res['train'] = MergedDataset([res['train']] * 5) | |
| logger.info('aug train dataset') | |
| if len(res['val']) < 1000: | |
| res['val'] = MergedDataset(res['val'].datasets * 5) | |
| logger.info('aug val dataset') | |
| if len(res['test']) < 1000: | |
| res['test'] = MergedDataset(res['test'].datasets * 5) | |
| logger.info('aug test dataset') | |
| # da_scenario: DAScenario = self.config['da_scenario'] | |
| # offline_datasets = da_scenario.get_offline_datasets() | |
| for k, v in res.items(): | |
| logger.info(f'{k} dataset: {len(v)}') | |
| # new_val_datasets = [ | |
| # *[d['val'] for d in offline_datasets.values()], | |
| # res['val'] | |
| # ] | |
| # res['val'] = MergedDataset(new_val_datasets) | |
| # new_test_datasets = [ | |
| # *[d['test'] for d in offline_datasets.values()], | |
| # res['test'] | |
| # ] | |
| # res['test'] = MergedDataset(new_test_datasets) | |
| return res | |
| def get_cur_task_train_datasets(self): | |
| dataset_info = self.target_datasets_info[self.cur_task_index] | |
| dataset_name = dataset_info.name.split('|')[0] | |
| # print() | |
| # source_datasets_info = [] | |
| res = get_dataset(dataset_name=dataset_name, | |
| root_dir=self.data_dirs[dataset_name], | |
| split='train', | |
| transform=None, | |
| ignore_classes=dataset_info.ignore_classes, | |
| idx_map=dataset_info.idx_map) | |
| return res | |
| def get_online_cur_task_samples_for_training(self, num_samples): | |
| dataset = self.get_cur_task_datasets() | |
| dataset = dataset['train'] | |
| return next(iter(build_dataloader(dataset, num_samples, 0, True, None)))[0] |