Spaces:
Running
Running
| # Authors: The MNE-Python contributors. | |
| # License: BSD-3-Clause | |
| # Copyright the MNE-Python contributors. | |
| import numpy as np | |
| from numpy.polynomial.legendre import legval | |
| from scipy.interpolate import RectBivariateSpline | |
| from scipy.linalg import pinv | |
| from scipy.spatial.distance import pdist, squareform | |
| from .._fiff.meas_info import _simplify_info | |
| from .._fiff.pick import pick_channels, pick_info, pick_types | |
| from ..surface import _normalize_vectors | |
| from ..utils import _validate_type, logger, verbose, warn | |
| def _calc_h(cosang, stiffness=4, n_legendre_terms=50): | |
| """Calculate spherical spline h function between points on a sphere. | |
| Parameters | |
| ---------- | |
| cosang : array-like | float | |
| cosine of angles between pairs of points on a spherical surface. This | |
| is equivalent to the dot product of unit vectors. | |
| stiffness : float | |
| stiffnes of the spline. Also referred to as ``m``. | |
| n_legendre_terms : int | |
| number of Legendre terms to evaluate. | |
| """ | |
| factors = [ | |
| (2 * n + 1) / (n ** (stiffness - 1) * (n + 1) ** (stiffness - 1) * 4 * np.pi) | |
| for n in range(1, n_legendre_terms + 1) | |
| ] | |
| return legval(cosang, [0] + factors) | |
| def _calc_g(cosang, stiffness=4, n_legendre_terms=50): | |
| """Calculate spherical spline g function between points on a sphere. | |
| Parameters | |
| ---------- | |
| cosang : array-like of float, shape(n_channels, n_channels) | |
| cosine of angles between pairs of points on a spherical surface. This | |
| is equivalent to the dot product of unit vectors. | |
| stiffness : float | |
| stiffness of the spline. | |
| n_legendre_terms : int | |
| number of Legendre terms to evaluate. | |
| Returns | |
| ------- | |
| G : np.ndrarray of float, shape(n_channels, n_channels) | |
| The G matrix. | |
| """ | |
| factors = [ | |
| (2 * n + 1) / (n**stiffness * (n + 1) ** stiffness * 4 * np.pi) | |
| for n in range(1, n_legendre_terms + 1) | |
| ] | |
| return legval(cosang, [0] + factors) | |
| def _make_interpolation_matrix(pos_from, pos_to, alpha=1e-5): | |
| """Compute interpolation matrix based on spherical splines. | |
| Implementation based on [1] | |
| Parameters | |
| ---------- | |
| pos_from : np.ndarray of float, shape(n_good_sensors, 3) | |
| The positions to interpolate from. | |
| pos_to : np.ndarray of float, shape(n_bad_sensors, 3) | |
| The positions to interpolate. | |
| alpha : float | |
| Regularization parameter. Defaults to 1e-5. | |
| Returns | |
| ------- | |
| interpolation : np.ndarray of float, shape(len(pos_from), len(pos_to)) | |
| The interpolation matrix that maps good signals to the location | |
| of bad signals. | |
| References | |
| ---------- | |
| [1] Perrin, F., Pernier, J., Bertrand, O. and Echallier, JF. (1989). | |
| Spherical splines for scalp potential and current density mapping. | |
| Electroencephalography Clinical Neurophysiology, Feb; 72(2):184-7. | |
| """ | |
| pos_from = pos_from.copy() | |
| pos_to = pos_to.copy() | |
| n_from = pos_from.shape[0] | |
| n_to = pos_to.shape[0] | |
| # normalize sensor positions to sphere | |
| _normalize_vectors(pos_from) | |
| _normalize_vectors(pos_to) | |
| # cosine angles between source positions | |
| cosang_from = pos_from.dot(pos_from.T) | |
| cosang_to_from = pos_to.dot(pos_from.T) | |
| G_from = _calc_g(cosang_from) | |
| G_to_from = _calc_g(cosang_to_from) | |
| assert G_from.shape == (n_from, n_from) | |
| assert G_to_from.shape == (n_to, n_from) | |
| if alpha is not None: | |
| G_from.flat[:: len(G_from) + 1] += alpha | |
| C = np.vstack( | |
| [ | |
| np.hstack([G_from, np.ones((n_from, 1))]), | |
| np.hstack([np.ones((1, n_from)), [[0]]]), | |
| ] | |
| ) | |
| C_inv = pinv(C) | |
| interpolation = np.hstack([G_to_from, np.ones((n_to, 1))]) @ C_inv[:, :-1] | |
| assert interpolation.shape == (n_to, n_from) | |
| return interpolation | |
| def _do_interp_dots(inst, interpolation, goods_idx, bads_idx): | |
| """Dot product of channel mapping matrix to channel data.""" | |
| from ..epochs import BaseEpochs | |
| from ..evoked import Evoked | |
| from ..io import BaseRaw | |
| _validate_type(inst, (BaseRaw, BaseEpochs, Evoked), "inst") | |
| inst._data[..., bads_idx, :] = np.matmul( | |
| interpolation, inst._data[..., goods_idx, :] | |
| ) | |
| def _interpolate_bads_eeg(inst, origin, exclude=None, ecog=False, verbose=None): | |
| if exclude is None: | |
| exclude = list() | |
| bads_idx = np.zeros(len(inst.ch_names), dtype=bool) | |
| goods_idx = np.zeros(len(inst.ch_names), dtype=bool) | |
| picks = pick_types(inst.info, meg=False, eeg=not ecog, ecog=ecog, exclude=exclude) | |
| inst.info._check_consistency() | |
| bads_idx[picks] = [inst.ch_names[ch] in inst.info["bads"] for ch in picks] | |
| if len(picks) == 0 or bads_idx.sum() == 0: | |
| return | |
| goods_idx[picks] = True | |
| goods_idx[bads_idx] = False | |
| pos = inst._get_channel_positions(picks) | |
| # Make sure only EEG are used | |
| bads_idx_pos = bads_idx[picks] | |
| goods_idx_pos = goods_idx[picks] | |
| # test spherical fit | |
| distance = np.linalg.norm(pos - origin, axis=-1) | |
| distance = np.mean(distance / np.mean(distance)) | |
| if np.abs(1.0 - distance) > 0.1: | |
| warn( | |
| "Your spherical fit is poor, interpolation results are " | |
| "likely to be inaccurate." | |
| ) | |
| pos_good = pos[goods_idx_pos] - origin | |
| pos_bad = pos[bads_idx_pos] - origin | |
| logger.info(f"Computing interpolation matrix from {len(pos_good)} sensor positions") | |
| interpolation = _make_interpolation_matrix(pos_good, pos_bad) | |
| logger.info(f"Interpolating {len(pos_bad)} sensors") | |
| _do_interp_dots(inst, interpolation, goods_idx, bads_idx) | |
| def _interpolate_bads_ecog(inst, *, origin, exclude=None, verbose=None): | |
| _interpolate_bads_eeg(inst, origin, exclude=exclude, ecog=True, verbose=verbose) | |
| def _interpolate_bads_meg( | |
| inst, mode="accurate", *, origin, verbose=None, ref_meg=False | |
| ): | |
| return _interpolate_bads_meeg( | |
| inst, mode, ref_meg=ref_meg, eeg=False, origin=origin, verbose=verbose | |
| ) | |
| def _interpolate_bads_nan( | |
| inst, | |
| *, | |
| ch_type, | |
| ref_meg=False, | |
| exclude=(), | |
| verbose=None, | |
| ): | |
| info = _simplify_info(inst.info) | |
| picks_type = pick_types(info, ref_meg=ref_meg, exclude=exclude, **{ch_type: True}) | |
| use_ch_names = [inst.info["ch_names"][p] for p in picks_type] | |
| bads_type = [ch for ch in inst.info["bads"] if ch in use_ch_names] | |
| if len(bads_type) == 0 or len(picks_type) == 0: | |
| return | |
| # select the bad channels to be interpolated | |
| picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[]) | |
| inst._data[..., picks_bad, :] = np.nan | |
| def _interpolate_bads_meeg( | |
| inst, | |
| mode="accurate", | |
| *, | |
| meg=True, | |
| eeg=True, | |
| ref_meg=False, | |
| exclude=(), | |
| origin, | |
| method=None, | |
| verbose=None, | |
| ): | |
| from ..forward import _map_meg_or_eeg_channels | |
| if method is None: | |
| method = {"meg": "MNE", "eeg": "MNE"} | |
| bools = dict(meg=meg, eeg=eeg) | |
| info = _simplify_info(inst.info) | |
| for ch_type, do in bools.items(): | |
| if not do: | |
| continue | |
| kw = dict(meg=False, eeg=False) | |
| kw[ch_type] = True | |
| picks_type = pick_types(info, ref_meg=ref_meg, exclude=exclude, **kw) | |
| picks_good = pick_types(info, ref_meg=ref_meg, exclude="bads", **kw) | |
| use_ch_names = [inst.info["ch_names"][p] for p in picks_type] | |
| bads_type = [ch for ch in inst.info["bads"] if ch in use_ch_names] | |
| if len(bads_type) == 0 or len(picks_type) == 0: | |
| continue | |
| # select the bad channels to be interpolated | |
| picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[]) | |
| # do MNE based interpolation | |
| if ch_type == "eeg": | |
| picks_to = picks_type | |
| bad_sel = np.isin(picks_type, picks_bad) | |
| else: | |
| picks_to = picks_bad | |
| bad_sel = slice(None) | |
| info_from = pick_info(inst.info, picks_good) | |
| info_to = pick_info(inst.info, picks_to) | |
| mapping = _map_meg_or_eeg_channels(info_from, info_to, mode=mode, origin=origin) | |
| mapping = mapping[bad_sel] | |
| _do_interp_dots(inst, mapping, picks_good, picks_bad) | |
| def _interpolate_bads_nirs(inst, exclude=(), verbose=None): | |
| from mne.preprocessing.nirs import _validate_nirs_info | |
| if len(pick_types(inst.info, fnirs=True, exclude=())) == 0: | |
| return | |
| # Returns pick of all nirs and ensures channels are correctly ordered | |
| picks_nirs = _validate_nirs_info(inst.info) | |
| nirs_ch_names = [inst.info["ch_names"][p] for p in picks_nirs] | |
| nirs_ch_names = [ch for ch in nirs_ch_names if ch not in exclude] | |
| bads_nirs = [ch for ch in inst.info["bads"] if ch in nirs_ch_names] | |
| if len(bads_nirs) == 0: | |
| return | |
| picks_bad = pick_channels(inst.info["ch_names"], bads_nirs, exclude=[]) | |
| bads_mask = [p in picks_bad for p in picks_nirs] | |
| chs = [inst.info["chs"][i] for i in picks_nirs] | |
| locs3d = np.array([ch["loc"][:3] for ch in chs]) | |
| dist = pdist(locs3d) | |
| dist = squareform(dist) | |
| for bad in picks_bad: | |
| dists_to_bad = dist[bad] | |
| # Ignore distances to self | |
| dists_to_bad[dists_to_bad == 0] = np.inf | |
| # Ignore distances to other bad channels | |
| dists_to_bad[bads_mask] = np.inf | |
| # Find closest remaining channels for same frequency | |
| closest_idx = np.argmin(dists_to_bad) + (bad % 2) | |
| inst._data[bad] = inst._data[closest_idx] | |
| # TODO: this seems like a bug because it does not respect reset_bads | |
| inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude] | |
| return inst | |
| def _find_seeg_electrode_shaft(pos, tol_shaft=0.002, tol_spacing=1): | |
| # 1) find nearest neighbor to define the electrode shaft line | |
| # 2) find all contacts on the same line | |
| # 3) remove contacts with large distances | |
| dist = squareform(pdist(pos)) | |
| np.fill_diagonal(dist, np.inf) | |
| shafts = list() | |
| shaft_ts = list() | |
| for i, n1 in enumerate(pos): | |
| if any([i in shaft for shaft in shafts]): | |
| continue | |
| n2 = pos[np.argmin(dist[i])] # 1 | |
| # https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html | |
| shaft_dists = np.linalg.norm( | |
| np.cross((pos - n1), (pos - n2)), axis=1 | |
| ) / np.linalg.norm(n2 - n1) | |
| shaft = np.where(shaft_dists < tol_shaft)[0] # 2 | |
| shaft_prev = None | |
| for _ in range(10): # avoid potential cycles | |
| if np.array_equal(shaft, shaft_prev): | |
| break | |
| shaft_prev = shaft | |
| # compute median shaft line | |
| v = np.median( | |
| [ | |
| pos[i] - pos[j] | |
| for idx, i in enumerate(shaft) | |
| for j in shaft[idx + 1 :] | |
| ], | |
| axis=0, | |
| ) | |
| c = np.median(pos[shaft], axis=0) | |
| # recompute distances | |
| shaft_dists = np.linalg.norm( | |
| np.cross((pos - c), (pos - c + v)), axis=1 | |
| ) / np.linalg.norm(v) | |
| shaft = np.where(shaft_dists < tol_shaft)[0] | |
| ts = np.array([np.dot(c - n0, v) / np.linalg.norm(v) ** 2 for n0 in pos[shaft]]) | |
| shaft_order = np.argsort(ts) | |
| shaft = shaft[shaft_order] | |
| ts = ts[shaft_order] | |
| # only include the largest group with spacing with the error tolerance | |
| # avoid interpolating across spans between contacts | |
| t_diffs = np.diff(ts) | |
| t_diff_med = np.median(t_diffs) | |
| spacing_errors = (t_diffs - t_diff_med) / t_diff_med | |
| groups = list() | |
| group = [shaft[0]] | |
| for j in range(len(shaft) - 1): | |
| if spacing_errors[j] > tol_spacing: | |
| groups.append(group) | |
| group = [shaft[j + 1]] | |
| else: | |
| group.append(shaft[j + 1]) | |
| groups.append(group) | |
| group = [group for group in groups if i in group][0] | |
| ts = ts[np.isin(shaft, group)] | |
| shaft = np.array(group, dtype=int) | |
| shafts.append(shaft) | |
| shaft_ts.append(ts) | |
| return shafts, shaft_ts | |
| def _interpolate_bads_seeg( | |
| inst, exclude=None, tol_shaft=0.002, tol_spacing=1, verbose=None | |
| ): | |
| if exclude is None: | |
| exclude = list() | |
| picks = pick_types(inst.info, meg=False, seeg=True, exclude=exclude) | |
| inst.info._check_consistency() | |
| bads_idx = np.isin(np.array(inst.ch_names)[picks], inst.info["bads"]) | |
| if len(picks) == 0 or bads_idx.sum() == 0: | |
| return | |
| pos = inst._get_channel_positions(picks) | |
| # Make sure only sEEG are used | |
| bads_idx_pos = bads_idx[picks] | |
| shafts, shaft_ts = _find_seeg_electrode_shaft( | |
| pos, tol_shaft=tol_shaft, tol_spacing=tol_spacing | |
| ) | |
| # interpolate the bad contacts | |
| picks_bad = list(np.where(bads_idx_pos)[0]) | |
| for shaft, ts in zip(shafts, shaft_ts): | |
| bads_shaft = np.array([idx for idx in picks_bad if idx in shaft]) | |
| if bads_shaft.size == 0: | |
| continue | |
| goods_shaft = shaft[np.isin(shaft, bads_shaft, invert=True)] | |
| if goods_shaft.size < 4: # cubic spline requires 3 channels | |
| msg = "No shaft" if shaft.size < 4 else "Not enough good channels" | |
| no_shaft_chs = " and ".join(np.array(inst.ch_names)[bads_shaft]) | |
| raise RuntimeError( | |
| f"{msg} found in a line with {no_shaft_chs} " | |
| "at least 3 good channels on the same line " | |
| f"are required for interpolation, {goods_shaft.size} found. " | |
| f"Dropping {no_shaft_chs} is recommended." | |
| ) | |
| logger.debug( | |
| f"Interpolating {np.array(inst.ch_names)[bads_shaft]} using " | |
| f"data from {np.array(inst.ch_names)[goods_shaft]}" | |
| ) | |
| bads_shaft_idx = np.where(np.isin(shaft, bads_shaft))[0] | |
| goods_shaft_idx = np.where(~np.isin(shaft, bads_shaft))[0] | |
| z = inst._data[..., goods_shaft, :] | |
| is_epochs = z.ndim == 3 | |
| if is_epochs: | |
| z = z.swapaxes(0, 1) | |
| z = z.reshape(z.shape[0], -1) | |
| y = np.arange(z.shape[-1]) | |
| out = RectBivariateSpline(x=ts[goods_shaft_idx], y=y, z=z)( | |
| x=ts[bads_shaft_idx], y=y | |
| ) | |
| if is_epochs: | |
| out = out.reshape(bads_shaft.size, inst._data.shape[0], -1) | |
| out = out.swapaxes(0, 1) | |
| inst._data[..., bads_shaft, :] = out | |