Spaces:
Running
Running
| """Container classes for spectral data.""" | |
| # Authors: The MNE-Python contributors. | |
| # License: BSD-3-Clause | |
| # Copyright the MNE-Python contributors. | |
| from copy import deepcopy | |
| from functools import partial | |
| from inspect import signature | |
| import numpy as np | |
| from .._fiff.meas_info import ContainsMixin, Info | |
| from .._fiff.pick import _pick_data_channels, _picks_to_idx, pick_info | |
| from ..channels.channels import UpdateChannelsMixin | |
| from ..channels.layout import _merge_ch_data, find_layout | |
| from ..defaults import ( | |
| _BORDER_DEFAULT, | |
| _EXTRAPOLATE_DEFAULT, | |
| _INTERPOLATION_DEFAULT, | |
| _handle_default, | |
| ) | |
| from ..html_templates import _get_html_template | |
| from ..utils import ( | |
| GetEpochsMixin, | |
| _build_data_frame, | |
| _check_method_kwargs, | |
| _check_pandas_index_arguments, | |
| _check_pandas_installed, | |
| _check_sphere, | |
| _time_mask, | |
| _validate_type, | |
| fill_doc, | |
| legacy, | |
| logger, | |
| object_diff, | |
| repr_html, | |
| verbose, | |
| warn, | |
| ) | |
| from ..utils.check import ( | |
| _check_fname, | |
| _check_option, | |
| _import_h5io_funcs, | |
| _is_numeric, | |
| check_fname, | |
| ) | |
| from ..utils.misc import _pl | |
| from ..utils.spectrum import _get_instance_type_string, _split_psd_kwargs | |
| from ..viz.topo import _plot_timeseries, _plot_timeseries_unified, _plot_topo | |
| from ..viz.topomap import _make_head_outlines, _prepare_topomap_plot, plot_psds_topomap | |
| from ..viz.utils import ( | |
| _format_units_psd, | |
| _get_plot_ch_type, | |
| _make_combine_callable, | |
| _plot_psd, | |
| _prepare_sensor_names, | |
| plt_show, | |
| ) | |
| from .multitaper import _psd_from_mt, psd_array_multitaper | |
| from .psd import _check_nfft, psd_array_welch | |
| class SpectrumMixin: | |
| """Mixin providing spectral plotting methods to sensor-space containers.""" | |
| def plot_psd( | |
| self, | |
| fmin=0, | |
| fmax=np.inf, | |
| tmin=None, | |
| tmax=None, | |
| picks=None, | |
| proj=False, | |
| reject_by_annotation=True, | |
| *, | |
| method="auto", | |
| average=False, | |
| dB=True, | |
| estimate="power", | |
| xscale="linear", | |
| area_mode="std", | |
| area_alpha=0.33, | |
| color="black", | |
| line_alpha=None, | |
| spatial_colors=True, | |
| sphere=None, | |
| exclude="bads", | |
| ax=None, | |
| show=True, | |
| n_jobs=1, | |
| verbose=None, | |
| **method_kw, | |
| ): | |
| """%(plot_psd_doc)s. | |
| Parameters | |
| ---------- | |
| %(fmin_fmax_psd)s | |
| %(tmin_tmax_psd)s | |
| %(picks_good_data_noref)s | |
| %(proj_psd)s | |
| %(reject_by_annotation_psd)s | |
| %(method_plot_psd_auto)s | |
| %(average_plot_psd)s | |
| %(dB_plot_psd)s | |
| %(estimate_plot_psd)s | |
| %(xscale_plot_psd)s | |
| %(area_mode_plot_psd)s | |
| %(area_alpha_plot_psd)s | |
| %(color_plot_psd)s | |
| %(line_alpha_plot_psd)s | |
| %(spatial_colors_psd)s | |
| %(sphere_topomap_auto)s | |
| .. versionadded:: 0.22.0 | |
| exclude : list of str | 'bads' | |
| Channels names to exclude from being shown. If 'bads', the bad | |
| channels are excluded. Pass an empty list to plot all channels | |
| (including channels marked "bad", if any). | |
| .. versionadded:: 0.24.0 | |
| %(ax_plot_psd)s | |
| %(show)s | |
| %(n_jobs)s | |
| %(verbose)s | |
| %(method_kw_psd)s | |
| Returns | |
| ------- | |
| fig : instance of Figure | |
| Figure with frequency spectra of the data channels. | |
| Notes | |
| ----- | |
| %(notes_plot_psd_meth)s | |
| """ | |
| init_kw, plot_kw = _split_psd_kwargs(plot_fun=Spectrum.plot) | |
| return self.compute_psd(**init_kw).plot(**plot_kw) | |
| def plot_psd_topo( | |
| self, | |
| tmin=None, | |
| tmax=None, | |
| fmin=0, | |
| fmax=100, | |
| proj=False, | |
| *, | |
| method="auto", | |
| dB=True, | |
| layout=None, | |
| color="w", | |
| fig_facecolor="k", | |
| axis_facecolor="k", | |
| axes=None, | |
| block=False, | |
| show=True, | |
| n_jobs=None, | |
| verbose=None, | |
| **method_kw, | |
| ): | |
| """Plot power spectral density, separately for each channel. | |
| Parameters | |
| ---------- | |
| %(tmin_tmax_psd)s | |
| %(fmin_fmax_psd_topo)s | |
| %(proj_psd)s | |
| %(method_plot_psd_auto)s | |
| %(dB_spectrum_plot_topo)s | |
| %(layout_spectrum_plot_topo)s | |
| %(color_spectrum_plot_topo)s | |
| %(fig_facecolor)s | |
| %(axis_facecolor)s | |
| %(axes_spectrum_plot_topo)s | |
| %(block)s | |
| %(show)s | |
| %(n_jobs)s | |
| %(verbose)s | |
| %(method_kw_psd)s Defaults to ``dict(n_fft=2048)``. | |
| Returns | |
| ------- | |
| fig : instance of matplotlib.figure.Figure | |
| Figure distributing one image per channel across sensor topography. | |
| """ | |
| init_kw, plot_kw = _split_psd_kwargs(plot_fun=Spectrum.plot_topo) | |
| return self.compute_psd(**init_kw).plot_topo(**plot_kw) | |
| def plot_psd_topomap( | |
| self, | |
| bands=None, | |
| tmin=None, | |
| tmax=None, | |
| ch_type=None, | |
| *, | |
| proj=False, | |
| method="auto", | |
| normalize=False, | |
| agg_fun=None, | |
| dB=False, | |
| sensors=True, | |
| show_names=False, | |
| mask=None, | |
| mask_params=None, | |
| contours=0, | |
| outlines="head", | |
| sphere=None, | |
| image_interp=_INTERPOLATION_DEFAULT, | |
| extrapolate=_EXTRAPOLATE_DEFAULT, | |
| border=_BORDER_DEFAULT, | |
| res=64, | |
| size=1, | |
| cmap=None, | |
| vlim=(None, None), | |
| cnorm=None, | |
| colorbar=True, | |
| cbar_fmt="auto", | |
| units=None, | |
| axes=None, | |
| show=True, | |
| n_jobs=None, | |
| verbose=None, | |
| **method_kw, | |
| ): | |
| """Plot scalp topography of PSD for chosen frequency bands. | |
| Parameters | |
| ---------- | |
| %(bands_psd_topo)s | |
| %(tmin_tmax_psd)s | |
| %(ch_type_topomap_psd)s | |
| %(proj_psd)s | |
| %(method_plot_psd_auto)s | |
| %(normalize_psd_topo)s | |
| %(agg_fun_psd_topo)s | |
| %(dB_plot_topomap)s | |
| %(sensors_topomap)s | |
| %(show_names_topomap)s | |
| %(mask_evoked_topomap)s | |
| %(mask_params_topomap)s | |
| %(contours_topomap)s | |
| %(outlines_topomap)s | |
| %(sphere_topomap_auto)s | |
| %(image_interp_topomap)s | |
| %(extrapolate_topomap)s | |
| %(border_topomap)s | |
| %(res_topomap)s | |
| %(size_topomap)s | |
| %(cmap_topomap)s | |
| %(vlim_plot_topomap_psd)s | |
| %(cnorm)s | |
| .. versionadded:: 1.2 | |
| %(colorbar_topomap)s | |
| %(cbar_fmt_topomap_psd)s | |
| %(units_topomap)s | |
| %(axes_spectrum_plot_topomap)s | |
| %(show)s | |
| %(n_jobs)s | |
| %(verbose)s | |
| %(method_kw_psd)s | |
| Returns | |
| ------- | |
| fig : instance of Figure | |
| Figure showing one scalp topography per frequency band. | |
| """ | |
| init_kw, plot_kw = _split_psd_kwargs(plot_fun=Spectrum.plot_topomap) | |
| return self.compute_psd(**init_kw).plot_topomap(**plot_kw) | |
| def _set_legacy_nfft_default(self, tmin, tmax, method, method_kw): | |
| """Update method_kw with legacy n_fft default for plot_psd[_topo](). | |
| This method returns ``None`` and has a side effect of (maybe) updating | |
| the ``method_kw`` dict. | |
| """ | |
| if method == "welch" and method_kw.get("n_fft") is None: | |
| tm = _time_mask(self.times, tmin, tmax, sfreq=self.info["sfreq"]) | |
| method_kw["n_fft"] = min(np.sum(tm), 2048) | |
| class BaseSpectrum(ContainsMixin, UpdateChannelsMixin): | |
| """Base class for Spectrum and EpochsSpectrum.""" | |
| def __init__( | |
| self, | |
| inst, | |
| method, | |
| fmin, | |
| fmax, | |
| tmin, | |
| tmax, | |
| picks, | |
| exclude, | |
| proj, | |
| remove_dc, | |
| *, | |
| n_jobs, | |
| verbose=None, | |
| **method_kw, | |
| ): | |
| # arg checking | |
| self._sfreq = inst.info["sfreq"] | |
| if np.isfinite(fmax) and (fmax > self.sfreq / 2): | |
| raise ValueError( | |
| f"Requested fmax ({fmax} Hz) must not exceed ½ the sampling " | |
| f"frequency of the data ({0.5 * inst.info['sfreq']} Hz)." | |
| ) | |
| # method | |
| self._inst_type = type(inst) | |
| method = _validate_method(method, _get_instance_type_string(self)) | |
| psd_funcs = dict(welch=psd_array_welch, multitaper=psd_array_multitaper) | |
| # triage method and kwargs. partial() doesn't check validity of kwargs, | |
| # so we do it manually to save compute time if any are invalid. | |
| psd_funcs = dict(welch=psd_array_welch, multitaper=psd_array_multitaper) | |
| _check_method_kwargs(psd_funcs[method], method_kw, msg=f'PSD method "{method}"') | |
| self._psd_func = partial(psd_funcs[method], remove_dc=remove_dc, **method_kw) | |
| # apply proj if desired | |
| if proj: | |
| inst = inst.copy().apply_proj() | |
| self.inst = inst | |
| # prep times and picks | |
| self._time_mask = _time_mask(inst.times, tmin, tmax, sfreq=self.sfreq) | |
| self._picks = _picks_to_idx( | |
| inst.info, picks, "data", exclude, with_ref_meg=False | |
| ) | |
| # add the info object. bads and non-data channels were dropped by | |
| # _picks_to_idx() so we update the info accordingly: | |
| self.info = pick_info(inst.info, sel=self._picks, copy=True) | |
| # assign some attributes | |
| self.preload = True # needed for __getitem__, never False | |
| self._method = method | |
| # self._dims may also get updated by child classes | |
| self._dims = ( | |
| "channel", | |
| "freq", | |
| ) | |
| if method_kw.get("average", "") in (None, False): | |
| self._dims += ("segment",) | |
| if self._returns_complex_tapers(**method_kw): | |
| self._dims = self._dims[:-1] + ("taper",) + self._dims[-1:] | |
| # record data type (for repr and html_repr) | |
| self._data_type = ( | |
| "Fourier Coefficients" | |
| if method_kw.get("output") == "complex" | |
| else "Power Spectrum" | |
| ) | |
| # set nave (child constructor overrides this for Evoked input) | |
| self._nave = None | |
| def __eq__(self, other): | |
| """Test equivalence of two Spectrum instances.""" | |
| return object_diff(vars(self), vars(other)) == "" | |
| def __getstate__(self): | |
| """Prepare object for serialization.""" | |
| inst_type_str = _get_instance_type_string(self) | |
| out = dict( | |
| method=self.method, | |
| data=self._data, | |
| sfreq=self.sfreq, | |
| dims=self._dims, | |
| freqs=self.freqs, | |
| inst_type_str=inst_type_str, | |
| data_type=self._data_type, | |
| info=self.info, | |
| nave=self.nave, | |
| weights=self.weights, | |
| ) | |
| return out | |
| def __setstate__(self, state): | |
| """Unpack from serialized format.""" | |
| from ..epochs import Epochs | |
| from ..evoked import Evoked | |
| from ..io import Raw | |
| self._method = state["method"] | |
| self._data = state["data"] | |
| self._freqs = state["freqs"] | |
| self._dims = state["dims"] | |
| self._sfreq = state["sfreq"] | |
| self.info = Info(**state["info"]) | |
| self._data_type = state["data_type"] | |
| self._nave = state.get("nave") # objs saved before #11282 won't have `nave` | |
| self._weights = state.get("weights") # objs saved before #12747 won't have | |
| self.preload = True | |
| # instance type | |
| inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Array=np.ndarray) | |
| self._inst_type = inst_types[state["inst_type_str"]] | |
| def __repr__(self): | |
| """Build string representation of the Spectrum object.""" | |
| inst_type_str = _get_instance_type_string(self) | |
| # shape & dimension names | |
| dims = " × ".join( | |
| [f"{dim[0]} {dim[1]}s" for dim in zip(self.shape, self._dims)] | |
| ) | |
| freq_range = f"{self.freqs[0]:0.1f}-{self.freqs[-1]:0.1f} Hz" | |
| return ( | |
| f"<{self._data_type} (from {inst_type_str}, " | |
| f"{self.method} method) | {dims}, {freq_range}>" | |
| ) | |
| def _repr_html_(self, caption=None): | |
| """Build HTML representation of the Spectrum object.""" | |
| inst_type_str = _get_instance_type_string(self) | |
| units = [f"{ch_type}: {unit}" for ch_type, unit in self.units().items()] | |
| t = _get_html_template("repr", "spectrum.html.jinja") | |
| t = t.render( | |
| inst=self, computed_from=inst_type_str, units=units, filenames=None | |
| ) | |
| return t | |
| def _check_values(self): | |
| """Check PSD results for correct shape and bad values.""" | |
| assert len(self._dims) == self._data.ndim, (self._dims, self._data.ndim) | |
| assert self._data.shape == self._shape | |
| # TODO: should this be more fine-grained (report "chan X in epoch Y")? | |
| ch_dim = self._dims.index("channel") | |
| dims = list(range(self._data.ndim)) | |
| dims.pop(ch_dim) | |
| # take min() across all but the channel axis | |
| # (if the abs becomes memory intensive we could iterate over channels) | |
| use_data = self._data | |
| if use_data.dtype.kind == "c": | |
| use_data = np.abs(use_data) | |
| bad_value = use_data.min(axis=tuple(dims)) == 0 | |
| bad_value &= ~np.isin(self.ch_names, self.info["bads"]) | |
| if bad_value.any(): | |
| chs = np.array(self.ch_names)[bad_value].tolist() | |
| s = _pl(bad_value.sum()) | |
| warn(f"Zero value in spectrum for channel{s} {', '.join(chs)}", UserWarning) | |
| def _returns_complex_tapers(self, **method_kw): | |
| return self.method == "multitaper" and method_kw.get("output") == "complex" | |
| def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose): | |
| # make the spectra | |
| result = self._psd_func( | |
| data, self.sfreq, fmin=fmin, fmax=fmax, n_jobs=n_jobs, verbose=verbose | |
| ) | |
| # assign ._data (handling unaggregated multitaper output) | |
| if self._returns_complex_tapers(**method_kw): | |
| fourier_coefs, freqs, weights = result | |
| self._data = fourier_coefs | |
| self._weights = weights | |
| else: | |
| psds, freqs = result | |
| self._data = psds | |
| self._weights = None | |
| # assign properties (._data already assigned above) | |
| self._freqs = freqs | |
| # this is *expected* shape, it gets asserted later in _check_values() | |
| # (and then deleted afterwards) | |
| self._shape = (len(self.ch_names), len(self.freqs)) | |
| # append n_welch_segments (use "" as .get() default since None considered valid) | |
| if method_kw.get("average", "") in (None, False): | |
| n_welch_segments = _compute_n_welch_segments(data.shape[-1], method_kw) | |
| self._shape += (n_welch_segments,) | |
| # insert n_tapers | |
| if self._returns_complex_tapers(**method_kw): | |
| self._shape = self._shape[:-1] + (self._weights.size,) + self._shape[-1:] | |
| # we don't need these anymore, and they make save/load harder | |
| del self._picks | |
| del self._psd_func | |
| del self._time_mask | |
| def _detrend_picks(self): | |
| """Provide compatibility with __iter__.""" | |
| return list() | |
| def ch_names(self): | |
| return self.info["ch_names"] | |
| def data(self): | |
| return self._data | |
| def freqs(self): | |
| return self._freqs | |
| def method(self): | |
| return self._method | |
| def nave(self): | |
| return self._nave | |
| def nave(self, nave): | |
| self._nave = nave | |
| def weights(self): | |
| return self._weights | |
| def sfreq(self): | |
| return self._sfreq | |
| def shape(self): | |
| return self._data.shape | |
| def copy(self): | |
| """Return copy of the Spectrum instance. | |
| Returns | |
| ------- | |
| spectrum : instance of Spectrum | |
| A copy of the object. | |
| """ | |
| return deepcopy(self) | |
| def get_data( | |
| self, picks=None, exclude="bads", fmin=0, fmax=np.inf, return_freqs=False | |
| ): | |
| """Get spectrum data in NumPy array format. | |
| Parameters | |
| ---------- | |
| %(picks_good_data_noref)s | |
| %(exclude_spectrum_get_data)s | |
| %(fmin_fmax_psd)s | |
| return_freqs : bool | |
| Whether to return the frequency bin values for the requested | |
| frequency range. Default is ``False``. | |
| Returns | |
| ------- | |
| data : array | |
| The requested data in a NumPy array. | |
| freqs : array | |
| The frequency values for the requested range. Only returned if | |
| ``return_freqs`` is ``True``. | |
| """ | |
| picks = _picks_to_idx( | |
| self.info, picks, "data_or_ica", exclude=exclude, with_ref_meg=False | |
| ) | |
| fmin_idx = np.searchsorted(self.freqs, fmin) | |
| fmax_idx = np.searchsorted(self.freqs, fmax, side="right") | |
| freq_picks = np.arange(fmin_idx, fmax_idx) | |
| freq_axis = self._dims.index("freq") | |
| chan_axis = self._dims.index("channel") | |
| # normally there's a risk of np.take reducing array dimension if there | |
| # were only one channel or frequency selected, but `_picks_to_idx` | |
| # always returns an array of picks, and np.arange always returns an | |
| # array of freq bin indices, so we're safe; the result will always be | |
| # 2D. | |
| data = self._data.take(picks, chan_axis).take(freq_picks, freq_axis) | |
| if return_freqs: | |
| freqs = self._freqs[fmin_idx:fmax_idx] | |
| return (data, freqs) | |
| return data | |
| def plot( | |
| self, | |
| *, | |
| picks=None, | |
| average=False, | |
| dB=True, | |
| amplitude=False, | |
| xscale="linear", | |
| ci="sd", | |
| ci_alpha=0.3, | |
| color="black", | |
| alpha=None, | |
| spatial_colors=True, | |
| sphere=None, | |
| exclude=(), | |
| axes=None, | |
| show=True, | |
| ): | |
| """%(plot_psd_doc)s. | |
| Parameters | |
| ---------- | |
| %(picks_all_data_noref)s | |
| .. versionchanged:: 1.5 | |
| In version 1.5, the default behavior changed so that all | |
| :term:`data channels` (not just "good" data channels) are shown by | |
| default. | |
| average : bool | |
| Whether to average across channels before plotting. If ``True``, interactive | |
| plotting of scalp topography is disabled, and parameters ``ci`` and | |
| ``ci_alpha`` control the style of the confidence band around the mean. | |
| Default is ``False``. | |
| %(dB_spectrum_plot)s | |
| amplitude : bool | |
| Whether to plot an amplitude spectrum (``True``) or power spectrum | |
| (``False``). | |
| .. versionchanged:: 1.8 | |
| In version 1.8, the default changed to ``amplitude=False``. | |
| %(xscale_plot_psd)s | |
| ci : float | 'sd' | 'range' | None | |
| Type of confidence band drawn around the mean when ``average=True``. If | |
| ``'sd'`` the band spans ±1 standard deviation across channels. If | |
| ``'range'`` the band spans the range across channels at each frequency. If a | |
| :class:`float`, it indicates the (bootstrapped) confidence interval to | |
| display, and must satisfy ``0 < ci <= 100``. If ``None``, no band is drawn. | |
| Default is ``sd``. | |
| ci_alpha : float | |
| Opacity of the confidence band. Must satisfy ``0 <= ci_alpha <= 1``. Default | |
| is 0.3. | |
| %(color_plot_psd)s | |
| alpha : float | None | |
| Opacity of the spectrum line(s). If :class:`float`, must satisfy | |
| ``0 <= alpha <= 1``. If ``None``, opacity will be ``1`` when | |
| ``average=True`` and ``0.1`` when ``average=False``. Default is ``None``. | |
| %(spatial_colors_psd)s | |
| %(sphere_topomap_auto)s | |
| %(exclude_spectrum_plot)s | |
| .. versionchanged:: 1.5 | |
| In version 1.5, the default behavior changed from ``exclude='bads'`` to | |
| ``exclude=()``. | |
| %(axes_spectrum_plot_topomap)s | |
| %(show)s | |
| Returns | |
| ------- | |
| fig : instance of matplotlib.figure.Figure | |
| Figure with spectra plotted in separate subplots for each channel type. | |
| """ | |
| # Must nest this _mpl_figure import because of the BACKEND global | |
| # stuff | |
| from ..viz._mpl_figure import _line_figure, _split_picks_by_type | |
| # arg checking | |
| ci = _check_ci(ci) | |
| _check_option("xscale", xscale, ("log", "linear")) | |
| sphere = _check_sphere(sphere, self.info) | |
| # defaults | |
| scalings = _handle_default("scalings", None) | |
| titles = _handle_default("titles", None) | |
| units = _handle_default("units", None) | |
| _validate_type(amplitude, bool, "amplitude") | |
| estimate = "amplitude" if amplitude else "power" | |
| logger.info(f"Plotting {estimate} spectral density ({dB=}).") | |
| # split picks by channel type | |
| picks = _picks_to_idx( | |
| self.info, picks, "data", exclude=exclude, with_ref_meg=False | |
| ) | |
| (picks_list, units_list, scalings_list, titles_list) = _split_picks_by_type( | |
| self, picks, units, scalings, titles | |
| ) | |
| # prepare data (e.g. aggregate across dims, convert complex to power) | |
| psd_list = [ | |
| self._prepare_data_for_plot( | |
| self._data.take(_p, axis=self._dims.index("channel")) | |
| ) | |
| for _p in picks_list | |
| ] | |
| # initialize figure | |
| fig, axes = _line_figure(self, axes, picks=picks) | |
| # don't add ylabels & titles if figure has unexpected number of axes | |
| make_label = len(axes) == len(fig.axes) | |
| # Plot Frequency [Hz] xlabel only on the last axis | |
| xlabels_list = [False] * (len(axes) - 1) + [True] | |
| # plot | |
| _plot_psd( | |
| self, | |
| fig, | |
| self.freqs, | |
| psd_list, | |
| picks_list, | |
| titles_list, | |
| units_list, | |
| scalings_list, | |
| axes, | |
| make_label, | |
| color, | |
| area_mode=ci, | |
| area_alpha=ci_alpha, | |
| dB=dB, | |
| estimate=estimate, | |
| average=average, | |
| spatial_colors=spatial_colors, | |
| xscale=xscale, | |
| line_alpha=alpha, | |
| sphere=sphere, | |
| xlabels_list=xlabels_list, | |
| ) | |
| plt_show(show, fig) | |
| return fig | |
| def plot_topo( | |
| self, | |
| *, | |
| dB=True, | |
| layout=None, | |
| color="w", | |
| fig_facecolor="k", | |
| axis_facecolor="k", | |
| axes=None, | |
| block=False, | |
| show=True, | |
| ): | |
| """Plot power spectral density, separately for each channel. | |
| Parameters | |
| ---------- | |
| %(dB_spectrum_plot_topo)s | |
| %(layout_spectrum_plot_topo)s | |
| %(color_spectrum_plot_topo)s | |
| %(fig_facecolor)s | |
| %(axis_facecolor)s | |
| %(axes_spectrum_plot_topo)s | |
| %(block)s | |
| %(show)s | |
| Returns | |
| ------- | |
| fig : instance of matplotlib.figure.Figure | |
| Figure distributing one image per channel across sensor topography. | |
| """ | |
| if layout is None: | |
| layout = find_layout(self.info) | |
| psds, freqs = self.get_data(return_freqs=True) | |
| # prepare data (e.g. aggregate across dims, convert complex to power) | |
| psds = self._prepare_data_for_plot(psds) | |
| if dB: | |
| psds = 10 * np.log10(psds) | |
| y_label = "dB" | |
| else: | |
| y_label = "Power" | |
| show_func = partial( | |
| _plot_timeseries_unified, data=[psds], color=color, times=[freqs] | |
| ) | |
| click_func = partial(_plot_timeseries, data=[psds], color=color, times=[freqs]) | |
| picks = _pick_data_channels(self.info) | |
| info = pick_info(self.info, picks) | |
| fig = _plot_topo( | |
| info, | |
| times=freqs, | |
| show_func=show_func, | |
| click_func=click_func, | |
| layout=layout, | |
| axis_facecolor=axis_facecolor, | |
| fig_facecolor=fig_facecolor, | |
| x_label="Frequency (Hz)", | |
| unified=True, | |
| y_label=y_label, | |
| axes=axes, | |
| ) | |
| plt_show(show, block=block) | |
| return fig | |
| def plot_topomap( | |
| self, | |
| bands=None, | |
| ch_type=None, | |
| *, | |
| normalize=False, | |
| agg_fun=None, | |
| dB=False, | |
| sensors=True, | |
| show_names=False, | |
| mask=None, | |
| mask_params=None, | |
| contours=6, | |
| outlines="head", | |
| sphere=None, | |
| image_interp=_INTERPOLATION_DEFAULT, | |
| extrapolate=_EXTRAPOLATE_DEFAULT, | |
| border=_BORDER_DEFAULT, | |
| res=64, | |
| size=1, | |
| cmap=None, | |
| vlim=(None, None), | |
| cnorm=None, | |
| colorbar=True, | |
| cbar_fmt="auto", | |
| units=None, | |
| axes=None, | |
| show=True, | |
| ): | |
| """Plot scalp topography of PSD for chosen frequency bands. | |
| Parameters | |
| ---------- | |
| %(bands_psd_topo)s | |
| %(ch_type_topomap_psd)s | |
| %(normalize_psd_topo)s | |
| %(agg_fun_psd_topo)s | |
| %(dB_plot_topomap)s | |
| %(sensors_topomap)s | |
| %(show_names_topomap)s | |
| %(mask_evoked_topomap)s | |
| %(mask_params_topomap)s | |
| %(contours_topomap)s | |
| %(outlines_topomap)s | |
| %(sphere_topomap_auto)s | |
| %(image_interp_topomap)s | |
| %(extrapolate_topomap)s | |
| %(border_topomap)s | |
| %(res_topomap)s | |
| %(size_topomap)s | |
| %(cmap_topomap)s | |
| %(vlim_plot_topomap_psd)s | |
| %(cnorm)s | |
| %(colorbar_topomap)s | |
| %(cbar_fmt_topomap_psd)s | |
| %(units_topomap)s | |
| %(axes_spectrum_plot_topomap)s | |
| %(show)s | |
| Returns | |
| ------- | |
| fig : instance of Figure | |
| Figure showing one scalp topography per frequency band. | |
| """ | |
| ch_type = _get_plot_ch_type(self, ch_type) | |
| if units is None: | |
| units = _handle_default("units", None) | |
| unit = units[ch_type] if hasattr(units, "keys") else units | |
| scalings = _handle_default("scalings", None) | |
| scaling = scalings[ch_type] | |
| ( | |
| picks, | |
| pos, | |
| merge_channels, | |
| names, | |
| ch_type, | |
| sphere, | |
| clip_origin, | |
| ) = _prepare_topomap_plot(self, ch_type, sphere=sphere) | |
| outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) | |
| psds, freqs = self.get_data(picks=picks, return_freqs=True) | |
| # prepare data (e.g. aggregate across dims, convert complex to power) | |
| psds = self._prepare_data_for_plot(psds) | |
| psds *= scaling**2 | |
| if merge_channels: | |
| psds, names = _merge_ch_data(psds, ch_type, names, method="mean") | |
| names = _prepare_sensor_names(names, show_names) | |
| return plot_psds_topomap( | |
| psds=psds, | |
| freqs=freqs, | |
| pos=pos, | |
| bands=bands, | |
| ch_type=ch_type, | |
| normalize=normalize, | |
| agg_fun=agg_fun, | |
| dB=dB, | |
| sensors=sensors, | |
| names=names, | |
| mask=mask, | |
| mask_params=mask_params, | |
| contours=contours, | |
| outlines=outlines, | |
| sphere=sphere, | |
| image_interp=image_interp, | |
| extrapolate=extrapolate, | |
| border=border, | |
| res=res, | |
| size=size, | |
| cmap=cmap, | |
| vlim=vlim, | |
| cnorm=cnorm, | |
| colorbar=colorbar, | |
| cbar_fmt=cbar_fmt, | |
| unit=unit, | |
| axes=axes, | |
| show=show, | |
| ) | |
| def _prepare_data_for_plot(self, data): | |
| # handle unaggregated Welch | |
| if "segment" in self._dims: | |
| logger.info("Aggregating Welch estimates (median) before plotting...") | |
| data = np.nanmedian(data, axis=self._dims.index("segment")) | |
| # handle unaggregated multitaper (also handles complex -> power) | |
| elif "taper" in self._dims: | |
| logger.info("Aggregating multitaper estimates before plotting...") | |
| data = _psd_from_mt(data, self.weights) | |
| # handle complex data (should only be Welch remaining) | |
| if np.iscomplexobj(data): | |
| data = (data * data.conj()).real # Scaling may be slightly off | |
| # handle epochs | |
| if "epoch" in self._dims: | |
| # XXX TODO FIXME decide how to properly aggregate across repeated | |
| # measures (epochs) and non-repeated but correlated measures | |
| # (channels) when calculating stddev or a CI. For across-channel | |
| # aggregation, doi:10.1007/s10162-012-0321-8 used hotellings T**2 | |
| # with a correction factor that estimated data rank using monte | |
| # carlo simulations; seems like we could use our own data rank | |
| # estimation methods to similar effect. Their exact approach used | |
| # complex spectra though, here we've already converted to power; | |
| # not sure if that makes an important difference? Anyway that | |
| # aggregation would need to happen in the _plot_psd function | |
| # though, not here... for now we just average like we always did. | |
| # only log message if averaging will actually have an effect | |
| if data.shape[0] > 1: | |
| logger.info("Averaging across epochs before plotting...") | |
| # epoch axis should always be the first axis | |
| data = data.mean(axis=0) | |
| return data | |
| def save(self, fname, *, overwrite=False, verbose=None): | |
| """Save spectrum data to disk (in HDF5 format). | |
| Parameters | |
| ---------- | |
| fname : path-like | |
| Path of file to save to. | |
| %(overwrite)s | |
| %(verbose)s | |
| See Also | |
| -------- | |
| mne.time_frequency.read_spectrum | |
| """ | |
| _, write_hdf5 = _import_h5io_funcs() | |
| check_fname(fname, "spectrum", (".h5", ".hdf5")) | |
| fname = _check_fname(fname, overwrite=overwrite, verbose=verbose) | |
| out = self.__getstate__() | |
| write_hdf5(fname, out, overwrite=overwrite, title="mnepython", slash="replace") | |
| def to_data_frame( | |
| self, picks=None, index=None, copy=True, long_format=False, *, verbose=None | |
| ): | |
| """Export data in tabular structure as a pandas DataFrame. | |
| Channels are converted to columns in the DataFrame. By default, | |
| an additional column "freq" is added, unless ``index='freq'`` | |
| (in which case frequency values form the DataFrame's index). | |
| Parameters | |
| ---------- | |
| %(picks_all)s | |
| index : str | list of str | None | |
| Kind of index to use for the DataFrame. If ``None``, a sequential | |
| integer index (:class:`pandas.RangeIndex`) will be used. If a | |
| :class:`str`, a :class:`pandas.Index` will be used (see Notes). If | |
| a list of two or more string values, a :class:`pandas.MultiIndex` | |
| will be used. Defaults to ``None``. | |
| %(copy_df)s | |
| %(long_format_df_spe)s | |
| %(verbose)s | |
| Returns | |
| ------- | |
| %(df_return)s | |
| Notes | |
| ----- | |
| Valid values for ``index`` depend on whether the Spectrum was created | |
| from continuous data (:class:`~mne.io.Raw`, :class:`~mne.Evoked`) or | |
| discontinuous data (:class:`~mne.Epochs`). For continuous data, only | |
| ``None`` or ``'freq'`` is supported. For discontinuous data, additional | |
| valid values are ``'epoch'`` and ``'condition'``, or a :class:`list` | |
| comprising some of the valid string values (e.g., | |
| ``['freq', 'epoch']``). | |
| """ | |
| # check pandas once here, instead of in each private utils function | |
| pd = _check_pandas_installed() # noqa | |
| # triage for Epoch-derived or unaggregated spectra | |
| from_epo = _get_instance_type_string(self) == "Epochs" | |
| unagg_welch = "segment" in self._dims | |
| unagg_mt = "taper" in self._dims | |
| # arg checking | |
| valid_index_args = ["freq"] | |
| if from_epo: | |
| valid_index_args += ["epoch", "condition"] | |
| index = _check_pandas_index_arguments(index, valid_index_args) | |
| # get data | |
| picks = _picks_to_idx(self.info, picks, "all", exclude=()) | |
| data = self.get_data(picks) | |
| if copy: | |
| data = data.copy() | |
| # reshape | |
| if unagg_mt: | |
| data = np.moveaxis(data, self._dims.index("freq"), -2) | |
| if from_epo: | |
| n_epochs, n_picks, n_freqs = data.shape[:3] | |
| else: | |
| n_epochs, n_picks, n_freqs = (1,) + data.shape[:2] | |
| n_segs = data.shape[-1] if unagg_mt or unagg_welch else 1 | |
| data = np.moveaxis(data, self._dims.index("channel"), -1) | |
| # at this point, should be ([epoch], freq, [segment/taper], channel) | |
| data = data.reshape(n_epochs * n_freqs * n_segs, n_picks) | |
| # prepare extra columns / multiindex | |
| mindex = list() | |
| default_index = list() | |
| if from_epo: | |
| rev_event_id = {v: k for k, v in self.event_id.items()} | |
| _conds = [rev_event_id[k] for k in self.events[:, 2]] | |
| conditions = np.repeat(_conds, n_freqs * n_segs) | |
| epoch_nums = np.repeat(self.selection, n_freqs * n_segs) | |
| mindex.extend([("condition", conditions), ("epoch", epoch_nums)]) | |
| default_index.extend(["condition", "epoch"]) | |
| freqs = np.tile(np.repeat(self.freqs, n_segs), n_epochs) | |
| mindex.append(("freq", freqs)) | |
| default_index.append("freq") | |
| if unagg_mt or unagg_welch: | |
| name = "taper" if unagg_mt else "segment" | |
| seg_nums = np.tile(np.arange(n_segs), n_epochs * n_freqs) | |
| mindex.append((name, seg_nums)) | |
| default_index.append(name) | |
| # build DataFrame | |
| df = _build_data_frame( | |
| self, data, picks, long_format, mindex, index, default_index=default_index | |
| ) | |
| return df | |
| def units(self, latex=False): | |
| """Get the spectrum units for each channel type. | |
| Parameters | |
| ---------- | |
| latex : bool | |
| Whether to format the unit strings as LaTeX. Default is ``False``. | |
| Returns | |
| ------- | |
| units : dict | |
| Mapping from channel type to a string representation of the units | |
| for that channel type. | |
| """ | |
| units = _handle_default("si_units", None) | |
| return { | |
| ch_type: _format_units_psd(units[ch_type], power=True, latex=latex) | |
| for ch_type in sorted(self.get_channel_types(unique=True)) | |
| } | |
| class Spectrum(BaseSpectrum): | |
| """Data object for spectral representations of continuous data. | |
| .. warning:: The preferred means of creating Spectrum objects from | |
| continuous or averaged data is via the instance methods | |
| :meth:`mne.io.Raw.compute_psd` or | |
| :meth:`mne.Evoked.compute_psd`. Direct class instantiation | |
| is not supported. | |
| Parameters | |
| ---------- | |
| inst : instance of Raw or Evoked | |
| The data from which to compute the frequency spectrum. | |
| %(method_psd_auto)s | |
| ``'auto'`` (default) uses Welch's method for continuous data | |
| and multitaper for :class:`~mne.Evoked` data. | |
| %(fmin_fmax_psd)s | |
| %(tmin_tmax_psd)s | |
| %(picks_good_data_noref)s | |
| %(exclude_psd)s | |
| %(proj_psd)s | |
| %(remove_dc)s | |
| %(reject_by_annotation_psd)s | |
| %(n_jobs)s | |
| %(verbose)s | |
| %(method_kw_psd)s | |
| Attributes | |
| ---------- | |
| ch_names : list | |
| The channel names. | |
| freqs : array | |
| Frequencies at which the amplitude, power, or fourier coefficients | |
| have been computed. | |
| %(info_not_none)s | |
| method : ``'welch'`` | ``'multitaper'`` | |
| The method used to compute the spectrum. | |
| nave : int | None | |
| The number of trials averaged together when generating the spectrum. ``None`` | |
| indicates no averaging is known to have occurred. | |
| weights : array | None | |
| The weights for each taper. Only present if spectra computed with | |
| ``method='multitaper'`` and ``output='complex'``. | |
| .. versionadded:: 1.8 | |
| See Also | |
| -------- | |
| EpochsSpectrum | |
| SpectrumArray | |
| mne.io.Raw.compute_psd | |
| mne.Epochs.compute_psd | |
| mne.Evoked.compute_psd | |
| References | |
| ---------- | |
| .. footbibliography:: | |
| """ | |
| def __init__( | |
| self, | |
| inst, | |
| method, | |
| fmin, | |
| fmax, | |
| tmin, | |
| tmax, | |
| picks, | |
| exclude, | |
| proj, | |
| remove_dc, | |
| reject_by_annotation, | |
| *, | |
| n_jobs, | |
| verbose=None, | |
| **method_kw, | |
| ): | |
| from ..io import BaseRaw | |
| # triage reading from file | |
| if isinstance(inst, dict): | |
| self.__setstate__(inst) | |
| return | |
| # do the basic setup | |
| super().__init__( | |
| inst, | |
| method, | |
| fmin, | |
| fmax, | |
| tmin, | |
| tmax, | |
| picks, | |
| exclude, | |
| proj, | |
| remove_dc, | |
| n_jobs=n_jobs, | |
| verbose=verbose, | |
| **method_kw, | |
| ) | |
| # get just the data we want | |
| if isinstance(self.inst, BaseRaw): | |
| start, stop = np.where(self._time_mask)[0][[0, -1]] | |
| rba = "NaN" if reject_by_annotation else None | |
| data = self.inst.get_data( | |
| self._picks, start, stop + 1, reject_by_annotation=rba | |
| ) | |
| if np.any(np.isnan(data)) and method == "multitaper": | |
| raise NotImplementedError( | |
| 'Cannot use method="multitaper" when reject_by_annotation=True. ' | |
| 'Please use method="welch" instead.' | |
| ) | |
| else: # Evoked | |
| data = self.inst.data[self._picks][:, self._time_mask] | |
| # set nave | |
| self._nave = getattr(inst, "nave", None) | |
| # compute the spectra | |
| self._compute_spectra(data, fmin, fmax, n_jobs, method_kw, verbose) | |
| # check for correct shape and bad values | |
| self._check_values() | |
| del self._shape # calculated from self._data henceforth | |
| # save memory | |
| del self.inst | |
| def __getitem__(self, item): | |
| """Get Spectrum data. | |
| Parameters | |
| ---------- | |
| item : int | slice | array-like | |
| Indexing is similar to a :class:`NumPy array<numpy.ndarray>`; see | |
| Notes. | |
| Returns | |
| ------- | |
| %(getitem_spectrum_return)s | |
| Notes | |
| ----- | |
| Integer-, list-, and slice-based indexing is possible: | |
| - ``spectrum[0]`` gives all frequency bins in the first channel | |
| - ``spectrum[:3]`` gives all frequency bins in the first 3 channels | |
| - ``spectrum[[0, 2], 5]`` gives the value in the sixth frequency bin of | |
| the first and third channels | |
| - ``spectrum[(4, 7)]`` is the same as ``spectrum[4, 7]``. | |
| .. note:: | |
| Unlike :class:`~mne.io.Raw` objects (which returns a tuple of the | |
| requested data values and the corresponding times), accessing | |
| :class:`~mne.time_frequency.Spectrum` values via subscript does | |
| **not** return the corresponding frequency bin values. If you need | |
| them, use ``spectrum.freqs[freq_indices]`` or | |
| ``spectrum.get_data(..., return_freqs=True)``. | |
| """ | |
| from ..io import BaseRaw | |
| self._parse_get_set_params = partial(BaseRaw._parse_get_set_params, self) | |
| return BaseRaw._getitem(self, item, return_times=False) | |
| def _check_data_shape(data, info, freqs, dim_names, weights, is_epoched): | |
| if data.ndim != len(dim_names): | |
| raise ValueError( | |
| f"Expected data to have {len(dim_names)} dimensions, got {data.ndim}." | |
| ) | |
| allowed_dims = ["epoch", "channel", "freq", "segment", "taper"] | |
| if not is_epoched: | |
| allowed_dims.remove("epoch") | |
| # TODO maybe we should be nice and allow plural versions of each dimname? | |
| for dim in dim_names: | |
| _check_option("dim_names", dim, allowed_dims) | |
| if "channel" not in dim_names or "freq" not in dim_names: | |
| raise ValueError("Both 'channel' and 'freq' must be present in `dim_names`.") | |
| if list(dim_names).index("channel") != int(is_epoched): | |
| raise ValueError( | |
| f"'channel' must be the {'second' if is_epoched else 'first'} dimension of " | |
| "the data." | |
| ) | |
| want_n_chan = _pick_data_channels(info, exclude=()).size | |
| got_n_chan = data.shape[list(dim_names).index("channel")] | |
| if got_n_chan != want_n_chan: | |
| raise ValueError( | |
| f"The number of channels in `data` ({got_n_chan}) must match the number of " | |
| f"good + bad data channels in `info` ({want_n_chan})." | |
| ) | |
| # given we limit max array size and ensure channel & freq dims present, only one of | |
| # taper or segment can be present | |
| if "taper" in dim_names: | |
| if dim_names[-2] != "taper": # _psd_from_mt assumes this (called when plotting) | |
| raise ValueError( | |
| "'taper' must be the second to last dimension of the data." | |
| ) | |
| # expect weights for each taper | |
| actual = None if weights is None else weights.size | |
| expected = data.shape[list(dim_names).index("taper")] | |
| if actual != expected: | |
| raise ValueError( | |
| f"Expected size of `weights` to be {expected} to match 'n_tapers' in " | |
| f"`data`, got {actual}." | |
| ) | |
| elif "segment" in dim_names and dim_names[-1] != "segment": | |
| raise ValueError("'segment' must be the last dimension of the data.") | |
| # freq being in wrong position ruled out by above checks | |
| want_n_freq = freqs.size | |
| got_n_freq = data.shape[list(dim_names).index("freq")] | |
| if got_n_freq != want_n_freq: | |
| raise ValueError( | |
| f"The number of frequencies in `data` ({got_n_freq}) must match the number " | |
| f"of elements in `freqs` ({want_n_freq})." | |
| ) | |
| class SpectrumArray(Spectrum): | |
| """Data object for precomputed spectral data (in NumPy array format). | |
| Parameters | |
| ---------- | |
| data : ndarray, shape (n_channels, [n_tapers], n_freqs, [n_segments]) | |
| The spectra for each channel. | |
| %(info_not_none)s | |
| %(freqs_tfr_array)s | |
| dim_names : tuple of str | |
| The name of the dimensions in the data, in the order they occur. Must contain | |
| ``'channel'`` and ``'freq'``; if data are unaggregated estimates, also include | |
| either a ``'segment'`` (e.g., Welch-like algorithms) or ``'taper'`` (e.g., | |
| multitaper algorithms) dimension. If including ``'taper'``, you should also pass | |
| a ``weights`` parameter. | |
| .. versionadded:: 1.8 | |
| weights : ndarray | None | |
| Weights for the ``'taper'`` dimension, if present (see ``dim_names``). | |
| .. versionadded:: 1.8 | |
| %(verbose)s | |
| See Also | |
| -------- | |
| mne.create_info | |
| mne.EvokedArray | |
| mne.io.RawArray | |
| EpochsSpectrumArray | |
| Notes | |
| ----- | |
| %(notes_spectrum_array)s | |
| .. versionadded:: 1.6 | |
| """ | |
| def __init__( | |
| self, | |
| data, | |
| info, | |
| freqs, | |
| dim_names=("channel", "freq"), | |
| weights=None, | |
| *, | |
| verbose=None, | |
| ): | |
| # (channel, [taper], freq, [segment]) | |
| _check_option("data.ndim", data.ndim, (2, 3)) # only allow one extra dimension | |
| _check_data_shape(data, info, freqs, dim_names, weights, is_epoched=False) | |
| self.__setstate__( | |
| dict( | |
| method="unknown", | |
| data=data, | |
| sfreq=info["sfreq"], | |
| dims=dim_names, | |
| freqs=freqs, | |
| inst_type_str="Array", | |
| data_type=( | |
| "Fourier Coefficients" | |
| if np.iscomplexobj(data) | |
| else "Power Spectrum" | |
| ), | |
| info=info, | |
| weights=weights, | |
| ) | |
| ) | |
| class EpochsSpectrum(BaseSpectrum, GetEpochsMixin): | |
| """Data object for spectral representations of epoched data. | |
| .. warning:: The preferred means of creating Spectrum objects from Epochs | |
| is via the instance method :meth:`mne.Epochs.compute_psd`. | |
| Direct class instantiation is not supported. | |
| Parameters | |
| ---------- | |
| inst : instance of Epochs | |
| The data from which to compute the frequency spectrum. | |
| %(method_psd)s | |
| %(fmin_fmax_psd)s | |
| %(tmin_tmax_psd)s | |
| %(picks_good_data_noref)s | |
| %(exclude_psd)s | |
| %(proj_psd)s | |
| %(remove_dc)s | |
| %(n_jobs)s | |
| %(verbose)s | |
| %(method_kw_psd)s | |
| Attributes | |
| ---------- | |
| ch_names : list | |
| The channel names. | |
| freqs : array | |
| Frequencies at which the amplitude, power, or fourier coefficients | |
| have been computed. | |
| %(info_not_none)s | |
| method : ``'welch'`` | ``'multitaper'`` | |
| The method used to compute the spectrum. | |
| weights : array | None | |
| The weights for each taper. Only present if spectra computed with | |
| ``method='multitaper'`` and ``output='complex'``. | |
| .. versionadded:: 1.8 | |
| See Also | |
| -------- | |
| EpochsSpectrumArray | |
| Spectrum | |
| mne.Epochs.compute_psd | |
| References | |
| ---------- | |
| .. footbibliography:: | |
| """ | |
| def __init__( | |
| self, | |
| inst, | |
| method, | |
| fmin, | |
| fmax, | |
| tmin, | |
| tmax, | |
| picks, | |
| exclude, | |
| proj, | |
| remove_dc, | |
| *, | |
| n_jobs, | |
| verbose=None, | |
| **method_kw, | |
| ): | |
| # triage reading from file | |
| if isinstance(inst, dict): | |
| self.__setstate__(inst) | |
| return | |
| # do the basic setup | |
| super().__init__( | |
| inst, | |
| method, | |
| fmin, | |
| fmax, | |
| tmin, | |
| tmax, | |
| picks, | |
| exclude, | |
| proj, | |
| remove_dc, | |
| n_jobs=n_jobs, | |
| verbose=verbose, | |
| **method_kw, | |
| ) | |
| # get just the data we want | |
| data = self.inst._get_data(picks=self._picks, on_empty="raise")[ | |
| :, :, self._time_mask | |
| ] | |
| # compute the spectra | |
| self._compute_spectra(data, fmin, fmax, n_jobs, method_kw, verbose) | |
| self._dims = ("epoch",) + self._dims | |
| self._shape = (len(self.inst),) + self._shape | |
| # check for correct shape and bad values | |
| self._check_values() | |
| del self._shape | |
| # we need these for to_data_frame() | |
| self.event_id = self.inst.event_id.copy() | |
| self.events = self.inst.events.copy() | |
| self.selection = self.inst.selection.copy() | |
| # we need these for __getitem__() | |
| self.drop_log = deepcopy(self.inst.drop_log) | |
| self._metadata = self.inst.metadata | |
| # save memory | |
| del self.inst | |
| def __getitem__(self, item): | |
| """Subselect epochs from an EpochsSpectrum. | |
| Parameters | |
| ---------- | |
| item : int | slice | array-like | str | |
| Access options are the same as for :class:`~mne.Epochs` objects, | |
| see the docstring of :meth:`mne.Epochs.__getitem__` for | |
| explanation. | |
| Returns | |
| ------- | |
| %(getitem_epochspectrum_return)s | |
| """ | |
| return super().__getitem__(item) | |
| def __getstate__(self): | |
| """Prepare object for serialization.""" | |
| out = super().__getstate__() | |
| out.update( | |
| metadata=self._metadata, | |
| drop_log=self.drop_log, | |
| event_id=self.event_id, | |
| events=self.events, | |
| selection=self.selection, | |
| ) | |
| return out | |
| def __setstate__(self, state): | |
| """Unpack from serialized format.""" | |
| super().__setstate__(state) | |
| self._metadata = state["metadata"] | |
| self.drop_log = state["drop_log"] | |
| self.event_id = state["event_id"] | |
| self.events = state["events"] | |
| self.selection = state["selection"] | |
| def average(self, method="mean"): | |
| """Average the spectra across epochs. | |
| Parameters | |
| ---------- | |
| method : 'mean' | 'median' | callable | |
| How to aggregate spectra across epochs. If callable, must take a | |
| :class:`NumPy array<numpy.ndarray>` of shape | |
| ``(n_epochs, n_channels, n_freqs)`` and return an array of shape | |
| ``(n_channels, n_freqs)``. Default is ``'mean'``. | |
| Returns | |
| ------- | |
| spectrum : instance of Spectrum | |
| The aggregated spectrum object. | |
| """ | |
| _validate_type(method, ("str", "callable"), "method") | |
| method = _make_combine_callable( | |
| method, axis=0, valid=("mean", "median"), keepdims=False | |
| ) | |
| if not callable(method): | |
| raise ValueError( | |
| '"method" must be a valid string or callable, ' | |
| f"got a {type(method).__name__} ({method})." | |
| ) | |
| # averaging unaggregated spectral estimates are not supported | |
| if "segment" in self._dims: | |
| raise NotImplementedError( | |
| "Averaging individual Welch segments across epochs is not " | |
| "supported. Consider averaging the signals before computing " | |
| "the Welch spectrum estimates." | |
| ) | |
| if "taper" in self._dims: | |
| raise NotImplementedError( | |
| "Averaging multitaper tapers across epochs is not supported. Consider " | |
| "averaging the signals before computing the complex spectrum." | |
| ) | |
| # serialize the object and update data, dims, and data type | |
| state = super().__getstate__() | |
| state["nave"] = state["data"].shape[0] | |
| state["data"] = method(state["data"]) | |
| state["dims"] = state["dims"][1:] | |
| state["data_type"] = f"Averaged {state['data_type']}" | |
| defaults = dict( | |
| method=None, | |
| fmin=None, | |
| fmax=None, | |
| tmin=None, | |
| tmax=None, | |
| picks=None, | |
| exclude=(), | |
| proj=None, | |
| remove_dc=None, | |
| reject_by_annotation=None, | |
| n_jobs=None, | |
| verbose=None, | |
| ) | |
| return Spectrum(state, **defaults) | |
| class EpochsSpectrumArray(EpochsSpectrum): | |
| """Data object for precomputed epoched spectral data (in NumPy array format). | |
| Parameters | |
| ---------- | |
| data : ndarray, shape (n_epochs, n_channels, [n_tapers], n_freqs, [n_segments]) | |
| The spectra for each channel in each epoch. | |
| %(info_not_none)s | |
| %(freqs_tfr_array)s | |
| %(events_epochs)s | |
| %(event_id)s | |
| dim_names : tuple of str | |
| The name of the dimensions in the data, in the order they occur. Must contain | |
| ``'channel'`` and ``'freq'``; if data are unaggregated estimates, also include | |
| either a ``'segment'`` (e.g., Welch-like algorithms) or ``'taper'`` (e.g., | |
| multitaper algorithms) dimension. If including ``'taper'``, you should also pass | |
| a ``weights`` parameter. | |
| .. versionadded:: 1.8 | |
| weights : ndarray | None | |
| Weights for the ``'taper'`` dimension, if present (see ``dim_names``). | |
| .. versionadded:: 1.8 | |
| %(verbose)s | |
| See Also | |
| -------- | |
| mne.create_info | |
| mne.EpochsArray | |
| SpectrumArray | |
| Notes | |
| ----- | |
| %(notes_spectrum_array)s | |
| .. versionadded:: 1.6 | |
| """ | |
| def __init__( | |
| self, | |
| data, | |
| info, | |
| freqs, | |
| events=None, | |
| event_id=None, | |
| dim_names=("epoch", "channel", "freq"), | |
| weights=None, | |
| *, | |
| verbose=None, | |
| ): | |
| # (epoch, channel, [taper], freq, [segment]) | |
| _check_option("data.ndim", data.ndim, (3, 4)) # only allow one extra dimension | |
| if list(dim_names).index("epoch") != 0: | |
| raise ValueError("'epoch' must be the first dimension of `data`.") | |
| if events is not None and data.shape[0] != events.shape[0]: | |
| raise ValueError( | |
| f"The first dimension of `data` ({data.shape[0]}) must match the first " | |
| f"dimension of `events` ({events.shape[0]})." | |
| ) | |
| _check_data_shape(data, info, freqs, dim_names, weights, is_epoched=True) | |
| self.__setstate__( | |
| dict( | |
| method="unknown", | |
| data=data, | |
| sfreq=info["sfreq"], | |
| dims=dim_names, | |
| freqs=freqs, | |
| inst_type_str="Array", | |
| data_type=( | |
| "Fourier Coefficients" | |
| if np.iscomplexobj(data) | |
| else "Power Spectrum" | |
| ), | |
| info=info, | |
| events=events, | |
| event_id=event_id, | |
| metadata=None, | |
| selection=np.arange(data.shape[0]), | |
| drop_log=tuple(tuple() for _ in range(data.shape[0])), | |
| weights=weights, | |
| ) | |
| ) | |
| def combine_spectrum(all_spectrum, weights="nave"): | |
| """Merge spectral data by weighted addition. | |
| Create a new :class:`mne.time_frequency.Spectrum` instance, using a combination of | |
| the supplied instances as its data. By default, the mean (weighted by trials) is | |
| used. Subtraction can be performed by passing negative weights (e.g., ``[1, -1]``). | |
| Data must have the same channels and the same frequencies. | |
| Parameters | |
| ---------- | |
| all_spectrum : list of Spectrum | |
| The Spectrum objects. | |
| weights : list of float | str | |
| The weights to apply to the data of each :class:`~mne.time_frequency.Spectrum` | |
| instance, or a string describing the weighting strategy to apply: 'nave' | |
| computes sum-to-one weights proportional to each object’s nave attribute; | |
| 'equal' weights each :class:`~mne.time_frequency.Spectrum` by | |
| ``1 / len(all_spectrum)``. | |
| Returns | |
| ------- | |
| spectrum : Spectrum | |
| The new spectral data. | |
| Notes | |
| ----- | |
| .. versionadded:: 1.10.0 | |
| """ | |
| spectrum = all_spectrum[0].copy() | |
| if isinstance(weights, str): | |
| if weights not in ("nave", "equal"): | |
| raise ValueError('Weights must be a list of float, or "nave" or "equal"') | |
| if weights == "nave": | |
| for s_ in all_spectrum: | |
| if s_.nave is None: | |
| raise ValueError(f"The 'nave' attribute is not specified for {s_}") | |
| weights = np.array([e.nave for e in all_spectrum], float) | |
| weights /= weights.sum() | |
| else: # == 'equal' | |
| weights = [1.0 / len(all_spectrum)] * len(all_spectrum) | |
| weights = np.array(weights, float) | |
| if weights.ndim != 1 or weights.size != len(all_spectrum): | |
| raise ValueError("Weights must be the same size as all_spectrum") | |
| ch_names = spectrum.ch_names | |
| for s_ in all_spectrum[1:]: | |
| assert s_.ch_names == ch_names, ( | |
| f"{spectrum} and {s_} do not contain the same channels" | |
| ) | |
| assert np.max(np.abs(s_.freqs - spectrum.freqs)) < 1e-7, ( | |
| f"{spectrum} and {s_} do not contain the same frequencies" | |
| ) | |
| # use union of bad channels | |
| bads = list( | |
| set(spectrum.info["bads"]).union(*(s_.info["bads"] for s_ in all_spectrum[1:])) | |
| ) | |
| spectrum.info["bads"] = bads | |
| # combine spectral data | |
| spectrum._data = sum(w * s_.data for w, s_ in zip(weights, all_spectrum)) | |
| if spectrum.nave is not None: | |
| spectrum._nave = max( | |
| int(1.0 / sum(w**2 / s_.nave for w, s_ in zip(weights, all_spectrum))), 1 | |
| ) | |
| return spectrum | |
| def read_spectrum(fname): | |
| """Load a :class:`mne.time_frequency.Spectrum` object from disk. | |
| Parameters | |
| ---------- | |
| fname : path-like | |
| Path to a spectrum file in HDF5 format, which should end with ``.h5`` or | |
| ``.hdf5``. | |
| Returns | |
| ------- | |
| spectrum : instance of Spectrum | |
| The loaded Spectrum object. | |
| See Also | |
| -------- | |
| mne.time_frequency.Spectrum.save | |
| """ | |
| read_hdf5, _ = _import_h5io_funcs() | |
| _validate_type(fname, "path-like", "fname") | |
| fname = _check_fname(fname=fname, overwrite="read", must_exist=False) | |
| # read it in | |
| hdf5_dict = read_hdf5(fname, title="mnepython", slash="replace") | |
| defaults = dict( | |
| method=None, | |
| fmin=None, | |
| fmax=None, | |
| tmin=None, | |
| tmax=None, | |
| picks=None, | |
| exclude=(), | |
| proj=None, | |
| remove_dc=None, | |
| reject_by_annotation=None, | |
| n_jobs=None, | |
| verbose=None, | |
| ) | |
| Klass = EpochsSpectrum if hdf5_dict["inst_type_str"] == "Epochs" else Spectrum | |
| return Klass(hdf5_dict, **defaults) | |
| def _check_ci(ci): | |
| ci = "sd" if ci == "std" else ci # be forgiving | |
| if _is_numeric(ci): | |
| if not (0 < ci <= 100): | |
| raise ValueError(f"ci must satisfy 0 < ci <= 100, got {ci}") | |
| ci /= 100.0 | |
| else: | |
| _check_option("ci", ci, [None, "sd", "range"]) | |
| return ci | |
| def _compute_n_welch_segments(n_times, method_kw): | |
| # get default values from psd_array_welch | |
| _defaults = dict() | |
| for param in ("n_fft", "n_per_seg", "n_overlap"): | |
| _defaults[param] = signature(psd_array_welch).parameters[param].default | |
| # override defaults with user-specified values | |
| for key, val in _defaults.items(): | |
| _defaults.update({key: method_kw.get(key, val)}) | |
| # sanity check values / replace `None`s with real numbers | |
| n_fft, n_per_seg, n_overlap = _check_nfft(n_times, **_defaults) | |
| # compute expected number of segments | |
| step = n_per_seg - n_overlap | |
| return (n_times - n_overlap) // step | |
| def _validate_method(method, instance_type): | |
| """Convert 'auto' to a real method name, and validate.""" | |
| if method == "auto": | |
| method = "welch" if instance_type.startswith("Raw") else "multitaper" | |
| _check_option("method", method, ("welch", "multitaper")) | |
| return method | |