Image Classification
English
TTA
ReservoirTTA / utils /eval_utils.py
GuillaumeVray
Uploading files
02ba886
import pandas as pd
import torch
import logging
import numpy as np
from typing import Union
import pickle
import os
from collections import deque
logger = logging.getLogger(__name__)
def split_results_by_domain(domain_dict: dict, labels : list, domains : list, predictions: torch.tensor, confs: torch.tensor):
"""
Separates the label prediction pairs by domain
Input:
domain_dict: Dictionary, where the keys are the domain names and the values are lists with pairs [[label1, prediction1], ...]
data: List containing [images, labels, domains, ...]
predictions: Tensor containing the predictions of the model
confs: Tensor containing confidence of predictions
Returns:
domain_dict: Updated dictionary containing the domain seperated label prediction pairs
"""
assert predictions.shape[0] == labels.shape[0], "The batch size of predictions and labels does not match!"
for i in range(labels.shape[0]):
if domains[i] in domain_dict.keys():
domain_dict[domains[i]].append([labels[i].item(), predictions[i].item(), confs[i].item()])
else:
domain_dict[domains[i]] = [[labels[i].item(), predictions[i].item(), confs[i].item()]]
return domain_dict
def eval_domain_dict(domain_dict: dict):
"""
Print detailed results for each domain. This is useful for settings where the domains are mixed
Input:
domain_dict: Dictionary containing the labels and predictions for each domain
domain_seq: Order to print the results (if all domains are contained in the domain dict)
"""
result_dict = {"ACC" : {}, "ECE" : {}}
logger.info(f"Splitting the results by domain...")
for key in domain_dict:
label_prediction_arr = np.array(domain_dict[key]) # rows: samples, cols: (label, prediction)
labels = label_prediction_arr[:, 0]
preds = label_prediction_arr[:, 1]
correct = (labels == preds).sum()
num_samples = label_prediction_arr.shape[0]
accuracy = correct / num_samples
result_dict["ACC"][key] = accuracy
result_dict["ACC"]["avg"] = sum(list(result_dict["ACC"].values())) / len(result_dict["ACC"].keys())
return result_dict
def flatten_dict(dict_):
new_dict = {}
for key, subdict in dict_.items():
for key2, values in subdict.items():
new_dict[f"{key}_{key2}"] = values
return new_dict
def load_error_dict(exp_dir: str, result_file : str = "result.pkl"):
records = []
for seed_folder in os.listdir(exp_dir):
seed_result_file = os.path.join(exp_dir, seed_folder, result_file)
if os.path.exists(seed_result_file):
with open(seed_result_file, "rb") as f:
result_dict = pickle.load(f)
result_dict = eval_domain_dict(result_dict)
result_dict = flatten_dict(result_dict)
records.append(result_dict)
else:
print(f"Warning ! {seed_result_file} does not exist")
df = pd.DataFrame.from_records(records).mean(axis=0)
return df
def get_accuracy(model: torch.nn.Module,
data_loader: torch.utils.data.DataLoader,
dataset_name: str,
domain_name: str,
print_every: int,
device: Union[str, torch.device]):
num_correct = 0.
num_samples = 0
domain_dict = {}
with torch.no_grad():
for i, data in enumerate(data_loader):
imgs, labels = data[0], data[1]
labels = labels.to(device, dtype=torch.int64)
output = model([img.to(device) for img in imgs]) if isinstance(imgs, list) else model(imgs.to(device))
predictions = output.argmax(1)
confs = output.softmax(1).amax(1)
current_num_correct = (predictions == labels.to(device)).float().sum()
num_correct += current_num_correct
current_num_samples = imgs[0].shape[0] if isinstance(imgs, list) else imgs.shape[0]
num_samples += current_num_samples
if len(data) >= 3:
domain_dict = split_results_by_domain(domain_dict, data[1], data[2], predictions, confs)
else:
domain_dict = split_results_by_domain(domain_dict, labels, [domain_name]*len(imgs), predictions, confs)
# track progress
if print_every > 0 and (i+1) % print_every == 0:
message = f"domain={domain_name} #batches={i+1:<6} #samples={num_samples:<9} running error = {1 - num_correct/num_samples:.2%}"
logger.info(message)
if dataset_name == "ccc" and num_samples >= 7500000:
break
accuracy = num_correct.item() / num_samples
return accuracy, domain_dict, num_samples