File size: 6,964 Bytes
6021dd1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import os, sys, logging, argparse
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import utilspp as utpp
from utilspp import mae, mse, ssim, psnr, lpips64, csi, hss
from data.config import SEVIR_13_12, HKO7_5_20, METEONET_5_20
from data.loader import GET_TestLoader
from data.dutils import resize
class MetricListEvaluator():
'''
To evaluate a list of metrics. Supported metrics:
- CSI, HSS (Eg. `csi-84, hss-84`)
- CSI-pooled (Eg. `csi_4-84`)
- MAE
- MSE
- SSIM
- PSNR
'''
def __init__(self, metric_list):
self.metric_holder = {}
self.batch_count = 0
for metric_name in metric_list:
threshold = ''
radius = ''
if '-' in metric_name:
metric, threshold = metric_name.split('-')
if '_' in metric:
metric, radius = metric.split('_')
radius = int(radius)
# initialize metrics
threshold = float(threshold) / 255 if threshold.isdigit() else threshold
self.metric_holder[metric_name] = self.init_metric(metric_name, threshold=threshold, radius=radius)
def init_metric(self, metric_name, **kwarg):
'''
return a tuple of three items in order:
- the function to call during eval
- the value(s) to keep track of
- a dict of any additional item to pass into the function
'''
if metric_name.split('-')[0] in ['csi', 'hss']:
# use tfpn instead
return [utpp.tfpn, np.array([0, 0, 0, 0], dtype=np.float32), {'threshold': kwarg['threshold']}] # tp,
elif '_' in metric_name.split('-')[0]: # Indicate Pooling
return [utpp.tfpn_pool, np.array([0, 0, 0, 0], dtype=np.float32), {'threshold': kwarg['threshold'], 'radius': kwarg['radius']}]
else:
# directly convert the string name into function call
return [eval(metric_name), 0, {}]
def eval(self, y_pred, y):
self.batch_count += 1
for _, metric in self.metric_holder.items():
temp = metric[0](y_pred, y, **metric[-1])
if temp is list:
temp = np.array(temp)
elif type(temp) == torch.Tensor:
temp = temp.detach().cpu().numpy()
metric[1] += temp
def get_results(self):
output_holder = {}
for key, metric in self.metric_holder.items():
val = metric[1]
# special handle of tfpn => compute the final score now
if metric[0] is utpp.tfpn:
metric_name, threshold = key.split('-')
val = eval(metric_name)(*list(metric[1]))
elif metric[0] is utpp.tfpn_pool:
metric_name, info = key.split('_')
val = eval(metric_name)(*list(metric[1]))
else:
val /= self.batch_count
output_holder[key] = val
return output_holder
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Dataset related
parser.add_argument('-d', '--dataset', type=str, default='', help='the dataset definition to be set')
parser.add_argument('--out_len',type=int, required=True, help='The actual prediction length')
# ensemble npy filename with {}
parser.add_argument('--e_file', default='', type=str, help='Ensemble npy filename with included \{ \}')
parser.add_argument('--ens_no', default=1, type=int, help='Total ensemble number')
# hyperparams
parser.add_argument('-s', '--step', type=int, default=-1, help='The number of steps to run. -1: the entire dataloader')
parser.add_argument('-b', '--batch_size', type=int, default=16, help='The batch size')
# config override
parser.add_argument('--metrics', type=str, default=None, help='A list of metrics to be evaluated, separated by character /')
# logging related
parser.add_argument('--print_every', type=int, default=100, help='The number of steps to log the training loss')
args = parser.parse_args()
# Prepare logger
path_list = args.e_file.split("/")
logfile_name = os.path.join(*path_list[:-1], 'ensemble_eval.log')
logging.basicConfig(level=logging.NOTSET, handlers=[logging.FileHandler(logfile_name), logging.StreamHandler()], format='%(message)s')
logging.info(f'Steps: {args.step}')
dataset_config = globals()[args.dataset]
dataset_param, dataset_meta = dataset_config['param'], dataset_config['meta']
loader = GET_TestLoader(dataset_meta, dataset_param, args.batch_size)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# prepare metrics
metric_list = dataset_meta['metrics']
if args.metrics is not None:
metric_list = args.metrics.lower().split('/')
logging.info(f'Overwriting metrics list with: {metric_list}')
evaluator = MetricListEvaluator(metric_list)
for e in range(args.ens_no):
prediction = np.load(args.e_file.format(str(e)))
prediction = torch.tensor(prediction, device=device)
step = 1
if dataset_meta['dataset'] in ['SEVIR', 'HKO-7']:
loader.reset() # Reset it, otherwise alignment error
else:
pass
while args.step < 0 or step <= args.step:
if dataset_meta['dataset'] == 'SEVIR':
data = loader.sample(batch_size=args.batch_size)
if data is None:
break
y = data['vil'][:, -args.out_len:] # Expected to be same as prediction
elif dataset_meta['dataset'] == 'HKO-7':
setattr(args, 'seq_len', dataset_meta['seq_len'])
try:
data = loader.sample(batch_size=args.batch_size)
except Exception as e:
logging.error(e)
break
x_seq, x_mask, dt_clip, _ = data
x, y = utpp.hko7_preprocess(x_seq, x_mask, dt_clip, args)
elif dataset_meta['dataset'].startswith('meteo'):
try:
x, y = next(loader)
except Exception as e:
logging.error(e)
break
with torch.no_grad():
y = y.to(device)
y_pred = prediction[(step-1)*args.batch_size:step*args.batch_size]
if y.shape[-1] != y_pred.shape[-1]:
y = resize(y, y_pred.shape[-1])
y, y_pred = y.clamp(0,1), y_pred.clamp(0,1) # B T C H W
evaluator.eval(y_pred, y)
# log/print every
if step == 1 or step % args.print_every == 0:
logging.info(f'E_ID:{e} -> {step} Steps evaluated')
step += 1
# log the final scores
final_results = evaluator.get_results()
for k, v in final_results.items():
logging.info(f'{k}: {v}') |