File size: 9,139 Bytes
6a22ec9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 | # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import ast
from typing import Any, Optional, Sequence, Set
from onnxscript import sourceinfo
from onnxscript._internal import ast_utils
def _get_loop_var(for_stmt: ast.For, formatter: sourceinfo.Formatter) -> str:
if not isinstance(for_stmt.target, ast.Name):
raise TypeError(formatter(for_stmt, "For loop target must be a single variable."))
return for_stmt.target.id
def _used_vars(expr: Optional[ast.expr]) -> Set[str]:
"""Return set of all variables used, including function names, in an expression."""
if expr is None:
return set()
if isinstance(expr, ast.Name):
return {expr.id}
result = set()
if isinstance(expr, ast.Call):
# The callee-expression is not visited
children = expr.args
for keyword in expr.keywords:
if isinstance(keyword.value, ast.Name):
result.add(keyword.value.id)
else:
children = ast.iter_child_nodes(expr) # type: ignore[assignment]
for c in children:
result = result | _used_vars(c)
return result
def _lhs_vars(lhs: ast.expr) -> Set[str]:
"""Return set of assigned variables in the lhs of an assignment statement."""
def get_id(e):
assert isinstance(e, ast.Name), "Only simple assignments supported."
return e.id
if isinstance(lhs, ast.Tuple):
return {get_id(x) for x in lhs.elts}
return {get_id(lhs)}
def assigned_vars(
stmt: ast.stmt | list[ast.stmt], formatter: sourceinfo.Formatter
) -> Set[str]:
"""Return the set of all variables that may be assigned to in an execution of input stmt
or sequence of statements.
"""
def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]:
result: set[Any] = set()
for s in block:
result = result | assigned_vars(s, formatter)
return result
if isinstance(stmt, ast.Assign):
return _lhs_vars(stmt.targets[0])
if isinstance(stmt, ast.AnnAssign):
return _lhs_vars(stmt.target)
if isinstance(stmt, ast.Return):
return set()
if isinstance(stmt, ast.If):
return assigned_in_block(stmt.body) | assigned_in_block(stmt.orelse)
if isinstance(stmt, ast.For):
return assigned_in_block(stmt.body) | {_get_loop_var(stmt, formatter)}
if isinstance(stmt, ast.While):
return assigned_in_block(stmt.body)
if isinstance(stmt, list):
return assigned_in_block(stmt)
if isinstance(stmt, ast.Break):
return set()
if ast_utils.is_print_call(stmt):
return set()
if ast_utils.is_doc_string(stmt):
return set()
error_message = formatter(stmt, f"Unsupported statement type {type(stmt)!r}.")
raise ValueError(error_message)
def do_liveness_analysis(fun: ast.FunctionDef, formatter: sourceinfo.Formatter):
"""Perform liveness analysis of the given function-ast. The results of the
analysis are stored directly with each statement-ast `s` as attributes `s.live_in`
and `s.live_out`.
"""
def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
stmt.live_out = live_out # type: ignore[attr-defined]
live = do_visit(stmt, live_out)
stmt.live_in = live # type: ignore[attr-defined]
return live
def do_visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
for s in reversed(block):
live_out = visit(s, live_out)
return live_out
if isinstance(stmt, ast.Assign):
return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value)
if isinstance(stmt, ast.AnnAssign):
return live_out.difference(_lhs_vars(stmt.target)) | _used_vars(stmt.value)
if isinstance(stmt, ast.Return):
return _used_vars(stmt.value)
if isinstance(stmt, ast.If):
live1 = visitBlock(stmt.body, live_out)
live2 = visitBlock(stmt.orelse, live_out)
return live1 | live2 | _used_vars(stmt.test)
if isinstance(stmt, ast.For):
p_loop_var = _get_loop_var(stmt, formatter)
prev = None
curr = live_out
while curr != prev:
prev = curr
curr = visitBlock(stmt.body, prev).difference({p_loop_var})
return curr
if isinstance(stmt, ast.While):
cond_vars = _used_vars(stmt.test)
prev = None
curr = live_out | cond_vars
while curr != prev:
prev = curr
curr = visitBlock(stmt.body, prev) | cond_vars
return curr
if isinstance(stmt, ast.Break):
# The following is sufficient for the current restricted usage, where
# a (conditional) break is allowed only as the last statement of a loop.
# Break statements in the middle of the loop, however, will require
# a generalization.
return live_out
if ast_utils.is_doc_string(stmt):
return live_out
if isinstance(stmt, ast.FunctionDef):
return live_out
if ast_utils.is_print_call(stmt):
return live_out
raise ValueError(formatter(stmt, f"Unsupported statement type {type(stmt)!r}."))
assert isinstance(fun, ast.FunctionDef)
live: set[Any] = set()
for s in reversed(fun.body):
live = visit(s, live)
def exposed_uses(stmts: Sequence[ast.stmt], formatter: sourceinfo.Formatter):
"""Return the set of variables that are used before being defined by given block.
In essence, this identifies the "inputs" to a given code-block.
For example, consider the following code-block:
::
x = x + 10
y = 20
z = x + y
x = 30
The exposed_uses of this code-block is { x }. The value of z is not used within
the block. Even though the value of y is used within the block, it is assigned
a value before it is used. However, in contrast, the incoming value of x is used
(in the first statement). Hence x is included in the exposed_uses.
"""
def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
for stmt in reversed(block):
live_out = visit(stmt, live_out)
return live_out
def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
if isinstance(stmt, ast.Assign):
return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value)
if isinstance(stmt, ast.AnnAssign):
return live_out.difference(_lhs_vars(stmt.target)) | _used_vars(stmt.value)
if isinstance(stmt, ast.Return):
return _used_vars(stmt.value)
if isinstance(stmt, ast.If):
live1 = visitBlock(stmt.body, live_out)
live2 = visitBlock(stmt.orelse, live_out)
return (live1 | live2) | _used_vars(stmt.test)
if ast_utils.is_print_call(stmt):
return live_out
if ast_utils.is_doc_string(stmt):
return live_out
if isinstance(stmt, ast.For):
# Analysis assumes loop may execute zero times. Results can be improved
# for loops that execute at least once.
loop_var_set = {_get_loop_var(stmt, formatter)}
used_after_loop = live_out.difference(loop_var_set)
used_inside_loop = visitBlock(stmt.body, set()).difference(loop_var_set)
used_in_loop_header = _used_vars(stmt.iter)
return used_inside_loop | used_in_loop_header | used_after_loop
if isinstance(stmt, ast.While):
# Analysis assumes loop may execute zero times. Results can be improved
# for loops that execute at least once.
used_inside_loop = visitBlock(stmt.body, set())
used_in_loop_header = _used_vars(stmt.test)
return used_inside_loop | used_in_loop_header | live_out
if isinstance(stmt, ast.Break):
# Currently, we assume that break statements are only allowed as the last
# statement in a loop, as "if cond: break".
return live_out
if isinstance(stmt, ast.FunctionDef):
if stmt.name in live_out:
live_out.remove(stmt.name)
live_out = live_out | outer_scope_variables(stmt, formatter)
return live_out
raise ValueError(formatter(stmt, f"Unsupported statement type {type(stmt)!r}."))
return visitBlock(stmts, set())
def outer_scope_variables(fun: ast.FunctionDef, formatter: sourceinfo.Formatter):
"""Return the set of outer-scope variables used in a nested function.
Args:
fun: The function-ast to analyze.
formatter: The formatter object.
Returns:
A set of variable names (strings).
"""
assert isinstance(fun, ast.FunctionDef)
used_vars_ = exposed_uses(fun.body, formatter)
inputs = [x.arg for x in fun.args.args]
return used_vars_.difference(inputs)
|