|
|
import pandas as pd |
|
|
import torch |
|
|
import logging |
|
|
import numpy as np |
|
|
from typing import Union |
|
|
import pickle |
|
|
import os |
|
|
from collections import deque |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def split_results_by_domain(domain_dict: dict, labels : list, domains : list, predictions: torch.tensor, confs: torch.tensor): |
|
|
""" |
|
|
Separates the label prediction pairs by domain |
|
|
Input: |
|
|
domain_dict: Dictionary, where the keys are the domain names and the values are lists with pairs [[label1, prediction1], ...] |
|
|
data: List containing [images, labels, domains, ...] |
|
|
predictions: Tensor containing the predictions of the model |
|
|
confs: Tensor containing confidence of predictions |
|
|
Returns: |
|
|
domain_dict: Updated dictionary containing the domain seperated label prediction pairs |
|
|
""" |
|
|
assert predictions.shape[0] == labels.shape[0], "The batch size of predictions and labels does not match!" |
|
|
|
|
|
for i in range(labels.shape[0]): |
|
|
if domains[i] in domain_dict.keys(): |
|
|
domain_dict[domains[i]].append([labels[i].item(), predictions[i].item(), confs[i].item()]) |
|
|
else: |
|
|
domain_dict[domains[i]] = [[labels[i].item(), predictions[i].item(), confs[i].item()]] |
|
|
|
|
|
return domain_dict |
|
|
|
|
|
|
|
|
def eval_domain_dict(domain_dict: dict): |
|
|
""" |
|
|
Print detailed results for each domain. This is useful for settings where the domains are mixed |
|
|
Input: |
|
|
domain_dict: Dictionary containing the labels and predictions for each domain |
|
|
domain_seq: Order to print the results (if all domains are contained in the domain dict) |
|
|
""" |
|
|
result_dict = {"ACC" : {}, "ECE" : {}} |
|
|
logger.info(f"Splitting the results by domain...") |
|
|
for key in domain_dict: |
|
|
label_prediction_arr = np.array(domain_dict[key]) |
|
|
labels = label_prediction_arr[:, 0] |
|
|
preds = label_prediction_arr[:, 1] |
|
|
correct = (labels == preds).sum() |
|
|
num_samples = label_prediction_arr.shape[0] |
|
|
accuracy = correct / num_samples |
|
|
result_dict["ACC"][key] = accuracy |
|
|
|
|
|
result_dict["ACC"]["avg"] = sum(list(result_dict["ACC"].values())) / len(result_dict["ACC"].keys()) |
|
|
return result_dict |
|
|
|
|
|
|
|
|
def flatten_dict(dict_): |
|
|
new_dict = {} |
|
|
for key, subdict in dict_.items(): |
|
|
for key2, values in subdict.items(): |
|
|
new_dict[f"{key}_{key2}"] = values |
|
|
return new_dict |
|
|
|
|
|
|
|
|
def load_error_dict(exp_dir: str, result_file : str = "result.pkl"): |
|
|
records = [] |
|
|
for seed_folder in os.listdir(exp_dir): |
|
|
seed_result_file = os.path.join(exp_dir, seed_folder, result_file) |
|
|
if os.path.exists(seed_result_file): |
|
|
with open(seed_result_file, "rb") as f: |
|
|
result_dict = pickle.load(f) |
|
|
result_dict = eval_domain_dict(result_dict) |
|
|
result_dict = flatten_dict(result_dict) |
|
|
records.append(result_dict) |
|
|
else: |
|
|
print(f"Warning ! {seed_result_file} does not exist") |
|
|
df = pd.DataFrame.from_records(records).mean(axis=0) |
|
|
return df |
|
|
|
|
|
|
|
|
def get_accuracy(model: torch.nn.Module, |
|
|
data_loader: torch.utils.data.DataLoader, |
|
|
dataset_name: str, |
|
|
domain_name: str, |
|
|
print_every: int, |
|
|
device: Union[str, torch.device]): |
|
|
|
|
|
num_correct = 0. |
|
|
num_samples = 0 |
|
|
domain_dict = {} |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i, data in enumerate(data_loader): |
|
|
imgs, labels = data[0], data[1] |
|
|
labels = labels.to(device, dtype=torch.int64) |
|
|
|
|
|
output = model([img.to(device) for img in imgs]) if isinstance(imgs, list) else model(imgs.to(device)) |
|
|
predictions = output.argmax(1) |
|
|
confs = output.softmax(1).amax(1) |
|
|
|
|
|
current_num_correct = (predictions == labels.to(device)).float().sum() |
|
|
num_correct += current_num_correct |
|
|
current_num_samples = imgs[0].shape[0] if isinstance(imgs, list) else imgs.shape[0] |
|
|
num_samples += current_num_samples |
|
|
|
|
|
if len(data) >= 3: |
|
|
domain_dict = split_results_by_domain(domain_dict, data[1], data[2], predictions, confs) |
|
|
else: |
|
|
domain_dict = split_results_by_domain(domain_dict, labels, [domain_name]*len(imgs), predictions, confs) |
|
|
|
|
|
|
|
|
if print_every > 0 and (i+1) % print_every == 0: |
|
|
message = f"domain={domain_name} #batches={i+1:<6} #samples={num_samples:<9} running error = {1 - num_correct/num_samples:.2%}" |
|
|
logger.info(message) |
|
|
|
|
|
if dataset_name == "ccc" and num_samples >= 7500000: |
|
|
break |
|
|
|
|
|
accuracy = num_correct.item() / num_samples |
|
|
return accuracy, domain_dict, num_samples |
|
|
|