Spaces:
Running
Running
| """The check functions.""" | |
| # Authors: The MNE-Python contributors. | |
| # License: BSD-3-Clause | |
| # Copyright the MNE-Python contributors. | |
| import numbers | |
| import operator | |
| import os | |
| import re | |
| from builtins import input # noqa: A004, UP029 | |
| from difflib import get_close_matches | |
| from importlib import import_module | |
| from inspect import signature | |
| from pathlib import Path | |
| import numpy as np | |
| from ..defaults import HEAD_SIZE_DEFAULT, _handle_default | |
| from ..fixes import _compare_version, _median_complex | |
| from ._logging import _record_warnings, _verbose_safe_false, logger, verbose, warn | |
| def _ensure_int(x, name="unknown", must_be="an int", *, extra=""): | |
| """Ensure a variable is an integer.""" | |
| # This is preferred over numbers.Integral, see: | |
| # https://github.com/scipy/scipy/pull/7351#issuecomment-299713159 | |
| extra = f" {extra}" if extra else extra | |
| try: | |
| # someone passing True/False is much more likely to be an error than | |
| # intentional usage | |
| if isinstance(x, bool): | |
| raise TypeError() | |
| x = int(operator.index(x)) | |
| except TypeError: | |
| raise TypeError(f"{name} must be {must_be}{extra}, got {type(x)}") | |
| return x | |
| def _check_integer_or_list(arg, name): | |
| """Validate arguments that should be an integer or a list. | |
| Always returns a list. | |
| """ | |
| if not isinstance(arg, list): | |
| arg = [_ensure_int(arg, name=name, must_be="an integer or a list")] | |
| return arg | |
| def check_fname(fname, filetype, endings, endings_err=()): | |
| """Enforce MNE filename conventions. | |
| Parameters | |
| ---------- | |
| fname : path-like | |
| Name of the file. | |
| filetype : str | |
| Type of file. e.g., ICA, Epochs, etc. | |
| endings : tuple | |
| Acceptable endings for the filename. | |
| endings_err : tuple | |
| Obligatory possible endings for the filename. | |
| """ | |
| _validate_type(fname, "path-like", "fname") | |
| fname = str(fname) | |
| if len(endings_err) > 0 and not fname.endswith(endings_err): | |
| print_endings = " or ".join([", ".join(endings_err[:-1]), endings_err[-1]]) | |
| raise OSError( | |
| f"The filename ({fname}) for file type {filetype} must end " | |
| f"with {print_endings}" | |
| ) | |
| print_endings = " or ".join([", ".join(endings[:-1]), endings[-1]]) | |
| if not fname.endswith(endings): | |
| warn( | |
| f"This filename ({fname}) does not conform to MNE naming conventions. " | |
| f"All {filetype} files should end with {print_endings}" | |
| ) | |
| def check_version(library, min_version="0.0", *, strip=True, return_version=False): | |
| r"""Check minimum library version required. | |
| Parameters | |
| ---------- | |
| library : str | |
| The library name to import. Must have a ``__version__`` property. | |
| min_version : str | |
| The minimum version string. Anything that matches | |
| ``'(\d+ | [a-z]+ | \.)'``. Can also be empty to skip version | |
| check (just check for library presence). | |
| strip : bool | |
| If True (default), then PEP440 development markers like ``.devN`` | |
| will be stripped from the version. This makes it so that | |
| ``check_version('mne', '1.1')`` will be ``True`` even when on version | |
| ``'1.1.dev0'`` (prerelease/dev version). This option is provided for | |
| backward compatibility with the behavior of ``LooseVersion``, and | |
| diverges from how modern parsing in ``packaging.version.parse`` works. | |
| .. versionadded:: 1.0 | |
| return_version : bool | |
| If True (default False), also return the version (can be None if the | |
| library is missing). | |
| .. versionadded:: 1.0 | |
| Returns | |
| ------- | |
| ok : bool | |
| True if the library exists with at least the specified version. | |
| version : str | None | |
| The version. Only returned when ``return_version=True``. | |
| """ | |
| ok = True | |
| version = None | |
| try: | |
| library = import_module(library) | |
| except ImportError: | |
| ok = False | |
| else: | |
| check_version = min_version and min_version != "0.0" | |
| get_version = check_version or return_version | |
| if get_version: | |
| version = library.__version__ | |
| if strip: | |
| version = _strip_dev(version) | |
| if check_version: | |
| if _compare_version(version, "<", min_version): | |
| ok = False | |
| out = (ok, version) if return_version else ok | |
| return out | |
| def _strip_dev(version): | |
| # First capturing group () is what we want to keep, at the beginning: | |
| # | |
| # - at least one numeral, then | |
| # - repeats of {dot, at least one numeral} | |
| # | |
| # The rest (consume to the end of the string) is the stuff we want to cut | |
| # off: | |
| # | |
| # - A period (maybe), then | |
| # - "dev", "rc", or "+", then | |
| # - numerals, periods, dashes, and "a" through "g" (hex chars) | |
| # | |
| # Thanks https://www.regextester.com ! | |
| exp = r"^([0-9]+(?:\.[0-9]+)*)\.?(?:dev|rc|\+)[0-9+a-g\.\-]+$" | |
| match = re.match(exp, version) | |
| return match.groups()[0] if match is not None else version | |
| def _require_version(lib, what, version="0.0"): | |
| """Require library for a purpose.""" | |
| ok, got = check_version(lib, version, return_version=True) | |
| if not ok: | |
| extra = f" (version >= {version})" if version != "0.0" else "" | |
| why = "package was not found" if got is None else f"got {repr(got)}" | |
| raise ImportError(f"The {lib} package{extra} is required to {what}, {why}") | |
| def _import_h5py(): | |
| _require_version("h5py", "read MATLAB files >= v7.3") | |
| import h5py | |
| return h5py | |
| def _import_h5io_funcs(): | |
| h5io = _soft_import("h5io", "HDF5-based I/O") | |
| # Saving to HDF5 does not support pathlib.Path objects, which are more and more used | |
| # in MNE-Python. | |
| # Related issue in h5io: https://github.com/h5io/h5io/issues/113 | |
| def cast_path_to_str(data: dict) -> dict: | |
| """Cast all paths value to string in data.""" | |
| keys2cast = [] | |
| for key, value in data.items(): | |
| if isinstance(value, dict): | |
| cast_path_to_str(value) | |
| if isinstance(value, Path): | |
| data[key] = value.as_posix() | |
| if isinstance(key, Path): | |
| keys2cast.append(key) | |
| for key in keys2cast: | |
| data[key.as_posix()] = data.pop(key) | |
| return data | |
| def write_hdf5(fname, data, *args, **kwargs): | |
| """Write h5 and cast all paths to string in data.""" | |
| if isinstance(data, dict): | |
| data = cast_path_to_str(data) | |
| elif isinstance(data, list): | |
| for k, elt in enumerate(data): | |
| if isinstance(elt, dict): | |
| data[k] = cast_path_to_str(elt) | |
| h5io.write_hdf5(fname, data, *args, **kwargs) | |
| return h5io.read_hdf5, write_hdf5 | |
| def _import_pymatreader_funcs(purpose): | |
| pymatreader = _soft_import("pymatreader", purpose) | |
| return pymatreader.read_mat | |
| # adapted from scikit-learn utils/validation.py | |
| def check_random_state(seed): | |
| """Turn seed into a numpy.random.mtrand.RandomState instance. | |
| If seed is None, return the RandomState singleton used by np.random.mtrand. | |
| If seed is an int, return a new RandomState instance seeded with seed. | |
| If seed is already a RandomState instance, return it. | |
| Otherwise raise ValueError. | |
| """ | |
| if seed is None or seed is np.random: | |
| return np.random.mtrand._rand | |
| if isinstance(seed, int | np.integer): | |
| return np.random.mtrand.RandomState(seed) | |
| if isinstance(seed, np.random.mtrand.RandomState): | |
| return seed | |
| if isinstance(seed, np.random.Generator): | |
| return seed | |
| raise ValueError( | |
| f"{seed!r} cannot be used to seed a numpy.random.mtrand.RandomState instance" | |
| ) | |
| def _check_event_id(event_id, events): | |
| """Check event_id and convert to default format.""" | |
| # check out event_id dict | |
| if event_id is None: # convert to int to make typing-checks happy | |
| event_id = list(np.unique(events[:, 2])) | |
| if isinstance(event_id, dict): | |
| for key in event_id.keys(): | |
| _validate_type(key, str, "Event names") | |
| event_id = { | |
| key: _ensure_int(val, f"event_id[{key}]") for key, val in event_id.items() | |
| } | |
| elif isinstance(event_id, list): | |
| event_id = [_ensure_int(v, f"event_id[{vi}]") for vi, v in enumerate(event_id)] | |
| event_id = dict(zip((str(i) for i in event_id), event_id)) | |
| else: | |
| event_id = _ensure_int(event_id, "event_id") | |
| event_id = {str(event_id): event_id} | |
| return event_id | |
| def _check_fname( | |
| fname, | |
| overwrite=False, | |
| must_exist=False, | |
| name="File", | |
| need_dir=False, | |
| *, | |
| check_bids_split=False, | |
| verbose=None, | |
| ) -> Path: | |
| """Check for file existence, and return its absolute path.""" | |
| _validate_type(fname, "path-like", name) | |
| # special case for MNE-BIDS, check split | |
| fname_path = Path(fname) | |
| if check_bids_split: | |
| try: | |
| from mne_bids import BIDSPath | |
| except Exception: | |
| pass | |
| else: | |
| if isinstance(fname, BIDSPath) and fname.split is not None: | |
| raise ValueError( | |
| f"Passing a BIDSPath {name} with `{fname.split=}` is unsafe as it " | |
| "can unexpectedly lead to invalid BIDS split naming. Explicitly " | |
| f"set `{name}.split = None` to avoid ambiguity. If you want the " | |
| f"old misleading split naming, you can pass `str({name})`." | |
| ) | |
| fname = fname_path.expanduser().absolute() | |
| del fname_path | |
| if fname.exists(): | |
| if not overwrite: | |
| raise FileExistsError( | |
| "Destination file exists. Please use option " | |
| '"overwrite=True" to force overwriting.' | |
| ) | |
| elif overwrite != "read": | |
| logger.info("Overwriting existing file.") | |
| if must_exist: | |
| if need_dir: | |
| if not fname.is_dir(): | |
| raise OSError( | |
| f"Need a directory for {name} but found a file at {fname}" | |
| ) | |
| else: | |
| if not fname.is_file(): | |
| raise OSError( | |
| f"Need a file for {name} but found a directory at {fname}" | |
| ) | |
| if not os.access(fname, os.R_OK): | |
| raise PermissionError(f"{name} does not have read permissions: {fname}") | |
| elif must_exist: | |
| raise FileNotFoundError(f'{name} does not exist: "{fname}"') | |
| return fname | |
| def _check_subject( | |
| first, | |
| second, | |
| *, | |
| raise_error=True, | |
| first_kind="class subject attribute", | |
| second_kind="input subject", | |
| ): | |
| """Get subject name from class.""" | |
| if second is not None: | |
| _validate_type(second, "str", "subject input") | |
| if first is not None and first != second: | |
| raise ValueError( | |
| f"{first_kind} ({repr(first)}) did not match {second_kind} ({second})" | |
| ) | |
| return second | |
| elif first is not None: | |
| _validate_type(first, "str", f"Either {second_kind} subject or {first_kind}") | |
| return first | |
| elif raise_error is True: | |
| raise ValueError(f"Neither {second_kind} subject nor {first_kind} was a string") | |
| return None | |
| def _check_preload(inst, msg): | |
| """Ensure data are preloaded.""" | |
| from ..epochs import BaseEpochs | |
| from ..evoked import Evoked | |
| from ..source_estimate import _BaseSourceEstimate | |
| from ..time_frequency import BaseTFR | |
| from ..time_frequency.spectrum import BaseSpectrum | |
| if isinstance(inst, BaseTFR | Evoked | BaseSpectrum | _BaseSourceEstimate): | |
| pass | |
| else: | |
| name = "epochs" if isinstance(inst, BaseEpochs) else "raw" | |
| if not inst.preload: | |
| raise RuntimeError( | |
| "By default, MNE does not load data into main memory to " | |
| "conserve resources. " + msg + f" requires {name} data to be " | |
| "loaded. Use preload=True (or string) in the constructor or " | |
| f"{name}.load_data()." | |
| ) | |
| if name == "epochs": | |
| inst._handle_empty("raise", msg) | |
| def _check_compensation_grade(info1, info2, name1, name2="data", ch_names=None): | |
| """Ensure that objects have same compensation_grade.""" | |
| from .._fiff.compensator import get_current_comp | |
| from .._fiff.meas_info import Info | |
| from .._fiff.pick import pick_channels, pick_info | |
| for t_info in (info1, info2): | |
| if t_info is None: | |
| return | |
| assert isinstance(t_info, Info), t_info # or internal code is wrong | |
| if ch_names is not None: | |
| info1 = info1.copy() | |
| info2 = info2.copy() | |
| # pick channels | |
| for t_info in [info1, info2]: | |
| if t_info["comps"]: | |
| with t_info._unlock(): | |
| t_info["comps"] = [] | |
| picks = pick_channels(t_info["ch_names"], ch_names, ordered=False) | |
| pick_info(t_info, picks, copy=False) | |
| # "or 0" here aliases None -> 0, as they are equivalent | |
| grade1 = get_current_comp(info1) or 0 | |
| grade2 = get_current_comp(info2) or 0 | |
| # perform check | |
| if grade1 != grade2: | |
| raise RuntimeError( | |
| f"Compensation grade of {name1} ({grade1}) and {name2} ({grade2}) " | |
| "do not match" | |
| ) | |
| def _soft_import(name, purpose, strict=True, *, min_version=None): | |
| """Import soft dependencies, providing informative errors on failure. | |
| Parameters | |
| ---------- | |
| name : str | |
| Name of the module to be imported. For example, 'pandas'. | |
| purpose : str | |
| A very brief statement (formulated as a noun phrase) explaining what | |
| functionality the package provides to MNE-Python. | |
| strict : bool | |
| Whether to raise an error if module import fails. | |
| """ | |
| # Mapping import namespaces to their pypi package name | |
| pip_name = dict( | |
| sklearn="scikit-learn", | |
| mne_bids="mne-bids", | |
| mne_nirs="mne-nirs", | |
| mne_features="mne-features", | |
| mne_qt_browser="mne-qt-browser", | |
| mne_connectivity="mne-connectivity", | |
| mne_gui_addons="mne-gui-addons", | |
| pyvista="pyvistaqt", | |
| ).get(name, name) | |
| got_version = None | |
| try: | |
| mod = import_module(name) | |
| except (ImportError, ModuleNotFoundError): | |
| mod = False | |
| else: | |
| have, got_version = check_version( | |
| name, | |
| min_version=min_version, | |
| return_version=True, | |
| ) | |
| if not have: | |
| mod = False | |
| if mod is False and strict: | |
| extra = "" if min_version is None else f">={min_version}" | |
| if got_version is not None: | |
| extra += f" (found version {got_version})" | |
| raise RuntimeError( | |
| f"For {purpose} to work, the module {name}{extra} is needed, " | |
| "but it could not be imported. Use the following installation method " | |
| "appropriate for your environment:\n\n" | |
| f" pip install {pip_name}\n" | |
| f" conda install -c conda-forge {pip_name}" | |
| ) | |
| return mod | |
| def _check_pandas_installed(strict=True): | |
| """Aux function.""" | |
| return _soft_import("pandas", "dataframe integration", strict=strict) | |
| def _check_eeglabio_installed(strict=True): | |
| """Aux function.""" | |
| return _soft_import("eeglabio", "exporting to EEGLab", strict=strict) | |
| def _check_edfio_installed(strict=True): | |
| """Aux function.""" | |
| return _soft_import("edfio", "exporting to EDF", strict=strict) | |
| def _check_pybv_installed(strict=True): | |
| """Aux function.""" | |
| return _soft_import("pybv", "exporting to BrainVision", strict=strict) | |
| def _check_pymatreader_installed(strict=True): | |
| """Aux function.""" | |
| return _soft_import("pymatreader", "loading v7.3 (HDF5) .MAT files", strict=strict) | |
| def _check_pandas_index_arguments(index, valid): | |
| """Check pandas index arguments.""" | |
| if index is None: | |
| return | |
| if isinstance(index, str): | |
| index = [index] | |
| if not isinstance(index, list): | |
| raise TypeError( | |
| "index must be `None` or a string or list of strings, got type " | |
| f"{type(index)}." | |
| ) | |
| invalid = set(index) - set(valid) | |
| if invalid: | |
| plural = ("is not a valid option", "are not valid options")[ | |
| int(len(invalid) > 1) | |
| ] | |
| raise ValueError( | |
| '"{}" {}. Valid index options are `None`, "{}".'.format( | |
| '", "'.join(invalid), plural, '", "'.join(valid) | |
| ) | |
| ) | |
| return index | |
| def _check_time_format(time_format, valid, meas_date=None): | |
| """Check time_format argument.""" | |
| if time_format not in valid and time_format is not None: | |
| valid_str = '", "'.join(valid) | |
| raise ValueError( | |
| f'"{time_format}" is not a valid time format. Valid options are ' | |
| f'"{valid_str}" and None.' | |
| ) | |
| # allow datetime only if meas_date available | |
| if time_format == "datetime" and meas_date is None: | |
| warn( | |
| "Cannot convert to Datetime when raw.info['meas_date'] is " | |
| "None. Falling back to Timedelta." | |
| ) | |
| time_format = "timedelta" | |
| return time_format | |
| def _check_ch_locs(info, picks=None, ch_type=None): | |
| """Check if channel locations exist. | |
| Parameters | |
| ---------- | |
| info : Info | None | |
| `~mne.Info` instance. | |
| picks : list of int | |
| Channel indices to consider. If provided, ``ch_type`` must be ``None``. | |
| ch_type : str | None | |
| The channel type to restrict the check to. If ``None``, check all | |
| channel types. If provided, ``picks`` must be ``None``. | |
| """ | |
| from .._fiff.pick import _picks_to_idx, pick_info | |
| if picks is not None and ch_type is not None: | |
| raise ValueError("Either picks or ch_type may be provided, not both") | |
| if picks is not None: | |
| info = pick_info(info=info, sel=picks) | |
| elif ch_type is not None: | |
| picks = _picks_to_idx(info=info, picks=ch_type, none=ch_type) | |
| info = pick_info(info=info, sel=picks) | |
| chs = info["chs"] | |
| locs3d = np.array([ch["loc"][:3] for ch in chs]) | |
| return not ( | |
| (locs3d == 0).all() or (~np.isfinite(locs3d)).all() or np.allclose(locs3d, 0.0) | |
| ) | |
| def _is_numeric(n): | |
| return isinstance(n, numbers.Number) | |
| class _IntLike: | |
| def __instancecheck__(cls, other): | |
| try: | |
| _ensure_int(other) | |
| except TypeError: | |
| return False | |
| else: | |
| return True | |
| int_like = _IntLike() | |
| path_like = (str, Path, os.PathLike) | |
| class _Callable: | |
| def __instancecheck__(cls, other): | |
| return callable(other) | |
| class _Sparse: | |
| def __instancecheck__(cls, other): | |
| from scipy import sparse | |
| return sparse.issparse(other) | |
| _multi = { | |
| "str": (str,), | |
| "numeric": (np.floating, float, int_like), | |
| "path-like": path_like, | |
| "int-like": (int_like,), | |
| "callable": (_Callable(),), | |
| "array-like": (list, tuple, set, np.ndarray), | |
| "sparse": (_Sparse(),), | |
| } | |
| def _validate_type(item, types=None, item_name=None, type_name=None, *, extra=""): | |
| """Validate that `item` is an instance of `types`. | |
| Parameters | |
| ---------- | |
| item : object | |
| The thing to be checked. | |
| types : type | str | tuple of types | tuple of str | |
| The types to be checked against. | |
| If str, must be one of {'int', 'int-like', 'str', 'numeric', 'info', | |
| 'path-like', 'callable', 'array-like'}. | |
| If a tuple of str is passed, use 'int-like' and not 'int' for integers. | |
| item_name : str | None | |
| Name of the item to show inside the error message. | |
| type_name : str | None | |
| Possible types to show inside the error message that the checked item | |
| can be. | |
| extra : str | |
| Extra text to append to the warning. | |
| """ | |
| if types == "int": | |
| _ensure_int(item, name=item_name, extra=extra) | |
| return # terminate prematurely | |
| elif types == "info": | |
| from .._fiff.meas_info import Info as types | |
| if not isinstance(types, list | tuple): | |
| types = [types] | |
| check_types = sum( | |
| ( | |
| ( | |
| (type(None),) | |
| if type_ is None | |
| else (type_,) | |
| if not isinstance(type_, str) | |
| else _multi[type_] | |
| ) | |
| for type_ in types | |
| ), | |
| (), | |
| ) | |
| extra = f" {extra}" if extra else extra | |
| if not isinstance(item, check_types): | |
| if type_name is None: | |
| type_name = [ | |
| ( | |
| "None" | |
| if cls_ is None | |
| else cls_.__name__ | |
| if not isinstance(cls_, str) | |
| else cls_ | |
| ) | |
| for cls_ in types | |
| ] | |
| if len(type_name) == 1: | |
| type_name = type_name[0] | |
| elif len(type_name) == 2: | |
| type_name = " or ".join(type_name) | |
| else: | |
| type_name[-1] = "or " + type_name[-1] | |
| type_name = ", ".join(type_name) | |
| _item_name = "Item" if item_name is None else item_name | |
| _item_type = type(item) if item is not None else item | |
| raise TypeError( | |
| f"{_item_name} must be an instance of {type_name}{extra}, " | |
| f"got {_item_type} instead." | |
| ) | |
| def _check_range(val, min_val, max_val, name, min_inclusive=True, max_inclusive=True): | |
| """Check that item is within range. | |
| Parameters | |
| ---------- | |
| val : int | float | |
| The value to be checked. | |
| min_val : int | float | |
| The minimum value allowed. | |
| max_val : int | float | |
| The maximum value allowed. | |
| name : str | |
| The name of the value. | |
| min_inclusive : bool | |
| Whether ``val`` is allowed to be ``min_val``. | |
| max_inclusive : bool | |
| Whether ``val`` is allowed to be ``max_val``. | |
| """ | |
| below_min = val < min_val if min_inclusive else val <= min_val | |
| above_max = val > max_val if max_inclusive else val >= max_val | |
| if below_min or above_max: | |
| error_str = f"The value of {name} must be between {min_val} " | |
| if min_inclusive: | |
| error_str += "inclusive " | |
| error_str += f"and {max_val}" | |
| if max_inclusive: | |
| error_str += "inclusive " | |
| raise ValueError(error_str) | |
| def _path_like(item): | |
| """Validate that `item` is `path-like`. | |
| Parameters | |
| ---------- | |
| item : object | |
| The thing to be checked. | |
| Returns | |
| ------- | |
| bool | |
| ``True`` if `item` is a `path-like` object; ``False`` otherwise. | |
| """ | |
| try: | |
| _validate_type(item, types="path-like") | |
| return True | |
| except TypeError: | |
| return False | |
| def _check_if_nan(data, msg=" to be plotted"): | |
| """Raise if any of the values are NaN.""" | |
| if not np.isfinite(data).all(): | |
| raise ValueError(f"Some of the values {msg} are NaN.") | |
| def _check_info_inv(info, forward, data_cov=None, noise_cov=None, verbose=None): | |
| """Return good channels common to forward model and covariance matrices.""" | |
| from .._fiff.pick import pick_types | |
| # get a list of all channel names: | |
| fwd_ch_names = forward["info"]["ch_names"] | |
| # handle channels from forward model and info: | |
| ch_names = _compare_ch_names(info["ch_names"], fwd_ch_names, info["bads"]) | |
| # make sure that no reference channels are left: | |
| ref_chs = pick_types(info, meg=False, ref_meg=True) | |
| ref_chs = [info["ch_names"][ch] for ch in ref_chs] | |
| ch_names = [ch for ch in ch_names if ch not in ref_chs] | |
| # inform about excluding channels: | |
| if ( | |
| data_cov is not None | |
| and set(info["bads"]) != set(data_cov["bads"]) | |
| and (len(set(ch_names).intersection(data_cov["bads"])) > 0) | |
| ): | |
| logger.info( | |
| 'info["bads"] and data_cov["bads"] do not match, ' | |
| "excluding bad channels from both." | |
| ) | |
| if ( | |
| noise_cov is not None | |
| and set(info["bads"]) != set(noise_cov["bads"]) | |
| and (len(set(ch_names).intersection(noise_cov["bads"])) > 0) | |
| ): | |
| logger.info( | |
| 'info["bads"] and noise_cov["bads"] do not match, ' | |
| "excluding bad channels from both." | |
| ) | |
| # handle channels from data cov if data cov is not None | |
| # Note: data cov is supposed to be None in tf_lcmv | |
| if data_cov is not None: | |
| ch_names = _compare_ch_names(ch_names, data_cov.ch_names, data_cov["bads"]) | |
| # handle channels from noise cov if noise cov available: | |
| if noise_cov is not None: | |
| ch_names = _compare_ch_names(ch_names, noise_cov.ch_names, noise_cov["bads"]) | |
| # inform about excluding any channels apart from bads and reference | |
| all_bads = info["bads"] + ref_chs | |
| if data_cov is not None: | |
| all_bads += data_cov["bads"] | |
| if noise_cov is not None: | |
| all_bads += noise_cov["bads"] | |
| dropped_nonbads = set(info["ch_names"]) - set(ch_names) - set(all_bads) | |
| if dropped_nonbads: | |
| logger.info( | |
| f"Excluding {len(dropped_nonbads)} channel(s) missing from the " | |
| "provided forward operator and/or covariance matrices" | |
| ) | |
| picks = [info["ch_names"].index(k) for k in ch_names if k in info["ch_names"]] | |
| return picks | |
| def _compare_ch_names(names1, names2, bads): | |
| """Return channel names of common and good channels.""" | |
| ch_names = [ch for ch in names1 if ch not in bads and ch in names2] | |
| return ch_names | |
| def _check_channels_spatial_filter(ch_names, filters): | |
| """Return data channel indices to be used with spatial filter. | |
| Unlike ``pick_channels``, this respects the order of ch_names. | |
| """ | |
| sel = [] | |
| # first check for channel discrepancies between filter and data: | |
| for ch_name in filters["ch_names"]: | |
| if ch_name not in ch_names: | |
| raise ValueError( | |
| f"The spatial filter was computed with channel {ch_name} " | |
| "which is not present in the data. You should " | |
| "compute a new spatial filter restricted to the " | |
| "good data channels." | |
| ) | |
| # then compare list of channels and get selection based on data: | |
| sel = [ii for ii, ch_name in enumerate(ch_names) if ch_name in filters["ch_names"]] | |
| return sel | |
| def _check_rank(rank): | |
| """Check rank parameter.""" | |
| _validate_type(rank, (None, dict, str), "rank") | |
| if isinstance(rank, str): | |
| if rank not in ["full", "info"]: | |
| raise ValueError(f'rank, if str, must be "full" or "info", got {rank}') | |
| return rank | |
| def _check_one_ch_type(method, info, forward, data_cov=None, noise_cov=None): | |
| """Check number of sensor types and presence of noise covariance matrix.""" | |
| from .._fiff.pick import _contains_ch_type, pick_info | |
| from ..cov import Covariance, make_ad_hoc_cov | |
| from ..time_frequency.csd import CrossSpectralDensity | |
| if isinstance(data_cov, CrossSpectralDensity): | |
| _validate_type(noise_cov, [None, CrossSpectralDensity], "noise_cov") | |
| # FIXME | |
| picks = list(range(len(data_cov.ch_names))) | |
| info_pick = info | |
| else: | |
| _validate_type(noise_cov, [None, Covariance], "noise_cov") | |
| picks = _check_info_inv( | |
| info, | |
| forward, | |
| data_cov=data_cov, | |
| noise_cov=noise_cov, | |
| verbose=_verbose_safe_false(), | |
| ) | |
| info_pick = pick_info(info, picks) | |
| ch_types = [_contains_ch_type(info_pick, tt) for tt in ("mag", "grad", "eeg")] | |
| if sum(ch_types) > 1: | |
| if noise_cov is None: | |
| raise ValueError( | |
| "Source reconstruction with several sensor types" | |
| " requires a noise covariance matrix to be " | |
| "able to apply whitening." | |
| ) | |
| if noise_cov is None: | |
| noise_cov = make_ad_hoc_cov(info_pick, std=1.0) | |
| allow_mismatch = True | |
| else: | |
| noise_cov = noise_cov.copy() | |
| if isinstance(noise_cov, Covariance) and "estimator" in noise_cov: | |
| del noise_cov["estimator"] | |
| allow_mismatch = False | |
| _validate_type(noise_cov, (Covariance, CrossSpectralDensity), "noise_cov") | |
| return noise_cov, picks, allow_mismatch | |
| def _check_depth(depth, kind="depth_mne"): | |
| """Check depth options.""" | |
| if not isinstance(depth, dict): | |
| depth = dict(exp=None if depth is None else float(depth)) | |
| return _handle_default(kind, depth) | |
| def _check_dict_keys(mapping, valid_keys, key_description, valid_key_source): | |
| """Check that the keys in dictionary are valid against a set list. | |
| Return the input dictionary if it is valid, | |
| otherwise raise a ValueError with a readable error message. | |
| Parameters | |
| ---------- | |
| mapping : dict | |
| The user-provided dict whose keys we want to check. | |
| valid_keys : iterable | |
| The valid keys. | |
| key_description : str | |
| Description of the keys in ``mapping``, e.g., "channel name(s)" or | |
| "annotation(s)". | |
| valid_key_source : str | |
| Description of the ``valid_keys`` source, e.g., "info dict" or | |
| "annotations in the data". | |
| Returns | |
| ------- | |
| mapping | |
| If all keys are valid the input dict is returned unmodified. | |
| """ | |
| missing = set(mapping) - set(valid_keys) | |
| if len(missing): | |
| _is = "are" if len(missing) > 1 else "is" | |
| msg = ( | |
| f"Invalid {key_description} {missing} {_is} not present in " | |
| f"{valid_key_source}" | |
| ) | |
| raise ValueError(msg) | |
| return mapping | |
| def _check_option(parameter, value, allowed_values, extra=""): | |
| """Check the value of a parameter against a list of valid options. | |
| Return the value if it is valid, otherwise raise a ValueError with a | |
| readable error message. | |
| Parameters | |
| ---------- | |
| parameter : str | |
| The name of the parameter to check. This is used in the error message. | |
| value : any type | |
| The value of the parameter to check. | |
| allowed_values : list | |
| The list of allowed values for the parameter. | |
| extra : str | |
| Extra string to append to the invalid value sentence, e.g. | |
| "when using ico mode". | |
| Raises | |
| ------ | |
| ValueError | |
| When the value of the parameter is not one of the valid options. | |
| Returns | |
| ------- | |
| value : any type | |
| The value if it is valid. | |
| """ | |
| if value in allowed_values: | |
| return value | |
| # Prepare a nice error message for the user | |
| extra = f" {extra}" if extra else extra | |
| msg = ( | |
| "Invalid value for the '{parameter}' parameter{extra}. " | |
| "{options}, but got {value!r} instead." | |
| ) | |
| allowed_values = list(allowed_values) # e.g., if a dict was given | |
| if len(allowed_values) == 1: | |
| options = f"The only allowed value is {repr(allowed_values[0])}" | |
| else: | |
| options = "Allowed values are " | |
| if len(allowed_values) == 2: | |
| options += " and ".join(repr(v) for v in allowed_values) | |
| else: | |
| options += ", ".join(repr(v) for v in allowed_values[:-1]) | |
| options += f", and {repr(allowed_values[-1])}" | |
| raise ValueError( | |
| msg.format(parameter=parameter, options=options, value=value, extra=extra) | |
| ) | |
| def _check_all_same_channel_names(instances): | |
| """Check if a collection of instances all have the same channels.""" | |
| ch_names = instances[0].info["ch_names"] | |
| for inst in instances: | |
| if ch_names != inst.info["ch_names"]: | |
| return False | |
| return True | |
| def _check_combine(mode, valid=("mean", "median", "std"), axis=0): | |
| # XXX TODO Possibly de-duplicate with _make_combine_callable of mne/viz/utils.py | |
| if mode == "mean": | |
| def fun(data): | |
| return np.mean(data, axis=axis) | |
| elif mode == "std": | |
| def fun(data): | |
| return np.std(data, axis=axis) | |
| elif mode == "median" or mode == np.median: | |
| def fun(data): | |
| return _median_complex(data, axis=axis) | |
| elif callable(mode): | |
| fun = mode | |
| else: | |
| raise ValueError( | |
| "Combine option must be " | |
| + ", ".join(valid) | |
| + f" or callable, got {mode} (type {type(mode)})." | |
| ) | |
| return fun | |
| def _check_src_normal(pick_ori, src): | |
| from ..source_space import SourceSpaces | |
| _validate_type(src, SourceSpaces, "src") | |
| if pick_ori == "normal" and src.kind not in ("surface", "discrete"): | |
| raise RuntimeError( | |
| "Normal source orientation is supported only for " | |
| "surface or discrete SourceSpaces, got type " | |
| f"{src.kind}" | |
| ) | |
| def _check_stc_units(stc, threshold=1e-7): # 100 nAm threshold for warning | |
| max_cur = np.max(np.abs(stc.data)) | |
| if max_cur > threshold: | |
| warn( | |
| f"The maximum current magnitude is {1e9 * max_cur:.1f} nAm, which is very " | |
| "large. Are you trying to apply the forward model to noise-normalized " | |
| "(dSPM, sLORETA, or eLORETA) values? The result will only be " | |
| "correct if currents (in units of Am) are used." | |
| ) | |
| def _check_qt_version(*, return_api=False, check_usable_display=True): | |
| """Check if Qt is installed.""" | |
| from ..viz.backends._utils import _init_mne_qtapp | |
| try: | |
| from qtpy import API_NAME as api | |
| from qtpy import QtCore | |
| except Exception: | |
| api = version = None | |
| else: | |
| try: # pyside | |
| version = QtCore.__version__ | |
| except AttributeError: | |
| version = QtCore.QT_VERSION_STR | |
| # Having Qt installed is not enough -- sometimes the app is unusable | |
| # for example because there is no usable display (e.g., on a server), | |
| # so we have to try instantiating one to actually know. | |
| if check_usable_display: | |
| try: | |
| _init_mne_qtapp() | |
| except Exception: | |
| api = version = None | |
| if return_api: | |
| return version, api | |
| else: | |
| return version | |
| def _check_sphere(sphere, info=None, sphere_units="m"): | |
| from ..bem import ConductorModel, fit_sphere_to_headshape, get_fitting_dig | |
| if sphere is None: | |
| sphere = HEAD_SIZE_DEFAULT | |
| if info is not None: | |
| # Decide if we have enough dig points to do the auto fit | |
| try: | |
| get_fitting_dig(info, "extra", verbose="error") | |
| except (RuntimeError, ValueError): | |
| pass | |
| else: | |
| sphere = "auto" | |
| if isinstance(sphere, str): | |
| _check_option( | |
| "sphere", sphere, ("auto", "eeglab", "extra", "eeg", "cardinal", "hpi") | |
| ) | |
| if isinstance(sphere, str) and sphere == "eeglab": | |
| assert info is not None | |
| # We need coordinates for the 2D plane formed by | |
| # Fpz<->Oz and T7<->T8, as this plane will be the horizon (i.e. it | |
| # will determine the location of the head circle). | |
| # | |
| # We implement some special-handling in case Fpz is missing, as this seems to be | |
| # a quite common situation in numerous EEG labs. | |
| montage = info.get_montage() | |
| if montage is None: | |
| raise ValueError( | |
| 'No montage was set on your data, but sphere="eeglab" can only work if ' | |
| "digitization points for the EEG channels are available. Consider " | |
| "calling set_montage() to apply a montage." | |
| ) | |
| ch_pos = montage.get_positions()["ch_pos"] | |
| horizon_ch_names = ("Fpz", "Oz", "T7", "T8") | |
| if "FPz" in ch_pos: # "fix" naming | |
| ch_pos["Fpz"] = ch_pos["FPz"] | |
| del ch_pos["FPz"] | |
| elif "Fpz" not in ch_pos and "Oz" in ch_pos: | |
| logger.info( | |
| "Approximating Fpz location by mirroring Oz along the X and Y axes." | |
| ) | |
| # This assumes Fpz and Oz have the same Z coordinate | |
| ch_pos["Fpz"] = ch_pos["Oz"] * [-1, -1, 1] | |
| for ch_name in horizon_ch_names: | |
| if ch_name not in ch_pos: | |
| msg = ( | |
| f'sphere="eeglab" requires digitization points of the following ' | |
| f"electrode locations in the data: {', '.join(horizon_ch_names)}, " | |
| f"but could not find: {ch_name}" | |
| ) | |
| if ch_name == "Fpz": | |
| msg += ", and was unable to approximate its location from Oz" | |
| raise ValueError(msg) | |
| # Calculate the radius from: T7<->T8, Fpz<->Oz | |
| radius = np.abs( | |
| [ | |
| ch_pos["T7"][0], # X axis | |
| ch_pos["T8"][0], # X axis | |
| ch_pos["Fpz"][1], # Y axis | |
| ch_pos["Oz"][1], # Y axis | |
| ] | |
| ).mean() | |
| # Calculate the center of the head sphere. | |
| # Use 4 digpoints for each of the 3 axes to hopefully get a better approximation | |
| # than when using just 2 digpoints. | |
| sphere_locs = dict() | |
| for idx, axis in enumerate(("X", "Y", "Z")): | |
| sphere_locs[axis] = np.mean( | |
| [ | |
| ch_pos["T7"][idx], | |
| ch_pos["T8"][idx], | |
| ch_pos["Fpz"][idx], | |
| ch_pos["Oz"][idx], | |
| ] | |
| ) | |
| sphere = (sphere_locs["X"], sphere_locs["Y"], sphere_locs["Z"], radius) | |
| sphere_units = "m" | |
| del sphere_locs, radius, montage, ch_pos | |
| elif isinstance(sphere, str) or ( | |
| isinstance(sphere, list) and all(isinstance(s, str) for s in sphere) | |
| ): | |
| # Fit a sphere to the head points. | |
| R, r0, _ = fit_sphere_to_headshape( | |
| info, dig_kinds=sphere, verbose=_verbose_safe_false(), units="m" | |
| ) | |
| sphere = tuple(r0) + (R,) | |
| sphere_units = "m" | |
| elif isinstance(sphere, ConductorModel): | |
| if not sphere["is_sphere"] or len(sphere["layers"]) == 0: | |
| raise ValueError( | |
| "sphere, if a ConductorModel, must be spherical with multiple layers, " | |
| f"not a BEM or single-layer sphere (got {sphere})" | |
| ) | |
| sphere = tuple(sphere["r0"]) + (sphere["layers"][0]["rad"],) | |
| sphere_units = "m" | |
| sphere = np.array(sphere, dtype=float) | |
| if sphere.shape == (): | |
| sphere = np.concatenate([[0.0] * 3, [sphere]]) | |
| if sphere.shape != (4,): | |
| raise ValueError( | |
| "sphere must be float or 1D array of shape (4,), " | |
| f"got array-like of shape {sphere.shape}" | |
| ) | |
| _check_option("sphere_units", sphere_units, ("m", "mm")) | |
| if sphere_units == "mm": | |
| sphere /= 1000.0 | |
| sphere = np.array(sphere, float) | |
| return sphere | |
| def _check_head_radius(radius, add_info=""): | |
| """Check that head radius is within a reasonable range (5. - 10.85 cm). | |
| Parameters | |
| ---------- | |
| radius : float | |
| Head radius in meters. | |
| add_info : str | |
| Additional info to add to the warning message. | |
| Notes | |
| ----- | |
| The maximum value was taken from the head size percentiles given in the | |
| following Wikipedia infographic: | |
| https://upload.wikimedia.org/wikipedia/commons/0/06/AvgHeadSizes.png | |
| the maximum radius is taken from the 99th percentile for men Glabella | |
| to back of the head measurements (Glabella is a point just above the | |
| Nasion): | |
| 21.7cm / 2 = 10.85 cm = 0.1085 m | |
| The minimum value was taken from The National Center for Health Statistics | |
| (USA) infant head circumference percentiles: | |
| https://www.cdc.gov/growthcharts/html_charts/hcageinf.htm | |
| we take the minimum to be the radius corresponding to the 3rd percentile | |
| head circumference of female 0-month infant, rounded down: | |
| 31.9302 cm circumference / (2 * pi) = 5.08 cm radius -> 0.05 m | |
| """ | |
| min_radius = 0.05 | |
| max_radius = 0.1085 | |
| if radius > max_radius: | |
| msg = ( | |
| f"Estimated head radius ({1e2 * radius:0.1f} cm) is " | |
| "above the 99th percentile for adult head size." | |
| ) | |
| warn(msg + add_info) | |
| elif radius < min_radius: | |
| msg = ( | |
| f"Estimated head radius ({1e2 * radius:0.1f} cm) is " | |
| "below the 3rd percentile for infant head size." | |
| ) | |
| warn(msg + add_info) | |
| def _check_freesurfer_home(): | |
| from .config import get_config | |
| fs_home = get_config("FREESURFER_HOME") | |
| if fs_home is None: | |
| raise RuntimeError("The FREESURFER_HOME environment variable is not set.") | |
| return fs_home | |
| def _suggest(val, options, cutoff=0.66): | |
| options = get_close_matches(val, options, cutoff=cutoff) | |
| if len(options) == 0: | |
| return "" | |
| elif len(options) == 1: | |
| return f" Did you mean {repr(options[0])}?" | |
| else: | |
| return f" Did you mean one of {repr(options)}?" | |
| def _check_on_missing(on_missing, name="on_missing", *, extras=()): | |
| _validate_type(on_missing, str, name) | |
| _check_option(name, on_missing, ["raise", "warn", "ignore"] + list(extras)) | |
| def _on_missing(on_missing, msg, name="on_missing", error_klass=None): | |
| _check_on_missing(on_missing, name) | |
| error_klass = ValueError if error_klass is None else error_klass | |
| on_missing = "raise" if on_missing == "error" else on_missing | |
| on_missing = "warn" if on_missing == "warning" else on_missing | |
| if on_missing == "raise": | |
| raise error_klass(msg) | |
| elif on_missing == "warn": | |
| warn(msg) | |
| else: # Ignore | |
| assert on_missing == "ignore" | |
| def _safe_input(msg, *, alt=None, use=None): | |
| try: | |
| return input(msg) | |
| except EOFError: # MATLAB or other non-stdin | |
| if use is not None: | |
| return use | |
| raise RuntimeError( | |
| f"Could not use input() to get a response to:\n{msg}\n" | |
| f"You can {alt} to avoid this error." | |
| ) | |
| def _ensure_events(events): | |
| err_msg = f"events should be a NumPy array of integers, got {type(events)}" | |
| with _record_warnings(): | |
| try: | |
| events = np.asarray(events) | |
| except ValueError as np_err: | |
| if str(np_err).startswith( | |
| "setting an array element with a sequence. The requested " | |
| "array has an inhomogeneous shape" | |
| ): | |
| raise TypeError(err_msg) from None | |
| else: | |
| raise | |
| if not np.issubdtype(events.dtype, np.integer): | |
| raise TypeError(err_msg) | |
| if events.ndim != 2 or events.shape[1] != 3: | |
| raise ValueError(f"events must be of shape (N, 3), got {events.shape}") | |
| return events | |
| def _to_rgb(*args, name="color", alpha=False): | |
| from matplotlib.colors import colorConverter | |
| func = colorConverter.to_rgba if alpha else colorConverter.to_rgb | |
| try: | |
| return func(*args) | |
| except ValueError: | |
| args = args[0] if len(args) == 1 else args | |
| raise ValueError( | |
| f"Invalid RGB{'A' if alpha else ''} argument(s) for {name}: {repr(args)}" | |
| ) from None | |
| def _import_nibabel(why="use MRI files"): | |
| try: | |
| import nibabel as nib | |
| except ImportError as exp: | |
| raise exp.__class__(f"nibabel is required to {why}, got:\n{exp}") from None | |
| return nib | |
| def _check_method_kwargs(func, kwargs, msg=None): | |
| """Ensure **kwargs are compatible with the function they're passed to.""" | |
| from .misc import _pl | |
| valid = list(signature(func).parameters) | |
| is_invalid = np.isin(list(kwargs), valid, invert=True) | |
| if is_invalid.any(): | |
| invalid_kw = np.array(list(kwargs))[is_invalid].tolist() | |
| s = _pl(invalid_kw) | |
| if msg is None: | |
| msg = f'function "{func}"' | |
| raise TypeError( | |
| f"Got unexpected keyword argument{s} {', '.join(invalid_kw)} for {msg}." | |
| ) | |