Image Classification
English
TTA
File size: 4,813 Bytes
02ba886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import pandas as pd
import torch
import logging
import numpy as np
from typing import Union
import pickle
import os
from collections import deque

logger = logging.getLogger(__name__)


def split_results_by_domain(domain_dict: dict, labels : list, domains : list, predictions: torch.tensor, confs: torch.tensor):
    """
    Separates the label prediction pairs by domain
    Input:
        domain_dict: Dictionary, where the keys are the domain names and the values are lists with pairs [[label1, prediction1], ...]
        data: List containing [images, labels, domains, ...]
        predictions: Tensor containing the predictions of the model
        confs: Tensor containing confidence of predictions
    Returns:
        domain_dict: Updated dictionary containing the domain seperated label prediction pairs
    """
    assert predictions.shape[0] == labels.shape[0], "The batch size of predictions and labels does not match!"

    for i in range(labels.shape[0]):
        if domains[i] in domain_dict.keys():
            domain_dict[domains[i]].append([labels[i].item(), predictions[i].item(), confs[i].item()])
        else:
            domain_dict[domains[i]] = [[labels[i].item(), predictions[i].item(), confs[i].item()]]

    return domain_dict


def eval_domain_dict(domain_dict: dict):
    """
    Print detailed results for each domain. This is useful for settings where the domains are mixed
    Input:
        domain_dict: Dictionary containing the labels and predictions for each domain
        domain_seq: Order to print the results (if all domains are contained in the domain dict)
    """
    result_dict = {"ACC" : {}, "ECE" : {}}
    logger.info(f"Splitting the results by domain...")
    for key in domain_dict:
        label_prediction_arr = np.array(domain_dict[key])  # rows: samples, cols: (label, prediction)
        labels = label_prediction_arr[:, 0]
        preds = label_prediction_arr[:, 1]
        correct = (labels == preds).sum()
        num_samples = label_prediction_arr.shape[0]
        accuracy = correct / num_samples
        result_dict["ACC"][key] = accuracy

    result_dict["ACC"]["avg"] = sum(list(result_dict["ACC"].values())) / len(result_dict["ACC"].keys())
    return result_dict


def flatten_dict(dict_):
    new_dict = {}
    for key, subdict in dict_.items():
        for key2, values in subdict.items():
            new_dict[f"{key}_{key2}"] = values
    return new_dict


def load_error_dict(exp_dir: str, result_file : str = "result.pkl"):
    records = []    
    for seed_folder in os.listdir(exp_dir):
        seed_result_file = os.path.join(exp_dir, seed_folder, result_file)
        if os.path.exists(seed_result_file):
            with open(seed_result_file, "rb") as f:
                result_dict = pickle.load(f)
                result_dict = eval_domain_dict(result_dict)
                result_dict = flatten_dict(result_dict)
                records.append(result_dict)
        else:
            print(f"Warning ! {seed_result_file} does not exist")
    df = pd.DataFrame.from_records(records).mean(axis=0)
    return df


def get_accuracy(model: torch.nn.Module,
                 data_loader: torch.utils.data.DataLoader,
                 dataset_name: str,
                 domain_name: str,
                 print_every: int,
                 device: Union[str, torch.device]):
    
    num_correct = 0.
    num_samples = 0
    domain_dict = {}

    with torch.no_grad():
        for i, data in enumerate(data_loader):
            imgs, labels = data[0], data[1]
            labels = labels.to(device, dtype=torch.int64)
    
            output = model([img.to(device) for img in imgs]) if isinstance(imgs, list) else model(imgs.to(device))
            predictions = output.argmax(1)
            confs = output.softmax(1).amax(1)

            current_num_correct = (predictions == labels.to(device)).float().sum()
            num_correct += current_num_correct
            current_num_samples = imgs[0].shape[0] if isinstance(imgs, list) else imgs.shape[0]
            num_samples += current_num_samples

            if len(data) >= 3:
                domain_dict = split_results_by_domain(domain_dict, data[1], data[2], predictions, confs)
            else:
                domain_dict = split_results_by_domain(domain_dict, labels, [domain_name]*len(imgs), predictions, confs)

            # track progress
            if print_every > 0 and (i+1) % print_every == 0:
                message = f"domain={domain_name} #batches={i+1:<6} #samples={num_samples:<9} running error = {1 - num_correct/num_samples:.2%}"
                logger.info(message)
        
            if dataset_name == "ccc" and num_samples >= 7500000:
                break

    accuracy = num_correct.item() / num_samples
    return accuracy, domain_dict, num_samples