Spaces:
Build error
Build error
File size: 7,522 Bytes
4bf4bf6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | """Tools 2-4/9: profile_python_hotspots, analyze_complexity, check_memory_access.
Three static-analysis tools the agent uses to *understand the input code* before
writing C++. All run on the AST — no Python execution required for these tools
(the verifier and benchmarker do the actual execution, sandboxed).
"""
from __future__ import annotations
import ast
import re
from typing import Any
# ----------------- Tool 2: profile_python_hotspots ----------------
def profile_python_hotspots_tool(tool_args: dict[str, Any], state) -> dict[str, Any]:
"""Return the top hot lines of the Python function (static cost estimate).
For a static-analysis-only tool, we approximate hotness via:
- loop nesting depth at the line
- operations inside loops (multiplied by estimated trip count)
- presence of np.* calls (vectorized but still expensive on large arrays)
For a more accurate dynamic profile (cProfile run), pass `dynamic=True` —
that path will be wired to a sandboxed run in Hour 16+.
"""
code = tool_args.get("code") or state.python_code
try:
tree = ast.parse(code)
except SyntaxError as e:
return {"error": f"Python parse error: {e}", "hotspots": []}
hotspots: list[dict[str, Any]] = []
line_costs: dict[int, int] = {}
class HotspotVisitor(ast.NodeVisitor):
def __init__(self):
self.loop_depth = 0
def visit_For(self, node):
self.loop_depth += 1
self.generic_visit(node)
self.loop_depth -= 1
def visit_While(self, node):
self.loop_depth += 1
self.generic_visit(node)
self.loop_depth -= 1
def visit_BinOp(self, node):
cost = 1 << self.loop_depth # 2^depth — exponential weight per nesting
line_costs[node.lineno] = line_costs.get(node.lineno, 0) + cost
self.generic_visit(node)
def visit_Call(self, node):
# Penalize np.* calls inside loops more
cost = (1 << self.loop_depth) * 2
line_costs[node.lineno] = line_costs.get(node.lineno, 0) + cost
self.generic_visit(node)
HotspotVisitor().visit(tree)
code_lines = code.splitlines()
sorted_lines = sorted(line_costs.items(), key=lambda x: -x[1])
for lineno, cost in sorted_lines[:5]:
if 0 < lineno <= len(code_lines):
hotspots.append({
"line_number": lineno,
"estimated_cost": cost,
"source": code_lines[lineno - 1].strip(),
})
total_cost = sum(line_costs.values())
return {
"hotspots": hotspots,
"total_estimated_cost": total_cost,
"method": "static_ast_analysis",
"hint": "Lines deep in loops dominate; vectorize or parallelize them first.",
}
# ----------------- Tool 3: analyze_complexity ----------------
def analyze_complexity_tool(tool_args: dict[str, Any], state) -> dict[str, Any]:
"""Return Big-O class + max loop nesting depth via AST.
A loop nesting depth of k suggests O(n^k) in the typical case. Recursion
detection is naive (treats every recursive call as +1 to complexity).
"""
code = tool_args.get("code") or state.python_code
try:
tree = ast.parse(code)
except SyntaxError as e:
return {"error": f"Python parse error: {e}"}
max_depth = [0]
class DepthVisitor(ast.NodeVisitor):
def __init__(self):
self.depth = 0
def visit_For(self, node):
self.depth += 1
max_depth[0] = max(max_depth[0], self.depth)
self.generic_visit(node)
self.depth -= 1
def visit_While(self, node):
self.depth += 1
max_depth[0] = max(max_depth[0], self.depth)
self.generic_visit(node)
self.depth -= 1
DepthVisitor().visit(tree)
depth = max_depth[0]
if depth == 0:
big_o = "O(1)"
elif depth == 1:
big_o = "O(n)"
else:
big_o = f"O(n^{depth})"
# Detect simple recursion (function calls itself)
func_names = {n.name for n in ast.walk(tree) if isinstance(n, ast.FunctionDef)}
has_recursion = any(
isinstance(c.func, ast.Name) and c.func.id in func_names
for c in ast.walk(tree) if isinstance(c, ast.Call)
)
return {
"big_o_estimate": big_o,
"max_loop_nesting_depth": depth,
"has_recursion": has_recursion,
"method": "static_ast_loop_depth",
}
# ----------------- Tool 4: check_memory_access ----------------
# Patterns that suggest cache-unfriendly access
_STRIDE_PATTERN = re.compile(r"\[\s*j\s*,\s*i\s*\]|\[\s*i\s*\]\s*\[\s*j\s*\]")
_TRANSPOSE_PATTERN = re.compile(r"\.T\s*\[")
_NON_CONTIG_PATTERN = re.compile(r"\bnp\.ascontiguousarray\b|\bnp\.asfortranarray\b")
def check_memory_access_tool(tool_args: dict[str, Any], state) -> dict[str, Any]:
"""Detect cache-unfriendly stride patterns / aliasing risks via static patterns.
This is a heuristic — not perfect, but catches the common cases:
- column-major access in row-major arrays (D[j, i] inside i,j loops)
- non-contiguous arrays passed in
- explicit transpose in hot expression
"""
code = tool_args.get("code") or state.python_code
issues: list[dict[str, str]] = []
if _STRIDE_PATTERN.search(code):
issues.append({
"type": "non_unit_stride",
"severity": "high",
"hint": "Detected D[j,i]-style access — likely column-major in a row-major array. "
"Cache misses dominate. Transpose the layout or swap loop order."
})
if _TRANSPOSE_PATTERN.search(code):
issues.append({
"type": "in_loop_transpose",
"severity": "med",
"hint": "`.T` in hot path may force a copy or non-contiguous access."
})
if _NON_CONTIG_PATTERN.search(code):
issues.append({
"type": "explicit_layout_handling",
"severity": "info",
"hint": "Code already handles contiguity — good; preserve in C++ via `restrict`."
})
# Inspect AST for "for i in range" + "for j in range" + a 2D index
try:
tree = ast.parse(code)
nested_for = False
for node in ast.walk(tree):
if isinstance(node, ast.For):
for sub in ast.walk(node):
if isinstance(sub, ast.For) and sub is not node:
nested_for = True
break
if nested_for and not issues:
issues.append({
"type": "nested_loop_unanalyzed",
"severity": "low",
"hint": "Nested loops detected. Verify that inner-loop index varies the contiguous dimension."
})
except SyntaxError:
pass
aliasing_risk = "low"
if "np.ndarray" in code or "ndarray" in code:
aliasing_risk = "med" # numpy arrays can alias; agent should consider `restrict`
return {
"issues": issues,
"aliasing_risk": aliasing_risk,
"recommendation": (
"Use `__restrict__` qualifier on non-aliasing pointers in C++. "
"Prefer SoA over AoS for SIMD-friendly access."
if issues else "No obvious memory-access issues; proceed with default layout."
),
}
__all__ = [
"profile_python_hotspots_tool",
"analyze_complexity_tool",
"check_memory_access_tool",
]
|