WaveLSFromer / utils /metrics.py
ducheng678
Initial WaveLSFromer project
093b0a5
Raw
History Blame Contribute Delete
1.65 kB
import numpy as np
import torch
from torchmetrics import Metric
def RSE(pred, true):
return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(
np.sum((true - true.mean()) ** 2)
)
def CORR(pred, true):
u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0)
d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0))
return (u / d).mean(-1)
def MAE(pred, true):
return np.mean(np.abs(pred - true))
def MSE(pred, true):
return np.mean((pred - true) ** 2)
def RMSE(pred, true):
return np.sqrt(MSE(pred, true))
def MAPE(pred, true):
return np.mean(np.abs((pred - true) / true))
def MSPE(pred, true):
return np.mean(np.square((pred - true) / true))
def metric(pred, true):
mae = MAE(pred, true)
mse = MSE(pred, true)
rmse = RMSE(pred, true)
mape = MAPE(pred, true)
mspe = MSPE(pred, true)
return mae, mse, rmse, mape, mspe
class MSELoss(Metric):
# Just testing to make sure torch metrics work as expected
is_differentiable = True
def __init__(self):
super().__init__()
self.add_state(
"sum_squared_errors",
default=torch.tensor(0, dtype=float),
dist_reduce_fx="sum",
)
self.add_state("n_observations", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
assert preds.shape == target.shape
self.sum_squared_errors += torch.sum(torch.square(preds - target))
self.n_observations += preds.numel()
def compute(self):
return self.sum_squared_errors / self.n_observations