| |
| from __future__ import annotations |
|
|
| import ast |
| from bisect import bisect_right |
| from collections.abc import Iterable |
| from collections.abc import Iterator |
| import inspect |
| import textwrap |
| import tokenize |
| import types |
| from typing import overload |
| import warnings |
|
|
|
|
| class Source: |
| """An immutable object holding a source code fragment. |
| |
| When using Source(...), the source lines are deindented. |
| """ |
|
|
| def __init__(self, obj: object = None) -> None: |
| if not obj: |
| self.lines: list[str] = [] |
| self.raw_lines: list[str] = [] |
| elif isinstance(obj, Source): |
| self.lines = obj.lines |
| self.raw_lines = obj.raw_lines |
| elif isinstance(obj, (tuple, list)): |
| self.lines = deindent(x.rstrip("\n") for x in obj) |
| self.raw_lines = list(x.rstrip("\n") for x in obj) |
| elif isinstance(obj, str): |
| self.lines = deindent(obj.split("\n")) |
| self.raw_lines = obj.split("\n") |
| else: |
| try: |
| rawcode = getrawcode(obj) |
| src = inspect.getsource(rawcode) |
| except TypeError: |
| src = inspect.getsource(obj) |
| self.lines = deindent(src.split("\n")) |
| self.raw_lines = src.split("\n") |
|
|
| def __eq__(self, other: object) -> bool: |
| if not isinstance(other, Source): |
| return NotImplemented |
| return self.lines == other.lines |
|
|
| |
| __hash__ = None |
|
|
| @overload |
| def __getitem__(self, key: int) -> str: ... |
|
|
| @overload |
| def __getitem__(self, key: slice) -> Source: ... |
|
|
| def __getitem__(self, key: int | slice) -> str | Source: |
| if isinstance(key, int): |
| return self.lines[key] |
| else: |
| if key.step not in (None, 1): |
| raise IndexError("cannot slice a Source with a step") |
| newsource = Source() |
| newsource.lines = self.lines[key.start : key.stop] |
| newsource.raw_lines = self.raw_lines[key.start : key.stop] |
| return newsource |
|
|
| def __iter__(self) -> Iterator[str]: |
| return iter(self.lines) |
|
|
| def __len__(self) -> int: |
| return len(self.lines) |
|
|
| def strip(self) -> Source: |
| """Return new Source object with trailing and leading blank lines removed.""" |
| start, end = 0, len(self) |
| while start < end and not self.lines[start].strip(): |
| start += 1 |
| while end > start and not self.lines[end - 1].strip(): |
| end -= 1 |
| source = Source() |
| source.raw_lines = self.raw_lines |
| source.lines[:] = self.lines[start:end] |
| return source |
|
|
| def indent(self, indent: str = " " * 4) -> Source: |
| """Return a copy of the source object with all lines indented by the |
| given indent-string.""" |
| newsource = Source() |
| newsource.raw_lines = self.raw_lines |
| newsource.lines = [(indent + line) for line in self.lines] |
| return newsource |
|
|
| def getstatement(self, lineno: int) -> Source: |
| """Return Source statement which contains the given linenumber |
| (counted from 0).""" |
| start, end = self.getstatementrange(lineno) |
| return self[start:end] |
|
|
| def getstatementrange(self, lineno: int) -> tuple[int, int]: |
| """Return (start, end) tuple which spans the minimal statement region |
| which containing the given lineno.""" |
| if not (0 <= lineno < len(self)): |
| raise IndexError("lineno out of range") |
| ast, start, end = getstatementrange_ast(lineno, self) |
| return start, end |
|
|
| def deindent(self) -> Source: |
| """Return a new Source object deindented.""" |
| newsource = Source() |
| newsource.lines[:] = deindent(self.lines) |
| newsource.raw_lines = self.raw_lines |
| return newsource |
|
|
| def __str__(self) -> str: |
| return "\n".join(self.lines) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def findsource(obj) -> tuple[Source | None, int]: |
| try: |
| sourcelines, lineno = inspect.findsource(obj) |
| except Exception: |
| return None, -1 |
| source = Source() |
| source.lines = [line.rstrip() for line in sourcelines] |
| source.raw_lines = sourcelines |
| return source, lineno |
|
|
|
|
| def getrawcode(obj: object, trycall: bool = True) -> types.CodeType: |
| """Return code object for given function.""" |
| try: |
| return obj.__code__ |
| except AttributeError: |
| pass |
| if trycall: |
| call = getattr(obj, "__call__", None) |
| if call and not isinstance(obj, type): |
| return getrawcode(call, trycall=False) |
| raise TypeError(f"could not get code object for {obj!r}") |
|
|
|
|
| def deindent(lines: Iterable[str]) -> list[str]: |
| return textwrap.dedent("\n".join(lines)).splitlines() |
|
|
|
|
| def get_statement_startend2(lineno: int, node: ast.AST) -> tuple[int, int | None]: |
| |
| |
| values: list[int] = [] |
| for x in ast.walk(node): |
| if isinstance(x, (ast.stmt, ast.ExceptHandler)): |
| |
| if isinstance(x, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)): |
| for d in x.decorator_list: |
| values.append(d.lineno - 1) |
| values.append(x.lineno - 1) |
| for name in ("finalbody", "orelse"): |
| val: list[ast.stmt] | None = getattr(x, name, None) |
| if val: |
| |
| values.append(val[0].lineno - 1 - 1) |
| values.sort() |
| insert_index = bisect_right(values, lineno) |
| start = values[insert_index - 1] |
| if insert_index >= len(values): |
| end = None |
| else: |
| end = values[insert_index] |
| return start, end |
|
|
|
|
| def getstatementrange_ast( |
| lineno: int, |
| source: Source, |
| assertion: bool = False, |
| astnode: ast.AST | None = None, |
| ) -> tuple[ast.AST, int, int]: |
| if astnode is None: |
| content = str(source) |
| |
| |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| astnode = ast.parse(content, "source", "exec") |
|
|
| start, end = get_statement_startend2(lineno, astnode) |
| |
| |
| |
| |
| if end is None: |
| end = len(source.lines) |
|
|
| if end > start + 1: |
| |
| |
| block_finder = inspect.BlockFinder() |
| |
| block_finder.started = ( |
| bool(source.lines[start]) and source.lines[start][0].isspace() |
| ) |
| it = ((x + "\n") for x in source.lines[start:end]) |
| try: |
| for tok in tokenize.generate_tokens(lambda: next(it)): |
| block_finder.tokeneater(*tok) |
| except (inspect.EndOfBlock, IndentationError): |
| end = block_finder.last + start |
| except Exception: |
| pass |
|
|
| |
| while end: |
| line = source.lines[end - 1].lstrip() |
| if line.startswith("#") or not line: |
| end -= 1 |
| else: |
| break |
| return astnode, start, end |
|
|