|
|
import os, sys, logging, argparse |
|
|
import torch |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
|
|
|
import utilspp as utpp |
|
|
from utilspp import mae, mse, ssim, psnr, lpips64, csi, hss |
|
|
from data.config import SEVIR_13_12, HKO7_5_20, METEONET_5_20 |
|
|
from data.loader import GET_TestLoader |
|
|
from data.dutils import resize |
|
|
|
|
|
class MetricListEvaluator(): |
|
|
''' |
|
|
To evaluate a list of metrics. Supported metrics: |
|
|
- CSI, HSS (Eg. `csi-84, hss-84`) |
|
|
- CSI-pooled (Eg. `csi_4-84`) |
|
|
- MAE |
|
|
- MSE |
|
|
- SSIM |
|
|
- PSNR |
|
|
''' |
|
|
def __init__(self, metric_list): |
|
|
self.metric_holder = {} |
|
|
self.batch_count = 0 |
|
|
for metric_name in metric_list: |
|
|
threshold = '' |
|
|
radius = '' |
|
|
if '-' in metric_name: |
|
|
metric, threshold = metric_name.split('-') |
|
|
if '_' in metric: |
|
|
metric, radius = metric.split('_') |
|
|
radius = int(radius) |
|
|
|
|
|
threshold = float(threshold) / 255 if threshold.isdigit() else threshold |
|
|
self.metric_holder[metric_name] = self.init_metric(metric_name, threshold=threshold, radius=radius) |
|
|
|
|
|
def init_metric(self, metric_name, **kwarg): |
|
|
''' |
|
|
return a tuple of three items in order: |
|
|
- the function to call during eval |
|
|
- the value(s) to keep track of |
|
|
- a dict of any additional item to pass into the function |
|
|
''' |
|
|
if metric_name.split('-')[0] in ['csi', 'hss']: |
|
|
|
|
|
return [utpp.tfpn, np.array([0, 0, 0, 0], dtype=np.float32), {'threshold': kwarg['threshold']}] |
|
|
elif '_' in metric_name.split('-')[0]: |
|
|
return [utpp.tfpn_pool, np.array([0, 0, 0, 0], dtype=np.float32), {'threshold': kwarg['threshold'], 'radius': kwarg['radius']}] |
|
|
else: |
|
|
|
|
|
return [eval(metric_name), 0, {}] |
|
|
|
|
|
def eval(self, y_pred, y): |
|
|
self.batch_count += 1 |
|
|
for _, metric in self.metric_holder.items(): |
|
|
temp = metric[0](y_pred, y, **metric[-1]) |
|
|
if temp is list: |
|
|
temp = np.array(temp) |
|
|
elif type(temp) == torch.Tensor: |
|
|
temp = temp.detach().cpu().numpy() |
|
|
metric[1] += temp |
|
|
|
|
|
def get_results(self): |
|
|
output_holder = {} |
|
|
for key, metric in self.metric_holder.items(): |
|
|
val = metric[1] |
|
|
|
|
|
if metric[0] is utpp.tfpn: |
|
|
metric_name, threshold = key.split('-') |
|
|
val = eval(metric_name)(*list(metric[1])) |
|
|
elif metric[0] is utpp.tfpn_pool: |
|
|
metric_name, info = key.split('_') |
|
|
val = eval(metric_name)(*list(metric[1])) |
|
|
else: |
|
|
val /= self.batch_count |
|
|
output_holder[key] = val |
|
|
return output_holder |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument('-d', '--dataset', type=str, default='', help='the dataset definition to be set') |
|
|
parser.add_argument('--out_len',type=int, required=True, help='The actual prediction length') |
|
|
|
|
|
parser.add_argument('--e_file', default='', type=str, help='Ensemble npy filename with included \{ \}') |
|
|
parser.add_argument('--ens_no', default=1, type=int, help='Total ensemble number') |
|
|
|
|
|
parser.add_argument('-s', '--step', type=int, default=-1, help='The number of steps to run. -1: the entire dataloader') |
|
|
parser.add_argument('-b', '--batch_size', type=int, default=16, help='The batch size') |
|
|
|
|
|
parser.add_argument('--metrics', type=str, default=None, help='A list of metrics to be evaluated, separated by character /') |
|
|
|
|
|
parser.add_argument('--print_every', type=int, default=100, help='The number of steps to log the training loss') |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
path_list = args.e_file.split("/") |
|
|
logfile_name = os.path.join(*path_list[:-1], 'ensemble_eval.log') |
|
|
logging.basicConfig(level=logging.NOTSET, handlers=[logging.FileHandler(logfile_name), logging.StreamHandler()], format='%(message)s') |
|
|
logging.info(f'Steps: {args.step}') |
|
|
|
|
|
dataset_config = globals()[args.dataset] |
|
|
dataset_param, dataset_meta = dataset_config['param'], dataset_config['meta'] |
|
|
loader = GET_TestLoader(dataset_meta, dataset_param, args.batch_size) |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
metric_list = dataset_meta['metrics'] |
|
|
if args.metrics is not None: |
|
|
metric_list = args.metrics.lower().split('/') |
|
|
logging.info(f'Overwriting metrics list with: {metric_list}') |
|
|
evaluator = MetricListEvaluator(metric_list) |
|
|
|
|
|
for e in range(args.ens_no): |
|
|
prediction = np.load(args.e_file.format(str(e))) |
|
|
prediction = torch.tensor(prediction, device=device) |
|
|
|
|
|
step = 1 |
|
|
if dataset_meta['dataset'] in ['SEVIR', 'HKO-7']: |
|
|
loader.reset() |
|
|
else: |
|
|
pass |
|
|
|
|
|
while args.step < 0 or step <= args.step: |
|
|
if dataset_meta['dataset'] == 'SEVIR': |
|
|
data = loader.sample(batch_size=args.batch_size) |
|
|
if data is None: |
|
|
break |
|
|
y = data['vil'][:, -args.out_len:] |
|
|
elif dataset_meta['dataset'] == 'HKO-7': |
|
|
setattr(args, 'seq_len', dataset_meta['seq_len']) |
|
|
try: |
|
|
data = loader.sample(batch_size=args.batch_size) |
|
|
except Exception as e: |
|
|
logging.error(e) |
|
|
break |
|
|
x_seq, x_mask, dt_clip, _ = data |
|
|
x, y = utpp.hko7_preprocess(x_seq, x_mask, dt_clip, args) |
|
|
elif dataset_meta['dataset'].startswith('meteo'): |
|
|
try: |
|
|
x, y = next(loader) |
|
|
except Exception as e: |
|
|
logging.error(e) |
|
|
break |
|
|
|
|
|
with torch.no_grad(): |
|
|
y = y.to(device) |
|
|
y_pred = prediction[(step-1)*args.batch_size:step*args.batch_size] |
|
|
|
|
|
if y.shape[-1] != y_pred.shape[-1]: |
|
|
y = resize(y, y_pred.shape[-1]) |
|
|
|
|
|
y, y_pred = y.clamp(0,1), y_pred.clamp(0,1) |
|
|
evaluator.eval(y_pred, y) |
|
|
|
|
|
if step == 1 or step % args.print_every == 0: |
|
|
logging.info(f'E_ID:{e} -> {step} Steps evaluated') |
|
|
step += 1 |
|
|
|
|
|
final_results = evaluator.get_results() |
|
|
for k, v in final_results.items(): |
|
|
logging.info(f'{k}: {v}') |