Spaces:
Build error
Build error
Yann Bouteiller
commited on
Commit
·
6da664d
1
Parent(s):
0671861
Added spindle stimulation mode
Browse files- portiloop/capture.py +98 -2
- portiloop/stimulation.py +29 -2
portiloop/capture.py
CHANGED
|
@@ -429,6 +429,58 @@ class DummyAlsaMixer:
|
|
| 429 |
self.volume = volume
|
| 430 |
|
| 431 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
class Capture:
|
| 433 |
def __init__(self, detector_cls=None, stimulator_cls=None):
|
| 434 |
# {now.strftime('%m_%d_%Y_%H_%M_%S')}
|
|
@@ -465,6 +517,8 @@ class Capture:
|
|
| 465 |
self._t_capture = None
|
| 466 |
self.channel_states = ['disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled']
|
| 467 |
self.channel_detection = 2
|
|
|
|
|
|
|
| 468 |
|
| 469 |
self.detector_cls = detector_cls
|
| 470 |
self.stimulator_cls = stimulator_cls
|
|
@@ -552,6 +606,21 @@ class Capture:
|
|
| 552 |
style={'description_width': 'initial'}
|
| 553 |
)
|
| 554 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 555 |
self.b_accordion_channels = widgets.Accordion(
|
| 556 |
children=[
|
| 557 |
widgets.GridBox([
|
|
@@ -796,6 +865,8 @@ class Capture:
|
|
| 796 |
self.b_radio_ch7.observe(self.on_b_radio_ch7, 'value')
|
| 797 |
self.b_radio_ch8.observe(self.on_b_radio_ch8, 'value')
|
| 798 |
self.b_channel_detect.observe(self.on_b_channel_detect, 'value')
|
|
|
|
|
|
|
| 799 |
self.b_power_line.observe(self.on_b_power_line, 'value')
|
| 800 |
self.b_custom_fir.observe(self.on_b_custom_fir, 'value')
|
| 801 |
self.b_custom_fir_order.observe(self.on_b_custom_fir_order, 'value')
|
|
@@ -823,6 +894,7 @@ class Capture:
|
|
| 823 |
widgets.HBox([self.b_filter, self.b_detect, self.b_stimulate, self.b_record, self.b_lsl, self.b_display]),
|
| 824 |
widgets.HBox([self.b_threshold, self.b_test_stimulus]),
|
| 825 |
self.b_volume,
|
|
|
|
| 826 |
self.b_accordion_filter,
|
| 827 |
self.b_capture,
|
| 828 |
self.b_pause]))
|
|
@@ -846,6 +918,8 @@ class Capture:
|
|
| 846 |
self.b_radio_ch8.disabled = False
|
| 847 |
self.b_power_line.disabled = False
|
| 848 |
self.b_channel_detect.disabled = False
|
|
|
|
|
|
|
| 849 |
self.b_polyak_mean.disabled = False
|
| 850 |
self.b_polyak_std.disabled = False
|
| 851 |
self.b_epsilon.disabled = False
|
|
@@ -880,6 +954,8 @@ class Capture:
|
|
| 880 |
self.b_radio_ch7.disabled = True
|
| 881 |
self.b_radio_ch8.disabled = True
|
| 882 |
self.b_channel_detect.disabled = True
|
|
|
|
|
|
|
| 883 |
self.b_power_line.disabled = True
|
| 884 |
self.b_polyak_mean.disabled = True
|
| 885 |
self.b_polyak_std.disabled = True
|
|
@@ -916,7 +992,17 @@ class Capture:
|
|
| 916 |
|
| 917 |
def on_b_channel_detect(self, value):
|
| 918 |
self.channel_detection = value['new']
|
| 919 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 920 |
def on_b_capture(self, value):
|
| 921 |
val = value['new']
|
| 922 |
if val == 'Start':
|
|
@@ -1208,6 +1294,13 @@ class Capture:
|
|
| 1208 |
|
| 1209 |
buffer = []
|
| 1210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1211 |
while True:
|
| 1212 |
with self._lock_msg_out:
|
| 1213 |
if self._msg_out is not None:
|
|
@@ -1238,12 +1331,15 @@ class Capture:
|
|
| 1238 |
if lsl:
|
| 1239 |
lsl_outlet_raw.push_sample(point)
|
| 1240 |
lsl_outlet.push_sample(filtered_point[-1])
|
|
|
|
|
|
|
|
|
|
| 1241 |
|
| 1242 |
with self._pause_detect_lock:
|
| 1243 |
pause = self._pause_detect
|
| 1244 |
if detector is not None and not pause:
|
| 1245 |
detection_signal = detector.detect(filtered_point)
|
| 1246 |
-
if stimulator is not None:
|
| 1247 |
stimulator.stimulate(detection_signal)
|
| 1248 |
with self._test_stimulus_lock:
|
| 1249 |
test_stimulus = self._test_stimulus
|
|
|
|
| 429 |
self.volume = volume
|
| 430 |
|
| 431 |
|
| 432 |
+
class UpStateDelayer:
|
| 433 |
+
def __init__(self, sample_freq, spindle_freq, peak):
|
| 434 |
+
'''
|
| 435 |
+
args:
|
| 436 |
+
buffer_size: int -> Size of desired buffer in length
|
| 437 |
+
sample_freq: int -> Sampling frequency of signal in Hz
|
| 438 |
+
'''
|
| 439 |
+
# Get number of timesteps for a whole spindle
|
| 440 |
+
self.spindle_timesteps = (1/spindle_freq) * sample_freq # s *
|
| 441 |
+
self.sample_freq = sample_freq
|
| 442 |
+
self.buffer_size = 1.5 * self.spindle_timesteps
|
| 443 |
+
self.peak = peak
|
| 444 |
+
self.buffer = []
|
| 445 |
+
|
| 446 |
+
def add_point(self, point):
|
| 447 |
+
'''
|
| 448 |
+
Adds a point to the buffer to be able to keep track of peaks
|
| 449 |
+
'''
|
| 450 |
+
self.buffer.append(point)
|
| 451 |
+
if len(self.buffer) > self.buffer_size:
|
| 452 |
+
self.buffer.pop(0)
|
| 453 |
+
|
| 454 |
+
def stimulate(self):
|
| 455 |
+
# Calculate how far away is last peak
|
| 456 |
+
last_peak = -1
|
| 457 |
+
count = 0
|
| 458 |
+
for idx, point in reversed(list(enumerate(self.buffer))):
|
| 459 |
+
if self.peak:
|
| 460 |
+
try:
|
| 461 |
+
sup = point >= self.buffer[idx+1]
|
| 462 |
+
except IndexError:
|
| 463 |
+
sup = False
|
| 464 |
+
try:
|
| 465 |
+
inf = point >= self.buffer[idx-1]
|
| 466 |
+
except IndexError:
|
| 467 |
+
inf = False
|
| 468 |
+
else:
|
| 469 |
+
try:
|
| 470 |
+
sup = point <= self.buffer[idx+1]
|
| 471 |
+
except IndexError:
|
| 472 |
+
sup = False
|
| 473 |
+
try:
|
| 474 |
+
inf = point <= self.buffer[idx-1]
|
| 475 |
+
except IndexError:
|
| 476 |
+
inf = False
|
| 477 |
+
if sup and inf:
|
| 478 |
+
last_peak = count
|
| 479 |
+
return self.spindle_timesteps - last_peak
|
| 480 |
+
count += 1
|
| 481 |
+
return -1
|
| 482 |
+
|
| 483 |
+
|
| 484 |
class Capture:
|
| 485 |
def __init__(self, detector_cls=None, stimulator_cls=None):
|
| 486 |
# {now.strftime('%m_%d_%Y_%H_%M_%S')}
|
|
|
|
| 517 |
self._t_capture = None
|
| 518 |
self.channel_states = ['disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled', 'disabled']
|
| 519 |
self.channel_detection = 2
|
| 520 |
+
self.spindle_detection_mode = 'Fast'
|
| 521 |
+
self.spindle_freq = 10
|
| 522 |
|
| 523 |
self.detector_cls = detector_cls
|
| 524 |
self.stimulator_cls = stimulator_cls
|
|
|
|
| 606 |
style={'description_width': 'initial'}
|
| 607 |
)
|
| 608 |
|
| 609 |
+
self.b_spindle_mode = widgets.Dropdown(
|
| 610 |
+
options=['Fast', 'Peak', 'Through'],
|
| 611 |
+
value='Fast',
|
| 612 |
+
description='Spindle Stimulation Mode',
|
| 613 |
+
disabled=False,
|
| 614 |
+
style={'description_width': 'initial'}
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
self.b_spindle_freq = widgets.IntText(
|
| 618 |
+
value=self.spindle_freq,
|
| 619 |
+
description='Spindle Freq (Hz):',
|
| 620 |
+
disabled=False,
|
| 621 |
+
style={'description_width': 'initial'}
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
self.b_accordion_channels = widgets.Accordion(
|
| 625 |
children=[
|
| 626 |
widgets.GridBox([
|
|
|
|
| 865 |
self.b_radio_ch7.observe(self.on_b_radio_ch7, 'value')
|
| 866 |
self.b_radio_ch8.observe(self.on_b_radio_ch8, 'value')
|
| 867 |
self.b_channel_detect.observe(self.on_b_channel_detect, 'value')
|
| 868 |
+
self.b_spindle_mode.observe(self.on_b_spindle_mode, 'value')
|
| 869 |
+
self.b_spindle_freq.observe(self.on_b_spindle_freq, 'value')
|
| 870 |
self.b_power_line.observe(self.on_b_power_line, 'value')
|
| 871 |
self.b_custom_fir.observe(self.on_b_custom_fir, 'value')
|
| 872 |
self.b_custom_fir_order.observe(self.on_b_custom_fir_order, 'value')
|
|
|
|
| 894 |
widgets.HBox([self.b_filter, self.b_detect, self.b_stimulate, self.b_record, self.b_lsl, self.b_display]),
|
| 895 |
widgets.HBox([self.b_threshold, self.b_test_stimulus]),
|
| 896 |
self.b_volume,
|
| 897 |
+
widgets.HBox([self.b_spindle_mode, self.b_spindle_freq]),
|
| 898 |
self.b_accordion_filter,
|
| 899 |
self.b_capture,
|
| 900 |
self.b_pause]))
|
|
|
|
| 918 |
self.b_radio_ch8.disabled = False
|
| 919 |
self.b_power_line.disabled = False
|
| 920 |
self.b_channel_detect.disabled = False
|
| 921 |
+
self.b_spindle_freq.disabled = False
|
| 922 |
+
self.b_spindle_mode.disabled = False
|
| 923 |
self.b_polyak_mean.disabled = False
|
| 924 |
self.b_polyak_std.disabled = False
|
| 925 |
self.b_epsilon.disabled = False
|
|
|
|
| 954 |
self.b_radio_ch7.disabled = True
|
| 955 |
self.b_radio_ch8.disabled = True
|
| 956 |
self.b_channel_detect.disabled = True
|
| 957 |
+
self.b_spindle_freq.disabled = True
|
| 958 |
+
self.b_spindle_mode.disabled = True
|
| 959 |
self.b_power_line.disabled = True
|
| 960 |
self.b_polyak_mean.disabled = True
|
| 961 |
self.b_polyak_std.disabled = True
|
|
|
|
| 992 |
|
| 993 |
def on_b_channel_detect(self, value):
|
| 994 |
self.channel_detection = value['new']
|
| 995 |
+
|
| 996 |
+
def on_b_spindle_freq(self, value):
|
| 997 |
+
val = value['new']
|
| 998 |
+
if val > 0:
|
| 999 |
+
self.spindle_freq = val
|
| 1000 |
+
else:
|
| 1001 |
+
self.b_spindle_freq.value = self.spindle_freq
|
| 1002 |
+
|
| 1003 |
+
def on_b_spindle_mode(self, value):
|
| 1004 |
+
self.spindle_detection_mode = value['new']
|
| 1005 |
+
|
| 1006 |
def on_b_capture(self, value):
|
| 1007 |
val = value['new']
|
| 1008 |
if val == 'Start':
|
|
|
|
| 1294 |
|
| 1295 |
buffer = []
|
| 1296 |
|
| 1297 |
+
if not self.spindle_detection_mode == 'Fast':
|
| 1298 |
+
print('here')
|
| 1299 |
+
stimulation_delayer = UpStateDelayer(self.frequency, self.spindle_freq, self.spindle_detection_mode == 'Peak')
|
| 1300 |
+
stimulator.add_delayer(stimulation_delayer)
|
| 1301 |
+
else:
|
| 1302 |
+
stimulation_delayer = None
|
| 1303 |
+
|
| 1304 |
while True:
|
| 1305 |
with self._lock_msg_out:
|
| 1306 |
if self._msg_out is not None:
|
|
|
|
| 1331 |
if lsl:
|
| 1332 |
lsl_outlet_raw.push_sample(point)
|
| 1333 |
lsl_outlet.push_sample(filtered_point[-1])
|
| 1334 |
+
|
| 1335 |
+
if stimulation_delayer is not None:
|
| 1336 |
+
stimulation_delayer.add_point(point[channel-1])
|
| 1337 |
|
| 1338 |
with self._pause_detect_lock:
|
| 1339 |
pause = self._pause_detect
|
| 1340 |
if detector is not None and not pause:
|
| 1341 |
detection_signal = detector.detect(filtered_point)
|
| 1342 |
+
if stimulator is not None:
|
| 1343 |
stimulator.stimulate(detection_signal)
|
| 1344 |
with self._test_stimulus_lock:
|
| 1345 |
test_stimulus = self._test_stimulus
|
portiloop/stimulation.py
CHANGED
|
@@ -37,6 +37,8 @@ class SleepSpindleRealTimeStimulator(Stimulator):
|
|
| 37 |
self._thread = None
|
| 38 |
self._lock = Lock()
|
| 39 |
self.last_detected_ts = time.time()
|
|
|
|
|
|
|
| 40 |
self.wait_t = 0.4 # 400 ms
|
| 41 |
|
| 42 |
lsl_markers_info = pylsl.StreamInfo(name='Portiloop_stimuli',
|
|
@@ -45,6 +47,7 @@ class SleepSpindleRealTimeStimulator(Stimulator):
|
|
| 45 |
channel_format='string',
|
| 46 |
source_id='portiloop1') # TODO: replace this by unique device identifier
|
| 47 |
self.lsl_outlet_markers = pylsl.StreamOutlet(lsl_markers_info)
|
|
|
|
| 48 |
|
| 49 |
# Initialize Alsa stuff
|
| 50 |
# Open WAV file and set PCM device
|
|
@@ -86,11 +89,32 @@ class SleepSpindleRealTimeStimulator(Stimulator):
|
|
| 86 |
|
| 87 |
def stimulate(self, detection_signal):
|
| 88 |
for sig in detection_signal:
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
ts = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
if ts - self.last_detected_ts > self.wait_t:
|
| 92 |
with self._lock:
|
| 93 |
-
if self._thread is None:
|
| 94 |
self._thread = Thread(target=self._t_sound, daemon=True)
|
| 95 |
self._thread.start()
|
| 96 |
self.last_detected_ts = ts
|
|
@@ -106,3 +130,6 @@ class SleepSpindleRealTimeStimulator(Stimulator):
|
|
| 106 |
if self._thread is None:
|
| 107 |
self._thread = Thread(target=self._t_sound, daemon=True)
|
| 108 |
self._thread.start()
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
self._thread = None
|
| 38 |
self._lock = Lock()
|
| 39 |
self.last_detected_ts = time.time()
|
| 40 |
+
self.wait_counter = 0
|
| 41 |
+
self.delayed = False
|
| 42 |
self.wait_t = 0.4 # 400 ms
|
| 43 |
|
| 44 |
lsl_markers_info = pylsl.StreamInfo(name='Portiloop_stimuli',
|
|
|
|
| 47 |
channel_format='string',
|
| 48 |
source_id='portiloop1') # TODO: replace this by unique device identifier
|
| 49 |
self.lsl_outlet_markers = pylsl.StreamOutlet(lsl_markers_info)
|
| 50 |
+
self.delayer = None
|
| 51 |
|
| 52 |
# Initialize Alsa stuff
|
| 53 |
# Open WAV file and set PCM device
|
|
|
|
| 89 |
|
| 90 |
def stimulate(self, detection_signal):
|
| 91 |
for sig in detection_signal:
|
| 92 |
+
# We are waiting for a delayed stimulation
|
| 93 |
+
if self.delayed:
|
| 94 |
+
if self.wait_counter >= self.wait_time:
|
| 95 |
+
with self._lock:
|
| 96 |
+
if self._thread is None:
|
| 97 |
+
self._thread = Thread(target=self._t_sound, daemon=True)
|
| 98 |
+
self._thread.start()
|
| 99 |
+
self.delayed = False
|
| 100 |
+
else:
|
| 101 |
+
self.wait_counter += 1
|
| 102 |
+
# We detect a stimulation
|
| 103 |
+
elif sig:
|
| 104 |
+
# Record time of stimulation
|
| 105 |
ts = time.time()
|
| 106 |
+
|
| 107 |
+
# Prompt delayer to try and get a stimulation
|
| 108 |
+
if self.delayer is not None:
|
| 109 |
+
self.wait_time = self.delayer.stimulate()
|
| 110 |
+
self.delayed = True
|
| 111 |
+
self.wait_counter = 0
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
# Stimulate if allowed
|
| 115 |
if ts - self.last_detected_ts > self.wait_t:
|
| 116 |
with self._lock:
|
| 117 |
+
if self._thread is None:
|
| 118 |
self._thread = Thread(target=self._t_sound, daemon=True)
|
| 119 |
self._thread.start()
|
| 120 |
self.last_detected_ts = ts
|
|
|
|
| 130 |
if self._thread is None:
|
| 131 |
self._thread = Thread(target=self._t_sound, daemon=True)
|
| 132 |
self._thread.start()
|
| 133 |
+
|
| 134 |
+
def add_delayer(self, delayer):
|
| 135 |
+
self.delayer = delayer
|