from __future__ import annotations import ast import math import os import re from dataclasses import dataclass from typing import Any from env.executor import run_code as execute_submission PROBE_TIMEOUT_SECONDS = 2.0 METRICS_PATTERN = re.compile(r"ADAPT_METRICS:\s*time_ms=([0-9.]+)\s+peak_kb=([0-9.]+)") @dataclass class ComplexitySignals: nested_loop_depth: int = 0 list_comprehensions: int = 0 set_comprehensions: int = 0 dict_comprehensions: int = 0 generator_expressions: int = 0 sorting_calls: int = 0 materialized_builtin_inputs: int = 0 class ComplexityVisitor(ast.NodeVisitor): def __init__(self) -> None: self.signals = ComplexitySignals() self._loop_depth = 0 def visit_For(self, node: ast.For) -> Any: self._loop_depth += 1 self.signals.nested_loop_depth = max(self.signals.nested_loop_depth, self._loop_depth) self.generic_visit(node) self._loop_depth -= 1 def visit_While(self, node: ast.While) -> Any: self._loop_depth += 1 self.signals.nested_loop_depth = max(self.signals.nested_loop_depth, self._loop_depth) self.generic_visit(node) self._loop_depth -= 1 def visit_ListComp(self, node: ast.ListComp) -> Any: self.signals.list_comprehensions += 1 self.generic_visit(node) def visit_SetComp(self, node: ast.SetComp) -> Any: self.signals.set_comprehensions += 1 self.generic_visit(node) def visit_DictComp(self, node: ast.DictComp) -> Any: self.signals.dict_comprehensions += 1 self.generic_visit(node) def visit_GeneratorExp(self, node: ast.GeneratorExp) -> Any: self.signals.generator_expressions += 1 self.generic_visit(node) def visit_Call(self, node: ast.Call) -> Any: fn_name = _call_name(node.func) if fn_name in {"sorted", "list.sort"}: self.signals.sorting_calls += 1 if fn_name in {"sum", "max", "min", "any", "all", "len"} and node.args: first_arg = node.args[0] if isinstance(first_arg, (ast.ListComp, ast.SetComp, ast.DictComp, ast.List, ast.Set, ast.Dict)): self.signals.materialized_builtin_inputs += 1 self.generic_visit(node) def _call_name(node: ast.AST) -> str: if isinstance(node, ast.Name): return node.id if isinstance(node, ast.Attribute): parent = _call_name(node.value) return f"{parent}.{node.attr}" if parent else node.attr return "" def _probe_timeout() -> float: raw_value = os.getenv("ADAPT_PROBE_TIMEOUT", str(PROBE_TIMEOUT_SECONDS)) try: timeout = float(raw_value) except ValueError: return PROBE_TIMEOUT_SECONDS return timeout if timeout > 0 else PROBE_TIMEOUT_SECONDS def size_hint_from_input(input_text: str) -> float: text = str(input_text or "") for line in text.splitlines(): stripped = line.strip() if not stripped: continue token = stripped.split()[0] try: return float(int(token)) except ValueError: break return float(len(text)) def _build_measurement_harness(code: str) -> str: return f""" import sys as _adapt_sys import time as _adapt_time import tracemalloc as _adapt_tracemalloc _adapt_globals = {{"__name__": "__main__"}} _adapt_source = {code!r} _adapt_tracemalloc.start() _adapt_t0 = _adapt_time.perf_counter() exec(compile(_adapt_source, "", "exec"), _adapt_globals, _adapt_globals) _adapt_t1 = _adapt_time.perf_counter() _adapt_peak_kb = _adapt_tracemalloc.get_traced_memory()[1] / 1024 _adapt_tracemalloc.stop() print( f"ADAPT_METRICS: time_ms={{(_adapt_t1 - _adapt_t0) * 1000:.3f}} peak_kb={{_adapt_peak_kb:.1f}}", file=_adapt_sys.stderr, ) """ def _parse_harness_output(stderr: str) -> tuple[float, float]: match = METRICS_PATTERN.search(str(stderr or "")) if match is None: return 0.0, 0.0 try: return float(match.group(1)), float(match.group(2)) except ValueError: return 0.0, 0.0 def _fit_scaling_exponent(sizes: list[float], values: list[float]) -> float: if len(sizes) < 2 or len(values) < 2: return 1.0 log_n = [math.log(max(size, 1.0)) for size in sizes] log_v = [math.log(max(value, 1e-6)) for value in values] count = len(log_n) mean_n = sum(log_n) / count mean_v = sum(log_v) / count numerator = sum((log_n[index] - mean_n) * (log_v[index] - mean_v) for index in range(count)) denominator = sum((log_n[index] - mean_n) ** 2 for index in range(count)) return numerator / denominator if denominator > 1e-9 else 1.0 def _exponent_to_score(alpha: float) -> float: if alpha < 0.1: return 1.0 if alpha < 1.2: return 0.85 if alpha < 1.6: return 0.75 if alpha < 2.3: return 0.50 if alpha < 3.2: return 0.20 return 0.0 def _memory_to_score(peak_kb: float) -> float: mb = peak_kb / 1024.0 if mb < 1: return 1.0 if mb < 10: return 0.85 if mb < 50: return 0.65 if mb < 256: return 0.40 return 0.10 def _hints_from_scores(time_score: float, space_score: float, time_alpha: float) -> list[str]: hints: list[str] = [] if time_alpha >= 2.3: hints.append("Reduce quadratic-or-worse work; measured runtime growth looks steep across larger inputs.") elif time_score < 0.85: hints.append("Consider a more scalable algorithm so runtime grows more gently with input size.") if space_score < 0.85: hints.append("Reduce peak memory usage by avoiding large intermediate containers when possible.") return hints def _merge_hints(*hint_groups: list[str]) -> list[str]: merged: list[str] = [] seen: set[str] = set() for group in hint_groups: for hint in group: if hint and hint not in seen: seen.add(hint) merged.append(hint) return merged def _heuristic_fallback(code: str) -> dict[str, Any]: try: tree = ast.parse(code) except SyntaxError: return { "time_complexity_score": 0.0, "space_complexity_score": 0.0, "efficiency_score": 0.0, "optimization_hints": [], "complexity_signals": {"measurement_source": "heuristic"}, } visitor = ComplexityVisitor() visitor.visit(tree) signals = visitor.signals time_penalty = 0.0 space_penalty = 0.0 hints: list[str] = [] if signals.nested_loop_depth > 1: time_penalty += 0.2 * (signals.nested_loop_depth - 1) hints.append("Reduce nested iteration if the problem can be solved in fewer passes.") if signals.sorting_calls: time_penalty += 0.1 * signals.sorting_calls hints.append("Avoid sorting unless it is required by the algorithm; it adds extra time complexity.") temporary_materializations = ( signals.list_comprehensions + signals.set_comprehensions + signals.dict_comprehensions + signals.materialized_builtin_inputs ) if temporary_materializations: space_penalty += 0.12 * temporary_materializations hints.append("Avoid materializing temporary containers when a streaming pass or generator expression is enough.") if signals.materialized_builtin_inputs: space_penalty += 0.08 * signals.materialized_builtin_inputs hints.append("Use generator expressions inside reducers like sum(...) instead of building an intermediate list.") time_score = max(0.0, 1.0 - min(time_penalty, 0.7)) space_score = max(0.0, 1.0 - min(space_penalty, 0.7)) efficiency_score = round(0.55 * time_score + 0.45 * space_score, 4) return { "time_complexity_score": round(time_score, 4), "space_complexity_score": round(space_score, 4), "efficiency_score": efficiency_score, "optimization_hints": hints, "complexity_signals": { "nested_loop_depth": signals.nested_loop_depth, "list_comprehensions": signals.list_comprehensions, "set_comprehensions": signals.set_comprehensions, "dict_comprehensions": signals.dict_comprehensions, "generator_expressions": signals.generator_expressions, "sorting_calls": signals.sorting_calls, "materialized_builtin_inputs": signals.materialized_builtin_inputs, "measurement_source": "heuristic", }, } def _empirical_complexity(code: str, probe_inputs: list[str]) -> dict[str, Any]: heuristic = _heuristic_fallback(code) if len(probe_inputs) < 3: return heuristic harness = _build_measurement_harness(code) sizes: list[float] = [] times: list[float] = [] mem_peaks: list[float] = [] for probe_input in probe_inputs: result = execute_submission(harness, probe_input, timeout_seconds=_probe_timeout()) if bool(result.get("timed_out")) or int(result.get("exit_code", 0)) != 0: return heuristic wall_ms, peak_kb = _parse_harness_output(str(result.get("stderr", ""))) if wall_ms <= 0.0 and peak_kb <= 0.0: return heuristic sizes.append(size_hint_from_input(probe_input)) times.append(float(wall_ms)) mem_peaks.append(float(peak_kb)) if len(sizes) < 3: return heuristic empirical_time_score = _exponent_to_score(_fit_scaling_exponent(sizes, times)) empirical_space_score = _memory_to_score(max(mem_peaks)) time_alpha = _fit_scaling_exponent(sizes, times) space_alpha = _fit_scaling_exponent(sizes, mem_peaks) time_score = min(empirical_time_score, float(heuristic["time_complexity_score"])) space_score = min(empirical_space_score, float(heuristic["space_complexity_score"])) efficiency_score = round(min(0.7 * time_score + 0.3 * space_score, float(heuristic["efficiency_score"])), 4) optimization_hints = _merge_hints( _hints_from_scores(time_score, space_score, time_alpha), list(heuristic.get("optimization_hints", [])), ) complexity_signals = dict(heuristic.get("complexity_signals", {})) complexity_signals.update( { "time_exponent": round(time_alpha, 4), "space_exponent": round(space_alpha, 4), "peak_memory_kb": round(max(mem_peaks), 1), "measurement_source": "empirical", } ) return { "time_complexity_score": round(time_score, 4), "space_complexity_score": round(space_score, 4), "efficiency_score": efficiency_score, "optimization_hints": optimization_hints, "complexity_signals": complexity_signals, } def analyze_code_complexity(code: str, probe_inputs: list[str] | None = None) -> dict[str, Any]: if probe_inputs and len(probe_inputs) >= 3: try: return _empirical_complexity(code, probe_inputs) except Exception: pass return _heuristic_fallback(code)