Spaces:
Build error
Build error
Merge pull request #5 from Portiloop/milo/filetypes_and_staging
Browse files- portiloop/src/demo/offline.py +29 -6
- portiloop/src/demo/phase_demo.py +63 -0
- portiloop/src/demo/test_offline.py +12 -7
- portiloop/src/demo/utils.py +49 -12
- portiloop/src/stimulation.py +54 -6
portiloop/src/demo/offline.py
CHANGED
|
@@ -1,13 +1,12 @@
|
|
| 1 |
-
import matplotlib.pyplot as plt
|
| 2 |
import numpy as np
|
| 3 |
from portiloop.src.detection import SleepSpindleRealTimeDetector
|
| 4 |
-
|
| 5 |
from portiloop.src.processing import FilterPipeline
|
| 6 |
-
from portiloop.src.demo.utils import compute_output_table, xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
|
| 7 |
import gradio as gr
|
| 8 |
|
| 9 |
|
| 10 |
-
def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
|
| 11 |
# Get the options from the checkbox group
|
| 12 |
offline_filtering = 0 in detect_filter_opts
|
| 13 |
lacourse = 1 in detect_filter_opts
|
|
@@ -30,6 +29,7 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
|
|
| 30 |
# Read the xdf file to a numpy array
|
| 31 |
print("Loading xdf file...")
|
| 32 |
data_whole, columns = xdf2array(xdf_file.name, int(channel_num))
|
|
|
|
| 33 |
# Do the offline filtering of the data
|
| 34 |
if offline_filtering:
|
| 35 |
print("Filtering offline...")
|
|
@@ -39,13 +39,18 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
|
|
| 39 |
data_whole = np.concatenate((data_whole, offline_filtered_data), axis=1)
|
| 40 |
columns.append("offline_filtered_signal")
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
# Do Wamsley's method
|
| 43 |
if wamsley:
|
| 44 |
print("Running Wamsley detection...")
|
| 45 |
wamsley_data = offline_detect("Wamsley", \
|
| 46 |
data_whole[:, columns.index("offline_filtered_signal")],\
|
| 47 |
data_whole[:, columns.index("time_stamps")],\
|
| 48 |
-
freq)
|
| 49 |
wamsley_data = np.expand_dims(wamsley_data, axis=1)
|
| 50 |
data_whole = np.concatenate((data_whole, wamsley_data), axis=1)
|
| 51 |
columns.append("wamsley_spindles")
|
|
@@ -56,7 +61,7 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
|
|
| 56 |
lacourse_data = offline_detect("Lacourse", \
|
| 57 |
data_whole[:, columns.index("offline_filtered_signal")],\
|
| 58 |
data_whole[:, columns.index("time_stamps")],\
|
| 59 |
-
freq)
|
| 60 |
lacourse_data = np.expand_dims(lacourse_data, axis=1)
|
| 61 |
data_whole = np.concatenate((data_whole, lacourse_data), axis=1)
|
| 62 |
columns.append("lacourse_spindles")
|
|
@@ -72,12 +77,17 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
|
|
| 72 |
if online_detection:
|
| 73 |
detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
|
| 74 |
stimulator = OfflineSleepSpindleRealTimeStimulator()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
if online_filtering or online_detection:
|
| 77 |
print("Running online filtering and detection...")
|
| 78 |
|
| 79 |
points = []
|
| 80 |
online_activations = []
|
|
|
|
| 81 |
|
| 82 |
# Go through the data
|
| 83 |
for index, point in enumerate(data):
|
|
@@ -93,6 +103,13 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
|
|
| 93 |
# Detect the spindles
|
| 94 |
result = detector.detect([filtered_point])
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
# Stimulate if necessary
|
| 97 |
stim = stimulator.stimulate(result)
|
| 98 |
if stim:
|
|
@@ -112,6 +129,12 @@ def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq):
|
|
| 112 |
data_whole = np.concatenate((data_whole, online_activations), axis=1)
|
| 113 |
columns.append("online_stimulations")
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
print("Saving output...")
|
| 116 |
# Output the data to a csv file
|
| 117 |
np.savetxt("output.csv", data_whole, delimiter=",", header=",".join(columns), comments="")
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
from portiloop.src.detection import SleepSpindleRealTimeDetector
|
| 3 |
+
from portiloop.src.stimulation import UpStateDelayer
|
| 4 |
from portiloop.src.processing import FilterPipeline
|
| 5 |
+
from portiloop.src.demo.utils import compute_output_table, sleep_stage, xdf2array, offline_detect, offline_filter, OfflineSleepSpindleRealTimeStimulator
|
| 6 |
import gradio as gr
|
| 7 |
|
| 8 |
|
| 9 |
+
def run_offline(xdf_file, detect_filter_opts, threshold, channel_num, freq, stimulation_phase="Fast", buffer_time=0.25):
|
| 10 |
# Get the options from the checkbox group
|
| 11 |
offline_filtering = 0 in detect_filter_opts
|
| 12 |
lacourse = 1 in detect_filter_opts
|
|
|
|
| 29 |
# Read the xdf file to a numpy array
|
| 30 |
print("Loading xdf file...")
|
| 31 |
data_whole, columns = xdf2array(xdf_file.name, int(channel_num))
|
| 32 |
+
|
| 33 |
# Do the offline filtering of the data
|
| 34 |
if offline_filtering:
|
| 35 |
print("Filtering offline...")
|
|
|
|
| 39 |
data_whole = np.concatenate((data_whole, offline_filtered_data), axis=1)
|
| 40 |
columns.append("offline_filtered_signal")
|
| 41 |
|
| 42 |
+
# Do the sleep staging approximation
|
| 43 |
+
if wamsley or lacourse:
|
| 44 |
+
print("Sleep staging...")
|
| 45 |
+
mask = sleep_stage(data_whole[:, columns.index("offline_filtered_signal")], threshold=150, group_size=100)
|
| 46 |
+
|
| 47 |
# Do Wamsley's method
|
| 48 |
if wamsley:
|
| 49 |
print("Running Wamsley detection...")
|
| 50 |
wamsley_data = offline_detect("Wamsley", \
|
| 51 |
data_whole[:, columns.index("offline_filtered_signal")],\
|
| 52 |
data_whole[:, columns.index("time_stamps")],\
|
| 53 |
+
freq, mask)
|
| 54 |
wamsley_data = np.expand_dims(wamsley_data, axis=1)
|
| 55 |
data_whole = np.concatenate((data_whole, wamsley_data), axis=1)
|
| 56 |
columns.append("wamsley_spindles")
|
|
|
|
| 61 |
lacourse_data = offline_detect("Lacourse", \
|
| 62 |
data_whole[:, columns.index("offline_filtered_signal")],\
|
| 63 |
data_whole[:, columns.index("time_stamps")],\
|
| 64 |
+
freq, mask)
|
| 65 |
lacourse_data = np.expand_dims(lacourse_data, axis=1)
|
| 66 |
data_whole = np.concatenate((data_whole, lacourse_data), axis=1)
|
| 67 |
columns.append("lacourse_spindles")
|
|
|
|
| 77 |
if online_detection:
|
| 78 |
detector = SleepSpindleRealTimeDetector(threshold=threshold, channel=1) # always 1 because we have only one channel
|
| 79 |
stimulator = OfflineSleepSpindleRealTimeStimulator()
|
| 80 |
+
if stimulation_phase != "Fast":
|
| 81 |
+
stimulation_delayer = UpStateDelayer(freq, stimulation_phase == 'Peak', time_to_buffer=buffer_time, stimulate=lambda: None)
|
| 82 |
+
stimulator.add_delayer(stimulation_delayer)
|
| 83 |
+
|
| 84 |
|
| 85 |
if online_filtering or online_detection:
|
| 86 |
print("Running online filtering and detection...")
|
| 87 |
|
| 88 |
points = []
|
| 89 |
online_activations = []
|
| 90 |
+
delayed_stims = []
|
| 91 |
|
| 92 |
# Go through the data
|
| 93 |
for index, point in enumerate(data):
|
|
|
|
| 103 |
# Detect the spindles
|
| 104 |
result = detector.detect([filtered_point])
|
| 105 |
|
| 106 |
+
if stimulation_phase != "Fast":
|
| 107 |
+
delayed_stim = stimulation_delayer.step_timesteps(filtered_point[0])
|
| 108 |
+
if delayed_stim:
|
| 109 |
+
delayed_stims.append(1)
|
| 110 |
+
else:
|
| 111 |
+
delayed_stims.append(0)
|
| 112 |
+
|
| 113 |
# Stimulate if necessary
|
| 114 |
stim = stimulator.stimulate(result)
|
| 115 |
if stim:
|
|
|
|
| 129 |
data_whole = np.concatenate((data_whole, online_activations), axis=1)
|
| 130 |
columns.append("online_stimulations")
|
| 131 |
|
| 132 |
+
if stimulation_phase != "Fast":
|
| 133 |
+
delayed_stims = np.array(delayed_stims)
|
| 134 |
+
delayed_stims = np.expand_dims(delayed_stims, axis=1)
|
| 135 |
+
data_whole = np.concatenate((data_whole, delayed_stims), axis=1)
|
| 136 |
+
columns.append("delayed_stimulations")
|
| 137 |
+
|
| 138 |
print("Saving output...")
|
| 139 |
# Output the data to a csv file
|
| 140 |
np.savetxt("output.csv", data_whole, delimiter=",", header=",".join(columns), comments="")
|
portiloop/src/demo/phase_demo.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
from portiloop.src.demo.offline import run_offline
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def on_upload_file(file):
|
| 7 |
+
# Check if file extension is .xdf
|
| 8 |
+
if file.name.split(".")[-1] != "xdf":
|
| 9 |
+
raise gr.Error("Please upload a .xdf file.")
|
| 10 |
+
else:
|
| 11 |
+
return file.name
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def main():
|
| 15 |
+
with gr.Blocks(title="Portiloop") as demo:
|
| 16 |
+
gr.Markdown("# Portiloop Demo")
|
| 17 |
+
gr.Markdown("This Demo takes as input an XDF file coming from the Portiloop EEG device and allows you to convert it to CSV and perform the following actions:: \n * Filter the data offline \n * Perform offline spindle detection using Wamsley or Lacourse. \n * Simulate the Portiloop online filtering and spindle detection with different parameters.")
|
| 18 |
+
gr.Markdown("Upload your XDF file and click **Run Inference** to start the processing...")
|
| 19 |
+
|
| 20 |
+
with gr.Row():
|
| 21 |
+
xdf_file_button = gr.UploadButton(label="Click to Upload", type="file", file_count="single")
|
| 22 |
+
xdf_file_static = gr.File(label="XDF File", type='file', interactive=False)
|
| 23 |
+
|
| 24 |
+
xdf_file_button.upload(on_upload_file, xdf_file_button, xdf_file_static)
|
| 25 |
+
|
| 26 |
+
# Make a checkbox group for the options
|
| 27 |
+
detect_filter = gr.CheckboxGroup(['Offline Filtering', 'Lacourse Detection', 'Wamsley Detection', 'Online Filtering', 'Online Detection'], type='index', label="Filtering/Detection options")
|
| 28 |
+
|
| 29 |
+
# Options for phase stimulation
|
| 30 |
+
with gr.Row():
|
| 31 |
+
# Dropwdown for phase
|
| 32 |
+
phase = gr.Dropdown(choices=["Peak", "Fast", "Valley"], value="Peak", label="Phase", interactive=True)
|
| 33 |
+
buffer_time = gr.Slider(0, 1, value=0.3, step=0.01, label="Buffer Time", interactive=True)
|
| 34 |
+
|
| 35 |
+
# Threshold value
|
| 36 |
+
threshold = gr.Slider(0, 1, value=0.82, step=0.01, label="Threshold", interactive=True)
|
| 37 |
+
# Detection Channel
|
| 38 |
+
detect_channel = gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8"], value="2", label="Detection Channel in XDF recording", interactive=True)
|
| 39 |
+
# Frequency
|
| 40 |
+
freq = gr.Dropdown(choices=["100", "200", "250", "256", "500", "512", "1000", "1024"], value="250", label="Sampling Frequency (Hz)", interactive=True)
|
| 41 |
+
|
| 42 |
+
with gr.Row():
|
| 43 |
+
output_array = gr.File(label="Output CSV File")
|
| 44 |
+
output_table = gr.Markdown(label="Output Table")
|
| 45 |
+
|
| 46 |
+
run_inference = gr.Button(value="Run Inference")
|
| 47 |
+
run_inference.click(
|
| 48 |
+
fn=run_offline,
|
| 49 |
+
inputs=[
|
| 50 |
+
xdf_file_static,
|
| 51 |
+
detect_filter,
|
| 52 |
+
threshold,
|
| 53 |
+
detect_channel,
|
| 54 |
+
freq,
|
| 55 |
+
phase,
|
| 56 |
+
buffer_time],
|
| 57 |
+
outputs=[output_array, output_table])
|
| 58 |
+
|
| 59 |
+
demo.queue()
|
| 60 |
+
demo.launch(share=False)
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
main()
|
portiloop/src/demo/test_offline.py
CHANGED
|
@@ -2,7 +2,9 @@ import itertools
|
|
| 2 |
import unittest
|
| 3 |
from portiloop.src.demo.offline import run_offline
|
| 4 |
from pathlib import Path
|
|
|
|
| 5 |
|
|
|
|
| 6 |
|
| 7 |
class TestOffline(unittest.TestCase):
|
| 8 |
|
|
@@ -21,7 +23,7 @@ class TestOffline(unittest.TestCase):
|
|
| 21 |
all_options_iterator = itertools.product(*map(combinatorial_config.get, keys))
|
| 22 |
all_options_dicts = [dict(zip(keys, values)) for values in all_options_iterator]
|
| 23 |
self.filtered_options = [value for value in all_options_dicts if (value['online_detection'] and value['online_filtering']) or not value['online_detection']]
|
| 24 |
-
self.xdf_file = Path(__file__).parents[3] / "
|
| 25 |
|
| 26 |
|
| 27 |
def test_all_options(self):
|
|
@@ -30,17 +32,20 @@ class TestOffline(unittest.TestCase):
|
|
| 30 |
self.assertTrue(config['online_filtering'])
|
| 31 |
|
| 32 |
def test_single_option(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
res = list(run_offline(
|
| 34 |
self.xdf_file,
|
| 35 |
-
|
| 36 |
-
online_filtering=True,
|
| 37 |
-
online_detection=True,
|
| 38 |
-
wamsley=True,
|
| 39 |
-
lacourse=True,
|
| 40 |
threshold=0.5,
|
| 41 |
channel_num=2,
|
| 42 |
-
freq=250
|
|
|
|
|
|
|
| 43 |
print(res)
|
|
|
|
| 44 |
|
| 45 |
def tearDown(self):
|
| 46 |
pass
|
|
|
|
| 2 |
import unittest
|
| 3 |
from portiloop.src.demo.offline import run_offline
|
| 4 |
from pathlib import Path
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
|
| 7 |
+
from portiloop.src.demo.utils import sleep_stage, xdf2array
|
| 8 |
|
| 9 |
class TestOffline(unittest.TestCase):
|
| 10 |
|
|
|
|
| 23 |
all_options_iterator = itertools.product(*map(combinatorial_config.get, keys))
|
| 24 |
all_options_dicts = [dict(zip(keys, values)) for values in all_options_iterator]
|
| 25 |
self.filtered_options = [value for value in all_options_dicts if (value['online_detection'] and value['online_filtering']) or not value['online_detection']]
|
| 26 |
+
self.xdf_file = Path(__file__).parents[3] / "test_file.xdf"
|
| 27 |
|
| 28 |
|
| 29 |
def test_all_options(self):
|
|
|
|
| 32 |
self.assertTrue(config['online_filtering'])
|
| 33 |
|
| 34 |
def test_single_option(self):
|
| 35 |
+
|
| 36 |
+
# Test options correspond to an index in the possible checkbox group options
|
| 37 |
+
test_options = [0, 1, 2]
|
| 38 |
+
|
| 39 |
res = list(run_offline(
|
| 40 |
self.xdf_file,
|
| 41 |
+
test_options,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
threshold=0.5,
|
| 43 |
channel_num=2,
|
| 44 |
+
freq=250,
|
| 45 |
+
stimulation_phase="Peak",
|
| 46 |
+
buffer_time=0.3))
|
| 47 |
print(res)
|
| 48 |
+
pass
|
| 49 |
|
| 50 |
def tearDown(self):
|
| 51 |
pass
|
portiloop/src/demo/utils.py
CHANGED
|
@@ -13,6 +13,32 @@ STREAM_NAMES = {
|
|
| 13 |
}
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
class OfflineSleepSpindleRealTimeStimulator(Stimulator):
|
| 17 |
def __init__(self):
|
| 18 |
self.last_detected_ts = time.time()
|
|
@@ -87,15 +113,19 @@ def xdf2array(xdf_path, channel):
|
|
| 87 |
return np.array(csv_list), columns
|
| 88 |
|
| 89 |
|
| 90 |
-
def offline_detect(method, data, timesteps, freq):
|
|
|
|
|
|
|
|
|
|
| 91 |
# Get the spindle data from the offline methods
|
| 92 |
time = np.arange(0, len(data)) / freq
|
|
|
|
| 93 |
if method == "Lacourse":
|
| 94 |
detector = DetectSpindle(method='Lacourse2018')
|
| 95 |
-
spindles, _, _ = detect_Lacourse2018(
|
| 96 |
elif method == "Wamsley":
|
| 97 |
detector = DetectSpindle(method='Wamsley2012')
|
| 98 |
-
spindles, _, _ = detect_Wamsley2012(
|
| 99 |
else:
|
| 100 |
raise ValueError("Invalid method")
|
| 101 |
|
|
@@ -134,18 +164,25 @@ def offline_filter(signal, freq):
|
|
| 134 |
def compute_output_table(online_stimulation, lacourse_spindles, wamsley_spindles):
|
| 135 |
# Count the number of spindles detected by each method
|
| 136 |
online_stimulation_count = np.sum(online_stimulation)
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
# Create markdown table with the results
|
| 146 |
table = "| Method | Detected spindles | Overlap with Portiloop |\n"
|
| 147 |
table += "| --- | --- | --- |\n"
|
| 148 |
table += f"| Online | {online_stimulation_count} | {online_stimulation_count} |\n"
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
| 151 |
return table
|
|
|
|
|
|
| 13 |
}
|
| 14 |
|
| 15 |
|
| 16 |
+
def sleep_stage(data, threshold=150, group_size=2):
|
| 17 |
+
"""Sleep stage approximation using a threshold and a group size.
|
| 18 |
+
Returns a numpy array containing all indices in the input data which CAN be used for offline detection.
|
| 19 |
+
These indices can then be used to reconstruct the signal from the original data.
|
| 20 |
+
"""
|
| 21 |
+
# Find all indexes where the signal is above or below the threshold
|
| 22 |
+
above = np.where(data > threshold)
|
| 23 |
+
below = np.where(data < -threshold)
|
| 24 |
+
indices = np.concatenate((above, below), axis=1)[0]
|
| 25 |
+
|
| 26 |
+
indices = np.sort(indices)
|
| 27 |
+
# Get all the indices where the difference between two consecutive indices is larger than 100
|
| 28 |
+
groups = np.where(np.diff(indices) <= group_size)[0] + 1
|
| 29 |
+
# Get the important indices
|
| 30 |
+
important_indices = indices[groups]
|
| 31 |
+
# Get all the indices between the important indices
|
| 32 |
+
group_filler = [np.arange(indices[groups[n] - 1] + 1, index) for n, index in enumerate(important_indices)]
|
| 33 |
+
# Create flat array from fillers
|
| 34 |
+
group_filler = np.concatenate(group_filler)
|
| 35 |
+
# Append all group fillers to the indices
|
| 36 |
+
masked_indices = np.sort(np.concatenate((indices, group_filler)))
|
| 37 |
+
unmasked_indices = np.setdiff1d(np.arange(len(data)), masked_indices)
|
| 38 |
+
|
| 39 |
+
return unmasked_indices
|
| 40 |
+
|
| 41 |
+
|
| 42 |
class OfflineSleepSpindleRealTimeStimulator(Stimulator):
|
| 43 |
def __init__(self):
|
| 44 |
self.last_detected_ts = time.time()
|
|
|
|
| 113 |
return np.array(csv_list), columns
|
| 114 |
|
| 115 |
|
| 116 |
+
def offline_detect(method, data, timesteps, freq, mask):
|
| 117 |
+
# Extract only the interesting elements from the mask
|
| 118 |
+
data_masked = data[mask]
|
| 119 |
+
|
| 120 |
# Get the spindle data from the offline methods
|
| 121 |
time = np.arange(0, len(data)) / freq
|
| 122 |
+
time_masked = time[mask]
|
| 123 |
if method == "Lacourse":
|
| 124 |
detector = DetectSpindle(method='Lacourse2018')
|
| 125 |
+
spindles, _, _ = detect_Lacourse2018(data_masked, freq, time_masked, detector)
|
| 126 |
elif method == "Wamsley":
|
| 127 |
detector = DetectSpindle(method='Wamsley2012')
|
| 128 |
+
spindles, _, _ = detect_Wamsley2012(data_masked, freq, time_masked, detector)
|
| 129 |
else:
|
| 130 |
raise ValueError("Invalid method")
|
| 131 |
|
|
|
|
| 164 |
def compute_output_table(online_stimulation, lacourse_spindles, wamsley_spindles):
|
| 165 |
# Count the number of spindles detected by each method
|
| 166 |
online_stimulation_count = np.sum(online_stimulation)
|
| 167 |
+
if lacourse_spindles is not None:
|
| 168 |
+
lacourse_spindles_count = sum([1 for index, spindle in enumerate(lacourse_spindles) if spindle == 1 and lacourse_spindles[index - 1] == 0])
|
| 169 |
+
# Count how many spindles were detected by both online and lacourse
|
| 170 |
+
both_online_lacourse = sum([1 for index, spindle in enumerate(online_stimulation) if spindle == 1 and lacourse_spindles[index] == 1])
|
| 171 |
+
|
| 172 |
+
if wamsley_spindles is not None:
|
| 173 |
+
wamsley_spindles_count = sum([1 for index, spindle in enumerate(wamsley_spindles) if spindle == 1 and wamsley_spindles[index - 1] == 0])
|
| 174 |
+
# Count how many spindles were detected by both online and wamsley
|
| 175 |
+
both_online_wamsley = sum([1 for index, spindle in enumerate(online_stimulation) if spindle == 1 and wamsley_spindles[index] == 1])
|
| 176 |
+
|
| 177 |
+
|
| 178 |
|
| 179 |
# Create markdown table with the results
|
| 180 |
table = "| Method | Detected spindles | Overlap with Portiloop |\n"
|
| 181 |
table += "| --- | --- | --- |\n"
|
| 182 |
table += f"| Online | {online_stimulation_count} | {online_stimulation_count} |\n"
|
| 183 |
+
if lacourse_spindles is not None:
|
| 184 |
+
table += f"| Lacourse | {lacourse_spindles_count} | {both_online_lacourse} |\n"
|
| 185 |
+
if wamsley_spindles is not None:
|
| 186 |
+
table += f"| Wamsley | {wamsley_spindles_count} | {both_online_wamsley} |\n"
|
| 187 |
return table
|
| 188 |
+
|
portiloop/src/stimulation.py
CHANGED
|
@@ -3,6 +3,8 @@ from enum import Enum
|
|
| 3 |
import time
|
| 4 |
from threading import Thread, Lock
|
| 5 |
from pathlib import Path
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from portiloop.src import ADS
|
| 8 |
|
|
@@ -146,20 +148,18 @@ class SleepSpindleRealTimeStimulator(Stimulator):
|
|
| 146 |
|
| 147 |
# Class that delays stimulation to always stimulate peak or through
|
| 148 |
class UpStateDelayer:
|
| 149 |
-
def __init__(self, sample_freq,
|
| 150 |
'''
|
| 151 |
args:
|
| 152 |
sample_freq: int -> Sampling frequency of signal in Hz
|
| 153 |
time_to_wait: float -> Time to wait to build buffer in seconds
|
| 154 |
'''
|
| 155 |
# Get number of timesteps for a whole spindle
|
| 156 |
-
self.spindle_timesteps = (1/spindle_freq) * sample_freq # s *
|
| 157 |
self.sample_freq = sample_freq
|
| 158 |
-
self.buffer_size = 1.5 * self.spindle_timesteps
|
| 159 |
self.peak = peak
|
| 160 |
self.buffer = []
|
| 161 |
self.time_to_buffer = time_to_buffer
|
| 162 |
-
self.stimulate =
|
| 163 |
|
| 164 |
self.state = States.NO_SPINDLE
|
| 165 |
|
|
@@ -192,10 +192,39 @@ class UpStateDelayer:
|
|
| 192 |
return True
|
| 193 |
return False
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
def detected(self):
|
| 196 |
if self.state == States.NO_SPINDLE:
|
| 197 |
self.state = States.BUFFERING
|
| 198 |
-
self.time_started = time.time()
|
| 199 |
|
| 200 |
def compute_time_to_wait(self):
|
| 201 |
"""
|
|
@@ -208,8 +237,27 @@ class UpStateDelayer:
|
|
| 208 |
# Returns the index of the last peak in the buffer
|
| 209 |
peaks, _ = find_peaks(self.buffer, prominence=1)
|
| 210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
# Compute the time until next peak and return it
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
class States(Enum):
|
| 215 |
NO_SPINDLE = 0
|
|
|
|
| 3 |
import time
|
| 4 |
from threading import Thread, Lock
|
| 5 |
from pathlib import Path
|
| 6 |
+
import numpy as np
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
|
| 9 |
from portiloop.src import ADS
|
| 10 |
|
|
|
|
| 148 |
|
| 149 |
# Class that delays stimulation to always stimulate peak or through
|
| 150 |
class UpStateDelayer:
|
| 151 |
+
def __init__(self, sample_freq, peak, time_to_buffer, stimulate=None):
|
| 152 |
'''
|
| 153 |
args:
|
| 154 |
sample_freq: int -> Sampling frequency of signal in Hz
|
| 155 |
time_to_wait: float -> Time to wait to build buffer in seconds
|
| 156 |
'''
|
| 157 |
# Get number of timesteps for a whole spindle
|
|
|
|
| 158 |
self.sample_freq = sample_freq
|
|
|
|
| 159 |
self.peak = peak
|
| 160 |
self.buffer = []
|
| 161 |
self.time_to_buffer = time_to_buffer
|
| 162 |
+
self.stimulate = stimulate
|
| 163 |
|
| 164 |
self.state = States.NO_SPINDLE
|
| 165 |
|
|
|
|
| 192 |
return True
|
| 193 |
return False
|
| 194 |
|
| 195 |
+
def step_timesteps(self, point):
|
| 196 |
+
'''
|
| 197 |
+
Step the delayer, ads a point to buffer if necessary.
|
| 198 |
+
Returns True if stimulation is actually done
|
| 199 |
+
'''
|
| 200 |
+
if self.state == States.NO_SPINDLE:
|
| 201 |
+
return False
|
| 202 |
+
elif self.state == States.BUFFERING:
|
| 203 |
+
self.buffer.append(point)
|
| 204 |
+
# If we are done buffering, move on to the waiting stage
|
| 205 |
+
if len(self.buffer) >= self.time_to_buffer * self.sample_freq:
|
| 206 |
+
# Compute the necessary time to wait
|
| 207 |
+
self.time_to_wait = self.compute_time_to_wait()
|
| 208 |
+
self.state = States.DELAYING
|
| 209 |
+
self.buffer = []
|
| 210 |
+
self.delaying_counter = 0
|
| 211 |
+
return False
|
| 212 |
+
elif self.state == States.DELAYING:
|
| 213 |
+
# Check if we are done delaying
|
| 214 |
+
self.delaying_counter += 1
|
| 215 |
+
if self.delaying_counter >= self.time_to_wait * self.sample_freq:
|
| 216 |
+
# Actually stimulate the patient after the delay
|
| 217 |
+
if self.stimulate is not None:
|
| 218 |
+
self.stimulate()
|
| 219 |
+
# Reset state
|
| 220 |
+
self.time_to_wait = -1
|
| 221 |
+
self.state = States.NO_SPINDLE
|
| 222 |
+
return True
|
| 223 |
+
return False
|
| 224 |
+
|
| 225 |
def detected(self):
|
| 226 |
if self.state == States.NO_SPINDLE:
|
| 227 |
self.state = States.BUFFERING
|
|
|
|
| 228 |
|
| 229 |
def compute_time_to_wait(self):
|
| 230 |
"""
|
|
|
|
| 237 |
# Returns the index of the last peak in the buffer
|
| 238 |
peaks, _ = find_peaks(self.buffer, prominence=1)
|
| 239 |
|
| 240 |
+
# Make a figure to show the peaks
|
| 241 |
+
if False:
|
| 242 |
+
plt.figure()
|
| 243 |
+
plt.plot(self.buffer)
|
| 244 |
+
for peak in peaks:
|
| 245 |
+
plt.axvline(x=peak)
|
| 246 |
+
plt.plot(np.zeros_like(self.buffer), "--", color="gray")
|
| 247 |
+
plt.show()
|
| 248 |
+
|
| 249 |
+
if len(peaks) == 0:
|
| 250 |
+
print("No peaks found, increase buffer size")
|
| 251 |
+
return (self.sample_freq / 10) * (1.0 / self.sample_freq)
|
| 252 |
+
|
| 253 |
+
# Compute average distance between each peak
|
| 254 |
+
avg_dist = np.mean(np.diff(peaks))
|
| 255 |
+
|
| 256 |
# Compute the time until next peak and return it
|
| 257 |
+
if (avg_dist < len(self.buffer) - peaks[-1]):
|
| 258 |
+
print("Average distance between peaks is smaller than the time to last peak, decrease buffer size")
|
| 259 |
+
return (len(self.buffer) - peaks[-1]) * (1.0 / self.sample_freq)
|
| 260 |
+
return (avg_dist - (len(self.buffer) - peaks[-1])) * (1.0 / self.sample_freq)
|
| 261 |
|
| 262 |
class States(Enum):
|
| 263 |
NO_SPINDLE = 0
|