Image Classification
English
TTA
ReservoirTTA / test_time.py
GuillaumeVray
Uploading files
02ba886
import os
import pickle
import torch
import logging
import numpy as np
from models.model import get_model
from utils.misc import print_memory_info
from utils.eval_utils import get_accuracy, eval_domain_dict
from utils.registry import ADAPTATION_REGISTRY
from datasets.data_loading import get_test_loader
from conf import cfg, load_cfg_from_args, get_num_classes
from methods import *
logger = logging.getLogger(__name__)
def evaluate(description):
load_cfg_from_args(description)
valid_settings = [
"continual", # train on sequence of domain shifts without knowing when a shift occurs
"continual_cdc",
]
assert cfg.SETTING in valid_settings, f"The setting '{cfg.SETTING}' is not supported! Choose from: {valid_settings}"
device = "cuda" if torch.cuda.is_available() else "cpu"
num_classes = get_num_classes(dataset_name=cfg.CORRUPTION.DATASET)
# setup number of recurrencies
if cfg.CORRUPTION.DATASET == "ccc":
num_recur = 1
else:
num_recur = cfg.CORRUPTION.RECUR
if os.path.exists(os.path.join(cfg.SAVE_DIR,f"results_r={num_recur}.pkl")):
print("Experiment already Done! No Overwriting possible!")
return
# get the base model and its corresponding input pre-processing (if available)
base_model, model_preprocess = get_model(cfg, num_classes, device)
# append the input pre-processing to the base model
base_model.model_preprocess = model_preprocess
# setup test-time adaptation method
available_adaptations = ADAPTATION_REGISTRY.registered_names()
assert cfg.MODEL.ADAPTATION in available_adaptations, \
f"The adaptation '{cfg.MODEL.ADAPTATION}' is not supported! Choose from: {available_adaptations}"
model = ADAPTATION_REGISTRY.get(cfg.MODEL.ADAPTATION)(cfg=cfg, model=base_model, num_classes=num_classes)
logger.info(f"Successfully prepared test-time adaptation method: {cfg.MODEL.ADAPTATION}")
# Get domain sequence
domain_sequence = ["continual_cdc"] if cfg.SETTING == "continual_cdc" else cfg.CORRUPTION.TYPE
domain_names_all = cfg.CORRUPTION.TYPE
logger.info(f"Using {cfg.CORRUPTION.DATASET} with the following domain sequence: {domain_sequence}")
# setup the severities
severities = cfg.CORRUPTION.SEVERITY
# start evaluation
logger.info(f"Run {num_recur} recurs")
all_errs = []
for r in range(num_recur):
logger.info(f"Start Recur={r+1}")
result_dict = {}
for i_dom, domain_name in enumerate(domain_sequence):
for severity in severities:
test_data_loader = get_test_loader(
setting=cfg.SETTING,
dataset_name=cfg.CORRUPTION.DATASET,
data_root_dir=cfg.DATA_DIR,
domain_name=domain_name,
domain_names_all=domain_names_all,
severity=severity,
num_examples=cfg.CORRUPTION.NUM_EX,
rng_seed=cfg.RNG_SEED+r*len(domain_sequence)+i_dom,
batch_size=cfg.TEST.BATCH_SIZE,
shuffle=False,
workers=cfg.TEST.NUM_WORKERS,
preprocess=model_preprocess,
)
if cfg.CORRUPTION.DATASET != "ccc" and (r == 0 and i_dom == 0):
# Note that the input normalization is done inside of the model
logger.info(f"Using the following data transformation:\n{test_data_loader.dataset.transform}")
# evaluate the model
acc, domain_dict, num_samples = get_accuracy(
model,
data_loader=test_data_loader,
dataset_name=cfg.CORRUPTION.DATASET,
domain_name=domain_name,
print_every=cfg.PRINT_EVERY,
device=device,
)
torch.cuda.empty_cache()
# Log results
logger.info(f"{cfg.CORRUPTION.DATASET} error % [{domain_name}][#severity={severity}][#recur={r+1}][#samples={num_samples}]: {1-acc:.2%}")
for domain, values in domain_dict.items():
key = f"{domain}_{severity}_#{r+1}"
result_dict[key] = values
# Save results for current round
avg_acc = eval_domain_dict(result_dict)["ACC"]['avg']
all_errs.append(1-avg_acc)
logger.info(f"#recur: {r+1}, mean avg error: {1-avg_acc:.2%}\n")
with open(os.path.join(cfg.SAVE_DIR, f"results_r={r+1}.pkl"), 'wb') as f:
pickle.dump(result_dict, f)
# Log final results
if cfg.TEST.DEBUG:
print_memory_info()
logger.info(f"Mean avg error rate: {np.mean(np.array(all_errs)):.2%}\n")
if __name__ == '__main__':
evaluate('"Evaluation.')