Spaces:
Build error
Build error
ybouteiller
commited on
Commit
·
35cdf83
1
Parent(s):
120f728
debugged and cleaned + implement stimulation skeletton
Browse files- portiloop/capture.py +28 -17
- portiloop/{inference.py → detection.py} +73 -40
- portiloop/notebooks/tests.ipynb +5 -34
- portiloop/stimulation.py +35 -0
portiloop/capture.py
CHANGED
|
@@ -399,7 +399,7 @@ def _capture_process(p_data_o, p_msg_io, duration, frequency, python_clock, time
|
|
| 399 |
|
| 400 |
|
| 401 |
class Capture:
|
| 402 |
-
def __init__(self,
|
| 403 |
# {now.strftime('%m_%d_%Y_%H_%M_%S')}
|
| 404 |
self.filename = EDF_PATH / 'recording.edf'
|
| 405 |
self._p_capture = None
|
|
@@ -433,7 +433,8 @@ class Capture:
|
|
| 433 |
self._t_capture = None
|
| 434 |
self.channel_states = ['disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled']
|
| 435 |
|
| 436 |
-
self.
|
|
|
|
| 437 |
|
| 438 |
# widgets ===============================
|
| 439 |
|
|
@@ -665,6 +666,7 @@ class Capture:
|
|
| 665 |
self.b_duration.observe(self.on_b_duration, 'value')
|
| 666 |
self.b_filter.observe(self.on_b_filter, 'value')
|
| 667 |
self.b_detect.observe(self.on_b_detect, 'value')
|
|
|
|
| 668 |
self.b_record.observe(self.on_b_record, 'value')
|
| 669 |
self.b_lsl.observe(self.on_b_lsl, 'value')
|
| 670 |
self.b_display.observe(self.on_b_display, 'value')
|
|
@@ -707,7 +709,7 @@ class Capture:
|
|
| 707 |
self.b_filter.disabled = False
|
| 708 |
self.b_detect.disabled = False
|
| 709 |
self.b_record.disabled = False
|
| 710 |
-
self.
|
| 711 |
self.b_display.disabled = False
|
| 712 |
self.b_clock.disabled = False
|
| 713 |
self.b_radio_ch2.disabled = False
|
|
@@ -733,8 +735,9 @@ class Capture:
|
|
| 733 |
self.b_filter.disabled = True
|
| 734 |
self.b_stimulate.disabled = True
|
| 735 |
self.b_filter.disabled = True
|
|
|
|
| 736 |
self.b_record.disabled = True
|
| 737 |
-
self.
|
| 738 |
self.b_display.disabled = True
|
| 739 |
self.b_clock.disabled = True
|
| 740 |
self.b_radio_ch2.disabled = True
|
|
@@ -784,8 +787,18 @@ class Capture:
|
|
| 784 |
if self._t_capture is not None:
|
| 785 |
warnings.warn("Capture already running, operation aborted.")
|
| 786 |
return
|
|
|
|
|
|
|
| 787 |
self._t_capture = Thread(target=self.start_capture,
|
| 788 |
-
args=(self.filter,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 789 |
self._t_capture.start()
|
| 790 |
elif val == 'Stop':
|
| 791 |
with self._lock_msg_out:
|
|
@@ -944,8 +957,9 @@ class Capture:
|
|
| 944 |
|
| 945 |
def start_capture(self,
|
| 946 |
filter,
|
| 947 |
-
|
| 948 |
-
|
|
|
|
| 949 |
record,
|
| 950 |
lsl,
|
| 951 |
viz,
|
|
@@ -971,8 +985,8 @@ class Capture:
|
|
| 971 |
alpha_std=self.polyak_std,
|
| 972 |
epsilon=self.epsilon)
|
| 973 |
|
| 974 |
-
if
|
| 975 |
-
|
| 976 |
|
| 977 |
self._p_capture = mp.Process(target=_capture_process,
|
| 978 |
args=(p_data_o,
|
|
@@ -984,7 +998,7 @@ class Capture:
|
|
| 984 |
self.channel_states)
|
| 985 |
)
|
| 986 |
self._p_capture.start()
|
| 987 |
-
|
| 988 |
|
| 989 |
if viz:
|
| 990 |
live_disp = LiveDisplay(channel_names = self.signal_labels, window_len=width)
|
|
@@ -1030,14 +1044,11 @@ class Capture:
|
|
| 1030 |
|
| 1031 |
filtered_point = n_array.tolist()
|
| 1032 |
|
| 1033 |
-
if
|
| 1034 |
-
|
| 1035 |
|
| 1036 |
-
|
| 1037 |
-
|
| 1038 |
-
|
| 1039 |
-
if stimulate and True:
|
| 1040 |
-
print('stimulation')
|
| 1041 |
|
| 1042 |
if lsl:
|
| 1043 |
lsl_outlet.push_sample(filtered_point[-1])
|
|
|
|
| 399 |
|
| 400 |
|
| 401 |
class Capture:
|
| 402 |
+
def __init__(self, detector_cls=None, stimulator_cls=None):
|
| 403 |
# {now.strftime('%m_%d_%Y_%H_%M_%S')}
|
| 404 |
self.filename = EDF_PATH / 'recording.edf'
|
| 405 |
self._p_capture = None
|
|
|
|
| 433 |
self._t_capture = None
|
| 434 |
self.channel_states = ['disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled']
|
| 435 |
|
| 436 |
+
self.detector_cls = detector_cls
|
| 437 |
+
self.stimulator_cls = stimulator_cls
|
| 438 |
|
| 439 |
# widgets ===============================
|
| 440 |
|
|
|
|
| 666 |
self.b_duration.observe(self.on_b_duration, 'value')
|
| 667 |
self.b_filter.observe(self.on_b_filter, 'value')
|
| 668 |
self.b_detect.observe(self.on_b_detect, 'value')
|
| 669 |
+
self.b_stimulate.observe(self.on_b_stimulate, 'value')
|
| 670 |
self.b_record.observe(self.on_b_record, 'value')
|
| 671 |
self.b_lsl.observe(self.on_b_lsl, 'value')
|
| 672 |
self.b_display.observe(self.on_b_display, 'value')
|
|
|
|
| 709 |
self.b_filter.disabled = False
|
| 710 |
self.b_detect.disabled = False
|
| 711 |
self.b_record.disabled = False
|
| 712 |
+
self.b_lsl.disabled = False
|
| 713 |
self.b_display.disabled = False
|
| 714 |
self.b_clock.disabled = False
|
| 715 |
self.b_radio_ch2.disabled = False
|
|
|
|
| 735 |
self.b_filter.disabled = True
|
| 736 |
self.b_stimulate.disabled = True
|
| 737 |
self.b_filter.disabled = True
|
| 738 |
+
self.b_detect.disabled = True
|
| 739 |
self.b_record.disabled = True
|
| 740 |
+
self.b_lsl.disabled = True
|
| 741 |
self.b_display.disabled = True
|
| 742 |
self.b_clock.disabled = True
|
| 743 |
self.b_radio_ch2.disabled = True
|
|
|
|
| 787 |
if self._t_capture is not None:
|
| 788 |
warnings.warn("Capture already running, operation aborted.")
|
| 789 |
return
|
| 790 |
+
detector_cls = self.detector_cls if self.detect else None
|
| 791 |
+
stimulator_cls = self.stimulator_cls if self.stimulate else None
|
| 792 |
self._t_capture = Thread(target=self.start_capture,
|
| 793 |
+
args=(self.filter,
|
| 794 |
+
detector_cls,
|
| 795 |
+
self.threshold,
|
| 796 |
+
stimulator_cls,
|
| 797 |
+
self.record,
|
| 798 |
+
self.lsl,
|
| 799 |
+
self.display,
|
| 800 |
+
500,
|
| 801 |
+
self.python_clock))
|
| 802 |
self._t_capture.start()
|
| 803 |
elif val == 'Stop':
|
| 804 |
with self._lock_msg_out:
|
|
|
|
| 957 |
|
| 958 |
def start_capture(self,
|
| 959 |
filter,
|
| 960 |
+
detector_cls,
|
| 961 |
+
threshold,
|
| 962 |
+
stimulator_cls,
|
| 963 |
record,
|
| 964 |
lsl,
|
| 965 |
viz,
|
|
|
|
| 985 |
alpha_std=self.polyak_std,
|
| 986 |
epsilon=self.epsilon)
|
| 987 |
|
| 988 |
+
detector = detector_cls(threshold) if detector_cls is not None else None
|
| 989 |
+
stimulator = stimulator_cls() if stimulator_cls is not None else None
|
| 990 |
|
| 991 |
self._p_capture = mp.Process(target=_capture_process,
|
| 992 |
args=(p_data_o,
|
|
|
|
| 998 |
self.channel_states)
|
| 999 |
)
|
| 1000 |
self._p_capture.start()
|
| 1001 |
+
print(f"PID capture: {self._p_capture.pid}")
|
| 1002 |
|
| 1003 |
if viz:
|
| 1004 |
live_disp = LiveDisplay(channel_names = self.signal_labels, window_len=width)
|
|
|
|
| 1044 |
|
| 1045 |
filtered_point = n_array.tolist()
|
| 1046 |
|
| 1047 |
+
if detector is not None:
|
| 1048 |
+
detection_signal = detector.detect(filtered_point)
|
| 1049 |
|
| 1050 |
+
if stimulator is not None:
|
| 1051 |
+
stimulator.stimulate(detection_signal)
|
|
|
|
|
|
|
|
|
|
| 1052 |
|
| 1053 |
if lsl:
|
| 1054 |
lsl_outlet.push_sample(filtered_point[-1])
|
portiloop/{inference.py → detection.py}
RENAMED
|
@@ -1,19 +1,49 @@
|
|
| 1 |
-
from pycoral.utils import edgetpu
|
| 2 |
-
import time
|
| 3 |
from abc import ABC, abstractmethod
|
|
|
|
| 4 |
from pathlib import Path
|
|
|
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
|
| 7 |
-
DEFAULT_MODEL_PATH = str(Path(__file__).parent / "models/portiloop_model_quant.tflite")
|
| 8 |
-
print(DEFAULT_MODEL_PATH)
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
@abstractmethod
|
| 12 |
-
def
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
class
|
| 16 |
-
def __init__(self, num_models_parallel=8, window_size=54, seq_stride=42, model_path=None, verbose=False, channel=2):
|
| 17 |
model_path = DEFAULT_MODEL_PATH if model_path is None else model_path
|
| 18 |
self.verbose = verbose
|
| 19 |
self.channel = channel
|
|
@@ -32,60 +62,63 @@ class QuantizedModelForInference(AbstractQuantizedModelForInference):
|
|
| 32 |
self.seq_stride = seq_stride
|
| 33 |
self.window_size = window_size
|
| 34 |
|
| 35 |
-
self.stride_counters = [np.floor((self.seq_stride / self.num_models_parallel) * i) for i in range(self.num_models_parallel)]
|
| 36 |
-
for idx
|
| 37 |
-
self.stride_counters[idx
|
|
|
|
|
|
|
| 38 |
self.current_stride_counter = self.stride_counters[0] - 1
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
|
|
|
| 42 |
res = []
|
| 43 |
-
for inp in
|
| 44 |
result = self.add_datapoint(inp)
|
| 45 |
if result is not None:
|
| 46 |
-
res.append(result)
|
| 47 |
return res
|
| 48 |
-
|
| 49 |
-
|
| 50 |
def add_datapoint(self, input_float):
|
| 51 |
-
input_float = input_float[self.channel-1]
|
| 52 |
result = None
|
| 53 |
self.buffer.append(input_float)
|
| 54 |
if len(self.buffer) > self.window_size:
|
| 55 |
self.buffer = self.buffer[1:]
|
| 56 |
self.current_stride_counter += 1
|
| 57 |
-
if self.current_stride_counter == self.
|
| 58 |
result = self.call_model(self.interpreter_counter, self.buffer)
|
| 59 |
self.interpreter_counter += 1
|
| 60 |
-
self.interpreter_counter %= self.
|
| 61 |
self.current_stride_counter = 0
|
| 62 |
return result
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
def call_model(self, idx, input_float=None):
|
| 67 |
if input_float is None:
|
| 68 |
-
# For
|
| 69 |
-
input_shape = input_details[0]['shape']
|
| 70 |
input = np.array(np.random.random_sample(input_shape), dtype=np.int8)
|
| 71 |
else:
|
| 72 |
# Convert float input to Int
|
| 73 |
-
input_scale, input_zero_point = input_details[0]["quantization"]
|
| 74 |
input = np.asarray(input_float) / input_scale + input_zero_point
|
| 75 |
-
input = input.astype(input_details[0]["dtype"])
|
| 76 |
-
|
| 77 |
-
interpreter.set_tensor(input_details[0]['index'], input)
|
| 78 |
-
if self.verbose:
|
| 79 |
-
start_time = time.time()
|
| 80 |
-
|
| 81 |
-
interpreter.invoke()
|
| 82 |
-
|
| 83 |
-
if self.verbose:
|
| 84 |
-
end_time = time.time()
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
if self.verbose:
|
| 91 |
print(f"Computed output {output} in {end_time - start_time} seconds")
|
|
|
|
|
|
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
+
import time
|
| 3 |
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from pycoral.utils import edgetpu
|
| 6 |
import numpy as np
|
| 7 |
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
# Abstract interface for developers:
|
| 10 |
+
|
| 11 |
+
class Detector(ABC):
|
| 12 |
+
|
| 13 |
+
def __init__(self, threshold=None):
|
| 14 |
+
"""
|
| 15 |
+
If implementing __init__() in your subclass, it must take threshold as a keyword argument.
|
| 16 |
+
This is the value of the threshold that the user can set in the Portiloop GUI.
|
| 17 |
+
Caution: even if you don't need this manual threshold in your application,
|
| 18 |
+
your implementation of __init__() still needs to have this keyword argument.
|
| 19 |
+
"""
|
| 20 |
+
self.threshold = threshold
|
| 21 |
+
|
| 22 |
@abstractmethod
|
| 23 |
+
def detect(self, datapoints):
|
| 24 |
+
"""
|
| 25 |
+
Takes datapoints as input and outputs a detection signal.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
datapoints: list of lists of n channels: may contain several datapoints.
|
| 29 |
+
A datapoint is a list of n floats, 1 for each channel.
|
| 30 |
+
In the current version of Portiloop, there is always only one datapoint per datapoints list.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
signal: Object: output detection signal (for instance, the output of a neural network);
|
| 34 |
+
this output signal is the input of the Stimulator.stimulate method.
|
| 35 |
+
If you don't mean to use a Stimulator, you can simply return None.
|
| 36 |
+
"""
|
| 37 |
+
raise NotImplementedError
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Example implementation for sleep spindles:
|
| 41 |
+
|
| 42 |
+
DEFAULT_MODEL_PATH = str(Path(__file__).parent / "models/portiloop_model_quant.tflite")
|
| 43 |
+
# print(DEFAULT_MODEL_PATH)
|
| 44 |
|
| 45 |
+
class SleepSpindleRealTimeDetector(Detector):
|
| 46 |
+
def __init__(self, threshold=0.5, num_models_parallel=8, window_size=54, seq_stride=42, model_path=None, verbose=False, channel=2):
|
| 47 |
model_path = DEFAULT_MODEL_PATH if model_path is None else model_path
|
| 48 |
self.verbose = verbose
|
| 49 |
self.channel = channel
|
|
|
|
| 62 |
self.seq_stride = seq_stride
|
| 63 |
self.window_size = window_size
|
| 64 |
|
| 65 |
+
self.stride_counters = [np.floor((self.seq_stride / self.num_models_parallel) * (i + 1)) for i in range(self.num_models_parallel)]
|
| 66 |
+
for idx in reversed(range(1, len(self.stride_counters))):
|
| 67 |
+
self.stride_counters[idx] -= self.stride_counters[idx-1]
|
| 68 |
+
assert sum(self.stride_counters) == self.seq_stride, f"{self.stride_counters} does not sum to {self.seq_stride}"
|
| 69 |
+
|
| 70 |
self.current_stride_counter = self.stride_counters[0] - 1
|
| 71 |
|
| 72 |
+
super().__init__(threshold)
|
| 73 |
+
|
| 74 |
+
def detect(self, datapoints):
|
| 75 |
res = []
|
| 76 |
+
for inp in datapoints:
|
| 77 |
result = self.add_datapoint(inp)
|
| 78 |
if result is not None:
|
| 79 |
+
res.append(result >= self.threshold)
|
| 80 |
return res
|
| 81 |
+
|
|
|
|
| 82 |
def add_datapoint(self, input_float):
|
| 83 |
+
input_float = input_float[self.channel - 1]
|
| 84 |
result = None
|
| 85 |
self.buffer.append(input_float)
|
| 86 |
if len(self.buffer) > self.window_size:
|
| 87 |
self.buffer = self.buffer[1:]
|
| 88 |
self.current_stride_counter += 1
|
| 89 |
+
if self.current_stride_counter == self.stride_counters[self.interpreter_counter]:
|
| 90 |
result = self.call_model(self.interpreter_counter, self.buffer)
|
| 91 |
self.interpreter_counter += 1
|
| 92 |
+
self.interpreter_counter %= self.num_models_parallel
|
| 93 |
self.current_stride_counter = 0
|
| 94 |
return result
|
| 95 |
+
|
|
|
|
|
|
|
| 96 |
def call_model(self, idx, input_float=None):
|
| 97 |
if input_float is None:
|
| 98 |
+
# For debugging purposes
|
| 99 |
+
input_shape = self.input_details[0]['shape']
|
| 100 |
input = np.array(np.random.random_sample(input_shape), dtype=np.int8)
|
| 101 |
else:
|
| 102 |
# Convert float input to Int
|
| 103 |
+
input_scale, input_zero_point = self.input_details[0]["quantization"]
|
| 104 |
input = np.asarray(input_float) / input_scale + input_zero_point
|
| 105 |
+
input = input.astype(self.input_details[0]["dtype"])
|
| 106 |
+
input = input.reshape((1, 1, -1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
+
# FIXME: bad sequence length: 50 instead of 1:
|
| 109 |
+
# self.interpreters[idx].set_tensor(self.input_details[0]['index'], input)
|
| 110 |
+
#
|
| 111 |
+
# if self.verbose:
|
| 112 |
+
# start_time = time.time()
|
| 113 |
+
#
|
| 114 |
+
# self.interpreters[idx].invoke()
|
| 115 |
+
#
|
| 116 |
+
# if self.verbose:
|
| 117 |
+
# end_time = time.time()
|
| 118 |
+
# output = self.interpreters[idx].get_tensor(self.output_details[0]['index'])
|
| 119 |
+
# output_scale, output_zero_point = self.input_details[0]["quantization"]
|
| 120 |
+
# output = float(output - output_zero_point) * output_scale
|
| 121 |
+
output = np.random.uniform() # FIXME: remove
|
| 122 |
|
| 123 |
if self.verbose:
|
| 124 |
print(f"Computed output {output} in {end_time - start_time} seconds")
|
portiloop/notebooks/tests.ipynb
CHANGED
|
@@ -2,47 +2,18 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
-
"execution_count":
|
| 6 |
"id": "7b2fc5da",
|
| 7 |
"metadata": {
|
| 8 |
"scrolled": false
|
| 9 |
},
|
| 10 |
-
"outputs": [
|
| 11 |
-
{
|
| 12 |
-
"data": {
|
| 13 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 14 |
-
"model_id": "910f8e489b6341119f4d6e17a5b2aedc",
|
| 15 |
-
"version_major": 2,
|
| 16 |
-
"version_minor": 0
|
| 17 |
-
},
|
| 18 |
-
"text/plain": [
|
| 19 |
-
"VBox(children=(Accordion(children=(GridBox(children=(Label(value='CH1'), Label(value='CH2'), Label(value='CH3'…"
|
| 20 |
-
]
|
| 21 |
-
},
|
| 22 |
-
"metadata": {},
|
| 23 |
-
"output_type": "display_data"
|
| 24 |
-
},
|
| 25 |
-
{
|
| 26 |
-
"name": "stderr",
|
| 27 |
-
"output_type": "stream",
|
| 28 |
-
"text": [
|
| 29 |
-
"Process Process-1:\n",
|
| 30 |
-
"Traceback (most recent call last):\n",
|
| 31 |
-
" File \"/usr/lib/python3.7/multiprocessing/process.py\", line 297, in _bootstrap\n",
|
| 32 |
-
" self.run()\n",
|
| 33 |
-
" File \"/usr/lib/python3.7/multiprocessing/process.py\", line 99, in run\n",
|
| 34 |
-
" self._target(*self._args, **self._kwargs)\n",
|
| 35 |
-
" File \"/home/mendel/software/portiloop-software/portiloop/capture.py\", line 325, in _capture_process\n",
|
| 36 |
-
" assert data == [0x3E], \"The communication with the ADS cannot be established.\"\n",
|
| 37 |
-
"AssertionError: The communication with the ADS cannot be established.\n"
|
| 38 |
-
]
|
| 39 |
-
}
|
| 40 |
-
],
|
| 41 |
"source": [
|
| 42 |
"from portiloop.capture import Capture\n",
|
| 43 |
-
"from portiloop.
|
|
|
|
| 44 |
"\n",
|
| 45 |
-
"cap = Capture(
|
| 46 |
]
|
| 47 |
}
|
| 48 |
],
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
"id": "7b2fc5da",
|
| 7 |
"metadata": {
|
| 8 |
"scrolled": false
|
| 9 |
},
|
| 10 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
"source": [
|
| 12 |
"from portiloop.capture import Capture\n",
|
| 13 |
+
"from portiloop.detection import SleepSpindleRealTimeDetector\n",
|
| 14 |
+
"from portiloop.stimulation import SleepSpindleRealTimeStimulator\n",
|
| 15 |
"\n",
|
| 16 |
+
"cap = Capture(detector_cls=SleepSpindleRealTimeDetector, stimulator_cls=SleepSpindleRealTimeStimulator)"
|
| 17 |
]
|
| 18 |
}
|
| 19 |
],
|
portiloop/stimulation.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Abstract interface for developers:
|
| 6 |
+
|
| 7 |
+
class Stimulator(ABC):
|
| 8 |
+
|
| 9 |
+
@abstractmethod
|
| 10 |
+
def stimulate(self, detection_signal):
|
| 11 |
+
"""
|
| 12 |
+
Stimulates accordingly to the output of the Detector.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
detection_signal: Object: the output of the Detector.add_datapoints method.
|
| 16 |
+
"""
|
| 17 |
+
raise NotImplementedError
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Example implementation for sleep spindles:
|
| 21 |
+
|
| 22 |
+
class SleepSpindleRealTimeStimulator(Stimulator):
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.last_detected_ts = time.time()
|
| 25 |
+
self.wait_t = 0.4 # 400 ms
|
| 26 |
+
|
| 27 |
+
def stimulate(self, detection_signal):
|
| 28 |
+
for sig in detection_signal:
|
| 29 |
+
if sig:
|
| 30 |
+
ts = time.time()
|
| 31 |
+
if ts - self.last_detected_ts > self.wait_t:
|
| 32 |
+
print("stimulation")
|
| 33 |
+
else:
|
| 34 |
+
print("same spindle")
|
| 35 |
+
self.last_detected_ts = ts
|