|
|
import os |
|
|
import pickle |
|
|
import torch |
|
|
import logging |
|
|
import numpy as np |
|
|
|
|
|
from models.model import get_model |
|
|
from utils.misc import print_memory_info |
|
|
from utils.eval_utils import get_accuracy, eval_domain_dict |
|
|
from utils.registry import ADAPTATION_REGISTRY |
|
|
from datasets.data_loading import get_test_loader |
|
|
from conf import cfg, load_cfg_from_args, get_num_classes |
|
|
from methods import * |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def evaluate(description): |
|
|
load_cfg_from_args(description) |
|
|
valid_settings = [ |
|
|
"continual", |
|
|
"continual_cdc", |
|
|
] |
|
|
assert cfg.SETTING in valid_settings, f"The setting '{cfg.SETTING}' is not supported! Choose from: {valid_settings}" |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
num_classes = get_num_classes(dataset_name=cfg.CORRUPTION.DATASET) |
|
|
|
|
|
|
|
|
if cfg.CORRUPTION.DATASET == "ccc": |
|
|
num_recur = 1 |
|
|
else: |
|
|
num_recur = cfg.CORRUPTION.RECUR |
|
|
|
|
|
if os.path.exists(os.path.join(cfg.SAVE_DIR,f"results_r={num_recur}.pkl")): |
|
|
print("Experiment already Done! No Overwriting possible!") |
|
|
return |
|
|
|
|
|
|
|
|
base_model, model_preprocess = get_model(cfg, num_classes, device) |
|
|
|
|
|
|
|
|
base_model.model_preprocess = model_preprocess |
|
|
|
|
|
|
|
|
available_adaptations = ADAPTATION_REGISTRY.registered_names() |
|
|
assert cfg.MODEL.ADAPTATION in available_adaptations, \ |
|
|
f"The adaptation '{cfg.MODEL.ADAPTATION}' is not supported! Choose from: {available_adaptations}" |
|
|
model = ADAPTATION_REGISTRY.get(cfg.MODEL.ADAPTATION)(cfg=cfg, model=base_model, num_classes=num_classes) |
|
|
logger.info(f"Successfully prepared test-time adaptation method: {cfg.MODEL.ADAPTATION}") |
|
|
|
|
|
|
|
|
domain_sequence = ["continual_cdc"] if cfg.SETTING == "continual_cdc" else cfg.CORRUPTION.TYPE |
|
|
domain_names_all = cfg.CORRUPTION.TYPE |
|
|
logger.info(f"Using {cfg.CORRUPTION.DATASET} with the following domain sequence: {domain_sequence}") |
|
|
|
|
|
|
|
|
severities = cfg.CORRUPTION.SEVERITY |
|
|
|
|
|
|
|
|
logger.info(f"Run {num_recur} recurs") |
|
|
all_errs = [] |
|
|
for r in range(num_recur): |
|
|
|
|
|
logger.info(f"Start Recur={r+1}") |
|
|
result_dict = {} |
|
|
|
|
|
for i_dom, domain_name in enumerate(domain_sequence): |
|
|
for severity in severities: |
|
|
test_data_loader = get_test_loader( |
|
|
setting=cfg.SETTING, |
|
|
dataset_name=cfg.CORRUPTION.DATASET, |
|
|
data_root_dir=cfg.DATA_DIR, |
|
|
domain_name=domain_name, |
|
|
domain_names_all=domain_names_all, |
|
|
severity=severity, |
|
|
num_examples=cfg.CORRUPTION.NUM_EX, |
|
|
rng_seed=cfg.RNG_SEED+r*len(domain_sequence)+i_dom, |
|
|
batch_size=cfg.TEST.BATCH_SIZE, |
|
|
shuffle=False, |
|
|
workers=cfg.TEST.NUM_WORKERS, |
|
|
preprocess=model_preprocess, |
|
|
) |
|
|
|
|
|
if cfg.CORRUPTION.DATASET != "ccc" and (r == 0 and i_dom == 0): |
|
|
|
|
|
logger.info(f"Using the following data transformation:\n{test_data_loader.dataset.transform}") |
|
|
|
|
|
|
|
|
acc, domain_dict, num_samples = get_accuracy( |
|
|
model, |
|
|
data_loader=test_data_loader, |
|
|
dataset_name=cfg.CORRUPTION.DATASET, |
|
|
domain_name=domain_name, |
|
|
print_every=cfg.PRINT_EVERY, |
|
|
device=device, |
|
|
) |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
logger.info(f"{cfg.CORRUPTION.DATASET} error % [{domain_name}][#severity={severity}][#recur={r+1}][#samples={num_samples}]: {1-acc:.2%}") |
|
|
for domain, values in domain_dict.items(): |
|
|
key = f"{domain}_{severity}_#{r+1}" |
|
|
result_dict[key] = values |
|
|
|
|
|
|
|
|
avg_acc = eval_domain_dict(result_dict)["ACC"]['avg'] |
|
|
all_errs.append(1-avg_acc) |
|
|
logger.info(f"#recur: {r+1}, mean avg error: {1-avg_acc:.2%}\n") |
|
|
with open(os.path.join(cfg.SAVE_DIR, f"results_r={r+1}.pkl"), 'wb') as f: |
|
|
pickle.dump(result_dict, f) |
|
|
|
|
|
|
|
|
if cfg.TEST.DEBUG: |
|
|
print_memory_info() |
|
|
logger.info(f"Mean avg error rate: {np.mean(np.array(all_errs)):.2%}\n") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
evaluate('"Evaluation.') |
|
|
|