# Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. import collections.abc as abc from functools import partial import numpy as np from .._fiff.meas_info import Info, create_info from .._fiff.pick import _picks_to_idx from ..filter import filter_data from ..utils import ( _validate_type, fill_doc, logger, ) from ._covs_ged import _ssd_estimate from ._mod_ged import _get_spectral_ratio, _ssd_mod from .base import _GEDTransformer @fill_doc class SSD(_GEDTransformer): """ Signal decomposition using the Spatio-Spectral Decomposition (SSD). SSD seeks to maximize the power at a frequency band of interest while simultaneously minimizing it at the flanking (surrounding) frequency bins (considered noise). It extremizes the covariance matrices associated with signal and noise :footcite:`NikulinEtAl2011`. SSD can either be used as a dimensionality reduction method or a ‘denoised’ low rank factorization method :footcite:`HaufeEtAl2014b`. Parameters ---------- %(info_not_none)s Must match the input data. filt_params_signal : dict Filtering for the frequencies of interest. filt_params_noise : dict Filtering for the frequencies of non-interest. reg : float | str | None (default) Which covariance estimator to use. If not None (same as 'empirical'), allow regularization for covariance estimation. If float, shrinkage is used (0 <= shrinkage <= 1). For str options, reg will be passed to method :func:`mne.compute_covariance`. n_components : int | None (default None) The number of components to extract from the signal. If None, the number of components equal to the rank of the data are returned (see ``rank``). picks : array of int | None (default None) The indices of good channels. sort_by_spectral_ratio : bool (default True) If set to True, the components are sorted according to the spectral ratio. See Eq. (24) in :footcite:`NikulinEtAl2011`. return_filtered : bool (default False) If return_filtered is True, data is bandpassed and projected onto the SSD components. n_fft : int (default None) If sort_by_spectral_ratio is set to True, then the SSD sources will be sorted according to their spectral ratio which is calculated based on :func:`mne.time_frequency.psd_array_welch`. The n_fft parameter sets the length of FFT used. The default (None) will use 1 second of data. See :func:`mne.time_frequency.psd_array_welch` for more information. cov_method_params : dict | None (default None) As in :class:`mne.decoding.SPoC` The default is None. restr_type : "restricting" | "whitening" | "ssd" | None Restricting transformation for covariance matrices before performing generalized eigendecomposition. If "restricting" only restriction to the principal subspace of signal_cov will be performed. If "whitening", covariance matrices will be additionally rescaled according to the whitening for the signal_cov. If "ssd", simplified version of "whitening" is performed. If None, no restriction will be applied. Defaults to "ssd". .. versionadded:: 1.11 rank : None | dict | ‘info’ | ‘full’ As in :class:`mne.decoding.SPoC` This controls the rank computation that can be read from the measurement info or estimated from the data, which determines the maximum possible number of components. See Notes of :func:`mne.compute_rank` for details. We recommend to use 'full' when working with epoched data. Attributes ---------- filters_ : array, shape (``n_channels or less``, n_channels) The spatial filters to be multiplied with the signal. patterns_ : array, shape (``n_channels or less``, n_channels) The patterns for reconstructing the signal from the filtered data. References ---------- .. footbibliography:: """ def __init__( self, info, filt_params_signal, filt_params_noise, reg=None, n_components=None, picks=None, sort_by_spectral_ratio=True, return_filtered=False, n_fft=None, cov_method_params=None, *, restr_type="whitening", rank=None, ): """Initialize instance.""" self.info = info self.filt_params_signal = filt_params_signal self.filt_params_noise = filt_params_noise self.reg = reg self.n_components = n_components self.picks = picks self.sort_by_spectral_ratio = sort_by_spectral_ratio self.return_filtered = return_filtered self.n_fft = n_fft self.cov_method_params = cov_method_params self.restr_type = restr_type self.rank = rank cov_callable = partial( _ssd_estimate, reg=reg, cov_method_params=cov_method_params, info=info, picks=picks, n_fft=n_fft, filt_params_signal=filt_params_signal, filt_params_noise=filt_params_noise, rank=rank, sort_by_spectral_ratio=sort_by_spectral_ratio, ) super().__init__( n_components=n_components, cov_callable=cov_callable, mod_ged_callable=_ssd_mod, restr_type=restr_type, ) def _validate_params(self, X): if isinstance(self.info, float): # special case, mostly for testing self.sfreq_ = self.info else: _validate_type(self.info, Info, "info") self.sfreq_ = self.info["sfreq"] dicts = {"signal": self.filt_params_signal, "noise": self.filt_params_noise} for param, dd in [("l", 0), ("h", 0), ("l", 1), ("h", 1)]: key = ("signal", "noise")[dd] if param + "_freq" not in dicts[key]: raise ValueError( f"{param + '_freq'} must be defined in filter parameters for {key}" ) val = dicts[key][param + "_freq"] if not isinstance(val, int | float): _validate_type(val, ("numeric",), f"{key} {param}_freq") # check freq bands if ( self.filt_params_noise["l_freq"] > self.filt_params_signal["l_freq"] or self.filt_params_signal["h_freq"] > self.filt_params_noise["h_freq"] ): raise ValueError( "Wrongly specified frequency bands!\n" "The signal band-pass must be within the noise " "band-pass!" ) self.freqs_signal_ = ( self.filt_params_signal["l_freq"], self.filt_params_signal["h_freq"], ) self.freqs_noise_ = ( self.filt_params_noise["l_freq"], self.filt_params_noise["h_freq"], ) _validate_type(self.sort_by_spectral_ratio, (bool,), "sort_by_spectral_ratio") _validate_type(self.n_fft, ("numeric", None), "n_fft") self.n_fft_ = min( int(self.n_fft if self.n_fft is not None else self.sfreq_), X.shape[-1], ) _validate_type(self.return_filtered, (bool,), "return_filtered") if isinstance(self.info, Info): ch_types = self.info.get_channel_types(picks=self.picks, unique=True) if len(ch_types) > 1: raise ValueError( "At this point SSD only supports fitting " f"single channel types. Your info has {len(ch_types)} types." ) _validate_type(self.cov_method_params, (abc.Mapping, None), "cov_method_params") def _check_X(self, X, *, y=None, fit=False): """Check input data.""" X = self._check_data(X, y=y, fit=fit, atleast_3d=False) n_chan = X.shape[-2] if isinstance(self.info, Info) and n_chan != self.info["nchan"]: raise ValueError( "Info must match the input data." f"Found {n_chan} channels but expected {self.info['nchan']}." ) return X def fit(self, X, y=None): """Estimate the SSD decomposition on raw or epoched data. Parameters ---------- X : array, shape ([n_epochs, ]n_channels, n_times) The input data from which to estimate the SSD. Either 2D array obtained from continuous data or 3D array obtained from epoched data. y : None Ignored; exists for compatibility with scikit-learn pipelines. Returns ------- self : instance of SSD Returns the modified instance. """ X = self._check_X(X, y=y, fit=True) self._validate_params(X) if isinstance(self.info, Info): info = self.info else: info = create_info(X.shape[-2], self.sfreq_, ch_types="eeg") self.picks_ = _picks_to_idx(info, self.picks, none="data", exclude="bads") super().fit(X, y) logger.info("Done.") return self def transform(self, X): """Estimate epochs sources given the SSD filters. Parameters ---------- X : array, shape ([n_epochs, ]n_channels, n_times) The input data from which to estimate the SSD. Either 2D array obtained from continuous data or 3D array obtained from epoched data. Returns ------- X_ssd : array, shape ([n_epochs, ]n_components, n_times) The processed data. """ X = self._check_X(X) # For the case where n_epochs dimension is absent. if X.ndim == 2: X = np.expand_dims(X, axis=0) X_aux = X[..., self.picks_, :] if self.return_filtered: X_aux = filter_data(X_aux, self.sfreq_, **self.filt_params_signal) X_ssd = super().transform(X_aux).squeeze() return X_ssd def fit_transform(self, X, y=None, **fit_params): """Fit SSD to data, then transform it. Fits transformer to ``X`` and ``y`` with optional parameters ``fit_params``, and returns a transformed version of ``X``. Parameters ---------- X : array, shape ([n_epochs, ]n_channels, n_times) The input data from which to estimate the SSD. Either 2D array obtained from continuous data or 3D array obtained from epoched data. y : None Ignored; exists for compatibility with scikit-learn pipelines. **fit_params : dict Additional fitting parameters passed to the :meth:`mne.decoding.SSD.fit` method. Not used for this class. Returns ------- X_ssd : array, shape ([n_epochs, ]n_components, n_times) The processed data. """ # use parent TransformerMixin method but with custom docstring return super().fit_transform(X, y=y, **fit_params) def get_spectral_ratio(self, ssd_sources): """Get the spectal signal-to-noise ratio for each spatial filter. Spectral ratio measure for best n_components selection See :footcite:`NikulinEtAl2011`, Eq. (24). Parameters ---------- ssd_sources : array Data projected to SSD space. Returns ------- spec_ratio : array, shape (n_channels) Array with the sprectal ratio value for each component. sorter_spec : array, shape (n_channels) Array of indices for sorting spec_ratio. References ---------- .. footbibliography:: """ spec_ratio, sorter_spec = _get_spectral_ratio( ssd_sources=ssd_sources, sfreq=self.sfreq_, n_fft=self.n_fft_, freqs_signal=self.freqs_signal_, freqs_noise=self.freqs_noise_, ) return spec_ratio, sorter_spec def inverse_transform(self): """Not implemented yet.""" raise NotImplementedError("inverse_transform is not yet available.") def apply(self, X): """Remove selected components from the signal. This procedure will reconstruct M/EEG signals from which the dynamics described by the excluded components is subtracted (denoised by low-rank factorization). See :footcite:`HaufeEtAl2014b` for more information. .. note:: Unlike in other classes with an apply method, only NumPy arrays are supported (not instances of MNE objects). Parameters ---------- X : array, shape ([n_epochs, ]n_channels, n_times) The input data from which to estimate the SSD. Either 2D array obtained from continuous data or 3D array obtained from epoched data. Returns ------- X : array, shape ([n_epochs, ]n_channels, n_times) The processed data. """ X_ssd = self.transform(X) pick_patterns = self.patterns_[: self.n_components].T X = pick_patterns @ X_ssd return X