File size: 6,964 Bytes
6021dd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os, sys, logging, argparse
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

import utilspp as utpp
from utilspp import mae, mse, ssim, psnr, lpips64, csi, hss
from data.config import SEVIR_13_12, HKO7_5_20, METEONET_5_20
from data.loader import GET_TestLoader
from data.dutils import resize

class MetricListEvaluator():    
    '''
    To evaluate a list of metrics. Supported metrics:
    - CSI, HSS (Eg. `csi-84, hss-84`)
    - CSI-pooled (Eg. `csi_4-84`)
    - MAE
    - MSE
    - SSIM
    - PSNR  
    '''
    def __init__(self, metric_list):        
        self.metric_holder = {}
        self.batch_count = 0
        for metric_name in metric_list:
            threshold = ''
            radius = ''
            if '-' in metric_name:
                metric, threshold = metric_name.split('-')
                if '_' in metric:
                    metric, radius = metric.split('_')
                    radius = int(radius)
            # initialize metrics
            threshold = float(threshold) / 255 if threshold.isdigit() else threshold
            self.metric_holder[metric_name] = self.init_metric(metric_name, threshold=threshold, radius=radius)

    def init_metric(self, metric_name, **kwarg):
        '''
        return a tuple of three items in order:
        - the function to call during eval
        - the value(s) to keep track of
        - a dict of any additional item to pass into the function
        '''
        if metric_name.split('-')[0] in ['csi', 'hss']:
            # use tfpn instead
            return [utpp.tfpn, np.array([0, 0, 0, 0], dtype=np.float32), {'threshold': kwarg['threshold']}] # tp, 
        elif '_' in metric_name.split('-')[0]: # Indicate Pooling
            return [utpp.tfpn_pool, np.array([0, 0, 0, 0], dtype=np.float32), {'threshold': kwarg['threshold'], 'radius': kwarg['radius']}]
        else:
            # directly convert the string name into function call
            return [eval(metric_name), 0, {}]

    def eval(self, y_pred, y):
        self.batch_count += 1
        for _, metric in self.metric_holder.items():
            temp = metric[0](y_pred, y, **metric[-1])      
            if temp is list:
                temp = np.array(temp)
            elif type(temp) == torch.Tensor:
                temp = temp.detach().cpu().numpy()
            metric[1] += temp
            
    def get_results(self):
        output_holder = {}
        for key, metric in self.metric_holder.items():
            val = metric[1]
            # special handle of tfpn => compute the final score now
            if metric[0] is utpp.tfpn:
                metric_name, threshold = key.split('-')
                val = eval(metric_name)(*list(metric[1]))
            elif metric[0] is utpp.tfpn_pool:
                metric_name, info = key.split('_')
                val = eval(metric_name)(*list(metric[1]))
            else:
                val /= self.batch_count
            output_holder[key] = val
        return output_holder


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # Dataset related
    parser.add_argument('-d', '--dataset', type=str, default='', help='the dataset definition to be set')
    parser.add_argument('--out_len',type=int, required=True, help='The actual prediction length')
    # ensemble npy filename with {}
    parser.add_argument('--e_file', default='', type=str, help='Ensemble npy filename with included \{ \}')
    parser.add_argument('--ens_no', default=1, type=int, help='Total ensemble number')
    # hyperparams
    parser.add_argument('-s', '--step', type=int, default=-1, help='The number of steps to run. -1: the entire dataloader')
    parser.add_argument('-b', '--batch_size', type=int, default=16, help='The batch size')
    # config override
    parser.add_argument('--metrics', type=str, default=None, help='A list of metrics to be evaluated, separated by character /')    
    # logging related
    parser.add_argument('--print_every', type=int, default=100, help='The number of steps to log the training loss')
    args = parser.parse_args()

    # Prepare logger
    path_list = args.e_file.split("/")
    logfile_name = os.path.join(*path_list[:-1], 'ensemble_eval.log')
    logging.basicConfig(level=logging.NOTSET, handlers=[logging.FileHandler(logfile_name), logging.StreamHandler()], format='%(message)s')
    logging.info(f'Steps: {args.step}')

    dataset_config = globals()[args.dataset]
    dataset_param, dataset_meta = dataset_config['param'], dataset_config['meta']
    loader = GET_TestLoader(dataset_meta, dataset_param, args.batch_size)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # prepare metrics
    metric_list = dataset_meta['metrics']
    if args.metrics is not None: 
        metric_list = args.metrics.lower().split('/')
        logging.info(f'Overwriting metrics list with: {metric_list}')
    evaluator = MetricListEvaluator(metric_list)

    for e in range(args.ens_no):
        prediction = np.load(args.e_file.format(str(e)))
        prediction = torch.tensor(prediction, device=device)
        
        step = 1
        if dataset_meta['dataset'] in ['SEVIR', 'HKO-7']:
            loader.reset() # Reset it, otherwise alignment error
        else:
            pass
            
        while args.step < 0 or step <= args.step:
            if dataset_meta['dataset'] == 'SEVIR':
                data = loader.sample(batch_size=args.batch_size)
                if data is None:
                    break
                y = data['vil'][:, -args.out_len:] # Expected to be same as prediction
            elif dataset_meta['dataset'] == 'HKO-7':
                setattr(args, 'seq_len', dataset_meta['seq_len'])
                try:
                    data = loader.sample(batch_size=args.batch_size)
                except Exception as e:
                    logging.error(e)
                    break
                x_seq, x_mask, dt_clip, _ = data
                x, y = utpp.hko7_preprocess(x_seq, x_mask, dt_clip, args)
            elif dataset_meta['dataset'].startswith('meteo'):
                try:
                    x, y = next(loader)
                except Exception as e:
                    logging.error(e)
                    break

            with torch.no_grad():
                y = y.to(device)
                y_pred = prediction[(step-1)*args.batch_size:step*args.batch_size]

                if y.shape[-1] != y_pred.shape[-1]:
                    y = resize(y, y_pred.shape[-1])

                y, y_pred = y.clamp(0,1), y_pred.clamp(0,1) # B T C H W
                evaluator.eval(y_pred, y)
            # log/print every
            if step == 1 or step % args.print_every == 0:
                logging.info(f'E_ID:{e} -> {step} Steps evaluated')
            step += 1
    # log the final scores
    final_results = evaluator.get_results()
    for k, v in final_results.items():
        logging.info(f'{k}: {v}')