| | import os |
| | import sys |
| | import torch |
| |
|
| | import numpy as np |
| | import core.model as arch |
| | import torch.optim as optim |
| | import torch.distributed as dist |
| | import torch.multiprocessing as mp |
| |
|
| | from pprint import pprint |
| | from contextlib import redirect_stdout |
| | from time import time |
| | from tqdm import tqdm |
| | from core.data import get_dataloader |
| | from core.utils import * |
| | from core.model.buffer import * |
| | from core.model import bic |
| | from torch.utils.data import DataLoader |
| | from core.utils import Logger, fmt_date_str |
| | from torch.optim.lr_scheduler import MultiStepLR, LambdaLR |
| | from copy import deepcopy |
| |
|
| | from core.scheduler import CosineSchedule, PatienceSchedule, CosineAnnealingWarmUp |
| |
|
| | class Trainer(object): |
| | """ |
| | The Trainer. |
| | |
| | Build a trainer from config dict, set up optimizer, model, etc. |
| | """ |
| |
|
| | def __init__(self, rank, config): |
| |
|
| | self.rank = rank |
| | self.config = config |
| | self.distribute = self.config['n_gpu'] > 1 |
| | assert not self.distribute |
| | if self.distribute: |
| | dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:23456', world_size=self.config['n_gpu'], rank=rank) |
| | self.logger = self._init_logger(config) |
| | self.device = self._init_device(config) |
| |
|
| | |
| | |
| | with redirect_stdout(self.logger.file): |
| | pprint(config) |
| | |
| | |
| | self.init_cls_num, self.inc_cls_num, self.task_num = self._init_data(config) |
| | self.model = self._init_model(config) |
| | ( |
| | self.train_loader, |
| | self.test_loader, |
| | ) = self._init_dataloader(config) |
| | |
| | self.buffer = self._init_buffer(config) |
| |
|
| | self.task_idx = 0 |
| | ( |
| | self.init_epoch, |
| | self.inc_epoch, |
| | self.optimizer, |
| | self.scheduler, |
| | ) = self._init_optim(config) |
| |
|
| | self.train_meter, self.test_meter = self._init_meter() |
| |
|
| | self.val_per_epoch = config['val_per_epoch'] |
| |
|
| | if self.config["classifier"]["name"] == "bic": |
| | self.stage2_epoch = config['stage2_epoch'] |
| |
|
| | def _init_logger(self, config, mode='train'): |
| | ''' |
| | Init logger. |
| | |
| | Args: |
| | config (dict): Parsed config file. |
| | |
| | Returns: |
| | logger (Logger) |
| | ''' |
| |
|
| | save_path = config['save_path'] |
| |
|
| | log_path = os.path.join(save_path, "log", config['classifier']['name']) |
| | os.makedirs(log_path, exist_ok=True) |
| | |
| | log_prefix = f"{config['dataset']}..{config['backbone']['name']}--ep{config['epoch']}--s{config['seed']}__{datetime.now().strftime('%Y-%m-%d_%H-%M')}" |
| | log_file = os.path.join(log_path, f"{log_prefix}.log") |
| | logger = Logger(log_file) |
| |
|
| | |
| | sys.stdout = logger |
| |
|
| | return logger |
| |
|
| | def _init_device(self, config): |
| | """" |
| | Init the devices from the config. |
| | |
| | Args: |
| | config(dict): Parsed config file. |
| | |
| | Returns: |
| | device: a device. |
| | """ |
| | init_seed(config['seed'], config['deterministic']) |
| |
|
| | device = torch.device(f'cuda:{config["device_ids"][self.rank]}') |
| | torch.cuda.set_device(device) |
| |
|
| | return device |
| |
|
| | def _init_files(self, config): |
| | pass |
| |
|
| | def _init_writer(self, config): |
| | pass |
| |
|
| | def _init_meter(self, ): |
| | """ |
| | Init the AverageMeter of train/val/test stage to cal avg... of batch_time, data_time,calc_time ,loss and acc1. |
| | |
| | Returns: |
| | tuple: A tuple of train_meter, val_meter, test_meter. |
| | """ |
| | train_meter = AverageMeter( |
| | "train", |
| | ["batch_time", "data_time", "calc_time", "loss", "acc1"], |
| | ) |
| |
|
| | test_meter = AverageMeter( |
| | "test", |
| | ["batch_time", "data_time", "calc_time", "acc1"], |
| | ) |
| |
|
| | return train_meter, test_meter |
| |
|
| | def _init_optim(self, config): |
| | """ |
| | Init the optimizers and scheduler from config, if necessary, load the state dict from a checkpoint. |
| | |
| | Args: |
| | config (dict): Parsed config file. |
| | |
| | Returns: |
| | tuple: A tuple of optimizer, scheduler. |
| | """ |
| |
|
| | if 'init_epoch' in config.keys(): |
| | init_epoch = config['init_epoch'] |
| | else: |
| | init_epoch = config['epoch'] |
| |
|
| | model = self.model.module if self.distribute else self.model |
| |
|
| | if self.task_idx == 0 and 'init_optimizer' in config.keys(): |
| | optimizer = get_instance( |
| | torch.optim, "init_optimizer", config, params=model.get_parameters(config) |
| | ) |
| | else: |
| | optimizer = get_instance( |
| | torch.optim, "optimizer", config, params=model.get_parameters(config) |
| | ) |
| |
|
| | |
| | if config['lr_scheduler']['name'] == "CosineSchedule": |
| | scheduler = CosineSchedule(optimizer, K=config['lr_scheduler']['kwargs']['K']) |
| | elif config['lr_scheduler']['name'] == "PatienceSchedule": |
| | scheduler = PatienceSchedule(optimizer, patience = config['lr_scheduler']['kwargs']['patience'], factor = config['lr_scheduler']['kwargs']['factor']) |
| | elif config['lr_scheduler']['name'] == "Constant": |
| | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda e: 1) |
| | elif config['lr_scheduler']['name'] == "CosineAnnealingWarmUp": |
| | T_max = len(self.train_loader.get_loader(self.task_idx)) |
| | T_max *= init_epoch if self.task_idx == 0 else config['epoch'] |
| | scheduler = CosineAnnealingWarmUp(optimizer, config['lr_scheduler']['kwargs']['warmup_length'], T_max) |
| | else: |
| | scheduler = get_instance(torch.optim.lr_scheduler, "lr_scheduler", config, optimizer=optimizer) |
| |
|
| | return init_epoch, config['epoch'], optimizer, scheduler |
| |
|
| | def _init_data(self, config): |
| | return config['init_cls_num'], config['inc_cls_num'], config['task_num'] |
| |
|
| | def _init_model(self, config): |
| | """ |
| | Init model(backbone+classifier) from the config dict and load the pretrained params or resume from a |
| | checkpoint, then parallel if necessary . |
| | |
| | Args: |
| | config (dict): Parsed config file. |
| | |
| | Returns: |
| | tuple: A tuple of the model and model's type. |
| | """ |
| | |
| | try: |
| | backbone = get_instance(arch, "backbone", config, **{'device': self.device}) |
| | except TypeError: |
| | backbone = get_instance(arch, "backbone", config) |
| |
|
| | model = get_instance(arch, "classifier", config, **{'device': self.device, 'backbone': backbone}).to(self.device) |
| |
|
| | if self.distribute: |
| | model = torch.nn.parallel.DistributedDataParallel( |
| | model, |
| | device_ids=[self.device] |
| | ) |
| |
|
| | return model |
| | |
| | def _init_dataloader(self, config): |
| | ''' |
| | Init DataLoader |
| | |
| | Args: |
| | config (dict): Parsed config file. |
| | |
| | Returns: |
| | train_loaders (list): Each task's train dataloader. |
| | test_loaders (list): Each task's test dataloader. |
| | ''' |
| |
|
| | train_loaders = get_dataloader(config, "train") |
| | test_loaders = get_dataloader(config, "test", cls_map=train_loaders.cls_map) |
| |
|
| | |
| | if self.distribute: |
| | for loaders in [train_loaders, test_loaders]: |
| | for i, dataloader in enumerate(loaders.dataloaders): |
| | dataset = dataloader.dataset |
| | sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True) |
| | loaders.dataloaders[i] = DataLoader( |
| | dataset, |
| | sampler=sampler, |
| | batch_size=dataloader.batch_size // self.config['n_gpu'], |
| | num_workers=dataloader.num_workers, |
| | drop_last=dataloader.drop_last |
| | ) |
| |
|
| | return train_loaders, test_loaders |
| | |
| | def _init_buffer(self, config): |
| | ''' |
| | Init Buffer |
| | |
| | Args: |
| | config (dict): Parsed config file. |
| | |
| | Returns: |
| | buffer (Buffer): a buffer for old samples. |
| | ''' |
| | buffer = get_instance(arch, "buffer", config) |
| |
|
| | return buffer |
| |
|
| | def train_loop(self,): |
| | """ |
| | The norm train loop: before_task, train, test, after_task |
| | """ |
| | experiment_begin = time() |
| | method_name = self.config["classifier"]["name"] |
| | testing_times = self.config['testing_times'] |
| |
|
| | |
| | batch_last_acc_list = np.zeros((self.task_num)) |
| | task_last_acc_list = np.zeros((self.task_num)) |
| |
|
| | |
| | best_batch_last_acc_list = np.zeros((self.task_num)) |
| | best_task_last_acc_list = np.zeros((self.task_num)) |
| |
|
| | acc_table = np.zeros((self.task_num, self.task_num)) |
| | bwt_list, frgt_list = [], [] |
| |
|
| | model = self.model.module if self.distribute else self.model |
| | |
| | if method_name == 'RAPF': |
| | model.model.classes_names = self.train_loader.cls_map |
| |
|
| | for task_idx in range(self.task_num): |
| | self.task_idx = task_idx |
| | if self.rank == 0: |
| | print(f"================Task {task_idx} Start!================") |
| | |
| | if hasattr(model, 'before_task'): |
| | model.before_task(task_idx, self.buffer, self.train_loader.get_loader(task_idx), self.test_loader.get_loader(task_idx)) |
| | |
| | if self.rank == 0: |
| | print(f"Trainable Parameters for Task {task_idx} : {count_parameters(model)} / {count_all_parameters(model)} ({count_parameters(model)*100/count_all_parameters(model):.2f}%)") |
| |
|
| | _, _, self.optimizer, self.scheduler = self._init_optim(self.config) |
| | dataloader = self.train_loader.get_loader(task_idx) |
| |
|
| | if method_name == "bic": |
| |
|
| | w_decay = 2e-4 * self.task_num / (task_idx + 1) |
| | self.optimizer = optim.SGD(model.get_parameters(self.config), lr = 0.1, momentum = 0.9, weight_decay = w_decay) |
| | self.scheduler = MultiStepLR(self.optimizer, milestones = [100, 150, 200], gamma = 0.1) |
| |
|
| | dataloader, val_bias_dataloader = self.model.spilt_and_update(dataloader, self.buffer, task_idx, self.config) |
| |
|
| | elif isinstance(self.buffer, (LinearBuffer, LinearHerdingBuffer)) and self.buffer.buffer_size > 0 and task_idx > 0: |
| | datasets = dataloader.dataset |
| | if isinstance(datasets.images, list): |
| | datasets.images.extend(self.buffer.images) |
| | datasets.labels.extend(self.buffer.labels) |
| | elif isinstance(datasets.images, np.ndarray): |
| | datasets.images = np.concatenate((datasets.images, self.buffer.images), axis=0) |
| | datasets.labels = np.concatenate((datasets.labels, self.buffer.labels), axis=0) |
| | else: |
| | assert 0 |
| |
|
| | dataloader = DataLoader( |
| | datasets, |
| | shuffle = True, |
| | batch_size = self.config['batch_size'], |
| | drop_last = False, |
| | num_workers = self.config['num_workers'] |
| | ) |
| | |
| | if method_name in ["LoRAsub_DRS"]: |
| | print('Replacing Optim & Scheduler') |
| | self.optimizer = self.model.get_optimizer(self.config['optimizer']['kwargs']['lr'], self.config['optimizer']['kwargs']['weight_decay']) |
| | self.scheduler = CosineSchedule(self.optimizer, K=self.config['epoch']) |
| |
|
| | if method_name == 'CL_LoRA': |
| | self.model.set_optim(self.optimizer) |
| |
|
| | if self.rank == 0: |
| | print(f"================Task {task_idx} Training!================") |
| | print(f"The training samples number : {len(dataloader.dataset)}") |
| | |
| | |
| | best_batch_last_acc, best_task_last_acc = 0., 0. |
| | best_bwt, best_frgt = float('-inf'), float('inf') |
| |
|
| | for epoch_idx in range(self.init_epoch if task_idx == 0 else self.inc_epoch): |
| | if self.rank == 0: |
| | print("================Train on train set================") |
| | train_meter = self._train(epoch_idx, dataloader) |
| |
|
| | acc1 = torch.tensor(train_meter.avg("acc1"), device=self.device) |
| | loss = torch.tensor(train_meter.avg("loss"), device=self.device) |
| | if self.distribute: |
| | |
| | dist.barrier() |
| | dist.all_reduce(acc1, op=dist.ReduceOp.SUM) |
| | dist.all_reduce(loss, op=dist.ReduceOp.SUM) |
| | acc1 = acc1 / self.config['n_gpu'] |
| | loss = loss / self.config['n_gpu'] |
| | dist.barrier() |
| |
|
| | acc1 = acc1.item() |
| | loss = loss.item() |
| | |
| | if self.rank == 0: |
| | print(f"Epoch [{epoch_idx}/{self.init_epoch if task_idx == 0 else self.inc_epoch}] Learning Rate {self.scheduler.get_last_lr()}\t|\tLoss: {loss:.4f} \tAverage Acc: {acc1:.2f} ") |
| |
|
| | if (epoch_idx+1) % self.val_per_epoch == 0 or (epoch_idx+1) == self.inc_epoch: |
| | if self.rank == 0: |
| | print(f"================Validation on test set================") |
| |
|
| | |
| | if method_name in ['TRGP', |
| | 'RanPAC', |
| | 'MInfLoRA2', |
| | 'MInfLoRA3', |
| | 'PRAKA', |
| | 'TRGP_CLIP', |
| | 'LoRAsub_DRS', |
| | 'CL_LoRA' |
| | ]: |
| | if self.rank == 0: |
| | print(f" * Disabled validation for this method") |
| | else: |
| | test_acc = self._validate(task_idx) |
| |
|
| | batch_last_acc, per_task_acc = test_acc['avg_acc'], test_acc['per_task_acc'] |
| | best_batch_last_acc = max(batch_last_acc, best_batch_last_acc) |
| |
|
| | task_last_acc = np.mean(per_task_acc) |
| | best_task_last_acc = max(task_last_acc, best_task_last_acc) |
| |
|
| | frgt, bwt = compute_frgt(acc_table, per_task_acc, task_idx), compute_bwt(acc_table, per_task_acc, task_idx) |
| | best_frgt, best_bwt = min(frgt, best_frgt), max(bwt, best_bwt) |
| |
|
| | if self.rank == 0: |
| | print(f" * [Batch] Last Average Acc: {batch_last_acc:.2f} (Best: {best_batch_last_acc:.2f})") |
| | print(f" * [Task] Last Average Acc: {task_last_acc:.2f} (Best: {best_task_last_acc:.2f})") |
| | print(f" * Forgetting: {frgt:.3f} (Best: {best_frgt:.3f})") |
| | print(f" * Backward Transfer: {bwt:.2f} (Best: {best_bwt:.2f})") |
| | print(f" * Per-Task Acc: {per_task_acc}") |
| | |
| | if self.config['lr_scheduler']['name'] == "PatienceSchedule": |
| | self.scheduler.step(train_meter.avg('loss')) |
| | if self.scheduler.get_last_lr() < self.config['lr_scheduler']['kwargs']['stopping_lr']: |
| | if self.rank == 0: |
| | print(f"{self.scheduler.get_last_lr()} < {self.config['lr_scheduler']['kwargs']['stopping_lr']}, stopping this task now") |
| | break |
| | else: |
| | self.scheduler.step() |
| |
|
| | if hasattr(model, 'after_task'): |
| | model.after_task(task_idx, self.buffer, self.train_loader.get_loader(task_idx), self.test_loader.get_loader(task_idx)) |
| |
|
| | |
| | if method_name not in ['bic', 'ERACE', 'ERAML']: |
| | self.buffer.total_classes += self.init_cls_num if task_idx == 0 else self.inc_cls_num |
| | if self.buffer.buffer_size > 0: |
| | if self.buffer.strategy == 'herding': |
| | herding_update(self.train_loader.get_loader(task_idx).dataset, self.buffer, model.backbone, self.device) |
| | elif self.buffer.strategy == 'random': |
| | random_update(self.train_loader.get_loader(task_idx).dataset, self.buffer) |
| | elif self.buffer.strategy == 'balance_random': |
| | balance_random_update(self.train_loader.get_loader(task_idx).dataset, self.buffer) |
| |
|
| | |
| | if self.config["classifier"]["name"] == "bic" and task_idx > 0: |
| |
|
| | bias_scheduler = optim.lr_scheduler.LambdaLR(model.bias_optimizer, lr_lambda=lambda e: 1) |
| |
|
| | for epoch_idx in range(self.stage2_epoch): |
| | if self.rank == 0: |
| | print("================ Train on the train set (stage2)================") |
| | train_meter = self.stage2_train(epoch_idx, val_bias_dataloader) |
| |
|
| | if self.rank == 0: |
| | print(f"Epoch [{epoch_idx}/{self.stage2_epoch}] Learning Rate {bias_scheduler.get_last_lr()}\t|\tLoss: {train_meter.avg('loss'):.4f} \tAverage Acc: {train_meter.avg('acc1'):.2f} ") |
| |
|
| | if (epoch_idx+1) % self.val_per_epoch == 0 or (epoch_idx+1) == self.inc_epoch: |
| | if self.rank == 0: |
| | print("================ Test on the test set (stage2)================") |
| |
|
| | test_acc = self._validate(task_idx) |
| |
|
| | batch_last_acc, per_task_acc = test_acc['avg_acc'], test_acc['per_task_acc'] |
| | best_batch_last_acc = max(batch_last_acc, best_batch_last_acc) |
| |
|
| | task_last_acc = np.mean(per_task_acc) |
| | best_task_last_acc = max(task_last_acc, best_task_last_acc) |
| |
|
| | frgt, bwt = compute_frgt(acc_table, per_task_acc, task_idx), compute_bwt(acc_table, per_task_acc, task_idx) |
| | best_frgt, best_bwt = min(frgt, best_frgt), max(bwt, best_bwt) |
| |
|
| | if self.rank == 0: |
| | print(f" * [Batch] Last Average Acc: {batch_last_acc:.2f} (Best: {best_batch_last_acc:.2f})") |
| | print(f" * [Task] Last Average Acc: {task_last_acc:.2f} (Best: {best_task_last_acc:.2f})") |
| | print(f" * Forgetting: {frgt:.3f} (Best: {best_frgt:.3f})") |
| | print(f" * Backward Transfer: {bwt:.2f} (Best: {best_bwt:.2f})") |
| | print(f" * Per-Task Acc: {per_task_acc}") |
| |
|
| | |
| |
|
| | for test_idx in range(testing_times): |
| | if self.rank == 0: |
| | print(f"================Test {test_idx+1}/{testing_times} of Task {task_idx}!================") |
| |
|
| | test_acc = self._validate(task_idx) |
| |
|
| | batch_last_acc, per_task_acc = test_acc['avg_acc'], test_acc['per_task_acc'] |
| | best_batch_last_acc = max(batch_last_acc, best_batch_last_acc) |
| |
|
| | task_last_acc = np.mean(per_task_acc) |
| | best_task_last_acc = max(task_last_acc, best_task_last_acc) |
| |
|
| | frgt, bwt = compute_frgt(acc_table, per_task_acc, task_idx), compute_bwt(acc_table, per_task_acc, task_idx) |
| | best_frgt, best_bwt = min(frgt, best_frgt), max(bwt, best_bwt) |
| |
|
| | if self.rank == 0: |
| | print(f" * [Batch] Last Average Acc: {batch_last_acc:.2f} (Best: {best_batch_last_acc:.2f})") |
| | print(f" * [Task] Last Average Acc: {task_last_acc:.2f} (Best: {best_task_last_acc:.2f})") |
| | print(f" * Forgetting: {frgt:.3f} (Best: {best_frgt:.3f})") |
| | print(f" * Backward Transfer: {bwt:.2f} (Best: {best_bwt:.2f})") |
| | print(f" * Per-Task Acc: {per_task_acc}") |
| |
|
| | batch_last_acc_list[task_idx] += batch_last_acc |
| | task_last_acc_list[task_idx] += task_last_acc |
| | acc_table[task_idx][:task_idx + 1] += np.array(per_task_acc) |
| |
|
| | best_batch_last_acc_list[task_idx] = best_batch_last_acc |
| | best_task_last_acc_list[task_idx] = best_task_last_acc |
| |
|
| | |
| | batch_last_acc_list[task_idx] /= testing_times |
| | task_last_acc_list[task_idx] /= testing_times |
| | acc_table[task_idx] /= testing_times |
| |
|
| | batch_last_acc = batch_last_acc_list[task_idx] |
| | task_last_acc = task_last_acc_list[task_idx] |
| |
|
| | frgt, bwt = compute_frgt(acc_table, acc_table[task_idx], task_idx), compute_bwt(acc_table, acc_table[task_idx], task_idx) |
| | best_frgt, best_bwt = min(frgt, best_frgt), max(bwt, best_bwt) |
| | if task_idx > 1: |
| | frgt_list.append(frgt) |
| | bwt_list.append(bwt) |
| | |
| | if self.rank == 0: |
| | print(f"================Result of Task {task_idx} Testing!================") |
| | print(f" * [Batch] Last Average Acc: {batch_last_acc:.2f} (Best: {best_batch_last_acc:.2f})") |
| | print(f" * [Task] Last Average Acc: {task_last_acc:.2f} (Best: {best_task_last_acc:.2f})") |
| | print(f" * Forgetting: {frgt:.3f} (Best: {best_frgt:.3f})") |
| | print(f" * Backward Transfer: {bwt:.2f} (Best: {best_bwt:.2f})") |
| | print(f" * Per-Task Acc: {acc_table[task_idx][:task_idx + 1]}") |
| |
|
| | batch_ovr_avg_acc = np.mean(batch_last_acc_list) |
| | best_batch_ovr_avg_acc = np.mean(best_batch_last_acc_list) |
| | |
| | task_ovr_avg_acc = np.sum(np.sum(acc_table[:task_idx + 1], axis = 1) / np.arange(1, task_idx + 2)) / (task_idx + 1) |
| | |
| | ovr_bwt = np.mean(bwt_list) if len(bwt_list) > 0 else float('-inf') |
| | ovr_frgt = np.mean(frgt_list) if len(frgt_list) > 0 else float('inf') |
| |
|
| | if self.rank == 0: |
| | print(f"================Overall Result of {self.task_num} Tasks!================") |
| | print(f" * [Batch] Last Average Acc: {batch_last_acc:.2f} (Best: {best_batch_last_acc:.2f})") |
| | print(f" * [Task] Last Average Acc: {task_last_acc:.2f} (Best: {best_task_last_acc:.2f})") |
| | print(f" * Forgetting: {frgt:.3f} (Best: {best_frgt:.3f})") |
| | print(f" * Backward Transfer: {bwt:.2f} (Best: {best_bwt:.2f})") |
| | print(f" * [Batch] Overall Avg Acc : {batch_ovr_avg_acc:.2f} (Best: {best_batch_ovr_avg_acc:.2f})") |
| | print(f" * [Task] Overall Avg Acc : {task_ovr_avg_acc:.2f}") |
| | print(f" * Overall Frgt : {ovr_frgt:.3f}") |
| | print(f" * Overall BwT : {ovr_bwt:.2f}") |
| | print(f" * Average Acc Table : \n{acc_table}") |
| |
|
| | print(f"================Model Performance Analysis================") |
| | print(f" * Time Costs : {(time() - experiment_begin):.2f} sec") |
| | fps = compute_fps(model, self.config) |
| | avg_fps, best_fps = fps['avg_fps'], fps['best_fps'] |
| | print(f" * Average FPS (Best FPS) : {avg_fps:.0f} ({best_fps:.0f})") |
| |
|
| | def stage2_train(self, epoch_idx, dataloader): |
| | """ |
| | The stage 2 train stage of method : BIC |
| | |
| | Args: |
| | epoch_idx (int): Epoch index |
| | |
| | Returns: |
| | dict: {"avg_acc": float} |
| | """ |
| | model = self.model.module if self.distribute else self.model |
| |
|
| | model.eval() |
| | for layer in model.bias_layers: |
| | layer.train() |
| |
|
| | meter = self.train_meter |
| | meter.reset() |
| | |
| | total = len(dataloader) |
| | for b, batch in tqdm(enumerate(dataloader), total=total, disable=(self.rank != 0)): |
| |
|
| | output, acc, loss = model.stage2(batch) |
| | |
| | meter.update("acc1", 100 * acc) |
| | meter.update("loss", loss.item()) |
| |
|
| | return meter |
| |
|
| | def _train(self, epoch_idx, dataloader): |
| | """ |
| | The train stage. |
| | |
| | Args: |
| | epoch_idx (int): Epoch index |
| | |
| | Returns: |
| | dict: {"avg_acc": float} |
| | """ |
| | model = self.model.module if self.distribute else self.model |
| |
|
| | model.train() |
| | if self.config['classifier']['name'] == 'bic': |
| | for layer in model.bias_layers: |
| | layer.eval() |
| | |
| | meter = deepcopy(self.train_meter) |
| | meter.reset() |
| |
|
| | total = len(dataloader) |
| | init_seed(self.config['seed'] + epoch_idx, self.config['deterministic']) |
| | for b, batch in tqdm(enumerate(dataloader), total=total, disable=(self.rank != 0)): |
| | |
| | batch['batch_id'] = b |
| |
|
| | |
| | if self.config['classifier']['name'] in ['MOE_ADAPTER4CL', 'DMNSP', 'DMNSP_CIL']: |
| | self.scheduler.step(total * epoch_idx + b) |
| |
|
| | if self.config["classifier"]["name"] in ['TRGP', 'DMNSP', 'DMNSP_CIL', 'TRGP_CLIP', |
| | 'GPM', 'MoE_Test2', 'API', 'L2P']: |
| | self.optimizer.zero_grad() |
| | output, acc, loss = model.observe(batch) |
| | elif self.config["classifier"]["name"] in ['bic']: |
| | output, acc, loss = model.observe(batch) |
| | self.optimizer.zero_grad() |
| | loss.backward(retain_graph=True) |
| | else: |
| | output, acc, loss = model.observe(batch) |
| | self.optimizer.zero_grad() |
| | loss.backward() |
| |
|
| | self.optimizer.step() |
| |
|
| | if self.config["classifier"]["name"] in ['ERACE', 'ERAML']: |
| | model.add_reservoir() |
| |
|
| | meter.update("acc1", 100 * acc) |
| | meter.update("loss", loss.item()) |
| |
|
| | return meter |
| |
|
| | def _validate(self, task_idx): |
| |
|
| | dataloaders = self.test_loader.get_loader(task_idx) |
| |
|
| | model = self.model.module if self.distribute else self.model |
| | model.eval() |
| |
|
| | if self.config["classifier"]["name"] == 'bic': |
| | for layer in model.bias_layers: |
| | layer.eval() |
| |
|
| | per_task_acc = [] |
| | count_all, correct_all = 0, 0 |
| |
|
| | if self.config['testing_per_task']: |
| |
|
| | count_task, correct_task = 0, 0 |
| |
|
| | with torch.no_grad(): |
| | for t, dataloader in enumerate(dataloaders): |
| | correct_task, count_task = 0, 0 |
| |
|
| | for b, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc = f"Testing on Task {t} data", disable=self.rank != 0): |
| | |
| | if self.config['setting'] == 'task-aware': |
| | output, acc = model.inference(batch, task_id=t) |
| | elif self.config['setting'] == 'task-agnostic': |
| | output, acc = model.inference(batch) |
| | |
| | correct_task += int(acc * batch['label'].shape[0]) |
| | count_task += batch['label'].shape[0] |
| |
|
| | correct_all += correct_task |
| | count_all += count_task |
| |
|
| | if self.distribute: |
| | pass |
| |
|
| | per_task_acc.append(round(correct_task * 100 / count_task, 2)) |
| |
|
| | if self.distribute: |
| | pass |
| |
|
| | else: |
| |
|
| | datasets = [dl.dataset for dl in dataloaders] |
| |
|
| | all_images = np.concatenate([ds.images for ds in datasets], axis=0) |
| | all_labels = np.concatenate([ds.labels for ds in datasets], axis=0) |
| |
|
| | merged_dataset = copy.deepcopy(datasets[0]) |
| | merged_dataset.images = all_images |
| | merged_dataset.labels = all_labels |
| |
|
| | merged_loader = DataLoader( |
| | merged_dataset, |
| | shuffle = True, |
| | batch_size = self.config['batch_size'], |
| | drop_last = False, |
| | num_workers = self.config['num_workers'], |
| | pin_memory=False |
| | ) |
| |
|
| | class_boundaries = [] |
| | start_cls = 0 |
| | for t in range(task_idx + 1): |
| | n_cls = self.init_cls_num if t == 0 else self.inc_cls_num |
| | class_boundaries.append((start_cls, start_cls + n_cls)) |
| | start_cls += n_cls |
| |
|
| | correct_by_task = np.zeros(task_idx + 1, dtype=int) |
| | count_by_task = np.zeros(task_idx + 1, dtype=int) |
| |
|
| | |
| | with torch.no_grad(): |
| | for b, batch in tqdm(enumerate(merged_loader), total=len(merged_loader), desc=f"Testing merged tasks <= {task_idx}", disable=self.rank != 0): |
| |
|
| | if self.config['setting'] == 'task-aware': |
| | print('Mostly methods dont support this, set testing_per_task to False') |
| | raise NotImplementedError |
| | output, acc = model.inference(batch, task_id=None) |
| | elif self.config['setting'] == 'task-agnostic': |
| | output, acc = model.inference(batch) |
| | preds = output.cpu().numpy() |
| |
|
| | labels = batch['label'].cpu().numpy() |
| | correct_all += np.sum(preds == labels) |
| |
|
| | count_all += len(labels) |
| |
|
| | |
| | for t, (start, end) in enumerate(class_boundaries): |
| | mask = (labels >= start) & (labels < end) |
| | if np.any(mask): |
| | correct_by_task[t] += np.sum(preds[mask] == labels[mask]) |
| | count_by_task[t] += np.sum(mask) |
| |
|
| | per_task_acc = [round(c * 100 / n, 2) if n > 0 else 0 for c, n in zip(correct_by_task, count_by_task)] |
| |
|
| | avg_acc = round(correct_all * 100 / count_all, 2) |
| |
|
| | return { |
| | "avg_acc": avg_acc, |
| | "per_task_acc": per_task_acc |
| | } |
| |
|