from __future__ import annotations import ast def unparse(node: ast.AST | None) -> str: if node is None: return "" try: return ast.unparse(node) except Exception: return "" def _attr_name(node: ast.AST) -> str: return node.attr if isinstance(node, ast.Attribute) else "" def _call_attr(node: ast.Call) -> ast.Attribute | None: return node.func if isinstance(node.func, ast.Attribute) else None def _base_target_from_inplace_chain(node: ast.Call) -> str: """Return the real mutated tensor for chained in-place calls. Example: exp_avg.mul_(beta1).add_(grad, alpha=1-beta1) The final ``add_`` call has ``node.func.value`` equal to the call ``exp_avg.mul_(beta1)``. The tensor being mutated is still ``exp_avg``. """ attr = _call_attr(node) if attr is None: return "tensor" value = attr.value while isinstance(value, ast.Call) and isinstance(value.func, ast.Attribute): value = value.func.value return unparse(value) def _chained_calls(node: ast.Call) -> list[ast.Call]: """Return calls in execution order for an in-place call chain.""" calls: list[ast.Call] = [] current: ast.AST = node while isinstance(current, ast.Call): calls.append(current) if isinstance(current.func, ast.Attribute): current = current.func.value else: break return list(reversed(calls)) def _keyword_value(node: ast.Call, name: str, default: str = "") -> str: for kw in node.keywords: if kw.arg == name: return unparse(kw.value) return default def describe_condition(node: ast.AST) -> str: text = unparse(node) replacements = { "p.grad is None": "parameter has no gradient", "closure is not None": "closure exists", "weight_decay != 0": "weight decay is nonzero", "weight_decay != 0.0": "weight decay is nonzero", "len(state) == 0": "optimizer state is empty for this parameter", '"momentum_buffer" not in state': "momentum buffer does not exist", "'momentum_buffer' not in state": "momentum buffer does not exist", "grad.is_sparse": "gradient is sparse", "norm > 0": "matrix norm is positive", "update.ndim == 2": "update is a 2D matrix", } return replacements.get(text, text) def describe_expression(node: ast.AST) -> str: text = unparse(node) replacements = { "self.param_groups": "parameter groups", 'group["params"]': "parameters in group", "group['params']": "parameters in group", 'group["lr"]': "learning rate", "group['lr']": "learning rate", 'group["momentum"]': "momentum coefficient", "group['momentum']": "momentum coefficient", 'group["weight_decay"]': "weight decay coefficient", "group['weight_decay']": "weight decay coefficient", 'group["betas"]': "Adam beta coefficients", "group['betas']": "Adam beta coefficients", "p.grad": "parameter gradient", "self.state[p]": "persistent optimizer state for parameter", "torch.zeros_like(p)": "zero tensor with parameter shape", "torch.clone(grad).detach()": "detached copy of gradient", "grad.sign()": "sign of gradient", "state['exp_avg']": "first moment estimate", 'state["exp_avg"]': "first moment estimate", "state['exp_avg_sq']": "second moment estimate", 'state["exp_avg_sq"]': "second moment estimate", "state['momentum_buffer']": "momentum buffer", 'state["momentum_buffer"]': "momentum buffer", } return replacements.get(text, text) def _describe_chained_inplace(node: ast.Call) -> str | None: calls = _chained_calls(node) if len(calls) < 2: return None target = _base_target_from_inplace_chain(node) method_names = [ _attr_name(call.func) for call in calls if isinstance(call.func, ast.Attribute) ] # exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) if method_names == ["mul_", "add_"]: mul_call, add_call = calls scale = unparse(mul_call.args[0]) if mul_call.args else "coefficient" arg = describe_expression(add_call.args[0]) if add_call.args else "tensor" alpha = _keyword_value(add_call, "alpha", "1") return f"UPDATE {target}: {target} ← {scale} · {target} + {alpha} · {arg}" # exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) if method_names == ["mul_", "addcmul_"]: mul_call, addcmul_call = calls scale = unparse(mul_call.args[0]) if mul_call.args else "coefficient" left = ( unparse(addcmul_call.args[0]) if len(addcmul_call.args) >= 1 else "tensor_a" ) right = ( unparse(addcmul_call.args[1]) if len(addcmul_call.args) >= 2 else "tensor_b" ) value = _keyword_value(addcmul_call, "value", "1") if left == right: product = f"{left}²" else: product = f"{left} ⊙ {right}" return f"UPDATE {target}: {target} ← {scale} · {target} + {value} · {product}" # exp_avg_sq.sqrt().div_(...).add_(eps) if method_names == ["sqrt", "div_", "add_"]: sqrt_call, div_call, add_call = calls base = _base_target_from_inplace_chain(sqrt_call) divisor = unparse(div_call.args[0]) if div_call.args else "divisor" eps = unparse(add_call.args[0]) if add_call.args else "epsilon" return f"COMPUTE DENOMINATOR: sqrt({base}) / {divisor} + {eps}" return None def describe_call(node: ast.Call) -> str: chained = _describe_chained_inplace(node) if chained is not None: return chained text = unparse(node) func_text = unparse(node.func) if func_text.endswith(".add_") and len(node.args) >= 1: target = _base_target_from_inplace_chain(node) arg = describe_expression(node.args[0]) alpha = _keyword_value(node, "alpha", "") if alpha and alpha.startswith("-"): return f"UPDATE {target}: {target} ← {target} - {alpha[1:]} · {arg}" if alpha: return f"UPDATE {target}: {target} ← {target} + {alpha} · {arg}" return f"ADD {arg} TO {target} IN PLACE" if func_text.endswith(".mul_") and len(node.args) == 1: target = _base_target_from_inplace_chain(node) return ( f"SCALE {target}: {target} ← {target} · {describe_expression(node.args[0])}" ) if func_text.endswith(".addcmul_") and len(node.args) >= 2: target = _base_target_from_inplace_chain(node) value = _keyword_value(node, "value", "1") left = unparse(node.args[0]) right = unparse(node.args[1]) product = f"{left}²" if left == right else f"{left} ⊙ {right}" return f"UPDATE {target}: {target} ← {target} + {value} · {product}" if func_text.endswith(".addcdiv_") and len(node.args) >= 2: target = _base_target_from_inplace_chain(node) value = _keyword_value(node, "value", "1") sign = "-" if value.startswith("-") else "+" coefficient = value[1:] if value.startswith("-") else value return f"UPDATE {target}: {target} ← {target} {sign} {coefficient} · {unparse(node.args[0])} / {unparse(node.args[1])}" if func_text == "super().__init__": return "INITIALIZE BASE OPTIMIZER WITH PARAMETERS AND DEFAULT HYPERPARAMETERS" if func_text == "torch.clone": return f"COPY {unparse(node.args[0])}" if func_text.endswith("zero_grad"): return "CLEAR PARAMETER GRADIENTS" if func_text.endswith("backward"): return "COMPUTE GRADIENTS BY BACKPROPAGATION" if func_text.endswith("step"): return "APPLY OPTIMIZER STEP" return f"CALL {text}"