| """Nice output for Black. |
| |
| The double calls are for patching purposes in tests. |
| """ |
|
|
| import json |
| import re |
| import tempfile |
| from typing import Any |
|
|
| from click import echo, style |
| from mypy_extensions import mypyc_attr |
|
|
|
|
| @mypyc_attr(patchable=True) |
| def _out(message: str | None = None, nl: bool = True, **styles: Any) -> None: |
| if message is not None: |
| if "bold" not in styles: |
| styles["bold"] = True |
| message = style(message, **styles) |
| echo(message, nl=nl, err=True) |
|
|
|
|
| @mypyc_attr(patchable=True) |
| def _err(message: str | None = None, nl: bool = True, **styles: Any) -> None: |
| if message is not None: |
| if "fg" not in styles: |
| styles["fg"] = "red" |
| message = style(message, **styles) |
| echo(message, nl=nl, err=True) |
|
|
|
|
| @mypyc_attr(patchable=True) |
| def out(message: str | None = None, nl: bool = True, **styles: Any) -> None: |
| _out(message, nl=nl, **styles) |
|
|
|
|
| def err(message: str | None = None, nl: bool = True, **styles: Any) -> None: |
| _err(message, nl=nl, **styles) |
|
|
|
|
| def ipynb_diff(a: str, b: str, a_name: str, b_name: str) -> str: |
| """Return a unified diff string between each cell in notebooks `a` and `b`.""" |
| a_nb = json.loads(a) |
| b_nb = json.loads(b) |
| diff_lines = [ |
| diff( |
| "".join(a_nb["cells"][cell_number]["source"]) + "\n", |
| "".join(b_nb["cells"][cell_number]["source"]) + "\n", |
| f"{a_name}:cell_{cell_number}", |
| f"{b_name}:cell_{cell_number}", |
| ) |
| for cell_number, cell in enumerate(a_nb["cells"]) |
| if cell["cell_type"] == "code" |
| ] |
| return "".join(diff_lines) |
|
|
|
|
| _line_pattern = re.compile(r"(.*?(?:\r\n|\n|\r|$))") |
|
|
|
|
| def _splitlines_no_ff(source: str) -> list[str]: |
| """Split a string into lines ignoring form feed and other chars. |
| |
| This mimics how the Python parser splits source code. |
| |
| A simplified version of the function with the same name in Lib/ast.py |
| """ |
| result = [match[0] for match in _line_pattern.finditer(source)] |
| if result[-1] == "": |
| result.pop(-1) |
| return result |
|
|
|
|
| def diff(a: str, b: str, a_name: str, b_name: str) -> str: |
| """Return a unified diff string between strings `a` and `b`.""" |
| import difflib |
|
|
| a_lines = _splitlines_no_ff(a) |
| b_lines = _splitlines_no_ff(b) |
| diff_lines = [] |
| for line in difflib.unified_diff( |
| a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5 |
| ): |
| |
| |
| |
| if line[-1] == "\n": |
| diff_lines.append(line) |
| else: |
| diff_lines.append(line + "\n") |
| diff_lines.append("\\ No newline at end of file\n") |
| return "".join(diff_lines) |
|
|
|
|
| def color_diff(contents: str) -> str: |
| """Inject the ANSI color codes to the diff.""" |
| lines = contents.split("\n") |
| for i, line in enumerate(lines): |
| if line.startswith("+++") or line.startswith("---"): |
| line = "\033[1m" + line + "\033[0m" |
| elif line.startswith("@@"): |
| line = "\033[36m" + line + "\033[0m" |
| elif line.startswith("+"): |
| line = "\033[32m" + line + "\033[0m" |
| elif line.startswith("-"): |
| line = "\033[31m" + line + "\033[0m" |
| lines[i] = line |
| return "\n".join(lines) |
|
|
|
|
| @mypyc_attr(patchable=True) |
| def dump_to_file(*output: str, ensure_final_newline: bool = True) -> str: |
| """Dump `output` to a temporary file. Return path to the file.""" |
| with tempfile.NamedTemporaryFile( |
| mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8" |
| ) as f: |
| for lines in output: |
| f.write(lines) |
| if ensure_final_newline and lines and lines[-1] != "\n": |
| f.write("\n") |
| return f.name |
|
|