Spaces:
Build error
Build error
| from abc import ABC, abstractmethod | |
| from enum import Enum | |
| import time | |
| from threading import Thread, Lock | |
| from pathlib import Path | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from portiloop.src import ADS | |
| if ADS: | |
| import alsaaudio | |
| import pylsl | |
| import wave | |
| from scipy.signal import find_peaks | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| # Abstract interface for developers: | |
| class Stimulator(ABC): | |
| def stimulate(self, detection_signal): | |
| """ | |
| Stimulates accordingly to the output of the Detector. | |
| Args: | |
| detection_signal: Object: the output of the Detector.add_datapoints method. | |
| """ | |
| raise NotImplementedError | |
| def test_stimulus(self): | |
| """ | |
| Optional: this is called when the 'Test stimulus' button is pressed. | |
| """ | |
| pass | |
| # Example implementation for sleep spindles | |
| class SleepSpindleRealTimeStimulator(Stimulator): | |
| def __init__(self): | |
| self._sound = Path(__file__).parent.parent / 'sounds' / 'stimulus.wav' | |
| print(f"DEBUG:{self._sound}") | |
| self._thread = None | |
| self._lock = Lock() | |
| self.last_detected_ts = time.time() | |
| self.wait_t = 0.4 # 400 ms | |
| self.delayer = None | |
| lsl_markers_info = pylsl.StreamInfo(name='Portiloop_stimuli', | |
| type='Markers', | |
| channel_count=1, | |
| channel_format='string', | |
| source_id='portiloop1') # TODO: replace this by unique device identifier | |
| # lsl_markers_info_fast = pylsl.StreamInfo(name='Portiloop_stimuli_fast', | |
| # type='Markers', | |
| # channel_count=1, | |
| # channel_format='string', | |
| # source_id='portiloop1') # TODO: replace this by unique device identifier | |
| self.lsl_outlet_markers = pylsl.StreamOutlet(lsl_markers_info) | |
| # self.lsl_outlet_markers_fast = pylsl.StreamOutlet(lsl_markers_info_fast) | |
| # Initialize Alsa stuff | |
| # Open WAV file and set PCM device | |
| with wave.open(str(self._sound), 'rb') as f: | |
| device = 'default' | |
| format = None | |
| # 8bit is unsigned in wav files | |
| if f.getsampwidth() == 1: | |
| format = alsaaudio.PCM_FORMAT_U8 | |
| # Otherwise we assume signed data, little endian | |
| elif f.getsampwidth() == 2: | |
| format = alsaaudio.PCM_FORMAT_S16_LE | |
| elif f.getsampwidth() == 3: | |
| format = alsaaudio.PCM_FORMAT_S24_3LE | |
| elif f.getsampwidth() == 4: | |
| format = alsaaudio.PCM_FORMAT_S32_LE | |
| else: | |
| raise ValueError('Unsupported format') | |
| self.periodsize = f.getframerate() // 8 | |
| self.pcm = alsaaudio.PCM(channels=f.getnchannels(), rate=f.getframerate(), format=format, periodsize=self.periodsize, device=device) | |
| # Store data in list to avoid reopening the file | |
| data = f.readframes(self.periodsize) | |
| self.wav_list = [data] | |
| while data: | |
| self.wav_list.append(data) | |
| data = f.readframes(self.periodsize) | |
| def play_sound(self): | |
| ''' | |
| Open the wav file and play a sound | |
| ''' | |
| for data in self.wav_list: | |
| self.pcm.write(data) | |
| def stimulate(self, detection_signal): | |
| for sig in detection_signal: | |
| # We detect a stimulation | |
| if sig: | |
| # Record time of stimulation | |
| ts = time.time() | |
| # Check if time since last stimulation is long enough | |
| if ts - self.last_detected_ts > self.wait_t: | |
| if self.delayer is not None: | |
| # If we have a delayer, notify it | |
| self.delayer.detected() | |
| # Send the LSL marer for the fast stimulation | |
| self.send_stimulation("FAST_STIM", False) | |
| else: | |
| self.send_stimulation("STIM", True) | |
| self.last_detected_ts = ts | |
| def send_stimulation(self, lsl_text, sound): | |
| print(f"Stimulating with text: {lsl_text}") | |
| # Send lsl stimulation | |
| self.lsl_outlet_markers.push_sample([lsl_text]) | |
| # Send sound to patient | |
| if sound: | |
| with self._lock: | |
| if self._thread is None: | |
| self._thread = Thread(target=self._t_sound, daemon=True) | |
| self._thread.start() | |
| def _t_sound(self): | |
| self.play_sound() | |
| with self._lock: | |
| self._thread = None | |
| def test_stimulus(self): | |
| with self._lock: | |
| if self._thread is None: | |
| self._thread = Thread(target=self._t_sound, daemon=True) | |
| self._thread.start() | |
| def add_delayer(self, delayer): | |
| self.delayer = delayer | |
| self.delayer.stimulate = lambda: self.send_stimulation("DELAY_STIM", True) | |
| class SpindleTrainRealTimeStimulator(SleepSpindleRealTimeStimulator): | |
| def __init__(self): | |
| self.max_spindle_train_t = 6.0 | |
| super().__init__() | |
| def stimulate(self, detection_signal): | |
| for sig in detection_signal: | |
| # We detect a stimulation | |
| if sig: | |
| # Record time of stimulation | |
| ts = time.time() | |
| # Check if time since last stimulation is long enough | |
| elapsed = ts - self.last_detected_ts | |
| if self.wait_t < elapsed < self.max_spindle_train_t: | |
| if self.delayer is not None: | |
| # If we have a delayer, notify it | |
| self.delayer.detected() | |
| # Send the LSL marer for the fast stimulation | |
| self.send_stimulation("FAST_STIM", False) | |
| else: | |
| self.send_stimulation("STIM", True) | |
| self.last_detected_ts = ts | |
| class IsolatedSpindleRealTimeStimulator(SpindleTrainRealTimeStimulator): | |
| def stimulate(self, detection_signal): | |
| for sig in detection_signal: | |
| # We detect a stimulation | |
| if sig: | |
| # Record time of stimulation | |
| ts = time.time() | |
| # Check if time since last stimulation is long enough | |
| elapsed = ts - self.last_detected_ts | |
| if self.max_spindle_train_t < elapsed: | |
| if self.delayer is not None: | |
| # If we have a delayer, notify it | |
| self.delayer.detected() | |
| # Send the LSL marer for the fast stimulation | |
| self.send_stimulation("FAST_STIM", False) | |
| else: | |
| self.send_stimulation("STIM", True) | |
| self.last_detected_ts = ts | |
| # Class that delays stimulation to always stimulate peak or through | |
| class UpStateDelayer: | |
| def __init__(self, sample_freq, peak, time_to_buffer, stimulate=None): | |
| ''' | |
| args: | |
| sample_freq: int -> Sampling frequency of signal in Hz | |
| time_to_wait: float -> Time to wait to build buffer in seconds | |
| ''' | |
| # Get number of timesteps for a whole spindle | |
| self.sample_freq = sample_freq | |
| self.peak = peak | |
| self.buffer = [] | |
| self.time_to_buffer = time_to_buffer | |
| self.stimulate = stimulate | |
| self.state = States.NO_SPINDLE | |
| def step(self, point): | |
| ''' | |
| Step the delayer, ads a point to buffer if necessary. | |
| Returns True if stimulation is actually done | |
| ''' | |
| if self.state == States.NO_SPINDLE: | |
| return False | |
| elif self.state == States.BUFFERING: | |
| self.buffer.append(point) | |
| # If we are done buffering, move on to the waiting stage | |
| if time.time() - self.time_started >= self.time_to_buffer: | |
| # Compute the necessary time to wait | |
| self.time_to_wait = self.compute_time_to_wait() | |
| self.state = States.DELAYING | |
| self.buffer = [] | |
| self.time_started = time.time() | |
| return False | |
| elif self.state == States.DELAYING: | |
| # Check if we are done delaying | |
| if time.time() - self.time_started >= self.time_to_wait: | |
| # Actually stimulate the patient after the delay | |
| if self.stimulate is not None: | |
| self.stimulate() | |
| # Reset state | |
| self.time_to_wait = -1 | |
| self.state = States.NO_SPINDLE | |
| return True | |
| return False | |
| def step_timesteps(self, point): | |
| ''' | |
| Step the delayer, ads a point to buffer if necessary. | |
| Returns True if stimulation is actually done | |
| ''' | |
| if self.state == States.NO_SPINDLE: | |
| return False | |
| elif self.state == States.BUFFERING: | |
| self.buffer.append(point) | |
| # If we are done buffering, move on to the waiting stage | |
| if len(self.buffer) >= self.time_to_buffer * self.sample_freq: | |
| # Compute the necessary time to wait | |
| self.time_to_wait = self.compute_time_to_wait() | |
| self.state = States.DELAYING | |
| self.buffer = [] | |
| self.delaying_counter = 0 | |
| return False | |
| elif self.state == States.DELAYING: | |
| # Check if we are done delaying | |
| self.delaying_counter += 1 | |
| if self.delaying_counter >= self.time_to_wait * self.sample_freq: | |
| # Actually stimulate the patient after the delay | |
| if self.stimulate is not None: | |
| self.stimulate() | |
| # Reset state | |
| self.time_to_wait = -1 | |
| self.state = States.NO_SPINDLE | |
| return True | |
| return False | |
| def detected(self): | |
| if self.state == States.NO_SPINDLE: | |
| self.state = States.BUFFERING | |
| def compute_time_to_wait(self): | |
| """ | |
| Computes the time we want to wait in total based on the spindle frequency and the buffer | |
| """ | |
| # If we want to look at the valleys, we search for peaks on the inversed signal | |
| if not self.peak: | |
| self.buffer = -self.buffer | |
| # Returns the index of the last peak in the buffer | |
| peaks, _ = find_peaks(self.buffer, prominence=1) | |
| # Make a figure to show the peaks | |
| if False: | |
| plt.figure() | |
| plt.plot(self.buffer) | |
| for peak in peaks: | |
| plt.axvline(x=peak) | |
| plt.plot(np.zeros_like(self.buffer), "--", color="gray") | |
| plt.show() | |
| if len(peaks) == 0: | |
| print("No peaks found, increase buffer size") | |
| return (self.sample_freq / 10) * (1.0 / self.sample_freq) | |
| # Compute average distance between each peak | |
| avg_dist = np.mean(np.diff(peaks)) | |
| # Compute the time until next peak and return it | |
| if (avg_dist < len(self.buffer) - peaks[-1]): | |
| print("Average distance between peaks is smaller than the time to last peak, decrease buffer size") | |
| return (len(self.buffer) - peaks[-1]) * (1.0 / self.sample_freq) | |
| return (avg_dist - (len(self.buffer) - peaks[-1])) * (1.0 / self.sample_freq) | |
| class States(Enum): | |
| NO_SPINDLE = 0 | |
| BUFFERING = 1 | |
| DELAYING = 2 | |
| if __name__ == "__main__": | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| freq = 250 | |
| spindle_freq = 10 | |
| time = 10 | |
| x = np.linspace(0, time * np.pi, num=time*freq) | |
| n = np.random.normal(scale=1, size=x.size) | |
| y = np.sin(x) + n | |
| plt.plot(x, y) | |
| plt.show() | |