Spaces:
Running
Running
| """Tools for data interpolation.""" | |
| # Authors: The MNE-Python contributors. | |
| # License: BSD-3-Clause | |
| # Copyright the MNE-Python contributors. | |
| from itertools import chain | |
| import numpy as np | |
| from scipy.sparse.csgraph import connected_components | |
| from .._fiff.meas_info import create_info | |
| from ..epochs import BaseEpochs, EpochsArray | |
| from ..evoked import Evoked, EvokedArray | |
| from ..io import BaseRaw, RawArray | |
| from ..transforms import _cart_to_sph, _sph_to_cart | |
| from ..utils import _ensure_int, _validate_type | |
| def equalize_bads(insts, interp_thresh=1.0, copy=True): | |
| """Interpolate or mark bads consistently for a list of instances. | |
| Once called on a list of instances, the instances can be concatenated | |
| as they will have the same list of bad channels. | |
| Parameters | |
| ---------- | |
| insts : list | |
| The list of instances (Evoked, Epochs or Raw) to consider | |
| for interpolation. Each instance should have marked channels. | |
| interp_thresh : float | |
| A float between 0 and 1 (default) that specifies the fraction of time | |
| a channel should be good to be eventually interpolated for certain | |
| instances. For example if 0.5, a channel which is good at least half | |
| of the time will be interpolated in the instances where it is marked | |
| as bad. If 1 then channels will never be interpolated and if 0 all bad | |
| channels will be systematically interpolated. | |
| copy : bool | |
| If True then the returned instances will be copies. | |
| Returns | |
| ------- | |
| insts_bads : list | |
| The list of instances, with the same channel(s) marked as bad in all of | |
| them, possibly with some formerly bad channels interpolated. | |
| """ | |
| if not 0 <= interp_thresh <= 1: | |
| raise ValueError(f"interp_thresh must be between 0 and 1, got {interp_thresh}") | |
| all_bads = list(set(chain.from_iterable([inst.info["bads"] for inst in insts]))) | |
| if isinstance(insts[0], BaseEpochs): | |
| durations = [len(inst) * len(inst.times) for inst in insts] | |
| else: | |
| durations = [len(inst.times) for inst in insts] | |
| good_times = [] | |
| for ch_name in all_bads: | |
| good_times.append( | |
| sum( | |
| durations[k] | |
| for k, inst in enumerate(insts) | |
| if ch_name not in inst.info["bads"] | |
| ) | |
| / np.sum(durations) | |
| ) | |
| bads_keep = [ch for k, ch in enumerate(all_bads) if good_times[k] < interp_thresh] | |
| if copy: | |
| insts = [inst.copy() for inst in insts] | |
| for inst in insts: | |
| if len(set(inst.info["bads"]) - set(bads_keep)): | |
| inst.interpolate_bads(exclude=bads_keep) | |
| inst.info["bads"] = bads_keep | |
| return insts | |
| def interpolate_bridged_electrodes(inst, bridged_idx, bad_limit=4): | |
| """Interpolate bridged electrode pairs. | |
| Because bridged electrodes contain brain signal, it's just that the | |
| signal is spatially smeared between the two electrodes, we can | |
| make a virtual channel midway between the bridged pairs and use | |
| that to aid in interpolation rather than completely discarding the | |
| data from the two channels. | |
| Parameters | |
| ---------- | |
| inst : instance of Epochs, Evoked, or Raw | |
| The data object with channels that are to be interpolated. | |
| bridged_idx : list of tuple | |
| The indices of channels marked as bridged with each bridged | |
| pair stored as a tuple. | |
| bad_limit : int | |
| The maximum number of electrodes that can be bridged together | |
| (included) and interpolated. Above this number, an error will be | |
| raised. | |
| .. versionadded:: 1.2 | |
| Returns | |
| ------- | |
| inst : instance of Epochs, Evoked, or Raw | |
| The modified data object. | |
| See Also | |
| -------- | |
| mne.preprocessing.compute_bridged_electrodes | |
| """ | |
| _validate_type(inst, (BaseRaw, BaseEpochs, Evoked)) | |
| bad_limit = _ensure_int(bad_limit, "bad_limit") | |
| if bad_limit <= 0: | |
| raise ValueError( | |
| "Argument 'bad_limit' should be a strictly positive " | |
| f"integer. Provided {bad_limit} is invalid." | |
| ) | |
| montage = inst.get_montage() | |
| if montage is None: | |
| raise RuntimeError("No channel positions found in ``inst``") | |
| pos = montage.get_positions() | |
| if pos["coord_frame"] != "head": | |
| raise RuntimeError( | |
| f"Montage channel positions must be in ``head`` got {pos['coord_frame']}" | |
| ) | |
| # store bads orig to put back at the end | |
| bads_orig = inst.info["bads"] | |
| inst.info["bads"] = list() | |
| # look for group of bad channels | |
| nodes = sorted(set(chain(*bridged_idx))) | |
| G_dense = np.zeros((len(nodes), len(nodes))) | |
| # fill the edges with a weight of 1 | |
| for bridge in bridged_idx: | |
| idx0 = np.searchsorted(nodes, bridge[0]) | |
| idx1 = np.searchsorted(nodes, bridge[1]) | |
| G_dense[idx0, idx1] = 1 | |
| G_dense[idx1, idx0] = 1 | |
| # look for connected components | |
| _, labels = connected_components(G_dense, directed=False) | |
| groups_idx = [[nodes[j] for j in np.where(labels == k)[0]] for k in set(labels)] | |
| groups_names = [ | |
| [inst.info.ch_names[k] for k in group_idx] for group_idx in groups_idx | |
| ] | |
| # warn for all bridged areas that include too many electrodes | |
| for group_names in groups_names: | |
| if len(group_names) > bad_limit: | |
| raise RuntimeError( | |
| f"The channels {', '.join(group_names)} are bridged together " | |
| "and form a large area of bridged electrodes. Interpolation " | |
| "might be inaccurate." | |
| ) | |
| # make virtual channels | |
| virtual_chs = dict() | |
| bads = set() | |
| for k, group_idx in enumerate(groups_idx): | |
| group_names = [inst.info.ch_names[k] for k in group_idx] | |
| bads = bads.union(group_names) | |
| # compute centroid position in spherical "head" coordinates | |
| pos_virtual = _find_centroid_sphere(pos["ch_pos"], group_names) | |
| # create the virtual channel info and set the position | |
| virtual_info = create_info([f"virtual {k + 1}"], inst.info["sfreq"], "eeg") | |
| virtual_info["chs"][0]["loc"][:3] = pos_virtual | |
| # create virtual channel | |
| data = inst.get_data(picks=group_names) | |
| if isinstance(inst, BaseRaw): | |
| data = np.average(data, axis=0).reshape(1, -1) | |
| virtual_ch = RawArray(data, virtual_info, first_samp=inst.first_samp) | |
| elif isinstance(inst, BaseEpochs): | |
| data = np.average(data, axis=1).reshape(len(data), 1, -1) | |
| virtual_ch = EpochsArray(data, virtual_info, tmin=inst.tmin) | |
| else: # evoked | |
| data = np.average(data, axis=0).reshape(1, -1) | |
| virtual_ch = EvokedArray( | |
| np.average(data, axis=0).reshape(1, -1), | |
| virtual_info, | |
| tmin=inst.tmin, | |
| nave=inst.nave, | |
| kind=inst.kind, | |
| ) | |
| virtual_chs[f"virtual {k + 1}"] = virtual_ch | |
| # add the virtual channels | |
| inst.add_channels(list(virtual_chs.values()), force_update_info=True) | |
| # use the virtual channels to interpolate | |
| inst.info["bads"] = list(bads) | |
| inst.interpolate_bads() | |
| # drop virtual channels | |
| inst.drop_channels(list(virtual_chs.keys())) | |
| inst.info["bads"] = bads_orig | |
| return inst | |
| def _find_centroid_sphere(ch_pos, group_names): | |
| """Compute the centroid position between N electrodes. | |
| The centroid should be determined in spherical "head" coordinates which is | |
| more accurante than cutting through the scalp by averaging in cartesian | |
| coordinates. | |
| A simple way is to average the location in cartesian coordinate, convert | |
| to spehrical coordinate and replace the radius with the average radius of | |
| the N points in spherical coordinates. | |
| Parameters | |
| ---------- | |
| ch_pos : OrderedDict | |
| The position of all channels in cartesian coordinates. | |
| group_names : list | tuple | |
| The name of the N electrodes used to determine the centroid. | |
| Returns | |
| ------- | |
| pos_centroid : array of shape (3,) | |
| The position of the centroid in cartesian coordinates. | |
| """ | |
| cartesian_positions = np.array([ch_pos[ch_name] for ch_name in group_names]) | |
| sphere_positions = _cart_to_sph(cartesian_positions) | |
| cartesian_pos_centroid = np.average(cartesian_positions, axis=0) | |
| sphere_pos_centroid = _cart_to_sph(cartesian_pos_centroid) | |
| # average the radius and overwrite it | |
| avg_radius = np.average(sphere_positions, axis=0)[0] | |
| sphere_pos_centroid[0, 0] = avg_radius | |
| # convert back to cartesian | |
| pos_centroid = _sph_to_cart(sphere_pos_centroid)[0, :] | |
| return pos_centroid | |