Spaces:
Running
Running
| # Authors: The MNE-Python contributors. | |
| # License: BSD-3-Clause | |
| # Copyright the MNE-Python contributors. | |
| import collections.abc as abc | |
| from functools import partial | |
| import numpy as np | |
| from .._fiff.meas_info import Info, create_info | |
| from .._fiff.pick import _picks_to_idx | |
| from ..filter import filter_data | |
| from ..utils import ( | |
| _validate_type, | |
| fill_doc, | |
| logger, | |
| ) | |
| from ._covs_ged import _ssd_estimate | |
| from ._mod_ged import _get_spectral_ratio, _ssd_mod | |
| from .base import _GEDTransformer | |
| class SSD(_GEDTransformer): | |
| """ | |
| Signal decomposition using the Spatio-Spectral Decomposition (SSD). | |
| SSD seeks to maximize the power at a frequency band of interest while | |
| simultaneously minimizing it at the flanking (surrounding) frequency bins | |
| (considered noise). It extremizes the covariance matrices associated with | |
| signal and noise :footcite:`NikulinEtAl2011`. | |
| SSD can either be used as a dimensionality reduction method or a | |
| ‘denoised’ low rank factorization method :footcite:`HaufeEtAl2014b`. | |
| Parameters | |
| ---------- | |
| %(info_not_none)s Must match the input data. | |
| filt_params_signal : dict | |
| Filtering for the frequencies of interest. | |
| filt_params_noise : dict | |
| Filtering for the frequencies of non-interest. | |
| reg : float | str | None (default) | |
| Which covariance estimator to use. | |
| If not None (same as 'empirical'), allow regularization for covariance | |
| estimation. If float, shrinkage is used (0 <= shrinkage <= 1). For str | |
| options, reg will be passed to method :func:`mne.compute_covariance`. | |
| n_components : int | None (default None) | |
| The number of components to extract from the signal. | |
| If None, the number of components equal to the rank of the data are | |
| returned (see ``rank``). | |
| picks : array of int | None (default None) | |
| The indices of good channels. | |
| sort_by_spectral_ratio : bool (default True) | |
| If set to True, the components are sorted according to the spectral | |
| ratio. | |
| See Eq. (24) in :footcite:`NikulinEtAl2011`. | |
| return_filtered : bool (default False) | |
| If return_filtered is True, data is bandpassed and projected onto the | |
| SSD components. | |
| n_fft : int (default None) | |
| If sort_by_spectral_ratio is set to True, then the SSD sources will be | |
| sorted according to their spectral ratio which is calculated based on | |
| :func:`mne.time_frequency.psd_array_welch`. The n_fft parameter sets the | |
| length of FFT used. The default (None) will use 1 second of data. | |
| See :func:`mne.time_frequency.psd_array_welch` for more information. | |
| cov_method_params : dict | None (default None) | |
| As in :class:`mne.decoding.SPoC` | |
| The default is None. | |
| restr_type : "restricting" | "whitening" | "ssd" | None | |
| Restricting transformation for covariance matrices before performing | |
| generalized eigendecomposition. | |
| If "restricting" only restriction to the principal subspace of signal_cov | |
| will be performed. | |
| If "whitening", covariance matrices will be additionally rescaled according | |
| to the whitening for the signal_cov. | |
| If "ssd", simplified version of "whitening" is performed. | |
| If None, no restriction will be applied. Defaults to "ssd". | |
| .. versionadded:: 1.11 | |
| rank : None | dict | ‘info’ | ‘full’ | |
| As in :class:`mne.decoding.SPoC` | |
| This controls the rank computation that can be read from the | |
| measurement info or estimated from the data, which determines the | |
| maximum possible number of components. | |
| See Notes of :func:`mne.compute_rank` for details. | |
| We recommend to use 'full' when working with epoched data. | |
| Attributes | |
| ---------- | |
| filters_ : array, shape (``n_channels or less``, n_channels) | |
| The spatial filters to be multiplied with the signal. | |
| patterns_ : array, shape (``n_channels or less``, n_channels) | |
| The patterns for reconstructing the signal from the filtered data. | |
| References | |
| ---------- | |
| .. footbibliography:: | |
| """ | |
| def __init__( | |
| self, | |
| info, | |
| filt_params_signal, | |
| filt_params_noise, | |
| reg=None, | |
| n_components=None, | |
| picks=None, | |
| sort_by_spectral_ratio=True, | |
| return_filtered=False, | |
| n_fft=None, | |
| cov_method_params=None, | |
| *, | |
| restr_type="whitening", | |
| rank=None, | |
| ): | |
| """Initialize instance.""" | |
| self.info = info | |
| self.filt_params_signal = filt_params_signal | |
| self.filt_params_noise = filt_params_noise | |
| self.reg = reg | |
| self.n_components = n_components | |
| self.picks = picks | |
| self.sort_by_spectral_ratio = sort_by_spectral_ratio | |
| self.return_filtered = return_filtered | |
| self.n_fft = n_fft | |
| self.cov_method_params = cov_method_params | |
| self.restr_type = restr_type | |
| self.rank = rank | |
| cov_callable = partial( | |
| _ssd_estimate, | |
| reg=reg, | |
| cov_method_params=cov_method_params, | |
| info=info, | |
| picks=picks, | |
| n_fft=n_fft, | |
| filt_params_signal=filt_params_signal, | |
| filt_params_noise=filt_params_noise, | |
| rank=rank, | |
| sort_by_spectral_ratio=sort_by_spectral_ratio, | |
| ) | |
| super().__init__( | |
| n_components=n_components, | |
| cov_callable=cov_callable, | |
| mod_ged_callable=_ssd_mod, | |
| restr_type=restr_type, | |
| ) | |
| def _validate_params(self, X): | |
| if isinstance(self.info, float): # special case, mostly for testing | |
| self.sfreq_ = self.info | |
| else: | |
| _validate_type(self.info, Info, "info") | |
| self.sfreq_ = self.info["sfreq"] | |
| dicts = {"signal": self.filt_params_signal, "noise": self.filt_params_noise} | |
| for param, dd in [("l", 0), ("h", 0), ("l", 1), ("h", 1)]: | |
| key = ("signal", "noise")[dd] | |
| if param + "_freq" not in dicts[key]: | |
| raise ValueError( | |
| f"{param + '_freq'} must be defined in filter parameters for {key}" | |
| ) | |
| val = dicts[key][param + "_freq"] | |
| if not isinstance(val, int | float): | |
| _validate_type(val, ("numeric",), f"{key} {param}_freq") | |
| # check freq bands | |
| if ( | |
| self.filt_params_noise["l_freq"] > self.filt_params_signal["l_freq"] | |
| or self.filt_params_signal["h_freq"] > self.filt_params_noise["h_freq"] | |
| ): | |
| raise ValueError( | |
| "Wrongly specified frequency bands!\n" | |
| "The signal band-pass must be within the noise " | |
| "band-pass!" | |
| ) | |
| self.freqs_signal_ = ( | |
| self.filt_params_signal["l_freq"], | |
| self.filt_params_signal["h_freq"], | |
| ) | |
| self.freqs_noise_ = ( | |
| self.filt_params_noise["l_freq"], | |
| self.filt_params_noise["h_freq"], | |
| ) | |
| _validate_type(self.sort_by_spectral_ratio, (bool,), "sort_by_spectral_ratio") | |
| _validate_type(self.n_fft, ("numeric", None), "n_fft") | |
| self.n_fft_ = min( | |
| int(self.n_fft if self.n_fft is not None else self.sfreq_), | |
| X.shape[-1], | |
| ) | |
| _validate_type(self.return_filtered, (bool,), "return_filtered") | |
| if isinstance(self.info, Info): | |
| ch_types = self.info.get_channel_types(picks=self.picks, unique=True) | |
| if len(ch_types) > 1: | |
| raise ValueError( | |
| "At this point SSD only supports fitting " | |
| f"single channel types. Your info has {len(ch_types)} types." | |
| ) | |
| _validate_type(self.cov_method_params, (abc.Mapping, None), "cov_method_params") | |
| def _check_X(self, X, *, y=None, fit=False): | |
| """Check input data.""" | |
| X = self._check_data(X, y=y, fit=fit, atleast_3d=False) | |
| n_chan = X.shape[-2] | |
| if isinstance(self.info, Info) and n_chan != self.info["nchan"]: | |
| raise ValueError( | |
| "Info must match the input data." | |
| f"Found {n_chan} channels but expected {self.info['nchan']}." | |
| ) | |
| return X | |
| def fit(self, X, y=None): | |
| """Estimate the SSD decomposition on raw or epoched data. | |
| Parameters | |
| ---------- | |
| X : array, shape ([n_epochs, ]n_channels, n_times) | |
| The input data from which to estimate the SSD. Either 2D array | |
| obtained from continuous data or 3D array obtained from epoched | |
| data. | |
| y : None | |
| Ignored; exists for compatibility with scikit-learn pipelines. | |
| Returns | |
| ------- | |
| self : instance of SSD | |
| Returns the modified instance. | |
| """ | |
| X = self._check_X(X, y=y, fit=True) | |
| self._validate_params(X) | |
| if isinstance(self.info, Info): | |
| info = self.info | |
| else: | |
| info = create_info(X.shape[-2], self.sfreq_, ch_types="eeg") | |
| self.picks_ = _picks_to_idx(info, self.picks, none="data", exclude="bads") | |
| super().fit(X, y) | |
| logger.info("Done.") | |
| return self | |
| def transform(self, X): | |
| """Estimate epochs sources given the SSD filters. | |
| Parameters | |
| ---------- | |
| X : array, shape ([n_epochs, ]n_channels, n_times) | |
| The input data from which to estimate the SSD. Either 2D array | |
| obtained from continuous data or 3D array obtained from epoched | |
| data. | |
| Returns | |
| ------- | |
| X_ssd : array, shape ([n_epochs, ]n_components, n_times) | |
| The processed data. | |
| """ | |
| X = self._check_X(X) | |
| # For the case where n_epochs dimension is absent. | |
| if X.ndim == 2: | |
| X = np.expand_dims(X, axis=0) | |
| X_aux = X[..., self.picks_, :] | |
| if self.return_filtered: | |
| X_aux = filter_data(X_aux, self.sfreq_, **self.filt_params_signal) | |
| X_ssd = super().transform(X_aux).squeeze() | |
| return X_ssd | |
| def fit_transform(self, X, y=None, **fit_params): | |
| """Fit SSD to data, then transform it. | |
| Fits transformer to ``X`` and ``y`` with optional parameters ``fit_params``, and | |
| returns a transformed version of ``X``. | |
| Parameters | |
| ---------- | |
| X : array, shape ([n_epochs, ]n_channels, n_times) | |
| The input data from which to estimate the SSD. Either 2D array obtained from | |
| continuous data or 3D array obtained from epoched data. | |
| y : None | |
| Ignored; exists for compatibility with scikit-learn pipelines. | |
| **fit_params : dict | |
| Additional fitting parameters passed to the :meth:`mne.decoding.SSD.fit` | |
| method. Not used for this class. | |
| Returns | |
| ------- | |
| X_ssd : array, shape ([n_epochs, ]n_components, n_times) | |
| The processed data. | |
| """ | |
| # use parent TransformerMixin method but with custom docstring | |
| return super().fit_transform(X, y=y, **fit_params) | |
| def get_spectral_ratio(self, ssd_sources): | |
| """Get the spectal signal-to-noise ratio for each spatial filter. | |
| Spectral ratio measure for best n_components selection | |
| See :footcite:`NikulinEtAl2011`, Eq. (24). | |
| Parameters | |
| ---------- | |
| ssd_sources : array | |
| Data projected to SSD space. | |
| Returns | |
| ------- | |
| spec_ratio : array, shape (n_channels) | |
| Array with the sprectal ratio value for each component. | |
| sorter_spec : array, shape (n_channels) | |
| Array of indices for sorting spec_ratio. | |
| References | |
| ---------- | |
| .. footbibliography:: | |
| """ | |
| spec_ratio, sorter_spec = _get_spectral_ratio( | |
| ssd_sources=ssd_sources, | |
| sfreq=self.sfreq_, | |
| n_fft=self.n_fft_, | |
| freqs_signal=self.freqs_signal_, | |
| freqs_noise=self.freqs_noise_, | |
| ) | |
| return spec_ratio, sorter_spec | |
| def inverse_transform(self): | |
| """Not implemented yet.""" | |
| raise NotImplementedError("inverse_transform is not yet available.") | |
| def apply(self, X): | |
| """Remove selected components from the signal. | |
| This procedure will reconstruct M/EEG signals from which the dynamics | |
| described by the excluded components is subtracted | |
| (denoised by low-rank factorization). | |
| See :footcite:`HaufeEtAl2014b` for more information. | |
| .. note:: Unlike in other classes with an apply method, | |
| only NumPy arrays are supported (not instances of MNE objects). | |
| Parameters | |
| ---------- | |
| X : array, shape ([n_epochs, ]n_channels, n_times) | |
| The input data from which to estimate the SSD. Either 2D array | |
| obtained from continuous data or 3D array obtained from epoched | |
| data. | |
| Returns | |
| ------- | |
| X : array, shape ([n_epochs, ]n_channels, n_times) | |
| The processed data. | |
| """ | |
| X_ssd = self.transform(X) | |
| pick_patterns = self.patterns_[: self.n_components].T | |
| X = pick_patterns @ X_ssd | |
| return X | |