|
|
"""Rewrite assertion AST to produce nice error messages.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import ast |
|
|
from collections import defaultdict |
|
|
import errno |
|
|
import functools |
|
|
import importlib.abc |
|
|
import importlib.machinery |
|
|
import importlib.util |
|
|
import io |
|
|
import itertools |
|
|
import marshal |
|
|
import os |
|
|
from pathlib import Path |
|
|
from pathlib import PurePath |
|
|
import struct |
|
|
import sys |
|
|
import tokenize |
|
|
import types |
|
|
from typing import Callable |
|
|
from typing import IO |
|
|
from typing import Iterable |
|
|
from typing import Iterator |
|
|
from typing import Sequence |
|
|
from typing import TYPE_CHECKING |
|
|
|
|
|
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE |
|
|
from _pytest._io.saferepr import saferepr |
|
|
from _pytest._version import version |
|
|
from _pytest.assertion import util |
|
|
from _pytest.config import Config |
|
|
from _pytest.main import Session |
|
|
from _pytest.pathlib import absolutepath |
|
|
from _pytest.pathlib import fnmatch_ex |
|
|
from _pytest.stash import StashKey |
|
|
|
|
|
|
|
|
|
|
|
from _pytest.assertion.util import format_explanation as _format_explanation |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from _pytest.assertion import AssertionState |
|
|
|
|
|
|
|
|
class Sentinel: |
|
|
pass |
|
|
|
|
|
|
|
|
assertstate_key = StashKey["AssertionState"]() |
|
|
|
|
|
|
|
|
PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}" |
|
|
PYC_EXT = ".py" + (__debug__ and "c" or "o") |
|
|
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT |
|
|
|
|
|
|
|
|
_SCOPE_END_MARKER = Sentinel() |
|
|
|
|
|
|
|
|
class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): |
|
|
"""PEP302/PEP451 import hook which rewrites asserts.""" |
|
|
|
|
|
def __init__(self, config: Config) -> None: |
|
|
self.config = config |
|
|
try: |
|
|
self.fnpats = config.getini("python_files") |
|
|
except ValueError: |
|
|
self.fnpats = ["test_*.py", "*_test.py"] |
|
|
self.session: Session | None = None |
|
|
self._rewritten_names: dict[str, Path] = {} |
|
|
self._must_rewrite: set[str] = set() |
|
|
|
|
|
|
|
|
self._writing_pyc = False |
|
|
self._basenames_to_check_rewrite = {"conftest"} |
|
|
self._marked_for_rewrite_cache: dict[str, bool] = {} |
|
|
self._session_paths_checked = False |
|
|
|
|
|
def set_session(self, session: Session | None) -> None: |
|
|
self.session = session |
|
|
self._session_paths_checked = False |
|
|
|
|
|
|
|
|
_find_spec = importlib.machinery.PathFinder.find_spec |
|
|
|
|
|
def find_spec( |
|
|
self, |
|
|
name: str, |
|
|
path: Sequence[str | bytes] | None = None, |
|
|
target: types.ModuleType | None = None, |
|
|
) -> importlib.machinery.ModuleSpec | None: |
|
|
if self._writing_pyc: |
|
|
return None |
|
|
state = self.config.stash[assertstate_key] |
|
|
if self._early_rewrite_bailout(name, state): |
|
|
return None |
|
|
state.trace(f"find_module called for: {name}") |
|
|
|
|
|
|
|
|
spec = self._find_spec(name, path) |
|
|
|
|
|
if spec is None and path is not None: |
|
|
|
|
|
|
|
|
|
|
|
for _path_str in path: |
|
|
spec = importlib.util.spec_from_file_location(name, _path_str) |
|
|
if spec is not None: |
|
|
break |
|
|
|
|
|
if ( |
|
|
|
|
|
spec is None |
|
|
|
|
|
|
|
|
or spec.origin is None |
|
|
|
|
|
or not isinstance(spec.loader, importlib.machinery.SourceFileLoader) |
|
|
|
|
|
or not os.path.exists(spec.origin) |
|
|
): |
|
|
return None |
|
|
else: |
|
|
fn = spec.origin |
|
|
|
|
|
if not self._should_rewrite(name, fn, state): |
|
|
return None |
|
|
|
|
|
return importlib.util.spec_from_file_location( |
|
|
name, |
|
|
fn, |
|
|
loader=self, |
|
|
submodule_search_locations=spec.submodule_search_locations, |
|
|
) |
|
|
|
|
|
def create_module( |
|
|
self, spec: importlib.machinery.ModuleSpec |
|
|
) -> types.ModuleType | None: |
|
|
return None |
|
|
|
|
|
def exec_module(self, module: types.ModuleType) -> None: |
|
|
assert module.__spec__ is not None |
|
|
assert module.__spec__.origin is not None |
|
|
fn = Path(module.__spec__.origin) |
|
|
state = self.config.stash[assertstate_key] |
|
|
|
|
|
self._rewritten_names[module.__name__] = fn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
write = not sys.dont_write_bytecode |
|
|
cache_dir = get_cache_dir(fn) |
|
|
if write: |
|
|
ok = try_makedirs(cache_dir) |
|
|
if not ok: |
|
|
write = False |
|
|
state.trace(f"read only directory: {cache_dir}") |
|
|
|
|
|
cache_name = fn.name[:-3] + PYC_TAIL |
|
|
pyc = cache_dir / cache_name |
|
|
|
|
|
|
|
|
co = _read_pyc(fn, pyc, state.trace) |
|
|
if co is None: |
|
|
state.trace(f"rewriting {fn!r}") |
|
|
source_stat, co = _rewrite_test(fn, self.config) |
|
|
if write: |
|
|
self._writing_pyc = True |
|
|
try: |
|
|
_write_pyc(state, co, source_stat, pyc) |
|
|
finally: |
|
|
self._writing_pyc = False |
|
|
else: |
|
|
state.trace(f"found cached rewritten pyc for {fn}") |
|
|
exec(co, module.__dict__) |
|
|
|
|
|
def _early_rewrite_bailout(self, name: str, state: AssertionState) -> bool: |
|
|
"""A fast way to get out of rewriting modules. |
|
|
|
|
|
Profiling has shown that the call to PathFinder.find_spec (inside of |
|
|
the find_spec from this class) is a major slowdown, so, this method |
|
|
tries to filter what we're sure won't be rewritten before getting to |
|
|
it. |
|
|
""" |
|
|
if self.session is not None and not self._session_paths_checked: |
|
|
self._session_paths_checked = True |
|
|
for initial_path in self.session._initialpaths: |
|
|
|
|
|
|
|
|
parts = str(initial_path).split(os.sep) |
|
|
|
|
|
self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0]) |
|
|
|
|
|
|
|
|
parts = name.split(".") |
|
|
if parts[-1] in self._basenames_to_check_rewrite: |
|
|
return False |
|
|
|
|
|
|
|
|
path = PurePath(*parts).with_suffix(".py") |
|
|
|
|
|
for pat in self.fnpats: |
|
|
|
|
|
|
|
|
if os.path.dirname(pat): |
|
|
return False |
|
|
if fnmatch_ex(pat, path): |
|
|
return False |
|
|
|
|
|
if self._is_marked_for_rewrite(name, state): |
|
|
return False |
|
|
|
|
|
state.trace(f"early skip of rewriting module: {name}") |
|
|
return True |
|
|
|
|
|
def _should_rewrite(self, name: str, fn: str, state: AssertionState) -> bool: |
|
|
|
|
|
if os.path.basename(fn) == "conftest.py": |
|
|
state.trace(f"rewriting conftest file: {fn!r}") |
|
|
return True |
|
|
|
|
|
if self.session is not None: |
|
|
if self.session.isinitpath(absolutepath(fn)): |
|
|
state.trace(f"matched test file (was specified on cmdline): {fn!r}") |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
fn_path = PurePath(fn) |
|
|
for pat in self.fnpats: |
|
|
if fnmatch_ex(pat, fn_path): |
|
|
state.trace(f"matched test file {fn!r}") |
|
|
return True |
|
|
|
|
|
return self._is_marked_for_rewrite(name, state) |
|
|
|
|
|
def _is_marked_for_rewrite(self, name: str, state: AssertionState) -> bool: |
|
|
try: |
|
|
return self._marked_for_rewrite_cache[name] |
|
|
except KeyError: |
|
|
for marked in self._must_rewrite: |
|
|
if name == marked or name.startswith(marked + "."): |
|
|
state.trace(f"matched marked file {name!r} (from {marked!r})") |
|
|
self._marked_for_rewrite_cache[name] = True |
|
|
return True |
|
|
|
|
|
self._marked_for_rewrite_cache[name] = False |
|
|
return False |
|
|
|
|
|
def mark_rewrite(self, *names: str) -> None: |
|
|
"""Mark import names as needing to be rewritten. |
|
|
|
|
|
The named module or package as well as any nested modules will |
|
|
be rewritten on import. |
|
|
""" |
|
|
already_imported = ( |
|
|
set(names).intersection(sys.modules).difference(self._rewritten_names) |
|
|
) |
|
|
for name in already_imported: |
|
|
mod = sys.modules[name] |
|
|
if not AssertionRewriter.is_rewrite_disabled( |
|
|
mod.__doc__ or "" |
|
|
) and not isinstance(mod.__loader__, type(self)): |
|
|
self._warn_already_imported(name) |
|
|
self._must_rewrite.update(names) |
|
|
self._marked_for_rewrite_cache.clear() |
|
|
|
|
|
def _warn_already_imported(self, name: str) -> None: |
|
|
from _pytest.warning_types import PytestAssertRewriteWarning |
|
|
|
|
|
self.config.issue_config_time_warning( |
|
|
PytestAssertRewriteWarning( |
|
|
f"Module already imported so cannot be rewritten: {name}" |
|
|
), |
|
|
stacklevel=5, |
|
|
) |
|
|
|
|
|
def get_data(self, pathname: str | bytes) -> bytes: |
|
|
"""Optional PEP302 get_data API.""" |
|
|
with open(pathname, "rb") as f: |
|
|
return f.read() |
|
|
|
|
|
if sys.version_info >= (3, 10): |
|
|
if sys.version_info >= (3, 12): |
|
|
from importlib.resources.abc import TraversableResources |
|
|
else: |
|
|
from importlib.abc import TraversableResources |
|
|
|
|
|
def get_resource_reader(self, name: str) -> TraversableResources: |
|
|
if sys.version_info < (3, 11): |
|
|
from importlib.readers import FileReader |
|
|
else: |
|
|
from importlib.resources.readers import FileReader |
|
|
|
|
|
return FileReader(types.SimpleNamespace(path=self._rewritten_names[name])) |
|
|
|
|
|
|
|
|
def _write_pyc_fp( |
|
|
fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType |
|
|
) -> None: |
|
|
|
|
|
|
|
|
|
|
|
fp.write(importlib.util.MAGIC_NUMBER) |
|
|
|
|
|
flags = b"\x00\x00\x00\x00" |
|
|
fp.write(flags) |
|
|
|
|
|
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF |
|
|
size = source_stat.st_size & 0xFFFFFFFF |
|
|
|
|
|
fp.write(struct.pack("<LL", mtime, size)) |
|
|
fp.write(marshal.dumps(co)) |
|
|
|
|
|
|
|
|
def _write_pyc( |
|
|
state: AssertionState, |
|
|
co: types.CodeType, |
|
|
source_stat: os.stat_result, |
|
|
pyc: Path, |
|
|
) -> bool: |
|
|
proc_pyc = f"{pyc}.{os.getpid()}" |
|
|
try: |
|
|
with open(proc_pyc, "wb") as fp: |
|
|
_write_pyc_fp(fp, source_stat, co) |
|
|
except OSError as e: |
|
|
state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}") |
|
|
return False |
|
|
|
|
|
try: |
|
|
os.replace(proc_pyc, pyc) |
|
|
except OSError as e: |
|
|
state.trace(f"error writing pyc file at {pyc}: {e}") |
|
|
|
|
|
|
|
|
|
|
|
return False |
|
|
return True |
|
|
|
|
|
|
|
|
def _rewrite_test(fn: Path, config: Config) -> tuple[os.stat_result, types.CodeType]: |
|
|
"""Read and rewrite *fn* and return the code object.""" |
|
|
stat = os.stat(fn) |
|
|
source = fn.read_bytes() |
|
|
strfn = str(fn) |
|
|
tree = ast.parse(source, filename=strfn) |
|
|
rewrite_asserts(tree, source, strfn, config) |
|
|
co = compile(tree, strfn, "exec", dont_inherit=True) |
|
|
return stat, co |
|
|
|
|
|
|
|
|
def _read_pyc( |
|
|
source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None |
|
|
) -> types.CodeType | None: |
|
|
"""Possibly read a pytest pyc containing rewritten code. |
|
|
|
|
|
Return rewritten code if successful or None if not. |
|
|
""" |
|
|
try: |
|
|
fp = open(pyc, "rb") |
|
|
except OSError: |
|
|
return None |
|
|
with fp: |
|
|
try: |
|
|
stat_result = os.stat(source) |
|
|
mtime = int(stat_result.st_mtime) |
|
|
size = stat_result.st_size |
|
|
data = fp.read(16) |
|
|
except OSError as e: |
|
|
trace(f"_read_pyc({source}): OSError {e}") |
|
|
return None |
|
|
|
|
|
if len(data) != (16): |
|
|
trace(f"_read_pyc({source}): invalid pyc (too short)") |
|
|
return None |
|
|
if data[:4] != importlib.util.MAGIC_NUMBER: |
|
|
trace(f"_read_pyc({source}): invalid pyc (bad magic number)") |
|
|
return None |
|
|
if data[4:8] != b"\x00\x00\x00\x00": |
|
|
trace(f"_read_pyc({source}): invalid pyc (unsupported flags)") |
|
|
return None |
|
|
mtime_data = data[8:12] |
|
|
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF: |
|
|
trace(f"_read_pyc({source}): out of date") |
|
|
return None |
|
|
size_data = data[12:16] |
|
|
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF: |
|
|
trace(f"_read_pyc({source}): invalid pyc (incorrect size)") |
|
|
return None |
|
|
try: |
|
|
co = marshal.load(fp) |
|
|
except Exception as e: |
|
|
trace(f"_read_pyc({source}): marshal.load error {e}") |
|
|
return None |
|
|
if not isinstance(co, types.CodeType): |
|
|
trace(f"_read_pyc({source}): not a code object") |
|
|
return None |
|
|
return co |
|
|
|
|
|
|
|
|
def rewrite_asserts( |
|
|
mod: ast.Module, |
|
|
source: bytes, |
|
|
module_path: str | None = None, |
|
|
config: Config | None = None, |
|
|
) -> None: |
|
|
"""Rewrite the assert statements in mod.""" |
|
|
AssertionRewriter(module_path, config, source).run(mod) |
|
|
|
|
|
|
|
|
def _saferepr(obj: object) -> str: |
|
|
r"""Get a safe repr of an object for assertion error messages. |
|
|
|
|
|
The assertion formatting (util.format_explanation()) requires |
|
|
newlines to be escaped since they are a special character for it. |
|
|
Normally assertion.util.format_explanation() does this but for a |
|
|
custom repr it is possible to contain one of the special escape |
|
|
sequences, especially '\n{' and '\n}' are likely to be present in |
|
|
JSON reprs. |
|
|
""" |
|
|
if isinstance(obj, types.MethodType): |
|
|
|
|
|
return obj.__name__ |
|
|
|
|
|
maxsize = _get_maxsize_for_saferepr(util._config) |
|
|
return saferepr(obj, maxsize=maxsize).replace("\n", "\\n") |
|
|
|
|
|
|
|
|
def _get_maxsize_for_saferepr(config: Config | None) -> int | None: |
|
|
"""Get `maxsize` configuration for saferepr based on the given config object.""" |
|
|
if config is None: |
|
|
verbosity = 0 |
|
|
else: |
|
|
verbosity = config.get_verbosity(Config.VERBOSITY_ASSERTIONS) |
|
|
if verbosity >= 2: |
|
|
return None |
|
|
if verbosity >= 1: |
|
|
return DEFAULT_REPR_MAX_SIZE * 10 |
|
|
return DEFAULT_REPR_MAX_SIZE |
|
|
|
|
|
|
|
|
def _format_assertmsg(obj: object) -> str: |
|
|
r"""Format the custom assertion message given. |
|
|
|
|
|
For strings this simply replaces newlines with '\n~' so that |
|
|
util.format_explanation() will preserve them instead of escaping |
|
|
newlines. For other objects saferepr() is used first. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
replaces = [("\n", "\n~"), ("%", "%%")] |
|
|
if not isinstance(obj, str): |
|
|
obj = saferepr(obj, _get_maxsize_for_saferepr(util._config)) |
|
|
replaces.append(("\\n", "\n~")) |
|
|
|
|
|
for r1, r2 in replaces: |
|
|
obj = obj.replace(r1, r2) |
|
|
|
|
|
return obj |
|
|
|
|
|
|
|
|
def _should_repr_global_name(obj: object) -> bool: |
|
|
if callable(obj): |
|
|
return False |
|
|
|
|
|
try: |
|
|
return not hasattr(obj, "__name__") |
|
|
except Exception: |
|
|
return True |
|
|
|
|
|
|
|
|
def _format_boolop(explanations: Iterable[str], is_or: bool) -> str: |
|
|
explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")" |
|
|
return explanation.replace("%", "%%") |
|
|
|
|
|
|
|
|
def _call_reprcompare( |
|
|
ops: Sequence[str], |
|
|
results: Sequence[bool], |
|
|
expls: Sequence[str], |
|
|
each_obj: Sequence[object], |
|
|
) -> str: |
|
|
for i, res, expl in zip(range(len(ops)), results, expls): |
|
|
try: |
|
|
done = not res |
|
|
except Exception: |
|
|
done = True |
|
|
if done: |
|
|
break |
|
|
if util._reprcompare is not None: |
|
|
custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1]) |
|
|
if custom is not None: |
|
|
return custom |
|
|
return expl |
|
|
|
|
|
|
|
|
def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None: |
|
|
if util._assertion_pass is not None: |
|
|
util._assertion_pass(lineno, orig, expl) |
|
|
|
|
|
|
|
|
def _check_if_assertion_pass_impl() -> bool: |
|
|
"""Check if any plugins implement the pytest_assertion_pass hook |
|
|
in order not to generate explanation unnecessarily (might be expensive).""" |
|
|
return True if util._assertion_pass else False |
|
|
|
|
|
|
|
|
UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"} |
|
|
|
|
|
BINOP_MAP = { |
|
|
ast.BitOr: "|", |
|
|
ast.BitXor: "^", |
|
|
ast.BitAnd: "&", |
|
|
ast.LShift: "<<", |
|
|
ast.RShift: ">>", |
|
|
ast.Add: "+", |
|
|
ast.Sub: "-", |
|
|
ast.Mult: "*", |
|
|
ast.Div: "/", |
|
|
ast.FloorDiv: "//", |
|
|
ast.Mod: "%%", |
|
|
ast.Eq: "==", |
|
|
ast.NotEq: "!=", |
|
|
ast.Lt: "<", |
|
|
ast.LtE: "<=", |
|
|
ast.Gt: ">", |
|
|
ast.GtE: ">=", |
|
|
ast.Pow: "**", |
|
|
ast.Is: "is", |
|
|
ast.IsNot: "is not", |
|
|
ast.In: "in", |
|
|
ast.NotIn: "not in", |
|
|
ast.MatMult: "@", |
|
|
} |
|
|
|
|
|
|
|
|
def traverse_node(node: ast.AST) -> Iterator[ast.AST]: |
|
|
"""Recursively yield node and all its children in depth-first order.""" |
|
|
yield node |
|
|
for child in ast.iter_child_nodes(node): |
|
|
yield from traverse_node(child) |
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=1) |
|
|
def _get_assertion_exprs(src: bytes) -> dict[int, str]: |
|
|
"""Return a mapping from {lineno: "assertion test expression"}.""" |
|
|
ret: dict[int, str] = {} |
|
|
|
|
|
depth = 0 |
|
|
lines: list[str] = [] |
|
|
assert_lineno: int | None = None |
|
|
seen_lines: set[int] = set() |
|
|
|
|
|
def _write_and_reset() -> None: |
|
|
nonlocal depth, lines, assert_lineno, seen_lines |
|
|
assert assert_lineno is not None |
|
|
ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\") |
|
|
depth = 0 |
|
|
lines = [] |
|
|
assert_lineno = None |
|
|
seen_lines = set() |
|
|
|
|
|
tokens = tokenize.tokenize(io.BytesIO(src).readline) |
|
|
for tp, source, (lineno, offset), _, line in tokens: |
|
|
if tp == tokenize.NAME and source == "assert": |
|
|
assert_lineno = lineno |
|
|
elif assert_lineno is not None: |
|
|
|
|
|
if tp == tokenize.OP and source in "([{": |
|
|
depth += 1 |
|
|
elif tp == tokenize.OP and source in ")]}": |
|
|
depth -= 1 |
|
|
|
|
|
if not lines: |
|
|
lines.append(line[offset:]) |
|
|
seen_lines.add(lineno) |
|
|
|
|
|
elif depth == 0 and tp == tokenize.OP and source == ",": |
|
|
|
|
|
if lineno in seen_lines and len(lines) == 1: |
|
|
offset_in_trimmed = offset + len(lines[-1]) - len(line) |
|
|
lines[-1] = lines[-1][:offset_in_trimmed] |
|
|
|
|
|
elif lineno in seen_lines: |
|
|
lines[-1] = lines[-1][:offset] |
|
|
|
|
|
else: |
|
|
lines.append(line[:offset]) |
|
|
_write_and_reset() |
|
|
elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}: |
|
|
_write_and_reset() |
|
|
elif lines and lineno not in seen_lines: |
|
|
lines.append(line) |
|
|
seen_lines.add(lineno) |
|
|
|
|
|
return ret |
|
|
|
|
|
|
|
|
class AssertionRewriter(ast.NodeVisitor): |
|
|
"""Assertion rewriting implementation. |
|
|
|
|
|
The main entrypoint is to call .run() with an ast.Module instance, |
|
|
this will then find all the assert statements and rewrite them to |
|
|
provide intermediate values and a detailed assertion error. See |
|
|
http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html |
|
|
for an overview of how this works. |
|
|
|
|
|
The entry point here is .run() which will iterate over all the |
|
|
statements in an ast.Module and for each ast.Assert statement it |
|
|
finds call .visit() with it. Then .visit_Assert() takes over and |
|
|
is responsible for creating new ast statements to replace the |
|
|
original assert statement: it rewrites the test of an assertion |
|
|
to provide intermediate values and replace it with an if statement |
|
|
which raises an assertion error with a detailed explanation in |
|
|
case the expression is false and calls pytest_assertion_pass hook |
|
|
if expression is true. |
|
|
|
|
|
For this .visit_Assert() uses the visitor pattern to visit all the |
|
|
AST nodes of the ast.Assert.test field, each visit call returning |
|
|
an AST node and the corresponding explanation string. During this |
|
|
state is kept in several instance attributes: |
|
|
|
|
|
:statements: All the AST statements which will replace the assert |
|
|
statement. |
|
|
|
|
|
:variables: This is populated by .variable() with each variable |
|
|
used by the statements so that they can all be set to None at |
|
|
the end of the statements. |
|
|
|
|
|
:variable_counter: Counter to create new unique variables needed |
|
|
by statements. Variables are created using .variable() and |
|
|
have the form of "@py_assert0". |
|
|
|
|
|
:expl_stmts: The AST statements which will be executed to get |
|
|
data from the assertion. This is the code which will construct |
|
|
the detailed assertion message that is used in the AssertionError |
|
|
or for the pytest_assertion_pass hook. |
|
|
|
|
|
:explanation_specifiers: A dict filled by .explanation_param() |
|
|
with %-formatting placeholders and their corresponding |
|
|
expressions to use in the building of an assertion message. |
|
|
This is used by .pop_format_context() to build a message. |
|
|
|
|
|
:stack: A stack of the explanation_specifiers dicts maintained by |
|
|
.push_format_context() and .pop_format_context() which allows |
|
|
to build another %-formatted string while already building one. |
|
|
|
|
|
:scope: A tuple containing the current scope used for variables_overwrite. |
|
|
|
|
|
:variables_overwrite: A dict filled with references to variables |
|
|
that change value within an assert. This happens when a variable is |
|
|
reassigned with the walrus operator |
|
|
|
|
|
This state, except the variables_overwrite, is reset on every new assert |
|
|
statement visited and used by the other visitors. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, module_path: str | None, config: Config | None, source: bytes |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.module_path = module_path |
|
|
self.config = config |
|
|
if config is not None: |
|
|
self.enable_assertion_pass_hook = config.getini( |
|
|
"enable_assertion_pass_hook" |
|
|
) |
|
|
else: |
|
|
self.enable_assertion_pass_hook = False |
|
|
self.source = source |
|
|
self.scope: tuple[ast.AST, ...] = () |
|
|
self.variables_overwrite: defaultdict[tuple[ast.AST, ...], dict[str, str]] = ( |
|
|
defaultdict(dict) |
|
|
) |
|
|
|
|
|
def run(self, mod: ast.Module) -> None: |
|
|
"""Find all assert statements in *mod* and rewrite them.""" |
|
|
if not mod.body: |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
doc = getattr(mod, "docstring", None) |
|
|
expect_docstring = doc is None |
|
|
if doc is not None and self.is_rewrite_disabled(doc): |
|
|
return |
|
|
pos = 0 |
|
|
item = None |
|
|
for item in mod.body: |
|
|
if ( |
|
|
expect_docstring |
|
|
and isinstance(item, ast.Expr) |
|
|
and isinstance(item.value, ast.Constant) |
|
|
and isinstance(item.value.value, str) |
|
|
): |
|
|
doc = item.value.value |
|
|
if self.is_rewrite_disabled(doc): |
|
|
return |
|
|
expect_docstring = False |
|
|
elif ( |
|
|
isinstance(item, ast.ImportFrom) |
|
|
and item.level == 0 |
|
|
and item.module == "__future__" |
|
|
): |
|
|
pass |
|
|
else: |
|
|
break |
|
|
pos += 1 |
|
|
|
|
|
|
|
|
if isinstance(item, ast.FunctionDef) and item.decorator_list: |
|
|
lineno = item.decorator_list[0].lineno |
|
|
else: |
|
|
lineno = item.lineno |
|
|
|
|
|
if sys.version_info >= (3, 10): |
|
|
aliases = [ |
|
|
ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0), |
|
|
ast.alias( |
|
|
"_pytest.assertion.rewrite", |
|
|
"@pytest_ar", |
|
|
lineno=lineno, |
|
|
col_offset=0, |
|
|
), |
|
|
] |
|
|
else: |
|
|
aliases = [ |
|
|
ast.alias("builtins", "@py_builtins"), |
|
|
ast.alias("_pytest.assertion.rewrite", "@pytest_ar"), |
|
|
] |
|
|
imports = [ |
|
|
ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases |
|
|
] |
|
|
mod.body[pos:pos] = imports |
|
|
|
|
|
|
|
|
self.scope = (mod,) |
|
|
nodes: list[ast.AST | Sentinel] = [mod] |
|
|
while nodes: |
|
|
node = nodes.pop() |
|
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): |
|
|
self.scope = tuple((*self.scope, node)) |
|
|
nodes.append(_SCOPE_END_MARKER) |
|
|
if node == _SCOPE_END_MARKER: |
|
|
self.scope = self.scope[:-1] |
|
|
continue |
|
|
assert isinstance(node, ast.AST) |
|
|
for name, field in ast.iter_fields(node): |
|
|
if isinstance(field, list): |
|
|
new: list[ast.AST] = [] |
|
|
for i, child in enumerate(field): |
|
|
if isinstance(child, ast.Assert): |
|
|
|
|
|
new.extend(self.visit(child)) |
|
|
else: |
|
|
new.append(child) |
|
|
if isinstance(child, ast.AST): |
|
|
nodes.append(child) |
|
|
setattr(node, name, new) |
|
|
elif ( |
|
|
isinstance(field, ast.AST) |
|
|
|
|
|
|
|
|
and not isinstance(field, ast.expr) |
|
|
): |
|
|
nodes.append(field) |
|
|
|
|
|
@staticmethod |
|
|
def is_rewrite_disabled(docstring: str) -> bool: |
|
|
return "PYTEST_DONT_REWRITE" in docstring |
|
|
|
|
|
def variable(self) -> str: |
|
|
"""Get a new variable.""" |
|
|
|
|
|
name = "@py_assert" + str(next(self.variable_counter)) |
|
|
self.variables.append(name) |
|
|
return name |
|
|
|
|
|
def assign(self, expr: ast.expr) -> ast.Name: |
|
|
"""Give *expr* a name.""" |
|
|
name = self.variable() |
|
|
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr)) |
|
|
return ast.copy_location(ast.Name(name, ast.Load()), expr) |
|
|
|
|
|
def display(self, expr: ast.expr) -> ast.expr: |
|
|
"""Call saferepr on the expression.""" |
|
|
return self.helper("_saferepr", expr) |
|
|
|
|
|
def helper(self, name: str, *args: ast.expr) -> ast.expr: |
|
|
"""Call a helper in this module.""" |
|
|
py_name = ast.Name("@pytest_ar", ast.Load()) |
|
|
attr = ast.Attribute(py_name, name, ast.Load()) |
|
|
return ast.Call(attr, list(args), []) |
|
|
|
|
|
def builtin(self, name: str) -> ast.Attribute: |
|
|
"""Return the builtin called *name*.""" |
|
|
builtin_name = ast.Name("@py_builtins", ast.Load()) |
|
|
return ast.Attribute(builtin_name, name, ast.Load()) |
|
|
|
|
|
def explanation_param(self, expr: ast.expr) -> str: |
|
|
"""Return a new named %-formatting placeholder for expr. |
|
|
|
|
|
This creates a %-formatting placeholder for expr in the |
|
|
current formatting context, e.g. ``%(py0)s``. The placeholder |
|
|
and expr are placed in the current format context so that it |
|
|
can be used on the next call to .pop_format_context(). |
|
|
""" |
|
|
specifier = "py" + str(next(self.variable_counter)) |
|
|
self.explanation_specifiers[specifier] = expr |
|
|
return "%(" + specifier + ")s" |
|
|
|
|
|
def push_format_context(self) -> None: |
|
|
"""Create a new formatting context. |
|
|
|
|
|
The format context is used for when an explanation wants to |
|
|
have a variable value formatted in the assertion message. In |
|
|
this case the value required can be added using |
|
|
.explanation_param(). Finally .pop_format_context() is used |
|
|
to format a string of %-formatted values as added by |
|
|
.explanation_param(). |
|
|
""" |
|
|
self.explanation_specifiers: dict[str, ast.expr] = {} |
|
|
self.stack.append(self.explanation_specifiers) |
|
|
|
|
|
def pop_format_context(self, expl_expr: ast.expr) -> ast.Name: |
|
|
"""Format the %-formatted string with current format context. |
|
|
|
|
|
The expl_expr should be an str ast.expr instance constructed from |
|
|
the %-placeholders created by .explanation_param(). This will |
|
|
add the required code to format said string to .expl_stmts and |
|
|
return the ast.Name instance of the formatted string. |
|
|
""" |
|
|
current = self.stack.pop() |
|
|
if self.stack: |
|
|
self.explanation_specifiers = self.stack[-1] |
|
|
keys: list[ast.expr | None] = [ast.Constant(key) for key in current.keys()] |
|
|
format_dict = ast.Dict(keys, list(current.values())) |
|
|
form = ast.BinOp(expl_expr, ast.Mod(), format_dict) |
|
|
name = "@py_format" + str(next(self.variable_counter)) |
|
|
if self.enable_assertion_pass_hook: |
|
|
self.format_variables.append(name) |
|
|
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form)) |
|
|
return ast.Name(name, ast.Load()) |
|
|
|
|
|
def generic_visit(self, node: ast.AST) -> tuple[ast.Name, str]: |
|
|
"""Handle expressions we don't have custom code for.""" |
|
|
assert isinstance(node, ast.expr) |
|
|
res = self.assign(node) |
|
|
return res, self.explanation_param(self.display(res)) |
|
|
|
|
|
def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]: |
|
|
"""Return the AST statements to replace the ast.Assert instance. |
|
|
|
|
|
This rewrites the test of an assertion to provide |
|
|
intermediate values and replace it with an if statement which |
|
|
raises an assertion error with a detailed explanation in case |
|
|
the expression is false. |
|
|
""" |
|
|
if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1: |
|
|
import warnings |
|
|
|
|
|
from _pytest.warning_types import PytestAssertRewriteWarning |
|
|
|
|
|
|
|
|
assert self.module_path is not None |
|
|
warnings.warn_explicit( |
|
|
PytestAssertRewriteWarning( |
|
|
"assertion is always true, perhaps remove parentheses?" |
|
|
), |
|
|
category=None, |
|
|
filename=self.module_path, |
|
|
lineno=assert_.lineno, |
|
|
) |
|
|
|
|
|
self.statements: list[ast.stmt] = [] |
|
|
self.variables: list[str] = [] |
|
|
self.variable_counter = itertools.count() |
|
|
|
|
|
if self.enable_assertion_pass_hook: |
|
|
self.format_variables: list[str] = [] |
|
|
|
|
|
self.stack: list[dict[str, ast.expr]] = [] |
|
|
self.expl_stmts: list[ast.stmt] = [] |
|
|
self.push_format_context() |
|
|
|
|
|
top_condition, explanation = self.visit(assert_.test) |
|
|
|
|
|
negation = ast.UnaryOp(ast.Not(), top_condition) |
|
|
|
|
|
if self.enable_assertion_pass_hook: |
|
|
msg = self.pop_format_context(ast.Constant(explanation)) |
|
|
|
|
|
|
|
|
if assert_.msg: |
|
|
assertmsg = self.helper("_format_assertmsg", assert_.msg) |
|
|
gluestr = "\n>assert " |
|
|
else: |
|
|
assertmsg = ast.Constant("") |
|
|
gluestr = "assert " |
|
|
err_explanation = ast.BinOp(ast.Constant(gluestr), ast.Add(), msg) |
|
|
err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation) |
|
|
err_name = ast.Name("AssertionError", ast.Load()) |
|
|
fmt = self.helper("_format_explanation", err_msg) |
|
|
exc = ast.Call(err_name, [fmt], []) |
|
|
raise_ = ast.Raise(exc, None) |
|
|
statements_fail = [] |
|
|
statements_fail.extend(self.expl_stmts) |
|
|
statements_fail.append(raise_) |
|
|
|
|
|
|
|
|
fmt_pass = self.helper("_format_explanation", msg) |
|
|
orig = _get_assertion_exprs(self.source)[assert_.lineno] |
|
|
hook_call_pass = ast.Expr( |
|
|
self.helper( |
|
|
"_call_assertion_pass", |
|
|
ast.Constant(assert_.lineno), |
|
|
ast.Constant(orig), |
|
|
fmt_pass, |
|
|
) |
|
|
) |
|
|
|
|
|
hook_impl_test = ast.If( |
|
|
self.helper("_check_if_assertion_pass_impl"), |
|
|
[*self.expl_stmts, hook_call_pass], |
|
|
[], |
|
|
) |
|
|
statements_pass: list[ast.stmt] = [hook_impl_test] |
|
|
|
|
|
|
|
|
main_test = ast.If(negation, statements_fail, statements_pass) |
|
|
self.statements.append(main_test) |
|
|
if self.format_variables: |
|
|
variables: list[ast.expr] = [ |
|
|
ast.Name(name, ast.Store()) for name in self.format_variables |
|
|
] |
|
|
clear_format = ast.Assign(variables, ast.Constant(None)) |
|
|
self.statements.append(clear_format) |
|
|
|
|
|
else: |
|
|
|
|
|
body = self.expl_stmts |
|
|
self.statements.append(ast.If(negation, body, [])) |
|
|
if assert_.msg: |
|
|
assertmsg = self.helper("_format_assertmsg", assert_.msg) |
|
|
explanation = "\n>assert " + explanation |
|
|
else: |
|
|
assertmsg = ast.Constant("") |
|
|
explanation = "assert " + explanation |
|
|
template = ast.BinOp(assertmsg, ast.Add(), ast.Constant(explanation)) |
|
|
msg = self.pop_format_context(template) |
|
|
fmt = self.helper("_format_explanation", msg) |
|
|
err_name = ast.Name("AssertionError", ast.Load()) |
|
|
exc = ast.Call(err_name, [fmt], []) |
|
|
raise_ = ast.Raise(exc, None) |
|
|
|
|
|
body.append(raise_) |
|
|
|
|
|
|
|
|
if self.variables: |
|
|
variables = [ast.Name(name, ast.Store()) for name in self.variables] |
|
|
clear = ast.Assign(variables, ast.Constant(None)) |
|
|
self.statements.append(clear) |
|
|
|
|
|
for stmt in self.statements: |
|
|
for node in traverse_node(stmt): |
|
|
if getattr(node, "lineno", None) is None: |
|
|
|
|
|
|
|
|
ast.copy_location(node, assert_) |
|
|
return self.statements |
|
|
|
|
|
def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]: |
|
|
|
|
|
|
|
|
|
|
|
locs = ast.Call(self.builtin("locals"), [], []) |
|
|
target_id = name.target.id |
|
|
inlocs = ast.Compare(ast.Constant(target_id), [ast.In()], [locs]) |
|
|
dorepr = self.helper("_should_repr_global_name", name) |
|
|
test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) |
|
|
expr = ast.IfExp(test, self.display(name), ast.Constant(target_id)) |
|
|
return name, self.explanation_param(expr) |
|
|
|
|
|
def visit_Name(self, name: ast.Name) -> tuple[ast.Name, str]: |
|
|
|
|
|
|
|
|
locs = ast.Call(self.builtin("locals"), [], []) |
|
|
inlocs = ast.Compare(ast.Constant(name.id), [ast.In()], [locs]) |
|
|
dorepr = self.helper("_should_repr_global_name", name) |
|
|
test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) |
|
|
expr = ast.IfExp(test, self.display(name), ast.Constant(name.id)) |
|
|
return name, self.explanation_param(expr) |
|
|
|
|
|
def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: |
|
|
res_var = self.variable() |
|
|
expl_list = self.assign(ast.List([], ast.Load())) |
|
|
app = ast.Attribute(expl_list, "append", ast.Load()) |
|
|
is_or = int(isinstance(boolop.op, ast.Or)) |
|
|
body = save = self.statements |
|
|
fail_save = self.expl_stmts |
|
|
levels = len(boolop.values) - 1 |
|
|
self.push_format_context() |
|
|
|
|
|
for i, v in enumerate(boolop.values): |
|
|
if i: |
|
|
fail_inner: list[ast.stmt] = [] |
|
|
|
|
|
self.expl_stmts.append(ast.If(cond, fail_inner, [])) |
|
|
self.expl_stmts = fail_inner |
|
|
|
|
|
if ( |
|
|
isinstance(v, ast.Compare) |
|
|
and isinstance(v.left, ast.NamedExpr) |
|
|
and v.left.target.id |
|
|
in [ |
|
|
ast_expr.id |
|
|
for ast_expr in boolop.values[:i] |
|
|
if hasattr(ast_expr, "id") |
|
|
] |
|
|
): |
|
|
pytest_temp = self.variable() |
|
|
self.variables_overwrite[self.scope][v.left.target.id] = v.left |
|
|
v.left.target.id = pytest_temp |
|
|
self.push_format_context() |
|
|
res, expl = self.visit(v) |
|
|
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) |
|
|
expl_format = self.pop_format_context(ast.Constant(expl)) |
|
|
call = ast.Call(app, [expl_format], []) |
|
|
self.expl_stmts.append(ast.Expr(call)) |
|
|
if i < levels: |
|
|
cond: ast.expr = res |
|
|
if is_or: |
|
|
cond = ast.UnaryOp(ast.Not(), cond) |
|
|
inner: list[ast.stmt] = [] |
|
|
self.statements.append(ast.If(cond, inner, [])) |
|
|
self.statements = body = inner |
|
|
self.statements = save |
|
|
self.expl_stmts = fail_save |
|
|
expl_template = self.helper("_format_boolop", expl_list, ast.Constant(is_or)) |
|
|
expl = self.pop_format_context(expl_template) |
|
|
return ast.Name(res_var, ast.Load()), self.explanation_param(expl) |
|
|
|
|
|
def visit_UnaryOp(self, unary: ast.UnaryOp) -> tuple[ast.Name, str]: |
|
|
pattern = UNARY_MAP[unary.op.__class__] |
|
|
operand_res, operand_expl = self.visit(unary.operand) |
|
|
res = self.assign(ast.copy_location(ast.UnaryOp(unary.op, operand_res), unary)) |
|
|
return res, pattern % (operand_expl,) |
|
|
|
|
|
def visit_BinOp(self, binop: ast.BinOp) -> tuple[ast.Name, str]: |
|
|
symbol = BINOP_MAP[binop.op.__class__] |
|
|
left_expr, left_expl = self.visit(binop.left) |
|
|
right_expr, right_expl = self.visit(binop.right) |
|
|
explanation = f"({left_expl} {symbol} {right_expl})" |
|
|
res = self.assign( |
|
|
ast.copy_location(ast.BinOp(left_expr, binop.op, right_expr), binop) |
|
|
) |
|
|
return res, explanation |
|
|
|
|
|
def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]: |
|
|
new_func, func_expl = self.visit(call.func) |
|
|
arg_expls = [] |
|
|
new_args = [] |
|
|
new_kwargs = [] |
|
|
for arg in call.args: |
|
|
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get( |
|
|
self.scope, {} |
|
|
): |
|
|
arg = self.variables_overwrite[self.scope][arg.id] |
|
|
res, expl = self.visit(arg) |
|
|
arg_expls.append(expl) |
|
|
new_args.append(res) |
|
|
for keyword in call.keywords: |
|
|
if isinstance( |
|
|
keyword.value, ast.Name |
|
|
) and keyword.value.id in self.variables_overwrite.get(self.scope, {}): |
|
|
keyword.value = self.variables_overwrite[self.scope][keyword.value.id] |
|
|
res, expl = self.visit(keyword.value) |
|
|
new_kwargs.append(ast.keyword(keyword.arg, res)) |
|
|
if keyword.arg: |
|
|
arg_expls.append(keyword.arg + "=" + expl) |
|
|
else: |
|
|
arg_expls.append("**" + expl) |
|
|
|
|
|
expl = "{}({})".format(func_expl, ", ".join(arg_expls)) |
|
|
new_call = ast.copy_location(ast.Call(new_func, new_args, new_kwargs), call) |
|
|
res = self.assign(new_call) |
|
|
res_expl = self.explanation_param(self.display(res)) |
|
|
outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}" |
|
|
return res, outer_expl |
|
|
|
|
|
def visit_Starred(self, starred: ast.Starred) -> tuple[ast.Starred, str]: |
|
|
|
|
|
res, expl = self.visit(starred.value) |
|
|
new_starred = ast.Starred(res, starred.ctx) |
|
|
return new_starred, "*" + expl |
|
|
|
|
|
def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]: |
|
|
if not isinstance(attr.ctx, ast.Load): |
|
|
return self.generic_visit(attr) |
|
|
value, value_expl = self.visit(attr.value) |
|
|
res = self.assign( |
|
|
ast.copy_location(ast.Attribute(value, attr.attr, ast.Load()), attr) |
|
|
) |
|
|
res_expl = self.explanation_param(self.display(res)) |
|
|
pat = "%s\n{%s = %s.%s\n}" |
|
|
expl = pat % (res_expl, res_expl, value_expl, attr.attr) |
|
|
return res, expl |
|
|
|
|
|
def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: |
|
|
self.push_format_context() |
|
|
|
|
|
if isinstance( |
|
|
comp.left, ast.Name |
|
|
) and comp.left.id in self.variables_overwrite.get(self.scope, {}): |
|
|
comp.left = self.variables_overwrite[self.scope][comp.left.id] |
|
|
if isinstance(comp.left, ast.NamedExpr): |
|
|
self.variables_overwrite[self.scope][comp.left.target.id] = comp.left |
|
|
left_res, left_expl = self.visit(comp.left) |
|
|
if isinstance(comp.left, (ast.Compare, ast.BoolOp)): |
|
|
left_expl = f"({left_expl})" |
|
|
res_variables = [self.variable() for i in range(len(comp.ops))] |
|
|
load_names: list[ast.expr] = [ast.Name(v, ast.Load()) for v in res_variables] |
|
|
store_names = [ast.Name(v, ast.Store()) for v in res_variables] |
|
|
it = zip(range(len(comp.ops)), comp.ops, comp.comparators) |
|
|
expls: list[ast.expr] = [] |
|
|
syms: list[ast.expr] = [] |
|
|
results = [left_res] |
|
|
for i, op, next_operand in it: |
|
|
if ( |
|
|
isinstance(next_operand, ast.NamedExpr) |
|
|
and isinstance(left_res, ast.Name) |
|
|
and next_operand.target.id == left_res.id |
|
|
): |
|
|
next_operand.target.id = self.variable() |
|
|
self.variables_overwrite[self.scope][left_res.id] = next_operand |
|
|
next_res, next_expl = self.visit(next_operand) |
|
|
if isinstance(next_operand, (ast.Compare, ast.BoolOp)): |
|
|
next_expl = f"({next_expl})" |
|
|
results.append(next_res) |
|
|
sym = BINOP_MAP[op.__class__] |
|
|
syms.append(ast.Constant(sym)) |
|
|
expl = f"{left_expl} {sym} {next_expl}" |
|
|
expls.append(ast.Constant(expl)) |
|
|
res_expr = ast.copy_location(ast.Compare(left_res, [op], [next_res]), comp) |
|
|
self.statements.append(ast.Assign([store_names[i]], res_expr)) |
|
|
left_res, left_expl = next_res, next_expl |
|
|
|
|
|
expl_call = self.helper( |
|
|
"_call_reprcompare", |
|
|
ast.Tuple(syms, ast.Load()), |
|
|
ast.Tuple(load_names, ast.Load()), |
|
|
ast.Tuple(expls, ast.Load()), |
|
|
ast.Tuple(results, ast.Load()), |
|
|
) |
|
|
if len(comp.ops) > 1: |
|
|
res: ast.expr = ast.BoolOp(ast.And(), load_names) |
|
|
else: |
|
|
res = load_names[0] |
|
|
|
|
|
return res, self.explanation_param(self.pop_format_context(expl_call)) |
|
|
|
|
|
|
|
|
def try_makedirs(cache_dir: Path) -> bool: |
|
|
"""Attempt to create the given directory and sub-directories exist. |
|
|
|
|
|
Returns True if successful or if it already exists. |
|
|
""" |
|
|
try: |
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
except (FileNotFoundError, NotADirectoryError, FileExistsError): |
|
|
|
|
|
|
|
|
|
|
|
return False |
|
|
except PermissionError: |
|
|
return False |
|
|
except OSError as e: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if e.errno in {errno.EROFS, errno.ENOSYS}: |
|
|
return False |
|
|
raise |
|
|
return True |
|
|
|
|
|
|
|
|
def get_cache_dir(file_path: Path) -> Path: |
|
|
"""Return the cache directory to write .pyc files for the given .py file path.""" |
|
|
if sys.pycache_prefix: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1]) |
|
|
else: |
|
|
|
|
|
return file_path.parent / "__pycache__" |
|
|
|