|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import ast |
|
|
from typing import Any, Optional, Sequence, Set |
|
|
|
|
|
from onnxscript import sourceinfo |
|
|
from onnxscript._internal import ast_utils |
|
|
|
|
|
|
|
|
def _get_loop_var(for_stmt: ast.For, formatter: sourceinfo.Formatter) -> str: |
|
|
if not isinstance(for_stmt.target, ast.Name): |
|
|
raise TypeError(formatter(for_stmt, "For loop target must be a single variable.")) |
|
|
return for_stmt.target.id |
|
|
|
|
|
|
|
|
def _used_vars(expr: Optional[ast.expr]) -> Set[str]: |
|
|
"""Return set of all variables used, including function names, in an expression.""" |
|
|
if expr is None: |
|
|
return set() |
|
|
if isinstance(expr, ast.Name): |
|
|
return {expr.id} |
|
|
result = set() |
|
|
if isinstance(expr, ast.Call): |
|
|
|
|
|
children = expr.args |
|
|
for keyword in expr.keywords: |
|
|
if isinstance(keyword.value, ast.Name): |
|
|
result.add(keyword.value.id) |
|
|
else: |
|
|
children = ast.iter_child_nodes(expr) |
|
|
for c in children: |
|
|
result = result | _used_vars(c) |
|
|
return result |
|
|
|
|
|
|
|
|
def _lhs_vars(lhs: ast.expr) -> Set[str]: |
|
|
"""Return set of assigned variables in the lhs of an assignment statement.""" |
|
|
|
|
|
def get_id(e): |
|
|
assert isinstance(e, ast.Name), "Only simple assignments supported." |
|
|
return e.id |
|
|
|
|
|
if isinstance(lhs, ast.Tuple): |
|
|
return {get_id(x) for x in lhs.elts} |
|
|
return {get_id(lhs)} |
|
|
|
|
|
|
|
|
def assigned_vars( |
|
|
stmt: ast.stmt | list[ast.stmt], formatter: sourceinfo.Formatter |
|
|
) -> Set[str]: |
|
|
"""Return the set of all variables that may be assigned to in an execution of input stmt |
|
|
or sequence of statements. |
|
|
""" |
|
|
|
|
|
def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]: |
|
|
result: set[Any] = set() |
|
|
for s in block: |
|
|
result = result | assigned_vars(s, formatter) |
|
|
return result |
|
|
|
|
|
if isinstance(stmt, ast.Assign): |
|
|
return _lhs_vars(stmt.targets[0]) |
|
|
if isinstance(stmt, ast.AnnAssign): |
|
|
return _lhs_vars(stmt.target) |
|
|
if isinstance(stmt, ast.Return): |
|
|
return set() |
|
|
if isinstance(stmt, ast.If): |
|
|
return assigned_in_block(stmt.body) | assigned_in_block(stmt.orelse) |
|
|
if isinstance(stmt, ast.For): |
|
|
return assigned_in_block(stmt.body) | {_get_loop_var(stmt, formatter)} |
|
|
if isinstance(stmt, ast.While): |
|
|
return assigned_in_block(stmt.body) |
|
|
if isinstance(stmt, list): |
|
|
return assigned_in_block(stmt) |
|
|
if isinstance(stmt, ast.Break): |
|
|
return set() |
|
|
if ast_utils.is_print_call(stmt): |
|
|
return set() |
|
|
if ast_utils.is_doc_string(stmt): |
|
|
return set() |
|
|
error_message = formatter(stmt, f"Unsupported statement type {type(stmt)!r}.") |
|
|
raise ValueError(error_message) |
|
|
|
|
|
|
|
|
def do_liveness_analysis(fun: ast.FunctionDef, formatter: sourceinfo.Formatter): |
|
|
"""Perform liveness analysis of the given function-ast. The results of the |
|
|
analysis are stored directly with each statement-ast `s` as attributes `s.live_in` |
|
|
and `s.live_out`. |
|
|
""" |
|
|
|
|
|
def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: |
|
|
stmt.live_out = live_out |
|
|
live = do_visit(stmt, live_out) |
|
|
stmt.live_in = live |
|
|
return live |
|
|
|
|
|
def do_visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: |
|
|
def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: |
|
|
for s in reversed(block): |
|
|
live_out = visit(s, live_out) |
|
|
return live_out |
|
|
|
|
|
if isinstance(stmt, ast.Assign): |
|
|
return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value) |
|
|
if isinstance(stmt, ast.AnnAssign): |
|
|
return live_out.difference(_lhs_vars(stmt.target)) | _used_vars(stmt.value) |
|
|
if isinstance(stmt, ast.Return): |
|
|
return _used_vars(stmt.value) |
|
|
if isinstance(stmt, ast.If): |
|
|
live1 = visitBlock(stmt.body, live_out) |
|
|
live2 = visitBlock(stmt.orelse, live_out) |
|
|
return live1 | live2 | _used_vars(stmt.test) |
|
|
if isinstance(stmt, ast.For): |
|
|
p_loop_var = _get_loop_var(stmt, formatter) |
|
|
prev = None |
|
|
curr = live_out |
|
|
while curr != prev: |
|
|
prev = curr |
|
|
curr = visitBlock(stmt.body, prev).difference({p_loop_var}) |
|
|
return curr |
|
|
if isinstance(stmt, ast.While): |
|
|
cond_vars = _used_vars(stmt.test) |
|
|
prev = None |
|
|
curr = live_out | cond_vars |
|
|
while curr != prev: |
|
|
prev = curr |
|
|
curr = visitBlock(stmt.body, prev) | cond_vars |
|
|
return curr |
|
|
if isinstance(stmt, ast.Break): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return live_out |
|
|
if ast_utils.is_doc_string(stmt): |
|
|
return live_out |
|
|
if isinstance(stmt, ast.FunctionDef): |
|
|
return live_out |
|
|
if ast_utils.is_print_call(stmt): |
|
|
return live_out |
|
|
raise ValueError(formatter(stmt, f"Unsupported statement type {type(stmt)!r}.")) |
|
|
|
|
|
assert isinstance(fun, ast.FunctionDef) |
|
|
live: set[Any] = set() |
|
|
for s in reversed(fun.body): |
|
|
live = visit(s, live) |
|
|
|
|
|
|
|
|
def exposed_uses(stmts: Sequence[ast.stmt], formatter: sourceinfo.Formatter): |
|
|
"""Return the set of variables that are used before being defined by given block. |
|
|
In essence, this identifies the "inputs" to a given code-block. |
|
|
For example, consider the following code-block: |
|
|
:: |
|
|
|
|
|
x = x + 10 |
|
|
y = 20 |
|
|
z = x + y |
|
|
x = 30 |
|
|
|
|
|
The exposed_uses of this code-block is { x }. The value of z is not used within |
|
|
the block. Even though the value of y is used within the block, it is assigned |
|
|
a value before it is used. However, in contrast, the incoming value of x is used |
|
|
(in the first statement). Hence x is included in the exposed_uses. |
|
|
""" |
|
|
|
|
|
def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: |
|
|
for stmt in reversed(block): |
|
|
live_out = visit(stmt, live_out) |
|
|
return live_out |
|
|
|
|
|
def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: |
|
|
if isinstance(stmt, ast.Assign): |
|
|
return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value) |
|
|
if isinstance(stmt, ast.AnnAssign): |
|
|
return live_out.difference(_lhs_vars(stmt.target)) | _used_vars(stmt.value) |
|
|
if isinstance(stmt, ast.Return): |
|
|
return _used_vars(stmt.value) |
|
|
if isinstance(stmt, ast.If): |
|
|
live1 = visitBlock(stmt.body, live_out) |
|
|
live2 = visitBlock(stmt.orelse, live_out) |
|
|
return (live1 | live2) | _used_vars(stmt.test) |
|
|
if ast_utils.is_print_call(stmt): |
|
|
return live_out |
|
|
if ast_utils.is_doc_string(stmt): |
|
|
return live_out |
|
|
if isinstance(stmt, ast.For): |
|
|
|
|
|
|
|
|
loop_var_set = {_get_loop_var(stmt, formatter)} |
|
|
used_after_loop = live_out.difference(loop_var_set) |
|
|
used_inside_loop = visitBlock(stmt.body, set()).difference(loop_var_set) |
|
|
used_in_loop_header = _used_vars(stmt.iter) |
|
|
return used_inside_loop | used_in_loop_header | used_after_loop |
|
|
if isinstance(stmt, ast.While): |
|
|
|
|
|
|
|
|
used_inside_loop = visitBlock(stmt.body, set()) |
|
|
used_in_loop_header = _used_vars(stmt.test) |
|
|
return used_inside_loop | used_in_loop_header | live_out |
|
|
if isinstance(stmt, ast.Break): |
|
|
|
|
|
|
|
|
return live_out |
|
|
if isinstance(stmt, ast.FunctionDef): |
|
|
if stmt.name in live_out: |
|
|
live_out.remove(stmt.name) |
|
|
live_out = live_out | outer_scope_variables(stmt, formatter) |
|
|
return live_out |
|
|
raise ValueError(formatter(stmt, f"Unsupported statement type {type(stmt)!r}.")) |
|
|
|
|
|
return visitBlock(stmts, set()) |
|
|
|
|
|
|
|
|
def outer_scope_variables(fun: ast.FunctionDef, formatter: sourceinfo.Formatter): |
|
|
"""Return the set of outer-scope variables used in a nested function. |
|
|
|
|
|
Args: |
|
|
fun: The function-ast to analyze. |
|
|
formatter: The formatter object. |
|
|
|
|
|
Returns: |
|
|
A set of variable names (strings). |
|
|
""" |
|
|
assert isinstance(fun, ast.FunctionDef) |
|
|
used_vars_ = exposed_uses(fun.body, formatter) |
|
|
inputs = [x.arg for x in fun.args.args] |
|
|
return used_vars_.difference(inputs) |
|
|
|