|
|
import torch |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
from .metrics import * |
|
|
|
|
|
SDSC = SignalDice() |
|
|
|
|
|
def calculate_fn(sig1, sig2): |
|
|
mse = MSE(sig1, sig2) |
|
|
mae = MAE(sig1, sig2) |
|
|
sdsc = SDSC(sig1, sig2) |
|
|
|
|
|
return f"MSE: {mse:.4f}, MAE: {mae:.4f}, SDSC: {sdsc:.4f}" |
|
|
|
|
|
|
|
|
def demo_fn( gt_radio:str='Sin', |
|
|
pred_radio:str='Sin', |
|
|
gt_scale:float=1.0, |
|
|
gt_shift:float=0.0, |
|
|
pred_scale:float=1.0, |
|
|
pred_shift:float=0.0): |
|
|
t = torch.linspace(0,2*np.pi, 1000) |
|
|
|
|
|
signals = { |
|
|
"Sin" : torch.sin(t), |
|
|
"Cos" : torch.cos(t), |
|
|
"Random": torch.randn_like(t) |
|
|
} |
|
|
|
|
|
gt = signals[gt_radio] * gt_scale + gt_shift |
|
|
pred = signals[pred_radio] * pred_scale + pred_shift |
|
|
|
|
|
result_strings = calculate_fn(gt, pred) |
|
|
|
|
|
|
|
|
fig = plt.figure() |
|
|
plt.plot(t, gt.numpy(), label="Ground Truth") |
|
|
plt.plot(t, pred.numpy(), ":", label="Pred Signals") |
|
|
plt.axhline(0, c="k", ls=":") |
|
|
plt.legend() |
|
|
plt.close() |
|
|
|
|
|
|
|
|
return fig, result_strings |
|
|
|