Spaces:
Runtime error
Runtime error
Now original waveform is displayed
Browse files- Gradio_app.ipynb +16 -20
- app.py +7 -6
Gradio_app.ipynb
CHANGED
|
@@ -2,14 +2,14 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
-
"execution_count":
|
| 6 |
"metadata": {},
|
| 7 |
"outputs": [
|
| 8 |
{
|
| 9 |
"name": "stdout",
|
| 10 |
"output_type": "stream",
|
| 11 |
"text": [
|
| 12 |
-
"Running on local URL: http://127.0.0.1:
|
| 13 |
"\n",
|
| 14 |
"To create a public link, set `share=True` in `launch()`.\n"
|
| 15 |
]
|
|
@@ -17,7 +17,7 @@
|
|
| 17 |
{
|
| 18 |
"data": {
|
| 19 |
"text/html": [
|
| 20 |
-
"<div><iframe src=\"http://127.0.0.1:
|
| 21 |
],
|
| 22 |
"text/plain": [
|
| 23 |
"<IPython.core.display.HTML object>"
|
|
@@ -30,22 +30,16 @@
|
|
| 30 |
"data": {
|
| 31 |
"text/plain": []
|
| 32 |
},
|
| 33 |
-
"execution_count":
|
| 34 |
"metadata": {},
|
| 35 |
"output_type": "execute_result"
|
| 36 |
},
|
| 37 |
-
{
|
| 38 |
-
"name": "stderr",
|
| 39 |
-
"output_type": "stream",
|
| 40 |
-
"text": [
|
| 41 |
-
"No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.\n"
|
| 42 |
-
]
|
| 43 |
-
},
|
| 44 |
{
|
| 45 |
"name": "stdout",
|
| 46 |
"output_type": "stream",
|
| 47 |
"text": [
|
| 48 |
-
"
|
|
|
|
| 49 |
]
|
| 50 |
}
|
| 51 |
],
|
|
@@ -83,8 +77,9 @@
|
|
| 83 |
" if len(waveform.shape) == 1:\n",
|
| 84 |
" waveform = waveform.reshape(1, waveform.shape[0])\n",
|
| 85 |
"\n",
|
|
|
|
| 86 |
" processed_input = prepare_waveform(waveform)\n",
|
| 87 |
-
"
|
| 88 |
" # Make prediction\n",
|
| 89 |
" with torch.inference_mode():\n",
|
| 90 |
" output = model(processed_input)\n",
|
|
@@ -92,33 +87,34 @@
|
|
| 92 |
" p_phase = output[:, 0]\n",
|
| 93 |
" s_phase = output[:, 1]\n",
|
| 94 |
"\n",
|
| 95 |
-
" return processed_input, p_phase, s_phase\n",
|
|
|
|
| 96 |
"\n",
|
| 97 |
"def mark_phases(waveform, uploaded_file, p_thres, s_thres):\n",
|
| 98 |
"\n",
|
| 99 |
" if uploaded_file is not None:\n",
|
| 100 |
" waveform = uploaded_file.name\n",
|
| 101 |
"\n",
|
| 102 |
-
" processed_input, p_phase, s_phase = make_prediction(waveform)\n",
|
| 103 |
"\n",
|
| 104 |
" # Create a plot of the waveform with the phases marked\n",
|
| 105 |
" if sum(processed_input[0][2] == 0): #if input is 1C\n",
|
| 106 |
" fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True)\n",
|
| 107 |
"\n",
|
| 108 |
-
" ax[0].plot(
|
| 109 |
" ax[0].set_ylabel('Norm. Ampl.')\n",
|
| 110 |
"\n",
|
| 111 |
" else: #if input is 3C\n",
|
| 112 |
" fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True)\n",
|
| 113 |
-
" ax[0].plot(
|
| 114 |
-
" ax[1].plot(
|
| 115 |
-
" ax[2].plot(
|
| 116 |
"\n",
|
| 117 |
" ax[0].set_ylabel('Z')\n",
|
| 118 |
" ax[1].set_ylabel('N')\n",
|
| 119 |
" ax[2].set_ylabel('E')\n",
|
| 120 |
"\n",
|
| 121 |
-
"
|
| 122 |
" do_we_have_p = (p_phase.std().item()*60 < p_thres)\n",
|
| 123 |
" if do_we_have_p:\n",
|
| 124 |
" p_phase_plot = p_phase*processed_input.shape[-1]\n",
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
+
"execution_count": 16,
|
| 6 |
"metadata": {},
|
| 7 |
"outputs": [
|
| 8 |
{
|
| 9 |
"name": "stdout",
|
| 10 |
"output_type": "stream",
|
| 11 |
"text": [
|
| 12 |
+
"Running on local URL: http://127.0.0.1:7869\n",
|
| 13 |
"\n",
|
| 14 |
"To create a public link, set `share=True` in `launch()`.\n"
|
| 15 |
]
|
|
|
|
| 17 |
{
|
| 18 |
"data": {
|
| 19 |
"text/html": [
|
| 20 |
+
"<div><iframe src=\"http://127.0.0.1:7869/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
| 21 |
],
|
| 22 |
"text/plain": [
|
| 23 |
"<IPython.core.display.HTML object>"
|
|
|
|
| 30 |
"data": {
|
| 31 |
"text/plain": []
|
| 32 |
},
|
| 33 |
+
"execution_count": 16,
|
| 34 |
"metadata": {},
|
| 35 |
"output_type": "execute_result"
|
| 36 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
{
|
| 38 |
"name": "stdout",
|
| 39 |
"output_type": "stream",
|
| 40 |
"text": [
|
| 41 |
+
"4\n",
|
| 42 |
+
"0.02744414610788226\n"
|
| 43 |
]
|
| 44 |
}
|
| 45 |
],
|
|
|
|
| 77 |
" if len(waveform.shape) == 1:\n",
|
| 78 |
" waveform = waveform.reshape(1, waveform.shape[0])\n",
|
| 79 |
"\n",
|
| 80 |
+
" orig_waveform = waveform[:, :6000].copy()\n",
|
| 81 |
" processed_input = prepare_waveform(waveform)\n",
|
| 82 |
+
"\n",
|
| 83 |
" # Make prediction\n",
|
| 84 |
" with torch.inference_mode():\n",
|
| 85 |
" output = model(processed_input)\n",
|
|
|
|
| 87 |
" p_phase = output[:, 0]\n",
|
| 88 |
" s_phase = output[:, 1]\n",
|
| 89 |
"\n",
|
| 90 |
+
" return processed_input, p_phase, s_phase, orig_waveform\n",
|
| 91 |
+
"\n",
|
| 92 |
"\n",
|
| 93 |
"def mark_phases(waveform, uploaded_file, p_thres, s_thres):\n",
|
| 94 |
"\n",
|
| 95 |
" if uploaded_file is not None:\n",
|
| 96 |
" waveform = uploaded_file.name\n",
|
| 97 |
"\n",
|
| 98 |
+
" processed_input, p_phase, s_phase, orig_waveform = make_prediction(waveform)\n",
|
| 99 |
"\n",
|
| 100 |
" # Create a plot of the waveform with the phases marked\n",
|
| 101 |
" if sum(processed_input[0][2] == 0): #if input is 1C\n",
|
| 102 |
" fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True)\n",
|
| 103 |
"\n",
|
| 104 |
+
" ax[0].plot(orig_waveform[0], color='black', lw=1)\n",
|
| 105 |
" ax[0].set_ylabel('Norm. Ampl.')\n",
|
| 106 |
"\n",
|
| 107 |
" else: #if input is 3C\n",
|
| 108 |
" fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True)\n",
|
| 109 |
+
" ax[0].plot(orig_waveform[0], color='black', lw=1)\n",
|
| 110 |
+
" ax[1].plot(orig_waveform[1], color='black', lw=1)\n",
|
| 111 |
+
" ax[2].plot(orig_waveform[2], color='black', lw=1)\n",
|
| 112 |
"\n",
|
| 113 |
" ax[0].set_ylabel('Z')\n",
|
| 114 |
" ax[1].set_ylabel('N')\n",
|
| 115 |
" ax[2].set_ylabel('E')\n",
|
| 116 |
"\n",
|
| 117 |
+
"\n",
|
| 118 |
" do_we_have_p = (p_phase.std().item()*60 < p_thres)\n",
|
| 119 |
" if do_we_have_p:\n",
|
| 120 |
" p_phase_plot = p_phase*processed_input.shape[-1]\n",
|
app.py
CHANGED
|
@@ -36,6 +36,7 @@ def make_prediction(waveform):
|
|
| 36 |
if len(waveform.shape) == 1:
|
| 37 |
waveform = waveform.reshape(1, waveform.shape[0])
|
| 38 |
|
|
|
|
| 39 |
processed_input = prepare_waveform(waveform)
|
| 40 |
|
| 41 |
# Make prediction
|
|
@@ -45,7 +46,7 @@ def make_prediction(waveform):
|
|
| 45 |
p_phase = output[:, 0]
|
| 46 |
s_phase = output[:, 1]
|
| 47 |
|
| 48 |
-
return processed_input, p_phase, s_phase
|
| 49 |
|
| 50 |
|
| 51 |
def mark_phases(waveform, uploaded_file, p_thres, s_thres):
|
|
@@ -53,20 +54,20 @@ def mark_phases(waveform, uploaded_file, p_thres, s_thres):
|
|
| 53 |
if uploaded_file is not None:
|
| 54 |
waveform = uploaded_file.name
|
| 55 |
|
| 56 |
-
processed_input, p_phase, s_phase = make_prediction(waveform)
|
| 57 |
|
| 58 |
# Create a plot of the waveform with the phases marked
|
| 59 |
if sum(processed_input[0][2] == 0): # if input is 1C
|
| 60 |
fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True)
|
| 61 |
|
| 62 |
-
ax[0].plot(
|
| 63 |
ax[0].set_ylabel("Norm. Ampl.")
|
| 64 |
|
| 65 |
else: # if input is 3C
|
| 66 |
fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True)
|
| 67 |
-
ax[0].plot(
|
| 68 |
-
ax[1].plot(
|
| 69 |
-
ax[2].plot(
|
| 70 |
|
| 71 |
ax[0].set_ylabel("Z")
|
| 72 |
ax[1].set_ylabel("N")
|
|
|
|
| 36 |
if len(waveform.shape) == 1:
|
| 37 |
waveform = waveform.reshape(1, waveform.shape[0])
|
| 38 |
|
| 39 |
+
orig_waveform = waveform[:, :6000].copy()
|
| 40 |
processed_input = prepare_waveform(waveform)
|
| 41 |
|
| 42 |
# Make prediction
|
|
|
|
| 46 |
p_phase = output[:, 0]
|
| 47 |
s_phase = output[:, 1]
|
| 48 |
|
| 49 |
+
return processed_input, p_phase, s_phase, orig_waveform
|
| 50 |
|
| 51 |
|
| 52 |
def mark_phases(waveform, uploaded_file, p_thres, s_thres):
|
|
|
|
| 54 |
if uploaded_file is not None:
|
| 55 |
waveform = uploaded_file.name
|
| 56 |
|
| 57 |
+
processed_input, p_phase, s_phase, orig_waveform = make_prediction(waveform)
|
| 58 |
|
| 59 |
# Create a plot of the waveform with the phases marked
|
| 60 |
if sum(processed_input[0][2] == 0): # if input is 1C
|
| 61 |
fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True)
|
| 62 |
|
| 63 |
+
ax[0].plot(orig_waveform[0], color="black", lw=1)
|
| 64 |
ax[0].set_ylabel("Norm. Ampl.")
|
| 65 |
|
| 66 |
else: # if input is 3C
|
| 67 |
fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True)
|
| 68 |
+
ax[0].plot(orig_waveform[0], color="black", lw=1)
|
| 69 |
+
ax[1].plot(orig_waveform[1], color="black", lw=1)
|
| 70 |
+
ax[2].plot(orig_waveform[2], color="black", lw=1)
|
| 71 |
|
| 72 |
ax[0].set_ylabel("Z")
|
| 73 |
ax[1].set_ylabel("N")
|