|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_depth_metrics(gt, pred, mask=None, median_align=False): |
|
|
"""Computation of metrics between predicted and ground truth depths |
|
|
""" |
|
|
|
|
|
if mask is None: |
|
|
mask = gt > 0 |
|
|
|
|
|
gt = gt.squeeze(1) |
|
|
pred = pred.squeeze(1) |
|
|
mask = mask.squeeze(1) |
|
|
gt = gt[mask] |
|
|
pred = pred[mask] |
|
|
|
|
|
|
|
|
thresh = torch.max((gt / pred), (pred / gt)) |
|
|
a1 = (thresh < 1.25 ).float().mean() |
|
|
a2 = (thresh < 1.25 ** 2).float().mean() |
|
|
a3 = (thresh < 1.25 ** 3).float().mean() |
|
|
|
|
|
rmse = (gt - pred) ** 2 |
|
|
rmse = torch.sqrt(rmse).mean() |
|
|
|
|
|
rmse_log = (torch.log10(gt) - torch.log10(pred)) ** 2 |
|
|
rmse_log = torch.sqrt(rmse_log).mean() |
|
|
|
|
|
abs_ = torch.mean(torch.abs(gt - pred)) |
|
|
|
|
|
abs_rel = torch.mean(torch.abs(gt - pred) / gt) |
|
|
|
|
|
sq_rel = torch.mean((gt - pred) ** 2 / gt) |
|
|
|
|
|
log10 = torch.mean(torch.abs(torch.log10(pred/gt))) |
|
|
|
|
|
return abs_, abs_rel, sq_rel, rmse, rmse_log, log10, a1, a2, a3 |
|
|
|
|
|
|
|
|
|
|
|
class AverageMeter(object): |
|
|
"""Computes and stores the average and current value""" |
|
|
|
|
|
def __init__(self): |
|
|
self.vals = [] |
|
|
self.reset() |
|
|
|
|
|
def reset(self): |
|
|
self.val = 0 |
|
|
self.avg = 0 |
|
|
self.sum = 0 |
|
|
self.count = 0 |
|
|
|
|
|
def update(self, val, n=1): |
|
|
self.vals.append(val) |
|
|
self.val = val |
|
|
self.sum += val * n |
|
|
self.count += n |
|
|
self.avg = self.sum / self.count |
|
|
|
|
|
def to_dict(self): |
|
|
return { |
|
|
'val': self.val, |
|
|
'sum': self.sum, |
|
|
'count': self.count, |
|
|
'avg': self.avg |
|
|
} |
|
|
|
|
|
def from_dict(self, meter_dict): |
|
|
self.val = meter_dict['val'] |
|
|
self.sum = meter_dict['sum'] |
|
|
self.count = meter_dict['count'] |
|
|
self.avg = meter_dict['avg'] |
|
|
|
|
|
|
|
|
class Evaluator(object): |
|
|
|
|
|
def __init__(self, median_align=False): |
|
|
|
|
|
self.median_align = median_align |
|
|
|
|
|
self.metrics = {} |
|
|
self.metrics["err/abs_"] = AverageMeter() |
|
|
self.metrics["err/abs_rel"] = AverageMeter() |
|
|
self.metrics["err/sq_rel"] = AverageMeter() |
|
|
self.metrics["err/rms"] = AverageMeter() |
|
|
self.metrics["err/log_rms"] = AverageMeter() |
|
|
self.metrics["err/log10"] = AverageMeter() |
|
|
self.metrics["acc/a1"] = AverageMeter() |
|
|
self.metrics["acc/a2"] = AverageMeter() |
|
|
self.metrics["acc/a3"] = AverageMeter() |
|
|
|
|
|
def reset_eval_metrics(self): |
|
|
""" |
|
|
Resets metrics used to evaluate the model |
|
|
""" |
|
|
self.metrics["err/abs_"].reset() |
|
|
self.metrics["err/abs_rel"].reset() |
|
|
self.metrics["err/sq_rel"].reset() |
|
|
self.metrics["err/rms"].reset() |
|
|
self.metrics["err/log_rms"].reset() |
|
|
self.metrics["err/log10"].reset() |
|
|
self.metrics["acc/a1"].reset() |
|
|
self.metrics["acc/a2"].reset() |
|
|
self.metrics["acc/a3"].reset() |
|
|
|
|
|
def compute_eval_metrics(self, gt_depth, pred_depth, mask): |
|
|
""" |
|
|
Computes metrics used to evaluate the model |
|
|
""" |
|
|
N = gt_depth.shape[0] |
|
|
|
|
|
abs_, abs_rel, sq_rel, rms, rms_log, log10, a1, a2, a3 = \ |
|
|
compute_depth_metrics(gt_depth, pred_depth, mask, self.median_align) |
|
|
|
|
|
self.metrics["err/abs_"].update(abs_, N) |
|
|
self.metrics["err/abs_rel"].update(abs_rel, N) |
|
|
self.metrics["err/sq_rel"].update(sq_rel, N) |
|
|
self.metrics["err/rms"].update(rms, N) |
|
|
self.metrics["err/log_rms"].update(rms_log, N) |
|
|
self.metrics["err/log10"].update(log10, N) |
|
|
self.metrics["acc/a1"].update(a1, N) |
|
|
self.metrics["acc/a2"].update(a2, N) |
|
|
self.metrics["acc/a3"].update(a3, N) |
|
|
|
|
|
def print(self, dir=None): |
|
|
avg_metrics = [] |
|
|
avg_metrics_print = [] |
|
|
|
|
|
avg_metrics.append(self.metrics["err/abs_"].avg) |
|
|
avg_metrics.append(self.metrics["err/abs_rel"].avg) |
|
|
avg_metrics.append(self.metrics["err/sq_rel"].avg) |
|
|
avg_metrics.append(self.metrics["err/rms"].avg) |
|
|
avg_metrics.append(self.metrics["err/log_rms"].avg) |
|
|
avg_metrics.append(self.metrics["err/log10"].avg) |
|
|
avg_metrics.append(self.metrics["acc/a1"].avg) |
|
|
avg_metrics.append(self.metrics["acc/a2"].avg) |
|
|
avg_metrics.append(self.metrics["acc/a3"].avg) |
|
|
avg_metrics_print.append(self.metrics["err/abs_rel"].avg) |
|
|
avg_metrics_print.append(self.metrics["err/rms"].avg) |
|
|
avg_metrics_print.append(self.metrics["acc/a1"].avg) |
|
|
|
|
|
print("\n "+ ("{:>8} | " * 3).format("abs_rel", "rms", "a1")) |
|
|
print(("& {: 8.5f} " * 3).format(*avg_metrics_print)) |
|
|
|
|
|
if dir is not None: |
|
|
file = os.path.join(dir, "result.txt") |
|
|
with open(file, 'w') as f: |
|
|
print("\n " + ("{:>9} | " * 9).format("abs_", "abs_rel", "sq_rel", "rms", "rms_log", |
|
|
"log10", "a1", "a2", "a3"), file=f) |
|
|
print(("& {: 8.5f} " * 9).format(*avg_metrics), file=f) |
|
|
|