Spaces:
Running
Running
| """A module which implements the time-frequency estimation. | |
| Morlet code inspired by Matlab code from Sheraz Khan & Brainstorm & SPM | |
| """ | |
| # Authors: The MNE-Python contributors. | |
| # License: BSD-3-Clause | |
| # Copyright the MNE-Python contributors. | |
| import inspect | |
| from copy import deepcopy | |
| from functools import partial | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from scipy.fft import fft, ifft | |
| from scipy.signal import argrelmax | |
| from .._fiff.meas_info import ContainsMixin, Info | |
| from .._fiff.pick import _picks_to_idx, pick_info | |
| from ..baseline import _check_baseline, rescale | |
| from ..channels.channels import UpdateChannelsMixin | |
| from ..channels.layout import _find_topomap_coords, _merge_ch_data, _pair_grad_sensors | |
| from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT | |
| from ..filter import next_fast_len | |
| from ..parallel import parallel_func | |
| from ..utils import ( | |
| ExtendedTimeMixin, | |
| GetEpochsMixin, | |
| SizeMixin, | |
| _build_data_frame, | |
| _check_combine, | |
| _check_event_id, | |
| _check_fname, | |
| _check_method_kwargs, | |
| _check_option, | |
| _check_pandas_index_arguments, | |
| _check_pandas_installed, | |
| _check_time_format, | |
| _convert_times, | |
| _ensure_events, | |
| _freq_mask, | |
| _import_h5io_funcs, | |
| _is_numeric, | |
| _pl, | |
| _prepare_read_metadata, | |
| _prepare_write_metadata, | |
| _time_mask, | |
| _validate_type, | |
| check_fname, | |
| copy_doc, | |
| copy_function_doc_to_method_doc, | |
| fill_doc, | |
| legacy, | |
| logger, | |
| object_diff, | |
| repr_html, | |
| sizeof_fmt, | |
| verbose, | |
| warn, | |
| ) | |
| from ..utils.spectrum import _get_instance_type_string | |
| from ..viz.topo import _imshow_tfr, _imshow_tfr_unified, _plot_topo | |
| from ..viz.topomap import ( | |
| _add_colorbar, | |
| _get_pos_outlines, | |
| _set_contour_locator, | |
| plot_tfr_topomap, | |
| plot_topomap, | |
| ) | |
| from ..viz.utils import ( | |
| _make_combine_callable, | |
| _prepare_joint_axes, | |
| _set_title_multiple_electrodes, | |
| _setup_cmap, | |
| _setup_vmin_vmax, | |
| add_background_image, | |
| figure_nobar, | |
| plt_show, | |
| ) | |
| from .multitaper import dpss_windows, tfr_array_multitaper | |
| from .spectrum import EpochsSpectrum | |
| def morlet(sfreq, freqs, n_cycles=7.0, sigma=None, zero_mean=False): | |
| """Compute Morlet wavelets for the given frequency range. | |
| Parameters | |
| ---------- | |
| sfreq : float | |
| The sampling Frequency. | |
| freqs : float | array-like, shape (n_freqs,) | |
| Frequencies to compute Morlet wavelets for. | |
| n_cycles : float | array-like, shape (n_freqs,) | |
| Number of cycles. Can be a fixed number (float) or one per frequency | |
| (array-like). | |
| sigma : float, default None | |
| It controls the width of the wavelet ie its temporal | |
| resolution. If sigma is None the temporal resolution | |
| is adapted with the frequency like for all wavelet transform. | |
| The higher the frequency the shorter is the wavelet. | |
| If sigma is fixed the temporal resolution is fixed | |
| like for the short time Fourier transform and the number | |
| of oscillations increases with the frequency. | |
| zero_mean : bool, default False | |
| Make sure the wavelet has a mean of zero. | |
| Returns | |
| ------- | |
| Ws : list of ndarray | ndarray | |
| The wavelets time series. If ``freqs`` was a float, a single | |
| ndarray is returned instead of a list of ndarray. | |
| See Also | |
| -------- | |
| mne.time_frequency.fwhm | |
| Notes | |
| ----- | |
| %(morlet_reference)s | |
| %(fwhm_morlet_notes)s | |
| References | |
| ---------- | |
| .. footbibliography:: | |
| Examples | |
| -------- | |
| Let's show a simple example of the relationship between ``n_cycles`` and | |
| the FWHM using :func:`mne.time_frequency.fwhm`: | |
| .. plot:: | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from mne.time_frequency import morlet, fwhm | |
| sfreq, freq, n_cycles = 1000., 10, 7 # i.e., 700 ms | |
| this_fwhm = fwhm(freq, n_cycles) | |
| wavelet = morlet(sfreq=sfreq, freqs=freq, n_cycles=n_cycles) | |
| M, w = len(wavelet), n_cycles # convert to SciPy convention | |
| s = w * sfreq / (2 * freq * np.pi) # from SciPy docs | |
| _, ax = plt.subplots(layout="constrained") | |
| colors = dict(real="#66CCEE", imag="#EE6677") | |
| t = np.arange(-M // 2 + 1, M // 2 + 1) / sfreq | |
| for kind in ('real', 'imag'): | |
| ax.plot( | |
| t, getattr(wavelet, kind), label=kind, color=colors[kind], | |
| ) | |
| ax.plot(t, np.abs(wavelet), label=f'abs', color='k', lw=1., zorder=6) | |
| half_max = np.max(np.abs(wavelet)) / 2. | |
| ax.plot([-this_fwhm / 2., this_fwhm / 2.], [half_max, half_max], | |
| color='k', linestyle='-', label='FWHM', zorder=6) | |
| ax.legend(loc='upper right') | |
| ax.set(xlabel='Time (s)', ylabel='Amplitude') | |
| """ # noqa: E501 | |
| Ws = list() | |
| n_cycles = np.array(n_cycles, float).ravel() | |
| freqs = np.array(freqs, float) | |
| if np.any(freqs <= 0): | |
| raise ValueError("all frequencies in 'freqs' must be greater than 0.") | |
| if (n_cycles.size != 1) and (n_cycles.size != len(freqs)): | |
| raise ValueError("n_cycles should be fixed or defined for each frequency.") | |
| _check_option("freqs.ndim", freqs.ndim, [0, 1]) | |
| singleton = freqs.ndim == 0 | |
| if singleton: | |
| freqs = freqs[np.newaxis] | |
| for k, f in enumerate(freqs): | |
| if len(n_cycles) != 1: | |
| this_n_cycles = n_cycles[k] | |
| else: | |
| this_n_cycles = n_cycles[0] | |
| # sigma_t is the stddev of gaussian window in the time domain; can be | |
| # scale-dependent or fixed across freqs | |
| if sigma is None: | |
| sigma_t = this_n_cycles / (2.0 * np.pi * f) | |
| else: | |
| sigma_t = this_n_cycles / (2.0 * np.pi * sigma) | |
| # time vector. We go 5 standard deviations out to make sure we're | |
| # *very* close to zero at the ends. We also make sure that there's a | |
| # sample at exactly t=0 | |
| t = np.arange(0.0, 5.0 * sigma_t, 1.0 / sfreq) | |
| t = np.r_[-t[::-1], t[1:]] | |
| oscillation = np.exp(2.0 * 1j * np.pi * f * t) | |
| if zero_mean: | |
| # this offset is equivalent to the κ_σ term in Wikipedia's | |
| # equations, and satisfies the "admissibility criterion" for CWTs | |
| real_offset = np.exp(-2 * (np.pi * f * sigma_t) ** 2) | |
| oscillation -= real_offset | |
| gaussian_envelope = np.exp(-(t**2) / (2.0 * sigma_t**2)) | |
| W = oscillation * gaussian_envelope | |
| # the scaling factor here is proportional to what is used in | |
| # Tallon-Baudry 1997: (sigma_t*sqrt(pi))^(-1/2). It yields a wavelet | |
| # with norm sqrt(2) for the full wavelet / norm 1 for the real part | |
| W /= np.sqrt(0.5) * np.linalg.norm(W.ravel()) | |
| Ws.append(W) | |
| if singleton: | |
| Ws = Ws[0] | |
| return Ws | |
| def fwhm(freq, n_cycles): | |
| """Compute the full-width half maximum of a Morlet wavelet. | |
| Uses the formula from :footcite:t:`Cohen2019`. | |
| Parameters | |
| ---------- | |
| freq : float | |
| The oscillation frequency of the wavelet. | |
| n_cycles : float | |
| The duration of the wavelet, expressed as the number of oscillation | |
| cycles. | |
| Returns | |
| ------- | |
| fwhm : float | |
| The full-width half maximum of the wavelet. | |
| Notes | |
| ----- | |
| .. versionadded:: 1.3 | |
| References | |
| ---------- | |
| .. footbibliography:: | |
| """ | |
| return n_cycles * np.sqrt(2 * np.log(2)) / (np.pi * freq) | |
| def _make_dpss( | |
| sfreq, | |
| freqs, | |
| n_cycles=7.0, | |
| time_bandwidth=4.0, | |
| zero_mean=False, | |
| return_weights=False, | |
| ): | |
| """Compute DPSS tapers for the given frequency range. | |
| Parameters | |
| ---------- | |
| sfreq : float | |
| The sampling frequency. | |
| freqs : ndarray, shape (n_freqs,) | |
| The frequencies in Hz. | |
| n_cycles : float | ndarray, shape (n_freqs,), default 7. | |
| The number of cycles globally or for each frequency. | |
| time_bandwidth : float, default 4.0 | |
| Time x Bandwidth product. | |
| The number of good tapers (low-bias) is chosen automatically based on | |
| this to equal floor(time_bandwidth - 1). | |
| Default is 4.0, giving 3 good tapers. | |
| zero_mean : bool | None, , default False | |
| Make sure the wavelet has a mean of zero. | |
| return_weights : bool | |
| Whether to return the concentration weights. | |
| Returns | |
| ------- | |
| Ws : list of array | |
| The wavelets time series. | |
| Cs : list of array | |
| The concentration weights. Only returned if return_weights=True. | |
| """ | |
| Ws = list() | |
| Cs = list() | |
| freqs = np.array(freqs) | |
| if np.any(freqs <= 0): | |
| raise ValueError("all frequencies in 'freqs' must be greater than 0.") | |
| if time_bandwidth < 2.0: | |
| raise ValueError("time_bandwidth should be >= 2.0 for good tapers") | |
| n_taps = int(np.floor(time_bandwidth - 1)) | |
| n_cycles = np.atleast_1d(n_cycles) | |
| if n_cycles.size != 1 and n_cycles.size != len(freqs): | |
| raise ValueError("n_cycles should be fixed or defined for each frequency.") | |
| for m in range(n_taps): | |
| Wm = list() | |
| Cm = list() | |
| for k, f in enumerate(freqs): | |
| if len(n_cycles) != 1: | |
| this_n_cycles = n_cycles[k] | |
| else: | |
| this_n_cycles = n_cycles[0] | |
| t_win = this_n_cycles / float(f) | |
| t = np.arange(0.0, t_win, 1.0 / sfreq) | |
| # Making sure wavelets are centered before tapering | |
| oscillation = np.exp(2.0 * 1j * np.pi * f * (t - t_win / 2.0)) | |
| # Get dpss tapers | |
| tapers, conc = dpss_windows( | |
| t.shape[0], time_bandwidth / 2.0, n_taps, sym=False | |
| ) | |
| Wk = oscillation * tapers[m] | |
| if zero_mean: # to make it zero mean | |
| real_offset = Wk.mean() | |
| Wk -= real_offset | |
| Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel()) | |
| Ck = np.sqrt(conc[m]) | |
| Wm.append(Wk) | |
| Cm.append(Ck) | |
| Ws.append(Wm) | |
| Cs.append(Cm) | |
| if return_weights: | |
| return Ws, Cs | |
| return Ws | |
| # Low level convolution | |
| def _get_nfft(wavelets, X, use_fft=True, check=True): | |
| n_times = X.shape[-1] | |
| max_size = max(w.size for w in wavelets) | |
| if max_size > n_times: | |
| msg = ( | |
| f"At least one of the wavelets ({max_size}) is longer than the " | |
| f"signal ({n_times}). Consider using a longer signal or " | |
| "shorter wavelets." | |
| ) | |
| if check: | |
| if use_fft: | |
| warn(msg, UserWarning) | |
| else: | |
| raise ValueError(msg) | |
| nfft = n_times + max_size - 1 | |
| nfft = next_fast_len(nfft) # 2 ** int(np.ceil(np.log2(nfft))) | |
| return nfft | |
| def _cwt_gen(X, Ws, *, fsize=0, mode="same", decim=1, use_fft=True): | |
| """Compute cwt with fft based convolutions or temporal convolutions. | |
| Parameters | |
| ---------- | |
| X : array of shape (n_signals, n_times) | |
| The data. | |
| Ws : list of array | |
| Wavelets time series. | |
| fsize : int | |
| FFT length. | |
| mode : {'full', 'valid', 'same'} | |
| See numpy.convolve. | |
| decim : int | slice, default 1 | |
| To reduce memory usage, decimation factor after time-frequency | |
| decomposition. | |
| If `int`, returns tfr[..., ::decim]. | |
| If `slice`, returns tfr[..., decim]. | |
| .. note:: Decimation may create aliasing artifacts. | |
| use_fft : bool, default True | |
| Use the FFT for convolutions or not. | |
| Returns | |
| ------- | |
| out : array, shape (n_signals, n_freqs, n_time_decim) | |
| The time-frequency transform of the signals. | |
| """ | |
| _check_option("mode", mode, ["same", "valid", "full"]) | |
| decim = _ensure_slice(decim) | |
| X = np.asarray(X) | |
| # Precompute wavelets for given frequency range to save time | |
| _, n_times = X.shape | |
| n_times_out = X[:, decim].shape[1] | |
| n_freqs = len(Ws) | |
| # precompute FFTs of Ws | |
| if use_fft: | |
| fft_Ws = np.empty((n_freqs, fsize), dtype=np.complex128) | |
| for i, W in enumerate(Ws): | |
| fft_Ws[i] = fft(W, fsize) | |
| # Make generator looping across signals | |
| tfr = np.zeros((n_freqs, n_times_out), dtype=np.complex128) | |
| for x in X: | |
| if use_fft: | |
| fft_x = fft(x, fsize) | |
| # Loop across wavelets | |
| for ii, W in enumerate(Ws): | |
| if use_fft: | |
| ret = ifft(fft_x * fft_Ws[ii])[: n_times + W.size - 1] | |
| else: | |
| # Work around multarray.correlate->OpenBLAS bug on ppc64le | |
| # ret = np.correlate(x, W, mode=mode) | |
| ret = np.convolve(x, W.real, mode=mode) + 1j * np.convolve( | |
| x, W.imag, mode=mode | |
| ) | |
| # Center and decimate decomposition | |
| if mode == "valid": | |
| sz = int(abs(W.size - n_times)) + 1 | |
| offset = (n_times - sz) // 2 | |
| this_slice = slice(offset // decim.step, (offset + sz) // decim.step) | |
| if use_fft: | |
| ret = _centered(ret, sz) | |
| tfr[ii, this_slice] = ret[decim] | |
| elif mode == "full" and not use_fft: | |
| start = (W.size - 1) // 2 | |
| end = len(ret) - (W.size // 2) | |
| ret = ret[start:end] | |
| tfr[ii, :] = ret[decim] | |
| else: | |
| if use_fft: | |
| ret = _centered(ret, n_times) | |
| tfr[ii, :] = ret[decim] | |
| yield tfr | |
| # Loop of convolution: single trial | |
| def _compute_tfr( | |
| epoch_data, | |
| freqs, | |
| sfreq=1.0, | |
| method="morlet", | |
| n_cycles=7.0, | |
| zero_mean=None, | |
| time_bandwidth=None, | |
| use_fft=True, | |
| decim=1, | |
| output="complex", | |
| return_weights=False, | |
| n_jobs=None, | |
| *, | |
| verbose=None, | |
| ): | |
| """Compute time-frequency transforms. | |
| Parameters | |
| ---------- | |
| epoch_data : array of shape (n_epochs, n_channels, n_times) | |
| The epochs.default ``'complex'`` | |
| freqs : array-like of floats, shape (n_freqs) | |
| The frequencies. | |
| sfreq : float | int, default 1.0 | |
| Sampling frequency of the data. | |
| method : 'multitaper' | 'morlet', default 'morlet' | |
| The time-frequency method. 'morlet' convolves a Morlet wavelet. | |
| 'multitaper' uses complex exponentials windowed with multiple DPSS | |
| tapers. | |
| n_cycles : float | array of float, default 7.0 | |
| Number of cycles in the wavelet. Fixed number | |
| or one per frequency. | |
| zero_mean : bool | None, default None | |
| None means True for method='multitaper' and False for method='morlet'. | |
| If True, make sure the wavelets have a mean of zero. | |
| time_bandwidth : float, default None | |
| If None and method=multitaper, will be set to 4.0 (3 tapers). | |
| Time x (Full) Bandwidth product. Only applies if | |
| method == 'multitaper'. The number of good tapers (low-bias) is | |
| chosen automatically based on this to equal floor(time_bandwidth - 1). | |
| use_fft : bool, default True | |
| Use the FFT for convolutions or not. | |
| decim : int | slice, default 1 | |
| To reduce memory usage, decimation factor after time-frequency | |
| decomposition. | |
| If `int`, returns tfr[..., ::decim]. | |
| If `slice`, returns tfr[..., decim]. | |
| .. note:: | |
| Decimation may create aliasing artifacts, yet decimation | |
| is done after the convolutions. | |
| output : str | |
| * 'complex' (default) : single trial complex. | |
| * 'power' : single trial power. | |
| * 'phase' : single trial phase. | |
| * 'avg_power' : average of single trial power. | |
| * 'itc' : inter-trial coherence. | |
| * 'avg_power_itc' : average of single trial power and inter-trial | |
| coherence across trials. | |
| return_weights : bool, default False | |
| Whether to return the taper weights. Only applies if method='multitaper' and | |
| output='complex' or 'phase'. | |
| %(n_jobs)s | |
| The number of epochs to process at the same time. The parallelization | |
| is implemented across channels. | |
| %(verbose)s | |
| Returns | |
| ------- | |
| out : array | |
| Time frequency transform of epoch_data. If output is in ['complex', | |
| 'phase', 'power'], then shape of ``out`` is ``(n_epochs, n_chans, | |
| n_freqs, n_times)``, else it is ``(n_chans, n_freqs, n_times)``. | |
| However, using multitaper method and output ``'complex'`` or | |
| ``'phase'`` results in shape of ``out`` being ``(n_epochs, n_chans, | |
| n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the | |
| real values in the ``output`` contain average power' and the imaginary | |
| values contain the ITC: ``out = avg_power + i * itc``. | |
| weights : array of shape (n_tapers, n_freqs) | |
| The taper weights. Only returned if method='multitaper', output='complex' or | |
| 'phase', and return_weights=True. | |
| """ | |
| # Check data | |
| epoch_data = np.asarray(epoch_data) | |
| if epoch_data.ndim != 3: | |
| raise ValueError( | |
| "epoch_data must be of shape (n_epochs, n_chans, " | |
| f"n_times), got {epoch_data.shape}" | |
| ) | |
| # Check params | |
| freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim = _check_tfr_param( | |
| freqs, | |
| sfreq, | |
| method, | |
| zero_mean, | |
| n_cycles, | |
| time_bandwidth, | |
| use_fft, | |
| decim, | |
| output, | |
| ) | |
| return_weights = ( | |
| return_weights and method == "multitaper" and output in ["complex", "phase"] | |
| ) | |
| decim = _ensure_slice(decim) | |
| if (freqs > sfreq / 2.0).any(): | |
| raise ValueError( | |
| "Cannot compute freq above Nyquist freq of the data " | |
| f"({sfreq / 2.0:0.1f} Hz), got {freqs.max():0.1f} Hz" | |
| ) | |
| # We decimate *after* decomposition, so we need to create our kernels | |
| # for the original sfreq | |
| if method == "morlet": | |
| W = morlet(sfreq, freqs, n_cycles=n_cycles, zero_mean=zero_mean) | |
| Ws = [W] # to have same dimensionality as the 'multitaper' case | |
| weights = None # no tapers for Morlet estimates | |
| elif method == "multitaper": | |
| Ws, weights = _make_dpss( | |
| sfreq, | |
| freqs, | |
| n_cycles=n_cycles, | |
| time_bandwidth=time_bandwidth, | |
| zero_mean=zero_mean, | |
| return_weights=True, # required for converting complex → power | |
| ) | |
| weights = np.asarray(weights) | |
| # Check wavelets | |
| if len(Ws[0][0]) > epoch_data.shape[2]: | |
| raise ValueError( | |
| "At least one of the wavelets is longer than the " | |
| f"signal ({len(Ws[0][0])} > {epoch_data.shape[2]} samples). " | |
| "Use a longer signal or shorter wavelets." | |
| ) | |
| # Initialize output | |
| n_freqs = len(freqs) | |
| n_tapers = len(Ws) | |
| n_epochs, n_chans, n_times = epoch_data[:, :, decim].shape | |
| if output in ("power", "phase", "avg_power", "itc"): | |
| dtype = np.float64 | |
| elif output in ("complex", "avg_power_itc"): | |
| # avg_power_itc is stored as power + 1i * itc to keep a | |
| # simple dimensionality | |
| dtype = np.complex128 | |
| if ("avg_" in output) or ("itc" in output): | |
| out = np.empty((n_chans, n_freqs, n_times), dtype) | |
| elif output in ["complex", "phase"] and method == "multitaper": | |
| out = np.empty((n_chans, n_epochs, n_tapers, n_freqs, n_times), dtype) | |
| else: | |
| out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype) | |
| # Parallel computation | |
| all_Ws = sum([list(W) for W in Ws], list()) | |
| _get_nfft(all_Ws, epoch_data, use_fft) | |
| parallel, my_cwt, n_jobs = parallel_func(_time_frequency_loop, n_jobs) | |
| # Parallelization is applied across channels. | |
| tfrs = parallel( | |
| my_cwt(channel, Ws, output, use_fft, "same", decim, weights) | |
| for channel in epoch_data.transpose(1, 0, 2) | |
| ) | |
| # FIXME: to avoid overheads we should use np.array_split() | |
| for channel_idx, tfr in enumerate(tfrs): | |
| out[channel_idx] = tfr | |
| if ("avg_" not in output) and ("itc" not in output): | |
| # This is to enforce that the first dimension is for epochs | |
| out = np.moveaxis(out, 1, 0) | |
| if return_weights: | |
| return out, weights | |
| return out | |
| def _check_tfr_param( | |
| freqs, sfreq, method, zero_mean, n_cycles, time_bandwidth, use_fft, decim, output | |
| ): | |
| """Aux. function to _compute_tfr to check the params validity.""" | |
| # Check freqs | |
| if not isinstance(freqs, list | np.ndarray): | |
| raise ValueError(f"freqs must be an array-like, got {type(freqs)} instead.") | |
| freqs = np.asarray(freqs, dtype=float) | |
| if freqs.ndim != 1: | |
| raise ValueError( | |
| f"freqs must be of shape (n_freqs,), got {np.array(freqs.shape)} instead." | |
| ) | |
| # Check sfreq | |
| if not isinstance(sfreq, float | int): | |
| raise ValueError(f"sfreq must be a float or an int, got {type(sfreq)} instead.") | |
| sfreq = float(sfreq) | |
| # Default zero_mean = True if multitaper else False | |
| zero_mean = method == "multitaper" if zero_mean is None else zero_mean | |
| if not isinstance(zero_mean, bool): | |
| raise ValueError( | |
| f"zero_mean should be of type bool, got {type(zero_mean)}. instead" | |
| ) | |
| freqs = np.asarray(freqs) | |
| # Check n_cycles | |
| if isinstance(n_cycles, int | float): | |
| n_cycles = float(n_cycles) | |
| elif isinstance(n_cycles, list | np.ndarray): | |
| n_cycles = np.array(n_cycles) | |
| if len(n_cycles) != len(freqs): | |
| raise ValueError( | |
| "n_cycles must be a float or an array of length " | |
| f"{len(freqs)} frequencies, got {len(n_cycles)} cycles instead." | |
| ) | |
| else: | |
| raise ValueError( | |
| f"n_cycles must be a float or an array, got {type(n_cycles)} instead." | |
| ) | |
| # Check time_bandwidth | |
| if (method == "morlet") and (time_bandwidth is not None): | |
| raise ValueError('time_bandwidth only applies to "multitaper" method.') | |
| elif method == "multitaper": | |
| time_bandwidth = 4.0 if time_bandwidth is None else float(time_bandwidth) | |
| # Check use_fft | |
| if not isinstance(use_fft, bool): | |
| raise ValueError(f"use_fft must be a boolean, got {type(use_fft)} instead.") | |
| # Check decim | |
| if isinstance(decim, int): | |
| decim = slice(None, None, decim) | |
| if not isinstance(decim, slice): | |
| raise ValueError( | |
| f"decim must be an integer or a slice, got {type(decim)} instead." | |
| ) | |
| # Check output | |
| _check_option( | |
| "output", | |
| output, | |
| ["complex", "power", "phase", "avg_power_itc", "avg_power", "itc"], | |
| ) | |
| _check_option("method", method, ["multitaper", "morlet"]) | |
| return freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim | |
| def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, weights=None): | |
| """Aux. function to _compute_tfr. | |
| Loops time-frequency transform across wavelets and epochs. | |
| Parameters | |
| ---------- | |
| X : array, shape (n_epochs, n_times) | |
| The epochs data of a single channel. | |
| Ws : list, shape (n_tapers, n_wavelets, n_times) | |
| The wavelets. | |
| output : str | |
| * 'complex' : single trial complex. | |
| * 'power' : single trial power. | |
| * 'phase' : single trial phase. | |
| * 'avg_power' : average of single trial power. | |
| * 'itc' : inter-trial coherence. | |
| * 'avg_power_itc' : average of single trial power and inter-trial | |
| coherence across trials. | |
| use_fft : bool | |
| Use the FFT for convolutions or not. | |
| mode : {'full', 'valid', 'same'} | |
| See numpy.convolve. | |
| decim : slice | |
| The decimation slice: e.g. power[:, decim] | |
| weights : array, shape (n_tapers, n_wavelets) | None | |
| Concentration weights for each taper in the wavelets, if present. | |
| """ | |
| # Set output type | |
| dtype = np.float64 | |
| if output in ["complex", "avg_power_itc"]: | |
| dtype = np.complex128 | |
| # Init outputs | |
| decim = _ensure_slice(decim) | |
| n_tapers = len(Ws) | |
| n_epochs, n_times = X[:, decim].shape | |
| n_freqs = len(Ws[0]) | |
| if ("avg_" in output) or ("itc" in output): | |
| tfrs = np.zeros((n_freqs, n_times), dtype=dtype) | |
| elif output in ["complex", "phase"] and weights is not None: | |
| tfrs = np.zeros((n_epochs, n_tapers, n_freqs, n_times), dtype=dtype) | |
| else: | |
| tfrs = np.zeros((n_epochs, n_freqs, n_times), dtype=dtype) | |
| if weights is not None: | |
| weights = np.expand_dims(weights, axis=-1) # add singleton time dimension | |
| # Loops across tapers. | |
| for taper_idx, W in enumerate(Ws): | |
| # No need to check here, it's done earlier (outside parallel part) | |
| nfft = _get_nfft(W, X, use_fft, check=False) | |
| coefs = _cwt_gen(X, W, fsize=nfft, mode=mode, decim=decim, use_fft=use_fft) | |
| # Inter-trial phase locking is apparently computed per taper... | |
| if "itc" in output: | |
| plf = np.zeros((n_freqs, n_times), dtype=np.complex128) | |
| # Loop across epochs | |
| for epoch_idx, tfr in enumerate(coefs): | |
| # Transform complex values | |
| if output not in ["complex", "phase"] and weights is not None: | |
| tfr = weights[taper_idx] * tfr # weight each taper estimate | |
| if output in ["power", "avg_power"]: | |
| tfr = (tfr * tfr.conj()).real # power | |
| elif output == "phase": | |
| tfr = np.angle(tfr) | |
| elif output == "avg_power_itc": | |
| tfr_abs = np.abs(tfr) | |
| plf += tfr / tfr_abs # phase | |
| tfr = tfr_abs**2 # power | |
| elif output == "itc": | |
| plf += tfr / np.abs(tfr) # phase | |
| continue # not need to stack anything else than plf | |
| # Stack or add | |
| if ("avg_" in output) or ("itc" in output): | |
| tfrs += tfr | |
| elif output in ["complex", "phase"] and weights is not None: | |
| tfrs[epoch_idx, taper_idx] += tfr | |
| else: | |
| tfrs[epoch_idx] += tfr | |
| # Compute inter trial coherence | |
| if output == "avg_power_itc": | |
| tfrs += 1j * np.abs(plf) | |
| elif output == "itc": | |
| tfrs += np.abs(plf) | |
| # Normalization of average metrics | |
| if ("avg_" in output) or ("itc" in output): | |
| tfrs /= n_epochs | |
| # Normalization by taper weights | |
| if n_tapers > 1 and output not in ["complex", "phase", "itc"]: | |
| if "avg_" not in output: # add singleton epochs dimension to weights | |
| weights = np.expand_dims(weights, axis=0) | |
| tfrs.real *= 2 / (weights * weights.conj()).real.sum(axis=-3) | |
| if output == "avg_power_itc": # weight itc by the number of tapers | |
| tfrs.imag = tfrs.imag / n_tapers | |
| return tfrs | |
| def cwt(X, Ws, use_fft=True, mode="same", decim=1): | |
| """Compute time-frequency decomposition with continuous wavelet transform. | |
| Parameters | |
| ---------- | |
| X : array, shape (n_signals, n_times) | |
| The signals. | |
| Ws : list of array | |
| Wavelets time series. | |
| use_fft : bool | |
| Use FFT for convolutions. Defaults to True. | |
| mode : 'same' | 'valid' | 'full' | |
| Convention for convolution. 'full' is currently not implemented with | |
| ``use_fft=False``. Defaults to ``'same'``. | |
| %(decim_tfr)s | |
| Returns | |
| ------- | |
| tfr : array, shape (n_signals, n_freqs, n_times) | |
| The time-frequency decompositions. | |
| See Also | |
| -------- | |
| mne.time_frequency.tfr_morlet : Compute time-frequency decomposition | |
| with Morlet wavelets. | |
| """ | |
| nfft = _get_nfft(Ws, X, use_fft) | |
| return _cwt_array(X, Ws, nfft, mode, decim, use_fft) | |
| def _cwt_array(X, Ws, nfft, mode, decim, use_fft): | |
| decim = _ensure_slice(decim) | |
| coefs = _cwt_gen(X, Ws, fsize=nfft, mode=mode, decim=decim, use_fft=use_fft) | |
| n_signals, n_times = X[:, decim].shape | |
| tfrs = np.empty((n_signals, len(Ws), n_times), dtype=np.complex128) | |
| for k, tfr in enumerate(coefs): | |
| tfrs[k] = tfr | |
| return tfrs | |
| def _tfr_aux( | |
| method, inst, freqs, decim, return_itc, picks, average, output, **tfr_params | |
| ): | |
| from ..epochs import BaseEpochs | |
| kwargs = dict( | |
| method=method, | |
| freqs=freqs, | |
| picks=picks, | |
| decim=decim, | |
| output=output, | |
| **tfr_params, | |
| ) | |
| if isinstance(inst, BaseEpochs): | |
| kwargs.update(average=average, return_itc=return_itc) | |
| elif average: | |
| logger.info("inst is Evoked, setting `average=False`") | |
| average = False | |
| if average and output == "complex": | |
| raise ValueError('output must be "power" if average=True') | |
| if not average and return_itc: | |
| raise ValueError("Inter-trial coherence is not supported with average=False") | |
| return inst.compute_tfr(**kwargs) | |
| def tfr_morlet( | |
| inst, | |
| freqs, | |
| n_cycles, | |
| use_fft=False, | |
| return_itc=True, | |
| decim=1, | |
| n_jobs=None, | |
| picks=None, | |
| zero_mean=True, | |
| average=True, | |
| output="power", | |
| verbose=None, | |
| ): | |
| """Compute Time-Frequency Representation (TFR) using Morlet wavelets. | |
| Same computation as `~mne.time_frequency.tfr_array_morlet`, but | |
| operates on `~mne.Epochs` or `~mne.Evoked` objects instead of | |
| :class:`NumPy arrays <numpy.ndarray>`. | |
| Parameters | |
| ---------- | |
| inst : Epochs | Evoked | |
| The epochs or evoked object. | |
| %(freqs_tfr_array)s | |
| %(n_cycles_tfr)s | |
| use_fft : bool, default False | |
| The fft based convolution or not. | |
| return_itc : bool, default True | |
| Return inter-trial coherence (ITC) as well as averaged power. | |
| Must be ``False`` for evoked data. | |
| %(decim_tfr)s | |
| %(n_jobs)s | |
| picks : array-like of int | None, default None | |
| The indices of the channels to decompose. If None, all available | |
| good data channels are decomposed. | |
| zero_mean : bool, default True | |
| Make sure the wavelet has a mean of zero. | |
| .. versionadded:: 0.13.0 | |
| %(average_tfr)s | |
| output : str | |
| Can be ``"power"`` (default) or ``"complex"``. If ``"complex"``, then | |
| ``average`` must be ``False``. | |
| .. versionadded:: 0.15.0 | |
| %(verbose)s | |
| Returns | |
| ------- | |
| power : AverageTFR | EpochsTFR | |
| The averaged or single-trial power. | |
| itc : AverageTFR | EpochsTFR | |
| The inter-trial coherence (ITC). Only returned if return_itc | |
| is True. | |
| See Also | |
| -------- | |
| mne.time_frequency.tfr_array_morlet | |
| mne.time_frequency.tfr_multitaper | |
| mne.time_frequency.tfr_array_multitaper | |
| mne.time_frequency.tfr_stockwell | |
| mne.time_frequency.tfr_array_stockwell | |
| Notes | |
| ----- | |
| %(morlet_reference)s | |
| %(temporal_window_tfr_intro)s | |
| %(temporal_window_tfr_morlet_notes)s | |
| See :func:`mne.time_frequency.morlet` for more information about the | |
| Morlet wavelet. | |
| References | |
| ---------- | |
| .. footbibliography:: | |
| """ | |
| tfr_params = dict( | |
| n_cycles=n_cycles, | |
| n_jobs=n_jobs, | |
| use_fft=use_fft, | |
| zero_mean=zero_mean, | |
| output=output, | |
| ) | |
| return _tfr_aux( | |
| "morlet", inst, freqs, decim, return_itc, picks, average, **tfr_params | |
| ) | |
| def tfr_array_morlet( | |
| data, | |
| sfreq, | |
| freqs, | |
| n_cycles=7.0, | |
| zero_mean=True, | |
| use_fft=True, | |
| decim=1, | |
| output="complex", | |
| n_jobs=None, | |
| *, | |
| verbose=None, | |
| ): | |
| """Compute Time-Frequency Representation (TFR) using Morlet wavelets. | |
| Same computation as `~mne.time_frequency.tfr_morlet`, but operates on | |
| :class:`NumPy arrays <numpy.ndarray>` instead of `~mne.Epochs` objects. | |
| Parameters | |
| ---------- | |
| data : array of shape (n_epochs, n_channels, n_times) | |
| The epochs. | |
| sfreq : float | int | |
| Sampling frequency of the data. | |
| %(freqs_tfr_array)s | |
| %(n_cycles_tfr)s | |
| zero_mean : bool | None | |
| If True, make sure the wavelets have a mean of zero. default False. | |
| .. versionchanged:: 1.8 | |
| The default will change from ``zero_mean=False`` in 1.6 to ``True`` in | |
| 1.8. | |
| use_fft : bool | |
| Use the FFT for convolutions or not. default True. | |
| %(decim_tfr)s | |
| output : str, default ``'complex'`` | |
| * ``'complex'`` : single trial complex. | |
| * ``'power'`` : single trial power. | |
| * ``'phase'`` : single trial phase. | |
| * ``'avg_power'`` : average of single trial power. | |
| * ``'itc'`` : inter-trial coherence. | |
| * ``'avg_power_itc'`` : average of single trial power and inter-trial | |
| coherence across trials. | |
| %(n_jobs)s | |
| The number of epochs to process at the same time. The parallelization | |
| is implemented across channels. Default 1. | |
| %(verbose)s | |
| Returns | |
| ------- | |
| out : array | |
| Time frequency transform of ``data``. | |
| - if ``output in ('complex', 'phase', 'power')``, array of shape | |
| ``(n_epochs, n_chans, n_freqs, n_times)`` | |
| - else, array of shape ``(n_chans, n_freqs, n_times)`` | |
| If ``output`` is ``'avg_power_itc'``, the real values in ``out`` | |
| contain the average power and the imaginary values contain the ITC: | |
| :math:`out = power_{avg} + i * itc`. | |
| See Also | |
| -------- | |
| mne.time_frequency.tfr_morlet | |
| mne.time_frequency.tfr_multitaper | |
| mne.time_frequency.tfr_array_multitaper | |
| mne.time_frequency.tfr_stockwell | |
| mne.time_frequency.tfr_array_stockwell | |
| Notes | |
| ----- | |
| %(morlet_reference)s | |
| %(temporal_window_tfr_intro)s | |
| %(temporal_window_tfr_morlet_notes)s | |
| .. versionadded:: 0.14.0 | |
| References | |
| ---------- | |
| .. footbibliography:: | |
| """ | |
| return _compute_tfr( | |
| epoch_data=data, | |
| freqs=freqs, | |
| sfreq=sfreq, | |
| method="morlet", | |
| n_cycles=n_cycles, | |
| zero_mean=zero_mean, | |
| time_bandwidth=None, | |
| use_fft=use_fft, | |
| decim=decim, | |
| output=output, | |
| n_jobs=n_jobs, | |
| verbose=verbose, | |
| ) | |
| def tfr_multitaper( | |
| inst, | |
| freqs, | |
| n_cycles, | |
| time_bandwidth=4.0, | |
| use_fft=True, | |
| return_itc=True, | |
| decim=1, | |
| n_jobs=None, | |
| picks=None, | |
| average=True, | |
| *, | |
| verbose=None, | |
| ): | |
| """Compute Time-Frequency Representation (TFR) using DPSS tapers. | |
| Same computation as :func:`~mne.time_frequency.tfr_array_multitaper`, but | |
| operates on :class:`~mne.Epochs` or :class:`~mne.Evoked` objects instead of | |
| :class:`NumPy arrays <numpy.ndarray>`. | |
| Parameters | |
| ---------- | |
| inst : Epochs | Evoked | |
| The epochs or evoked object. | |
| %(freqs_tfr_array)s | |
| %(n_cycles_tfr)s | |
| %(time_bandwidth_tfr)s | |
| use_fft : bool, default True | |
| The fft based convolution or not. | |
| return_itc : bool, default True | |
| Return inter-trial coherence (ITC) as well as averaged (or | |
| single-trial) power. | |
| %(decim_tfr)s | |
| %(n_jobs)s | |
| %(picks_good_data)s | |
| %(average_tfr)s | |
| %(verbose)s | |
| Returns | |
| ------- | |
| power : AverageTFR | EpochsTFR | |
| The averaged or single-trial power. | |
| itc : AverageTFR | EpochsTFR | |
| The inter-trial coherence (ITC). Only returned if return_itc | |
| is True. | |
| See Also | |
| -------- | |
| mne.time_frequency.tfr_array_multitaper | |
| mne.time_frequency.tfr_stockwell | |
| mne.time_frequency.tfr_array_stockwell | |
| mne.time_frequency.tfr_morlet | |
| mne.time_frequency.tfr_array_morlet | |
| Notes | |
| ----- | |
| %(temporal_window_tfr_intro)s | |
| %(temporal_window_tfr_multitaper_notes)s | |
| %(time_bandwidth_tfr_notes)s | |
| .. versionadded:: 0.9.0 | |
| """ | |
| from ..epochs import EpochsArray | |
| from ..evoked import Evoked | |
| tfr_params = dict( | |
| n_cycles=n_cycles, | |
| n_jobs=n_jobs, | |
| use_fft=use_fft, | |
| zero_mean=True, | |
| time_bandwidth=time_bandwidth, | |
| ) | |
| if isinstance(inst, Evoked) and not average: | |
| # convert AverageTFR to EpochsTFR for backwards compatibility | |
| inst = EpochsArray(inst.data[np.newaxis], inst.info, tmin=inst.tmin, proj=False) | |
| return _tfr_aux( | |
| method="multitaper", | |
| inst=inst, | |
| freqs=freqs, | |
| decim=decim, | |
| return_itc=return_itc, | |
| picks=picks, | |
| average=average, | |
| output="power", | |
| **tfr_params, | |
| ) | |
| # TFR(s) class | |
| class BaseTFR(ContainsMixin, UpdateChannelsMixin, SizeMixin, ExtendedTimeMixin): | |
| """Base class for RawTFR, EpochsTFR, and AverageTFR (for type checking only). | |
| .. note:: | |
| This class should not be instantiated directly; it is provided in the public API | |
| only for type-checking purposes (e.g., ``isinstance(my_obj, BaseTFR)``). To | |
| create TFR objects, use the ``.compute_tfr()`` methods on :class:`~mne.io.Raw`, | |
| :class:`~mne.Epochs`, or :class:`~mne.Evoked`, or use the constructors listed | |
| below under "See Also". | |
| Parameters | |
| ---------- | |
| inst : instance of Raw, Epochs, or Evoked | |
| The data from which to compute the time-frequency representation. | |
| %(method_tfr)s | |
| %(freqs_tfr)s | |
| %(tmin_tmax_psd)s | |
| %(picks_good_data_noref)s | |
| %(proj_psd)s | |
| %(decim_tfr)s | |
| %(n_jobs)s | |
| %(reject_by_annotation_tfr)s | |
| %(verbose)s | |
| %(method_kw_tfr)s | |
| See Also | |
| -------- | |
| mne.time_frequency.RawTFR | |
| mne.time_frequency.RawTFRArray | |
| mne.time_frequency.EpochsTFR | |
| mne.time_frequency.EpochsTFRArray | |
| mne.time_frequency.AverageTFR | |
| mne.time_frequency.AverageTFRArray | |
| """ | |
| def __init__( | |
| self, | |
| inst, | |
| method, | |
| freqs, | |
| tmin, | |
| tmax, | |
| picks, | |
| proj, | |
| *, | |
| decim, | |
| n_jobs, | |
| reject_by_annotation=None, | |
| verbose=None, | |
| **method_kw, | |
| ): | |
| from ..epochs import BaseEpochs | |
| from ._stockwell import tfr_array_stockwell | |
| # triage reading from file | |
| if isinstance(inst, dict): | |
| self.__setstate__(inst) | |
| return | |
| if method is None or freqs is None: | |
| problem = [ | |
| f"{k}=None" | |
| for k, v in dict(method=method, freqs=freqs).items() | |
| if v is None | |
| ] | |
| # TODO when py3.11 is min version, replace if/elif/else block with | |
| # classname = inspect.currentframe().f_back.f_code.co_qualname.split(".")[0] | |
| _varnames = inspect.currentframe().f_back.f_code.co_varnames | |
| if "BaseRaw" in _varnames: | |
| classname = "RawTFR" | |
| elif "Evoked" in _varnames: | |
| classname = "AverageTFR" | |
| else: | |
| assert "BaseEpochs" in _varnames and "Evoked" not in _varnames | |
| classname = "EpochsTFR" | |
| # end TODO | |
| raise ValueError( | |
| f"{classname} got unsupported parameter value{_pl(problem)} " | |
| f"{' and '.join(problem)}." | |
| ) | |
| # check method | |
| valid_methods = ["morlet", "multitaper"] | |
| if isinstance(inst, BaseEpochs): | |
| valid_methods.append("stockwell") | |
| method = _check_option("method", method, valid_methods) | |
| # for stockwell, `tmin, tmax` already added to `method_kw` by calling method, | |
| # and `freqs` vector has been pre-computed | |
| if method != "stockwell": | |
| method_kw.update(freqs=freqs) | |
| # ↓↓↓ if constructor called directly, prevents key error | |
| method_kw.setdefault("output", "power") | |
| self._freqs = np.asarray(freqs, dtype=np.float64) | |
| del freqs | |
| # always store weights for per-taper outputs | |
| if method == "multitaper" and method_kw.get("output") in ["complex", "phase"]: | |
| method_kw["return_weights"] = True | |
| # check validity of kwargs manually to save compute time if any are invalid | |
| tfr_funcs = dict( | |
| morlet=tfr_array_morlet, | |
| multitaper=tfr_array_multitaper, | |
| stockwell=tfr_array_stockwell, | |
| ) | |
| _check_method_kwargs(tfr_funcs[method], method_kw, msg=f'TFR method "{method}"') | |
| self._tfr_func = partial(tfr_funcs[method], **method_kw) | |
| # apply proj if desired | |
| if proj: | |
| inst = inst.copy().apply_proj() | |
| self.inst = inst | |
| # prep picks and add the info object. bads and non-data channels are dropped by | |
| # _picks_to_idx() so we update the info accordingly: | |
| self._picks = _picks_to_idx(inst.info, picks, "data", with_ref_meg=False) | |
| self.info = pick_info(inst.info, sel=self._picks, copy=True) | |
| # assign some attributes | |
| self._method = method | |
| self._inst_type = type(inst) | |
| self._baseline = None | |
| self._weights = None | |
| self.preload = True # needed for __getitem__, never False for TFRs | |
| # self._dims may also get updated by child classes | |
| self._dims = ["channel", "freq", "time"] | |
| self._needs_taper_dim = method == "multitaper" and method_kw["output"] in ( | |
| "complex", | |
| "phase", | |
| ) | |
| if self._needs_taper_dim: | |
| self._dims.insert(1, "taper") | |
| self._dims = tuple(self._dims) | |
| # get the instance data. | |
| time_mask = _time_mask(inst.times, tmin, tmax, sfreq=self.sfreq) | |
| get_instance_data_kw = dict(time_mask=time_mask) | |
| if reject_by_annotation is not None: | |
| get_instance_data_kw.update(reject_by_annotation=reject_by_annotation) | |
| data = self._get_instance_data(**get_instance_data_kw) | |
| # compute the TFR | |
| self._decim = _ensure_slice(decim) | |
| self._raw_times = inst.times[time_mask] | |
| self._compute_tfr(data, n_jobs, verbose) | |
| self._update_epoch_attributes() | |
| # "apply" decim to the rest of the object (data is decimated in _compute_tfr) | |
| with self.info._unlock(): | |
| self.info["sfreq"] /= self._decim.step | |
| _decim_times = inst.times[self._decim] | |
| _decim_time_mask = _time_mask(_decim_times, tmin, tmax, sfreq=self.sfreq) | |
| self._raw_times = _decim_times[_decim_time_mask].copy() | |
| self._set_times(self._raw_times) | |
| self._decim = 1 | |
| # record data type (for repr and html_repr). ITC handled in the calling method. | |
| if method == "stockwell": | |
| self._data_type = "Power Estimates" | |
| else: | |
| data_types = dict( | |
| power="Power Estimates", | |
| avg_power="Average Power Estimates", | |
| avg_power_itc="Average Power Estimates", | |
| phase="Phase", | |
| complex="Complex Amplitude", | |
| ) | |
| self._data_type = data_types[method_kw["output"]] | |
| # check for correct shape and bad values. `tfr_array_stockwell` doesn't take kw | |
| # `output` so it may be missing here, so use `.get()` | |
| negative_ok = method_kw.get("output", "") in ("complex", "phase") | |
| # if method_kw.get("output", None) in ("phase", "complex"): | |
| # raise RuntimeError | |
| self._check_values(negative_ok=negative_ok) | |
| # we don't need these anymore, and they make save/load harder | |
| del self._picks | |
| del self._tfr_func | |
| del self._needs_taper_dim | |
| del self._shape # calculated from self._data henceforth | |
| del self.inst # save memory | |
| def __abs__(self): | |
| """Return the absolute value.""" | |
| tfr = self.copy() | |
| tfr.data = np.abs(tfr.data) | |
| return tfr | |
| def __add__(self, other): | |
| """Add two TFR instances. | |
| %(__add__tfr)s | |
| """ | |
| self._check_compatibility(other) | |
| out = self.copy() | |
| out.data += other.data | |
| return out | |
| def __iadd__(self, other): | |
| """Add a TFR instance to another, in-place. | |
| %(__iadd__tfr)s | |
| """ | |
| self._check_compatibility(other) | |
| self.data += other.data | |
| return self | |
| def __sub__(self, other): | |
| """Subtract two TFR instances. | |
| %(__sub__tfr)s | |
| """ | |
| self._check_compatibility(other) | |
| out = self.copy() | |
| out.data -= other.data | |
| return out | |
| def __isub__(self, other): | |
| """Subtract a TFR instance from another, in-place. | |
| %(__isub__tfr)s | |
| """ | |
| self._check_compatibility(other) | |
| self.data -= other.data | |
| return self | |
| def __mul__(self, num): | |
| """Multiply a TFR instance by a scalar. | |
| %(__mul__tfr)s | |
| """ | |
| out = self.copy() | |
| out.data *= num | |
| return out | |
| def __imul__(self, num): | |
| """Multiply a TFR instance by a scalar, in-place. | |
| %(__imul__tfr)s | |
| """ | |
| self.data *= num | |
| return self | |
| def __truediv__(self, num): | |
| """Divide a TFR instance by a scalar. | |
| %(__truediv__tfr)s | |
| """ | |
| out = self.copy() | |
| out.data /= num | |
| return out | |
| def __itruediv__(self, num): | |
| """Divide a TFR instance by a scalar, in-place. | |
| %(__itruediv__tfr)s | |
| """ | |
| self.data /= num | |
| return self | |
| def __eq__(self, other): | |
| """Test equivalence of two TFR instances.""" | |
| return object_diff(vars(self), vars(other)) == "" | |
| def __getstate__(self): | |
| """Prepare object for serialization.""" | |
| return dict( | |
| method=self.method, | |
| data=self._data, | |
| sfreq=self.sfreq, | |
| dims=self._dims, | |
| freqs=self.freqs, | |
| times=self.times, | |
| inst_type_str=_get_instance_type_string(self), | |
| data_type=self._data_type, | |
| info=self.info, | |
| baseline=self._baseline, | |
| decim=self._decim, | |
| weights=self._weights, | |
| ) | |
| def __setstate__(self, state): | |
| """Unpack from serialized format.""" | |
| from ..epochs import Epochs | |
| from ..evoked import Evoked | |
| from ..io import Raw | |
| defaults = dict( | |
| method="unknown", | |
| baseline=None, | |
| decim=1, | |
| data_type="TFR", | |
| inst_type_str="Unknown", | |
| ) | |
| defaults.update(**state) | |
| self._method = defaults["method"] | |
| self._data = defaults["data"] | |
| self._freqs = np.asarray(defaults["freqs"], dtype=np.float64) | |
| self._dims = defaults["dims"] | |
| self._raw_times = np.asarray(defaults["times"], dtype=np.float64) | |
| self._baseline = defaults["baseline"] | |
| self.info = Info(**defaults["info"]) | |
| self._data_type = defaults["data_type"] | |
| self._decim = defaults["decim"] | |
| self.preload = True | |
| self._set_times(self._raw_times) | |
| self._weights = state.get("weights") # objs saved before #12910 won't have | |
| # Handle instance type. Prior to gh-11282, Raw was not a possibility so if | |
| # `inst_type_str` is missing it must be Epochs or Evoked | |
| unknown_class = Epochs if "epoch" in self._dims else Evoked | |
| inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Unknown=unknown_class) | |
| self._inst_type = inst_types[defaults["inst_type_str"]] | |
| # sanity check data/freqs/times/info/weights agreement | |
| self._check_state() | |
| def __repr__(self): | |
| """Build string representation of the TFR object.""" | |
| inst_type_str = _get_instance_type_string(self) | |
| nave = f" (nave={self.nave})" if hasattr(self, "nave") else "" | |
| # shape & dimension names | |
| dims = " × ".join( | |
| [f"{size} {dim}s" for size, dim in zip(self.shape, self._dims)] | |
| ) | |
| freq_range = f"{self.freqs[0]:0.1f} - {self.freqs[-1]:0.1f} Hz" | |
| time_range = f"{self.times[0]:0.2f} - {self.times[-1]:0.2f} s" | |
| return ( | |
| f"<{self._data_type} from {inst_type_str}{nave}, " | |
| f"{self.method} method | {dims}, {freq_range}, {time_range}, " | |
| f"{sizeof_fmt(self._size)}>" | |
| ) | |
| def _repr_html_(self, caption=None): | |
| """Build HTML representation of the TFR object.""" | |
| from ..html_templates import _get_html_template | |
| inst_type_str = _get_instance_type_string(self) | |
| nave = getattr(self, "nave", 0) | |
| t = _get_html_template("repr", "tfr.html.jinja") | |
| t = t.render(tfr=self, inst_type=inst_type_str, nave=nave, caption=caption) | |
| return t | |
| def _check_compatibility(self, other): | |
| """Check compatibility of two TFR instances, in preparation for arithmetic.""" | |
| operation = inspect.currentframe().f_back.f_code.co_name.strip("_") | |
| if operation.startswith("i"): | |
| operation = operation[1:] | |
| msg = f"Cannot {operation} the two TFR instances: {{}} do not match{{}}." | |
| extra = "" | |
| if not isinstance(other, type(self)): | |
| problem = "types" | |
| extra = f" (self is {type(self)}, other is {type(other)})" | |
| elif not self.times.shape == other.times.shape or np.any( | |
| self.times != other.times | |
| ): | |
| problem = "times" | |
| elif not self.freqs.shape == other.freqs.shape or np.any( | |
| self.freqs != other.freqs | |
| ): | |
| problem = "freqs" | |
| else: # should be OK | |
| return | |
| raise RuntimeError(msg.format(problem, extra)) | |
| def _check_state(self): | |
| """Check data/freqs/times/info/weights agreement during __setstate__.""" | |
| msg = "{} axis of data ({}) doesn't match {} attribute ({})" | |
| n_chan_info = len(self.info["chs"]) | |
| n_chan = self._data.shape[self._dims.index("channel")] | |
| n_freq = self._data.shape[self._dims.index("freq")] | |
| n_time = self._data.shape[self._dims.index("time")] | |
| n_taper = ( | |
| self._data.shape[self._dims.index("taper")] | |
| if "taper" in self._dims | |
| else None | |
| ) | |
| if n_taper is not None and self._weights is None: | |
| raise ValueError("Taper dimension in data, but no weights found.") | |
| if n_chan_info != n_chan: | |
| msg = msg.format("Channel", n_chan, "info", n_chan_info) | |
| elif n_freq != len(self.freqs): | |
| msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size) | |
| elif n_time != len(self.times): | |
| msg = msg.format("Time", n_time, "times", self.times.size) | |
| elif n_taper is not None and n_taper != self._weights.shape[0]: | |
| msg = msg.format("Taper", n_taper, "weights", self._weights.shape[0]) | |
| elif n_taper is not None and n_freq != self._weights.shape[1]: | |
| msg = msg.format("Frequency", n_freq, "weights", self._weights.shape[1]) | |
| else: | |
| return | |
| raise ValueError(msg) | |
| def _check_values(self, negative_ok=False): | |
| """Check TFR results for correct shape and bad values.""" | |
| assert len(self._dims) == self._data.ndim | |
| assert self._data.shape == self._shape | |
| # Check for implausible power values: take min() across all but the channel axis | |
| # TODO: should this be more fine-grained (report "chan X in epoch Y")? | |
| ch_dim = self._dims.index("channel") | |
| dims = np.arange(self._data.ndim).tolist() | |
| dims.pop(ch_dim) | |
| negative_values = self._data.min(axis=tuple(dims)) < 0 | |
| if negative_values.any() and not negative_ok: | |
| chs = np.array(self.ch_names)[negative_values].tolist() | |
| s = _pl(negative_values.sum()) | |
| warn( | |
| f"Negative value in time-frequency decomposition for channel{s} " | |
| f"{', '.join(chs)}", | |
| UserWarning, | |
| ) | |
| def _compute_tfr(self, data, n_jobs, verbose): | |
| result = self._tfr_func( | |
| data, | |
| self.sfreq, | |
| decim=self._decim, | |
| n_jobs=n_jobs, | |
| verbose=verbose, | |
| ) | |
| # assign ._data and maybe ._itc | |
| # tfr_array_stockwell always returns ITC (sometimes it's None) | |
| if self.method == "stockwell": | |
| self._data, self._itc, freqs = result | |
| assert np.array_equal(self._freqs, freqs) | |
| elif self.method == "multitaper" and self._tfr_func.keywords.get( | |
| "output", "" | |
| ) in ["complex", "phase"]: | |
| self._data, self._weights = result | |
| elif self._tfr_func.keywords.get("output", "").endswith("_itc"): | |
| self._data, self._itc = result.real, result.imag | |
| else: | |
| self._data = result | |
| # remove fake "epoch" dimension | |
| if self.method != "stockwell" and _get_instance_type_string(self) != "Epochs": | |
| self._data = np.squeeze(self._data, axis=0) | |
| # this is *expected* shape, it gets asserted later in _check_values() | |
| # (and then deleted afterwards) | |
| expected_shape = [ | |
| len(self.ch_names), | |
| len(self.freqs), | |
| len(self._raw_times[self._decim]), # don't use self.times, not set yet | |
| ] | |
| # deal with the "taper" dimension | |
| if self._needs_taper_dim: | |
| tapers_dim = 1 if _get_instance_type_string(self) != "Epochs" else 2 | |
| expected_shape.insert(1, self._data.shape[tapers_dim]) | |
| self._shape = tuple(expected_shape) | |
| def _onselect( | |
| self, | |
| eclick, | |
| erelease, | |
| picks=None, | |
| exclude="bads", | |
| combine="mean", | |
| baseline=None, | |
| mode=None, | |
| cmap=None, | |
| source_plot_joint=False, | |
| topomap_args=None, | |
| verbose=None, | |
| ): | |
| """Respond to rectangle selector in TFR image plots with a topomap plot.""" | |
| if abs(eclick.x - erelease.x) < 0.1 or abs(eclick.y - erelease.y) < 0.1: | |
| return | |
| t_range = (min(eclick.xdata, erelease.xdata), max(eclick.xdata, erelease.xdata)) | |
| f_range = (min(eclick.ydata, erelease.ydata), max(eclick.ydata, erelease.ydata)) | |
| # snap to nearest measurement point | |
| t_idx = np.abs(self.times - np.atleast_2d(t_range).T).argmin(axis=1) | |
| f_idx = np.abs(self.freqs - np.atleast_2d(f_range).T).argmin(axis=1) | |
| tmin, tmax = self.times[t_idx] | |
| fmin, fmax = self.freqs[f_idx] | |
| # immutable → mutable default | |
| if topomap_args is None: | |
| topomap_args = dict() | |
| topomap_args.setdefault("cmap", cmap) | |
| topomap_args.setdefault("vlim", (None, None)) | |
| # figure out which channel types we're dealing with | |
| types = list() | |
| if "eeg" in self: | |
| types.append("eeg") | |
| if "mag" in self: | |
| types.append("mag") | |
| if "grad" in self: | |
| grad_picks = _pair_grad_sensors( | |
| self.info, topomap_coords=False, raise_error=False | |
| ) | |
| if len(grad_picks) > 1: | |
| types.append("grad") | |
| elif len(types) == 0: | |
| logger.info( | |
| "Need at least 2 gradiometer pairs to plot a gradiometer topomap." | |
| ) | |
| return # Don't draw a figure for nothing. | |
| fig = figure_nobar() | |
| t_range = f"{tmin:.3f}" if tmin == tmax else f"{tmin:.3f} - {tmax:.3f}" | |
| f_range = f"{fmin:.2f}" if fmin == fmax else f"{fmin:.2f} - {fmax:.2f}" | |
| fig.suptitle(f"{t_range} s,\n{f_range} Hz") | |
| if source_plot_joint: | |
| ax = fig.add_subplot() | |
| data, times, freqs = self.get_data( | |
| picks=picks, exclude=exclude, return_times=True, return_freqs=True | |
| ) | |
| # merge grads before baselining (makes ERDs visible) | |
| ch_types = np.array(self.get_channel_types(unique=True)) | |
| ch_type = ch_types.item() # will error if there are more than one | |
| data, pos = _merge_if_grads( | |
| data=data, | |
| info=self.info, | |
| ch_type=ch_type, | |
| sphere=topomap_args.get("sphere"), | |
| combine=combine, | |
| ) | |
| # baseline and crop | |
| data, *_ = _prep_data_for_plot( | |
| data, | |
| times, | |
| freqs, | |
| tmin=tmin, | |
| tmax=tmax, | |
| fmin=fmin, | |
| fmax=fmax, | |
| baseline=baseline, | |
| mode=mode, | |
| taper_weights=self.weights, | |
| verbose=verbose, | |
| ) | |
| # average over times and freqs | |
| data = data.mean((-2, -1)) | |
| im, _ = plot_topomap(data, pos, axes=ax, show=False, **topomap_args) | |
| _add_colorbar(ax, im, topomap_args["cmap"], title="AU") | |
| plt_show(fig=fig) | |
| else: | |
| for idx, ch_type in enumerate(types): | |
| ax = fig.add_subplot(1, len(types), idx + 1) | |
| plot_tfr_topomap( | |
| self, | |
| ch_type=ch_type, | |
| tmin=tmin, | |
| tmax=tmax, | |
| fmin=fmin, | |
| fmax=fmax, | |
| baseline=baseline, | |
| mode=mode, | |
| axes=ax, | |
| **topomap_args, | |
| ) | |
| ax.set_title(ch_type) | |
| def _update_epoch_attributes(self): | |
| # overwritten in EpochsTFR; adds things needed for to_data_frame and __getitem__ | |
| pass | |
| def _detrend_picks(self): | |
| """Provide compatibility with __iter__.""" | |
| return list() | |
| def baseline(self): | |
| """Start and end of the baseline period (in seconds).""" | |
| return self._baseline | |
| def ch_names(self): | |
| """The channel names.""" | |
| return self.info["ch_names"] | |
| def data(self): | |
| """The time-frequency-resolved power estimates.""" | |
| return self._data | |
| def data(self, data): | |
| self._data = data | |
| def freqs(self): | |
| """The frequencies at which power estimates were computed.""" | |
| return self._freqs | |
| def method(self): | |
| """The method used to compute the time-frequency power estimates.""" | |
| return self._method | |
| def sfreq(self): | |
| """Sampling frequency of the data.""" | |
| return self.info["sfreq"] | |
| def shape(self): | |
| """Data shape.""" | |
| return self._data.shape | |
| def times(self): | |
| """The time points present in the data (in seconds).""" | |
| return self._times_readonly | |
| def weights(self): | |
| """The weights used for each taper in the time-frequency estimates.""" | |
| return self._weights | |
| def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): | |
| """Crop data to a given time interval in place. | |
| Parameters | |
| ---------- | |
| %(tmin_tmax_psd)s | |
| fmin : float | None | |
| Lowest frequency of selection in Hz. | |
| .. versionadded:: 0.18.0 | |
| fmax : float | None | |
| Highest frequency of selection in Hz. | |
| .. versionadded:: 0.18.0 | |
| %(include_tmax)s | |
| Returns | |
| ------- | |
| %(inst_tfr)s | |
| The modified instance. | |
| """ | |
| super().crop(tmin=tmin, tmax=tmax, include_tmax=include_tmax) | |
| if fmin is not None or fmax is not None: | |
| freq_mask = _freq_mask( | |
| self.freqs, sfreq=self.info["sfreq"], fmin=fmin, fmax=fmax | |
| ) | |
| else: | |
| freq_mask = slice(None) | |
| self._freqs = self.freqs[freq_mask] | |
| # Deal with broadcasting (boolean arrays do not broadcast, but indices | |
| # do, so we need to convert freq_mask to make use of broadcasting) | |
| if isinstance(freq_mask, np.ndarray): | |
| freq_mask = np.where(freq_mask)[0] | |
| self._data = self._data[..., freq_mask, :] | |
| return self | |
| def copy(self): | |
| """Return copy of the TFR instance. | |
| Returns | |
| ------- | |
| %(inst_tfr)s | |
| A copy of the object. | |
| """ | |
| return deepcopy(self) | |
| def apply_baseline(self, baseline, mode="mean", verbose=None): | |
| """Baseline correct the data. | |
| Parameters | |
| ---------- | |
| %(baseline_rescale)s | |
| How baseline is computed is determined by the ``mode`` parameter. | |
| mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio' | |
| Perform baseline correction by | |
| - subtracting the mean of baseline values ('mean') | |
| - dividing by the mean of baseline values ('ratio') | |
| - dividing by the mean of baseline values and taking the log | |
| ('logratio') | |
| - subtracting the mean of baseline values followed by dividing by | |
| the mean of baseline values ('percent') | |
| - subtracting the mean of baseline values and dividing by the | |
| standard deviation of baseline values ('zscore') | |
| - dividing by the mean of baseline values, taking the log, and | |
| dividing by the standard deviation of log baseline values | |
| ('zlogratio') | |
| %(verbose)s | |
| Returns | |
| ------- | |
| %(inst_tfr)s | |
| The modified instance. | |
| """ | |
| self._baseline = _check_baseline(baseline, times=self.times, sfreq=self.sfreq) | |
| rescale(self.data, self.times, self.baseline, mode, copy=False, verbose=verbose) | |
| return self | |
| def get_data( | |
| self, | |
| picks=None, | |
| exclude="bads", | |
| fmin=None, | |
| fmax=None, | |
| tmin=None, | |
| tmax=None, | |
| return_times=False, | |
| return_freqs=False, | |
| return_tapers=False, | |
| ): | |
| """Get time-frequency data in NumPy array format. | |
| Parameters | |
| ---------- | |
| %(picks_good_data_noref)s | |
| %(exclude_spectrum_get_data)s | |
| %(fmin_fmax_tfr)s | |
| %(tmin_tmax_psd)s | |
| return_times : bool | |
| Whether to return the time values for the requested time range. | |
| Default is ``False``. | |
| return_freqs : bool | |
| Whether to return the frequency bin values for the requested | |
| frequency range. Default is ``False``. | |
| return_tapers : bool | |
| Whether to return the taper numbers. Default is ``False``. | |
| .. versionadded:: 1.10.0 | |
| Returns | |
| ------- | |
| data : array | |
| The requested data in a NumPy array. | |
| times : array | |
| The time values for the requested data range. Only returned if | |
| ``return_times`` is ``True``. | |
| freqs : array | |
| The frequency values for the requested data range. Only returned if | |
| ``return_freqs`` is ``True``. | |
| tapers : array | None | |
| The taper numbers. Only returned if ``return_tapers`` is ``True``. Will be | |
| ``None`` if a taper dimension is not present in the data. | |
| Notes | |
| ----- | |
| Returns a copy of the underlying data (not a view). | |
| """ | |
| tmin = self.times[0] if tmin is None else tmin | |
| tmax = self.times[-1] if tmax is None else tmax | |
| fmin = 0 if fmin is None else fmin | |
| fmax = np.inf if fmax is None else fmax | |
| picks = _picks_to_idx( | |
| self.info, picks, "data_or_ica", exclude=exclude, with_ref_meg=False | |
| ) | |
| fmin_idx = np.searchsorted(self.freqs, fmin) | |
| fmax_idx = np.searchsorted(self.freqs, fmax, side="right") | |
| tmin_idx = np.searchsorted(self.times, tmin) | |
| tmax_idx = np.searchsorted(self.times, tmax, side="right") | |
| freq_picks = np.arange(fmin_idx, fmax_idx) | |
| time_picks = np.arange(tmin_idx, tmax_idx) | |
| freq_axis = self._dims.index("freq") | |
| time_axis = self._dims.index("time") | |
| chan_axis = self._dims.index("channel") | |
| # normally there's a risk of np.take reducing array dimension if there | |
| # were only one channel or frequency selected, but `_picks_to_idx` | |
| # and np.arange both always return arrays, so we're safe; the result | |
| # will always have the same `ndim` as it started with. | |
| data = ( | |
| self._data.take(picks, chan_axis) | |
| .take(freq_picks, freq_axis) | |
| .take(time_picks, time_axis) | |
| ) | |
| out = [data] | |
| if return_times: | |
| times = self._raw_times[tmin_idx:tmax_idx] | |
| out.append(times) | |
| if return_freqs: | |
| freqs = self._freqs[fmin_idx:fmax_idx] | |
| out.append(freqs) | |
| if return_tapers: | |
| if "taper" in self._dims: | |
| tapers = np.arange(self.shape[self._dims.index("taper")]) | |
| else: | |
| tapers = None | |
| out.append(tapers) | |
| if not return_times and not return_freqs and not return_tapers: | |
| return out[0] | |
| return tuple(out) | |
| def plot( | |
| self, | |
| picks=None, | |
| *, | |
| exclude=(), | |
| tmin=None, | |
| tmax=None, | |
| fmin=0.0, | |
| fmax=np.inf, | |
| baseline=None, | |
| mode="mean", | |
| dB=False, | |
| combine=None, | |
| layout=None, # TODO deprecate? not used in orig implementation either | |
| yscale="auto", | |
| vlim=(None, None), | |
| cnorm=None, | |
| cmap=None, | |
| colorbar=True, | |
| title=None, # don't deprecate this one; has (useful) option title="auto" | |
| mask=None, | |
| mask_style=None, | |
| mask_cmap="Greys", | |
| mask_alpha=0.1, | |
| axes=None, | |
| show=True, | |
| verbose=None, | |
| ): | |
| """Plot TFRs as two-dimensional time-frequency images. | |
| Parameters | |
| ---------- | |
| %(picks_good_data)s | |
| %(exclude_spectrum_plot)s | |
| %(tmin_tmax_psd)s | |
| %(fmin_fmax_tfr)s | |
| %(baseline_rescale)s | |
| How baseline is computed is determined by the ``mode`` parameter. | |
| %(mode_tfr_plot)s | |
| %(dB_tfr_plot)s | |
| %(combine_tfr_plot)s | |
| .. versionchanged:: 1.3 | |
| Added support for ``callable``. | |
| %(layout_spectrum_plot_topo)s | |
| %(yscale_tfr_plot)s | |
| .. versionadded:: 0.14.0 | |
| %(vlim_tfr_plot)s | |
| %(cnorm)s | |
| .. versionadded:: 0.24 | |
| %(cmap_topomap)s | |
| %(colorbar)s | |
| %(title_tfr_plot)s | |
| %(mask_tfr_plot)s | |
| .. versionadded:: 0.16.0 | |
| %(mask_style_tfr_plot)s | |
| .. versionadded:: 0.17 | |
| %(mask_cmap_tfr_plot)s | |
| .. versionadded:: 0.17 | |
| %(mask_alpha_tfr_plot)s | |
| .. versionadded:: 0.16.0 | |
| %(axes_tfr_plot)s | |
| %(show)s | |
| %(verbose)s | |
| Returns | |
| ------- | |
| figs : list of instances of matplotlib.figure.Figure | |
| A list of figures containing the time-frequency power. | |
| """ | |
| # the rectangle selector plots topomaps, which needs all channels uncombined, | |
| # so we keep a reference to that state here, and (because the topomap plotting | |
| # function wants an AverageTFR) update it with `comment` and `nave` values in | |
| # case we started out with a singleton EpochsTFR or RawTFR | |
| initial_state = self.__getstate__() | |
| initial_state.setdefault("comment", "") | |
| initial_state.setdefault("nave", 1) | |
| # `_picks_to_idx` also gets done inside `get_data()`` below, but we do it here | |
| # because we need the indices later | |
| idx_picks = _picks_to_idx( | |
| self.info, picks, "data_or_ica", exclude=exclude, with_ref_meg=False | |
| ) | |
| pick_names = np.array(self.ch_names)[idx_picks].tolist() # for titles | |
| ch_types = self.get_channel_types(idx_picks) | |
| # get data arrays | |
| data, times, freqs = self.get_data( | |
| picks=idx_picks, exclude=(), return_times=True, return_freqs=True | |
| ) | |
| # pass tmin/tmax here ↓↓↓, not here ↑↑↑; we want to crop *after* baselining | |
| data, times, freqs = _prep_data_for_plot( | |
| data, | |
| times, | |
| freqs, | |
| tmin=tmin, | |
| tmax=tmax, | |
| fmin=fmin, | |
| fmax=fmax, | |
| baseline=baseline, | |
| mode=mode, | |
| dB=dB, | |
| taper_weights=self.weights, | |
| verbose=verbose, | |
| ) | |
| # shape | |
| ch_axis = self._dims.index("channel") | |
| freq_axis = self._dims.index("freq") | |
| time_axis = self._dims.index("time") | |
| want_shape = list(self.shape) | |
| want_shape[ch_axis] = len(idx_picks) if combine is None else 1 | |
| want_shape[freq_axis] = len(freqs) # in case there was fmin/fmax cropping | |
| want_shape[time_axis] = len(times) # in case there was tmin/tmax cropping | |
| want_shape = [ | |
| n for dim, n in zip(self._dims, want_shape) if dim != "taper" | |
| ] # tapers must be aggregated over by now | |
| want_shape = tuple(want_shape) | |
| # combine | |
| combine_was_none = combine is None | |
| combine = _make_combine_callable( | |
| combine, axis=ch_axis, valid=("mean", "rms"), keepdims=True | |
| ) | |
| try: | |
| data = combine(data) # no need to copy; get_data() never returns a view | |
| except Exception as e: | |
| msg = ( | |
| "Something went wrong with the callable passed to 'combine'; see " | |
| "traceback." | |
| ) | |
| raise ValueError(msg) from e | |
| # call succeeded, check type and shape | |
| mismatch = False | |
| if not isinstance(data, np.ndarray): | |
| mismatch = "type" | |
| extra = "" | |
| elif data.shape not in (want_shape, want_shape[1:]): | |
| mismatch = "shape" | |
| extra = f" of shape {data.shape}" | |
| if mismatch: | |
| raise RuntimeError( | |
| f"Wrong {mismatch} yielded by callable passed to 'combine'. Make sure " | |
| "your function takes a single argument (an array of shape " | |
| "(n_channels, n_freqs, n_times)) and returns an array of shape " | |
| f"(n_freqs, n_times); yours yielded: {type(data)}{extra}." | |
| ) | |
| # restore singleton collapsed axis (removed by user-provided callable): | |
| # (n_freqs, n_times) → (1, n_freqs, n_times) | |
| if data.shape == (len(freqs), len(times)): | |
| data = data[np.newaxis] | |
| assert data.shape == want_shape | |
| # cmap handling. power may be negative depending on baseline strategy so set | |
| # `norm` empirically — but only if user didn't set limits explicitly. | |
| norm = False if vlim == (None, None) else data.min() >= 0.0 | |
| vmin, vmax = _setup_vmin_vmax(data, *vlim, norm=norm) | |
| cmap = _setup_cmap(cmap, norm=norm) | |
| # prepare figure(s) | |
| if axes is None: | |
| figs = [plt.figure(layout="constrained") for _ in range(data.shape[0])] | |
| axes = [fig.add_subplot() for fig in figs] | |
| elif isinstance(axes, plt.Axes): | |
| figs = [axes.get_figure()] | |
| axes = [axes] | |
| elif isinstance(axes, np.ndarray): # allow plotting into a grid of axes | |
| figs = [ax.get_figure() for ax in axes.flat] | |
| elif hasattr(axes, "__iter__") and len(axes): | |
| figs = [ax.get_figure() for ax in axes] | |
| else: | |
| raise ValueError( | |
| f"axes must be None, Axes, or list/array of Axes, got {type(axes)}" | |
| ) | |
| if len(axes) != data.shape[0]: | |
| raise RuntimeError( | |
| f"Mismatch between picked channels ({data.shape[0]}) and axes " | |
| f"({len(axes)}); there must be one axes for each picked channel." | |
| ) | |
| # check if we're being called from within plot_joint(). If so, get the | |
| # `topomap_args` from the calling context and pass it to the onselect handler. | |
| # (we need 2 `f_back` here because of the verbose decorator) | |
| calling_frame = inspect.currentframe().f_back.f_back | |
| source_plot_joint = calling_frame.f_code.co_name == "plot_joint" | |
| topomap_args = ( | |
| dict() | |
| if not source_plot_joint | |
| else calling_frame.f_locals.get("topomap_args", dict()) | |
| ) | |
| # plot | |
| for ix, _fig in enumerate(figs): | |
| # restrict the onselect instance to the channel type of the picks used in | |
| # the image plot | |
| uniq_types = np.unique(ch_types) | |
| ch_type = None if len(uniq_types) > 1 else uniq_types.item() | |
| this_tfr = AverageTFR(inst=initial_state).pick(ch_type, verbose=verbose) | |
| _onselect_callback = partial( | |
| this_tfr._onselect, | |
| picks=None, # already restricted the picks in `this_tfr` | |
| exclude=(), | |
| baseline=baseline, | |
| mode=mode, | |
| cmap=cmap, | |
| source_plot_joint=source_plot_joint, | |
| topomap_args=topomap_args, | |
| ) | |
| # draw the image plot | |
| _imshow_tfr( | |
| ax=axes[ix], | |
| tfr=data[[ix]], | |
| ch_idx=0, | |
| tmin=times[0], | |
| tmax=times[-1], | |
| vmin=vmin, | |
| vmax=vmax, | |
| onselect=_onselect_callback, | |
| ylim=None, | |
| freq=freqs, | |
| x_label="Time (s)", | |
| y_label="Frequency (Hz)", | |
| colorbar=colorbar, | |
| cmap=cmap, | |
| yscale=yscale, | |
| mask=mask, | |
| mask_style=mask_style, | |
| mask_cmap=mask_cmap, | |
| mask_alpha=mask_alpha, | |
| cnorm=cnorm, | |
| ) | |
| # handle title. automatic title is: | |
| # f"{Baselined} {power} ({ch_name})" or | |
| # f"{Baselined} {power} ({combination} of {N} {ch_type}s)" | |
| if title == "auto": | |
| if combine_was_none: # one plot per channel | |
| which_chs = pick_names[ix] | |
| elif len(pick_names) == 1: # there was only one pick anyway | |
| which_chs = pick_names[0] | |
| else: # one plot for all chs combined | |
| which_chs = _set_title_multiple_electrodes( | |
| None, combine, pick_names, all_=True, ch_type=ch_type | |
| ) | |
| _prefix = "Power" if baseline is None else "Baselined power" | |
| _title = f"{_prefix} ({which_chs})" | |
| else: | |
| _title = title | |
| _fig.suptitle(_title) | |
| plt_show(show) | |
| return figs | |
| def plot_joint( | |
| self, | |
| *, | |
| timefreqs=None, | |
| picks=None, | |
| exclude=(), | |
| combine="mean", | |
| tmin=None, | |
| tmax=None, | |
| fmin=None, | |
| fmax=None, | |
| baseline=None, | |
| mode="mean", | |
| dB=False, | |
| yscale="auto", | |
| vlim=(None, None), | |
| cnorm=None, | |
| cmap=None, | |
| colorbar=True, | |
| title=None, # TODO consider deprecating this one, or adding an "auto" option | |
| show=True, | |
| topomap_args=None, | |
| image_args=None, | |
| verbose=None, | |
| ): | |
| """Plot TFRs as a two-dimensional image with topomap highlights. | |
| Parameters | |
| ---------- | |
| %(timefreqs)s | |
| %(picks_good_data)s | |
| %(exclude_psd)s | |
| Default is an empty :class:`tuple` which includes all channels. | |
| %(combine_tfr_plot_joint)s | |
| .. versionchanged:: 1.3 | |
| Added support for ``callable``. | |
| %(tmin_tmax_psd)s | |
| %(fmin_fmax_tfr)s | |
| %(baseline_rescale)s | |
| How baseline is computed is determined by the ``mode`` parameter. | |
| %(mode_tfr_plot)s | |
| %(dB_tfr_plot)s | |
| %(yscale_tfr_plot)s | |
| %(vlim_tfr_plot_joint)s | |
| %(cnorm)s | |
| %(cmap_tfr_plot_topo)s | |
| %(colorbar_tfr_plot_joint)s | |
| %(title_none)s | |
| %(show)s | |
| %(topomap_args)s | |
| %(image_args)s | |
| %(verbose)s | |
| Returns | |
| ------- | |
| fig : matplotlib.figure.Figure | |
| The figure containing the topography. | |
| Notes | |
| ----- | |
| %(notes_timefreqs_tfr_plot_joint)s | |
| .. versionadded:: 0.16.0 | |
| """ | |
| from matplotlib import ticker | |
| from matplotlib.patches import ConnectionPatch | |
| # handle recursion | |
| picks = _picks_to_idx( | |
| self.info, picks, "data_or_ica", exclude=exclude, with_ref_meg=False | |
| ) | |
| all_ch_types = np.array(self.get_channel_types()) | |
| uniq_ch_types = sorted(set(all_ch_types[picks])) | |
| if len(uniq_ch_types) > 1: | |
| msg = "Multiple channel types selected, returning one figure per type." | |
| logger.info(msg) | |
| figs = list() | |
| for this_type in uniq_ch_types: | |
| this_picks = np.intersect1d( | |
| picks, | |
| np.nonzero(np.isin(all_ch_types, this_type))[0], | |
| assume_unique=True, | |
| ) | |
| # TODO might be nice to not "copy first, then pick"; alternative might | |
| # be to subset the data with `this_picks` and then construct the "copy" | |
| # using __getstate__ and __setstate__ | |
| _tfr = self.copy().pick(this_picks) | |
| figs.append( | |
| _tfr.plot_joint( | |
| timefreqs=timefreqs, | |
| picks=None, | |
| baseline=baseline, | |
| mode=mode, | |
| tmin=tmin, | |
| tmax=tmax, | |
| fmin=fmin, | |
| fmax=fmax, | |
| vlim=vlim, | |
| cmap=cmap, | |
| dB=dB, | |
| colorbar=colorbar, | |
| show=False, | |
| title=title, | |
| yscale=yscale, | |
| combine=combine, | |
| exclude=(), | |
| topomap_args=topomap_args, | |
| verbose=verbose, | |
| ) | |
| ) | |
| return figs | |
| else: | |
| ch_type = uniq_ch_types[0] | |
| # handle defaults | |
| _validate_type(combine, ("str", "callable"), item_name="combine") # no `None` | |
| image_args = dict() if image_args is None else image_args | |
| topomap_args = dict() if topomap_args is None else topomap_args.copy() | |
| # make sure if topomap_args["ch_type"] is set, it matches what is in `self.info` | |
| topomap_args.setdefault("ch_type", ch_type) | |
| if topomap_args["ch_type"] != ch_type: | |
| raise ValueError( | |
| f"topomap_args['ch_type'] is {topomap_args['ch_type']} which does not " | |
| f"match the channel type present in the object ({ch_type})." | |
| ) | |
| # some necessary defaults | |
| topomap_args.setdefault("outlines", "head") | |
| topomap_args.setdefault("contours", 6) | |
| # don't pass these: | |
| topomap_args.pop("axes", None) | |
| topomap_args.pop("show", None) | |
| topomap_args.pop("colorbar", None) | |
| # get the time/freq limits of the image plot, to make sure requested annotation | |
| # times/freqs are in range | |
| _, times, freqs = self.get_data( | |
| picks=picks, | |
| exclude=(), | |
| tmin=tmin, | |
| tmax=tmax, | |
| fmin=fmin, | |
| fmax=fmax, | |
| return_times=True, | |
| return_freqs=True, | |
| ) | |
| # validate requested annotation times and freqs | |
| timefreqs = _get_timefreqs(self, timefreqs) | |
| valid_timefreqs = dict() | |
| while timefreqs: | |
| (_time, _freq), (t_win, f_win) = timefreqs.popitem() | |
| # convert to half-windows | |
| t_win /= 2 | |
| f_win /= 2 | |
| # make sure the times / freqs are in-bounds | |
| msg = ( | |
| "Requested {} exceeds the range of the data ({}). Choose different " | |
| "`timefreqs`." | |
| ) | |
| if (times > _time).all() or (times < _time).all(): | |
| _var = f"time point ({_time:0.3f} s)" | |
| _range = f"{times[0]:0.3f} - {times[-1]:0.3f} s" | |
| raise ValueError(msg.format(_var, _range)) | |
| elif (freqs > _freq).all() or (freqs < _freq).all(): | |
| _var = f"frequency ({_freq:0.1f} Hz)" | |
| _range = f"{freqs[0]:0.1f} - {freqs[-1]:0.1f} Hz" | |
| raise ValueError(msg.format(_var, _range)) | |
| # snap the times/freqs to the nearest point we have an estimate for, and | |
| # store the validated points | |
| if t_win == 0: | |
| _time = times[np.argmin(np.abs(times - _time))] | |
| if f_win == 0: | |
| _freq = freqs[np.argmin(np.abs(freqs - _freq))] | |
| valid_timefreqs[(_time, _freq)] = (t_win, f_win) | |
| # prep data for topomaps (unlike image plot, must include all channels of the | |
| # current ch_type). Don't pass tmin/tmax here (crop later after baselining) | |
| topomap_picks = _picks_to_idx(self.info, ch_type) | |
| data, times, freqs = self.get_data( | |
| picks=topomap_picks, exclude=(), return_times=True, return_freqs=True | |
| ) | |
| # merge grads before baselining (makes ERDS visible) | |
| info = pick_info(self.info, sel=topomap_picks, copy=True) | |
| data, pos = _merge_if_grads( | |
| data=data, | |
| info=info, | |
| ch_type=ch_type, | |
| sphere=topomap_args.get("sphere"), | |
| combine=combine, | |
| ) | |
| # loop over intended topomap locations, to find one vlim that works for all. | |
| tf_array = np.array(list(valid_timefreqs)) # each row is [time, freq] | |
| tf_array = tf_array[tf_array[:, 0].argsort()] # sort by time | |
| _vmin, _vmax = (np.inf, -np.inf) | |
| topomap_arrays = list() | |
| topomap_titles = list() | |
| for _time, _freq in tf_array: | |
| # reduce data to the range of interest in the TF plane (i.e., finally crop) | |
| t_win, f_win = valid_timefreqs[(_time, _freq)] | |
| _tmin, _tmax = np.array([-1, 1]) * t_win + _time | |
| _fmin, _fmax = np.array([-1, 1]) * f_win + _freq | |
| _data, *_ = _prep_data_for_plot( | |
| data, | |
| times, | |
| freqs, | |
| tmin=_tmin, | |
| tmax=_tmax, | |
| fmin=_fmin, | |
| fmax=_fmax, | |
| baseline=baseline, | |
| mode=mode, | |
| taper_weights=self.weights, | |
| verbose=verbose, | |
| ) | |
| _data = _data.mean(axis=(-1, -2)) # avg over times and freqs | |
| topomap_arrays.append(_data) | |
| _vmin = min(_data.min(), _vmin) | |
| _vmax = max(_data.max(), _vmax) | |
| # construct topopmap subplot title | |
| t_pm = "" if t_win == 0 else f" ± {t_win:0.2f}" | |
| f_pm = "" if f_win == 0 else f" ± {f_win:0.1f}" | |
| _title = f"{_time:0.2f}{t_pm} s,\n{_freq:0.1f}{f_pm} Hz" | |
| topomap_titles.append(_title) | |
| # handle cmap. Power may be negative depending on baseline strategy so set | |
| # `norm` empirically. vmin/vmax will be handled separately within the `plot()` | |
| # call for the image plot. | |
| norm = np.min(topomap_arrays) >= 0.0 | |
| cmap = _setup_cmap(cmap, norm=norm) | |
| topomap_args.setdefault("cmap", cmap[0]) # prevent interactive cbar | |
| # finalize topomap vlims and compute contour locations. | |
| # By passing `data=None` here ↓↓↓↓ we effectively assert vmin & vmax aren't None | |
| _vlim = _setup_vmin_vmax(data=None, vmin=_vmin, vmax=_vmax, norm=norm) | |
| topomap_args.setdefault("vlim", _vlim) | |
| locator, topomap_args["contours"] = _set_contour_locator( | |
| *topomap_args["vlim"], topomap_args["contours"] | |
| ) | |
| # initialize figure and do the image plot. `self.plot()` needed to wait to be | |
| # called until after `topomap_args` was fully populated --- we don't pass the | |
| # dict through to `self.plot()` explicitly here, but we do "reach back" and get | |
| # it if it's needed by the interactive rectangle selector. | |
| fig, image_ax, topomap_axes = _prepare_joint_axes(len(valid_timefreqs)) | |
| fig = self.plot( | |
| picks=picks, | |
| exclude=(), | |
| tmin=tmin, | |
| tmax=tmax, | |
| fmin=fmin, | |
| fmax=fmax, | |
| baseline=baseline, | |
| mode=mode, | |
| dB=dB, | |
| combine=combine, | |
| yscale=yscale, | |
| vlim=vlim, | |
| cnorm=cnorm, | |
| cmap=cmap, | |
| colorbar=False, | |
| title=title, | |
| # mask, mask_style, mask_cmap, mask_alpha | |
| axes=image_ax, | |
| show=False, | |
| verbose=verbose, | |
| **image_args, | |
| )[0] # [0] because `.plot()` always returns a list | |
| # now, actually plot the topomaps | |
| for ax, title, _data in zip(topomap_axes, topomap_titles, topomap_arrays): | |
| ax.set_title(title) | |
| plot_topomap(_data, pos, axes=ax, show=False, **topomap_args) | |
| # draw colorbar | |
| if colorbar: | |
| cbar = fig.colorbar(ax.images[0]) | |
| cbar.locator = ticker.MaxNLocator(nbins=5) if locator is None else locator | |
| cbar.update_ticks() | |
| # draw the connection lines between time-frequency image and topoplots | |
| for (time_, freq_), topo_ax in zip(tf_array, topomap_axes): | |
| con = ConnectionPatch( | |
| xyA=[time_, freq_], | |
| xyB=[0.5, 0], | |
| coordsA="data", | |
| coordsB="axes fraction", | |
| axesA=image_ax, | |
| axesB=topo_ax, | |
| color="grey", | |
| linestyle="-", | |
| linewidth=1.5, | |
| alpha=0.66, | |
| zorder=1, | |
| clip_on=False, | |
| ) | |
| fig.add_artist(con) | |
| plt_show(show) | |
| return fig | |
| def plot_topo( | |
| self, | |
| picks=None, | |
| baseline=None, | |
| mode="mean", | |
| tmin=None, | |
| tmax=None, | |
| fmin=None, | |
| fmax=None, | |
| vmin=None, # TODO deprecate in favor of `vlim` (needs helper func refactor) | |
| vmax=None, | |
| layout=None, | |
| cmap="RdBu_r", | |
| title=None, # don't deprecate; topo titles aren't standard (color, size, just.) | |
| dB=False, | |
| colorbar=True, | |
| layout_scale=0.945, | |
| show=True, | |
| border="none", | |
| fig_facecolor="k", | |
| fig_background=None, | |
| font_color="w", | |
| yscale="auto", | |
| verbose=None, | |
| ): | |
| """Plot a TFR image for each channel in a sensor layout arrangement. | |
| Parameters | |
| ---------- | |
| %(picks_good_data)s | |
| %(baseline_rescale)s | |
| How baseline is computed is determined by the ``mode`` parameter. | |
| %(mode_tfr_plot)s | |
| %(tmin_tmax_psd)s | |
| %(fmin_fmax_tfr)s | |
| %(vmin_vmax_tfr_plot_topo)s | |
| %(layout_spectrum_plot_topo)s | |
| %(cmap_tfr_plot_topo)s | |
| %(title_none)s | |
| %(dB_tfr_plot)s | |
| %(colorbar)s | |
| %(layout_scale)s | |
| %(show)s | |
| %(border_topo)s | |
| %(fig_facecolor)s | |
| %(fig_background)s | |
| %(font_color)s | |
| %(yscale_tfr_plot)s | |
| %(verbose)s | |
| Returns | |
| ------- | |
| fig : matplotlib.figure.Figure | |
| The figure containing the topography. | |
| """ | |
| # convenience vars | |
| times = self.times.copy() | |
| freqs = self.freqs | |
| data = self.data | |
| info = self.info | |
| info, data = _prepare_picks(info, data, picks, axis=0) | |
| del picks | |
| # baseline, crop, convert complex to power, aggregate tapers, and dB scaling | |
| data, times, freqs = _prep_data_for_plot( | |
| data, | |
| times, | |
| freqs, | |
| tmin=tmin, | |
| tmax=tmax, | |
| fmin=fmin, | |
| fmax=fmax, | |
| baseline=baseline, | |
| mode=mode, | |
| dB=dB, | |
| taper_weights=self.weights, | |
| verbose=verbose, | |
| ) | |
| # get vlims | |
| vmin, vmax = _setup_vmin_vmax(data, vmin, vmax) | |
| if layout is None: | |
| from mne import find_layout | |
| layout = find_layout(self.info) | |
| onselect_callback = partial(self._onselect, baseline=baseline, mode=mode) | |
| click_fun = partial( | |
| _imshow_tfr, | |
| tfr=data, | |
| freq=freqs, | |
| yscale=yscale, | |
| cmap=(cmap, True), | |
| onselect=onselect_callback, | |
| ) | |
| imshow = partial( | |
| _imshow_tfr_unified, | |
| tfr=data, | |
| freq=freqs, | |
| cmap=cmap, | |
| onselect=onselect_callback, | |
| ) | |
| fig = _plot_topo( | |
| info=info, | |
| times=times, | |
| show_func=imshow, | |
| click_func=click_fun, | |
| layout=layout, | |
| colorbar=colorbar, | |
| vmin=vmin, | |
| vmax=vmax, | |
| cmap=cmap, | |
| layout_scale=layout_scale, | |
| title=title, | |
| border=border, | |
| x_label="Time (s)", | |
| y_label="Frequency (Hz)", | |
| fig_facecolor=fig_facecolor, | |
| font_color=font_color, | |
| unified=True, | |
| img=True, | |
| ) | |
| add_background_image(fig, fig_background) | |
| plt_show(show) | |
| return fig | |
| def plot_topomap( | |
| self, | |
| tmin=None, | |
| tmax=None, | |
| fmin=0.0, | |
| fmax=np.inf, | |
| *, | |
| ch_type=None, | |
| baseline=None, | |
| mode="mean", | |
| sensors=True, | |
| show_names=False, | |
| mask=None, | |
| mask_params=None, | |
| contours=6, | |
| outlines="head", | |
| sphere=None, | |
| image_interp=_INTERPOLATION_DEFAULT, | |
| extrapolate=_EXTRAPOLATE_DEFAULT, | |
| border=_BORDER_DEFAULT, | |
| res=64, | |
| size=2, | |
| cmap=None, | |
| vlim=(None, None), | |
| cnorm=None, | |
| colorbar=True, | |
| cbar_fmt="%1.1e", | |
| units=None, | |
| axes=None, | |
| show=True, | |
| ): | |
| return plot_tfr_topomap( | |
| self, | |
| tmin=tmin, | |
| tmax=tmax, | |
| fmin=fmin, | |
| fmax=fmax, | |
| ch_type=ch_type, | |
| baseline=baseline, | |
| mode=mode, | |
| sensors=sensors, | |
| show_names=show_names, | |
| mask=mask, | |
| mask_params=mask_params, | |
| contours=contours, | |
| outlines=outlines, | |
| sphere=sphere, | |
| image_interp=image_interp, | |
| extrapolate=extrapolate, | |
| border=border, | |
| res=res, | |
| size=size, | |
| cmap=cmap, | |
| vlim=vlim, | |
| cnorm=cnorm, | |
| colorbar=colorbar, | |
| cbar_fmt=cbar_fmt, | |
| units=units, | |
| axes=axes, | |
| show=show, | |
| ) | |
| def save(self, fname, *, overwrite=False, verbose=None): | |
| """Save time-frequency data to disk (in HDF5 format). | |
| Parameters | |
| ---------- | |
| fname : path-like | |
| Path of file to save to, which should end with ``-tfr.h5`` or ``-tfr.hdf5``. | |
| %(overwrite)s | |
| %(verbose)s | |
| See Also | |
| -------- | |
| mne.time_frequency.read_tfrs | |
| """ | |
| _, write_hdf5 = _import_h5io_funcs() | |
| check_fname(fname, "time-frequency object", (".h5", ".hdf5")) | |
| fname = _check_fname(fname, overwrite=overwrite, verbose=verbose) | |
| out = self.__getstate__() | |
| if "metadata" in out: | |
| out["metadata"] = _prepare_write_metadata(out["metadata"]) | |
| write_hdf5(fname, out, overwrite=overwrite, title="mnepython", slash="replace") | |
| def to_data_frame( | |
| self, | |
| picks=None, | |
| index=None, | |
| long_format=False, | |
| time_format=None, | |
| *, | |
| verbose=None, | |
| ): | |
| """Export data in tabular structure as a pandas DataFrame. | |
| Channels are converted to columns in the DataFrame. By default, additional | |
| columns ``'time'``, ``'freq'``, ``'taper'``, ``'epoch'``, and ``'condition'`` | |
| (epoch event description) are added, unless ``index`` is not ``None`` (in which | |
| case the columns specified in ``index`` will be used to form the DataFrame's | |
| index instead). ``'epoch'``, and ``'condition'`` are not supported for | |
| ``AverageTFR``. ``'taper'`` is only supported when a taper dimensions is | |
| present, such as for complex or phase multitaper data. | |
| Parameters | |
| ---------- | |
| %(picks_all)s | |
| %(index_df_epo)s | |
| Valid string values are ``'time'``, ``'freq'``, ``'taper'``, ``'epoch'``, | |
| and ``'condition'`` for ``EpochsTFR`` and ``'time'``, ``'freq'``, and | |
| ``'taper'`` for ``AverageTFR``. Defaults to ``None``. | |
| %(long_format_df_epo)s | |
| %(time_format_df)s | |
| .. versionadded:: 0.23 | |
| %(verbose)s | |
| Returns | |
| ------- | |
| %(df_return)s | |
| """ | |
| # check pandas once here, instead of in each private utils function | |
| pd = _check_pandas_installed() # noqa | |
| # triage for Epoch-derived or unaggregated spectra | |
| from_epo = isinstance(self, EpochsTFR) | |
| unagg_mt = "taper" in self._dims | |
| # arg checking | |
| valid_index_args = ["time", "freq"] | |
| if from_epo: | |
| valid_index_args.extend(["epoch", "condition"]) | |
| if unagg_mt: | |
| valid_index_args.append("taper") | |
| valid_time_formats = ["ms", "timedelta"] | |
| index = _check_pandas_index_arguments(index, valid_index_args) | |
| time_format = _check_time_format(time_format, valid_time_formats) | |
| # get data | |
| picks = _picks_to_idx(self.info, picks, "all", exclude=()) | |
| data, times, freqs, tapers = self.get_data( | |
| picks, return_times=True, return_freqs=True, return_tapers=True | |
| ) | |
| ch_axis = self._dims.index("channel") | |
| if not from_epo: | |
| data = data[np.newaxis] # add singleton "epochs" axis | |
| ch_axis += 1 | |
| if not unagg_mt: | |
| data = np.expand_dims(data, -3) # add singleton "tapers" axis | |
| n_epochs, n_picks, n_tapers, n_freqs, n_times = data.shape | |
| # reshape to (epochs*tapers*freqs*times) x signals | |
| data = np.moveaxis(data, ch_axis, -1) | |
| data = data.reshape(n_epochs * n_tapers * n_freqs * n_times, n_picks) | |
| # prepare extra columns / multiindex | |
| mindex = list() | |
| default_index = list() | |
| times = _convert_times(times, time_format, meas_date=self.info["meas_date"]) | |
| times = np.tile(times, n_epochs * n_freqs * n_tapers) | |
| freqs = np.tile(np.repeat(freqs, n_times), n_epochs * n_tapers) | |
| mindex.append(("time", times)) | |
| mindex.append(("freq", freqs)) | |
| if from_epo: | |
| mindex.append( | |
| ("epoch", np.repeat(self.selection, n_times * n_freqs * n_tapers)) | |
| ) | |
| rev_event_id = {v: k for k, v in self.event_id.items()} | |
| conditions = [rev_event_id[k] for k in self.events[:, 2]] | |
| mindex.append( | |
| ("condition", np.repeat(conditions, n_times * n_freqs * n_tapers)) | |
| ) | |
| default_index.extend(["condition", "epoch"]) | |
| if unagg_mt: | |
| tapers = np.repeat(np.tile(tapers, n_epochs), n_freqs * n_times) | |
| mindex.append(("taper", tapers)) | |
| default_index.append("taper") | |
| default_index.extend(["freq", "time"]) | |
| assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:]) | |
| # build DataFrame | |
| df = _build_data_frame( | |
| self, data, picks, long_format, mindex, index, default_index=default_index | |
| ) | |
| return df | |
| class AverageTFR(BaseTFR): | |
| """Data object for spectrotemporal representations of averaged data. | |
| .. warning:: The preferred means of creating AverageTFR objects is via the | |
| instance methods :meth:`mne.Epochs.compute_tfr` and | |
| :meth:`mne.Evoked.compute_tfr`, or via | |
| :meth:`mne.time_frequency.EpochsTFR.average`. Direct class | |
| instantiation is discouraged. | |
| Parameters | |
| ---------- | |
| inst : instance of Evoked | instance of Epochs | dict | |
| The data from which to compute the time-frequency representation. Passing a | |
| :class:`dict` will create the AverageTFR using the ``__setstate__`` interface | |
| and is not recommended for typical use cases. | |
| freqs : ndarray, shape (n_freqs,) | |
| The frequencies in Hz. | |
| %(method_tfr)s | |
| %(freqs_tfr)s | |
| %(tmin_tmax_psd)s | |
| %(picks_good_data_noref)s | |
| %(proj_psd)s | |
| %(decim_tfr)s | |
| %(comment_averagetfr)s | |
| %(n_jobs)s | |
| %(verbose)s | |
| %(method_kw_tfr)s | |
| Attributes | |
| ---------- | |
| %(baseline_tfr_attr)s | |
| %(ch_names_tfr_attr)s | |
| %(comment_averagetfr_attr)s | |
| %(freqs_tfr_attr)s | |
| %(info_not_none)s | |
| %(method_tfr_attr)s | |
| %(nave_tfr_attr)s | |
| %(sfreq_tfr_attr)s | |
| %(shape_tfr_attr)s | |
| %(weights_tfr_attr)s | |
| See Also | |
| -------- | |
| RawTFR | |
| EpochsTFR | |
| AverageTFRArray | |
| mne.Evoked.compute_tfr | |
| mne.time_frequency.EpochsTFR.average | |
| Notes | |
| ----- | |
| The old API (prior to version 1.7) was:: | |
| AverageTFR(info, data, times, freqs, nave, comment=None, method=None) | |
| That API is still available via :class:`~mne.time_frequency.AverageTFRArray` for | |
| cases where the data are precomputed or do not originate from MNE-Python objects. | |
| The preferred new API uses instance methods:: | |
| evoked.compute_tfr(method, freqs, ...) | |
| epochs.compute_tfr(method, freqs, average=True, ...) | |
| The new API also supports AverageTFR instantiation from a :class:`dict`, but this | |
| is primarily for save/load and internal purposes, and wraps ``__setstate__``. | |
| During the transition from the old to the new API, it may be expedient to use | |
| :class:`~mne.time_frequency.AverageTFRArray` as a "quick-fix" approach to updating | |
| scripts under active development. | |
| References | |
| ---------- | |
| .. footbibliography:: | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| inst=None, | |
| freqs=None, | |
| method=None, | |
| tmin=None, | |
| tmax=None, | |
| picks=None, | |
| proj=False, | |
| decim=1, | |
| comment=None, | |
| n_jobs=None, | |
| verbose=None, | |
| **method_kw, | |
| ): | |
| from ..epochs import BaseEpochs | |
| from ..evoked import Evoked | |
| from ._stockwell import _check_input_st, _compute_freqs_st | |
| # dict is allowed for __setstate__ compatibility, and Epochs.compute_tfr() can | |
| # return an AverageTFR depending on its parameters, so Epochs input is allowed | |
| _validate_type( | |
| inst, (BaseEpochs, Evoked, dict), "object passed to AverageTFR constructor" | |
| ) | |
| # stockwell API is very different from multitaper/morlet | |
| if method == "stockwell" and not isinstance(inst, dict): | |
| if isinstance(freqs, str) and freqs == "auto": | |
| fmin, fmax = None, None | |
| elif len(freqs) == 2: | |
| fmin, fmax = freqs | |
| else: | |
| raise ValueError( | |
| "for Stockwell method, freqs must be a length-2 iterable " | |
| f'or "auto", got {freqs}.' | |
| ) | |
| method_kw.update(fmin=fmin, fmax=fmax) | |
| # Compute freqs. We need a couple lines of code dupe here (also in | |
| # BaseTFR.__init__) to get the subset of times to pass to _check_input_st() | |
| _mask = _time_mask(inst.times, tmin, tmax, sfreq=inst.info["sfreq"]) | |
| _times = inst.times[_mask].copy() | |
| _, default_nfft, _ = _check_input_st(_times, None) | |
| n_fft = method_kw.get("n_fft", default_nfft) | |
| *_, freqs = _compute_freqs_st(fmin, fmax, n_fft, inst.info["sfreq"]) | |
| # use Evoked.comment or str(Epochs.event_id) as the default comment... | |
| if comment is None: | |
| comment = getattr(inst, "comment", ",".join(getattr(inst, "event_id", ""))) | |
| # ...but don't overwrite if it's coming in with a comment already set | |
| if isinstance(inst, dict): | |
| inst.setdefault("comment", comment) | |
| else: | |
| self._comment = getattr(self, "_comment", comment) | |
| super().__init__( | |
| inst, | |
| method, | |
| freqs, | |
| tmin=tmin, | |
| tmax=tmax, | |
| picks=picks, | |
| proj=proj, | |
| decim=decim, | |
| n_jobs=n_jobs, | |
| verbose=verbose, | |
| **method_kw, | |
| ) | |
| def __getstate__(self): | |
| """Prepare AverageTFR object for serialization.""" | |
| out = super().__getstate__() | |
| out.update(nave=self.nave, comment=self.comment) | |
| # NOTE: self._itc should never exist in the instance returned to the user; it | |
| # is temporarily present in the output from the tfr_array_* function, and is | |
| # split out into a separate AverageTFR object (and deleted from the object | |
| # holding power estimates) before those objects are passed back to the user. | |
| # The following lines are there because we make use of __getstate__ to achieve | |
| # that splitting of objects. | |
| if hasattr(self, "_itc"): | |
| out.update(itc=self._itc) | |
| return out | |
| def __setstate__(self, state): | |
| """Unpack AverageTFR from serialized format.""" | |
| if state["data"].ndim not in [3, 4]: | |
| raise ValueError( | |
| f"RawTFR data should be 3D or 4D, got {state['data'].ndim}." | |
| ) | |
| # Set dims now since optional tapers makes it difficult to disentangle later | |
| state["dims"] = ("channel",) | |
| if state["data"].ndim == 4: | |
| state["dims"] += ("taper",) | |
| state["dims"] += ("freq", "time") | |
| super().__setstate__(state) | |
| self._comment = state.get("comment", "") | |
| self._nave = state.get("nave", 1) | |
| def comment(self): | |
| return self._comment | |
| def comment(self, comment): | |
| self._comment = comment | |
| def nave(self): | |
| return self._nave | |
| def nave(self, nave): | |
| self._nave = nave | |
| def _get_instance_data(self, time_mask): | |
| # AverageTFRs can be constructed from Epochs data, so we triage shape here. | |
| # Evoked data get a fake singleton "epoch" axis prepended | |
| dim = slice(None) if _get_instance_type_string(self) == "Epochs" else np.newaxis | |
| data = self.inst.get_data(picks=self._picks)[dim, :, time_mask] | |
| self._nave = getattr(self.inst, "nave", data.shape[0]) | |
| return data | |
| class AverageTFRArray(AverageTFR): | |
| """Data object for *precomputed* spectrotemporal representations of averaged data. | |
| Parameters | |
| ---------- | |
| %(info_not_none)s | |
| %(data_tfr)s | |
| %(times)s | |
| %(freqs_tfr_array)s | |
| nave : int | |
| The number of averaged TFRs. | |
| %(comment_averagetfr_attr)s | |
| %(method_tfr_array)s | |
| %(weights_tfr_array)s | |
| Attributes | |
| ---------- | |
| %(baseline_tfr_attr)s | |
| %(ch_names_tfr_attr)s | |
| %(comment_averagetfr_attr)s | |
| %(freqs_tfr_attr)s | |
| %(info_not_none)s | |
| %(method_tfr_attr)s | |
| %(nave_tfr_attr)s | |
| %(sfreq_tfr_attr)s | |
| %(shape_tfr_attr)s | |
| %(weights_tfr_attr)s | |
| See Also | |
| -------- | |
| AverageTFR | |
| EpochsTFRArray | |
| mne.Epochs.compute_tfr | |
| mne.Evoked.compute_tfr | |
| """ | |
| def __init__( | |
| self, | |
| info, | |
| data, | |
| times, | |
| freqs, | |
| *, | |
| nave=None, | |
| comment=None, | |
| method=None, | |
| weights=None, | |
| ): | |
| state = dict(info=info, data=data, times=times, freqs=freqs) | |
| optional = dict(nave=nave, comment=comment, method=method, weights=weights) | |
| for name, value in optional.items(): | |
| if value is not None: | |
| state[name] = value | |
| self.__setstate__(state) | |
| class EpochsTFR(BaseTFR, GetEpochsMixin): | |
| """Data object for spectrotemporal representations of epoched data. | |
| .. important:: | |
| The preferred means of creating EpochsTFR objects from :class:`~mne.Epochs` | |
| objects is via the instance method :meth:`~mne.Epochs.compute_tfr`. | |
| To create an EpochsTFR object from pre-computed data (i.e., a NumPy array) use | |
| :class:`~mne.time_frequency.EpochsTFRArray`. | |
| Parameters | |
| ---------- | |
| inst : instance of Epochs | |
| The data from which to compute the time-frequency representation. | |
| %(freqs_tfr_epochs)s | |
| %(method_tfr_epochs)s | |
| %(tmin_tmax_psd)s | |
| %(picks_good_data_noref)s | |
| %(proj_psd)s | |
| %(decim_tfr)s | |
| %(n_jobs)s | |
| %(verbose)s | |
| %(method_kw_tfr)s | |
| Attributes | |
| ---------- | |
| %(baseline_tfr_attr)s | |
| %(ch_names_tfr_attr)s | |
| %(comment_tfr_attr)s | |
| %(drop_log)s | |
| %(event_id_attr)s | |
| %(events_attr)s | |
| %(freqs_tfr_attr)s | |
| %(info_not_none)s | |
| %(metadata_attr)s | |
| %(method_tfr_attr)s | |
| %(selection_attr)s | |
| %(sfreq_tfr_attr)s | |
| %(shape_tfr_attr)s | |
| %(weights_tfr_attr)s | |
| See Also | |
| -------- | |
| mne.Epochs.compute_tfr | |
| RawTFR | |
| AverageTFR | |
| EpochsTFRArray | |
| References | |
| ---------- | |
| .. footbibliography:: | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| inst=None, | |
| freqs=None, | |
| method=None, | |
| tmin=None, | |
| tmax=None, | |
| picks=None, | |
| proj=False, | |
| decim=1, | |
| n_jobs=None, | |
| verbose=None, | |
| **method_kw, | |
| ): | |
| from ..epochs import BaseEpochs | |
| # dict is allowed for __setstate__ compatibility | |
| _validate_type( | |
| inst, (BaseEpochs, dict), "object passed to EpochsTFR constructor", "Epochs" | |
| ) | |
| super().__init__( | |
| inst, | |
| method, | |
| freqs, | |
| tmin=tmin, | |
| tmax=tmax, | |
| picks=picks, | |
| proj=proj, | |
| decim=decim, | |
| n_jobs=n_jobs, | |
| verbose=verbose, | |
| **method_kw, | |
| ) | |
| def __getitem__(self, item): | |
| """Subselect epochs from an EpochsTFR. | |
| Parameters | |
| ---------- | |
| %(item)s | |
| Access options are the same as for :class:`~mne.Epochs` objects, see the | |
| docstring Notes section of :meth:`mne.Epochs.__getitem__` for explanation. | |
| Returns | |
| ------- | |
| %(getitem_epochstfr_return)s | |
| """ | |
| return super().__getitem__(item) | |
| def __getstate__(self): | |
| """Prepare EpochsTFR object for serialization.""" | |
| out = super().__getstate__() | |
| out.update( | |
| metadata=self._metadata, | |
| drop_log=self.drop_log, | |
| event_id=self.event_id, | |
| events=self.events, | |
| selection=self.selection, | |
| raw_times=self._raw_times, | |
| ) | |
| return out | |
| def __setstate__(self, state): | |
| """Unpack EpochsTFR from serialized format.""" | |
| if state["data"].ndim not in [4, 5]: | |
| raise ValueError( | |
| f"EpochsTFR data should be 4D or 5D, got {state['data'].ndim}." | |
| ) | |
| # Set dims now since optional tapers makes it difficult to disentangle later | |
| state["dims"] = ("epoch", "channel") | |
| if state["data"].ndim == 5: | |
| state["dims"] += ("taper",) | |
| state["dims"] += ("freq", "time") | |
| super().__setstate__(state) | |
| self._metadata = state.get("metadata", None) | |
| n_epochs = self.shape[0] | |
| n_times = self.shape[-1] | |
| fake_samps = np.linspace( | |
| n_times, n_times * (n_epochs + 1), n_epochs, dtype=int, endpoint=False | |
| ) | |
| fake_events = np.dstack( | |
| (fake_samps, np.zeros_like(fake_samps), np.ones_like(fake_samps)) | |
| ).squeeze(axis=0) | |
| self.events = state.get("events", _ensure_events(fake_events)) | |
| self.event_id = state.get("event_id", _check_event_id(None, self.events)) | |
| self.selection = state.get("selection", np.arange(n_epochs)) | |
| self.drop_log = state.get( | |
| "drop_log", | |
| tuple( | |
| () if k in self.selection else ("IGNORED",) | |
| for k in range(max(len(self.events), max(self.selection) + 1)) | |
| ), | |
| ) | |
| self._bad_dropped = True # always true, need for `equalize_event_counts()` | |
| def __next__(self, return_event_id=False): | |
| """Iterate over EpochsTFR objects. | |
| NOTE: __iter__() and _stop_iter() are defined by the GetEpochs mixin. | |
| Parameters | |
| ---------- | |
| return_event_id : bool | |
| If ``True``, return both the EpochsTFR data and its associated ``event_id``. | |
| Returns | |
| ------- | |
| epoch : array of shape (n_channels, n_freqs, n_times) | |
| The single-epoch time-frequency data. | |
| event_id : int | |
| The integer event id associated with the epoch. Only returned if | |
| ``return_event_id`` is ``True``. | |
| """ | |
| if self._current >= len(self._data): | |
| self._stop_iter() | |
| epoch = self._data[self._current] | |
| event_id = self.events[self._current][-1] | |
| self._current += 1 | |
| if return_event_id: | |
| return epoch, event_id | |
| return epoch | |
| def _check_singleton(self): | |
| """Check if self contains only one Epoch, and return it as an AverageTFR.""" | |
| if self.shape[0] > 1: | |
| calling_func = inspect.currentframe().f_back.f_code.co_name | |
| raise NotImplementedError( | |
| f"Cannot call {calling_func}() from EpochsTFR with multiple epochs; " | |
| "please subselect a single epoch before plotting." | |
| ) | |
| return list(self.iter_evoked())[0] | |
| def _get_instance_data(self, time_mask): | |
| return self.inst.get_data(picks=self._picks)[:, :, time_mask] | |
| def _update_epoch_attributes(self): | |
| # adjust dims and shape | |
| if self.method != "stockwell": # stockwell consumes epochs dimension | |
| self._dims = ("epoch",) + self._dims | |
| self._shape = (len(self.inst),) + self._shape | |
| # we need these for to_data_frame() | |
| self.event_id = self.inst.event_id.copy() | |
| self.events = self.inst.events.copy() | |
| self.selection = self.inst.selection.copy() | |
| # we need these for __getitem__() | |
| self.drop_log = deepcopy(self.inst.drop_log) | |
| self._metadata = self.inst.metadata | |
| # we need this for compatibility with equalize_event_counts() | |
| self._bad_dropped = True | |
| def average(self, method="mean", *, dim="epochs", copy=False): | |
| """Aggregate the EpochsTFR across epochs, frequencies, or times. | |
| Parameters | |
| ---------- | |
| method : "mean" | "median" | callable | |
| How to aggregate the data across the given ``dim``. If callable, | |
| must take a :class:`NumPy array<numpy.ndarray>` of shape | |
| ``(n_epochs, n_channels, n_freqs, n_times)`` and return an array | |
| with one fewer dimensions (which dimension is collapsed depends on | |
| the value of ``dim``). Default is ``"mean"``. | |
| dim : "epochs" | "freqs" | "times" | |
| The dimension along which to combine the data. | |
| copy : bool | |
| Whether to return a copy of the modified instance, or modify in place. | |
| Ignored when ``dim="epochs"`` or ``"times"`` because those options return | |
| different types (:class:`~mne.time_frequency.AverageTFR` and | |
| :class:`~mne.time_frequency.EpochsSpectrum`, respectively). | |
| Returns | |
| ------- | |
| tfr : instance of EpochsTFR | AverageTFR | EpochsSpectrum | |
| The aggregated TFR object. | |
| Notes | |
| ----- | |
| Passing in ``np.median`` is considered unsafe for complex data; pass | |
| the string ``"median"`` instead to compute the *marginal* median | |
| (i.e. the median of the real and imaginary components separately). | |
| See discussion here: | |
| https://github.com/scipy/scipy/pull/12676#issuecomment-783370228 | |
| Averaging is not supported for data containing a taper dimension. | |
| """ | |
| if "taper" in self._dims: | |
| raise NotImplementedError( | |
| "Averaging multitaper tapers across epochs, frequencies, or times is " | |
| "not supported. If averaging across epochs, consider averaging the " | |
| "epochs before computing the complex/phase spectrum." | |
| ) | |
| _check_option("dim", dim, ("epochs", "freqs", "times")) | |
| axis = self._dims.index(dim[:-1]) # self._dims entries aren't plural | |
| func = _check_combine(mode=method, axis=axis) | |
| data = func(self.data) | |
| n_epochs, n_channels, n_freqs, n_times = self.data.shape | |
| freqs, times = self.freqs, self.times | |
| if dim == "epochs": | |
| expected_shape = self._data.shape[1:] | |
| elif dim == "freqs": | |
| expected_shape = (n_epochs, n_channels, n_times) | |
| freqs = np.mean(self.freqs, keepdims=True) | |
| elif dim == "times": | |
| expected_shape = (n_epochs, n_channels, n_freqs) | |
| times = np.mean(self.times, keepdims=True) | |
| if data.shape != expected_shape: | |
| raise RuntimeError( | |
| "EpochsTFR.average() got a method that resulted in data of shape " | |
| f"{data.shape}, but it should be {expected_shape}." | |
| ) | |
| state = self.__getstate__() | |
| # restore singleton freqs axis (not necessary for epochs/times: class changes) | |
| if dim == "freqs": | |
| data = np.expand_dims(data, axis=axis) | |
| else: | |
| state["dims"] = (*state["dims"][:axis], *state["dims"][axis + 1 :]) | |
| state["data"] = data | |
| state["info"] = deepcopy(self.info) | |
| state["freqs"] = freqs | |
| state["times"] = times | |
| if dim == "epochs": | |
| state["inst_type_str"] = "Evoked" | |
| state["nave"] = n_epochs | |
| state["comment"] = f"{method} of {n_epochs} EpochsTFR{_pl(n_epochs)}" | |
| out = AverageTFR(inst=state) | |
| out._data_type = "Average Power" | |
| return out | |
| elif dim == "times": | |
| return EpochsSpectrum( | |
| state, | |
| method=None, | |
| fmin=None, | |
| fmax=None, | |
| tmin=None, | |
| tmax=None, | |
| picks=None, | |
| exclude=None, | |
| proj=None, | |
| remove_dc=None, | |
| n_jobs=None, | |
| ) | |
| # ↓↓↓ these two are for dim == "freqs" | |
| elif copy: | |
| return EpochsTFR(inst=state, method=None, freqs=None) | |
| else: | |
| self._data = np.expand_dims(data, axis=axis) | |
| self._freqs = freqs | |
| return self | |
| def drop(self, indices, reason="USER", verbose=None): | |
| """Drop epochs based on indices or boolean mask. | |
| .. note:: The indices refer to the current set of undropped epochs | |
| rather than the complete set of dropped and undropped epochs. | |
| They are therefore not necessarily consistent with any | |
| external indices (e.g., behavioral logs). To drop epochs | |
| based on external criteria, do not use the ``preload=True`` | |
| flag when constructing an Epochs object, and call this | |
| method before calling the :meth:`mne.Epochs.drop_bad` or | |
| :meth:`mne.Epochs.load_data` methods. | |
| Parameters | |
| ---------- | |
| indices : array of int or bool | |
| Set epochs to remove by specifying indices to remove or a boolean | |
| mask to apply (where True values get removed). Events are | |
| correspondingly modified. | |
| reason : str | |
| Reason for dropping the epochs ('ECG', 'timeout', 'blink' etc). | |
| Default: 'USER'. | |
| %(verbose)s | |
| Returns | |
| ------- | |
| epochs : instance of Epochs or EpochsTFR | |
| The epochs with indices dropped. Operates in-place. | |
| """ | |
| from ..epochs import BaseEpochs | |
| BaseEpochs.drop(self, indices=indices, reason=reason, verbose=verbose) | |
| return self | |
| def iter_evoked(self, copy=False): | |
| """Iterate over EpochsTFR to yield a sequence of AverageTFR objects. | |
| The AverageTFR objects will each contain a single epoch (i.e., no averaging is | |
| performed). This method resets the EpochTFR instance's iteration state to the | |
| first epoch. | |
| Parameters | |
| ---------- | |
| copy : bool | |
| Whether to yield copies of the data and measurement info, or views/pointers. | |
| """ | |
| self.__iter__() | |
| state = self.__getstate__() | |
| state["inst_type_str"] = "Evoked" | |
| state["dims"] = state["dims"][1:] # drop "epochs" | |
| while True: | |
| try: | |
| data, event_id = self.__next__(return_event_id=True) | |
| except StopIteration: | |
| break | |
| if copy: | |
| state["info"] = deepcopy(self.info) | |
| state["data"] = data.copy() | |
| else: | |
| state["data"] = data | |
| state["nave"] = 1 | |
| yield AverageTFR(inst=state, method=None, freqs=None, comment=str(event_id)) | |
| def plot( | |
| self, | |
| picks=None, | |
| *, | |
| exclude=(), | |
| tmin=None, | |
| tmax=None, | |
| fmin=None, | |
| fmax=None, | |
| baseline=None, | |
| mode="mean", | |
| dB=False, | |
| combine=None, | |
| layout=None, # TODO deprecate; not used in orig implementation | |
| yscale="auto", | |
| vlim=(None, None), | |
| cnorm=None, | |
| cmap=None, | |
| colorbar=True, | |
| title=None, # don't deprecate this one; has (useful) option title="auto" | |
| mask=None, | |
| mask_style=None, | |
| mask_cmap="Greys", | |
| mask_alpha=0.1, | |
| axes=None, | |
| show=True, | |
| verbose=None, | |
| ): | |
| singleton_epoch = self._check_singleton() | |
| return singleton_epoch.plot( | |
| picks=picks, | |
| exclude=exclude, | |
| tmin=tmin, | |
| tmax=tmax, | |
| fmin=fmin, | |
| fmax=fmax, | |
| baseline=baseline, | |
| mode=mode, | |
| dB=dB, | |
| combine=combine, | |
| layout=layout, | |
| yscale=yscale, | |
| vlim=vlim, | |
| cnorm=cnorm, | |
| cmap=cmap, | |
| colorbar=colorbar, | |
| title=title, | |
| mask=mask, | |
| mask_style=mask_style, | |
| mask_cmap=mask_cmap, | |
| mask_alpha=mask_alpha, | |
| axes=axes, | |
| show=show, | |
| verbose=verbose, | |
| ) | |
| def plot_topo( | |
| self, | |
| picks=None, | |
| baseline=None, | |
| mode="mean", | |
| tmin=None, | |
| tmax=None, | |
| fmin=None, | |
| fmax=None, | |
| vmin=None, # TODO deprecate in favor of `vlim` (needs helper func refactor) | |
| vmax=None, | |
| layout=None, | |
| cmap=None, | |
| title=None, # don't deprecate; topo titles aren't standard (color, size, just.) | |
| dB=False, | |
| colorbar=True, | |
| layout_scale=0.945, | |
| show=True, | |
| border="none", | |
| fig_facecolor="k", | |
| fig_background=None, | |
| font_color="w", | |
| yscale="auto", | |
| verbose=None, | |
| ): | |
| singleton_epoch = self._check_singleton() | |
| return singleton_epoch.plot_topo( | |
| picks=picks, | |
| baseline=baseline, | |
| mode=mode, | |
| tmin=tmin, | |
| tmax=tmax, | |
| fmin=fmin, | |
| fmax=fmax, | |
| vmin=vmin, | |
| vmax=vmax, | |
| layout=layout, | |
| cmap=cmap, | |
| title=title, | |
| dB=dB, | |
| colorbar=colorbar, | |
| layout_scale=layout_scale, | |
| show=show, | |
| border=border, | |
| fig_facecolor=fig_facecolor, | |
| fig_background=fig_background, | |
| font_color=font_color, | |
| yscale=yscale, | |
| verbose=verbose, | |
| ) | |
| def plot_joint( | |
| self, | |
| *, | |
| timefreqs=None, | |
| picks=None, | |
| exclude=(), | |
| combine="mean", | |
| tmin=None, | |
| tmax=None, | |
| fmin=None, | |
| fmax=None, | |
| baseline=None, | |
| mode="mean", | |
| dB=False, | |
| yscale="auto", | |
| vlim=(None, None), | |
| cnorm=None, | |
| cmap=None, | |
| colorbar=True, | |
| title=None, | |
| show=True, | |
| topomap_args=None, | |
| image_args=None, | |
| verbose=None, | |
| ): | |
| singleton_epoch = self._check_singleton() | |
| return singleton_epoch.plot_joint( | |
| timefreqs=timefreqs, | |
| picks=picks, | |
| exclude=exclude, | |
| combine=combine, | |
| tmin=tmin, | |
| tmax=tmax, | |
| fmin=fmin, | |
| fmax=fmax, | |
| baseline=baseline, | |
| mode=mode, | |
| dB=dB, | |
| yscale=yscale, | |
| vlim=vlim, | |
| cnorm=cnorm, | |
| cmap=cmap, | |
| colorbar=colorbar, | |
| title=title, | |
| show=show, | |
| topomap_args=topomap_args, | |
| image_args=image_args, | |
| verbose=verbose, | |
| ) | |
| def plot_topomap( | |
| self, | |
| tmin=None, | |
| tmax=None, | |
| fmin=0.0, | |
| fmax=np.inf, | |
| *, | |
| ch_type=None, | |
| baseline=None, | |
| mode="mean", | |
| sensors=True, | |
| show_names=False, | |
| mask=None, | |
| mask_params=None, | |
| contours=6, | |
| outlines="head", | |
| sphere=None, | |
| image_interp=_INTERPOLATION_DEFAULT, | |
| extrapolate=_EXTRAPOLATE_DEFAULT, | |
| border=_BORDER_DEFAULT, | |
| res=64, | |
| size=2, | |
| cmap=None, | |
| vlim=(None, None), | |
| cnorm=None, | |
| colorbar=True, | |
| cbar_fmt="%1.1e", | |
| units=None, | |
| axes=None, | |
| show=True, | |
| ): | |
| singleton_epoch = self._check_singleton() | |
| return singleton_epoch.plot_topomap( | |
| tmin=tmin, | |
| tmax=tmax, | |
| fmin=fmin, | |
| fmax=fmax, | |
| ch_type=ch_type, | |
| baseline=baseline, | |
| mode=mode, | |
| sensors=sensors, | |
| show_names=show_names, | |
| mask=mask, | |
| mask_params=mask_params, | |
| contours=contours, | |
| outlines=outlines, | |
| sphere=sphere, | |
| image_interp=image_interp, | |
| extrapolate=extrapolate, | |
| border=border, | |
| res=res, | |
| size=size, | |
| cmap=cmap, | |
| vlim=vlim, | |
| cnorm=cnorm, | |
| colorbar=colorbar, | |
| cbar_fmt=cbar_fmt, | |
| units=units, | |
| axes=axes, | |
| show=show, | |
| ) | |
| class EpochsTFRArray(EpochsTFR): | |
| """Data object for *precomputed* spectrotemporal representations of epoched data. | |
| Parameters | |
| ---------- | |
| %(info_not_none)s | |
| %(data_tfr)s | |
| %(times)s | |
| %(freqs_tfr_array)s | |
| %(comment_tfr_attr)s | |
| %(method_tfr_array)s | |
| %(events_epochstfr)s | |
| %(event_id_epochstfr)s | |
| %(selection)s | |
| %(drop_log)s | |
| %(metadata_epochstfr)s | |
| %(weights_tfr_array)s | |
| Attributes | |
| ---------- | |
| %(baseline_tfr_attr)s | |
| %(ch_names_tfr_attr)s | |
| %(comment_tfr_attr)s | |
| %(drop_log)s | |
| %(event_id_attr)s | |
| %(events_attr)s | |
| %(freqs_tfr_attr)s | |
| %(info_not_none)s | |
| %(metadata_attr)s | |
| %(method_tfr_attr)s | |
| %(selection_attr)s | |
| %(sfreq_tfr_attr)s | |
| %(shape_tfr_attr)s | |
| %(weights_tfr_attr)s | |
| See Also | |
| -------- | |
| AverageTFR | |
| mne.Epochs.compute_tfr | |
| mne.Evoked.compute_tfr | |
| """ | |
| def __init__( | |
| self, | |
| info, | |
| data, | |
| times, | |
| freqs, | |
| *, | |
| comment=None, | |
| method=None, | |
| events=None, | |
| event_id=None, | |
| selection=None, | |
| drop_log=None, | |
| metadata=None, | |
| weights=None, | |
| ): | |
| state = dict(info=info, data=data, times=times, freqs=freqs) | |
| optional = dict( | |
| comment=comment, | |
| method=method, | |
| events=events, | |
| event_id=event_id, | |
| selection=selection, | |
| drop_log=drop_log, | |
| metadata=metadata, | |
| weights=weights, | |
| ) | |
| for name, value in optional.items(): | |
| if value is not None: | |
| state[name] = value | |
| self.__setstate__(state) | |
| class RawTFR(BaseTFR): | |
| """Data object for spectrotemporal representations of continuous data. | |
| .. warning:: The preferred means of creating RawTFR objects from | |
| :class:`~mne.io.Raw` objects is via the instance method | |
| :meth:`~mne.io.Raw.compute_tfr`. Direct class instantiation | |
| is not supported. | |
| Parameters | |
| ---------- | |
| inst : instance of Raw | |
| The data from which to compute the time-frequency representation. | |
| %(method_tfr)s | |
| %(freqs_tfr)s | |
| %(tmin_tmax_psd)s | |
| %(picks_good_data_noref)s | |
| %(proj_psd)s | |
| %(reject_by_annotation_tfr)s | |
| %(decim_tfr)s | |
| %(n_jobs)s | |
| %(verbose)s | |
| %(method_kw_tfr)s | |
| Attributes | |
| ---------- | |
| ch_names : list | |
| The channel names. | |
| freqs : array | |
| Frequencies at which the amplitude, power, or fourier coefficients | |
| have been computed. | |
| %(info_not_none)s | |
| method : str | |
| The method used to compute the spectra (``'morlet'``, ``'multitaper'`` | |
| or ``'stockwell'``). | |
| %(weights_tfr_attr)s | |
| See Also | |
| -------- | |
| mne.io.Raw.compute_tfr | |
| EpochsTFR | |
| AverageTFR | |
| References | |
| ---------- | |
| .. footbibliography:: | |
| """ | |
| def __init__( | |
| self, | |
| inst, | |
| method=None, | |
| freqs=None, | |
| *, | |
| tmin=None, | |
| tmax=None, | |
| picks=None, | |
| proj=False, | |
| reject_by_annotation=False, | |
| decim=1, | |
| n_jobs=None, | |
| verbose=None, | |
| **method_kw, | |
| ): | |
| from ..io import BaseRaw | |
| # dict is allowed for __setstate__ compatibility | |
| _validate_type( | |
| inst, (BaseRaw, dict), "object passed to RawTFR constructor", "Raw" | |
| ) | |
| super().__init__( | |
| inst, | |
| method, | |
| freqs, | |
| tmin=tmin, | |
| tmax=tmax, | |
| picks=picks, | |
| proj=proj, | |
| reject_by_annotation=reject_by_annotation, | |
| decim=decim, | |
| n_jobs=n_jobs, | |
| verbose=verbose, | |
| **method_kw, | |
| ) | |
| def __setstate__(self, state): | |
| """Unpack RawTFR from serialized format.""" | |
| if state["data"].ndim not in [3, 4]: | |
| raise ValueError( | |
| f"RawTFR data should be 3D or 4D, got {state['data'].ndim}." | |
| ) | |
| # Set dims now since optional tapers makes it difficult to disentangle later | |
| state["dims"] = ("channel",) | |
| if state["data"].ndim == 4: | |
| state["dims"] += ("taper",) | |
| state["dims"] += ("freq", "time") | |
| super().__setstate__(state) | |
| def __getitem__(self, item): | |
| """Get RawTFR data. | |
| Parameters | |
| ---------- | |
| item : int | slice | array-like | |
| Indexing is similar to a :class:`NumPy array<numpy.ndarray>`; see | |
| Notes. | |
| Returns | |
| ------- | |
| %(getitem_tfr_return)s | |
| Notes | |
| ----- | |
| The last axis is always time, the next-to-last axis is always | |
| frequency, and the first axis is always channel. If | |
| ``method='multitaper'`` and ``output='complex'`` then the second axis | |
| will be taper index. | |
| Integer-, list-, and slice-based indexing is possible: | |
| - ``raw_tfr[[0, 2]]`` gives the whole time-frequency plane for the | |
| first and third channels. | |
| - ``raw_tfr[..., :3, :]`` gives the first 3 frequency bins and all | |
| times for all channels (and tapers, if present). | |
| - ``raw_tfr[..., :100]`` gives the first 100 time samples in all | |
| frequency bins for all channels (and tapers). | |
| - ``raw_tfr[(4, 7)]`` is the same as ``raw_tfr[4, 7]``. | |
| .. note:: | |
| Unlike :class:`~mne.io.Raw` objects (which returns a tuple of the | |
| requested data values and the corresponding times), accessing | |
| :class:`~mne.time_frequency.RawTFR` values via subscript does | |
| **not** return the corresponding frequency bin values. If you need | |
| them, use ``RawTFR.freqs[freq_indices]`` or | |
| ``RawTFR.get_data(..., return_freqs=True)``. | |
| """ | |
| from ..io import BaseRaw | |
| self._parse_get_set_params = partial(BaseRaw._parse_get_set_params, self) | |
| return BaseRaw._getitem(self, item, return_times=False) | |
| def _get_instance_data(self, time_mask, reject_by_annotation): | |
| start, stop = np.where(time_mask)[0][[0, -1]] | |
| rba = "NaN" if reject_by_annotation else None | |
| data = self.inst.get_data( | |
| self._picks, start, stop + 1, reject_by_annotation=rba | |
| ) | |
| # prepend a singleton "epochs" axis | |
| return data[np.newaxis] | |
| class RawTFRArray(RawTFR): | |
| """Data object for *precomputed* spectrotemporal representations of continuous data. | |
| Parameters | |
| ---------- | |
| %(info_not_none)s | |
| %(data_tfr)s | |
| %(times)s | |
| %(freqs_tfr_array)s | |
| %(method_tfr_array)s | |
| %(weights_tfr_array)s | |
| Attributes | |
| ---------- | |
| %(baseline_tfr_attr)s | |
| %(ch_names_tfr_attr)s | |
| %(freqs_tfr_attr)s | |
| %(info_not_none)s | |
| %(method_tfr_attr)s | |
| %(sfreq_tfr_attr)s | |
| %(shape_tfr_attr)s | |
| %(weights_tfr_attr)s | |
| See Also | |
| -------- | |
| RawTFR | |
| mne.io.Raw.compute_tfr | |
| EpochsTFRArray | |
| AverageTFRArray | |
| """ | |
| def __init__( | |
| self, | |
| info, | |
| data, | |
| times, | |
| freqs, | |
| *, | |
| method=None, | |
| weights=None, | |
| ): | |
| state = dict(info=info, data=data, times=times, freqs=freqs) | |
| optional = dict(method=method, weights=weights) | |
| for name, value in optional.items(): | |
| if value is not None: | |
| state[name] = value | |
| self.__setstate__(state) | |
| def combine_tfr(all_tfr, weights="nave"): | |
| """Merge AverageTFR data by weighted addition. | |
| Create a new :class:`mne.time_frequency.AverageTFR` instance, using a combination of | |
| the supplied instances as its data. By default, the mean (weighted by trials) is | |
| used. Subtraction can be performed by passing negative weights (e.g., [1, -1]). Data | |
| must have the same channels and the same time instants. | |
| Parameters | |
| ---------- | |
| all_tfr : list of AverageTFR | |
| The tfr datasets. | |
| weights : list of float | str | |
| The weights to apply to the data of each AverageTFR instance. | |
| Can also be ``'nave'`` to weight according to tfr.nave, | |
| or ``'equal'`` to use equal weighting (each weighted as ``1/N``). | |
| Returns | |
| ------- | |
| tfr : AverageTFR | |
| The new TFR data. | |
| Notes | |
| ----- | |
| Aggregating multitaper TFR datasets with a taper dimension such as for complex or | |
| phase data is not supported. | |
| .. versionadded:: 0.11.0 | |
| """ | |
| if any("taper" in tfr._dims for tfr in all_tfr): | |
| raise NotImplementedError( | |
| "Aggregating multitaper tapers across TFR datasets is not supported." | |
| ) | |
| tfr = all_tfr[0].copy() | |
| if isinstance(weights, str): | |
| if weights not in ("nave", "equal"): | |
| raise ValueError('Weights must be a list of float, or "nave" or "equal"') | |
| if weights == "nave": | |
| weights = np.array([e.nave for e in all_tfr], float) | |
| weights /= weights.sum() | |
| else: # == 'equal' | |
| weights = [1.0 / len(all_tfr)] * len(all_tfr) | |
| weights = np.array(weights, float) | |
| if weights.ndim != 1 or weights.size != len(all_tfr): | |
| raise ValueError("Weights must be the same size as all_tfr") | |
| ch_names = tfr.ch_names | |
| for t_ in all_tfr[1:]: | |
| assert t_.ch_names == ch_names, ( | |
| f"{tfr} and {t_} do not contain the same channels" | |
| ) | |
| assert np.max(np.abs(t_.times - tfr.times)) < 1e-7, ( | |
| f"{tfr} and {t_} do not contain the same time instants" | |
| ) | |
| # use union of bad channels | |
| bads = list(set(tfr.info["bads"]).union(*(t_.info["bads"] for t_ in all_tfr[1:]))) | |
| tfr.info["bads"] = bads | |
| # XXX : should be refactored with combined_evoked function | |
| tfr.data = sum(w * t_.data for w, t_ in zip(weights, all_tfr)) | |
| tfr.nave = max(int(1.0 / sum(w**2 / e.nave for w, e in zip(weights, all_tfr))), 1) | |
| return tfr | |
| # Utils | |
| # ↓↓↓↓↓↓↓↓↓↓↓ this is still used in _stockwell.py | |
| def _get_data(inst, return_itc): | |
| """Get data from Epochs or Evoked instance as epochs x ch x time.""" | |
| from ..epochs import BaseEpochs | |
| from ..evoked import Evoked | |
| if not isinstance(inst, BaseEpochs | Evoked): | |
| raise TypeError("inst must be Epochs or Evoked") | |
| if isinstance(inst, BaseEpochs): | |
| data = inst.get_data(copy=False) | |
| else: | |
| if return_itc: | |
| raise ValueError("return_itc must be False for evoked data") | |
| data = inst.data[np.newaxis].copy() | |
| return data | |
| def _prepare_picks(info, data, picks, axis): | |
| """Prepare the picks.""" | |
| picks = _picks_to_idx(info, picks, exclude="bads") | |
| info = pick_info(info, picks) | |
| sl = [slice(None)] * data.ndim | |
| sl[axis] = picks | |
| data = data[tuple(sl)] | |
| return info, data | |
| def _centered(arr, newsize): | |
| """Aux Function to center data.""" | |
| # Return the center newsize portion of the array. | |
| newsize = np.asarray(newsize) | |
| currsize = np.array(arr.shape) | |
| startind = (currsize - newsize) // 2 | |
| endind = startind + newsize | |
| myslice = [slice(startind[k], endind[k]) for k in range(len(endind))] | |
| return arr[tuple(myslice)] | |
| def _ensure_slice(decim): | |
| """Aux function checking the decim parameter.""" | |
| _validate_type(decim, ("int-like", slice), "decim") | |
| if not isinstance(decim, slice): | |
| decim = slice(None, None, int(decim)) | |
| # ensure that we can actually use `decim.step` | |
| if decim.step is None: | |
| decim = slice(decim.start, decim.stop, 1) | |
| return decim | |
| # i/o | |
| def write_tfrs(fname, tfr, overwrite=False, *, verbose=None): | |
| """Write a TFR dataset to hdf5. | |
| Parameters | |
| ---------- | |
| fname : path-like | |
| The file name, which should end with ``-tfr.h5``. | |
| tfr : RawTFR | EpochsTFR | AverageTFR | list of RawTFR | list of EpochsTFR | list of AverageTFR | |
| The (list of) TFR object(s) to save in one file. If ``tfr.comment`` is ``None``, | |
| a sequential numeric string name will be generated on the fly, based on the | |
| order in which the TFR objects are passed. This can be used to selectively load | |
| single TFR objects from the file later. | |
| %(overwrite)s | |
| %(verbose)s | |
| See Also | |
| -------- | |
| read_tfrs | |
| Notes | |
| ----- | |
| .. versionadded:: 0.9.0 | |
| """ # noqa E501 | |
| _, write_hdf5 = _import_h5io_funcs() | |
| out = [] | |
| if not isinstance(tfr, list | tuple): | |
| tfr = [tfr] | |
| for ii, tfr_ in enumerate(tfr): | |
| comment = ii if getattr(tfr_, "comment", None) is None else tfr_.comment | |
| state = tfr_.__getstate__() | |
| if "metadata" in state: | |
| state["metadata"] = _prepare_write_metadata(state["metadata"]) | |
| out.append((comment, state)) | |
| write_hdf5(fname, out, overwrite=overwrite, title="mnepython", slash="replace") | |
| def read_tfrs(fname, condition=None, *, verbose=None): | |
| """Load a TFR object from disk. | |
| Parameters | |
| ---------- | |
| fname : path-like | |
| Path to a TFR file in HDF5 format, which should end with ``-tfr.h5`` or | |
| ``-tfr.hdf5``. | |
| condition : int or str | list of int or str | None | |
| The condition to load. If ``None``, all conditions will be returned. | |
| Defaults to ``None``. | |
| %(verbose)s | |
| Returns | |
| ------- | |
| tfr : RawTFR | EpochsTFR | AverageTFR | list of RawTFR | list of EpochsTFR | list of AverageTFR | |
| The loaded time-frequency object. | |
| See Also | |
| -------- | |
| mne.time_frequency.RawTFR.save | |
| mne.time_frequency.EpochsTFR.save | |
| mne.time_frequency.AverageTFR.save | |
| write_tfrs | |
| Notes | |
| ----- | |
| .. versionadded:: 0.9.0 | |
| """ # noqa E501 | |
| read_hdf5, _ = _import_h5io_funcs() | |
| fname = _check_fname(fname=fname, overwrite="read", must_exist=False) | |
| valid_fnames = tuple( | |
| f"{sep}tfr.{ext}" for sep in ("-", "_") for ext in ("h5", "hdf5") | |
| ) | |
| check_fname(fname, "tfr", valid_fnames) | |
| logger.info(f"Reading {fname} ...") | |
| hdf5_dict = read_hdf5(fname, title="mnepython", slash="replace") | |
| # single TFR from TFR.save() | |
| if "inst_type_str" in hdf5_dict: | |
| if "epoch" in hdf5_dict["dims"]: | |
| Klass = EpochsTFR | |
| elif "nave" in hdf5_dict: | |
| Klass = AverageTFR | |
| else: | |
| Klass = RawTFR | |
| out = Klass(inst=hdf5_dict) | |
| if getattr(out, "metadata", None) is not None: | |
| out.metadata = _prepare_read_metadata(out.metadata) | |
| return out | |
| # maybe multiple TFRs from write_tfrs() | |
| return _read_multiple_tfrs(hdf5_dict, condition=condition, verbose=verbose) | |
| def _read_multiple_tfrs(tfr_data, condition=None, *, verbose=None): | |
| """Read (possibly multiple) TFR datasets from an h5 file written by write_tfrs().""" | |
| out = list() | |
| keys = list() | |
| # tfr_data is a list of (comment, tfr_dict) tuples | |
| for key, tfr in tfr_data: | |
| keys.append(str(key)) # auto-assigned keys are ints | |
| is_epochs = tfr["data"].ndim == 4 | |
| is_average = "nave" in tfr | |
| if condition is not None: | |
| if not is_average: | |
| raise NotImplementedError( | |
| "condition is only supported when reading AverageTFRs." | |
| ) | |
| if key != condition: | |
| continue | |
| tfr = dict(tfr) | |
| tfr["info"] = Info(tfr["info"]) | |
| tfr["info"]._check_consistency() | |
| if "metadata" in tfr: | |
| tfr["metadata"] = _prepare_read_metadata(tfr["metadata"]) | |
| # additional keys needed for TFR __setstate__ | |
| defaults = dict(baseline=None, data_type="Power Estimates") | |
| if is_epochs: | |
| Klass = EpochsTFR | |
| defaults.update( | |
| inst_type_str="Epochs", dims=("epoch", "channel", "freq", "time") | |
| ) | |
| elif is_average: | |
| Klass = AverageTFR | |
| defaults.update(inst_type_str="Evoked", dims=("channel", "freq", "time")) | |
| else: | |
| Klass = RawTFR | |
| defaults.update(inst_type_str="Raw", dims=("channel", "freq", "time")) | |
| out.append(Klass(inst=defaults | tfr)) | |
| if len(out) == 0: | |
| raise ValueError( | |
| f'Cannot find condition "{condition}" in this file. ' | |
| f"The file contains conditions {', '.join(keys)}" | |
| ) | |
| if len(out) == 1: | |
| out = out[0] | |
| return out | |
| def _get_timefreqs(tfr, timefreqs): | |
| """Find and/or setup timefreqs for `tfr.plot_joint`.""" | |
| # Input check | |
| timefreq_error_msg = ( | |
| "Supplied `timefreqs` are somehow malformed. Please supply None, " | |
| "a list of tuple pairs, or a dict of such tuple pairs, not {}" | |
| ) | |
| if isinstance(timefreqs, dict): | |
| for k, v in timefreqs.items(): | |
| for item in (k, v): | |
| if len(item) != 2 or any(not _is_numeric(n) for n in item): | |
| raise ValueError(timefreq_error_msg, item) | |
| elif timefreqs is not None: | |
| if not hasattr(timefreqs, "__len__"): | |
| raise ValueError(timefreq_error_msg.format(timefreqs)) | |
| if len(timefreqs) == 2 and all(_is_numeric(v) for v in timefreqs): | |
| timefreqs = [tuple(timefreqs)] # stick a pair of numbers in a list | |
| else: | |
| for item in timefreqs: | |
| if ( | |
| hasattr(item, "__len__") | |
| and len(item) == 2 | |
| and all(_is_numeric(n) for n in item) | |
| ): | |
| pass | |
| else: | |
| raise ValueError(timefreq_error_msg.format(item)) | |
| # If None, automatic identification of max peak | |
| else: | |
| order = max((1, tfr.data.shape[2] // 30)) | |
| peaks_idx = argrelmax(tfr.data, order=order, axis=2) | |
| if peaks_idx[0].size == 0: | |
| _, p_t, p_f = np.unravel_index(tfr.data.argmax(), tfr.data.shape) | |
| timefreqs = [(tfr.times[p_t], tfr.freqs[p_f])] | |
| else: | |
| peaks = [tfr.data[0, f, t] for f, t in zip(peaks_idx[1], peaks_idx[2])] | |
| peakmax_idx = np.argmax(peaks) | |
| peakmax_time = tfr.times[peaks_idx[2][peakmax_idx]] | |
| peakmax_freq = tfr.freqs[peaks_idx[1][peakmax_idx]] | |
| timefreqs = [(peakmax_time, peakmax_freq)] | |
| timefreqs = { | |
| tuple(k): np.asarray(timefreqs[k]) | |
| if isinstance(timefreqs, dict) | |
| else np.array([0, 0]) | |
| for k in timefreqs | |
| } | |
| return timefreqs | |
| def _check_tfr_complex(tfr, reason="source space estimation"): | |
| """Check that time-frequency epochs or average data is complex.""" | |
| if not np.iscomplexobj(tfr.data): | |
| raise RuntimeError(f"Time-frequency data must be complex for {reason}") | |
| def _merge_if_grads(data, info, ch_type, sphere, combine=None): | |
| if ch_type == "grad": | |
| grad_picks = _pair_grad_sensors(info, topomap_coords=False) | |
| pos = _find_topomap_coords(info, picks=grad_picks[::2], sphere=sphere) | |
| grad_method = combine if isinstance(combine, str) else "rms" | |
| data, _ = _merge_ch_data(data[grad_picks], ch_type, [], method=grad_method) | |
| else: | |
| pos, _ = _get_pos_outlines(info, picks=ch_type, sphere=sphere) | |
| return data, pos | |
| def _prep_data_for_plot( | |
| data, | |
| times, | |
| freqs, | |
| *, | |
| tmin=None, | |
| tmax=None, | |
| fmin=None, | |
| fmax=None, | |
| baseline=None, | |
| mode=None, | |
| dB=False, | |
| taper_weights=None, | |
| verbose=None, | |
| ): | |
| # baseline | |
| copy = baseline is not None | |
| data = rescale(data, times, baseline, mode, copy=copy, verbose=verbose) | |
| # crop times | |
| time_mask = np.nonzero(_time_mask(times, tmin, tmax))[0] | |
| times = times[time_mask] | |
| # crop freqs | |
| freq_mask = np.nonzero(_time_mask(freqs, fmin, fmax))[0] | |
| freqs = freqs[freq_mask] | |
| # crop data | |
| data = data[..., freq_mask, :][..., time_mask] | |
| # handle unaggregated multitaper (complex or phase multitaper data) | |
| if taper_weights is not None: # assumes a taper dimension | |
| logger.info("Aggregating multitaper estimates before plotting...") | |
| if np.iscomplexobj(data): # complex coefficients → power | |
| data = _tfr_from_mt(data, taper_weights) | |
| else: # tapered phase data → weighted phase data | |
| # channels, tapers, freqs, time | |
| assert data.ndim == 4 | |
| # weights as a function of (tapers, freqs) | |
| assert taper_weights.ndim == 2 | |
| data = (data * taper_weights[np.newaxis, :, :, np.newaxis]).mean(axis=1) | |
| # handle remaining complex amplitude → real power | |
| if np.iscomplexobj(data): | |
| data = (data * data.conj()).real | |
| if dB: | |
| data = 10 * np.log10(data) | |
| return data, times, freqs | |
| def _tfr_from_mt(x_mt, weights): | |
| """Aggregate complex multitaper coefficients over tapers and convert to power. | |
| Parameters | |
| ---------- | |
| x_mt : array, shape (..., n_tapers, n_freqs, n_times) | |
| The complex-valued multitaper coefficients. | |
| weights : array, shape (n_tapers, n_freqs) | |
| The weights to use to combine the tapered estimates. | |
| Returns | |
| ------- | |
| tfr : array, shape (..., n_freqs, n_times) | |
| The time-frequency power estimates. | |
| """ | |
| # add singleton dim for time and any dims preceding the tapers | |
| weights = weights[..., np.newaxis] | |
| tfr = weights * x_mt | |
| tfr *= tfr.conj() | |
| tfr = tfr.real.sum(axis=-3) | |
| tfr *= 2 / (weights * weights.conj()).real.sum(axis=-3) | |
| return tfr | |