Spaces:
Running
Running
| # Authors: The MNE-Python contributors. | |
| # License: BSD-3-Clause | |
| # Copyright the MNE-Python contributors. | |
| from collections import Counter | |
| import numpy as np | |
| from sklearn.base import BaseEstimator | |
| from .._fiff.pick import _picks_to_idx, pick_info, pick_types | |
| from ..parallel import parallel_func | |
| from ..utils import logger, verbose | |
| from .base import _set_cv | |
| from .transformer import MNETransformerMixin | |
| class EMS(MNETransformerMixin, BaseEstimator): | |
| """Transformer to compute event-matched spatial filters. | |
| This version of EMS :footcite:`SchurgerEtAl2013` operates on the entire | |
| time course. No time | |
| window needs to be specified. The result is a spatial filter at each | |
| time point and a corresponding time course. Intuitively, the result | |
| gives the similarity between the filter at each time point and the | |
| data vector (sensors) at that time point. | |
| .. note:: EMS only works for binary classification. | |
| Attributes | |
| ---------- | |
| filters_ : ndarray, shape (n_channels, n_times) | |
| The set of spatial filters. | |
| classes_ : ndarray, shape (n_classes,) | |
| The target classes. | |
| References | |
| ---------- | |
| .. footbibliography:: | |
| """ | |
| def __sklearn_tags__(self): | |
| """Return sklearn tags.""" | |
| from sklearn.utils import ClassifierTags | |
| tags = super().__sklearn_tags__() | |
| if tags.classifier_tags is None: | |
| tags.classifier_tags = ClassifierTags() | |
| tags.classifier_tags.multi_class = False | |
| return tags | |
| def __repr__(self): # noqa: D105 | |
| if hasattr(self, "filters_"): | |
| return ( | |
| f"<EMS: fitted with {len(self.filters_)} filters " | |
| f"on {len(self.classes_)} classes.>" | |
| ) | |
| else: | |
| return "<EMS: not fitted.>" | |
| def fit(self, X, y): | |
| """Fit the spatial filters. | |
| .. note : EMS is fitted on data normalized by channel type before the | |
| fitting of the spatial filters. | |
| Parameters | |
| ---------- | |
| X : array, shape (n_epochs, n_channels, n_times) | |
| The training data. | |
| y : array of int, shape (n_epochs) | |
| The target classes. | |
| Returns | |
| ------- | |
| self : instance of EMS | |
| Returns self. | |
| """ | |
| X, y = self._check_data(X, y=y, fit=True, return_y=True) | |
| classes, y = np.unique(y, return_inverse=True) | |
| if len(classes) > 2: | |
| raise ValueError("EMS only works for binary classification.") | |
| self.classes_ = classes | |
| filters = X[y == 0].mean(0) - X[y == 1].mean(0) | |
| filters /= np.linalg.norm(filters, axis=0)[None, :] | |
| self.filters_ = filters | |
| return self | |
| def transform(self, X): | |
| """Transform the data by the spatial filters. | |
| Parameters | |
| ---------- | |
| X : array, shape (n_epochs, n_channels, n_times) | |
| The input data. | |
| Returns | |
| ------- | |
| X : array, shape (n_epochs, n_times) | |
| The input data transformed by the spatial filters. | |
| """ | |
| X = self._check_data(X) | |
| Xt = np.sum(X * self.filters_, axis=1) | |
| return Xt | |
| def compute_ems( | |
| epochs, conditions=None, picks=None, n_jobs=None, cv=None, *, verbose=None | |
| ): | |
| """Compute event-matched spatial filter on epochs. | |
| This version of EMS :footcite:`SchurgerEtAl2013` operates on the entire | |
| time course. No time | |
| window needs to be specified. The result is a spatial filter at each | |
| time point and a corresponding time course. Intuitively, the result | |
| gives the similarity between the filter at each time point and the | |
| data vector (sensors) at that time point. | |
| .. note : EMS only works for binary classification. | |
| .. note : The present function applies a leave-one-out cross-validation, | |
| following Schurger et al's paper. However, we recommend using | |
| a stratified k-fold cross-validation. Indeed, leave-one-out tends | |
| to overfit and cannot be used to estimate the variance of the | |
| prediction within a given fold. | |
| .. note : Because of the leave-one-out, this function needs an equal | |
| number of epochs in each of the two conditions. | |
| Parameters | |
| ---------- | |
| epochs : instance of mne.Epochs | |
| The epochs. | |
| conditions : list of str | None, default None | |
| If a list of strings, strings must match the epochs.event_id's key as | |
| well as the number of conditions supported by the objective_function. | |
| If None keys in epochs.event_id are used. | |
| %(picks_good_data)s | |
| %(n_jobs)s | |
| cv : cross-validation object | str | None, default LeaveOneOut | |
| The cross-validation scheme. | |
| %(verbose)s | |
| Returns | |
| ------- | |
| surrogate_trials : ndarray, shape (n_trials // 2, n_times) | |
| The trial surrogates. | |
| mean_spatial_filter : ndarray, shape (n_channels, n_times) | |
| The set of spatial filters. | |
| conditions : ndarray, shape (n_classes,) | |
| The conditions used. Values correspond to original event ids. | |
| References | |
| ---------- | |
| .. footbibliography:: | |
| """ | |
| logger.info("...computing surrogate time series. This can take some time") | |
| # Default to leave-one-out cv | |
| cv = "LeaveOneOut" if cv is None else cv | |
| picks = _picks_to_idx(epochs.info, picks) | |
| if not len(set(Counter(epochs.events[:, 2]).values())) == 1: | |
| raise ValueError( | |
| "The same number of epochs is required by " | |
| "this function. Please consider " | |
| "`epochs.equalize_event_counts`" | |
| ) | |
| if conditions is None: | |
| conditions = epochs.event_id.keys() | |
| epochs = epochs.copy() | |
| else: | |
| epochs = epochs[conditions] | |
| epochs.drop_bad() | |
| if len(conditions) != 2: | |
| raise ValueError( | |
| "Currently this function expects exactly 2 " | |
| f"conditions but you gave me {len(conditions)}" | |
| ) | |
| ev = epochs.events[:, 2] | |
| # Special care to avoid path dependent mappings and orders | |
| conditions = list(sorted(conditions)) | |
| cond_idx = [np.where(ev == epochs.event_id[k])[0] for k in conditions] | |
| info = pick_info(epochs.info, picks) | |
| data = epochs.get_data(picks=picks) | |
| # Scale (z-score) the data by channel type | |
| # XXX the z-scoring is applied outside the CV, which is not standard. | |
| for ch_type in ["mag", "grad", "eeg"]: | |
| if ch_type in epochs: | |
| # FIXME should be applied to all sort of data channels | |
| if ch_type == "eeg": | |
| this_picks = pick_types(info, meg=False, eeg=True) | |
| else: | |
| this_picks = pick_types(info, meg=ch_type, eeg=False) | |
| data[:, this_picks] /= np.std(data[:, this_picks]) | |
| # Setup cross-validation. Need to use _set_cv to deal with sklearn | |
| # changes in cv object handling. | |
| y = epochs.events[:, 2] | |
| _, cv_splits = _set_cv(cv, "classifier", X=y, y=y) | |
| parallel, p_func, n_jobs = parallel_func(_run_ems, n_jobs=n_jobs) | |
| # FIXME this parallelization should be removed. | |
| # 1) it's numpy computation so it's already efficient, | |
| # 2) it duplicates the data in RAM, | |
| # 3) the computation is already super fast. | |
| out = parallel( | |
| p_func(_ems_diff, data, cond_idx, train, test) for train, test in cv_splits | |
| ) | |
| surrogate_trials, spatial_filter = zip(*out) | |
| surrogate_trials = np.array(surrogate_trials) | |
| spatial_filter = np.mean(spatial_filter, axis=0) | |
| return surrogate_trials, spatial_filter, epochs.events[:, 2] | |
| def _ems_diff(data0, data1): | |
| """Compute the default diff objective function.""" | |
| return np.mean(data0, axis=0) - np.mean(data1, axis=0) | |
| def _run_ems(objective_function, data, cond_idx, train, test): | |
| """Run EMS.""" | |
| d = objective_function(*(data[np.intersect1d(c, train)] for c in cond_idx)) | |
| d /= np.sqrt(np.sum(d**2, axis=0))[None, :] | |
| # compute surrogates | |
| return np.sum(data[test[0]] * d, axis=0), d | |