| | """Rewrite assertion AST to produce nice error messages.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import ast |
| | from collections import defaultdict |
| | import errno |
| | import functools |
| | import importlib.abc |
| | import importlib.machinery |
| | import importlib.util |
| | import io |
| | import itertools |
| | import marshal |
| | import os |
| | from pathlib import Path |
| | from pathlib import PurePath |
| | import struct |
| | import sys |
| | import tokenize |
| | import types |
| | from typing import Callable |
| | from typing import IO |
| | from typing import Iterable |
| | from typing import Iterator |
| | from typing import Sequence |
| | from typing import TYPE_CHECKING |
| |
|
| | from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE |
| | from _pytest._io.saferepr import saferepr |
| | from _pytest._version import version |
| | from _pytest.assertion import util |
| | from _pytest.config import Config |
| | from _pytest.main import Session |
| | from _pytest.pathlib import absolutepath |
| | from _pytest.pathlib import fnmatch_ex |
| | from _pytest.stash import StashKey |
| |
|
| |
|
| | |
| | from _pytest.assertion.util import format_explanation as _format_explanation |
| | |
| |
|
| | if TYPE_CHECKING: |
| | from _pytest.assertion import AssertionState |
| |
|
| |
|
| | class Sentinel: |
| | pass |
| |
|
| |
|
| | assertstate_key = StashKey["AssertionState"]() |
| |
|
| | |
| | PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}" |
| | PYC_EXT = ".py" + (__debug__ and "c" or "o") |
| | PYC_TAIL = "." + PYTEST_TAG + PYC_EXT |
| |
|
| | |
| | _SCOPE_END_MARKER = Sentinel() |
| |
|
| |
|
| | class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): |
| | """PEP302/PEP451 import hook which rewrites asserts.""" |
| |
|
| | def __init__(self, config: Config) -> None: |
| | self.config = config |
| | try: |
| | self.fnpats = config.getini("python_files") |
| | except ValueError: |
| | self.fnpats = ["test_*.py", "*_test.py"] |
| | self.session: Session | None = None |
| | self._rewritten_names: dict[str, Path] = {} |
| | self._must_rewrite: set[str] = set() |
| | |
| | |
| | self._writing_pyc = False |
| | self._basenames_to_check_rewrite = {"conftest"} |
| | self._marked_for_rewrite_cache: dict[str, bool] = {} |
| | self._session_paths_checked = False |
| |
|
| | def set_session(self, session: Session | None) -> None: |
| | self.session = session |
| | self._session_paths_checked = False |
| |
|
| | |
| | _find_spec = importlib.machinery.PathFinder.find_spec |
| |
|
| | def find_spec( |
| | self, |
| | name: str, |
| | path: Sequence[str | bytes] | None = None, |
| | target: types.ModuleType | None = None, |
| | ) -> importlib.machinery.ModuleSpec | None: |
| | if self._writing_pyc: |
| | return None |
| | state = self.config.stash[assertstate_key] |
| | if self._early_rewrite_bailout(name, state): |
| | return None |
| | state.trace(f"find_module called for: {name}") |
| |
|
| | |
| | spec = self._find_spec(name, path) |
| |
|
| | if spec is None and path is not None: |
| | |
| | |
| | |
| | for _path_str in path: |
| | spec = importlib.util.spec_from_file_location(name, _path_str) |
| | if spec is not None: |
| | break |
| |
|
| | if ( |
| | |
| | spec is None |
| | |
| | |
| | or spec.origin is None |
| | |
| | or not isinstance(spec.loader, importlib.machinery.SourceFileLoader) |
| | |
| | or not os.path.exists(spec.origin) |
| | ): |
| | return None |
| | else: |
| | fn = spec.origin |
| |
|
| | if not self._should_rewrite(name, fn, state): |
| | return None |
| |
|
| | return importlib.util.spec_from_file_location( |
| | name, |
| | fn, |
| | loader=self, |
| | submodule_search_locations=spec.submodule_search_locations, |
| | ) |
| |
|
| | def create_module( |
| | self, spec: importlib.machinery.ModuleSpec |
| | ) -> types.ModuleType | None: |
| | return None |
| |
|
| | def exec_module(self, module: types.ModuleType) -> None: |
| | assert module.__spec__ is not None |
| | assert module.__spec__.origin is not None |
| | fn = Path(module.__spec__.origin) |
| | state = self.config.stash[assertstate_key] |
| |
|
| | self._rewritten_names[module.__name__] = fn |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | write = not sys.dont_write_bytecode |
| | cache_dir = get_cache_dir(fn) |
| | if write: |
| | ok = try_makedirs(cache_dir) |
| | if not ok: |
| | write = False |
| | state.trace(f"read only directory: {cache_dir}") |
| |
|
| | cache_name = fn.name[:-3] + PYC_TAIL |
| | pyc = cache_dir / cache_name |
| | |
| | |
| | co = _read_pyc(fn, pyc, state.trace) |
| | if co is None: |
| | state.trace(f"rewriting {fn!r}") |
| | source_stat, co = _rewrite_test(fn, self.config) |
| | if write: |
| | self._writing_pyc = True |
| | try: |
| | _write_pyc(state, co, source_stat, pyc) |
| | finally: |
| | self._writing_pyc = False |
| | else: |
| | state.trace(f"found cached rewritten pyc for {fn}") |
| | exec(co, module.__dict__) |
| |
|
| | def _early_rewrite_bailout(self, name: str, state: AssertionState) -> bool: |
| | """A fast way to get out of rewriting modules. |
| | |
| | Profiling has shown that the call to PathFinder.find_spec (inside of |
| | the find_spec from this class) is a major slowdown, so, this method |
| | tries to filter what we're sure won't be rewritten before getting to |
| | it. |
| | """ |
| | if self.session is not None and not self._session_paths_checked: |
| | self._session_paths_checked = True |
| | for initial_path in self.session._initialpaths: |
| | |
| | |
| | parts = str(initial_path).split(os.sep) |
| | |
| | self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0]) |
| |
|
| | |
| | parts = name.split(".") |
| | if parts[-1] in self._basenames_to_check_rewrite: |
| | return False |
| |
|
| | |
| | path = PurePath(*parts).with_suffix(".py") |
| |
|
| | for pat in self.fnpats: |
| | |
| | |
| | if os.path.dirname(pat): |
| | return False |
| | if fnmatch_ex(pat, path): |
| | return False |
| |
|
| | if self._is_marked_for_rewrite(name, state): |
| | return False |
| |
|
| | state.trace(f"early skip of rewriting module: {name}") |
| | return True |
| |
|
| | def _should_rewrite(self, name: str, fn: str, state: AssertionState) -> bool: |
| | |
| | if os.path.basename(fn) == "conftest.py": |
| | state.trace(f"rewriting conftest file: {fn!r}") |
| | return True |
| |
|
| | if self.session is not None: |
| | if self.session.isinitpath(absolutepath(fn)): |
| | state.trace(f"matched test file (was specified on cmdline): {fn!r}") |
| | return True |
| |
|
| | |
| | |
| | fn_path = PurePath(fn) |
| | for pat in self.fnpats: |
| | if fnmatch_ex(pat, fn_path): |
| | state.trace(f"matched test file {fn!r}") |
| | return True |
| |
|
| | return self._is_marked_for_rewrite(name, state) |
| |
|
| | def _is_marked_for_rewrite(self, name: str, state: AssertionState) -> bool: |
| | try: |
| | return self._marked_for_rewrite_cache[name] |
| | except KeyError: |
| | for marked in self._must_rewrite: |
| | if name == marked or name.startswith(marked + "."): |
| | state.trace(f"matched marked file {name!r} (from {marked!r})") |
| | self._marked_for_rewrite_cache[name] = True |
| | return True |
| |
|
| | self._marked_for_rewrite_cache[name] = False |
| | return False |
| |
|
| | def mark_rewrite(self, *names: str) -> None: |
| | """Mark import names as needing to be rewritten. |
| | |
| | The named module or package as well as any nested modules will |
| | be rewritten on import. |
| | """ |
| | already_imported = ( |
| | set(names).intersection(sys.modules).difference(self._rewritten_names) |
| | ) |
| | for name in already_imported: |
| | mod = sys.modules[name] |
| | if not AssertionRewriter.is_rewrite_disabled( |
| | mod.__doc__ or "" |
| | ) and not isinstance(mod.__loader__, type(self)): |
| | self._warn_already_imported(name) |
| | self._must_rewrite.update(names) |
| | self._marked_for_rewrite_cache.clear() |
| |
|
| | def _warn_already_imported(self, name: str) -> None: |
| | from _pytest.warning_types import PytestAssertRewriteWarning |
| |
|
| | self.config.issue_config_time_warning( |
| | PytestAssertRewriteWarning( |
| | f"Module already imported so cannot be rewritten: {name}" |
| | ), |
| | stacklevel=5, |
| | ) |
| |
|
| | def get_data(self, pathname: str | bytes) -> bytes: |
| | """Optional PEP302 get_data API.""" |
| | with open(pathname, "rb") as f: |
| | return f.read() |
| |
|
| | if sys.version_info >= (3, 10): |
| | if sys.version_info >= (3, 12): |
| | from importlib.resources.abc import TraversableResources |
| | else: |
| | from importlib.abc import TraversableResources |
| |
|
| | def get_resource_reader(self, name: str) -> TraversableResources: |
| | if sys.version_info < (3, 11): |
| | from importlib.readers import FileReader |
| | else: |
| | from importlib.resources.readers import FileReader |
| |
|
| | return FileReader(types.SimpleNamespace(path=self._rewritten_names[name])) |
| |
|
| |
|
| | def _write_pyc_fp( |
| | fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType |
| | ) -> None: |
| | |
| | |
| | |
| | fp.write(importlib.util.MAGIC_NUMBER) |
| | |
| | flags = b"\x00\x00\x00\x00" |
| | fp.write(flags) |
| | |
| | mtime = int(source_stat.st_mtime) & 0xFFFFFFFF |
| | size = source_stat.st_size & 0xFFFFFFFF |
| | |
| | fp.write(struct.pack("<LL", mtime, size)) |
| | fp.write(marshal.dumps(co)) |
| |
|
| |
|
| | def _write_pyc( |
| | state: AssertionState, |
| | co: types.CodeType, |
| | source_stat: os.stat_result, |
| | pyc: Path, |
| | ) -> bool: |
| | proc_pyc = f"{pyc}.{os.getpid()}" |
| | try: |
| | with open(proc_pyc, "wb") as fp: |
| | _write_pyc_fp(fp, source_stat, co) |
| | except OSError as e: |
| | state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}") |
| | return False |
| |
|
| | try: |
| | os.replace(proc_pyc, pyc) |
| | except OSError as e: |
| | state.trace(f"error writing pyc file at {pyc}: {e}") |
| | |
| | |
| | |
| | return False |
| | return True |
| |
|
| |
|
| | def _rewrite_test(fn: Path, config: Config) -> tuple[os.stat_result, types.CodeType]: |
| | """Read and rewrite *fn* and return the code object.""" |
| | stat = os.stat(fn) |
| | source = fn.read_bytes() |
| | strfn = str(fn) |
| | tree = ast.parse(source, filename=strfn) |
| | rewrite_asserts(tree, source, strfn, config) |
| | co = compile(tree, strfn, "exec", dont_inherit=True) |
| | return stat, co |
| |
|
| |
|
| | def _read_pyc( |
| | source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None |
| | ) -> types.CodeType | None: |
| | """Possibly read a pytest pyc containing rewritten code. |
| | |
| | Return rewritten code if successful or None if not. |
| | """ |
| | try: |
| | fp = open(pyc, "rb") |
| | except OSError: |
| | return None |
| | with fp: |
| | try: |
| | stat_result = os.stat(source) |
| | mtime = int(stat_result.st_mtime) |
| | size = stat_result.st_size |
| | data = fp.read(16) |
| | except OSError as e: |
| | trace(f"_read_pyc({source}): OSError {e}") |
| | return None |
| | |
| | if len(data) != (16): |
| | trace(f"_read_pyc({source}): invalid pyc (too short)") |
| | return None |
| | if data[:4] != importlib.util.MAGIC_NUMBER: |
| | trace(f"_read_pyc({source}): invalid pyc (bad magic number)") |
| | return None |
| | if data[4:8] != b"\x00\x00\x00\x00": |
| | trace(f"_read_pyc({source}): invalid pyc (unsupported flags)") |
| | return None |
| | mtime_data = data[8:12] |
| | if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF: |
| | trace(f"_read_pyc({source}): out of date") |
| | return None |
| | size_data = data[12:16] |
| | if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF: |
| | trace(f"_read_pyc({source}): invalid pyc (incorrect size)") |
| | return None |
| | try: |
| | co = marshal.load(fp) |
| | except Exception as e: |
| | trace(f"_read_pyc({source}): marshal.load error {e}") |
| | return None |
| | if not isinstance(co, types.CodeType): |
| | trace(f"_read_pyc({source}): not a code object") |
| | return None |
| | return co |
| |
|
| |
|
| | def rewrite_asserts( |
| | mod: ast.Module, |
| | source: bytes, |
| | module_path: str | None = None, |
| | config: Config | None = None, |
| | ) -> None: |
| | """Rewrite the assert statements in mod.""" |
| | AssertionRewriter(module_path, config, source).run(mod) |
| |
|
| |
|
| | def _saferepr(obj: object) -> str: |
| | r"""Get a safe repr of an object for assertion error messages. |
| | |
| | The assertion formatting (util.format_explanation()) requires |
| | newlines to be escaped since they are a special character for it. |
| | Normally assertion.util.format_explanation() does this but for a |
| | custom repr it is possible to contain one of the special escape |
| | sequences, especially '\n{' and '\n}' are likely to be present in |
| | JSON reprs. |
| | """ |
| | if isinstance(obj, types.MethodType): |
| | |
| | return obj.__name__ |
| |
|
| | maxsize = _get_maxsize_for_saferepr(util._config) |
| | return saferepr(obj, maxsize=maxsize).replace("\n", "\\n") |
| |
|
| |
|
| | def _get_maxsize_for_saferepr(config: Config | None) -> int | None: |
| | """Get `maxsize` configuration for saferepr based on the given config object.""" |
| | if config is None: |
| | verbosity = 0 |
| | else: |
| | verbosity = config.get_verbosity(Config.VERBOSITY_ASSERTIONS) |
| | if verbosity >= 2: |
| | return None |
| | if verbosity >= 1: |
| | return DEFAULT_REPR_MAX_SIZE * 10 |
| | return DEFAULT_REPR_MAX_SIZE |
| |
|
| |
|
| | def _format_assertmsg(obj: object) -> str: |
| | r"""Format the custom assertion message given. |
| | |
| | For strings this simply replaces newlines with '\n~' so that |
| | util.format_explanation() will preserve them instead of escaping |
| | newlines. For other objects saferepr() is used first. |
| | """ |
| | |
| | |
| | |
| | |
| | replaces = [("\n", "\n~"), ("%", "%%")] |
| | if not isinstance(obj, str): |
| | obj = saferepr(obj, _get_maxsize_for_saferepr(util._config)) |
| | replaces.append(("\\n", "\n~")) |
| |
|
| | for r1, r2 in replaces: |
| | obj = obj.replace(r1, r2) |
| |
|
| | return obj |
| |
|
| |
|
| | def _should_repr_global_name(obj: object) -> bool: |
| | if callable(obj): |
| | return False |
| |
|
| | try: |
| | return not hasattr(obj, "__name__") |
| | except Exception: |
| | return True |
| |
|
| |
|
| | def _format_boolop(explanations: Iterable[str], is_or: bool) -> str: |
| | explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")" |
| | return explanation.replace("%", "%%") |
| |
|
| |
|
| | def _call_reprcompare( |
| | ops: Sequence[str], |
| | results: Sequence[bool], |
| | expls: Sequence[str], |
| | each_obj: Sequence[object], |
| | ) -> str: |
| | for i, res, expl in zip(range(len(ops)), results, expls): |
| | try: |
| | done = not res |
| | except Exception: |
| | done = True |
| | if done: |
| | break |
| | if util._reprcompare is not None: |
| | custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1]) |
| | if custom is not None: |
| | return custom |
| | return expl |
| |
|
| |
|
| | def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None: |
| | if util._assertion_pass is not None: |
| | util._assertion_pass(lineno, orig, expl) |
| |
|
| |
|
| | def _check_if_assertion_pass_impl() -> bool: |
| | """Check if any plugins implement the pytest_assertion_pass hook |
| | in order not to generate explanation unnecessarily (might be expensive).""" |
| | return True if util._assertion_pass else False |
| |
|
| |
|
| | UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"} |
| |
|
| | BINOP_MAP = { |
| | ast.BitOr: "|", |
| | ast.BitXor: "^", |
| | ast.BitAnd: "&", |
| | ast.LShift: "<<", |
| | ast.RShift: ">>", |
| | ast.Add: "+", |
| | ast.Sub: "-", |
| | ast.Mult: "*", |
| | ast.Div: "/", |
| | ast.FloorDiv: "//", |
| | ast.Mod: "%%", |
| | ast.Eq: "==", |
| | ast.NotEq: "!=", |
| | ast.Lt: "<", |
| | ast.LtE: "<=", |
| | ast.Gt: ">", |
| | ast.GtE: ">=", |
| | ast.Pow: "**", |
| | ast.Is: "is", |
| | ast.IsNot: "is not", |
| | ast.In: "in", |
| | ast.NotIn: "not in", |
| | ast.MatMult: "@", |
| | } |
| |
|
| |
|
| | def traverse_node(node: ast.AST) -> Iterator[ast.AST]: |
| | """Recursively yield node and all its children in depth-first order.""" |
| | yield node |
| | for child in ast.iter_child_nodes(node): |
| | yield from traverse_node(child) |
| |
|
| |
|
| | @functools.lru_cache(maxsize=1) |
| | def _get_assertion_exprs(src: bytes) -> dict[int, str]: |
| | """Return a mapping from {lineno: "assertion test expression"}.""" |
| | ret: dict[int, str] = {} |
| |
|
| | depth = 0 |
| | lines: list[str] = [] |
| | assert_lineno: int | None = None |
| | seen_lines: set[int] = set() |
| |
|
| | def _write_and_reset() -> None: |
| | nonlocal depth, lines, assert_lineno, seen_lines |
| | assert assert_lineno is not None |
| | ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\") |
| | depth = 0 |
| | lines = [] |
| | assert_lineno = None |
| | seen_lines = set() |
| |
|
| | tokens = tokenize.tokenize(io.BytesIO(src).readline) |
| | for tp, source, (lineno, offset), _, line in tokens: |
| | if tp == tokenize.NAME and source == "assert": |
| | assert_lineno = lineno |
| | elif assert_lineno is not None: |
| | |
| | if tp == tokenize.OP and source in "([{": |
| | depth += 1 |
| | elif tp == tokenize.OP and source in ")]}": |
| | depth -= 1 |
| |
|
| | if not lines: |
| | lines.append(line[offset:]) |
| | seen_lines.add(lineno) |
| | |
| | elif depth == 0 and tp == tokenize.OP and source == ",": |
| | |
| | if lineno in seen_lines and len(lines) == 1: |
| | offset_in_trimmed = offset + len(lines[-1]) - len(line) |
| | lines[-1] = lines[-1][:offset_in_trimmed] |
| | |
| | elif lineno in seen_lines: |
| | lines[-1] = lines[-1][:offset] |
| | |
| | else: |
| | lines.append(line[:offset]) |
| | _write_and_reset() |
| | elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}: |
| | _write_and_reset() |
| | elif lines and lineno not in seen_lines: |
| | lines.append(line) |
| | seen_lines.add(lineno) |
| |
|
| | return ret |
| |
|
| |
|
| | class AssertionRewriter(ast.NodeVisitor): |
| | """Assertion rewriting implementation. |
| | |
| | The main entrypoint is to call .run() with an ast.Module instance, |
| | this will then find all the assert statements and rewrite them to |
| | provide intermediate values and a detailed assertion error. See |
| | http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html |
| | for an overview of how this works. |
| | |
| | The entry point here is .run() which will iterate over all the |
| | statements in an ast.Module and for each ast.Assert statement it |
| | finds call .visit() with it. Then .visit_Assert() takes over and |
| | is responsible for creating new ast statements to replace the |
| | original assert statement: it rewrites the test of an assertion |
| | to provide intermediate values and replace it with an if statement |
| | which raises an assertion error with a detailed explanation in |
| | case the expression is false and calls pytest_assertion_pass hook |
| | if expression is true. |
| | |
| | For this .visit_Assert() uses the visitor pattern to visit all the |
| | AST nodes of the ast.Assert.test field, each visit call returning |
| | an AST node and the corresponding explanation string. During this |
| | state is kept in several instance attributes: |
| | |
| | :statements: All the AST statements which will replace the assert |
| | statement. |
| | |
| | :variables: This is populated by .variable() with each variable |
| | used by the statements so that they can all be set to None at |
| | the end of the statements. |
| | |
| | :variable_counter: Counter to create new unique variables needed |
| | by statements. Variables are created using .variable() and |
| | have the form of "@py_assert0". |
| | |
| | :expl_stmts: The AST statements which will be executed to get |
| | data from the assertion. This is the code which will construct |
| | the detailed assertion message that is used in the AssertionError |
| | or for the pytest_assertion_pass hook. |
| | |
| | :explanation_specifiers: A dict filled by .explanation_param() |
| | with %-formatting placeholders and their corresponding |
| | expressions to use in the building of an assertion message. |
| | This is used by .pop_format_context() to build a message. |
| | |
| | :stack: A stack of the explanation_specifiers dicts maintained by |
| | .push_format_context() and .pop_format_context() which allows |
| | to build another %-formatted string while already building one. |
| | |
| | :scope: A tuple containing the current scope used for variables_overwrite. |
| | |
| | :variables_overwrite: A dict filled with references to variables |
| | that change value within an assert. This happens when a variable is |
| | reassigned with the walrus operator |
| | |
| | This state, except the variables_overwrite, is reset on every new assert |
| | statement visited and used by the other visitors. |
| | """ |
| |
|
| | def __init__( |
| | self, module_path: str | None, config: Config | None, source: bytes |
| | ) -> None: |
| | super().__init__() |
| | self.module_path = module_path |
| | self.config = config |
| | if config is not None: |
| | self.enable_assertion_pass_hook = config.getini( |
| | "enable_assertion_pass_hook" |
| | ) |
| | else: |
| | self.enable_assertion_pass_hook = False |
| | self.source = source |
| | self.scope: tuple[ast.AST, ...] = () |
| | self.variables_overwrite: defaultdict[tuple[ast.AST, ...], dict[str, str]] = ( |
| | defaultdict(dict) |
| | ) |
| |
|
| | def run(self, mod: ast.Module) -> None: |
| | """Find all assert statements in *mod* and rewrite them.""" |
| | if not mod.body: |
| | |
| | return |
| |
|
| | |
| | |
| | doc = getattr(mod, "docstring", None) |
| | expect_docstring = doc is None |
| | if doc is not None and self.is_rewrite_disabled(doc): |
| | return |
| | pos = 0 |
| | item = None |
| | for item in mod.body: |
| | if ( |
| | expect_docstring |
| | and isinstance(item, ast.Expr) |
| | and isinstance(item.value, ast.Constant) |
| | and isinstance(item.value.value, str) |
| | ): |
| | doc = item.value.value |
| | if self.is_rewrite_disabled(doc): |
| | return |
| | expect_docstring = False |
| | elif ( |
| | isinstance(item, ast.ImportFrom) |
| | and item.level == 0 |
| | and item.module == "__future__" |
| | ): |
| | pass |
| | else: |
| | break |
| | pos += 1 |
| | |
| | |
| | if isinstance(item, ast.FunctionDef) and item.decorator_list: |
| | lineno = item.decorator_list[0].lineno |
| | else: |
| | lineno = item.lineno |
| | |
| | if sys.version_info >= (3, 10): |
| | aliases = [ |
| | ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0), |
| | ast.alias( |
| | "_pytest.assertion.rewrite", |
| | "@pytest_ar", |
| | lineno=lineno, |
| | col_offset=0, |
| | ), |
| | ] |
| | else: |
| | aliases = [ |
| | ast.alias("builtins", "@py_builtins"), |
| | ast.alias("_pytest.assertion.rewrite", "@pytest_ar"), |
| | ] |
| | imports = [ |
| | ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases |
| | ] |
| | mod.body[pos:pos] = imports |
| |
|
| | |
| | self.scope = (mod,) |
| | nodes: list[ast.AST | Sentinel] = [mod] |
| | while nodes: |
| | node = nodes.pop() |
| | if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): |
| | self.scope = tuple((*self.scope, node)) |
| | nodes.append(_SCOPE_END_MARKER) |
| | if node == _SCOPE_END_MARKER: |
| | self.scope = self.scope[:-1] |
| | continue |
| | assert isinstance(node, ast.AST) |
| | for name, field in ast.iter_fields(node): |
| | if isinstance(field, list): |
| | new: list[ast.AST] = [] |
| | for i, child in enumerate(field): |
| | if isinstance(child, ast.Assert): |
| | |
| | new.extend(self.visit(child)) |
| | else: |
| | new.append(child) |
| | if isinstance(child, ast.AST): |
| | nodes.append(child) |
| | setattr(node, name, new) |
| | elif ( |
| | isinstance(field, ast.AST) |
| | |
| | |
| | and not isinstance(field, ast.expr) |
| | ): |
| | nodes.append(field) |
| |
|
| | @staticmethod |
| | def is_rewrite_disabled(docstring: str) -> bool: |
| | return "PYTEST_DONT_REWRITE" in docstring |
| |
|
| | def variable(self) -> str: |
| | """Get a new variable.""" |
| | |
| | name = "@py_assert" + str(next(self.variable_counter)) |
| | self.variables.append(name) |
| | return name |
| |
|
| | def assign(self, expr: ast.expr) -> ast.Name: |
| | """Give *expr* a name.""" |
| | name = self.variable() |
| | self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr)) |
| | return ast.copy_location(ast.Name(name, ast.Load()), expr) |
| |
|
| | def display(self, expr: ast.expr) -> ast.expr: |
| | """Call saferepr on the expression.""" |
| | return self.helper("_saferepr", expr) |
| |
|
| | def helper(self, name: str, *args: ast.expr) -> ast.expr: |
| | """Call a helper in this module.""" |
| | py_name = ast.Name("@pytest_ar", ast.Load()) |
| | attr = ast.Attribute(py_name, name, ast.Load()) |
| | return ast.Call(attr, list(args), []) |
| |
|
| | def builtin(self, name: str) -> ast.Attribute: |
| | """Return the builtin called *name*.""" |
| | builtin_name = ast.Name("@py_builtins", ast.Load()) |
| | return ast.Attribute(builtin_name, name, ast.Load()) |
| |
|
| | def explanation_param(self, expr: ast.expr) -> str: |
| | """Return a new named %-formatting placeholder for expr. |
| | |
| | This creates a %-formatting placeholder for expr in the |
| | current formatting context, e.g. ``%(py0)s``. The placeholder |
| | and expr are placed in the current format context so that it |
| | can be used on the next call to .pop_format_context(). |
| | """ |
| | specifier = "py" + str(next(self.variable_counter)) |
| | self.explanation_specifiers[specifier] = expr |
| | return "%(" + specifier + ")s" |
| |
|
| | def push_format_context(self) -> None: |
| | """Create a new formatting context. |
| | |
| | The format context is used for when an explanation wants to |
| | have a variable value formatted in the assertion message. In |
| | this case the value required can be added using |
| | .explanation_param(). Finally .pop_format_context() is used |
| | to format a string of %-formatted values as added by |
| | .explanation_param(). |
| | """ |
| | self.explanation_specifiers: dict[str, ast.expr] = {} |
| | self.stack.append(self.explanation_specifiers) |
| |
|
| | def pop_format_context(self, expl_expr: ast.expr) -> ast.Name: |
| | """Format the %-formatted string with current format context. |
| | |
| | The expl_expr should be an str ast.expr instance constructed from |
| | the %-placeholders created by .explanation_param(). This will |
| | add the required code to format said string to .expl_stmts and |
| | return the ast.Name instance of the formatted string. |
| | """ |
| | current = self.stack.pop() |
| | if self.stack: |
| | self.explanation_specifiers = self.stack[-1] |
| | keys: list[ast.expr | None] = [ast.Constant(key) for key in current.keys()] |
| | format_dict = ast.Dict(keys, list(current.values())) |
| | form = ast.BinOp(expl_expr, ast.Mod(), format_dict) |
| | name = "@py_format" + str(next(self.variable_counter)) |
| | if self.enable_assertion_pass_hook: |
| | self.format_variables.append(name) |
| | self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form)) |
| | return ast.Name(name, ast.Load()) |
| |
|
| | def generic_visit(self, node: ast.AST) -> tuple[ast.Name, str]: |
| | """Handle expressions we don't have custom code for.""" |
| | assert isinstance(node, ast.expr) |
| | res = self.assign(node) |
| | return res, self.explanation_param(self.display(res)) |
| |
|
| | def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]: |
| | """Return the AST statements to replace the ast.Assert instance. |
| | |
| | This rewrites the test of an assertion to provide |
| | intermediate values and replace it with an if statement which |
| | raises an assertion error with a detailed explanation in case |
| | the expression is false. |
| | """ |
| | if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1: |
| | import warnings |
| |
|
| | from _pytest.warning_types import PytestAssertRewriteWarning |
| |
|
| | |
| | assert self.module_path is not None |
| | warnings.warn_explicit( |
| | PytestAssertRewriteWarning( |
| | "assertion is always true, perhaps remove parentheses?" |
| | ), |
| | category=None, |
| | filename=self.module_path, |
| | lineno=assert_.lineno, |
| | ) |
| |
|
| | self.statements: list[ast.stmt] = [] |
| | self.variables: list[str] = [] |
| | self.variable_counter = itertools.count() |
| |
|
| | if self.enable_assertion_pass_hook: |
| | self.format_variables: list[str] = [] |
| |
|
| | self.stack: list[dict[str, ast.expr]] = [] |
| | self.expl_stmts: list[ast.stmt] = [] |
| | self.push_format_context() |
| | |
| | top_condition, explanation = self.visit(assert_.test) |
| |
|
| | negation = ast.UnaryOp(ast.Not(), top_condition) |
| |
|
| | if self.enable_assertion_pass_hook: |
| | msg = self.pop_format_context(ast.Constant(explanation)) |
| |
|
| | |
| | if assert_.msg: |
| | assertmsg = self.helper("_format_assertmsg", assert_.msg) |
| | gluestr = "\n>assert " |
| | else: |
| | assertmsg = ast.Constant("") |
| | gluestr = "assert " |
| | err_explanation = ast.BinOp(ast.Constant(gluestr), ast.Add(), msg) |
| | err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation) |
| | err_name = ast.Name("AssertionError", ast.Load()) |
| | fmt = self.helper("_format_explanation", err_msg) |
| | exc = ast.Call(err_name, [fmt], []) |
| | raise_ = ast.Raise(exc, None) |
| | statements_fail = [] |
| | statements_fail.extend(self.expl_stmts) |
| | statements_fail.append(raise_) |
| |
|
| | |
| | fmt_pass = self.helper("_format_explanation", msg) |
| | orig = _get_assertion_exprs(self.source)[assert_.lineno] |
| | hook_call_pass = ast.Expr( |
| | self.helper( |
| | "_call_assertion_pass", |
| | ast.Constant(assert_.lineno), |
| | ast.Constant(orig), |
| | fmt_pass, |
| | ) |
| | ) |
| | |
| | hook_impl_test = ast.If( |
| | self.helper("_check_if_assertion_pass_impl"), |
| | [*self.expl_stmts, hook_call_pass], |
| | [], |
| | ) |
| | statements_pass: list[ast.stmt] = [hook_impl_test] |
| |
|
| | |
| | main_test = ast.If(negation, statements_fail, statements_pass) |
| | self.statements.append(main_test) |
| | if self.format_variables: |
| | variables: list[ast.expr] = [ |
| | ast.Name(name, ast.Store()) for name in self.format_variables |
| | ] |
| | clear_format = ast.Assign(variables, ast.Constant(None)) |
| | self.statements.append(clear_format) |
| |
|
| | else: |
| | |
| | body = self.expl_stmts |
| | self.statements.append(ast.If(negation, body, [])) |
| | if assert_.msg: |
| | assertmsg = self.helper("_format_assertmsg", assert_.msg) |
| | explanation = "\n>assert " + explanation |
| | else: |
| | assertmsg = ast.Constant("") |
| | explanation = "assert " + explanation |
| | template = ast.BinOp(assertmsg, ast.Add(), ast.Constant(explanation)) |
| | msg = self.pop_format_context(template) |
| | fmt = self.helper("_format_explanation", msg) |
| | err_name = ast.Name("AssertionError", ast.Load()) |
| | exc = ast.Call(err_name, [fmt], []) |
| | raise_ = ast.Raise(exc, None) |
| |
|
| | body.append(raise_) |
| |
|
| | |
| | if self.variables: |
| | variables = [ast.Name(name, ast.Store()) for name in self.variables] |
| | clear = ast.Assign(variables, ast.Constant(None)) |
| | self.statements.append(clear) |
| | |
| | for stmt in self.statements: |
| | for node in traverse_node(stmt): |
| | if getattr(node, "lineno", None) is None: |
| | |
| | |
| | ast.copy_location(node, assert_) |
| | return self.statements |
| |
|
| | def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]: |
| | |
| | |
| | |
| | locs = ast.Call(self.builtin("locals"), [], []) |
| | target_id = name.target.id |
| | inlocs = ast.Compare(ast.Constant(target_id), [ast.In()], [locs]) |
| | dorepr = self.helper("_should_repr_global_name", name) |
| | test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) |
| | expr = ast.IfExp(test, self.display(name), ast.Constant(target_id)) |
| | return name, self.explanation_param(expr) |
| |
|
| | def visit_Name(self, name: ast.Name) -> tuple[ast.Name, str]: |
| | |
| | |
| | locs = ast.Call(self.builtin("locals"), [], []) |
| | inlocs = ast.Compare(ast.Constant(name.id), [ast.In()], [locs]) |
| | dorepr = self.helper("_should_repr_global_name", name) |
| | test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) |
| | expr = ast.IfExp(test, self.display(name), ast.Constant(name.id)) |
| | return name, self.explanation_param(expr) |
| |
|
| | def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: |
| | res_var = self.variable() |
| | expl_list = self.assign(ast.List([], ast.Load())) |
| | app = ast.Attribute(expl_list, "append", ast.Load()) |
| | is_or = int(isinstance(boolop.op, ast.Or)) |
| | body = save = self.statements |
| | fail_save = self.expl_stmts |
| | levels = len(boolop.values) - 1 |
| | self.push_format_context() |
| | |
| | for i, v in enumerate(boolop.values): |
| | if i: |
| | fail_inner: list[ast.stmt] = [] |
| | |
| | self.expl_stmts.append(ast.If(cond, fail_inner, [])) |
| | self.expl_stmts = fail_inner |
| | |
| | if ( |
| | isinstance(v, ast.Compare) |
| | and isinstance(v.left, ast.NamedExpr) |
| | and v.left.target.id |
| | in [ |
| | ast_expr.id |
| | for ast_expr in boolop.values[:i] |
| | if hasattr(ast_expr, "id") |
| | ] |
| | ): |
| | pytest_temp = self.variable() |
| | self.variables_overwrite[self.scope][v.left.target.id] = v.left |
| | v.left.target.id = pytest_temp |
| | self.push_format_context() |
| | res, expl = self.visit(v) |
| | body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) |
| | expl_format = self.pop_format_context(ast.Constant(expl)) |
| | call = ast.Call(app, [expl_format], []) |
| | self.expl_stmts.append(ast.Expr(call)) |
| | if i < levels: |
| | cond: ast.expr = res |
| | if is_or: |
| | cond = ast.UnaryOp(ast.Not(), cond) |
| | inner: list[ast.stmt] = [] |
| | self.statements.append(ast.If(cond, inner, [])) |
| | self.statements = body = inner |
| | self.statements = save |
| | self.expl_stmts = fail_save |
| | expl_template = self.helper("_format_boolop", expl_list, ast.Constant(is_or)) |
| | expl = self.pop_format_context(expl_template) |
| | return ast.Name(res_var, ast.Load()), self.explanation_param(expl) |
| |
|
| | def visit_UnaryOp(self, unary: ast.UnaryOp) -> tuple[ast.Name, str]: |
| | pattern = UNARY_MAP[unary.op.__class__] |
| | operand_res, operand_expl = self.visit(unary.operand) |
| | res = self.assign(ast.copy_location(ast.UnaryOp(unary.op, operand_res), unary)) |
| | return res, pattern % (operand_expl,) |
| |
|
| | def visit_BinOp(self, binop: ast.BinOp) -> tuple[ast.Name, str]: |
| | symbol = BINOP_MAP[binop.op.__class__] |
| | left_expr, left_expl = self.visit(binop.left) |
| | right_expr, right_expl = self.visit(binop.right) |
| | explanation = f"({left_expl} {symbol} {right_expl})" |
| | res = self.assign( |
| | ast.copy_location(ast.BinOp(left_expr, binop.op, right_expr), binop) |
| | ) |
| | return res, explanation |
| |
|
| | def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]: |
| | new_func, func_expl = self.visit(call.func) |
| | arg_expls = [] |
| | new_args = [] |
| | new_kwargs = [] |
| | for arg in call.args: |
| | if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get( |
| | self.scope, {} |
| | ): |
| | arg = self.variables_overwrite[self.scope][arg.id] |
| | res, expl = self.visit(arg) |
| | arg_expls.append(expl) |
| | new_args.append(res) |
| | for keyword in call.keywords: |
| | if isinstance( |
| | keyword.value, ast.Name |
| | ) and keyword.value.id in self.variables_overwrite.get(self.scope, {}): |
| | keyword.value = self.variables_overwrite[self.scope][keyword.value.id] |
| | res, expl = self.visit(keyword.value) |
| | new_kwargs.append(ast.keyword(keyword.arg, res)) |
| | if keyword.arg: |
| | arg_expls.append(keyword.arg + "=" + expl) |
| | else: |
| | arg_expls.append("**" + expl) |
| |
|
| | expl = "{}({})".format(func_expl, ", ".join(arg_expls)) |
| | new_call = ast.copy_location(ast.Call(new_func, new_args, new_kwargs), call) |
| | res = self.assign(new_call) |
| | res_expl = self.explanation_param(self.display(res)) |
| | outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}" |
| | return res, outer_expl |
| |
|
| | def visit_Starred(self, starred: ast.Starred) -> tuple[ast.Starred, str]: |
| | |
| | res, expl = self.visit(starred.value) |
| | new_starred = ast.Starred(res, starred.ctx) |
| | return new_starred, "*" + expl |
| |
|
| | def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]: |
| | if not isinstance(attr.ctx, ast.Load): |
| | return self.generic_visit(attr) |
| | value, value_expl = self.visit(attr.value) |
| | res = self.assign( |
| | ast.copy_location(ast.Attribute(value, attr.attr, ast.Load()), attr) |
| | ) |
| | res_expl = self.explanation_param(self.display(res)) |
| | pat = "%s\n{%s = %s.%s\n}" |
| | expl = pat % (res_expl, res_expl, value_expl, attr.attr) |
| | return res, expl |
| |
|
| | def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: |
| | self.push_format_context() |
| | |
| | if isinstance( |
| | comp.left, ast.Name |
| | ) and comp.left.id in self.variables_overwrite.get(self.scope, {}): |
| | comp.left = self.variables_overwrite[self.scope][comp.left.id] |
| | if isinstance(comp.left, ast.NamedExpr): |
| | self.variables_overwrite[self.scope][comp.left.target.id] = comp.left |
| | left_res, left_expl = self.visit(comp.left) |
| | if isinstance(comp.left, (ast.Compare, ast.BoolOp)): |
| | left_expl = f"({left_expl})" |
| | res_variables = [self.variable() for i in range(len(comp.ops))] |
| | load_names: list[ast.expr] = [ast.Name(v, ast.Load()) for v in res_variables] |
| | store_names = [ast.Name(v, ast.Store()) for v in res_variables] |
| | it = zip(range(len(comp.ops)), comp.ops, comp.comparators) |
| | expls: list[ast.expr] = [] |
| | syms: list[ast.expr] = [] |
| | results = [left_res] |
| | for i, op, next_operand in it: |
| | if ( |
| | isinstance(next_operand, ast.NamedExpr) |
| | and isinstance(left_res, ast.Name) |
| | and next_operand.target.id == left_res.id |
| | ): |
| | next_operand.target.id = self.variable() |
| | self.variables_overwrite[self.scope][left_res.id] = next_operand |
| | next_res, next_expl = self.visit(next_operand) |
| | if isinstance(next_operand, (ast.Compare, ast.BoolOp)): |
| | next_expl = f"({next_expl})" |
| | results.append(next_res) |
| | sym = BINOP_MAP[op.__class__] |
| | syms.append(ast.Constant(sym)) |
| | expl = f"{left_expl} {sym} {next_expl}" |
| | expls.append(ast.Constant(expl)) |
| | res_expr = ast.copy_location(ast.Compare(left_res, [op], [next_res]), comp) |
| | self.statements.append(ast.Assign([store_names[i]], res_expr)) |
| | left_res, left_expl = next_res, next_expl |
| | |
| | expl_call = self.helper( |
| | "_call_reprcompare", |
| | ast.Tuple(syms, ast.Load()), |
| | ast.Tuple(load_names, ast.Load()), |
| | ast.Tuple(expls, ast.Load()), |
| | ast.Tuple(results, ast.Load()), |
| | ) |
| | if len(comp.ops) > 1: |
| | res: ast.expr = ast.BoolOp(ast.And(), load_names) |
| | else: |
| | res = load_names[0] |
| |
|
| | return res, self.explanation_param(self.pop_format_context(expl_call)) |
| |
|
| |
|
| | def try_makedirs(cache_dir: Path) -> bool: |
| | """Attempt to create the given directory and sub-directories exist. |
| | |
| | Returns True if successful or if it already exists. |
| | """ |
| | try: |
| | os.makedirs(cache_dir, exist_ok=True) |
| | except (FileNotFoundError, NotADirectoryError, FileExistsError): |
| | |
| | |
| | |
| | return False |
| | except PermissionError: |
| | return False |
| | except OSError as e: |
| | |
| | |
| | |
| | |
| | if e.errno in {errno.EROFS, errno.ENOSYS}: |
| | return False |
| | raise |
| | return True |
| |
|
| |
|
| | def get_cache_dir(file_path: Path) -> Path: |
| | """Return the cache directory to write .pyc files for the given .py file path.""" |
| | if sys.pycache_prefix: |
| | |
| | |
| | |
| | |
| | |
| | return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1]) |
| | else: |
| | |
| | return file_path.parent / "__pycache__" |
| |
|