| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import ast |
| import collections |
| import io |
| import sys |
| import token |
| import tokenize |
| from abc import ABCMeta |
| from ast import Module, expr, AST |
| from functools import lru_cache |
| from typing import ( |
| Callable, |
| Dict, |
| Iterable, |
| Iterator, |
| List, |
| Optional, |
| Tuple, |
| Union, |
| cast, |
| Any, |
| TYPE_CHECKING, |
| Type, |
| ) |
|
|
| if TYPE_CHECKING: |
| from .astroid_compat import NodeNG |
|
|
| |
| |
| class EnhancedAST(AST): |
| |
| first_token = None |
| last_token = None |
| lineno = 0 |
| end_lineno = 0 |
| end_col_offset = 0 |
|
|
| AstNode = Union[EnhancedAST, NodeNG] |
|
|
| TokenInfo = tokenize.TokenInfo |
|
|
|
|
| def token_repr(tok_type, string): |
| |
| """Returns a human-friendly representation of a token with the given type and string.""" |
| |
| return '%s:%s' % (token.tok_name[tok_type], repr(string).lstrip('u')) |
|
|
|
|
| class Token(collections.namedtuple('Token', 'type string start end line index startpos endpos')): |
| """ |
| TokenInfo is an 8-tuple containing the same 5 fields as the tokens produced by the tokenize |
| module, and 3 additional ones useful for this module: |
| |
| - [0] .type Token type (see token.py) |
| - [1] .string Token (a string) |
| - [2] .start Starting (row, column) indices of the token (a 2-tuple of ints) |
| - [3] .end Ending (row, column) indices of the token (a 2-tuple of ints) |
| - [4] .line Original line (string) |
| - [5] .index Index of the token in the list of tokens that it belongs to. |
| - [6] .startpos Starting character offset into the input text. |
| - [7] .endpos Ending character offset into the input text. |
| """ |
| def __str__(self): |
| |
| return token_repr(self.type, self.string) |
|
|
|
|
| def match_token(token, tok_type, tok_str=None): |
| |
| """Returns true if token is of the given type and, if a string is given, has that string.""" |
| return token.type == tok_type and (tok_str is None or token.string == tok_str) |
|
|
|
|
| def expect_token(token, tok_type, tok_str=None): |
| |
| """ |
| Verifies that the given token is of the expected type. If tok_str is given, the token string |
| is verified too. If the token doesn't match, raises an informative ValueError. |
| """ |
| if not match_token(token, tok_type, tok_str): |
| raise ValueError("Expected token %s, got %s on line %s col %s" % ( |
| token_repr(tok_type, tok_str), str(token), |
| token.start[0], token.start[1] + 1)) |
|
|
|
|
| def is_non_coding_token(token_type): |
| |
| """ |
| These are considered non-coding tokens, as they don't affect the syntax tree. |
| """ |
| return token_type in (token.NL, token.COMMENT, token.ENCODING) |
|
|
|
|
| def generate_tokens(text): |
| |
| """ |
| Generates standard library tokens for the given code. |
| """ |
| |
| |
| |
| return tokenize.generate_tokens(cast(Callable[[], str], io.StringIO(text).readline)) |
|
|
|
|
| def iter_children_func(node): |
| |
| """ |
| Returns a function which yields all direct children of a AST node, |
| skipping children that are singleton nodes. |
| The function depends on whether ``node`` is from ``ast`` or from the ``astroid`` module. |
| """ |
| return iter_children_astroid if hasattr(node, 'get_children') else iter_children_ast |
|
|
|
|
| def iter_children_astroid(node, include_joined_str=False): |
| |
| if not include_joined_str and is_joined_str(node): |
| return [] |
|
|
| return node.get_children() |
|
|
|
|
| SINGLETONS = {c for n, c in ast.__dict__.items() if isinstance(c, type) and |
| issubclass(c, (ast.expr_context, ast.boolop, ast.operator, ast.unaryop, ast.cmpop))} |
|
|
|
|
| def iter_children_ast(node, include_joined_str=False): |
| |
| if not include_joined_str and is_joined_str(node): |
| return |
|
|
| if isinstance(node, ast.Dict): |
| |
| |
| for (key, value) in zip(node.keys, node.values): |
| if key is not None: |
| yield key |
| yield value |
| return |
|
|
| for child in ast.iter_child_nodes(node): |
| |
| |
| |
| if child.__class__ not in SINGLETONS: |
| yield child |
|
|
|
|
| stmt_class_names = {n for n, c in ast.__dict__.items() |
| if isinstance(c, type) and issubclass(c, ast.stmt)} |
| expr_class_names = ({n for n, c in ast.__dict__.items() |
| if isinstance(c, type) and issubclass(c, ast.expr)} | |
| {'AssignName', 'DelName', 'Const', 'AssignAttr', 'DelAttr'}) |
|
|
| |
| |
| def is_expr(node): |
| |
| """Returns whether node is an expression node.""" |
| return node.__class__.__name__ in expr_class_names |
|
|
| def is_stmt(node): |
| |
| """Returns whether node is a statement node.""" |
| return node.__class__.__name__ in stmt_class_names |
|
|
| def is_module(node): |
| |
| """Returns whether node is a module node.""" |
| return node.__class__.__name__ == 'Module' |
|
|
| def is_joined_str(node): |
| |
| """Returns whether node is a JoinedStr node, used to represent f-strings.""" |
| |
| |
| return node.__class__.__name__ == 'JoinedStr' |
|
|
|
|
| def is_expr_stmt(node): |
| |
| """Returns whether node is an `Expr` node, which is a statement that is an expression.""" |
| return node.__class__.__name__ == 'Expr' |
|
|
|
|
|
|
| CONSTANT_CLASSES: Tuple[Type, ...] = (ast.Constant,) |
| try: |
| from astroid.nodes import Const |
| CONSTANT_CLASSES += (Const,) |
| except ImportError: |
| |
| pass |
|
|
| def is_constant(node): |
| |
| """Returns whether node is a Constant node.""" |
| return isinstance(node, CONSTANT_CLASSES) |
|
|
|
|
| def is_ellipsis(node): |
| |
| """Returns whether node is an Ellipsis node.""" |
| return is_constant(node) and node.value is Ellipsis |
|
|
|
|
| def is_starred(node): |
| |
| """Returns whether node is a starred expression node.""" |
| return node.__class__.__name__ == 'Starred' |
|
|
|
|
| def is_slice(node): |
| |
| """Returns whether node represents a slice, e.g. `1:2` in `x[1:2]`""" |
| |
| |
| return ( |
| node.__class__.__name__ in ('Slice', 'ExtSlice') |
| or ( |
| node.__class__.__name__ == 'Tuple' |
| and any(map(is_slice, cast(ast.Tuple, node).elts)) |
| ) |
| ) |
|
|
|
|
| def is_empty_astroid_slice(node): |
| |
| return ( |
| node.__class__.__name__ == "Slice" |
| and not isinstance(node, ast.AST) |
| and node.lower is node.upper is node.step is None |
| ) |
|
|
|
|
| |
| _PREVISIT = object() |
|
|
| def visit_tree(node, previsit, postvisit): |
| |
| """ |
| Scans the tree under the node depth-first using an explicit stack. It avoids implicit recursion |
| via the function call stack to avoid hitting 'maximum recursion depth exceeded' error. |
| |
| It calls ``previsit()`` and ``postvisit()`` as follows: |
| |
| * ``previsit(node, par_value)`` - should return ``(par_value, value)`` |
| ``par_value`` is as returned from ``previsit()`` of the parent. |
| |
| * ``postvisit(node, par_value, value)`` - should return ``value`` |
| ``par_value`` is as returned from ``previsit()`` of the parent, and ``value`` is as |
| returned from ``previsit()`` of this node itself. The return ``value`` is ignored except |
| the one for the root node, which is returned from the overall ``visit_tree()`` call. |
| |
| For the initial node, ``par_value`` is None. ``postvisit`` may be None. |
| """ |
| if not postvisit: |
| postvisit = lambda node, pvalue, value: None |
|
|
| iter_children = iter_children_func(node) |
| done = set() |
| ret = None |
| stack = [(node, None, _PREVISIT)] |
| while stack: |
| current, par_value, value = stack.pop() |
| if value is _PREVISIT: |
| assert current not in done |
| done.add(current) |
|
|
| pvalue, post_value = previsit(current, par_value) |
| stack.append((current, par_value, post_value)) |
|
|
| |
| ins = len(stack) |
| for n in iter_children(current): |
| stack.insert(ins, (n, pvalue, _PREVISIT)) |
| else: |
| ret = postvisit(current, par_value, cast(Optional[Token], value)) |
| return ret |
|
|
|
|
| def walk(node, include_joined_str=False): |
| |
| """ |
| Recursively yield all descendant nodes in the tree starting at ``node`` (including ``node`` |
| itself), using depth-first pre-order traversal (yieling parents before their children). |
| |
| This is similar to ``ast.walk()``, but with a different order, and it works for both ``ast`` and |
| ``astroid`` trees. Also, as ``iter_children()``, it skips singleton nodes generated by ``ast``. |
| |
| By default, ``JoinedStr`` (f-string) nodes and their contents are skipped |
| because they previously couldn't be handled. Set ``include_joined_str`` to True to include them. |
| """ |
| iter_children = iter_children_func(node) |
| done = set() |
| stack = [node] |
| while stack: |
| current = stack.pop() |
| assert current not in done |
| done.add(current) |
|
|
| yield current |
|
|
| |
| |
| ins = len(stack) |
| for c in iter_children(current, include_joined_str): |
| stack.insert(ins, c) |
|
|
|
|
| def replace(text, replacements): |
| |
| """ |
| Replaces multiple slices of text with new values. This is a convenience method for making code |
| modifications of ranges e.g. as identified by ``ASTTokens.get_text_range(node)``. Replacements is |
| an iterable of ``(start, end, new_text)`` tuples. |
| |
| For example, ``replace("this is a test", [(0, 4, "X"), (8, 9, "THE")])`` produces |
| ``"X is THE test"``. |
| """ |
| p = 0 |
| parts = [] |
| for (start, end, new_text) in sorted(replacements): |
| parts.append(text[p:start]) |
| parts.append(new_text) |
| p = end |
| parts.append(text[p:]) |
| return ''.join(parts) |
|
|
|
|
| class NodeMethods: |
| """ |
| Helper to get `visit_{node_type}` methods given a node's class and cache the results. |
| """ |
| def __init__(self): |
| |
| self._cache = {} |
|
|
| def get(self, obj, cls): |
| |
| """ |
| Using the lowercase name of the class as node_type, returns `obj.visit_{node_type}`, |
| or `obj.visit_default` if the type-specific method is not found. |
| """ |
| method = self._cache.get(cls) |
| if not method: |
| name = "visit_" + cls.__name__.lower() |
| method = getattr(obj, name, obj.visit_default) |
| self._cache[cls] = method |
| return method |
|
|
|
|
| def patched_generate_tokens(original_tokens): |
| |
| """ |
| Fixes tokens yielded by `tokenize.generate_tokens` to handle more non-ASCII characters in identifiers. |
| Workaround for https://github.com/python/cpython/issues/68382. |
| Should only be used when tokenizing a string that is known to be valid syntax, |
| because it assumes that error tokens are not actually errors. |
| Combines groups of consecutive NAME, NUMBER, and/or ERRORTOKEN tokens into a single NAME token. |
| """ |
| group = [] |
| for tok in original_tokens: |
| if ( |
| tok.type in (tokenize.NAME, tokenize.ERRORTOKEN, tokenize.NUMBER) |
| |
| and (not group or group[-1].end == tok.start) |
| ): |
| group.append(tok) |
| else: |
| for combined_token in combine_tokens(group): |
| yield combined_token |
| group = [] |
| yield tok |
| for combined_token in combine_tokens(group): |
| yield combined_token |
|
|
| def combine_tokens(group): |
| |
| if not any(tok.type == tokenize.ERRORTOKEN for tok in group) or len({tok.line for tok in group}) != 1: |
| return group |
| return [ |
| tokenize.TokenInfo( |
| type=tokenize.NAME, |
| string="".join(t.string for t in group), |
| start=group[0].start, |
| end=group[-1].end, |
| line=group[0].line, |
| ) |
| ] |
|
|
|
|
| def last_stmt(node): |
| |
| """ |
| If the given AST node contains multiple statements, return the last one. |
| Otherwise, just return the node. |
| """ |
| child_stmts = [ |
| child for child in iter_children_func(node)(node) |
| if is_stmt(child) or type(child).__name__ in ( |
| "excepthandler", |
| "ExceptHandler", |
| "match_case", |
| "MatchCase", |
| "TryExcept", |
| "TryFinally", |
| ) |
| ] |
| if child_stmts: |
| return last_stmt(child_stmts[-1]) |
| return node |
|
|
|
|
|
|
| @lru_cache(maxsize=None) |
| def fstring_positions_work(): |
| |
| """ |
| The positions attached to nodes inside f-string FormattedValues have some bugs |
| that were fixed in Python 3.9.7 in https://github.com/python/cpython/pull/27729. |
| This checks for those bugs more concretely without relying on the Python version. |
| Specifically this checks: |
| - Values with a format spec or conversion |
| - Repeated (i.e. identical-looking) expressions |
| - f-strings implicitly concatenated over multiple lines. |
| - Multiline, triple-quoted f-strings. |
| """ |
| source = """( |
| f"a {b}{b} c {d!r} e {f:g} h {i:{j}} k {l:{m:n}}" |
| f"a {b}{b} c {d!r} e {f:g} h {i:{j}} k {l:{m:n}}" |
| f"{x + y + z} {x} {y} {z} {z} {z!a} {z:z}" |
| f''' |
| {s} {t} |
| {u} {v} |
| ''' |
| )""" |
| tree = ast.parse(source) |
| name_nodes = [node for node in ast.walk(tree) if isinstance(node, ast.Name)] |
| name_positions = [(node.lineno, node.col_offset) for node in name_nodes] |
| positions_are_unique = len(set(name_positions)) == len(name_positions) |
| correct_source_segments = all( |
| ast.get_source_segment(source, node) == node.id |
| for node in name_nodes |
| ) |
| return positions_are_unique and correct_source_segments |
|
|
| def annotate_fstring_nodes(tree): |
| |
| """ |
| Add a special attribute `_broken_positions` to nodes inside f-strings |
| if the lineno/col_offset cannot be trusted. |
| """ |
| if sys.version_info >= (3, 12): |
| |
| |
| return |
| for joinedstr in walk(tree, include_joined_str=True): |
| if not isinstance(joinedstr, ast.JoinedStr): |
| continue |
| for part in joinedstr.values: |
| |
| setattr(part, '_broken_positions', True) |
|
|
| if isinstance(part, ast.FormattedValue): |
| if not fstring_positions_work(): |
| for child in walk(part.value): |
| setattr(child, '_broken_positions', True) |
|
|
| if part.format_spec: |
| |
| setattr(part.format_spec, '_broken_positions', True) |
|
|