Spaces:
Running
Running
| # Authors: The MNE-Python contributors. | |
| # License: BSD-3-Clause | |
| # Copyright the MNE-Python contributors. | |
| import warnings | |
| from functools import partial | |
| import numpy as np | |
| from scipy.signal import spectrogram | |
| from ..parallel import parallel_func | |
| from ..utils import _check_option, _ensure_int, logger, verbose, warn | |
| from ..utils.numerics import _mask_to_onsets_offsets | |
| # adapted from SciPy | |
| # https://github.com/scipy/scipy/blob/f71e7fad717801c4476312fe1e23f2dfbb4c9d7f/scipy/signal/_spectral_py.py#L2019 # noqa: E501 | |
| def _median_biases(n): | |
| # Compute the biases for 0 to max(n, 1) terms included in a median calc | |
| biases = np.ones(n + 1) | |
| # The original SciPy code is: | |
| # | |
| # def _median_bias(n): | |
| # ii_2 = 2 * np.arange(1., (n - 1) // 2 + 1) | |
| # return 1 + np.sum(1. / (ii_2 + 1) - 1. / ii_2) | |
| # | |
| # This is a sum over (n-1)//2 terms. | |
| # The ii_2 terms here for different n are: | |
| # | |
| # n=0: [] # 0 terms | |
| # n=1: [] # 0 terms | |
| # n=2: [] # 0 terms | |
| # n=3: [2] # 1 term | |
| # n=4: [2] # 1 term | |
| # n=5: [2, 4] # 2 terms | |
| # n=6: [2, 4] # 2 terms | |
| # ... | |
| # | |
| # We can get the terms for 0 through n using a cumulative summation and | |
| # indexing: | |
| if n >= 3: | |
| ii_2 = 2 * np.arange(1, (n - 1) // 2 + 1) | |
| sums = 1 + np.cumsum(1.0 / (ii_2 + 1) - 1.0 / ii_2) | |
| idx = np.arange(2, n) // 2 - 1 | |
| biases[3:] = sums[idx] | |
| return biases | |
| def _decomp_aggregate_mask(epoch, func, average, freq_sl): | |
| _, _, spect = func(epoch) | |
| spect = spect[..., freq_sl, :] | |
| # Do the averaging here (per epoch) to save memory | |
| if average == "mean": | |
| spect = np.nanmean(spect, axis=-1) | |
| elif average == "median": | |
| biases = _median_biases(spect.shape[-1]) | |
| idx = (~np.isnan(spect)).sum(-1) | |
| spect = np.nanmedian(spect, axis=-1) / biases[idx] | |
| return spect | |
| def _spect_func(epoch, func, freq_sl, average, *, output="power"): | |
| """Aux function.""" | |
| # Decide if we should split this to save memory or not, since doing | |
| # multiple calls will incur some performance overhead. Eventually we might | |
| # want to write (really, go back to) our own spectrogram implementation | |
| # that, if possible, averages after each transform, but this will incur | |
| # a lot of overhead because of the many Python calls required. | |
| kwargs = dict(func=func, average=average, freq_sl=freq_sl) | |
| if epoch.nbytes > 10e6: | |
| spect = np.apply_along_axis(_decomp_aggregate_mask, -1, epoch, **kwargs) | |
| else: | |
| spect = _decomp_aggregate_mask(epoch, **kwargs) | |
| return spect | |
| def _check_nfft(n, n_fft, n_per_seg, n_overlap): | |
| """Ensure n_fft, n_per_seg and n_overlap make sense.""" | |
| if n_per_seg is None and n_fft > n: | |
| raise ValueError( | |
| "If n_per_seg is None n_fft is not allowed to be > " | |
| "n_times. If you want zero-padding, you have to set " | |
| f"n_per_seg to relevant length. Got n_fft of {n_fft} while" | |
| f" signal length is {n}." | |
| ) | |
| n_per_seg = n_fft if n_per_seg is None or n_per_seg > n_fft else n_per_seg | |
| n_per_seg = n if n_per_seg > n else n_per_seg | |
| if n_overlap >= n_per_seg: | |
| raise ValueError( | |
| "n_overlap cannot be greater than n_per_seg (or n_fft). Got n_overlap " | |
| f"of {n_overlap} while n_per_seg is {n_per_seg}." | |
| ) | |
| return n_fft, n_per_seg, n_overlap | |
| def psd_array_welch( | |
| x, | |
| sfreq, | |
| fmin=0, | |
| fmax=np.inf, | |
| n_fft=256, | |
| n_overlap=0, | |
| n_per_seg=None, | |
| n_jobs=None, | |
| average="mean", | |
| window="hamming", | |
| remove_dc=True, | |
| *, | |
| output="power", | |
| verbose=None, | |
| ): | |
| """Compute power spectral density (PSD) using Welch's method. | |
| Welch's method is described in :footcite:t:`Welch1967`. | |
| Parameters | |
| ---------- | |
| x : array, shape=(..., n_times) | |
| The data to compute PSD from. | |
| sfreq : float | |
| The sampling frequency. | |
| fmin : float | |
| The lower frequency of interest. | |
| fmax : float | |
| The upper frequency of interest. | |
| n_fft : int | |
| The length of FFT used, must be ``>= n_per_seg`` (default: 256). | |
| The segments will be zero-padded if ``n_fft > n_per_seg``. | |
| n_overlap : int | |
| The number of points of overlap between segments. Will be adjusted | |
| to be <= n_per_seg. The default value is 0. | |
| n_per_seg : int | None | |
| Length of each Welch segment (windowed with a Hamming window). Defaults | |
| to None, which sets n_per_seg equal to n_fft. | |
| %(n_jobs)s | |
| %(average_psd)s | |
| .. versionadded:: 0.19.0 | |
| %(window_psd)s | |
| .. versionadded:: 0.22.0 | |
| %(remove_dc)s | |
| output : str | |
| The format of the returned ``psds`` array, ``'complex'`` or | |
| ``'power'``: | |
| * ``'power'`` : the power spectral density is returned. | |
| * ``'complex'`` : the complex fourier coefficients are returned per | |
| window. | |
| .. versionadded:: 1.4.0 | |
| %(verbose)s | |
| Returns | |
| ------- | |
| psds : ndarray, shape (..., n_freqs) or (..., n_freqs, n_segments) | |
| The power spectral densities. If ``average='mean`` or | |
| ``average='median'``, the returned array will have the same shape | |
| as the input data plus an additional frequency dimension. | |
| If ``average=None``, the returned array will have the same shape as | |
| the input data plus two additional dimensions corresponding to | |
| frequencies and the unaggregated segments, respectively. | |
| freqs : ndarray, shape (n_freqs,) | |
| The frequencies. | |
| Notes | |
| ----- | |
| .. versionadded:: 0.14.0 | |
| References | |
| ---------- | |
| .. footbibliography:: | |
| """ | |
| _check_option("average", average, (None, False, "mean", "median")) | |
| _check_option("output", output, ("power", "complex")) | |
| detrend = "constant" if remove_dc else False | |
| mode = "complex" if output == "complex" else "psd" | |
| n_fft = _ensure_int(n_fft, "n_fft") | |
| n_overlap = _ensure_int(n_overlap, "n_overlap") | |
| if n_per_seg is not None: | |
| n_per_seg = _ensure_int(n_per_seg, "n_per_seg") | |
| if average is False: | |
| average = None | |
| dshape = x.shape[:-1] | |
| n_times = x.shape[-1] | |
| x = x.reshape(-1, n_times) | |
| # Prep the PSD | |
| n_fft, n_per_seg, n_overlap = _check_nfft(n_times, n_fft, n_per_seg, n_overlap) | |
| win_size = n_fft / float(sfreq) | |
| logger.info(f"Effective window size : {win_size:0.3f} (s)") | |
| freqs = np.arange(n_fft // 2 + 1, dtype=float) * (sfreq / n_fft) | |
| freq_mask = (freqs >= fmin) & (freqs <= fmax) | |
| if not freq_mask.any(): | |
| raise ValueError(f"No frequencies found between fmin={fmin} and fmax={fmax}") | |
| freq_sl = slice(*(np.where(freq_mask)[0][[0, -1]] + [0, 1])) | |
| del freq_mask | |
| freqs = freqs[freq_sl] | |
| step = max(int(n_per_seg) - int(n_overlap), 1) | |
| if n_times >= n_per_seg: | |
| n_segments = 1 + (n_times - n_per_seg) // step | |
| analyzed_end = step * (n_segments - 1) + n_per_seg | |
| else: | |
| n_segments = 0 | |
| analyzed_end = 0 | |
| nan_mask_full = np.isnan(x) | |
| nan_present = nan_mask_full.any() | |
| if nan_present: | |
| good_mask_full = ~nan_mask_full | |
| aligned_nan = np.allclose(good_mask_full, good_mask_full[[0]], equal_nan=True) | |
| else: | |
| aligned_nan = False | |
| if analyzed_end > 0: | |
| # Inf always counts as non-finite per-channel | |
| nonfinite_mask = np.isinf(x[..., :analyzed_end]) | |
| # NaNs count per-channel only if NOT aligned (i.e., not annotations) | |
| if nan_present and not aligned_nan: | |
| nonfinite_mask |= nan_mask_full[..., :analyzed_end] | |
| bad_ch = nonfinite_mask.any(axis=-1) | |
| else: | |
| bad_ch = np.zeros(x.shape[0], dtype=bool) | |
| if bad_ch.any(): | |
| warn( | |
| "Non-finite values (NaN/Inf) detected in some channels; PSD for " | |
| "those channels will be NaN.", | |
| ) | |
| # avoid downstream NumPy warnings by zeroing bad channels; | |
| # will overwrite their PSD rows with NaN at the end | |
| x = x.copy() | |
| x[bad_ch] = 0.0 | |
| # Parallelize across first N-1 dimensions | |
| logger.debug( | |
| f"Spectogram using {n_fft}-point FFT on {n_per_seg} samples with " | |
| f"{n_overlap} overlap and {window} window" | |
| ) | |
| parallel, my_spect_func, n_jobs = parallel_func(_spect_func, n_jobs=n_jobs) | |
| _func = partial( | |
| spectrogram, | |
| detrend=detrend, | |
| noverlap=n_overlap, | |
| nperseg=n_per_seg, | |
| nfft=n_fft, | |
| fs=sfreq, | |
| window=window, | |
| mode=mode, | |
| ) | |
| if nan_present and aligned_nan: | |
| # Aligned NaNs across channels → treat as bad annotations. | |
| good_mask = ~nan_mask_full | |
| t_onsets, t_offsets = _mask_to_onsets_offsets(good_mask[0]) | |
| x_splits = [x[..., t_ons:t_off] for t_ons, t_off in zip(t_onsets, t_offsets)] | |
| # weights reflect the number of samples used from each span. For spans longer | |
| # than `n_per_seg`, trailing samples may be discarded. For spans shorter than | |
| # `n_per_seg`, the wrapped function (`scipy.signal.spectrogram`) automatically | |
| # reduces `n_per_seg` to match the span length (with a warning). | |
| step = n_per_seg - n_overlap | |
| span_lengths = [span.shape[-1] for span in x_splits] | |
| weights = [ | |
| w if w < n_per_seg else w - ((w - n_overlap) % step) for w in span_lengths | |
| ] | |
| agg_func = partial(np.average, weights=weights) | |
| if n_jobs > 1: | |
| logger.info( | |
| f"Data split into {len(x_splits)} (probably unequal) chunks due to " | |
| '"bad_*" annotations. Parallelization may be sub-optimal.' | |
| ) | |
| if (np.array(span_lengths) < n_per_seg).any(): | |
| logger.info( | |
| "At least one good data span is shorter than n_per_seg, and will be " | |
| "analyzed with a shorter window than the rest of the file." | |
| ) | |
| def func(*args, **kwargs): | |
| # swallow SciPy warnings caused by short good data spans | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings( | |
| action="ignore", | |
| module="scipy", | |
| category=UserWarning, | |
| message=r"nperseg = \d+ is greater than input length", | |
| ) | |
| return _func(*args, **kwargs) | |
| else: | |
| # Either no NaNs, or NaNs are not aligned across channels. | |
| if nan_present and not aligned_nan: | |
| logger.info( | |
| "NaN masks are not aligned across channels; treating NaNs as " | |
| "per-channel contamination." | |
| ) | |
| x_splits = [arr for arr in np.array_split(x, n_jobs) if arr.size != 0] | |
| agg_func = np.concatenate | |
| func = _func | |
| f_spect = parallel( | |
| my_spect_func(d, func=func, freq_sl=freq_sl, average=average, output=output) | |
| for d in x_splits | |
| ) | |
| psds = agg_func(f_spect, axis=0) | |
| shape = dshape + (len(freqs),) | |
| if average is None: | |
| shape = shape + (-1,) | |
| if bad_ch.any(): | |
| psds[bad_ch] = np.nan | |
| psds.shape = shape | |
| return psds, freqs | |