Upload 5 files
Browse files- app.py +36 -0
- libs/__init__.py +2 -0
- libs/functions.py +44 -0
- libs/metrics.py +45 -0
app.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import gradio as gr
|
| 4 |
+
|
| 5 |
+
from libs import *
|
| 6 |
+
|
| 7 |
+
with gr.Blocks(title="Signal Similarity Live Demo") as demo:
|
| 8 |
+
|
| 9 |
+
default_fig, default_result = demo_fn()
|
| 10 |
+
|
| 11 |
+
with gr.Row():
|
| 12 |
+
signal_images = gr.Plot(default_fig,label="Visualize Signals")
|
| 13 |
+
with gr.Row():
|
| 14 |
+
singal_results = gr.Text(default_result, interactive=False)
|
| 15 |
+
|
| 16 |
+
with gr.Row():
|
| 17 |
+
with gr.Column():
|
| 18 |
+
gr.Markdown("### Ground Truth Control")
|
| 19 |
+
gt_button = gr.Radio(["Sin", "Cos", "Random"], label="Signal Type", value="Sin")
|
| 20 |
+
gt_scale_slider = gr.Slider(minimum=-2, maximum=2, step=0.1, label="Scaling Factor", value=1)
|
| 21 |
+
gt_shift_slider = gr.Slider(minimum=-2, maximum=2, step=0.1, label="Shifting Factor", value=0)
|
| 22 |
+
with gr.Column():
|
| 23 |
+
gr.Markdown("### Pred Signal Control")
|
| 24 |
+
signal_button = gr.Radio(["Sin", "Cos", "Random"], label="Signal Type", value="Sin")
|
| 25 |
+
pred_scale_slider = gr.Slider(minimum=-2, maximum=2, step=0.1, label="Scaling Factor", value=1)
|
| 26 |
+
pred_shift_slider = gr.Slider(minimum=-2, maximum=2, step=0.1, label="Shifting Factor", value=0)
|
| 27 |
+
|
| 28 |
+
gt_button.change(fn=demo_fn, inputs=[gt_button, signal_button, gt_scale_slider, gt_shift_slider, pred_scale_slider, pred_shift_slider], outputs=[signal_images, singal_results])
|
| 29 |
+
signal_button.change(fn=demo_fn, inputs=[gt_button, signal_button, gt_scale_slider, gt_shift_slider, pred_scale_slider, pred_shift_slider], outputs=[signal_images, singal_results])
|
| 30 |
+
gt_scale_slider.change(fn=demo_fn, inputs=[gt_button, signal_button, gt_scale_slider, gt_shift_slider, pred_scale_slider, pred_shift_slider], outputs=[signal_images, singal_results])
|
| 31 |
+
gt_shift_slider.change(fn=demo_fn, inputs=[gt_button, signal_button, gt_scale_slider, gt_shift_slider, pred_scale_slider, pred_shift_slider], outputs=[signal_images, singal_results])
|
| 32 |
+
pred_scale_slider.change(fn=demo_fn, inputs=[gt_button, signal_button, gt_scale_slider, gt_shift_slider, pred_scale_slider, pred_shift_slider], outputs=[signal_images, singal_results])
|
| 33 |
+
pred_shift_slider.change(fn=demo_fn, inputs=[gt_button, signal_button, gt_scale_slider, gt_shift_slider, pred_scale_slider, pred_shift_slider], outputs=[signal_images, singal_results])
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
demo.launch(share=True)
|
libs/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .functions import *
|
| 2 |
+
from .metrics import *
|
libs/functions.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
|
| 4 |
+
from .metrics import *
|
| 5 |
+
|
| 6 |
+
SDSC = SignalDice()
|
| 7 |
+
|
| 8 |
+
def calculate_fn(sig1, sig2):
|
| 9 |
+
mse = MSE(sig1, sig2)
|
| 10 |
+
mae = MAE(sig1, sig2)
|
| 11 |
+
sdsc = SDSC(sig1, sig2)
|
| 12 |
+
|
| 13 |
+
return f"MSE: {mse:.4f}, MAE: {mae:.4f}, SDSC: {sdsc:.4f}"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def demo_fn( gt_radio:str='Sin',
|
| 17 |
+
pred_radio:str='Sin',
|
| 18 |
+
gt_scale:float=1.0,
|
| 19 |
+
gt_shift:float=0.0,
|
| 20 |
+
pred_scale:float=1.0,
|
| 21 |
+
pred_shift:float=0.0):
|
| 22 |
+
t = torch.linspace(0,2*np.pi, 1000)
|
| 23 |
+
|
| 24 |
+
signals = {
|
| 25 |
+
"Sin" : torch.sin(t),
|
| 26 |
+
"Cos" : torch.cos(t),
|
| 27 |
+
"Random": torch.randn_like(t)
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
gt = signals[gt_radio] * gt_scale + gt_shift
|
| 31 |
+
pred = signals[pred_radio] * pred_scale + pred_shift
|
| 32 |
+
|
| 33 |
+
result_strings = calculate_fn(gt, pred)
|
| 34 |
+
|
| 35 |
+
# Draw Image
|
| 36 |
+
fig = plt.figure()
|
| 37 |
+
plt.plot(t, gt.numpy(), label="Ground Truth")
|
| 38 |
+
plt.plot(t, pred.numpy(), ":", label="Pred Signals")
|
| 39 |
+
plt.axhline(0, c="k", ls=":")
|
| 40 |
+
plt.legend()
|
| 41 |
+
plt.close()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
return fig, result_strings
|
libs/metrics.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
def MSE(gt:torch.Tensor, pred:torch.Tensor) -> np.float32:
|
| 7 |
+
r"""
|
| 8 |
+
Simple MSE Function
|
| 9 |
+
"""
|
| 10 |
+
return torch.mean(torch.square(gt - pred)).item()
|
| 11 |
+
|
| 12 |
+
def MAE(gt:np.array, pred:np.array) -> np.float32:
|
| 13 |
+
r"""
|
| 14 |
+
Simple MAE Function
|
| 15 |
+
"""
|
| 16 |
+
return torch.mean(torch.abs(gt - pred)).item()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SignalDice(nn.Module):
|
| 20 |
+
def __init__(self, eps=1e-6):
|
| 21 |
+
super(SignalDice,self).__init__()
|
| 22 |
+
self.eps = eps
|
| 23 |
+
|
| 24 |
+
def calc_inter(self, a, b, same_sign_mat):
|
| 25 |
+
a = a * same_sign_mat
|
| 26 |
+
b = b * same_sign_mat
|
| 27 |
+
return torch.where(a >= b, b, a)
|
| 28 |
+
|
| 29 |
+
def calc_union(self, a, b):
|
| 30 |
+
return a + b
|
| 31 |
+
|
| 32 |
+
def forward(self, inputs:torch.Tensor, targets:torch.Tensor)->np.float32:
|
| 33 |
+
# Make abs value
|
| 34 |
+
in_abs = torch.abs(inputs)
|
| 35 |
+
tar_abs = torch.abs(targets)
|
| 36 |
+
|
| 37 |
+
# Make Heaviside Matrix
|
| 38 |
+
with torch.no_grad():
|
| 39 |
+
same_sign_mat = torch.heaviside(inputs * targets, torch.tensor([0.]))
|
| 40 |
+
|
| 41 |
+
self.intersection = self.calc_inter(in_abs, tar_abs, same_sign_mat)
|
| 42 |
+
self.union = self.calc_union(in_abs, tar_abs)
|
| 43 |
+
|
| 44 |
+
return torch.mean((2 * torch.sum(self.intersection) + self.eps) / (torch.sum(self.union) + self.eps)).item()
|
| 45 |
+
|