|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Converting code to AST. |
|
|
|
|
|
Adapted from Tangent. |
|
|
""" |
|
|
|
|
|
import ast |
|
|
import inspect |
|
|
import io |
|
|
import linecache |
|
|
import re |
|
|
import sys |
|
|
import textwrap |
|
|
import tokenize |
|
|
|
|
|
import astunparse |
|
|
import gast |
|
|
|
|
|
from malt.pyct import errors |
|
|
from malt.pyct import inspect_utils |
|
|
|
|
|
|
|
|
PY2_PREAMBLE = textwrap.dedent(""" |
|
|
""") |
|
|
PY3_PREAMBLE = '' |
|
|
MAX_SIZE = 0 |
|
|
|
|
|
if sys.version_info >= (3, 9): |
|
|
astunparse = ast |
|
|
|
|
|
if sys.version_info >= (3,): |
|
|
STANDARD_PREAMBLE = PY3_PREAMBLE |
|
|
MAX_SIZE = sys.maxsize |
|
|
else: |
|
|
STANDARD_PREAMBLE = PY2_PREAMBLE |
|
|
MAX_SIZE = sys.maxint |
|
|
|
|
|
STANDARD_PREAMBLE_LEN = STANDARD_PREAMBLE.count('__future__') |
|
|
|
|
|
|
|
|
_LEADING_WHITESPACE = re.compile(r'\s*') |
|
|
|
|
|
|
|
|
def _unfold_continuations(code_string): |
|
|
"""Removes any backslash line continuations from the code.""" |
|
|
return code_string.replace('\\\n', '') |
|
|
|
|
|
|
|
|
def dedent_block(code_string): |
|
|
"""Dedents a code so that its first line starts at row zero.""" |
|
|
|
|
|
code_string = _unfold_continuations(code_string) |
|
|
|
|
|
token_gen = tokenize.generate_tokens(io.StringIO(code_string).readline) |
|
|
|
|
|
block_indentation = None |
|
|
tokens = [] |
|
|
try: |
|
|
for tok in token_gen: |
|
|
tokens.append(tok) |
|
|
except tokenize.TokenError: |
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
for tok in tokens: |
|
|
tok_type, tok_string, _, _, _ = tok |
|
|
if tok_type == tokenize.INDENT: |
|
|
block_indentation = tok_string |
|
|
block_level = len(block_indentation) |
|
|
break |
|
|
elif tok_type not in ( |
|
|
tokenize.NL, tokenize.NEWLINE, tokenize.STRING, tokenize.COMMENT): |
|
|
block_indentation = '' |
|
|
break |
|
|
|
|
|
if not block_indentation: |
|
|
return code_string |
|
|
|
|
|
block_level = len(block_indentation) |
|
|
first_indent_uses_tabs = '\t' in block_indentation |
|
|
for i, tok in enumerate(tokens): |
|
|
tok_type, tok_string, _, _, _ = tok |
|
|
if tok_type == tokenize.INDENT: |
|
|
if ((' ' in tok_string and first_indent_uses_tabs) |
|
|
or ('\t' in tok_string and not first_indent_uses_tabs)): |
|
|
|
|
|
|
|
|
|
|
|
raise errors.UnsupportedLanguageElementError( |
|
|
'code mixing tabs and spaces for indentation is not allowed') |
|
|
if len(tok_string) >= block_level: |
|
|
tok_string = tok_string[block_level:] |
|
|
tokens[i] = (tok_type, tok_string) |
|
|
|
|
|
new_code = tokenize.untokenize(tokens) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dedented_code = [] |
|
|
for line, new_line in zip(code_string.split('\n'), new_code.split('\n')): |
|
|
original_indent = re.match(_LEADING_WHITESPACE, line).group() |
|
|
new_indent = re.match(_LEADING_WHITESPACE, new_line).group() |
|
|
if len(original_indent) > len(new_indent): |
|
|
dedented_line = line[len(original_indent) - len(new_indent):] |
|
|
else: |
|
|
dedented_line = line |
|
|
dedented_code.append(dedented_line) |
|
|
new_code = '\n'.join(dedented_code) |
|
|
|
|
|
return new_code |
|
|
|
|
|
|
|
|
def parse_entity(entity, future_features): |
|
|
"""Returns the AST and source code of given entity. |
|
|
|
|
|
Args: |
|
|
entity: Any, Python function/method/class |
|
|
future_features: Iterable[Text], future features to use (e.g. |
|
|
'print_statement'). See |
|
|
https://docs.python.org/2/reference/simple_stmts.html#future |
|
|
|
|
|
Returns: |
|
|
gast.AST, Text: the parsed AST node; the source code that was parsed to |
|
|
generate the AST (including any prefixes that this function may have added). |
|
|
""" |
|
|
if inspect_utils.islambda(entity): |
|
|
return _parse_lambda(entity) |
|
|
|
|
|
try: |
|
|
original_source = inspect_utils.getimmediatesource(entity) |
|
|
except OSError as e: |
|
|
raise errors.InaccessibleSourceCodeError( |
|
|
f'Unable to locate the source code of {entity}. Note that functions' |
|
|
' defined in certain environments, like the interactive Python shell,' |
|
|
' do not expose their source code. If that is the case, you should' |
|
|
' define them in a .py source file. If you are certain the code is' |
|
|
' graph-compatible, wrap the call using' |
|
|
f' @tf.autograph.experimental.do_not_convert. Original error: {e}') |
|
|
|
|
|
source = dedent_block(original_source) |
|
|
|
|
|
future_statements = tuple( |
|
|
'from __future__ import {}'.format(name) for name in future_features) |
|
|
source = '\n'.join(future_statements + (source,)) |
|
|
|
|
|
return parse(source, preamble_len=len(future_features)), source |
|
|
|
|
|
|
|
|
def _without_context(node, lines, minl, maxl): |
|
|
"""Returns a clean node and source code without indenting and context.""" |
|
|
for n in gast.walk(node): |
|
|
lineno = getattr(n, 'lineno', None) |
|
|
if lineno is not None: |
|
|
n.lineno = lineno - minl |
|
|
end_lineno = getattr(n, 'end_lineno', None) |
|
|
if end_lineno is not None: |
|
|
n.end_lineno = end_lineno - minl |
|
|
|
|
|
code_lines = lines[minl - 1:maxl] |
|
|
|
|
|
|
|
|
|
|
|
end_col_offset = getattr(node, 'end_col_offset', None) |
|
|
if end_col_offset is not None: |
|
|
|
|
|
code_lines[-1] = code_lines[-1][:end_col_offset] |
|
|
|
|
|
col_offset = getattr(node, 'col_offset', None) |
|
|
if col_offset is None: |
|
|
|
|
|
match = re.search(r'(?<!\w)lambda(?!\w)', code_lines[0]) |
|
|
if match is not None: |
|
|
col_offset = match.start(0) |
|
|
|
|
|
if col_offset is not None: |
|
|
code_lines[0] = code_lines[0][col_offset:] |
|
|
|
|
|
code_block = '\n'.join([c.rstrip() for c in code_lines]) |
|
|
|
|
|
return node, code_block |
|
|
|
|
|
|
|
|
def _arg_name(node): |
|
|
if node is None: |
|
|
return None |
|
|
if isinstance(node, gast.Name): |
|
|
return node.id |
|
|
assert isinstance(node, str) |
|
|
return node |
|
|
|
|
|
|
|
|
def _node_matches_argspec(node, func): |
|
|
"""Returns True is node fits the argspec of func.""" |
|
|
|
|
|
|
|
|
arg_spec = inspect.getfullargspec(func) |
|
|
|
|
|
node_args = tuple(_arg_name(arg) for arg in node.args.args) |
|
|
if node_args != tuple(arg_spec.args): |
|
|
return False |
|
|
|
|
|
if arg_spec.varargs != _arg_name(node.args.vararg): |
|
|
return False |
|
|
|
|
|
if arg_spec.varkw != _arg_name(node.args.kwarg): |
|
|
return False |
|
|
|
|
|
node_kwonlyargs = tuple(_arg_name(arg) for arg in node.args.kwonlyargs) |
|
|
if node_kwonlyargs != tuple(arg_spec.kwonlyargs): |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def _parse_lambda(lam): |
|
|
"""Returns the AST and source code of given lambda function. |
|
|
|
|
|
Args: |
|
|
lam: types.LambdaType, Python function/method/class |
|
|
|
|
|
Returns: |
|
|
gast.AST, Text: the parsed AST node; the source code that was parsed to |
|
|
generate the AST (including any prefixes that this function may have added). |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mod = inspect.getmodule(lam) |
|
|
f = inspect.getsourcefile(lam) |
|
|
def_line = lam.__code__.co_firstlineno |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lines = linecache.getlines(f, mod.__dict__) |
|
|
source = ''.join(lines) |
|
|
|
|
|
|
|
|
all_nodes = parse(source, preamble_len=0, single_node=False) |
|
|
search_nodes = [] |
|
|
for node in all_nodes: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if getattr(node, 'lineno', def_line) <= def_line: |
|
|
search_nodes.append(node) |
|
|
else: |
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
lambda_nodes = [] |
|
|
for node in search_nodes: |
|
|
lambda_nodes.extend( |
|
|
n for n in gast.walk(node) if isinstance(n, gast.Lambda)) |
|
|
|
|
|
|
|
|
candidates = [] |
|
|
for ln in lambda_nodes: |
|
|
minl, maxl = MAX_SIZE, 0 |
|
|
for n in gast.walk(ln): |
|
|
minl = min(minl, getattr(n, 'lineno', minl)) |
|
|
lineno = getattr(n, 'lineno', maxl) |
|
|
end_lineno = getattr(n, 'end_lineno', None) |
|
|
if end_lineno is not None: |
|
|
|
|
|
lineno = end_lineno |
|
|
maxl = max(maxl, lineno) |
|
|
if minl <= def_line <= maxl: |
|
|
candidates.append((ln, minl, maxl)) |
|
|
|
|
|
|
|
|
if len(candidates) == 1: |
|
|
(node, minl, maxl), = candidates |
|
|
return _without_context(node, lines, minl, maxl) |
|
|
|
|
|
elif not candidates: |
|
|
lambda_codes = '\n'.join([unparse(l) for l in lambda_nodes]) |
|
|
raise errors.UnsupportedLanguageElementError( |
|
|
f'could not parse the source code of {lam}:' |
|
|
f' no matching AST found among candidates:\n{lambda_codes}') |
|
|
|
|
|
|
|
|
matches = [v for v in candidates if _node_matches_argspec(v[0], lam)] |
|
|
if len(matches) == 1: |
|
|
(node, minl, maxl), = matches |
|
|
return _without_context(node, lines, minl, maxl) |
|
|
|
|
|
|
|
|
matches = '\n'.join( |
|
|
'Match {}:\n{}\n'.format(i, unparse(node, include_encoding_marker=False)) |
|
|
for i, (node, _, _) in enumerate(matches)) |
|
|
raise errors.UnsupportedLanguageElementError( |
|
|
f'could not parse the source code of {lam}: found multiple definitions' |
|
|
' with identical signatures at the location. This error' |
|
|
' may be avoided by defining each lambda on a single line and with' |
|
|
f' unique argument names. The matching definitions were:\n{matches}') |
|
|
|
|
|
|
|
|
|
|
|
def parse(src, preamble_len=0, single_node=True): |
|
|
"""Returns the AST of given piece of code. |
|
|
|
|
|
Args: |
|
|
src: Text |
|
|
preamble_len: Int, indicates leading nodes in the parsed AST which should be |
|
|
dropped. |
|
|
single_node: Bool, whether `src` is assumed to be represented by exactly one |
|
|
AST node. |
|
|
|
|
|
Returns: |
|
|
ast.AST |
|
|
""" |
|
|
module_node = gast.parse(src) |
|
|
nodes = module_node.body |
|
|
if preamble_len: |
|
|
nodes = nodes[preamble_len:] |
|
|
if single_node: |
|
|
if len(nodes) != 1: |
|
|
raise ValueError('expected exactly one node, got {}'.format(nodes)) |
|
|
return nodes[0] |
|
|
return nodes |
|
|
|
|
|
|
|
|
def parse_expression(src): |
|
|
"""Returns the AST of given identifier. |
|
|
|
|
|
Args: |
|
|
src: A piece of code that represents a single Python expression |
|
|
Returns: |
|
|
A gast.AST object. |
|
|
Raises: |
|
|
ValueError: if src does not consist of a single Expression. |
|
|
""" |
|
|
src = STANDARD_PREAMBLE + src.strip() |
|
|
node = parse(src, preamble_len=STANDARD_PREAMBLE_LEN, single_node=True) |
|
|
if __debug__: |
|
|
if not isinstance(node, gast.Expr): |
|
|
raise ValueError( |
|
|
'expected exactly one node of type Expr, got {}'.format(node)) |
|
|
return node.value |
|
|
|
|
|
|
|
|
def unparse(node, indentation=None, include_encoding_marker=True): |
|
|
"""Returns the source code of given AST. |
|
|
|
|
|
Args: |
|
|
node: The code to compile, as an AST object. |
|
|
indentation: Unused, deprecated. The returning code will always be indented |
|
|
at 4 spaces. |
|
|
include_encoding_marker: Bool, whether to include a comment on the first |
|
|
line to explicitly specify UTF-8 encoding. |
|
|
|
|
|
Returns: |
|
|
code: The source code generated from the AST object |
|
|
source_mapping: A mapping between the user and AutoGraph generated code. |
|
|
""" |
|
|
del indentation |
|
|
if not isinstance(node, (list, tuple)): |
|
|
node = (node,) |
|
|
|
|
|
codes = [] |
|
|
if include_encoding_marker: |
|
|
codes.append('# coding=utf-8') |
|
|
for n in node: |
|
|
if isinstance(n, gast.AST): |
|
|
ast_n = gast.gast_to_ast(n) |
|
|
else: |
|
|
ast_n = n |
|
|
|
|
|
if astunparse is ast: |
|
|
ast.fix_missing_locations(ast_n) |
|
|
codes.append(astunparse.unparse(ast_n).strip()) |
|
|
|
|
|
return '\n'.join(codes) |
|
|
|