File size: 3,735 Bytes
b389d26 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | import os
import json
import random
import hydra
import logging
import numpy as np
from omegaconf import DictConfig
from tqdm import tqdm
import torch
import statistics
from continuum.metrics import Logger
from continual_clip import utils
from continual_clip.models import load_model
from continual_clip.datasets import build_cl_scenarios
from torch.utils.data import DataLoader, DistributedSampler
WORLD_NUM = 1
@hydra.main(config_path=None, config_name=None, version_base="1.1")
def continual_clip(cfg: DictConfig) -> None:
set_seed(RANDOM_SEED)
cfg.workdir = "/***/DMNSP/cil"
cfg.dataset_root = os.path.join(cfg.workdir, cfg.dataset_root)
utils.save_config(cfg)
cfg.class_order = utils.get_class_order(os.path.join(cfg.workdir, cfg.class_order))
origin_flag = False
devices = [0]
model = load_model(cfg, devices[0], origin_flag)
eval_dataset, classes_names = build_cl_scenarios(
cfg, is_train=False, transforms=model.transforms
)
print(eval_dataset, eval_dataset)
train_dataset, train_classes_names = build_cl_scenarios(
cfg, is_train=True, transforms=model.transforms
)
model.classes_names = classes_names
print("Using devices", devices)
model = torch.nn.DataParallel(model, device_ids=devices)
with open(cfg.log_path, 'w+') as f:
pass
acc_list = []
forgetting_list = []
metric_logger = Logger(list_subsets=["test"])
world = WORLD_NUM
for task_id, _ in enumerate(eval_dataset):
logging.info(f"Evaluation for task {task_id} has started.")
model.module.adaptation(task_id, cfg, train_dataset, train_classes_names, world) # task id 已经传入mode
eval_sampler = DistributedSampler(eval_dataset[:task_id + 1], num_replicas=world, rank=0)
eval_loader = DataLoader(eval_dataset[:task_id + 1], batch_size=64, sampler=eval_sampler, num_workers=8)
for inputs, targets, task_ids in tqdm(eval_loader):
inputs, targets = inputs.cuda(device=devices[0]), targets.cuda(device=devices[0])
outputs = model.module.cuda(devices[0])(inputs.cuda(devices[0]), task_ids)
metric_logger.add([outputs.cpu().argmax(dim=1), targets.cpu(), task_ids], subset="test")
acc_list.append(100 * metric_logger.accuracy)
forgetting_list.append(100 * metric_logger.forgetting)
with open(cfg.log_path, 'a+') as f:
f.write(json.dumps({
'task': task_id,
'acc': round(100 * metric_logger.accuracy, 2),
'avg_acc': round(100 * metric_logger.average_incremental_accuracy, 2),
'forgetting': round(100 * metric_logger.forgetting, 6),
'acc_per_task': [round(100 * acc_t, 2) for acc_t in metric_logger.accuracy_per_task],
'bwt': round(100 * metric_logger.backward_transfer, 2),
'fwt': round(100 * metric_logger.forward_transfer, 2),
}) + '\n')
metric_logger.end_task()
with open(cfg.log_path, 'a+') as f:
f.write(json.dumps({
'last_Cifar100': round(acc_list[-1], 2),
'avg_Cifar100': round(statistics.mean(acc_list), 2),
'avg_forgetting': round(statistics.mean(forgetting_list), 2)
}) + '\n')
# Seeds: 386, 2345, 157 (Performance might slightly vary across different machines)
RANDOM_SEED = 386
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if __name__ == "__main__":
continual_clip()
|