Spaces:
Build error
Build error
ybouteiller
commited on
Commit
·
45a88e4
1
Parent(s):
ba200d7
full pipeline
Browse files- portiloop/capture.py +56 -5
- portiloop/detection.py +45 -28
- portiloop/notebooks/tests.ipynb +41 -4
- portiloop/stimulation.py +63 -2
portiloop/capture.py
CHANGED
|
@@ -12,6 +12,7 @@ import multiprocessing as mp
|
|
| 12 |
import warnings
|
| 13 |
import shutil
|
| 14 |
from threading import Thread, Lock
|
|
|
|
| 15 |
|
| 16 |
import matplotlib.pyplot as plt
|
| 17 |
from EDFlib.edfwriter import EDFwriter
|
|
@@ -198,7 +199,7 @@ class FilterPipeline:
|
|
| 198 |
sampling_rate,
|
| 199 |
power_line_fq=60,
|
| 200 |
use_custom_fir=False,
|
| 201 |
-
custom_fir_order=
|
| 202 |
custom_fir_cutoff=30,
|
| 203 |
alpha_avg=0.1,
|
| 204 |
alpha_std=0.001,
|
|
@@ -411,7 +412,7 @@ class Capture:
|
|
| 411 |
self.polyak_std = 0.001
|
| 412 |
self.epsilon = 0.000001
|
| 413 |
self.custom_fir = False
|
| 414 |
-
self.custom_fir_order =
|
| 415 |
self.custom_fir_cutoff = 30
|
| 416 |
self.filter = True
|
| 417 |
self.record = False
|
|
@@ -436,6 +437,19 @@ class Capture:
|
|
| 436 |
self.detector_cls = detector_cls
|
| 437 |
self.stimulator_cls = stimulator_cls
|
| 438 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
# widgets ===============================
|
| 440 |
|
| 441 |
# CHANNELS ------------------------------
|
|
@@ -657,6 +671,22 @@ class Capture:
|
|
| 657 |
indent=False
|
| 658 |
)
|
| 659 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
# CALLBACKS ----------------------
|
| 661 |
|
| 662 |
self.b_capture.observe(self.on_b_capture, 'value')
|
|
@@ -684,6 +714,8 @@ class Capture:
|
|
| 684 |
self.b_polyak_mean.observe(self.on_b_polyak_mean, 'value')
|
| 685 |
self.b_polyak_std.observe(self.on_b_polyak_std, 'value')
|
| 686 |
self.b_epsilon.observe(self.on_b_epsilon, 'value')
|
|
|
|
|
|
|
| 687 |
|
| 688 |
self.display_buttons()
|
| 689 |
|
|
@@ -698,7 +730,8 @@ class Capture:
|
|
| 698 |
self.b_power_line,
|
| 699 |
self.b_clock,
|
| 700 |
widgets.HBox([self.b_filter, self.b_detect, self.b_stimulate, self.b_record, self.b_lsl, self.b_display]),
|
| 701 |
-
self.b_threshold,
|
|
|
|
| 702 |
self.b_accordion_filter,
|
| 703 |
self.b_capture]))
|
| 704 |
|
|
@@ -727,6 +760,7 @@ class Capture:
|
|
| 727 |
self.b_custom_fir_cutoff.disabled = not self.custom_fir
|
| 728 |
self.b_stimulate.disabled = not self.detect
|
| 729 |
self.b_threshold.disabled = not self.detect
|
|
|
|
| 730 |
|
| 731 |
def disable_buttons(self):
|
| 732 |
self.b_frequency.disabled = True
|
|
@@ -754,6 +788,7 @@ class Capture:
|
|
| 754 |
self.b_custom_fir_order.disabled = True
|
| 755 |
self.b_custom_fir_cutoff.disabled = True
|
| 756 |
self.b_threshold.disabled = True
|
|
|
|
| 757 |
|
| 758 |
def on_b_radio_ch2(self, value):
|
| 759 |
self.channel_states[1] = value['new']
|
|
@@ -789,6 +824,7 @@ class Capture:
|
|
| 789 |
return
|
| 790 |
detector_cls = self.detector_cls if self.detect else None
|
| 791 |
stimulator_cls = self.stimulator_cls if self.stimulate else None
|
|
|
|
| 792 |
self._t_capture = Thread(target=self.start_capture,
|
| 793 |
args=(self.filter,
|
| 794 |
detector_cls,
|
|
@@ -918,6 +954,16 @@ class Capture:
|
|
| 918 |
def on_b_display(self, value):
|
| 919 |
val = value['new']
|
| 920 |
self.display = val
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 921 |
|
| 922 |
def open_recording_file(self):
|
| 923 |
nb_signals = self.nb_signals
|
|
@@ -1011,8 +1057,9 @@ class Capture:
|
|
| 1011 |
lsl_info = StreamInfo(name='Portiloop',
|
| 1012 |
type='EEG',
|
| 1013 |
channel_count=8,
|
|
|
|
| 1014 |
channel_format='float32',
|
| 1015 |
-
source_id='') # TODO: replace this by unique device identifier
|
| 1016 |
lsl_outlet = StreamOutlet(lsl_info)
|
| 1017 |
|
| 1018 |
buffer = []
|
|
@@ -1046,9 +1093,13 @@ class Capture:
|
|
| 1046 |
|
| 1047 |
if detector is not None:
|
| 1048 |
detection_signal = detector.detect(filtered_point)
|
| 1049 |
-
|
| 1050 |
if stimulator is not None:
|
| 1051 |
stimulator.stimulate(detection_signal)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1052 |
|
| 1053 |
if lsl:
|
| 1054 |
lsl_outlet.push_sample(filtered_point[-1])
|
|
|
|
| 12 |
import warnings
|
| 13 |
import shutil
|
| 14 |
from threading import Thread, Lock
|
| 15 |
+
import alsaaudio
|
| 16 |
|
| 17 |
import matplotlib.pyplot as plt
|
| 18 |
from EDFlib.edfwriter import EDFwriter
|
|
|
|
| 199 |
sampling_rate,
|
| 200 |
power_line_fq=60,
|
| 201 |
use_custom_fir=False,
|
| 202 |
+
custom_fir_order=20,
|
| 203 |
custom_fir_cutoff=30,
|
| 204 |
alpha_avg=0.1,
|
| 205 |
alpha_std=0.001,
|
|
|
|
| 412 |
self.polyak_std = 0.001
|
| 413 |
self.epsilon = 0.000001
|
| 414 |
self.custom_fir = False
|
| 415 |
+
self.custom_fir_order = 20
|
| 416 |
self.custom_fir_cutoff = 30
|
| 417 |
self.filter = True
|
| 418 |
self.record = False
|
|
|
|
| 437 |
self.detector_cls = detector_cls
|
| 438 |
self.stimulator_cls = stimulator_cls
|
| 439 |
|
| 440 |
+
self._test_stimulus_lock = Lock()
|
| 441 |
+
self._test_stimulus = False
|
| 442 |
+
|
| 443 |
+
mixers = alsaaudio.mixers()
|
| 444 |
+
if 'PCM' in mixers:
|
| 445 |
+
self.mixer = alsaaudio.Mixer(control='PCM')
|
| 446 |
+
else:
|
| 447 |
+
assert len(mixers) > 0, 'No ALSA mixer found'
|
| 448 |
+
warnings.warn(f"Could not find mixer PCM, using {mixers[0]} instead.")
|
| 449 |
+
self.mixer = alsaaudio.Mixer(control=mixers[0])
|
| 450 |
+
self.volume = self.mixer.getvolume()[0] # we will set the same volume on all channels
|
| 451 |
+
|
| 452 |
+
|
| 453 |
# widgets ===============================
|
| 454 |
|
| 455 |
# CHANNELS ------------------------------
|
|
|
|
| 671 |
indent=False
|
| 672 |
)
|
| 673 |
|
| 674 |
+
self.b_volume = widgets.IntSlider(
|
| 675 |
+
value=self.volume,
|
| 676 |
+
min=0,
|
| 677 |
+
max=100,
|
| 678 |
+
step=1,
|
| 679 |
+
description="Volume",
|
| 680 |
+
disabled=False
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
self.b_test_stimulus = widgets.Button(
|
| 684 |
+
description='Test stimulus',
|
| 685 |
+
disabled=True,
|
| 686 |
+
button_style='', # 'success', 'info', 'warning', 'danger' or ''
|
| 687 |
+
tooltip='Send a test stimulus'
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
# CALLBACKS ----------------------
|
| 691 |
|
| 692 |
self.b_capture.observe(self.on_b_capture, 'value')
|
|
|
|
| 714 |
self.b_polyak_mean.observe(self.on_b_polyak_mean, 'value')
|
| 715 |
self.b_polyak_std.observe(self.on_b_polyak_std, 'value')
|
| 716 |
self.b_epsilon.observe(self.on_b_epsilon, 'value')
|
| 717 |
+
self.b_volume.observe(self.on_b_volume, 'value')
|
| 718 |
+
self.b_test_stimulus.on_click(self.on_b_test_stimulus)
|
| 719 |
|
| 720 |
self.display_buttons()
|
| 721 |
|
|
|
|
| 730 |
self.b_power_line,
|
| 731 |
self.b_clock,
|
| 732 |
widgets.HBox([self.b_filter, self.b_detect, self.b_stimulate, self.b_record, self.b_lsl, self.b_display]),
|
| 733 |
+
widgets.HBox([self.b_threshold, self.b_test_stimulus]),
|
| 734 |
+
self.b_volume,
|
| 735 |
self.b_accordion_filter,
|
| 736 |
self.b_capture]))
|
| 737 |
|
|
|
|
| 760 |
self.b_custom_fir_cutoff.disabled = not self.custom_fir
|
| 761 |
self.b_stimulate.disabled = not self.detect
|
| 762 |
self.b_threshold.disabled = not self.detect
|
| 763 |
+
self.b_test_stimulus.disabled = True # only enabled when running
|
| 764 |
|
| 765 |
def disable_buttons(self):
|
| 766 |
self.b_frequency.disabled = True
|
|
|
|
| 788 |
self.b_custom_fir_order.disabled = True
|
| 789 |
self.b_custom_fir_cutoff.disabled = True
|
| 790 |
self.b_threshold.disabled = True
|
| 791 |
+
self.b_test_stimulus.disabled = not self.stimulate # only enabled when running
|
| 792 |
|
| 793 |
def on_b_radio_ch2(self, value):
|
| 794 |
self.channel_states[1] = value['new']
|
|
|
|
| 824 |
return
|
| 825 |
detector_cls = self.detector_cls if self.detect else None
|
| 826 |
stimulator_cls = self.stimulator_cls if self.stimulate else None
|
| 827 |
+
|
| 828 |
self._t_capture = Thread(target=self.start_capture,
|
| 829 |
args=(self.filter,
|
| 830 |
detector_cls,
|
|
|
|
| 954 |
def on_b_display(self, value):
|
| 955 |
val = value['new']
|
| 956 |
self.display = val
|
| 957 |
+
|
| 958 |
+
def on_b_volume(self, value):
|
| 959 |
+
val = value['new']
|
| 960 |
+
if val >= 0 and val <= 100:
|
| 961 |
+
self.volume = val
|
| 962 |
+
self.mixer.setvolume(self.volume)
|
| 963 |
+
|
| 964 |
+
def on_b_test_stimulus(self, b):
|
| 965 |
+
with self._test_stimulus_lock:
|
| 966 |
+
self._test_stimulus = True
|
| 967 |
|
| 968 |
def open_recording_file(self):
|
| 969 |
nb_signals = self.nb_signals
|
|
|
|
| 1057 |
lsl_info = StreamInfo(name='Portiloop',
|
| 1058 |
type='EEG',
|
| 1059 |
channel_count=8,
|
| 1060 |
+
nominal_srate=self.frequency,
|
| 1061 |
channel_format='float32',
|
| 1062 |
+
source_id='portiloop1') # TODO: replace this by unique device identifier
|
| 1063 |
lsl_outlet = StreamOutlet(lsl_info)
|
| 1064 |
|
| 1065 |
buffer = []
|
|
|
|
| 1093 |
|
| 1094 |
if detector is not None:
|
| 1095 |
detection_signal = detector.detect(filtered_point)
|
|
|
|
| 1096 |
if stimulator is not None:
|
| 1097 |
stimulator.stimulate(detection_signal)
|
| 1098 |
+
with self._test_stimulus_lock:
|
| 1099 |
+
test_stimulus = self._test_stimulus
|
| 1100 |
+
self._test_stimulus = False
|
| 1101 |
+
if test_stimulus:
|
| 1102 |
+
stimulator.test_stimulus()
|
| 1103 |
|
| 1104 |
if lsl:
|
| 1105 |
lsl_outlet.push_sample(filtered_point[-1])
|
portiloop/detection.py
CHANGED
|
@@ -82,51 +82,68 @@ class SleepSpindleRealTimeDetector(Detector):
|
|
| 82 |
return res
|
| 83 |
|
| 84 |
def add_datapoint(self, input_float):
|
|
|
|
|
|
|
|
|
|
| 85 |
input_float = input_float[self.channel - 1]
|
| 86 |
result = None
|
|
|
|
| 87 |
self.buffer.append(input_float)
|
| 88 |
if len(self.buffer) > self.window_size:
|
|
|
|
| 89 |
self.buffer = self.buffer[1:]
|
| 90 |
self.current_stride_counter += 1
|
| 91 |
if self.current_stride_counter == self.stride_counters[self.interpreter_counter]:
|
| 92 |
-
|
|
|
|
| 93 |
self.interpreter_counter += 1
|
| 94 |
self.interpreter_counter %= self.num_models_parallel
|
| 95 |
self.current_stride_counter = 0
|
| 96 |
return result
|
| 97 |
-
|
| 98 |
-
def call_model(self, idx, input_float=None):
|
| 99 |
-
if input_float is None:
|
| 100 |
-
# For debugging purposes
|
| 101 |
-
input_shape = self.input_details[0]['shape']
|
| 102 |
-
input = np.array(np.random.random_sample(input_shape), dtype=np.int8)
|
| 103 |
-
else:
|
| 104 |
-
# Convert float input to Int
|
| 105 |
-
input_scale, input_zero_point = self.input_details[0]["quantization"]
|
| 106 |
-
input = np.asarray(input_float) / input_scale + input_zero_point
|
| 107 |
-
input = input.astype(self.input_details[0]["dtype"])
|
| 108 |
-
input = input.reshape((1, 1, -1))
|
| 109 |
-
|
| 110 |
-
# TODO: Milo please implement this:
|
| 111 |
-
# self.interpreters[idx].set_tensor(self.input_details[0]['index'], (self.h[idx], input))
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
|
|
|
| 115 |
|
| 116 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
# output = float(output - output_zero_point) * output_scale
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
|
|
|
|
|
|
|
|
|
| 127 |
if self.verbose:
|
| 128 |
-
print(f"Computed output {
|
| 129 |
|
| 130 |
-
return
|
|
|
|
|
|
|
| 131 |
|
| 132 |
|
|
|
|
| 82 |
return res
|
| 83 |
|
| 84 |
def add_datapoint(self, input_float):
|
| 85 |
+
'''
|
| 86 |
+
Add one datapoint to the buffer
|
| 87 |
+
'''
|
| 88 |
input_float = input_float[self.channel - 1]
|
| 89 |
result = None
|
| 90 |
+
# Add to current buffer
|
| 91 |
self.buffer.append(input_float)
|
| 92 |
if len(self.buffer) > self.window_size:
|
| 93 |
+
# Remove the end of the buffer
|
| 94 |
self.buffer = self.buffer[1:]
|
| 95 |
self.current_stride_counter += 1
|
| 96 |
if self.current_stride_counter == self.stride_counters[self.interpreter_counter]:
|
| 97 |
+
# If we have reached the next window size, we send the current buffer to the inference function and update the hidden state
|
| 98 |
+
result, self.h[self.interpreter_counter] = self.forward_tflite(self.interpreter_counter, self.buffer, self.h[self.interpreter_counter])
|
| 99 |
self.interpreter_counter += 1
|
| 100 |
self.interpreter_counter %= self.num_models_parallel
|
| 101 |
self.current_stride_counter = 0
|
| 102 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
+
def forward_tflite(self, idx, input_x, input_h):
|
| 105 |
+
input_details = self.interpreters[idx].get_input_details()
|
| 106 |
+
output_details = self.interpreters[idx].get_output_details()
|
| 107 |
|
| 108 |
+
# convert input to int
|
| 109 |
+
input_scale, input_zero_point = input_details[1]["quantization"]
|
| 110 |
+
input_x = np.asarray(input_x) / input_scale + input_zero_point
|
| 111 |
+
input_data_x = input_x.astype(input_details[1]["dtype"])
|
| 112 |
+
input_data_x = np.expand_dims(input_data_x, (0, 1))
|
| 113 |
+
|
| 114 |
+
# input_scale, input_zero_point = input_details[0]["quantization"]
|
| 115 |
+
# input = np.asarray(input) / input_scale + input_zero_point
|
| 116 |
+
|
| 117 |
+
# Test the model on random input data.
|
| 118 |
+
input_shape_h = input_details[0]['shape']
|
| 119 |
+
input_shape_x = input_details[1]['shape']
|
| 120 |
+
|
| 121 |
+
# input_data_h = np.array(np.random.random_sample(input_shape_h), dtype=np.int8)
|
| 122 |
+
# input_data_x = np.array(np.random.random_sample(input_shape_x), dtype=np.int8)
|
| 123 |
+
self.interpreters[idx].set_tensor(input_details[0]['index'], input_h)
|
| 124 |
+
self.interpreters[idx].set_tensor(input_details[1]['index'], input_data_x)
|
| 125 |
|
| 126 |
+
if self.verbose:
|
| 127 |
+
start_time = time.time()
|
| 128 |
+
|
| 129 |
+
self.interpreters[idx].invoke()
|
|
|
|
| 130 |
|
| 131 |
+
if self.verbose:
|
| 132 |
+
end_time = time.time()
|
| 133 |
+
|
| 134 |
+
# The function `get_tensor()` returns a copy of the tensor data.
|
| 135 |
+
# Use `tensor()` in order to get a pointer to the tensor.
|
| 136 |
+
output_data_h = self.interpreters[idx].get_tensor(output_details[0]['index'])
|
| 137 |
+
output_data_y = self.interpreters[idx].get_tensor(output_details[1]['index'])
|
| 138 |
|
| 139 |
+
output_scale, output_zero_point = output_details[1]["quantization"]
|
| 140 |
+
output_data_y = float(output_data_y - output_zero_point) * output_scale
|
| 141 |
+
|
| 142 |
if self.verbose:
|
| 143 |
+
print(f"Computed output {output_data_y} in {end_time - start_time} seconds")
|
| 144 |
|
| 145 |
+
return output_data_y, output_data_h
|
| 146 |
+
|
| 147 |
+
|
| 148 |
|
| 149 |
|
portiloop/notebooks/tests.ipynb
CHANGED
|
@@ -2,24 +2,61 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
-
"execution_count":
|
| 6 |
"id": "7b2fc5da",
|
| 7 |
"metadata": {
|
| 8 |
"scrolled": false
|
| 9 |
},
|
| 10 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
"source": [
|
| 12 |
"from portiloop.capture import Capture\n",
|
| 13 |
"from portiloop.detection import SleepSpindleRealTimeDetector\n",
|
| 14 |
"from portiloop.stimulation import SleepSpindleRealTimeStimulator\n",
|
| 15 |
"\n",
|
| 16 |
-
"
|
|
|
|
|
|
|
|
|
|
| 17 |
]
|
| 18 |
},
|
| 19 |
{
|
| 20 |
"cell_type": "code",
|
| 21 |
"execution_count": null,
|
| 22 |
-
"id": "
|
| 23 |
"metadata": {},
|
| 24 |
"outputs": [],
|
| 25 |
"source": []
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
"id": "7b2fc5da",
|
| 7 |
"metadata": {
|
| 8 |
"scrolled": false
|
| 9 |
},
|
| 10 |
+
"outputs": [
|
| 11 |
+
{
|
| 12 |
+
"data": {
|
| 13 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 14 |
+
"model_id": "5bd498c14c0b47ef8fc0c7b25d6197c0",
|
| 15 |
+
"version_major": 2,
|
| 16 |
+
"version_minor": 0
|
| 17 |
+
},
|
| 18 |
+
"text/plain": [
|
| 19 |
+
"VBox(children=(Accordion(children=(GridBox(children=(Label(value='CH1'), Label(value='CH2'), Label(value='CH3'…"
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
"metadata": {},
|
| 23 |
+
"output_type": "display_data"
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"name": "stdout",
|
| 27 |
+
"output_type": "stream",
|
| 28 |
+
"text": [
|
| 29 |
+
"DEBUG:/home/mendel/software/portiloop-software/portiloop/sounds/stimulus.wav\n",
|
| 30 |
+
"PID capture: 4311\n",
|
| 31 |
+
"DEBUG: new config[5]:0xe1\n",
|
| 32 |
+
"DEBUG: new config[6]:0xe1\n",
|
| 33 |
+
"DEBUG: new config[7]:0xe1\n",
|
| 34 |
+
"DEBUG: new config[8]:0xe1\n",
|
| 35 |
+
"DEBUG: new config[9]:0xe1\n",
|
| 36 |
+
"DEBUG: new config[10]:0xe1\n",
|
| 37 |
+
"DEBUG: new config[11]:0xe1\n",
|
| 38 |
+
"DEBUG: new config[12]:0xe1\n",
|
| 39 |
+
"DEBUG: new config[13]:0x0\n",
|
| 40 |
+
"DEBUG: new config[14]:0x0\n",
|
| 41 |
+
"DEBUG: new config[3]:0xe8\n"
|
| 42 |
+
]
|
| 43 |
+
}
|
| 44 |
+
],
|
| 45 |
"source": [
|
| 46 |
"from portiloop.capture import Capture\n",
|
| 47 |
"from portiloop.detection import SleepSpindleRealTimeDetector\n",
|
| 48 |
"from portiloop.stimulation import SleepSpindleRealTimeStimulator\n",
|
| 49 |
"\n",
|
| 50 |
+
"my_detector_class = SleepSpindleRealTimeDetector # you may want to implement yours\n",
|
| 51 |
+
"my_stimulator_class = SleepSpindleRealTimeStimulator # you may also want to implement yours\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"cap = Capture(detector_cls=my_detector_class, stimulator_cls=my_stimulator_class)"
|
| 54 |
]
|
| 55 |
},
|
| 56 |
{
|
| 57 |
"cell_type": "code",
|
| 58 |
"execution_count": null,
|
| 59 |
+
"id": "fd7c79a7",
|
| 60 |
"metadata": {},
|
| 61 |
"outputs": [],
|
| 62 |
"source": []
|
portiloop/stimulation.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
import time
|
| 3 |
-
from playsound import playsound
|
| 4 |
from threading import Thread, Lock
|
| 5 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
# Abstract interface for developers:
|
|
@@ -18,6 +20,12 @@ class Stimulator(ABC):
|
|
| 18 |
detection_signal: Object: the output of the Detector.add_datapoints method.
|
| 19 |
"""
|
| 20 |
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
# Example implementation for sleep spindles
|
|
@@ -30,6 +38,52 @@ class SleepSpindleRealTimeStimulator(Stimulator):
|
|
| 30 |
self._lock = Lock()
|
| 31 |
self.last_detected_ts = time.time()
|
| 32 |
self.wait_t = 0.4 # 400 ms
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def stimulate(self, detection_signal):
|
| 35 |
for sig in detection_signal:
|
|
@@ -43,6 +97,13 @@ class SleepSpindleRealTimeStimulator(Stimulator):
|
|
| 43 |
self.last_detected_ts = ts
|
| 44 |
|
| 45 |
def _t_sound(self):
|
| 46 |
-
|
|
|
|
| 47 |
with self._lock:
|
| 48 |
self._thread = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
import time
|
|
|
|
| 3 |
from threading import Thread, Lock
|
| 4 |
from pathlib import Path
|
| 5 |
+
import alsaaudio
|
| 6 |
+
import wave
|
| 7 |
+
import pylsl
|
| 8 |
|
| 9 |
|
| 10 |
# Abstract interface for developers:
|
|
|
|
| 20 |
detection_signal: Object: the output of the Detector.add_datapoints method.
|
| 21 |
"""
|
| 22 |
raise NotImplementedError
|
| 23 |
+
|
| 24 |
+
def test_stimulus(self):
|
| 25 |
+
"""
|
| 26 |
+
Optional: this is called when the 'Test stimulus' button is pressed.
|
| 27 |
+
"""
|
| 28 |
+
pass
|
| 29 |
|
| 30 |
|
| 31 |
# Example implementation for sleep spindles
|
|
|
|
| 38 |
self._lock = Lock()
|
| 39 |
self.last_detected_ts = time.time()
|
| 40 |
self.wait_t = 0.4 # 400 ms
|
| 41 |
+
|
| 42 |
+
lsl_markers_info = pylsl.StreamInfo(name='Portiloop_stimuli',
|
| 43 |
+
type='Markers',
|
| 44 |
+
channel_count=1,
|
| 45 |
+
channel_format='string',
|
| 46 |
+
source_id='portiloop1') # TODO: replace this by unique device identifier
|
| 47 |
+
self.lsl_outlet_markers = pylsl.StreamOutlet(lsl_markers_info)
|
| 48 |
+
|
| 49 |
+
# Initialize Alsa stuff
|
| 50 |
+
# Open WAV file and set PCM device
|
| 51 |
+
with wave.open(str(self._sound), 'rb') as f:
|
| 52 |
+
device = 'default'
|
| 53 |
+
|
| 54 |
+
format = None
|
| 55 |
+
|
| 56 |
+
# 8bit is unsigned in wav files
|
| 57 |
+
if f.getsampwidth() == 1:
|
| 58 |
+
format = alsaaudio.PCM_FORMAT_U8
|
| 59 |
+
# Otherwise we assume signed data, little endian
|
| 60 |
+
elif f.getsampwidth() == 2:
|
| 61 |
+
format = alsaaudio.PCM_FORMAT_S16_LE
|
| 62 |
+
elif f.getsampwidth() == 3:
|
| 63 |
+
format = alsaaudio.PCM_FORMAT_S24_3LE
|
| 64 |
+
elif f.getsampwidth() == 4:
|
| 65 |
+
format = alsaaudio.PCM_FORMAT_S32_LE
|
| 66 |
+
else:
|
| 67 |
+
raise ValueError('Unsupported format')
|
| 68 |
+
|
| 69 |
+
self.periodsize = f.getframerate() // 8
|
| 70 |
+
|
| 71 |
+
self.pcm = alsaaudio.PCM(channels=f.getnchannels(), rate=f.getframerate(), format=format, periodsize=self.periodsize, device=device)
|
| 72 |
+
|
| 73 |
+
# Store data in list to avoid reopening the file
|
| 74 |
+
data = f.readframes(self.periodsize)
|
| 75 |
+
self.wav_list = [data]
|
| 76 |
+
while data:
|
| 77 |
+
self.wav_list.append(data)
|
| 78 |
+
data = f.readframes(self.periodsize)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def play_sound(self):
|
| 82 |
+
'''
|
| 83 |
+
Open the wav file and play a sound
|
| 84 |
+
'''
|
| 85 |
+
for data in self.wav_list:
|
| 86 |
+
self.pcm.write(data)
|
| 87 |
|
| 88 |
def stimulate(self, detection_signal):
|
| 89 |
for sig in detection_signal:
|
|
|
|
| 97 |
self.last_detected_ts = ts
|
| 98 |
|
| 99 |
def _t_sound(self):
|
| 100 |
+
self.lsl_outlet_markers.push_sample(['STIM'])
|
| 101 |
+
self.play_sound()
|
| 102 |
with self._lock:
|
| 103 |
self._thread = None
|
| 104 |
+
|
| 105 |
+
def test_stimulus(self):
|
| 106 |
+
with self._lock:
|
| 107 |
+
if self._thread is None:
|
| 108 |
+
self._thread = Thread(target=self._t_sound, daemon=True)
|
| 109 |
+
self._thread.start()
|