STLDM_official / ens_eval.py
sqfoo's picture
Upload 99 files
6021dd1 verified
import os, sys, logging, argparse
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import utilspp as utpp
from utilspp import mae, mse, ssim, psnr, lpips64, csi, hss
from data.config import SEVIR_13_12, HKO7_5_20, METEONET_5_20
from data.loader import GET_TestLoader
from data.dutils import resize
class MetricListEvaluator():
'''
To evaluate a list of metrics. Supported metrics:
- CSI, HSS (Eg. `csi-84, hss-84`)
- CSI-pooled (Eg. `csi_4-84`)
- MAE
- MSE
- SSIM
- PSNR
'''
def __init__(self, metric_list):
self.metric_holder = {}
self.batch_count = 0
for metric_name in metric_list:
threshold = ''
radius = ''
if '-' in metric_name:
metric, threshold = metric_name.split('-')
if '_' in metric:
metric, radius = metric.split('_')
radius = int(radius)
# initialize metrics
threshold = float(threshold) / 255 if threshold.isdigit() else threshold
self.metric_holder[metric_name] = self.init_metric(metric_name, threshold=threshold, radius=radius)
def init_metric(self, metric_name, **kwarg):
'''
return a tuple of three items in order:
- the function to call during eval
- the value(s) to keep track of
- a dict of any additional item to pass into the function
'''
if metric_name.split('-')[0] in ['csi', 'hss']:
# use tfpn instead
return [utpp.tfpn, np.array([0, 0, 0, 0], dtype=np.float32), {'threshold': kwarg['threshold']}] # tp,
elif '_' in metric_name.split('-')[0]: # Indicate Pooling
return [utpp.tfpn_pool, np.array([0, 0, 0, 0], dtype=np.float32), {'threshold': kwarg['threshold'], 'radius': kwarg['radius']}]
else:
# directly convert the string name into function call
return [eval(metric_name), 0, {}]
def eval(self, y_pred, y):
self.batch_count += 1
for _, metric in self.metric_holder.items():
temp = metric[0](y_pred, y, **metric[-1])
if temp is list:
temp = np.array(temp)
elif type(temp) == torch.Tensor:
temp = temp.detach().cpu().numpy()
metric[1] += temp
def get_results(self):
output_holder = {}
for key, metric in self.metric_holder.items():
val = metric[1]
# special handle of tfpn => compute the final score now
if metric[0] is utpp.tfpn:
metric_name, threshold = key.split('-')
val = eval(metric_name)(*list(metric[1]))
elif metric[0] is utpp.tfpn_pool:
metric_name, info = key.split('_')
val = eval(metric_name)(*list(metric[1]))
else:
val /= self.batch_count
output_holder[key] = val
return output_holder
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Dataset related
parser.add_argument('-d', '--dataset', type=str, default='', help='the dataset definition to be set')
parser.add_argument('--out_len',type=int, required=True, help='The actual prediction length')
# ensemble npy filename with {}
parser.add_argument('--e_file', default='', type=str, help='Ensemble npy filename with included \{ \}')
parser.add_argument('--ens_no', default=1, type=int, help='Total ensemble number')
# hyperparams
parser.add_argument('-s', '--step', type=int, default=-1, help='The number of steps to run. -1: the entire dataloader')
parser.add_argument('-b', '--batch_size', type=int, default=16, help='The batch size')
# config override
parser.add_argument('--metrics', type=str, default=None, help='A list of metrics to be evaluated, separated by character /')
# logging related
parser.add_argument('--print_every', type=int, default=100, help='The number of steps to log the training loss')
args = parser.parse_args()
# Prepare logger
path_list = args.e_file.split("/")
logfile_name = os.path.join(*path_list[:-1], 'ensemble_eval.log')
logging.basicConfig(level=logging.NOTSET, handlers=[logging.FileHandler(logfile_name), logging.StreamHandler()], format='%(message)s')
logging.info(f'Steps: {args.step}')
dataset_config = globals()[args.dataset]
dataset_param, dataset_meta = dataset_config['param'], dataset_config['meta']
loader = GET_TestLoader(dataset_meta, dataset_param, args.batch_size)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# prepare metrics
metric_list = dataset_meta['metrics']
if args.metrics is not None:
metric_list = args.metrics.lower().split('/')
logging.info(f'Overwriting metrics list with: {metric_list}')
evaluator = MetricListEvaluator(metric_list)
for e in range(args.ens_no):
prediction = np.load(args.e_file.format(str(e)))
prediction = torch.tensor(prediction, device=device)
step = 1
if dataset_meta['dataset'] in ['SEVIR', 'HKO-7']:
loader.reset() # Reset it, otherwise alignment error
else:
pass
while args.step < 0 or step <= args.step:
if dataset_meta['dataset'] == 'SEVIR':
data = loader.sample(batch_size=args.batch_size)
if data is None:
break
y = data['vil'][:, -args.out_len:] # Expected to be same as prediction
elif dataset_meta['dataset'] == 'HKO-7':
setattr(args, 'seq_len', dataset_meta['seq_len'])
try:
data = loader.sample(batch_size=args.batch_size)
except Exception as e:
logging.error(e)
break
x_seq, x_mask, dt_clip, _ = data
x, y = utpp.hko7_preprocess(x_seq, x_mask, dt_clip, args)
elif dataset_meta['dataset'].startswith('meteo'):
try:
x, y = next(loader)
except Exception as e:
logging.error(e)
break
with torch.no_grad():
y = y.to(device)
y_pred = prediction[(step-1)*args.batch_size:step*args.batch_size]
if y.shape[-1] != y_pred.shape[-1]:
y = resize(y, y_pred.shape[-1])
y, y_pred = y.clamp(0,1), y_pred.clamp(0,1) # B T C H W
evaluator.eval(y_pred, y)
# log/print every
if step == 1 or step % args.print_every == 0:
logging.info(f'E_ID:{e} -> {step} Steps evaluated')
step += 1
# log the final scores
final_results = evaluator.get_results()
for k, v in final_results.items():
logging.info(f'{k}: {v}')