| """ |
| Three OpenEnv tasks with AST-based graders scoring 0.0-1.0. |
| """ |
| from __future__ import annotations |
|
|
| import ast |
| from dataclasses import dataclass |
| from typing import Callable, Dict, List, Optional, Sequence, Tuple |
|
|
| from acre.tasks.easy_task import EasyTask |
| from acre.tasks.hard_task import HardTask |
| from acre.tasks.medium_task import MediumTask |
| from acre.tasks.grader import grade_task |
|
|
|
|
| @dataclass |
| class Task: |
| id: str |
| name: str |
| description: str |
| difficulty: str |
| samples: List[str] |
| expected_outputs: List[str] |
| _grade_fn: Callable[[str], float] |
|
|
| @property |
| def initial_code(self) -> str: |
| return str(self.samples[0]) if self.samples else "" |
|
|
| def expected_output_for_index(self, idx: int) -> str: |
| if 0 <= idx < len(self.expected_outputs): |
| return str(self.expected_outputs[idx]) |
| return str(self.expected_outputs[0]) if self.expected_outputs else "" |
|
|
| def grade(self, code: str) -> float: |
| """Return a score in [0.0, 1.0].""" |
| try: |
| return float(min(1.0, max(0.0, self._grade_fn(code)))) |
| except Exception: |
| return 0.0 |
|
|
| def grade_against_expected(self, code: str) -> float: |
| """ |
| Deterministic grader comparing against this task's expected outputs. |
| |
| Since the HTTP `grade` endpoint doesn't know which sample was active, we |
| score against the best-matching expected output (still deterministic). |
| """ |
| if not self.expected_outputs: |
| return 0.0 |
| return float(max(grade_task(code, exp) for exp in self.expected_outputs)) |
|
|
|
|
| def _safe_unparse(tree: ast.AST) -> str: |
| try: |
| return ast.unparse(tree) |
| except Exception: |
| return "" |
|
|
|
|
| def _has_unreachable_after_terminator(stmts: Sequence[ast.stmt]) -> bool: |
| unreachable = False |
| for s in stmts: |
| if unreachable: |
| |
| if isinstance(s, ast.Expr) and isinstance(s.value, ast.Constant) and isinstance(s.value.value, str): |
| continue |
| return True |
| if isinstance(s, (ast.Return, ast.Raise)): |
| unreachable = True |
| return False |
|
|
|
|
| def _tree_has_unreachable(tree: ast.AST) -> bool: |
| class _Scan(ast.NodeVisitor): |
| def __init__(self) -> None: |
| self.bad = False |
|
|
| def visit_FunctionDef(self, node: ast.FunctionDef) -> None: |
| if _has_unreachable_after_terminator(node.body): |
| self.bad = True |
| self.generic_visit(node) |
|
|
| def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: |
| if _has_unreachable_after_terminator(node.body): |
| self.bad = True |
| self.generic_visit(node) |
|
|
| s = _Scan() |
| s.visit(tree) |
| return bool(s.bad) |
|
|
|
|
| |
| |
| |
| _EASY_SAMPLES: List[str] = [ |
| """\ |
| def compute(x, y, tmp): |
| tmp = x + y |
| x = tmp * 2 |
| result = x |
| return result |
| """, |
| """\ |
| def normalize(tmp, x): |
| for i in range(3): |
| tmp = tmp + i |
| return tmp * x |
| """, |
| """\ |
| def score(items): |
| tmp = 0 |
| for i in items: |
| tmp += i |
| x = tmp |
| return x |
| """, |
| """\ |
| def transform(x): |
| tmp = x |
| if tmp > 10: |
| tmp = tmp - 1 |
| return tmp |
| """, |
| """\ |
| def merge(a, b): |
| x = a |
| tmp = b |
| return x + tmp |
| """, |
| ] |
|
|
| _EASY_EXPECTED: List[str] = [ |
| EasyTask.expected_output, |
| """\ |
| def normalize(temp_value, value): |
| for index in range(3): |
| temp_value = temp_value + index |
| return temp_value * value |
| """, |
| """\ |
| def score(items): |
| total = 0 |
| for item in items: |
| total += item |
| value = total |
| return value |
| """, |
| """\ |
| def transform(value): |
| temp_value = value |
| if temp_value > 10: |
| temp_value = temp_value - 1 |
| return temp_value |
| """, |
| """\ |
| def merge(a, b): |
| left = a |
| right = b |
| return left + right |
| """, |
| ] |
|
|
|
|
| def _grade_easy(code: str) -> float: |
| """Score = fraction of generic names removed from all scopes.""" |
| generic = {"x", "tmp", "i"} |
| try: |
| tree = ast.parse(code) |
| except SyntaxError: |
| return 0.0 |
|
|
| remaining: set[str] = set() |
|
|
| class _Collector(ast.NodeVisitor): |
| def visit_Name(self, node: ast.Name) -> None: |
| if node.id in generic: |
| remaining.add(node.id) |
| self.generic_visit(node) |
|
|
| def visit_arg(self, node: ast.arg) -> None: |
| if node.arg in generic: |
| remaining.add(node.arg) |
| self.generic_visit(node) |
|
|
| _Collector().visit(tree) |
| renamed = len(generic - remaining) |
| return renamed / len(generic) |
|
|
|
|
| |
| |
| |
| _MEDIUM_SAMPLES: List[str] = [ |
| """\ |
| def process(data): |
| result = [] |
| for item in data: |
| result.append(item * 2) |
| if False: |
| print("never runs") |
| unused_var = 42 |
| return result |
| print("unreachable") |
| """, |
| """\ |
| def build(values): |
| out = [] |
| for v in values: |
| out.append(v + 1) |
| while False: |
| out.append(999) |
| dead = 0 |
| return out |
| dead += 1 |
| """, |
| """\ |
| def route(flag): |
| if False: |
| return 1 |
| if True: |
| x = 2 |
| y = x |
| return y |
| """, |
| """\ |
| def clean(xs): |
| res = [] |
| for x in xs: |
| res.append(x * 2) |
| unused = "remove me" |
| if False: |
| unused2 = 123 |
| return res |
| """, |
| """\ |
| def calc(n): |
| total = 0 |
| for i in range(n): |
| total += i |
| return total |
| print("dead") |
| """, |
| ] |
|
|
| _MEDIUM_EXPECTED: List[str] = [ |
| MediumTask.expected_output, |
| """\ |
| def build(values): |
| return [v + 1 for v in values] |
| """, |
| """\ |
| def route(flag): |
| x = 2 |
| y = x |
| return y |
| """, |
| """\ |
| def clean(xs): |
| return [x * 2 for x in xs] |
| """, |
| """\ |
| def calc(n): |
| total = 0 |
| for index in range(n): |
| total += index |
| return total |
| """, |
| ] |
|
|
|
|
| def _grade_medium(code: str) -> float: |
| """Score = fraction of dead-code patterns eliminated (4 checks, 0.25 each).""" |
| try: |
| tree = ast.parse(code) |
| except SyntaxError: |
| return 0.0 |
|
|
| source = _safe_unparse(tree) |
| score = 0.0 |
|
|
| |
| if ("if False" not in source) and ("while False" not in source): |
| score += 0.25 |
|
|
| |
| if not _tree_has_unreachable(tree): |
| score += 0.25 |
|
|
| |
| has_listcomp = any(isinstance(n, ast.ListComp) for n in ast.walk(tree)) |
| if has_listcomp: |
| score += 0.25 |
|
|
| |
| if all(name not in source for name in ["unused_var", "unused", "dead", "unused2"]): |
| score += 0.25 |
|
|
| return score |
|
|
|
|
| |
| |
| |
| _HARD_SAMPLES: List[str] = [ |
| """\ |
| def add(p, q): |
| return p + q |
| |
| def compute(x, data, tmp): |
| result = [] |
| for item in data: |
| result.append(item * 2) |
| if False: |
| y = 999 |
| if True: |
| val = add(x, tmp) |
| unused = 0 |
| flag = not not True |
| return val |
| print("dead") |
| """, |
| """\ |
| def helper(a, b): |
| return a + b |
| |
| def pipeline(tmp, xs, x): |
| out = [] |
| for i in xs: |
| out.append(i * 2) |
| if True: |
| y = helper(tmp, x) |
| if False: |
| y = 0 |
| return y |
| y = 123 |
| """, |
| """\ |
| def add(p, q): |
| return p + q |
| |
| def compute(x, data, tmp): |
| result = [] |
| for item in data: |
| result.append(item * 2) |
| if False: |
| print("never") |
| val = add(x, tmp) |
| return val |
| """, |
| """\ |
| def add(p, q): |
| return p + q |
| |
| def compute(x, data, tmp): |
| res = [] |
| for item in data: |
| res.append(item * 2) |
| flag = not not True |
| if True: |
| return add(x, tmp) |
| """, |
| """\ |
| def plus(p, q): |
| return p + q |
| |
| def compute(tmp, data, x): |
| out = [] |
| for item in data: |
| out.append(item * 2) |
| if False: |
| tmp = 999 |
| if True: |
| val = plus(x, tmp) |
| return val |
| """, |
| ] |
|
|
| _HARD_EXPECTED: List[str] = [ |
| HardTask.expected_output, |
| """\ |
| def pipeline(offset, xs, value): |
| _ = [item * 2 for item in xs] |
| return offset + value |
| """, |
| """\ |
| def compute(value, data, offset): |
| _ = [item * 2 for item in data] |
| return value + offset |
| """, |
| """\ |
| def compute(value, data, offset): |
| _ = [item * 2 for item in data] |
| return value + offset |
| """, |
| """\ |
| def compute(offset, data, value): |
| _ = [item * 2 for item in data] |
| return value + offset |
| """, |
| ] |
|
|
|
|
| def _grade_hard(code: str) -> float: |
| """Score = fraction of 7 quality checks passed.""" |
| try: |
| tree = ast.parse(code) |
| except SyntaxError: |
| return 0.0 |
|
|
| source = _safe_unparse(tree) |
| checks = 0 |
|
|
| |
| has_generic = False |
|
|
| class _GenCheck(ast.NodeVisitor): |
| def visit_arg(self, node: ast.arg) -> None: |
| nonlocal has_generic |
| if node.arg in {"x", "tmp", "i"}: |
| has_generic = True |
|
|
| _GenCheck().visit(tree) |
| if not has_generic: |
| checks += 1 |
|
|
| |
| if ("if False" not in source) and ("while False" not in source): |
| checks += 1 |
|
|
| |
| if "if True" not in source: |
| checks += 1 |
|
|
| |
| if any(isinstance(n, ast.ListComp) for n in ast.walk(tree)): |
| checks += 1 |
|
|
| |
| calls = [n for n in ast.walk(tree) if isinstance(n, ast.Call)] |
| fn_names = {c.func.id for c in calls if isinstance(c.func, ast.Name)} |
| if not ({"add", "plus", "helper"} & fn_names): |
| checks += 1 |
|
|
| |
| if not _tree_has_unreachable(tree): |
| checks += 1 |
|
|
| |
| if "not not" not in source: |
| checks += 1 |
|
|
| return checks / 7 |
|
|
|
|
| |
| |
| |
|
|
| class TaskRegistry: |
| def __init__(self) -> None: |
| self._tasks: Dict[str, Task] = {} |
| self._register_all() |
|
|
| def _register_all(self) -> None: |
| self._tasks["rename_variables"] = Task( |
| id="rename_variables", |
| name="Rename Variables (Easy)", |
| description=EasyTask.description, |
| difficulty="easy", |
| samples=_EASY_SAMPLES, |
| expected_outputs=_EASY_EXPECTED, |
| _grade_fn=_grade_easy, |
| ) |
| self._tasks["remove_dead_code"] = Task( |
| id="remove_dead_code", |
| name="Remove Dead Code (Medium)", |
| description=MediumTask.description, |
| difficulty="medium", |
| samples=_MEDIUM_SAMPLES, |
| expected_outputs=_MEDIUM_EXPECTED, |
| _grade_fn=_grade_medium, |
| ) |
| self._tasks["full_refactor"] = Task( |
| id="full_refactor", |
| name="Full Refactor (Hard)", |
| description=HardTask.description, |
| difficulty="hard", |
| samples=_HARD_SAMPLES, |
| expected_outputs=_HARD_EXPECTED, |
| _grade_fn=_grade_hard, |
| ) |
|
|
| def get_task(self, task_id: str) -> Optional[Task]: |
| return self._tasks.get(task_id) |
|
|
| def list_tasks(self) -> List[dict]: |
| return [ |
| { |
| "id": t.id, |
| "name": t.name, |
| "description": t.description, |
| "difficulty": t.difficulty, |
| "initial_code": t.initial_code, |
| } |
| for t in self._tasks.values() |
| ] |
|
|