OptimizationArena / arena /pseudocode /torch_patterns.py
mmkuznecov's picture
first version
6551a95
from __future__ import annotations
import ast
def unparse(node: ast.AST | None) -> str:
if node is None:
return ""
try:
return ast.unparse(node)
except Exception:
return "<expression>"
def _attr_name(node: ast.AST) -> str:
return node.attr if isinstance(node, ast.Attribute) else ""
def _call_attr(node: ast.Call) -> ast.Attribute | None:
return node.func if isinstance(node.func, ast.Attribute) else None
def _base_target_from_inplace_chain(node: ast.Call) -> str:
"""Return the real mutated tensor for chained in-place calls.
Example:
exp_avg.mul_(beta1).add_(grad, alpha=1-beta1)
The final ``add_`` call has ``node.func.value`` equal to the call
``exp_avg.mul_(beta1)``. The tensor being mutated is still ``exp_avg``.
"""
attr = _call_attr(node)
if attr is None:
return "tensor"
value = attr.value
while isinstance(value, ast.Call) and isinstance(value.func, ast.Attribute):
value = value.func.value
return unparse(value)
def _chained_calls(node: ast.Call) -> list[ast.Call]:
"""Return calls in execution order for an in-place call chain."""
calls: list[ast.Call] = []
current: ast.AST = node
while isinstance(current, ast.Call):
calls.append(current)
if isinstance(current.func, ast.Attribute):
current = current.func.value
else:
break
return list(reversed(calls))
def _keyword_value(node: ast.Call, name: str, default: str = "") -> str:
for kw in node.keywords:
if kw.arg == name:
return unparse(kw.value)
return default
def describe_condition(node: ast.AST) -> str:
text = unparse(node)
replacements = {
"p.grad is None": "parameter has no gradient",
"closure is not None": "closure exists",
"weight_decay != 0": "weight decay is nonzero",
"weight_decay != 0.0": "weight decay is nonzero",
"len(state) == 0": "optimizer state is empty for this parameter",
'"momentum_buffer" not in state': "momentum buffer does not exist",
"'momentum_buffer' not in state": "momentum buffer does not exist",
"grad.is_sparse": "gradient is sparse",
"norm > 0": "matrix norm is positive",
"update.ndim == 2": "update is a 2D matrix",
}
return replacements.get(text, text)
def describe_expression(node: ast.AST) -> str:
text = unparse(node)
replacements = {
"self.param_groups": "parameter groups",
'group["params"]': "parameters in group",
"group['params']": "parameters in group",
'group["lr"]': "learning rate",
"group['lr']": "learning rate",
'group["momentum"]': "momentum coefficient",
"group['momentum']": "momentum coefficient",
'group["weight_decay"]': "weight decay coefficient",
"group['weight_decay']": "weight decay coefficient",
'group["betas"]': "Adam beta coefficients",
"group['betas']": "Adam beta coefficients",
"p.grad": "parameter gradient",
"self.state[p]": "persistent optimizer state for parameter",
"torch.zeros_like(p)": "zero tensor with parameter shape",
"torch.clone(grad).detach()": "detached copy of gradient",
"grad.sign()": "sign of gradient",
"state['exp_avg']": "first moment estimate",
'state["exp_avg"]': "first moment estimate",
"state['exp_avg_sq']": "second moment estimate",
'state["exp_avg_sq"]': "second moment estimate",
"state['momentum_buffer']": "momentum buffer",
'state["momentum_buffer"]': "momentum buffer",
}
return replacements.get(text, text)
def _describe_chained_inplace(node: ast.Call) -> str | None:
calls = _chained_calls(node)
if len(calls) < 2:
return None
target = _base_target_from_inplace_chain(node)
method_names = [
_attr_name(call.func) for call in calls if isinstance(call.func, ast.Attribute)
]
# exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
if method_names == ["mul_", "add_"]:
mul_call, add_call = calls
scale = unparse(mul_call.args[0]) if mul_call.args else "coefficient"
arg = describe_expression(add_call.args[0]) if add_call.args else "tensor"
alpha = _keyword_value(add_call, "alpha", "1")
return f"UPDATE {target}: {target}{scale} · {target} + {alpha} · {arg}"
# exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if method_names == ["mul_", "addcmul_"]:
mul_call, addcmul_call = calls
scale = unparse(mul_call.args[0]) if mul_call.args else "coefficient"
left = (
unparse(addcmul_call.args[0]) if len(addcmul_call.args) >= 1 else "tensor_a"
)
right = (
unparse(addcmul_call.args[1]) if len(addcmul_call.args) >= 2 else "tensor_b"
)
value = _keyword_value(addcmul_call, "value", "1")
if left == right:
product = f"{left}²"
else:
product = f"{left}{right}"
return f"UPDATE {target}: {target}{scale} · {target} + {value} · {product}"
# exp_avg_sq.sqrt().div_(...).add_(eps)
if method_names == ["sqrt", "div_", "add_"]:
sqrt_call, div_call, add_call = calls
base = _base_target_from_inplace_chain(sqrt_call)
divisor = unparse(div_call.args[0]) if div_call.args else "divisor"
eps = unparse(add_call.args[0]) if add_call.args else "epsilon"
return f"COMPUTE DENOMINATOR: sqrt({base}) / {divisor} + {eps}"
return None
def describe_call(node: ast.Call) -> str:
chained = _describe_chained_inplace(node)
if chained is not None:
return chained
text = unparse(node)
func_text = unparse(node.func)
if func_text.endswith(".add_") and len(node.args) >= 1:
target = _base_target_from_inplace_chain(node)
arg = describe_expression(node.args[0])
alpha = _keyword_value(node, "alpha", "")
if alpha and alpha.startswith("-"):
return f"UPDATE {target}: {target}{target} - {alpha[1:]} · {arg}"
if alpha:
return f"UPDATE {target}: {target}{target} + {alpha} · {arg}"
return f"ADD {arg} TO {target} IN PLACE"
if func_text.endswith(".mul_") and len(node.args) == 1:
target = _base_target_from_inplace_chain(node)
return (
f"SCALE {target}: {target}{target} · {describe_expression(node.args[0])}"
)
if func_text.endswith(".addcmul_") and len(node.args) >= 2:
target = _base_target_from_inplace_chain(node)
value = _keyword_value(node, "value", "1")
left = unparse(node.args[0])
right = unparse(node.args[1])
product = f"{left}²" if left == right else f"{left}{right}"
return f"UPDATE {target}: {target}{target} + {value} · {product}"
if func_text.endswith(".addcdiv_") and len(node.args) >= 2:
target = _base_target_from_inplace_chain(node)
value = _keyword_value(node, "value", "1")
sign = "-" if value.startswith("-") else "+"
coefficient = value[1:] if value.startswith("-") else value
return f"UPDATE {target}: {target}{target} {sign} {coefficient} · {unparse(node.args[0])} / {unparse(node.args[1])}"
if func_text == "super().__init__":
return "INITIALIZE BASE OPTIMIZER WITH PARAMETERS AND DEFAULT HYPERPARAMETERS"
if func_text == "torch.clone":
return f"COPY {unparse(node.args[0])}"
if func_text.endswith("zero_grad"):
return "CLEAR PARAMETER GRADIENTS"
if func_text.endswith("backward"):
return "COMPUTE GRADIENTS BY BACKPROPAGATION"
if func_text.endswith("step"):
return "APPLY OPTIMIZER STEP"
return f"CALL {text}"