File size: 1,046 Bytes
3cd9d06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
import matplotlib.pyplot as plt

from .metrics import *

SDSC = SignalDice()

def calculate_fn(sig1, sig2):
    mse = MSE(sig1, sig2)
    mae = MAE(sig1, sig2)
    sdsc = SDSC(sig1, sig2)

    return f"MSE: {mse:.4f}, MAE: {mae:.4f}, SDSC: {sdsc:.4f}"


def demo_fn( gt_radio:str='Sin',
             pred_radio:str='Sin', 
             gt_scale:float=1.0, 
             gt_shift:float=0.0, 
             pred_scale:float=1.0, 
             pred_shift:float=0.0):
    t = torch.linspace(0,2*np.pi, 1000) 

    signals = {
        "Sin" : torch.sin(t),
        "Cos" : torch.cos(t),
        "Random": torch.randn_like(t)
    }

    gt = signals[gt_radio] * gt_scale + gt_shift
    pred = signals[pred_radio] * pred_scale + pred_shift

    result_strings = calculate_fn(gt, pred)

    # Draw Image
    fig = plt.figure()
    plt.plot(t, gt.numpy(), label="Ground Truth")
    plt.plot(t, pred.numpy(), ":", label="Pred Signals")
    plt.axhline(0, c="k", ls=":")
    plt.legend()
    plt.close()
    
    
    return fig, result_strings