| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import ast |
| | import collections |
| | import io |
| | import sys |
| | import token |
| | import tokenize |
| | from abc import ABCMeta |
| | from ast import Module, expr, AST |
| | from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union, cast, Any, TYPE_CHECKING |
| |
|
| | from six import iteritems |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from .astroid_compat import NodeNG |
| |
|
| | |
| | |
| | class EnhancedAST(AST): |
| | |
| | first_token = None |
| | last_token = None |
| | lineno = 0 |
| |
|
| | AstNode = Union[EnhancedAST, NodeNG] |
| |
|
| | if sys.version_info[0] == 2: |
| | TokenInfo = Tuple[int, str, Tuple[int, int], Tuple[int, int], str] |
| | else: |
| | TokenInfo = tokenize.TokenInfo |
| |
|
| |
|
| | def token_repr(tok_type, string): |
| | |
| | """Returns a human-friendly representation of a token with the given type and string.""" |
| | |
| | return '%s:%s' % (token.tok_name[tok_type], repr(string).lstrip('u')) |
| |
|
| |
|
| | class Token(collections.namedtuple('Token', 'type string start end line index startpos endpos')): |
| | """ |
| | TokenInfo is an 8-tuple containing the same 5 fields as the tokens produced by the tokenize |
| | module, and 3 additional ones useful for this module: |
| | |
| | - [0] .type Token type (see token.py) |
| | - [1] .string Token (a string) |
| | - [2] .start Starting (row, column) indices of the token (a 2-tuple of ints) |
| | - [3] .end Ending (row, column) indices of the token (a 2-tuple of ints) |
| | - [4] .line Original line (string) |
| | - [5] .index Index of the token in the list of tokens that it belongs to. |
| | - [6] .startpos Starting character offset into the input text. |
| | - [7] .endpos Ending character offset into the input text. |
| | """ |
| | def __str__(self): |
| | |
| | return token_repr(self.type, self.string) |
| |
|
| |
|
| | if sys.version_info >= (3, 6): |
| | AstConstant = ast.Constant |
| | else: |
| | class AstConstant: |
| | value = object() |
| |
|
| |
|
| | def match_token(token, tok_type, tok_str=None): |
| | |
| | """Returns true if token is of the given type and, if a string is given, has that string.""" |
| | return token.type == tok_type and (tok_str is None or token.string == tok_str) |
| |
|
| |
|
| | def expect_token(token, tok_type, tok_str=None): |
| | |
| | """ |
| | Verifies that the given token is of the expected type. If tok_str is given, the token string |
| | is verified too. If the token doesn't match, raises an informative ValueError. |
| | """ |
| | if not match_token(token, tok_type, tok_str): |
| | raise ValueError("Expected token %s, got %s on line %s col %s" % ( |
| | token_repr(tok_type, tok_str), str(token), |
| | token.start[0], token.start[1] + 1)) |
| |
|
| | |
| | |
| | if sys.version_info >= (3, 7): |
| | def is_non_coding_token(token_type): |
| | |
| | """ |
| | These are considered non-coding tokens, as they don't affect the syntax tree. |
| | """ |
| | return token_type in (token.NL, token.COMMENT, token.ENCODING) |
| | else: |
| | def is_non_coding_token(token_type): |
| | |
| | """ |
| | These are considered non-coding tokens, as they don't affect the syntax tree. |
| | """ |
| | return token_type >= token.N_TOKENS |
| |
|
| |
|
| | def generate_tokens(text): |
| | |
| | """ |
| | Generates standard library tokens for the given code. |
| | """ |
| | |
| | |
| | |
| | return tokenize.generate_tokens(cast(Callable[[], str], io.StringIO(text).readline)) |
| |
|
| |
|
| | def iter_children_func(node): |
| | |
| | """ |
| | Returns a function which yields all direct children of a AST node, |
| | skipping children that are singleton nodes. |
| | The function depends on whether ``node`` is from ``ast`` or from the ``astroid`` module. |
| | """ |
| | return iter_children_astroid if hasattr(node, 'get_children') else iter_children_ast |
| |
|
| |
|
| | def iter_children_astroid(node): |
| | |
| | |
| | if is_joined_str(node): |
| | return [] |
| |
|
| | return node.get_children() |
| |
|
| |
|
| | SINGLETONS = {c for n, c in iteritems(ast.__dict__) if isinstance(c, type) and |
| | issubclass(c, (ast.expr_context, ast.boolop, ast.operator, ast.unaryop, ast.cmpop))} |
| |
|
| | def iter_children_ast(node): |
| | |
| | |
| | if is_joined_str(node): |
| | return |
| |
|
| | if isinstance(node, ast.Dict): |
| | |
| | |
| | for (key, value) in zip(node.keys, node.values): |
| | if key is not None: |
| | yield key |
| | yield value |
| | return |
| |
|
| | for child in ast.iter_child_nodes(node): |
| | |
| | |
| | |
| | if child.__class__ not in SINGLETONS: |
| | yield child |
| |
|
| |
|
| | stmt_class_names = {n for n, c in iteritems(ast.__dict__) |
| | if isinstance(c, type) and issubclass(c, ast.stmt)} |
| | expr_class_names = ({n for n, c in iteritems(ast.__dict__) |
| | if isinstance(c, type) and issubclass(c, ast.expr)} | |
| | {'AssignName', 'DelName', 'Const', 'AssignAttr', 'DelAttr'}) |
| |
|
| | |
| | |
| | def is_expr(node): |
| | |
| | """Returns whether node is an expression node.""" |
| | return node.__class__.__name__ in expr_class_names |
| |
|
| | def is_stmt(node): |
| | |
| | """Returns whether node is a statement node.""" |
| | return node.__class__.__name__ in stmt_class_names |
| |
|
| | def is_module(node): |
| | |
| | """Returns whether node is a module node.""" |
| | return node.__class__.__name__ == 'Module' |
| |
|
| | def is_joined_str(node): |
| | |
| | """Returns whether node is a JoinedStr node, used to represent f-strings.""" |
| | |
| | |
| | return node.__class__.__name__ == 'JoinedStr' |
| |
|
| |
|
| | def is_starred(node): |
| | |
| | """Returns whether node is a starred expression node.""" |
| | return node.__class__.__name__ == 'Starred' |
| |
|
| |
|
| | def is_slice(node): |
| | |
| | """Returns whether node represents a slice, e.g. `1:2` in `x[1:2]`""" |
| | |
| | |
| | return ( |
| | node.__class__.__name__ in ('Slice', 'ExtSlice') |
| | or ( |
| | node.__class__.__name__ == 'Tuple' |
| | and any(map(is_slice, cast(ast.Tuple, node).elts)) |
| | ) |
| | ) |
| |
|
| |
|
| | def is_empty_astroid_slice(node): |
| | |
| | return ( |
| | node.__class__.__name__ == "Slice" |
| | and not isinstance(node, ast.AST) |
| | and node.lower is node.upper is node.step is None |
| | ) |
| |
|
| |
|
| | |
| | _PREVISIT = object() |
| |
|
| | def visit_tree(node, previsit, postvisit): |
| | |
| | """ |
| | Scans the tree under the node depth-first using an explicit stack. It avoids implicit recursion |
| | via the function call stack to avoid hitting 'maximum recursion depth exceeded' error. |
| | |
| | It calls ``previsit()`` and ``postvisit()`` as follows: |
| | |
| | * ``previsit(node, par_value)`` - should return ``(par_value, value)`` |
| | ``par_value`` is as returned from ``previsit()`` of the parent. |
| | |
| | * ``postvisit(node, par_value, value)`` - should return ``value`` |
| | ``par_value`` is as returned from ``previsit()`` of the parent, and ``value`` is as |
| | returned from ``previsit()`` of this node itself. The return ``value`` is ignored except |
| | the one for the root node, which is returned from the overall ``visit_tree()`` call. |
| | |
| | For the initial node, ``par_value`` is None. ``postvisit`` may be None. |
| | """ |
| | if not postvisit: |
| | postvisit = lambda node, pvalue, value: None |
| |
|
| | iter_children = iter_children_func(node) |
| | done = set() |
| | ret = None |
| | stack = [(node, None, _PREVISIT)] |
| | while stack: |
| | current, par_value, value = stack.pop() |
| | if value is _PREVISIT: |
| | assert current not in done |
| | done.add(current) |
| |
|
| | pvalue, post_value = previsit(current, par_value) |
| | stack.append((current, par_value, post_value)) |
| |
|
| | |
| | ins = len(stack) |
| | for n in iter_children(current): |
| | stack.insert(ins, (n, pvalue, _PREVISIT)) |
| | else: |
| | ret = postvisit(current, par_value, cast(Optional[Token], value)) |
| | return ret |
| |
|
| |
|
| |
|
| | def walk(node): |
| | |
| | """ |
| | Recursively yield all descendant nodes in the tree starting at ``node`` (including ``node`` |
| | itself), using depth-first pre-order traversal (yieling parents before their children). |
| | |
| | This is similar to ``ast.walk()``, but with a different order, and it works for both ``ast`` and |
| | ``astroid`` trees. Also, as ``iter_children()``, it skips singleton nodes generated by ``ast``. |
| | """ |
| | iter_children = iter_children_func(node) |
| | done = set() |
| | stack = [node] |
| | while stack: |
| | current = stack.pop() |
| | assert current not in done |
| | done.add(current) |
| |
|
| | yield current |
| |
|
| | |
| | |
| | ins = len(stack) |
| | for c in iter_children(current): |
| | stack.insert(ins, c) |
| |
|
| |
|
| | def replace(text, replacements): |
| | |
| | """ |
| | Replaces multiple slices of text with new values. This is a convenience method for making code |
| | modifications of ranges e.g. as identified by ``ASTTokens.get_text_range(node)``. Replacements is |
| | an iterable of ``(start, end, new_text)`` tuples. |
| | |
| | For example, ``replace("this is a test", [(0, 4, "X"), (8, 9, "THE")])`` produces |
| | ``"X is THE test"``. |
| | """ |
| | p = 0 |
| | parts = [] |
| | for (start, end, new_text) in sorted(replacements): |
| | parts.append(text[p:start]) |
| | parts.append(new_text) |
| | p = end |
| | parts.append(text[p:]) |
| | return ''.join(parts) |
| |
|
| |
|
| | class NodeMethods(object): |
| | """ |
| | Helper to get `visit_{node_type}` methods given a node's class and cache the results. |
| | """ |
| | def __init__(self): |
| | |
| | self._cache = {} |
| |
|
| | def get(self, obj, cls): |
| | |
| | """ |
| | Using the lowercase name of the class as node_type, returns `obj.visit_{node_type}`, |
| | or `obj.visit_default` if the type-specific method is not found. |
| | """ |
| | method = self._cache.get(cls) |
| | if not method: |
| | name = "visit_" + cls.__name__.lower() |
| | method = getattr(obj, name, obj.visit_default) |
| | self._cache[cls] = method |
| | return method |
| |
|
| |
|
| | if sys.version_info[0] == 2: |
| | |
| | |
| | def patched_generate_tokens(original_tokens): |
| | |
| | return iter(original_tokens) |
| | else: |
| | def patched_generate_tokens(original_tokens): |
| | |
| | """ |
| | Fixes tokens yielded by `tokenize.generate_tokens` to handle more non-ASCII characters in identifiers. |
| | Workaround for https://github.com/python/cpython/issues/68382. |
| | Should only be used when tokenizing a string that is known to be valid syntax, |
| | because it assumes that error tokens are not actually errors. |
| | Combines groups of consecutive NAME, NUMBER, and/or ERRORTOKEN tokens into a single NAME token. |
| | """ |
| | group = [] |
| | for tok in original_tokens: |
| | if ( |
| | tok.type in (tokenize.NAME, tokenize.ERRORTOKEN, tokenize.NUMBER) |
| | |
| | and (not group or group[-1].end == tok.start) |
| | ): |
| | group.append(tok) |
| | else: |
| | for combined_token in combine_tokens(group): |
| | yield combined_token |
| | group = [] |
| | yield tok |
| | for combined_token in combine_tokens(group): |
| | yield combined_token |
| |
|
| | def combine_tokens(group): |
| | |
| | if not any(tok.type == tokenize.ERRORTOKEN for tok in group) or len({tok.line for tok in group}) != 1: |
| | return group |
| | return [ |
| | tokenize.TokenInfo( |
| | type=tokenize.NAME, |
| | string="".join(t.string for t in group), |
| | start=group[0].start, |
| | end=group[-1].end, |
| | line=group[0].line, |
| | ) |
| | ] |
| |
|
| |
|
| | def last_stmt(node): |
| | |
| | """ |
| | If the given AST node contains multiple statements, return the last one. |
| | Otherwise, just return the node. |
| | """ |
| | child_stmts = [ |
| | child for child in ast.iter_child_nodes(node) |
| | if isinstance(child, (ast.stmt, ast.excepthandler, getattr(ast, "match_case", ()))) |
| | ] |
| | if child_stmts: |
| | return last_stmt(child_stmts[-1]) |
| | return node |
| |
|
| |
|
| | if sys.version_info[:2] >= (3, 8): |
| | from functools import lru_cache |
| |
|
| | @lru_cache(maxsize=None) |
| | def fstring_positions_work(): |
| | |
| | """ |
| | The positions attached to nodes inside f-string FormattedValues have some bugs |
| | that were fixed in Python 3.9.7 in https://github.com/python/cpython/pull/27729. |
| | This checks for those bugs more concretely without relying on the Python version. |
| | Specifically this checks: |
| | - Values with a format spec or conversion |
| | - Repeated (i.e. identical-looking) expressions |
| | - Multiline f-strings implicitly concatenated. |
| | """ |
| | source = """( |
| | f"a {b}{b} c {d!r} e {f:g} h {i:{j}} k {l:{m:n}}" |
| | f"a {b}{b} c {d!r} e {f:g} h {i:{j}} k {l:{m:n}}" |
| | f"{x + y + z} {x} {y} {z} {z} {z!a} {z:z}" |
| | )""" |
| | tree = ast.parse(source) |
| | name_nodes = [node for node in ast.walk(tree) if isinstance(node, ast.Name)] |
| | name_positions = [(node.lineno, node.col_offset) for node in name_nodes] |
| | positions_are_unique = len(set(name_positions)) == len(name_positions) |
| | correct_source_segments = all( |
| | ast.get_source_segment(source, node) == node.id |
| | for node in name_nodes |
| | ) |
| | return positions_are_unique and correct_source_segments |
| |
|
| | def annotate_fstring_nodes(tree): |
| | |
| | """ |
| | Add a special attribute `_broken_positions` to nodes inside f-strings |
| | if the lineno/col_offset cannot be trusted. |
| | """ |
| | for joinedstr in walk(tree): |
| | if not isinstance(joinedstr, ast.JoinedStr): |
| | continue |
| | for part in joinedstr.values: |
| | |
| | setattr(part, '_broken_positions', True) |
| |
|
| | if isinstance(part, ast.FormattedValue): |
| | if not fstring_positions_work(): |
| | for child in walk(part.value): |
| | setattr(child, '_broken_positions', True) |
| |
|
| | if part.format_spec: |
| | |
| | setattr(part.format_spec, '_broken_positions', True) |
| | |
| | |
| | |
| | annotate_fstring_nodes(part.format_spec) |
| | else: |
| | def fstring_positions_work(): |
| | |
| | return False |
| |
|
| | def annotate_fstring_nodes(_tree): |
| | |
| | pass |
| |
|