| from __future__ import annotations |
|
|
| import ast |
| import copy |
| from dataclasses import dataclass |
| from itertools import zip_longest |
| from typing import Any, Dict, Protocol |
|
|
|
|
| @dataclass(frozen=True) |
| class TransformationResult: |
| """Output of applying a transformation (placeholder).""" |
|
|
| code: str |
| changed: bool |
| metadata: Dict[str, Any] |
|
|
|
|
| class Transformation(Protocol): |
| """Protocol for a code transformation.""" |
|
|
| name: str |
|
|
| def apply(self, code: str) -> TransformationResult: ... |
|
|
|
|
| def noop_transformation(code: str) -> TransformationResult: |
| """Baseline transformation that leaves code unchanged.""" |
| return TransformationResult(code=code, changed=False, metadata={"kind": "noop"}) |
|
|
|
|
| def _finalize_result(*, original: str, out: str, meta: Dict[str, Any]) -> TransformationResult: |
| """ |
| Standardize metadata across transformations. |
| |
| - Adds `lines_changed` and `impact` for explainability/metrics. |
| - Ensures formatting-only changes don't count as `changed`. |
| """ |
|
|
| def _count_lines_changed(a: str, b: str) -> int: |
| a_lines = a.splitlines() |
| b_lines = b.splitlines() |
| changed = 0 |
| for x, y in zip_longest(a_lines, b_lines, fillvalue=None): |
| if x != y: |
| changed += 1 |
| return int(changed) |
|
|
| lines_changed = _count_lines_changed(original, out) |
|
|
| |
| |
| if out.strip() == original.strip(): |
| meta["success"] = False |
| meta["lines_changed"] = 0 |
| meta["impact"] = "low" |
| return TransformationResult(code=original, changed=False, metadata=meta) |
|
|
| meta["lines_changed"] = lines_changed |
| meta["impact"] = "high" if lines_changed >= 3 else "low" |
| meta["success"] = True |
| return TransformationResult(code=out, changed=True, metadata=meta) |
|
|
|
|
| def _unchanged(*, code: str, meta: Dict[str, Any]) -> TransformationResult: |
| meta.setdefault("success", False) |
| meta.setdefault("lines_changed", 0) |
| meta.setdefault("impact", "low") |
| return TransformationResult(code=code, changed=False, metadata=meta) |
|
|
|
|
| def rename_variable(code: str) -> TransformationResult: |
| """ |
| Rename simple, generic variable names to more descriptive ones. |
| |
| Hackathon-scope heuristic: |
| - Rename generic names in priority order: x, tmp, i. |
| - Uses descriptive base names and avoids collisions. |
| - Applies to Name nodes and function args. |
| """ |
| meta: Dict[str, Any] = {"type": "rename_variable", "success": False} |
| try: |
| tree = ast.parse(code) |
|
|
| class _NameCollector(ast.NodeVisitor): |
| def __init__(self) -> None: |
| self.names: set[str] = set() |
|
|
| def visit_Name(self, node: ast.Name) -> None: |
| self.names.add(node.id) |
|
|
| def visit_arg(self, node: ast.arg) -> None: |
| self.names.add(node.arg) |
|
|
| collector = _NameCollector() |
| collector.visit(tree) |
|
|
| rename_plan = [ |
| ("x", "value"), |
| ("tmp", "temp_value"), |
| ("i", "index"), |
| ] |
|
|
| old = "" |
| base_new = "value" |
| for candidate_old, candidate_base in rename_plan: |
| if candidate_old in collector.names: |
| old = candidate_old |
| base_new = candidate_base |
| break |
|
|
| if not old: |
| return _unchanged(code=code, meta=meta) |
|
|
| new = base_new |
| i = 1 |
| while new in collector.names: |
| new = f"{base_new}{i}" |
| i += 1 |
|
|
| class _Renamer(ast.NodeTransformer): |
| def __init__(self, old_name: str, new_name: str) -> None: |
| self.old_name = old_name |
| self.new_name = new_name |
| self.changed = False |
|
|
| def visit_Name(self, node: ast.Name) -> ast.AST: |
| if node.id == self.old_name: |
| self.changed = True |
| return ast.copy_location(ast.Name(id=self.new_name, ctx=node.ctx), node) |
| return node |
|
|
| def visit_arg(self, node: ast.arg) -> ast.AST: |
| if node.arg == self.old_name: |
| self.changed = True |
| new_node = copy.copy(node) |
| new_node.arg = self.new_name |
| return new_node |
| return node |
|
|
| renamer = _Renamer(old, new) |
| tree = renamer.visit(tree) |
| ast.fix_missing_locations(tree) |
|
|
| if not renamer.changed: |
| return _unchanged(code=code, meta=meta) |
|
|
| out = ast.unparse(tree) |
| meta["old"] = old |
| meta["new"] = new |
| |
| return _finalize_result(original=code, out=out, meta=meta) |
| except Exception: |
| return _unchanged(code=code, meta=meta) |
|
|
|
|
| def remove_dead_code(code: str) -> TransformationResult: |
| """ |
| Remove simple dead code patterns. |
| |
| Hackathon-scope heuristics: |
| - Drop statements after `return` / `raise` in the same block. |
| - Remove `if False: ...` blocks (keep `else` if present). |
| - Remove assignments to unused names in a block (very simple check). |
| """ |
| meta: Dict[str, Any] = {"type": "remove_dead_code", "success": False} |
|
|
| try: |
| tree = ast.parse(code) |
|
|
| def _is_const_bool(expr: ast.AST, value: bool) -> bool: |
| return isinstance(expr, ast.Constant) and isinstance(expr.value, bool) and expr.value is value |
|
|
| class _LoadNameCollector(ast.NodeVisitor): |
| def __init__(self) -> None: |
| self.loaded: set[str] = set() |
|
|
| def visit_Name(self, node: ast.Name) -> None: |
| if isinstance(node.ctx, ast.Load): |
| self.loaded.add(node.id) |
|
|
| class _DeadCode(ast.NodeTransformer): |
| def __init__(self) -> None: |
| self.changed = False |
|
|
| def _prune_unreachable(self, stmts: list[ast.stmt]) -> list[ast.stmt]: |
| out: list[ast.stmt] = [] |
| unreachable = False |
| for s in stmts: |
| if unreachable: |
| self.changed = True |
| continue |
| out.append(s) |
| if isinstance(s, (ast.Return, ast.Raise)): |
| unreachable = True |
| return out |
|
|
| def _remove_unused_assigns(self, stmts: list[ast.stmt]) -> list[ast.stmt]: |
| collector = _LoadNameCollector() |
| for s in stmts: |
| collector.visit(s) |
| used = collector.loaded |
|
|
| out: list[ast.stmt] = [] |
| for s in stmts: |
| if isinstance(s, ast.Assign) and all(isinstance(t, ast.Name) for t in s.targets): |
| targets = [t.id for t in s.targets if isinstance(t, ast.Name)] |
| |
| if targets and all(t not in used for t in targets): |
| self.changed = True |
| continue |
| if isinstance(s, ast.AnnAssign) and isinstance(s.target, ast.Name): |
| if s.target.id not in used: |
| self.changed = True |
| continue |
| out.append(s) |
| return out |
|
|
| def _clean_block(self, stmts: list[ast.stmt]) -> list[ast.stmt]: |
| |
| visited = [self.visit(s) for s in stmts] |
| flat: list[ast.stmt] = [] |
| for s in visited: |
| if s is None: |
| self.changed = True |
| continue |
| if isinstance(s, list): |
| flat.extend([x for x in s if isinstance(x, ast.stmt)]) |
| self.changed = True |
| else: |
| flat.append(s) |
|
|
| flat = self._prune_unreachable(flat) |
| flat = self._remove_unused_assigns(flat) |
| return flat |
|
|
| def visit_Module(self, node: ast.Module) -> ast.AST: |
| node.body = self._clean_block(node.body) |
| return node |
|
|
| def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: |
| node.body = self._clean_block(node.body) |
| return node |
|
|
| def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: |
| node.body = self._clean_block(node.body) |
| return node |
|
|
| def visit_If(self, node: ast.If) -> ast.AST | list[ast.stmt]: |
| node = self.generic_visit(node) |
| if _is_const_bool(node.test, False): |
| self.changed = True |
| return node.orelse or [] |
| return node |
|
|
| def visit_While(self, node: ast.While) -> ast.AST | None: |
| node = self.generic_visit(node) |
| if _is_const_bool(node.test, False): |
| self.changed = True |
| return None |
| return node |
|
|
| dc = _DeadCode() |
| tree = dc.visit(tree) |
| ast.fix_missing_locations(tree) |
| if not dc.changed: |
| return _unchanged(code=code, meta=meta) |
|
|
| out = ast.unparse(tree) |
| return _finalize_result(original=code, out=out, meta=meta) |
| except Exception: |
| return _unchanged(code=code, meta=meta) |
|
|
|
|
| def simplify_loops(code: str) -> TransformationResult: |
| """ |
| Simplify very basic loop patterns into more pythonic forms. |
| |
| Supported pattern (only when adjacent in the same block): |
| - xs = [] |
| for t in it: |
| xs.append(expr) |
| => xs = [expr for t in it] |
| """ |
| meta: Dict[str, Any] = {"type": "simplify_loops", "success": False} |
| try: |
| tree = ast.parse(code) |
|
|
| class _LoopSimplifier(ast.NodeTransformer): |
| def __init__(self) -> None: |
| self.changed = False |
|
|
| def _simplify_body(self, body: list[ast.stmt]) -> list[ast.stmt]: |
| out: list[ast.stmt] = [] |
| i = 0 |
| while i < len(body): |
| cur = body[i] |
| nxt = body[i + 1] if i + 1 < len(body) else None |
|
|
| if ( |
| isinstance(cur, ast.Assign) |
| and len(cur.targets) == 1 |
| and isinstance(cur.targets[0], ast.Name) |
| and isinstance(cur.value, ast.List) |
| and cur.value.elts == [] |
| and isinstance(nxt, ast.For) |
| and len(nxt.body) == 1 |
| and isinstance(nxt.body[0], ast.Expr) |
| and isinstance(nxt.body[0].value, ast.Call) |
| ): |
| list_name = cur.targets[0].id |
| call = nxt.body[0].value |
| if ( |
| isinstance(call.func, ast.Attribute) |
| and isinstance(call.func.value, ast.Name) |
| and call.func.value.id == list_name |
| and call.func.attr == "append" |
| and len(call.args) == 1 |
| and not call.keywords |
| ): |
| |
| comp = ast.ListComp( |
| elt=call.args[0], |
| generators=[ |
| ast.comprehension( |
| target=nxt.target, |
| iter=nxt.iter, |
| ifs=[], |
| is_async=0, |
| ) |
| ], |
| ) |
| new_assign = ast.Assign(targets=[ast.Name(id=list_name, ctx=ast.Store())], value=comp) |
| out.append(ast.copy_location(new_assign, cur)) |
| self.changed = True |
| i += 2 |
| continue |
|
|
| out.append(cur) |
| i += 1 |
|
|
| return out |
|
|
| def visit_Module(self, node: ast.Module) -> ast.AST: |
| node = self.generic_visit(node) |
| node.body = self._simplify_body(node.body) |
| return node |
|
|
| def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: |
| node = self.generic_visit(node) |
| node.body = self._simplify_body(node.body) |
| return node |
|
|
| def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: |
| node = self.generic_visit(node) |
| node.body = self._simplify_body(node.body) |
| return node |
|
|
| simp = _LoopSimplifier() |
| tree = simp.visit(tree) |
| ast.fix_missing_locations(tree) |
| if not simp.changed: |
| return _unchanged(code=code, meta=meta) |
|
|
| out = ast.unparse(tree) |
| return _finalize_result(original=code, out=out, meta=meta) |
| except Exception: |
| return _unchanged(code=code, meta=meta) |
|
|
|
|
| def simplify_loop(code: str) -> TransformationResult: |
| |
| return simplify_loops(code) |
|
|
|
|
| def optimize_condition(code: str) -> TransformationResult: |
| """ |
| Simplify redundant boolean conditions. |
| |
| Hackathon-scope heuristics: |
| - Replace `if True:` with its body; `if False:` with `else` (if present). |
| - Simplify `not not X` -> `X`. |
| - Simplify comparisons to True/False: `X == True` -> `X`, `X == False` -> `not X`. |
| """ |
| meta: Dict[str, Any] = {"type": "optimize_condition", "success": False} |
| try: |
| tree = ast.parse(code) |
|
|
| def _is_bool_const(node: ast.AST, value: bool) -> bool: |
| return isinstance(node, ast.Constant) and isinstance(node.value, bool) and node.value is value |
|
|
| class _CondOpt(ast.NodeTransformer): |
| def __init__(self) -> None: |
| self.changed = False |
|
|
| def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST: |
| node = self.generic_visit(node) |
| if isinstance(node.op, ast.Not) and isinstance(node.operand, ast.UnaryOp) and isinstance(node.operand.op, ast.Not): |
| self.changed = True |
| return node.operand.operand |
| return node |
|
|
| def visit_Compare(self, node: ast.Compare) -> ast.AST: |
| node = self.generic_visit(node) |
| if len(node.ops) == 1 and len(node.comparators) == 1: |
| op = node.ops[0] |
| rhs = node.comparators[0] |
| if isinstance(op, (ast.Eq, ast.Is)) and _is_bool_const(rhs, True): |
| self.changed = True |
| return node.left |
| if isinstance(op, (ast.Eq, ast.Is)) and _is_bool_const(rhs, False): |
| self.changed = True |
| return ast.UnaryOp(op=ast.Not(), operand=node.left) |
| return node |
|
|
| def visit_If(self, node: ast.If) -> ast.AST | list[ast.stmt]: |
| node = self.generic_visit(node) |
| if _is_bool_const(node.test, True): |
| self.changed = True |
| return node.body |
| if _is_bool_const(node.test, False): |
| self.changed = True |
| return node.orelse or [] |
| return node |
|
|
| opt = _CondOpt() |
| tree = opt.visit(tree) |
| ast.fix_missing_locations(tree) |
| if not opt.changed: |
| return _unchanged(code=code, meta=meta) |
|
|
| out = ast.unparse(tree) |
| return _finalize_result(original=code, out=out, meta=meta) |
| except Exception: |
| return _unchanged(code=code, meta=meta) |
|
|
|
|
| def inline_function(code: str) -> TransformationResult: |
| """ |
| Inline very simple functions into their call sites. |
| |
| Supported pattern: |
| - def f(a, b): return <expr using only a,b> |
| - Replace calls: f(x, y) -> <expr with a->x, b->y> |
| Only handles module-level functions and positional args. |
| """ |
| meta: Dict[str, Any] = {"type": "inline_function", "success": False} |
| try: |
| tree = ast.parse(code) |
|
|
| simple_fns: Dict[str, tuple[list[str], ast.AST]] = {} |
| for node in tree.body: |
| if not isinstance(node, ast.FunctionDef): |
| continue |
| if node.decorator_list: |
| continue |
| args = node.args |
| if args.vararg or args.kwarg or args.kwonlyargs or args.defaults or args.posonlyargs: |
| continue |
| if len(node.body) != 1 or not isinstance(node.body[0], ast.Return) or node.body[0].value is None: |
| continue |
| arg_names = [a.arg for a in args.args] |
| |
| referenced: set[str] = set() |
|
|
| class _Ref(ast.NodeVisitor): |
| def visit_Name(self, n: ast.Name) -> None: |
| if isinstance(n.ctx, ast.Load): |
| referenced.add(n.id) |
|
|
| _Ref().visit(node.body[0].value) |
| if not referenced.issubset(set(arg_names)): |
| continue |
| simple_fns[node.name] = (arg_names, node.body[0].value) |
|
|
| if not simple_fns: |
| return _unchanged(code=code, meta=meta) |
|
|
| class _Substitute(ast.NodeTransformer): |
| def __init__(self, mapping: Dict[str, ast.AST]) -> None: |
| self.mapping = mapping |
|
|
| def visit_Name(self, n: ast.Name) -> ast.AST: |
| if isinstance(n.ctx, ast.Load) and n.id in self.mapping: |
| return copy.deepcopy(self.mapping[n.id]) |
| return n |
|
|
| class _Inliner(ast.NodeTransformer): |
| def __init__(self) -> None: |
| self.changed = False |
|
|
| def visit_Call(self, node: ast.Call) -> ast.AST: |
| node = self.generic_visit(node) |
| if not isinstance(node.func, ast.Name): |
| return node |
| fn = simple_fns.get(node.func.id) |
| if fn is None: |
| return node |
| arg_names, expr = fn |
| if node.keywords or len(node.args) != len(arg_names): |
| return node |
| mapping = {name: arg for name, arg in zip(arg_names, node.args, strict=True)} |
| new_expr = _Substitute(mapping).visit(copy.deepcopy(expr)) |
| self.changed = True |
| return ast.copy_location(new_expr, node) |
|
|
| inliner = _Inliner() |
| tree = inliner.visit(tree) |
| ast.fix_missing_locations(tree) |
| if not inliner.changed: |
| return _unchanged(code=code, meta=meta) |
|
|
| out = ast.unparse(tree) |
| meta["inlined"] = sorted(simple_fns.keys()) |
| return _finalize_result(original=code, out=out, meta=meta) |
| except Exception: |
| return _unchanged(code=code, meta=meta) |
|
|
|
|