# Authors: The MNE-Python contributors. # License: BSD-3-Clause # Copyright the MNE-Python contributors. import numpy as np from sklearn.base import BaseEstimator, TransformerMixin, check_array, clone from sklearn.preprocessing import RobustScaler, StandardScaler from sklearn.utils import check_X_y from sklearn.utils.validation import check_is_fitted from .._fiff.pick import ( _pick_data_channels, _picks_by_type, _picks_to_idx, pick_info, ) from ..cov import _check_scalings_user from ..epochs import BaseEpochs from ..filter import filter_data from ..time_frequency import psd_array_multitaper from ..utils import _check_option, _validate_type, check_version, fill_doc from ._fixes import validate_data # TODO VERSION remove with sklearn 1.4+ class MNETransformerMixin(TransformerMixin): """TransformerMixin plus some helpers.""" def _check_data( self, epochs_data, *, y=None, atleast_3d=True, fit=False, return_y=False, multi_output=False, check_n_features=True, ): # Sklearn calls asarray under the hood which works, but elsewhere they check for # __len__ then look at the size of obj[0]... which is an epoch of shape (1, ...) # rather than what they expect (shape (...)). So we explicitly get the NumPy # array to make everyone happy. if isinstance(epochs_data, BaseEpochs): epochs_data = epochs_data.get_data(copy=False) kwargs = dict(dtype=np.float64, allow_nd=True, order="C") if check_version("sklearn", "1.4"): # TODO VERSION sklearn 1.4+ kwargs["force_writeable"] = True if hasattr(self, "n_features_in_") and check_n_features: if y is None: epochs_data = validate_data( self, epochs_data, **kwargs, reset=fit, ) else: epochs_data, y = validate_data( self, epochs_data, y, **kwargs, reset=fit, ) elif y is None: epochs_data = check_array(epochs_data, **kwargs) else: epochs_data, y = check_X_y( X=epochs_data, y=y, multi_output=multi_output, **kwargs ) if fit: self.n_features_in_ = epochs_data.shape[1] if atleast_3d: epochs_data = np.atleast_3d(epochs_data) return (epochs_data, y) if return_y else epochs_data class _ConstantScaler: """Scale channel types using constant values.""" def __init__(self, info, scalings, do_scaling=True): self._scalings = scalings self._info = info self._do_scaling = do_scaling def fit(self, X, y=None): scalings = _check_scalings_user(self._scalings) picks_by_type = _picks_by_type( pick_info(self._info, _pick_data_channels(self._info, exclude=())) ) std = np.ones(sum(len(p[1]) for p in picks_by_type)) if X.shape[1] != len(std): raise ValueError( f"info had {len(std)} data channels but X has {len(X)} channels" ) if self._do_scaling: # this is silly, but necessary for completeness for kind, picks in picks_by_type: std[picks] = 1.0 / scalings[kind] self.std_ = std self.mean_ = np.zeros_like(std) return self def transform(self, X): return X / self.std_ def inverse_transform(self, X, y=None): return X * self.std_ def fit_transform(self, X, y=None): return self.fit(X, y).transform(X) def _sklearn_reshape_apply(func, return_result, X, *args, **kwargs): """Reshape epochs and apply function.""" _validate_type(X, np.ndarray, "X") if X.size == 0: return X.copy() if return_result else None orig_shape = X.shape X = np.reshape(X.transpose(0, 2, 1), (-1, orig_shape[1])) X = func(X, *args, **kwargs) if return_result: X.shape = (orig_shape[0], orig_shape[2], orig_shape[1]) X = X.transpose(0, 2, 1) return X @fill_doc class Scaler(MNETransformerMixin, BaseEstimator): """Standardize channel data. This class scales data for each channel. It differs from scikit-learn classes (e.g., :class:`sklearn.preprocessing.StandardScaler`) in that it scales each *channel* by estimating μ and σ using data from all time points and epochs, as opposed to standardizing each *feature* (i.e., each time point for each channel) by estimating using μ and σ using data from all epochs. Parameters ---------- %(info)s Only necessary if ``scalings`` is a dict or None. scalings : dict, str, default None Scaling method to be applied to data channel wise. * if scalings is None (default), scales mag by 1e15, grad by 1e13, and eeg by 1e6. * if scalings is :class:`dict`, keys are channel types and values are scale factors. * if ``scalings=='median'``, :class:`sklearn.preprocessing.RobustScaler` is used (requires sklearn version 0.17+). * if ``scalings=='mean'``, :class:`sklearn.preprocessing.StandardScaler` is used. with_mean : bool, default True If True, center the data using mean (or median) before scaling. Ignored for channel-type scaling. with_std : bool, default True If True, scale the data to unit variance (``scalings='mean'``), quantile range (``scalings='median``), or using channel type if ``scalings`` is a dict or None). """ def __init__(self, info=None, scalings=None, with_mean=True, with_std=True): self.info = info self.with_mean = with_mean self.with_std = with_std self.scalings = scalings def fit(self, epochs_data, y=None): """Standardize data across channels. Parameters ---------- epochs_data : array, shape (n_epochs, n_channels, n_times) The data to concatenate channels. y : array, shape (n_epochs,) The label for each epoch. Returns ------- self : instance of Scaler The modified instance. """ epochs_data = self._check_data(epochs_data, y=y, fit=True, multi_output=True) assert epochs_data.ndim == 3, epochs_data.shape _validate_type(self.scalings, (dict, str, type(None)), "scalings") if isinstance(self.scalings, str): _check_option( "scalings", self.scalings, ["mean", "median"], extra="when str" ) if self.scalings is None or isinstance(self.scalings, dict): if self.info is None: raise ValueError( f'Need to specify "info" if scalings is {type(self.scalings)}' ) self.scaler_ = _ConstantScaler(self.info, self.scalings, self.with_std) elif self.scalings == "mean": self.scaler_ = StandardScaler( with_mean=self.with_mean, with_std=self.with_std ) else: # scalings == 'median': self.scaler_ = RobustScaler( with_centering=self.with_mean, with_scaling=self.with_std ) _sklearn_reshape_apply(self.scaler_.fit, False, epochs_data, y=y) return self def transform(self, epochs_data): """Standardize data across channels. Parameters ---------- epochs_data : array, shape (n_epochs, n_channels[, n_times]) The data. Returns ------- X : array, shape (n_epochs, n_channels, n_times) The data concatenated over channels. Notes ----- This function makes a copy of the data before the operations and the memory usage may be large with big data. """ check_is_fitted(self, "scaler_") epochs_data = self._check_data(epochs_data, atleast_3d=False) if epochs_data.ndim == 2: # can happen with SlidingEstimator if self.info is not None: assert len(self.info["ch_names"]) == epochs_data.shape[1] epochs_data = epochs_data[..., np.newaxis] assert epochs_data.ndim == 3, epochs_data.shape return _sklearn_reshape_apply(self.scaler_.transform, True, epochs_data) def fit_transform(self, epochs_data, y=None): """Fit to data, then transform it. Fits transformer to epochs_data and y and returns a transformed version of epochs_data. Parameters ---------- epochs_data : array, shape (n_epochs, n_channels, n_times) The data. y : None | array, shape (n_epochs,) The label for each epoch. Defaults to None. Returns ------- X : array, shape (n_epochs, n_channels, n_times) The data concatenated over channels. Notes ----- This function makes a copy of the data before the operations and the memory usage may be large with big data. """ return self.fit(epochs_data, y).transform(epochs_data) def inverse_transform(self, epochs_data): """Invert standardization of data across channels. Parameters ---------- epochs_data : array, shape ([n_epochs, ]n_channels, n_times) The data. Returns ------- X : array, shape (n_epochs, n_channels, n_times) The data concatenated over channels. Notes ----- This function makes a copy of the data before the operations and the memory usage may be large with big data. """ epochs_data = self._check_data(epochs_data, atleast_3d=False) squeeze = False # Can happen with CSP if epochs_data.ndim == 2: squeeze = True epochs_data = epochs_data[..., np.newaxis] assert epochs_data.ndim == 3, epochs_data.shape out = _sklearn_reshape_apply(self.scaler_.inverse_transform, True, epochs_data) if squeeze: out = out[..., 0] return out class Vectorizer(MNETransformerMixin, BaseEstimator): """Transform n-dimensional array into 2D array of n_samples by n_features. This class reshapes an n-dimensional array into an n_samples * n_features array, usable by the estimators and transformers of scikit-learn. Attributes ---------- features_shape_ : tuple Stores the original shape of data. Examples -------- >>> from sklearn.linear_model import LogisticRegression >>> from sklearn.pipeline import make_pipeline >>> from sklearn.preprocessing import StandardScaler >>> clf = make_pipeline(Vectorizer(), StandardScaler(), LogisticRegression()) """ def fit(self, X, y=None): """Store the shape of the features of X. Parameters ---------- X : array-like The data to fit. Can be, for example a list, or an array of at least 2d. The first dimension must be of length n_samples, where samples are the independent samples used by the estimator (e.g. n_epochs for epoched data). y : None | array, shape (n_samples,) Used for scikit-learn compatibility. Returns ------- self : instance of Vectorizer Return the modified instance. """ X = self._check_data(X, y=y, atleast_3d=False, fit=True, check_n_features=False) self.features_shape_ = X.shape[1:] return self def transform(self, X): """Convert given array into two dimensions. Parameters ---------- X : array-like The data to fit. Can be, for example a list, or an array of at least 2d. The first dimension must be of length n_samples, where samples are the independent samples used by the estimator (e.g. n_epochs for epoched data). Returns ------- X : array, shape (n_samples, n_features) The transformed data. """ X = self._check_data(X, atleast_3d=False) if X.shape[1:] != self.features_shape_: raise ValueError("Shape of X used in fit and transform must be same") return X.reshape(len(X), -1) def fit_transform(self, X, y=None): """Fit the data, then transform in one step. Parameters ---------- X : array-like The data to fit. Can be, for example a list, or an array of at least 2d. The first dimension must be of length n_samples, where samples are the independent samples used by the estimator (e.g. n_epochs for epoched data). y : None | array, shape (n_samples,) Used for scikit-learn compatibility. Returns ------- X : array, shape (n_samples, -1) The transformed data. """ return self.fit(X).transform(X) def inverse_transform(self, X): """Transform 2D data back to its original feature shape. Parameters ---------- X : array-like, shape (n_samples, n_features) Data to be transformed back to original shape. Returns ------- X : array The data transformed into shape as used in fit. The first dimension is of length n_samples. """ X = self._check_data(X, atleast_3d=False, check_n_features=False) if X.ndim not in (2, 3): raise ValueError( f"X should be of 2 or 3 dimensions but has shape {X.shape}" ) return X.reshape(X.shape[:-1] + self.features_shape_) @fill_doc class PSDEstimator(MNETransformerMixin, BaseEstimator): """Compute power spectral density (PSD) using a multi-taper method. Parameters ---------- sfreq : float The sampling frequency. fmin : float The lower frequency of interest. fmax : float The upper frequency of interest. bandwidth : float The bandwidth of the multi taper windowing function in Hz. adaptive : bool Use adaptive weights to combine the tapered spectra into PSD (slow, use n_jobs >> 1 to speed up computation). low_bias : bool Only use tapers with more than 90%% spectral concentration within bandwidth. n_jobs : int Number of parallel jobs to use (only used if adaptive=True). %(normalization)s See Also -------- mne.time_frequency.psd_array_multitaper mne.io.Raw.compute_psd mne.Epochs.compute_psd mne.Evoked.compute_psd """ def __init__( self, sfreq=2 * np.pi, fmin=0, fmax=np.inf, bandwidth=None, adaptive=False, low_bias=True, n_jobs=None, normalization="length", ): self.sfreq = sfreq self.fmin = fmin self.fmax = fmax self.bandwidth = bandwidth self.adaptive = adaptive self.low_bias = low_bias self.n_jobs = n_jobs self.normalization = normalization def __sklearn_tags__(self): """...""" tags = super().__sklearn_tags__() tags.target_tags.required = False tags.requires_fit = False return tags def fit(self, epochs_data, y=None): """Compute power spectral density (PSD) using a multi-taper method. Parameters ---------- epochs_data : array, shape (n_epochs, n_channels, n_times) The data. y : array, shape (n_epochs,) The label for each epoch. Returns ------- self : instance of PSDEstimator The modified instance. """ self._check_data(epochs_data, y=y, fit=True) self.fitted_ = True # sklearn compliance return self def transform(self, epochs_data): """Compute power spectral density (PSD) using a multi-taper method. Parameters ---------- epochs_data : array, shape (n_epochs, n_channels, n_times) The data. Returns ------- psd : array, shape (n_signals, n_freqs) or (n_freqs,) The computed PSD. """ epochs_data = self._check_data(epochs_data) psd, _ = psd_array_multitaper( epochs_data, sfreq=self.sfreq, fmin=self.fmin, fmax=self.fmax, bandwidth=self.bandwidth, adaptive=self.adaptive, low_bias=self.low_bias, normalization=self.normalization, n_jobs=self.n_jobs, ) return psd @fill_doc class FilterEstimator(MNETransformerMixin, BaseEstimator): """Estimator to filter RtEpochs. Applies a zero-phase low-pass, high-pass, band-pass, or band-stop filter to the channels selected by "picks". l_freq and h_freq are the frequencies below which and above which, respectively, to filter out of the data. Thus the uses are: - l_freq < h_freq: band-pass filter - l_freq > h_freq: band-stop filter - l_freq is not None, h_freq is None: low-pass filter - l_freq is None, h_freq is not None: high-pass filter If n_jobs > 1, more memory is required as "len(picks) * n_times" additional time points need to be temporarily stored in memory. Parameters ---------- %(info_not_none)s %(l_freq)s %(h_freq)s %(picks_good_data)s %(filter_length)s %(l_trans_bandwidth)s %(h_trans_bandwidth)s n_jobs : int | str Number of jobs to run in parallel. Can be 'cuda' if ``cupy`` is installed properly and method='fir'. method : str 'fir' will use overlap-add FIR filtering, 'iir' will use IIR filtering. iir_params : dict | None Dictionary of parameters to use for IIR filtering. See mne.filter.construct_iir_filter for details. If iir_params is None and method="iir", 4th order Butterworth will be used. %(fir_design)s See Also -------- TemporalFilter Notes ----- This is primarily meant for use in realtime applications. In general it is not recommended in a normal processing pipeline as it may result in edge artifacts. Use with caution. """ def __init__( self, info, l_freq, h_freq, picks=None, filter_length="auto", l_trans_bandwidth="auto", h_trans_bandwidth="auto", n_jobs=None, method="fir", iir_params=None, fir_design="firwin", ): self.info = info self.l_freq = l_freq self.h_freq = h_freq self.picks = picks self.filter_length = filter_length self.l_trans_bandwidth = l_trans_bandwidth self.h_trans_bandwidth = h_trans_bandwidth self.n_jobs = n_jobs self.method = method self.iir_params = iir_params self.fir_design = fir_design def fit(self, epochs_data, y): """Filter data. Parameters ---------- epochs_data : array, shape (n_epochs, n_channels, n_times) The data. y : array, shape (n_epochs,) The label for each epoch. Returns ------- self : instance of FilterEstimator The modified instance. """ self.picks_ = _picks_to_idx(self.info, self.picks) self._check_data(epochs_data, y=y, fit=True) if self.l_freq == 0: self.l_freq = None if self.info["lowpass"] is None or ( self.h_freq is not None and (self.l_freq is None or self.l_freq < self.h_freq) and self.h_freq < self.info["lowpass"] ): with self.info._unlock(): self.info["lowpass"] = self.h_freq if self.info["highpass"] is None or ( self.l_freq is not None and (self.h_freq is None or self.l_freq < self.h_freq) and self.l_freq > self.info["highpass"] ): with self.info._unlock(): self.info["highpass"] = self.l_freq return self def transform(self, epochs_data): """Filter data. Parameters ---------- epochs_data : array, shape (n_epochs, n_channels, n_times) The data. Returns ------- X : array, shape (n_epochs, n_channels, n_times) The data after filtering. """ return filter_data( self._check_data(epochs_data), self.info["sfreq"], self.l_freq, self.h_freq, self.picks_, self.filter_length, self.l_trans_bandwidth, self.h_trans_bandwidth, method=self.method, iir_params=self.iir_params, n_jobs=self.n_jobs, copy=False, fir_design=self.fir_design, verbose=False, ) class UnsupervisedSpatialFilter(MNETransformerMixin, BaseEstimator): """Use unsupervised spatial filtering across time and samples. Parameters ---------- estimator : instance of sklearn.base.BaseEstimator Estimator using some decomposition algorithm. average : bool, default False If True, the estimator is fitted on the average across samples (e.g. epochs). """ def __init__(self, estimator, average=False): self.estimator = estimator self.average = average def fit(self, X, y=None): """Fit the spatial filters. Parameters ---------- X : array, shape (n_epochs, n_channels, n_times) The data to be filtered. y : None | array, shape (n_samples,) Used for scikit-learn compatibility. Returns ------- self : instance of UnsupervisedSpatialFilter Return the modified instance. """ # sklearn.utils.estimator_checks.check_estimator(self.estimator) is probably # too strict for us, given that we don't fully adhere yet, so just check attrs for attr in ("fit", "transform", "fit_transform"): if not hasattr(self.estimator, attr): raise ValueError( "estimator must be a scikit-learn " f"transformer, missing {attr} method" ) _validate_type(self.average, bool, "average") X = self._check_data(X, y=y, fit=True) if self.average: X = np.mean(X, axis=0).T else: n_epochs, n_channels, n_times = X.shape # trial as time samples X = np.transpose(X, (1, 0, 2)).reshape((n_channels, n_epochs * n_times)).T self.estimator_ = clone(self.estimator) self.estimator_.fit(X) return self def fit_transform(self, X, y=None): """Transform the data to its filtered components after fitting. Parameters ---------- X : array, shape (n_epochs, n_channels, n_times) The data to be filtered. y : None | array, shape (n_samples,) Used for scikit-learn compatibility. Returns ------- X : array, shape (n_epochs, n_channels, n_times) The transformed data. """ return self.fit(X).transform(X) def transform(self, X): """Transform the data to its spatial filters. Parameters ---------- X : array, shape (n_epochs, n_channels, n_times) The data to be filtered. Returns ------- X : array, shape (n_epochs, n_channels, n_times) The transformed data. """ check_is_fitted(self.estimator_) X = self._check_data(X) return self._apply_method(X, "transform") def inverse_transform(self, X): """Inverse transform the data to its original space. Parameters ---------- X : array, shape (n_epochs, n_components, n_times) The data to be inverted. Returns ------- X : array, shape (n_epochs, n_channels, n_times) The transformed data. """ return self._apply_method(X, "inverse_transform") def _apply_method(self, X, method): """Vectorize time samples as trials, apply method and reshape back. Parameters ---------- X : array, shape (n_epochs, n_dims, n_times) The data to be inverted. Returns ------- X : array, shape (n_epochs, n_dims, n_times) The transformed data. """ n_epochs, n_channels, n_times = X.shape # trial as time samples X = np.transpose(X, [1, 0, 2]) X = np.reshape(X, [n_channels, n_epochs * n_times]).T # apply method method = getattr(self.estimator_, method) X = method(X) # put it back to n_epochs, n_dimensions X = np.reshape(X.T, [-1, n_epochs, n_times]).transpose([1, 0, 2]) return X @fill_doc class TemporalFilter(MNETransformerMixin, BaseEstimator): """Estimator to filter data array along the last dimension. Applies a zero-phase low-pass, high-pass, band-pass, or band-stop filter to the channels. l_freq and h_freq are the frequencies below which and above which, respectively, to filter out of the data. Thus the uses are: - l_freq < h_freq: band-pass filter - l_freq > h_freq: band-stop filter - l_freq is not None, h_freq is None: low-pass filter - l_freq is None, h_freq is not None: high-pass filter See :func:`mne.filter.filter_data`. Parameters ---------- l_freq : float | None Low cut-off frequency in Hz. If None the data are only low-passed. h_freq : float | None High cut-off frequency in Hz. If None the data are only high-passed. sfreq : float, default 1.0 Sampling frequency in Hz. filter_length : str | int, default 'auto' Length of the FIR filter to use (if applicable): * int: specified length in samples. * 'auto' (default in 0.14): the filter length is chosen based on the size of the transition regions (7 times the reciprocal of the shortest transition band). * str: (default in 0.13 is "10s") a human-readable time in units of "s" or "ms" (e.g., "10s" or "5500ms") will be converted to that number of samples if ``phase="zero"``, or the shortest power-of-two length at least that duration for ``phase="zero-double"``. l_trans_bandwidth : float | str Width of the transition band at the low cut-off frequency in Hz (high pass or cutoff 1 in bandpass). Can be "auto" (default in 0.14) to use a multiple of ``l_freq``:: min(max(l_freq * 0.25, 2), l_freq) Only used for ``method='fir'``. h_trans_bandwidth : float | str Width of the transition band at the high cut-off frequency in Hz (low pass or cutoff 2 in bandpass). Can be "auto" (default in 0.14) to use a multiple of ``h_freq``:: min(max(h_freq * 0.25, 2.), info['sfreq'] / 2. - h_freq) Only used for ``method='fir'``. n_jobs : int | str, default 1 Number of jobs to run in parallel. Can be 'cuda' if ``cupy`` is installed properly and method='fir'. method : str, default 'fir' 'fir' will use overlap-add FIR filtering, 'iir' will use IIR forward-backward filtering (via filtfilt). iir_params : dict | None, default None Dictionary of parameters to use for IIR filtering. See mne.filter.construct_iir_filter for details. If iir_params is None and method="iir", 4th order Butterworth will be used. fir_window : str, default 'hamming' The window to use in FIR design, can be "hamming", "hann", or "blackman". fir_design : str Can be "firwin" (default) to use :func:`scipy.signal.firwin`, or "firwin2" to use :func:`scipy.signal.firwin2`. "firwin" uses a time-domain design technique that generally gives improved attenuation using fewer samples than "firwin2". .. versionadded:: 0.15 See Also -------- FilterEstimator Vectorizer mne.filter.filter_data """ def __init__( self, l_freq=None, h_freq=None, sfreq=1.0, filter_length="auto", l_trans_bandwidth="auto", h_trans_bandwidth="auto", n_jobs=None, method="fir", iir_params=None, fir_window="hamming", fir_design="firwin", ): self.l_freq = l_freq self.h_freq = h_freq self.sfreq = sfreq self.filter_length = filter_length self.l_trans_bandwidth = l_trans_bandwidth self.h_trans_bandwidth = h_trans_bandwidth self.n_jobs = n_jobs self.method = method self.iir_params = iir_params self.fir_window = fir_window self.fir_design = fir_design def __sklearn_tags__(self): """...""" tags = super().__sklearn_tags__() tags.target_tags.required = False tags.requires_fit = False return tags def fit(self, X, y=None): """Do nothing (for scikit-learn compatibility purposes). Parameters ---------- X : array, shape ([n_epochs, ]n_channels, n_times) The data to be filtered over the last dimension. The channels dimension can be zero when passing a 2D array. y : None Not used, for scikit-learn compatibility issues. Returns ------- self : instance of TemporalFilter The modified instance. """ self.fitted_ = True # sklearn compliance self._check_data(X, y=y, atleast_3d=False, fit=True) return self def transform(self, X): """Filter data along the last dimension. Parameters ---------- X : array, shape ([n_epochs, ]n_channels, n_times) The data to be filtered over the last dimension. The channels dimension can be zero when passing a 2D array. Returns ------- X : array The data after filtering. """ # noqa: E501 X = self._check_data(X, atleast_3d=False) X = np.atleast_2d(X) if X.ndim > 3: raise ValueError( "Array must be of at max 3 dimensions instead " f"got {X.ndim} dimensional matrix" ) shape = X.shape X = X.reshape(-1, shape[-1]) X = filter_data( X, self.sfreq, self.l_freq, self.h_freq, filter_length=self.filter_length, l_trans_bandwidth=self.l_trans_bandwidth, h_trans_bandwidth=self.h_trans_bandwidth, n_jobs=self.n_jobs, method=self.method, iir_params=self.iir_params, copy=False, fir_window=self.fir_window, fir_design=self.fir_design, ) return X.reshape(shape)