File size: 6,541 Bytes
a481509 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
"""A mypy_ plugin for managing a number of platform-specific annotations.
Its functionality can be split into three distinct parts:
* Assigning the (platform-dependent) precisions of certain `~numpy.number`
subclasses, including the likes of `~numpy.int_`, `~numpy.intp` and
`~numpy.longlong`. See the documentation on
:ref:`scalar types <arrays.scalars.built-in>` for a comprehensive overview
of the affected classes. Without the plugin the precision of all relevant
classes will be inferred as `~typing.Any`.
* Removing all extended-precision `~numpy.number` subclasses that are
unavailable for the platform in question. Most notably this includes the
likes of `~numpy.float128` and `~numpy.complex256`. Without the plugin *all*
extended-precision types will, as far as mypy is concerned, be available
to all platforms.
* Assigning the (platform-dependent) precision of `~numpy.ctypeslib.c_intp`.
Without the plugin the type will default to `ctypes.c_int64`.
.. versionadded:: 1.22
.. deprecated:: 2.3
Examples
--------
To enable the plugin, one must add it to their mypy `configuration file`_:
.. code-block:: ini
[mypy]
plugins = numpy.typing.mypy_plugin
.. _mypy: https://mypy-lang.org/
.. _configuration file: https://mypy.readthedocs.io/en/stable/config_file.html
"""
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Final, TypeAlias, cast
import numpy as np
__all__: list[str] = []
def _get_precision_dict() -> dict[str, str]:
names = [
("_NBitByte", np.byte),
("_NBitShort", np.short),
("_NBitIntC", np.intc),
("_NBitIntP", np.intp),
("_NBitInt", np.int_),
("_NBitLong", np.long),
("_NBitLongLong", np.longlong),
("_NBitHalf", np.half),
("_NBitSingle", np.single),
("_NBitDouble", np.double),
("_NBitLongDouble", np.longdouble),
]
ret: dict[str, str] = {}
for name, typ in names:
n = 8 * np.dtype(typ).itemsize
ret[f"{_MODULE}._nbit.{name}"] = f"{_MODULE}._nbit_base._{n}Bit"
return ret
def _get_extended_precision_list() -> list[str]:
extended_names = [
"float96",
"float128",
"complex192",
"complex256",
]
return [i for i in extended_names if hasattr(np, i)]
def _get_c_intp_name() -> str:
# Adapted from `np.core._internal._getintp_ctype`
return {
"i": "c_int",
"l": "c_long",
"q": "c_longlong",
}.get(np.dtype("n").char, "c_long")
_MODULE: Final = "numpy._typing"
#: A dictionary mapping type-aliases in `numpy._typing._nbit` to
#: concrete `numpy.typing.NBitBase` subclasses.
_PRECISION_DICT: Final = _get_precision_dict()
#: A list with the names of all extended precision `np.number` subclasses.
_EXTENDED_PRECISION_LIST: Final = _get_extended_precision_list()
#: The name of the ctypes equivalent of `np.intp`
_C_INTP: Final = _get_c_intp_name()
try:
if TYPE_CHECKING:
from mypy.typeanal import TypeAnalyser
import mypy.types
from mypy.build import PRI_MED
from mypy.nodes import ImportFrom, MypyFile, Statement
from mypy.plugin import AnalyzeTypeContext, Plugin
except ModuleNotFoundError as e:
def plugin(version: str) -> type:
raise e
else:
_HookFunc: TypeAlias = Callable[[AnalyzeTypeContext], mypy.types.Type]
def _hook(ctx: AnalyzeTypeContext) -> mypy.types.Type:
"""Replace a type-alias with a concrete ``NBitBase`` subclass."""
typ, _, api = ctx
name = typ.name.split(".")[-1]
name_new = _PRECISION_DICT[f"{_MODULE}._nbit.{name}"]
return cast("TypeAnalyser", api).named_type(name_new)
def _index(iterable: Iterable[Statement], id: str) -> int:
"""Identify the first ``ImportFrom`` instance the specified `id`."""
for i, value in enumerate(iterable):
if getattr(value, "id", None) == id:
return i
raise ValueError("Failed to identify a `ImportFrom` instance "
f"with the following id: {id!r}")
def _override_imports(
file: MypyFile,
module: str,
imports: list[tuple[str, str | None]],
) -> None:
"""Override the first `module`-based import with new `imports`."""
# Construct a new `from module import y` statement
import_obj = ImportFrom(module, 0, names=imports)
import_obj.is_top_level = True
# Replace the first `module`-based import statement with `import_obj`
for lst in [file.defs, cast("list[Statement]", file.imports)]:
i = _index(lst, module)
lst[i] = import_obj
class _NumpyPlugin(Plugin):
"""A mypy plugin for handling versus numpy-specific typing tasks."""
def get_type_analyze_hook(self, fullname: str) -> _HookFunc | None:
"""Set the precision of platform-specific `numpy.number`
subclasses.
For example: `numpy.int_`, `numpy.longlong` and `numpy.longdouble`.
"""
if fullname in _PRECISION_DICT:
return _hook
return None
def get_additional_deps(
self, file: MypyFile
) -> list[tuple[int, str, int]]:
"""Handle all import-based overrides.
* Import platform-specific extended-precision `numpy.number`
subclasses (*e.g.* `numpy.float96` and `numpy.float128`).
* Import the appropriate `ctypes` equivalent to `numpy.intp`.
"""
fullname = file.fullname
if fullname == "numpy":
_override_imports(
file,
f"{_MODULE}._extended_precision",
imports=[(v, v) for v in _EXTENDED_PRECISION_LIST],
)
elif fullname == "numpy.ctypeslib":
_override_imports(
file,
"ctypes",
imports=[(_C_INTP, "_c_intp")],
)
return [(PRI_MED, fullname, -1)]
def plugin(version: str) -> type:
import warnings
plugin = "numpy.typing.mypy_plugin"
# Deprecated 2025-01-10, NumPy 2.3
warn_msg = (
f"`{plugin}` is deprecated, and will be removed in a future "
f"release. Please remove `plugins = {plugin}` in your mypy config."
f"(deprecated in NumPy 2.3)"
)
warnings.warn(warn_msg, DeprecationWarning, stacklevel=3)
return _NumpyPlugin
|