SDSC_DEMO / libs /functions.py
IgnoreLee's picture
Upload 5 files
3cd9d06 verified
import torch
import matplotlib.pyplot as plt
from .metrics import *
SDSC = SignalDice()
def calculate_fn(sig1, sig2):
mse = MSE(sig1, sig2)
mae = MAE(sig1, sig2)
sdsc = SDSC(sig1, sig2)
return f"MSE: {mse:.4f}, MAE: {mae:.4f}, SDSC: {sdsc:.4f}"
def demo_fn( gt_radio:str='Sin',
pred_radio:str='Sin',
gt_scale:float=1.0,
gt_shift:float=0.0,
pred_scale:float=1.0,
pred_shift:float=0.0):
t = torch.linspace(0,2*np.pi, 1000)
signals = {
"Sin" : torch.sin(t),
"Cos" : torch.cos(t),
"Random": torch.randn_like(t)
}
gt = signals[gt_radio] * gt_scale + gt_shift
pred = signals[pred_radio] * pred_scale + pred_shift
result_strings = calculate_fn(gt, pred)
# Draw Image
fig = plt.figure()
plt.plot(t, gt.numpy(), label="Ground Truth")
plt.plot(t, pred.numpy(), ":", label="Pred Signals")
plt.axhline(0, c="k", ls=":")
plt.legend()
plt.close()
return fig, result_strings