Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import ast | |
| def unparse(node: ast.AST | None) -> str: | |
| if node is None: | |
| return "" | |
| try: | |
| return ast.unparse(node) | |
| except Exception: | |
| return "<expression>" | |
| def _attr_name(node: ast.AST) -> str: | |
| return node.attr if isinstance(node, ast.Attribute) else "" | |
| def _call_attr(node: ast.Call) -> ast.Attribute | None: | |
| return node.func if isinstance(node.func, ast.Attribute) else None | |
| def _base_target_from_inplace_chain(node: ast.Call) -> str: | |
| """Return the real mutated tensor for chained in-place calls. | |
| Example: | |
| exp_avg.mul_(beta1).add_(grad, alpha=1-beta1) | |
| The final ``add_`` call has ``node.func.value`` equal to the call | |
| ``exp_avg.mul_(beta1)``. The tensor being mutated is still ``exp_avg``. | |
| """ | |
| attr = _call_attr(node) | |
| if attr is None: | |
| return "tensor" | |
| value = attr.value | |
| while isinstance(value, ast.Call) and isinstance(value.func, ast.Attribute): | |
| value = value.func.value | |
| return unparse(value) | |
| def _chained_calls(node: ast.Call) -> list[ast.Call]: | |
| """Return calls in execution order for an in-place call chain.""" | |
| calls: list[ast.Call] = [] | |
| current: ast.AST = node | |
| while isinstance(current, ast.Call): | |
| calls.append(current) | |
| if isinstance(current.func, ast.Attribute): | |
| current = current.func.value | |
| else: | |
| break | |
| return list(reversed(calls)) | |
| def _keyword_value(node: ast.Call, name: str, default: str = "") -> str: | |
| for kw in node.keywords: | |
| if kw.arg == name: | |
| return unparse(kw.value) | |
| return default | |
| def describe_condition(node: ast.AST) -> str: | |
| text = unparse(node) | |
| replacements = { | |
| "p.grad is None": "parameter has no gradient", | |
| "closure is not None": "closure exists", | |
| "weight_decay != 0": "weight decay is nonzero", | |
| "weight_decay != 0.0": "weight decay is nonzero", | |
| "len(state) == 0": "optimizer state is empty for this parameter", | |
| '"momentum_buffer" not in state': "momentum buffer does not exist", | |
| "'momentum_buffer' not in state": "momentum buffer does not exist", | |
| "grad.is_sparse": "gradient is sparse", | |
| "norm > 0": "matrix norm is positive", | |
| "update.ndim == 2": "update is a 2D matrix", | |
| } | |
| return replacements.get(text, text) | |
| def describe_expression(node: ast.AST) -> str: | |
| text = unparse(node) | |
| replacements = { | |
| "self.param_groups": "parameter groups", | |
| 'group["params"]': "parameters in group", | |
| "group['params']": "parameters in group", | |
| 'group["lr"]': "learning rate", | |
| "group['lr']": "learning rate", | |
| 'group["momentum"]': "momentum coefficient", | |
| "group['momentum']": "momentum coefficient", | |
| 'group["weight_decay"]': "weight decay coefficient", | |
| "group['weight_decay']": "weight decay coefficient", | |
| 'group["betas"]': "Adam beta coefficients", | |
| "group['betas']": "Adam beta coefficients", | |
| "p.grad": "parameter gradient", | |
| "self.state[p]": "persistent optimizer state for parameter", | |
| "torch.zeros_like(p)": "zero tensor with parameter shape", | |
| "torch.clone(grad).detach()": "detached copy of gradient", | |
| "grad.sign()": "sign of gradient", | |
| "state['exp_avg']": "first moment estimate", | |
| 'state["exp_avg"]': "first moment estimate", | |
| "state['exp_avg_sq']": "second moment estimate", | |
| 'state["exp_avg_sq"]': "second moment estimate", | |
| "state['momentum_buffer']": "momentum buffer", | |
| 'state["momentum_buffer"]': "momentum buffer", | |
| } | |
| return replacements.get(text, text) | |
| def _describe_chained_inplace(node: ast.Call) -> str | None: | |
| calls = _chained_calls(node) | |
| if len(calls) < 2: | |
| return None | |
| target = _base_target_from_inplace_chain(node) | |
| method_names = [ | |
| _attr_name(call.func) for call in calls if isinstance(call.func, ast.Attribute) | |
| ] | |
| # exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) | |
| if method_names == ["mul_", "add_"]: | |
| mul_call, add_call = calls | |
| scale = unparse(mul_call.args[0]) if mul_call.args else "coefficient" | |
| arg = describe_expression(add_call.args[0]) if add_call.args else "tensor" | |
| alpha = _keyword_value(add_call, "alpha", "1") | |
| return f"UPDATE {target}: {target} ← {scale} · {target} + {alpha} · {arg}" | |
| # exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) | |
| if method_names == ["mul_", "addcmul_"]: | |
| mul_call, addcmul_call = calls | |
| scale = unparse(mul_call.args[0]) if mul_call.args else "coefficient" | |
| left = ( | |
| unparse(addcmul_call.args[0]) if len(addcmul_call.args) >= 1 else "tensor_a" | |
| ) | |
| right = ( | |
| unparse(addcmul_call.args[1]) if len(addcmul_call.args) >= 2 else "tensor_b" | |
| ) | |
| value = _keyword_value(addcmul_call, "value", "1") | |
| if left == right: | |
| product = f"{left}²" | |
| else: | |
| product = f"{left} ⊙ {right}" | |
| return f"UPDATE {target}: {target} ← {scale} · {target} + {value} · {product}" | |
| # exp_avg_sq.sqrt().div_(...).add_(eps) | |
| if method_names == ["sqrt", "div_", "add_"]: | |
| sqrt_call, div_call, add_call = calls | |
| base = _base_target_from_inplace_chain(sqrt_call) | |
| divisor = unparse(div_call.args[0]) if div_call.args else "divisor" | |
| eps = unparse(add_call.args[0]) if add_call.args else "epsilon" | |
| return f"COMPUTE DENOMINATOR: sqrt({base}) / {divisor} + {eps}" | |
| return None | |
| def describe_call(node: ast.Call) -> str: | |
| chained = _describe_chained_inplace(node) | |
| if chained is not None: | |
| return chained | |
| text = unparse(node) | |
| func_text = unparse(node.func) | |
| if func_text.endswith(".add_") and len(node.args) >= 1: | |
| target = _base_target_from_inplace_chain(node) | |
| arg = describe_expression(node.args[0]) | |
| alpha = _keyword_value(node, "alpha", "") | |
| if alpha and alpha.startswith("-"): | |
| return f"UPDATE {target}: {target} ← {target} - {alpha[1:]} · {arg}" | |
| if alpha: | |
| return f"UPDATE {target}: {target} ← {target} + {alpha} · {arg}" | |
| return f"ADD {arg} TO {target} IN PLACE" | |
| if func_text.endswith(".mul_") and len(node.args) == 1: | |
| target = _base_target_from_inplace_chain(node) | |
| return ( | |
| f"SCALE {target}: {target} ← {target} · {describe_expression(node.args[0])}" | |
| ) | |
| if func_text.endswith(".addcmul_") and len(node.args) >= 2: | |
| target = _base_target_from_inplace_chain(node) | |
| value = _keyword_value(node, "value", "1") | |
| left = unparse(node.args[0]) | |
| right = unparse(node.args[1]) | |
| product = f"{left}²" if left == right else f"{left} ⊙ {right}" | |
| return f"UPDATE {target}: {target} ← {target} + {value} · {product}" | |
| if func_text.endswith(".addcdiv_") and len(node.args) >= 2: | |
| target = _base_target_from_inplace_chain(node) | |
| value = _keyword_value(node, "value", "1") | |
| sign = "-" if value.startswith("-") else "+" | |
| coefficient = value[1:] if value.startswith("-") else value | |
| return f"UPDATE {target}: {target} ← {target} {sign} {coefficient} · {unparse(node.args[0])} / {unparse(node.args[1])}" | |
| if func_text == "super().__init__": | |
| return "INITIALIZE BASE OPTIMIZER WITH PARAMETERS AND DEFAULT HYPERPARAMETERS" | |
| if func_text == "torch.clone": | |
| return f"COPY {unparse(node.args[0])}" | |
| if func_text.endswith("zero_grad"): | |
| return "CLEAR PARAMETER GRADIENTS" | |
| if func_text.endswith("backward"): | |
| return "COMPUTE GRADIENTS BY BACKPROPAGATION" | |
| if func_text.endswith("step"): | |
| return "APPLY OPTIMIZER STEP" | |
| return f"CALL {text}" | |