import torch import matplotlib.pyplot as plt from .metrics import * SDSC = SignalDice() def calculate_fn(sig1, sig2): mse = MSE(sig1, sig2) mae = MAE(sig1, sig2) sdsc = SDSC(sig1, sig2) return f"MSE: {mse:.4f}, MAE: {mae:.4f}, SDSC: {sdsc:.4f}" def demo_fn( gt_radio:str='Sin', pred_radio:str='Sin', gt_scale:float=1.0, gt_shift:float=0.0, pred_scale:float=1.0, pred_shift:float=0.0): t = torch.linspace(0,2*np.pi, 1000) signals = { "Sin" : torch.sin(t), "Cos" : torch.cos(t), "Random": torch.randn_like(t) } gt = signals[gt_radio] * gt_scale + gt_shift pred = signals[pred_radio] * pred_scale + pred_shift result_strings = calculate_fn(gt, pred) # Draw Image fig = plt.figure() plt.plot(t, gt.numpy(), label="Ground Truth") plt.plot(t, pred.numpy(), ":", label="Pred Signals") plt.axhline(0, c="k", ls=":") plt.legend() plt.close() return fig, result_strings