IgnoreLee commited on
Commit
3cd9d06
·
verified ·
1 Parent(s): 63b5acf

Upload 5 files

Browse files
Files changed (4) hide show
  1. app.py +36 -0
  2. libs/__init__.py +2 -0
  3. libs/functions.py +44 -0
  4. libs/metrics.py +45 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import gradio as gr
4
+
5
+ from libs import *
6
+
7
+ with gr.Blocks(title="Signal Similarity Live Demo") as demo:
8
+
9
+ default_fig, default_result = demo_fn()
10
+
11
+ with gr.Row():
12
+ signal_images = gr.Plot(default_fig,label="Visualize Signals")
13
+ with gr.Row():
14
+ singal_results = gr.Text(default_result, interactive=False)
15
+
16
+ with gr.Row():
17
+ with gr.Column():
18
+ gr.Markdown("### Ground Truth Control")
19
+ gt_button = gr.Radio(["Sin", "Cos", "Random"], label="Signal Type", value="Sin")
20
+ gt_scale_slider = gr.Slider(minimum=-2, maximum=2, step=0.1, label="Scaling Factor", value=1)
21
+ gt_shift_slider = gr.Slider(minimum=-2, maximum=2, step=0.1, label="Shifting Factor", value=0)
22
+ with gr.Column():
23
+ gr.Markdown("### Pred Signal Control")
24
+ signal_button = gr.Radio(["Sin", "Cos", "Random"], label="Signal Type", value="Sin")
25
+ pred_scale_slider = gr.Slider(minimum=-2, maximum=2, step=0.1, label="Scaling Factor", value=1)
26
+ pred_shift_slider = gr.Slider(minimum=-2, maximum=2, step=0.1, label="Shifting Factor", value=0)
27
+
28
+ gt_button.change(fn=demo_fn, inputs=[gt_button, signal_button, gt_scale_slider, gt_shift_slider, pred_scale_slider, pred_shift_slider], outputs=[signal_images, singal_results])
29
+ signal_button.change(fn=demo_fn, inputs=[gt_button, signal_button, gt_scale_slider, gt_shift_slider, pred_scale_slider, pred_shift_slider], outputs=[signal_images, singal_results])
30
+ gt_scale_slider.change(fn=demo_fn, inputs=[gt_button, signal_button, gt_scale_slider, gt_shift_slider, pred_scale_slider, pred_shift_slider], outputs=[signal_images, singal_results])
31
+ gt_shift_slider.change(fn=demo_fn, inputs=[gt_button, signal_button, gt_scale_slider, gt_shift_slider, pred_scale_slider, pred_shift_slider], outputs=[signal_images, singal_results])
32
+ pred_scale_slider.change(fn=demo_fn, inputs=[gt_button, signal_button, gt_scale_slider, gt_shift_slider, pred_scale_slider, pred_shift_slider], outputs=[signal_images, singal_results])
33
+ pred_shift_slider.change(fn=demo_fn, inputs=[gt_button, signal_button, gt_scale_slider, gt_shift_slider, pred_scale_slider, pred_shift_slider], outputs=[signal_images, singal_results])
34
+
35
+
36
+ demo.launch(share=True)
libs/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .functions import *
2
+ from .metrics import *
libs/functions.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import matplotlib.pyplot as plt
3
+
4
+ from .metrics import *
5
+
6
+ SDSC = SignalDice()
7
+
8
+ def calculate_fn(sig1, sig2):
9
+ mse = MSE(sig1, sig2)
10
+ mae = MAE(sig1, sig2)
11
+ sdsc = SDSC(sig1, sig2)
12
+
13
+ return f"MSE: {mse:.4f}, MAE: {mae:.4f}, SDSC: {sdsc:.4f}"
14
+
15
+
16
+ def demo_fn( gt_radio:str='Sin',
17
+ pred_radio:str='Sin',
18
+ gt_scale:float=1.0,
19
+ gt_shift:float=0.0,
20
+ pred_scale:float=1.0,
21
+ pred_shift:float=0.0):
22
+ t = torch.linspace(0,2*np.pi, 1000)
23
+
24
+ signals = {
25
+ "Sin" : torch.sin(t),
26
+ "Cos" : torch.cos(t),
27
+ "Random": torch.randn_like(t)
28
+ }
29
+
30
+ gt = signals[gt_radio] * gt_scale + gt_shift
31
+ pred = signals[pred_radio] * pred_scale + pred_shift
32
+
33
+ result_strings = calculate_fn(gt, pred)
34
+
35
+ # Draw Image
36
+ fig = plt.figure()
37
+ plt.plot(t, gt.numpy(), label="Ground Truth")
38
+ plt.plot(t, pred.numpy(), ":", label="Pred Signals")
39
+ plt.axhline(0, c="k", ls=":")
40
+ plt.legend()
41
+ plt.close()
42
+
43
+
44
+ return fig, result_strings
libs/metrics.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ def MSE(gt:torch.Tensor, pred:torch.Tensor) -> np.float32:
7
+ r"""
8
+ Simple MSE Function
9
+ """
10
+ return torch.mean(torch.square(gt - pred)).item()
11
+
12
+ def MAE(gt:np.array, pred:np.array) -> np.float32:
13
+ r"""
14
+ Simple MAE Function
15
+ """
16
+ return torch.mean(torch.abs(gt - pred)).item()
17
+
18
+
19
+ class SignalDice(nn.Module):
20
+ def __init__(self, eps=1e-6):
21
+ super(SignalDice,self).__init__()
22
+ self.eps = eps
23
+
24
+ def calc_inter(self, a, b, same_sign_mat):
25
+ a = a * same_sign_mat
26
+ b = b * same_sign_mat
27
+ return torch.where(a >= b, b, a)
28
+
29
+ def calc_union(self, a, b):
30
+ return a + b
31
+
32
+ def forward(self, inputs:torch.Tensor, targets:torch.Tensor)->np.float32:
33
+ # Make abs value
34
+ in_abs = torch.abs(inputs)
35
+ tar_abs = torch.abs(targets)
36
+
37
+ # Make Heaviside Matrix
38
+ with torch.no_grad():
39
+ same_sign_mat = torch.heaviside(inputs * targets, torch.tensor([0.]))
40
+
41
+ self.intersection = self.calc_inter(in_abs, tar_abs, same_sign_mat)
42
+ self.union = self.calc_union(in_abs, tar_abs)
43
+
44
+ return torch.mean((2 * torch.sum(self.intersection) + self.eps) / (torch.sum(self.union) + self.eps)).item()
45
+