Commit ·
8d66fec
1
Parent(s): 900e1f4
Updated structure and fixed module import issue
Browse files- ACRE_FINAL/.gitignore +0 -26
- ACRE_FINAL/Dockerfile +0 -23
- ACRE_FINAL/README.md +0 -174
- ACRE_FINAL/acre/__init__.py +0 -14
- ACRE_FINAL/acre/actions/__init__.py +0 -6
- ACRE_FINAL/acre/actions/transformations.py +0 -518
- ACRE_FINAL/acre/datasets/__init__.py +0 -6
- ACRE_FINAL/acre/datasets/code_samples.py +0 -34
- ACRE_FINAL/acre/demo.py +0 -185
- ACRE_FINAL/acre/main.py +0 -39
- ACRE_FINAL/acre/tasks/__init__.py +0 -3
- ACRE_FINAL/acre/tasks/task_registry.py +0 -222
- ACRE_FINAL/acre/training/__init__.py +0 -6
- ACRE_FINAL/acre/training/train_agent.py +0 -75
- ACRE_FINAL/acre/utils/__init__.py +0 -6
- ACRE_FINAL/acre/utils/metrics.py +0 -33
- ACRE_FINAL/inference.py +0 -278
- ACRE_FINAL/models.py +0 -156
- ACRE_FINAL/openenv.yaml +0 -85
- ACRE_FINAL/openenv_interface.py +0 -116
- ACRE_FINAL/requirements.txt +0 -11
- ACRE_FINAL/server.py +0 -667
- ACRE_FINAL/validate.py +0 -281
- README.md +4 -4
- acre/tasks/task_registry.py +212 -31
- inference.py +26 -63
- openenv_interface.py +17 -1
- validate.py +10 -2
ACRE_FINAL/.gitignore
DELETED
|
@@ -1,26 +0,0 @@
|
|
| 1 |
-
__pycache__/
|
| 2 |
-
*.pyc
|
| 3 |
-
*.pyo
|
| 4 |
-
*.pyd
|
| 5 |
-
.Python
|
| 6 |
-
*.egg-info/
|
| 7 |
-
dist/
|
| 8 |
-
build/
|
| 9 |
-
.env
|
| 10 |
-
.venv
|
| 11 |
-
venv/
|
| 12 |
-
*.zip
|
| 13 |
-
acre_agent.zip
|
| 14 |
-
*.log
|
| 15 |
-
.DS_Store
|
| 16 |
-
.deps/
|
| 17 |
-
libs/
|
| 18 |
-
numpy.libs/
|
| 19 |
-
*.dll
|
| 20 |
-
*.so
|
| 21 |
-
*.dylib
|
| 22 |
-
env/
|
| 23 |
-
ENV/
|
| 24 |
-
.cache/
|
| 25 |
-
.huggingface/
|
| 26 |
-
Thumbs.db
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/Dockerfile
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
FROM python:3.11-slim
|
| 2 |
-
|
| 3 |
-
WORKDIR /app
|
| 4 |
-
|
| 5 |
-
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 6 |
-
build-essential \
|
| 7 |
-
&& rm -rf /var/lib/apt/lists/*
|
| 8 |
-
|
| 9 |
-
COPY requirements.txt .
|
| 10 |
-
RUN pip install --no-cache-dir -r requirements.txt
|
| 11 |
-
|
| 12 |
-
COPY . .
|
| 13 |
-
|
| 14 |
-
ENV API_BASE_URL=https://api.openai.com/v1
|
| 15 |
-
ENV MODEL_NAME=gpt-4o-mini
|
| 16 |
-
ENV PORT=7860
|
| 17 |
-
|
| 18 |
-
EXPOSE 7860
|
| 19 |
-
|
| 20 |
-
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
| 21 |
-
CMD python -c "import requests; requests.get('http://localhost:7860/').raise_for_status()"
|
| 22 |
-
|
| 23 |
-
CMD ["python", "server.py"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/README.md
DELETED
|
@@ -1,174 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: ACRE - Autonomous Code Refactoring Environment
|
| 3 |
-
colorFrom: blue
|
| 4 |
-
colorTo: green
|
| 5 |
-
sdk: docker
|
| 6 |
-
app_port: 7860
|
| 7 |
-
pinned: false
|
| 8 |
-
license: mit
|
| 9 |
-
tags:
|
| 10 |
-
- openenv
|
| 11 |
-
---
|
| 12 |
-
|
| 13 |
-
# ACRE - Autonomous Code Refactoring Environment
|
| 14 |
-
|
| 15 |
-
ACRE is an OpenEnv-compatible environment for autonomous Python code refactoring. An agent receives real code-cleanup tasks and must improve the code through AST-based transformations while receiving dense reward feedback for correctness, simplification, and performance.
|
| 16 |
-
|
| 17 |
-
## Environment Overview and Motivation
|
| 18 |
-
|
| 19 |
-
This project simulates a realistic developer workflow: cleaning up messy Python code, removing dead logic, simplifying loops, and inlining trivial helpers. The canonical OpenEnv wrapper lives in `openenv_interface.py`, while the original Gymnasium-compatible environment remains available for RL training and demos.
|
| 20 |
-
|
| 21 |
-
## Definitions of Action and Observation Spaces
|
| 22 |
-
|
| 23 |
-
### Action Space - Discrete(5)
|
| 24 |
-
|
| 25 |
-
| Action | Name | Description |
|
| 26 |
-
|---|---|---|
|
| 27 |
-
| 0 | rename_variable | Rename generic variables like `x`, `tmp`, and `i` |
|
| 28 |
-
| 1 | remove_dead_code | Remove unreachable statements, `if False` branches, and unused assignments |
|
| 29 |
-
| 2 | simplify_loop | Convert append-loops into list comprehensions |
|
| 30 |
-
| 3 | optimize_condition | Simplify `not not x`, `if True`, `if False`, and boolean comparisons |
|
| 31 |
-
| 4 | inline_function | Inline simple single-return module-level functions |
|
| 32 |
-
|
| 33 |
-
### Observation Space - Box(4,)
|
| 34 |
-
|
| 35 |
-
The environment tracks:
|
| 36 |
-
|
| 37 |
-
- `code_length`
|
| 38 |
-
- `complexity_score`
|
| 39 |
-
- `runtime_s`
|
| 40 |
-
- `error_flag`
|
| 41 |
-
|
| 42 |
-
### Typed OpenEnv Models
|
| 43 |
-
|
| 44 |
-
The submission-facing interface uses Pydantic models in `models.py`:
|
| 45 |
-
|
| 46 |
-
- `ObservationModel`
|
| 47 |
-
- `ActionModel`
|
| 48 |
-
- `RewardModel`
|
| 49 |
-
- `StateResponse`
|
| 50 |
-
|
| 51 |
-
The canonical interface is:
|
| 52 |
-
|
| 53 |
-
```python
|
| 54 |
-
observation = env.reset(...)
|
| 55 |
-
observation, reward, done, info = env.step(action)
|
| 56 |
-
state = env.state()
|
| 57 |
-
```
|
| 58 |
-
|
| 59 |
-
## Task Descriptions with Expected Difficulty Levels
|
| 60 |
-
|
| 61 |
-
| Task ID | Difficulty | Objective |
|
| 62 |
-
|---|---|---|
|
| 63 |
-
| `rename_variables` | Easy | Remove generic variable names from the snippet |
|
| 64 |
-
| `remove_dead_code` | Medium | Eliminate dead branches, unreachable code, and unused assignments |
|
| 65 |
-
| `full_refactor` | Hard | Combine renaming, dead-code removal, loop simplification, condition optimization, and inlining |
|
| 66 |
-
|
| 67 |
-
Each task includes a deterministic AST-based grader returning a score in `[0.0, 1.0]`.
|
| 68 |
-
|
| 69 |
-
## Reward Design
|
| 70 |
-
|
| 71 |
-
Rewards are shaped throughout the trajectory instead of only at the end.
|
| 72 |
-
|
| 73 |
-
- Success reward for syntactically valid, executable output
|
| 74 |
-
- Complexity reward when control-flow complexity decreases
|
| 75 |
-
- Performance reward when runtime improves
|
| 76 |
-
- Error penalty for invalid or failing code
|
| 77 |
-
- No-change penalty to discourage loops and unproductive actions
|
| 78 |
-
|
| 79 |
-
Raw reward range is `[-32, 20]`, normalized to `[0.0, 1.0]` with `(raw + 32) / 52`.
|
| 80 |
-
|
| 81 |
-
## HTTP API
|
| 82 |
-
|
| 83 |
-
| Method | Path | Purpose |
|
| 84 |
-
|---|---|---|
|
| 85 |
-
| GET | `/` | Health check |
|
| 86 |
-
| GET | `/health` | Compatibility health check |
|
| 87 |
-
| POST | `/reset` | Reset environment and return typed observation/state |
|
| 88 |
-
| POST | `/step` | Apply one action and return typed observation/reward/done |
|
| 89 |
-
| GET | `/state` | Return the current typed state |
|
| 90 |
-
| GET | `/tasks` | List available tasks |
|
| 91 |
-
| POST | `/tasks/{task_id}/grade` | Grade submitted code |
|
| 92 |
-
|
| 93 |
-
## Setup and Usage Instructions
|
| 94 |
-
|
| 95 |
-
### Local setup
|
| 96 |
-
|
| 97 |
-
```bash
|
| 98 |
-
pip install -r requirements.txt
|
| 99 |
-
python server.py
|
| 100 |
-
```
|
| 101 |
-
|
| 102 |
-
### Baseline inference
|
| 103 |
-
|
| 104 |
-
Set environment variables before running:
|
| 105 |
-
|
| 106 |
-
```bash
|
| 107 |
-
export API_BASE_URL=https://api.openai.com/v1
|
| 108 |
-
export MODEL_NAME=gpt-4o-mini
|
| 109 |
-
export HF_TOKEN=your_key
|
| 110 |
-
export ENV_URL=http://localhost:7860
|
| 111 |
-
python inference.py
|
| 112 |
-
```
|
| 113 |
-
|
| 114 |
-
Notes:
|
| 115 |
-
|
| 116 |
-
- `API_BASE_URL` and `MODEL_NAME` have defaults in `inference.py`
|
| 117 |
-
- `HF_TOKEN` is optional because the script falls back to a deterministic heuristic baseline
|
| 118 |
-
- `LOCAL_IMAGE_NAME` is read for evaluator compatibility when using a local Docker image launcher
|
| 119 |
-
|
| 120 |
-
### Docker / Hugging Face Spaces
|
| 121 |
-
|
| 122 |
-
```bash
|
| 123 |
-
docker build -t acre .
|
| 124 |
-
docker run -p 7860:7860 \
|
| 125 |
-
-e API_BASE_URL=https://api.openai.com/v1 \
|
| 126 |
-
-e MODEL_NAME=gpt-4o-mini \
|
| 127 |
-
-e HF_TOKEN=your_key \
|
| 128 |
-
-e ENV_URL=http://localhost:7860 \
|
| 129 |
-
acre
|
| 130 |
-
```
|
| 131 |
-
|
| 132 |
-
The repository is configured for a Docker-based Hugging Face Space and includes the `openenv` tag in the front matter.
|
| 133 |
-
|
| 134 |
-
## Validation
|
| 135 |
-
|
| 136 |
-
Run the repository validator:
|
| 137 |
-
|
| 138 |
-
```bash
|
| 139 |
-
python validate.py --url http://localhost:7860
|
| 140 |
-
```
|
| 141 |
-
|
| 142 |
-
When using the official hackathon tooling, also run:
|
| 143 |
-
|
| 144 |
-
```bash
|
| 145 |
-
openenv validate
|
| 146 |
-
```
|
| 147 |
-
|
| 148 |
-
## Interactive Demo
|
| 149 |
-
|
| 150 |
-
Start the server and open:
|
| 151 |
-
|
| 152 |
-
```text
|
| 153 |
-
http://localhost:7860/demo
|
| 154 |
-
```
|
| 155 |
-
|
| 156 |
-
The demo shows:
|
| 157 |
-
|
| 158 |
-
- Original code
|
| 159 |
-
- Optimized code
|
| 160 |
-
- Unified diff
|
| 161 |
-
- Per-step action and reward logs
|
| 162 |
-
|
| 163 |
-
## Baseline Performance Scores
|
| 164 |
-
|
| 165 |
-
The deterministic fallback policy used by `inference.py` produces the following reproducible task scores:
|
| 166 |
-
|
| 167 |
-
| Task | Score |
|
| 168 |
-
|---|---|
|
| 169 |
-
| `rename_variables` | 1.0 |
|
| 170 |
-
| `remove_dead_code` | 1.0 |
|
| 171 |
-
| `full_refactor` | 1.0 |
|
| 172 |
-
| Average | 1.0 |
|
| 173 |
-
|
| 174 |
-
These scores come from the built-in heuristic policy with `HF_TOKEN` unset, which keeps the baseline reproducible across runs.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/acre/__init__.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
ACRE (Autonomous Code Refactoring Environment).
|
| 3 |
-
|
| 4 |
-
Package skeleton for an RL-based code refactoring system.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
__all__ = [
|
| 8 |
-
"env",
|
| 9 |
-
"actions",
|
| 10 |
-
"datasets",
|
| 11 |
-
"training",
|
| 12 |
-
"utils",
|
| 13 |
-
]
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/acre/actions/__init__.py
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
"""Action definitions and transformations for ACRE."""
|
| 2 |
-
|
| 3 |
-
from .transformations import Transformation, TransformationResult
|
| 4 |
-
|
| 5 |
-
__all__ = ["Transformation", "TransformationResult"]
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/acre/actions/transformations.py
DELETED
|
@@ -1,518 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import ast
|
| 4 |
-
import copy
|
| 5 |
-
from dataclasses import dataclass
|
| 6 |
-
from itertools import zip_longest
|
| 7 |
-
from typing import Any, Dict, Protocol
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
@dataclass(frozen=True)
|
| 11 |
-
class TransformationResult:
|
| 12 |
-
"""Output of applying a transformation (placeholder)."""
|
| 13 |
-
|
| 14 |
-
code: str
|
| 15 |
-
changed: bool
|
| 16 |
-
metadata: Dict[str, Any]
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class Transformation(Protocol):
|
| 20 |
-
"""Protocol for a code transformation."""
|
| 21 |
-
|
| 22 |
-
name: str
|
| 23 |
-
|
| 24 |
-
def apply(self, code: str) -> TransformationResult: ...
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def noop_transformation(code: str) -> TransformationResult:
|
| 28 |
-
"""Baseline transformation that leaves code unchanged."""
|
| 29 |
-
return TransformationResult(code=code, changed=False, metadata={"kind": "noop"})
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def _finalize_result(*, original: str, out: str, meta: Dict[str, Any]) -> TransformationResult:
|
| 33 |
-
"""
|
| 34 |
-
Standardize metadata across transformations.
|
| 35 |
-
|
| 36 |
-
- Adds `lines_changed` and `impact` for explainability/metrics.
|
| 37 |
-
- Ensures formatting-only changes don't count as `changed`.
|
| 38 |
-
"""
|
| 39 |
-
|
| 40 |
-
def _count_lines_changed(a: str, b: str) -> int:
|
| 41 |
-
a_lines = a.splitlines()
|
| 42 |
-
b_lines = b.splitlines()
|
| 43 |
-
changed = 0
|
| 44 |
-
for x, y in zip_longest(a_lines, b_lines, fillvalue=None):
|
| 45 |
-
if x != y:
|
| 46 |
-
changed += 1
|
| 47 |
-
return int(changed)
|
| 48 |
-
|
| 49 |
-
lines_changed = _count_lines_changed(original, out)
|
| 50 |
-
|
| 51 |
-
# Fallback identity check: AST round-trips can reformat without changing meaning.
|
| 52 |
-
# If the textual content is the same after stripping, treat it as unchanged.
|
| 53 |
-
if out.strip() == original.strip():
|
| 54 |
-
meta["success"] = False
|
| 55 |
-
meta["lines_changed"] = 0
|
| 56 |
-
meta["impact"] = "low"
|
| 57 |
-
return TransformationResult(code=original, changed=False, metadata=meta)
|
| 58 |
-
|
| 59 |
-
meta["lines_changed"] = lines_changed
|
| 60 |
-
meta["impact"] = "high" if lines_changed >= 3 else "low"
|
| 61 |
-
meta["success"] = True
|
| 62 |
-
return TransformationResult(code=out, changed=True, metadata=meta)
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def _unchanged(*, code: str, meta: Dict[str, Any]) -> TransformationResult:
|
| 66 |
-
meta.setdefault("success", False)
|
| 67 |
-
meta.setdefault("lines_changed", 0)
|
| 68 |
-
meta.setdefault("impact", "low")
|
| 69 |
-
return TransformationResult(code=code, changed=False, metadata=meta)
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
def rename_variable(code: str) -> TransformationResult:
|
| 73 |
-
"""
|
| 74 |
-
Rename simple, generic variable names to more descriptive ones.
|
| 75 |
-
|
| 76 |
-
Hackathon-scope heuristic:
|
| 77 |
-
- Rename generic names in priority order: x, tmp, i.
|
| 78 |
-
- Uses descriptive base names and avoids collisions.
|
| 79 |
-
- Applies to Name nodes and function args.
|
| 80 |
-
"""
|
| 81 |
-
meta: Dict[str, Any] = {"type": "rename_variable", "success": False}
|
| 82 |
-
try:
|
| 83 |
-
tree = ast.parse(code)
|
| 84 |
-
|
| 85 |
-
class _NameCollector(ast.NodeVisitor):
|
| 86 |
-
def __init__(self) -> None:
|
| 87 |
-
self.names: set[str] = set()
|
| 88 |
-
|
| 89 |
-
def visit_Name(self, node: ast.Name) -> None: # noqa: N802
|
| 90 |
-
self.names.add(node.id)
|
| 91 |
-
|
| 92 |
-
def visit_arg(self, node: ast.arg) -> None: # noqa: N802
|
| 93 |
-
self.names.add(node.arg)
|
| 94 |
-
|
| 95 |
-
collector = _NameCollector()
|
| 96 |
-
collector.visit(tree)
|
| 97 |
-
|
| 98 |
-
rename_plan = [
|
| 99 |
-
("x", "value"),
|
| 100 |
-
("tmp", "temp_value"),
|
| 101 |
-
("i", "index"),
|
| 102 |
-
]
|
| 103 |
-
|
| 104 |
-
old = ""
|
| 105 |
-
base_new = "value"
|
| 106 |
-
for candidate_old, candidate_base in rename_plan:
|
| 107 |
-
if candidate_old in collector.names:
|
| 108 |
-
old = candidate_old
|
| 109 |
-
base_new = candidate_base
|
| 110 |
-
break
|
| 111 |
-
|
| 112 |
-
if not old:
|
| 113 |
-
return _unchanged(code=code, meta=meta)
|
| 114 |
-
|
| 115 |
-
new = base_new
|
| 116 |
-
i = 1
|
| 117 |
-
while new in collector.names:
|
| 118 |
-
new = f"{base_new}{i}"
|
| 119 |
-
i += 1
|
| 120 |
-
|
| 121 |
-
class _Renamer(ast.NodeTransformer):
|
| 122 |
-
def __init__(self, old_name: str, new_name: str) -> None:
|
| 123 |
-
self.old_name = old_name
|
| 124 |
-
self.new_name = new_name
|
| 125 |
-
self.changed = False
|
| 126 |
-
|
| 127 |
-
def visit_Name(self, node: ast.Name) -> ast.AST: # noqa: N802
|
| 128 |
-
if node.id == self.old_name:
|
| 129 |
-
self.changed = True
|
| 130 |
-
return ast.copy_location(ast.Name(id=self.new_name, ctx=node.ctx), node)
|
| 131 |
-
return node
|
| 132 |
-
|
| 133 |
-
def visit_arg(self, node: ast.arg) -> ast.AST: # noqa: N802
|
| 134 |
-
if node.arg == self.old_name:
|
| 135 |
-
self.changed = True
|
| 136 |
-
new_node = copy.copy(node)
|
| 137 |
-
new_node.arg = self.new_name
|
| 138 |
-
return new_node
|
| 139 |
-
return node
|
| 140 |
-
|
| 141 |
-
renamer = _Renamer(old, new)
|
| 142 |
-
tree = renamer.visit(tree)
|
| 143 |
-
ast.fix_missing_locations(tree)
|
| 144 |
-
|
| 145 |
-
if not renamer.changed:
|
| 146 |
-
return _unchanged(code=code, meta=meta)
|
| 147 |
-
|
| 148 |
-
out = ast.unparse(tree)
|
| 149 |
-
meta["old"] = old
|
| 150 |
-
meta["new"] = new
|
| 151 |
-
# Renames tend to be small diffs; label as low impact unless the diff is large.
|
| 152 |
-
return _finalize_result(original=code, out=out, meta=meta)
|
| 153 |
-
except Exception:
|
| 154 |
-
return _unchanged(code=code, meta=meta)
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
def remove_dead_code(code: str) -> TransformationResult:
|
| 158 |
-
"""
|
| 159 |
-
Remove simple dead code patterns.
|
| 160 |
-
|
| 161 |
-
Hackathon-scope heuristics:
|
| 162 |
-
- Drop statements after `return` / `raise` in the same block.
|
| 163 |
-
- Remove `if False: ...` blocks (keep `else` if present).
|
| 164 |
-
- Remove assignments to unused names in a block (very simple check).
|
| 165 |
-
"""
|
| 166 |
-
meta: Dict[str, Any] = {"type": "remove_dead_code", "success": False}
|
| 167 |
-
|
| 168 |
-
try:
|
| 169 |
-
tree = ast.parse(code)
|
| 170 |
-
|
| 171 |
-
def _is_const_bool(expr: ast.AST, value: bool) -> bool:
|
| 172 |
-
return isinstance(expr, ast.Constant) and isinstance(expr.value, bool) and expr.value is value
|
| 173 |
-
|
| 174 |
-
class _LoadNameCollector(ast.NodeVisitor):
|
| 175 |
-
def __init__(self) -> None:
|
| 176 |
-
self.loaded: set[str] = set()
|
| 177 |
-
|
| 178 |
-
def visit_Name(self, node: ast.Name) -> None: # noqa: N802
|
| 179 |
-
if isinstance(node.ctx, ast.Load):
|
| 180 |
-
self.loaded.add(node.id)
|
| 181 |
-
|
| 182 |
-
class _DeadCode(ast.NodeTransformer):
|
| 183 |
-
def __init__(self) -> None:
|
| 184 |
-
self.changed = False
|
| 185 |
-
|
| 186 |
-
def _prune_unreachable(self, stmts: list[ast.stmt]) -> list[ast.stmt]:
|
| 187 |
-
out: list[ast.stmt] = []
|
| 188 |
-
unreachable = False
|
| 189 |
-
for s in stmts:
|
| 190 |
-
if unreachable:
|
| 191 |
-
self.changed = True
|
| 192 |
-
continue
|
| 193 |
-
out.append(s)
|
| 194 |
-
if isinstance(s, (ast.Return, ast.Raise)):
|
| 195 |
-
unreachable = True
|
| 196 |
-
return out
|
| 197 |
-
|
| 198 |
-
def _remove_unused_assigns(self, stmts: list[ast.stmt]) -> list[ast.stmt]:
|
| 199 |
-
collector = _LoadNameCollector()
|
| 200 |
-
for s in stmts:
|
| 201 |
-
collector.visit(s)
|
| 202 |
-
used = collector.loaded
|
| 203 |
-
|
| 204 |
-
out: list[ast.stmt] = []
|
| 205 |
-
for s in stmts:
|
| 206 |
-
if isinstance(s, ast.Assign) and all(isinstance(t, ast.Name) for t in s.targets):
|
| 207 |
-
targets = [t.id for t in s.targets if isinstance(t, ast.Name)]
|
| 208 |
-
# Remove only if *all* assigned names are unused.
|
| 209 |
-
if targets and all(t not in used for t in targets):
|
| 210 |
-
self.changed = True
|
| 211 |
-
continue
|
| 212 |
-
if isinstance(s, ast.AnnAssign) and isinstance(s.target, ast.Name):
|
| 213 |
-
if s.target.id not in used:
|
| 214 |
-
self.changed = True
|
| 215 |
-
continue
|
| 216 |
-
out.append(s)
|
| 217 |
-
return out
|
| 218 |
-
|
| 219 |
-
def _clean_block(self, stmts: list[ast.stmt]) -> list[ast.stmt]:
|
| 220 |
-
# First apply transformations inside statements.
|
| 221 |
-
visited = [self.visit(s) for s in stmts]
|
| 222 |
-
flat: list[ast.stmt] = []
|
| 223 |
-
for s in visited:
|
| 224 |
-
if s is None:
|
| 225 |
-
self.changed = True
|
| 226 |
-
continue
|
| 227 |
-
if isinstance(s, list):
|
| 228 |
-
flat.extend([x for x in s if isinstance(x, ast.stmt)])
|
| 229 |
-
self.changed = True
|
| 230 |
-
else:
|
| 231 |
-
flat.append(s)
|
| 232 |
-
|
| 233 |
-
flat = self._prune_unreachable(flat)
|
| 234 |
-
flat = self._remove_unused_assigns(flat)
|
| 235 |
-
return flat
|
| 236 |
-
|
| 237 |
-
def visit_Module(self, node: ast.Module) -> ast.AST: # noqa: N802
|
| 238 |
-
node.body = self._clean_block(node.body)
|
| 239 |
-
return node
|
| 240 |
-
|
| 241 |
-
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: # noqa: N802
|
| 242 |
-
node.body = self._clean_block(node.body)
|
| 243 |
-
return node
|
| 244 |
-
|
| 245 |
-
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: # noqa: N802
|
| 246 |
-
node.body = self._clean_block(node.body)
|
| 247 |
-
return node
|
| 248 |
-
|
| 249 |
-
def visit_If(self, node: ast.If) -> ast.AST | list[ast.stmt]: # noqa: N802
|
| 250 |
-
node = self.generic_visit(node)
|
| 251 |
-
if _is_const_bool(node.test, False):
|
| 252 |
-
self.changed = True
|
| 253 |
-
return node.orelse or []
|
| 254 |
-
return node
|
| 255 |
-
|
| 256 |
-
def visit_While(self, node: ast.While) -> ast.AST | None: # noqa: N802
|
| 257 |
-
node = self.generic_visit(node)
|
| 258 |
-
if _is_const_bool(node.test, False):
|
| 259 |
-
self.changed = True
|
| 260 |
-
return None
|
| 261 |
-
return node
|
| 262 |
-
|
| 263 |
-
dc = _DeadCode()
|
| 264 |
-
tree = dc.visit(tree)
|
| 265 |
-
ast.fix_missing_locations(tree)
|
| 266 |
-
if not dc.changed:
|
| 267 |
-
return _unchanged(code=code, meta=meta)
|
| 268 |
-
|
| 269 |
-
out = ast.unparse(tree)
|
| 270 |
-
return _finalize_result(original=code, out=out, meta=meta)
|
| 271 |
-
except Exception:
|
| 272 |
-
return _unchanged(code=code, meta=meta)
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
def simplify_loops(code: str) -> TransformationResult:
|
| 276 |
-
"""
|
| 277 |
-
Simplify very basic loop patterns into more pythonic forms.
|
| 278 |
-
|
| 279 |
-
Supported pattern (only when adjacent in the same block):
|
| 280 |
-
- xs = []
|
| 281 |
-
for t in it:
|
| 282 |
-
xs.append(expr)
|
| 283 |
-
=> xs = [expr for t in it]
|
| 284 |
-
"""
|
| 285 |
-
meta: Dict[str, Any] = {"type": "simplify_loops", "success": False}
|
| 286 |
-
try:
|
| 287 |
-
tree = ast.parse(code)
|
| 288 |
-
|
| 289 |
-
class _LoopSimplifier(ast.NodeTransformer):
|
| 290 |
-
def __init__(self) -> None:
|
| 291 |
-
self.changed = False
|
| 292 |
-
|
| 293 |
-
def _simplify_body(self, body: list[ast.stmt]) -> list[ast.stmt]:
|
| 294 |
-
out: list[ast.stmt] = []
|
| 295 |
-
i = 0
|
| 296 |
-
while i < len(body):
|
| 297 |
-
cur = body[i]
|
| 298 |
-
nxt = body[i + 1] if i + 1 < len(body) else None
|
| 299 |
-
|
| 300 |
-
if (
|
| 301 |
-
isinstance(cur, ast.Assign)
|
| 302 |
-
and len(cur.targets) == 1
|
| 303 |
-
and isinstance(cur.targets[0], ast.Name)
|
| 304 |
-
and isinstance(cur.value, ast.List)
|
| 305 |
-
and cur.value.elts == []
|
| 306 |
-
and isinstance(nxt, ast.For)
|
| 307 |
-
and len(nxt.body) == 1
|
| 308 |
-
and isinstance(nxt.body[0], ast.Expr)
|
| 309 |
-
and isinstance(nxt.body[0].value, ast.Call)
|
| 310 |
-
):
|
| 311 |
-
list_name = cur.targets[0].id
|
| 312 |
-
call = nxt.body[0].value
|
| 313 |
-
if (
|
| 314 |
-
isinstance(call.func, ast.Attribute)
|
| 315 |
-
and isinstance(call.func.value, ast.Name)
|
| 316 |
-
and call.func.value.id == list_name
|
| 317 |
-
and call.func.attr == "append"
|
| 318 |
-
and len(call.args) == 1
|
| 319 |
-
and not call.keywords
|
| 320 |
-
):
|
| 321 |
-
# Build list comprehension: [call.args[0] for <target> in <iter>]
|
| 322 |
-
comp = ast.ListComp(
|
| 323 |
-
elt=call.args[0],
|
| 324 |
-
generators=[
|
| 325 |
-
ast.comprehension(
|
| 326 |
-
target=nxt.target,
|
| 327 |
-
iter=nxt.iter,
|
| 328 |
-
ifs=[],
|
| 329 |
-
is_async=0,
|
| 330 |
-
)
|
| 331 |
-
],
|
| 332 |
-
)
|
| 333 |
-
new_assign = ast.Assign(targets=[ast.Name(id=list_name, ctx=ast.Store())], value=comp)
|
| 334 |
-
out.append(ast.copy_location(new_assign, cur))
|
| 335 |
-
self.changed = True
|
| 336 |
-
i += 2
|
| 337 |
-
continue
|
| 338 |
-
|
| 339 |
-
out.append(cur)
|
| 340 |
-
i += 1
|
| 341 |
-
|
| 342 |
-
return out
|
| 343 |
-
|
| 344 |
-
def visit_Module(self, node: ast.Module) -> ast.AST: # noqa: N802
|
| 345 |
-
node = self.generic_visit(node)
|
| 346 |
-
node.body = self._simplify_body(node.body)
|
| 347 |
-
return node
|
| 348 |
-
|
| 349 |
-
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: # noqa: N802
|
| 350 |
-
node = self.generic_visit(node)
|
| 351 |
-
node.body = self._simplify_body(node.body)
|
| 352 |
-
return node
|
| 353 |
-
|
| 354 |
-
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: # noqa: N802
|
| 355 |
-
node = self.generic_visit(node)
|
| 356 |
-
node.body = self._simplify_body(node.body)
|
| 357 |
-
return node
|
| 358 |
-
|
| 359 |
-
simp = _LoopSimplifier()
|
| 360 |
-
tree = simp.visit(tree)
|
| 361 |
-
ast.fix_missing_locations(tree)
|
| 362 |
-
if not simp.changed:
|
| 363 |
-
return _unchanged(code=code, meta=meta)
|
| 364 |
-
|
| 365 |
-
out = ast.unparse(tree)
|
| 366 |
-
return _finalize_result(original=code, out=out, meta=meta)
|
| 367 |
-
except Exception:
|
| 368 |
-
return _unchanged(code=code, meta=meta)
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
def simplify_loop(code: str) -> TransformationResult:
|
| 372 |
-
# Backwards-compatible alias for the environment's action mapping.
|
| 373 |
-
return simplify_loops(code)
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
def optimize_condition(code: str) -> TransformationResult:
|
| 377 |
-
"""
|
| 378 |
-
Simplify redundant boolean conditions.
|
| 379 |
-
|
| 380 |
-
Hackathon-scope heuristics:
|
| 381 |
-
- Replace `if True:` with its body; `if False:` with `else` (if present).
|
| 382 |
-
- Simplify `not not X` -> `X`.
|
| 383 |
-
- Simplify comparisons to True/False: `X == True` -> `X`, `X == False` -> `not X`.
|
| 384 |
-
"""
|
| 385 |
-
meta: Dict[str, Any] = {"type": "optimize_condition", "success": False}
|
| 386 |
-
try:
|
| 387 |
-
tree = ast.parse(code)
|
| 388 |
-
|
| 389 |
-
def _is_bool_const(node: ast.AST, value: bool) -> bool:
|
| 390 |
-
return isinstance(node, ast.Constant) and isinstance(node.value, bool) and node.value is value
|
| 391 |
-
|
| 392 |
-
class _CondOpt(ast.NodeTransformer):
|
| 393 |
-
def __init__(self) -> None:
|
| 394 |
-
self.changed = False
|
| 395 |
-
|
| 396 |
-
def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST: # noqa: N802
|
| 397 |
-
node = self.generic_visit(node)
|
| 398 |
-
if isinstance(node.op, ast.Not) and isinstance(node.operand, ast.UnaryOp) and isinstance(node.operand.op, ast.Not):
|
| 399 |
-
self.changed = True
|
| 400 |
-
return node.operand.operand
|
| 401 |
-
return node
|
| 402 |
-
|
| 403 |
-
def visit_Compare(self, node: ast.Compare) -> ast.AST: # noqa: N802
|
| 404 |
-
node = self.generic_visit(node)
|
| 405 |
-
if len(node.ops) == 1 and len(node.comparators) == 1:
|
| 406 |
-
op = node.ops[0]
|
| 407 |
-
rhs = node.comparators[0]
|
| 408 |
-
if isinstance(op, (ast.Eq, ast.Is)) and _is_bool_const(rhs, True):
|
| 409 |
-
self.changed = True
|
| 410 |
-
return node.left
|
| 411 |
-
if isinstance(op, (ast.Eq, ast.Is)) and _is_bool_const(rhs, False):
|
| 412 |
-
self.changed = True
|
| 413 |
-
return ast.UnaryOp(op=ast.Not(), operand=node.left)
|
| 414 |
-
return node
|
| 415 |
-
|
| 416 |
-
def visit_If(self, node: ast.If) -> ast.AST | list[ast.stmt]: # noqa: N802
|
| 417 |
-
node = self.generic_visit(node)
|
| 418 |
-
if _is_bool_const(node.test, True):
|
| 419 |
-
self.changed = True
|
| 420 |
-
return node.body
|
| 421 |
-
if _is_bool_const(node.test, False):
|
| 422 |
-
self.changed = True
|
| 423 |
-
return node.orelse or []
|
| 424 |
-
return node
|
| 425 |
-
|
| 426 |
-
opt = _CondOpt()
|
| 427 |
-
tree = opt.visit(tree)
|
| 428 |
-
ast.fix_missing_locations(tree)
|
| 429 |
-
if not opt.changed:
|
| 430 |
-
return _unchanged(code=code, meta=meta)
|
| 431 |
-
|
| 432 |
-
out = ast.unparse(tree)
|
| 433 |
-
return _finalize_result(original=code, out=out, meta=meta)
|
| 434 |
-
except Exception:
|
| 435 |
-
return _unchanged(code=code, meta=meta)
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
def inline_function(code: str) -> TransformationResult:
|
| 439 |
-
"""
|
| 440 |
-
Inline very simple functions into their call sites.
|
| 441 |
-
|
| 442 |
-
Supported pattern:
|
| 443 |
-
- def f(a, b): return <expr using only a,b>
|
| 444 |
-
- Replace calls: f(x, y) -> <expr with a->x, b->y>
|
| 445 |
-
Only handles module-level functions and positional args.
|
| 446 |
-
"""
|
| 447 |
-
meta: Dict[str, Any] = {"type": "inline_function", "success": False}
|
| 448 |
-
try:
|
| 449 |
-
tree = ast.parse(code)
|
| 450 |
-
|
| 451 |
-
simple_fns: Dict[str, tuple[list[str], ast.AST]] = {}
|
| 452 |
-
for node in tree.body:
|
| 453 |
-
if not isinstance(node, ast.FunctionDef):
|
| 454 |
-
continue
|
| 455 |
-
if node.decorator_list:
|
| 456 |
-
continue
|
| 457 |
-
args = node.args
|
| 458 |
-
if args.vararg or args.kwarg or args.kwonlyargs or args.defaults or args.posonlyargs:
|
| 459 |
-
continue
|
| 460 |
-
if len(node.body) != 1 or not isinstance(node.body[0], ast.Return) or node.body[0].value is None:
|
| 461 |
-
continue
|
| 462 |
-
arg_names = [a.arg for a in args.args]
|
| 463 |
-
# Ensure the return expression only references the function's args.
|
| 464 |
-
referenced: set[str] = set()
|
| 465 |
-
|
| 466 |
-
class _Ref(ast.NodeVisitor):
|
| 467 |
-
def visit_Name(self, n: ast.Name) -> None: # noqa: N802
|
| 468 |
-
if isinstance(n.ctx, ast.Load):
|
| 469 |
-
referenced.add(n.id)
|
| 470 |
-
|
| 471 |
-
_Ref().visit(node.body[0].value)
|
| 472 |
-
if not referenced.issubset(set(arg_names)):
|
| 473 |
-
continue
|
| 474 |
-
simple_fns[node.name] = (arg_names, node.body[0].value)
|
| 475 |
-
|
| 476 |
-
if not simple_fns:
|
| 477 |
-
return _unchanged(code=code, meta=meta)
|
| 478 |
-
|
| 479 |
-
class _Substitute(ast.NodeTransformer):
|
| 480 |
-
def __init__(self, mapping: Dict[str, ast.AST]) -> None:
|
| 481 |
-
self.mapping = mapping
|
| 482 |
-
|
| 483 |
-
def visit_Name(self, n: ast.Name) -> ast.AST: # noqa: N802
|
| 484 |
-
if isinstance(n.ctx, ast.Load) and n.id in self.mapping:
|
| 485 |
-
return copy.deepcopy(self.mapping[n.id])
|
| 486 |
-
return n
|
| 487 |
-
|
| 488 |
-
class _Inliner(ast.NodeTransformer):
|
| 489 |
-
def __init__(self) -> None:
|
| 490 |
-
self.changed = False
|
| 491 |
-
|
| 492 |
-
def visit_Call(self, node: ast.Call) -> ast.AST: # noqa: N802
|
| 493 |
-
node = self.generic_visit(node)
|
| 494 |
-
if not isinstance(node.func, ast.Name):
|
| 495 |
-
return node
|
| 496 |
-
fn = simple_fns.get(node.func.id)
|
| 497 |
-
if fn is None:
|
| 498 |
-
return node
|
| 499 |
-
arg_names, expr = fn
|
| 500 |
-
if node.keywords or len(node.args) != len(arg_names):
|
| 501 |
-
return node
|
| 502 |
-
mapping = {name: arg for name, arg in zip(arg_names, node.args, strict=True)}
|
| 503 |
-
new_expr = _Substitute(mapping).visit(copy.deepcopy(expr))
|
| 504 |
-
self.changed = True
|
| 505 |
-
return ast.copy_location(new_expr, node)
|
| 506 |
-
|
| 507 |
-
inliner = _Inliner()
|
| 508 |
-
tree = inliner.visit(tree)
|
| 509 |
-
ast.fix_missing_locations(tree)
|
| 510 |
-
if not inliner.changed:
|
| 511 |
-
return _unchanged(code=code, meta=meta)
|
| 512 |
-
|
| 513 |
-
out = ast.unparse(tree)
|
| 514 |
-
meta["inlined"] = sorted(simple_fns.keys())
|
| 515 |
-
return _finalize_result(original=code, out=out, meta=meta)
|
| 516 |
-
except Exception:
|
| 517 |
-
return _unchanged(code=code, meta=meta)
|
| 518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/acre/datasets/__init__.py
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
"""Datasets and sample code providers for ACRE."""
|
| 2 |
-
|
| 3 |
-
from .code_samples import CodeSample, CodeSampleDataset
|
| 4 |
-
|
| 5 |
-
__all__ = ["CodeSample", "CodeSampleDataset"]
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/acre/datasets/code_samples.py
DELETED
|
@@ -1,34 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from dataclasses import dataclass
|
| 4 |
-
from typing import Iterable, Iterator, List, Optional
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
@dataclass(frozen=True)
|
| 8 |
-
class CodeSample:
|
| 9 |
-
"""A single code sample (placeholder)."""
|
| 10 |
-
|
| 11 |
-
id: str
|
| 12 |
-
language: str
|
| 13 |
-
code: str
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class CodeSampleDataset:
|
| 17 |
-
"""
|
| 18 |
-
Minimal in-memory dataset stub.
|
| 19 |
-
|
| 20 |
-
Later versions can back this with files, Git repos, or benchmark suites.
|
| 21 |
-
"""
|
| 22 |
-
|
| 23 |
-
def __init__(self, samples: Optional[Iterable[CodeSample]] = None) -> None:
|
| 24 |
-
self._samples: List[CodeSample] = list(samples or [])
|
| 25 |
-
|
| 26 |
-
def __len__(self) -> int:
|
| 27 |
-
return len(self._samples)
|
| 28 |
-
|
| 29 |
-
def __iter__(self) -> Iterator[CodeSample]:
|
| 30 |
-
return iter(self._samples)
|
| 31 |
-
|
| 32 |
-
def add(self, sample: CodeSample) -> None:
|
| 33 |
-
self._samples.append(sample)
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/acre/demo.py
DELETED
|
@@ -1,185 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import os
|
| 4 |
-
import random
|
| 5 |
-
import sys
|
| 6 |
-
from typing import Any, Optional, Tuple
|
| 7 |
-
|
| 8 |
-
from acre.datasets.code_samples import CodeSample, CodeSampleDataset
|
| 9 |
-
from acre.env.refactor_env import RefactorEnv
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def _load_model(path: str):
|
| 13 |
-
"""Load a Stable-Baselines3 PPO model if available; otherwise return None."""
|
| 14 |
-
if not os.path.exists(path):
|
| 15 |
-
return None
|
| 16 |
-
try:
|
| 17 |
-
from stable_baselines3 import PPO
|
| 18 |
-
except Exception:
|
| 19 |
-
return None
|
| 20 |
-
try:
|
| 21 |
-
return PPO.load(path)
|
| 22 |
-
except Exception:
|
| 23 |
-
return None
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def _messy_sample_code() -> str:
|
| 27 |
-
# Intentionally "messy" but valid Python for demo purposes.
|
| 28 |
-
return (
|
| 29 |
-
"def add(a,b):\n"
|
| 30 |
-
" x=0\n"
|
| 31 |
-
" for i in range(a):\n"
|
| 32 |
-
" x=x+1\n"
|
| 33 |
-
" if True:\n"
|
| 34 |
-
" x = x\n"
|
| 35 |
-
" if False:\n"
|
| 36 |
-
" y=123\n"
|
| 37 |
-
" else:\n"
|
| 38 |
-
" y=0\n"
|
| 39 |
-
" def f(p,q):\n"
|
| 40 |
-
" return p+q\n"
|
| 41 |
-
" r = f(x,y)\n"
|
| 42 |
-
" return r\n"
|
| 43 |
-
)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def _format_code_block(code: str) -> str:
|
| 47 |
-
return "\n".join(f" {line}" for line in code.rstrip().splitlines()) + "\n"
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def _safe_print(text: str) -> None:
|
| 51 |
-
"""
|
| 52 |
-
Print text safely across Windows consoles (some default encodings can't print emojis).
|
| 53 |
-
"""
|
| 54 |
-
encoding = sys.stdout.encoding or "utf-8"
|
| 55 |
-
try:
|
| 56 |
-
text.encode(encoding)
|
| 57 |
-
print(text, flush=True)
|
| 58 |
-
except Exception:
|
| 59 |
-
# Fall back to ASCII-friendly markers if emojis can't be encoded.
|
| 60 |
-
safe = text.replace("✅", "[OK]").replace("⚠️", "[WARN]").replace("⚠", "[WARN]")
|
| 61 |
-
print(safe, flush=True)
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def _compute_runtime(executor: Any, code: str) -> float:
|
| 65 |
-
"""Best-effort runtime metric using the current executor contract."""
|
| 66 |
-
try:
|
| 67 |
-
res = executor.run(code, filename="demo.py")
|
| 68 |
-
if getattr(res, "exit_code", 1) == 0 and isinstance(getattr(res, "metrics", None), dict):
|
| 69 |
-
return float(res.metrics.get("runtime_s", 0.0) or 0.0)
|
| 70 |
-
except Exception:
|
| 71 |
-
pass
|
| 72 |
-
return 0.0
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def _choose_action(model: Any, obs, env: RefactorEnv, rng: random.Random) -> Tuple[int, str]:
|
| 76 |
-
"""Choose an action from the model, falling back to random."""
|
| 77 |
-
n_actions = int(getattr(getattr(env, "action_space", None), "n", 5))
|
| 78 |
-
if model is None:
|
| 79 |
-
a = int(rng.randint(0, n_actions - 1))
|
| 80 |
-
return a, "random"
|
| 81 |
-
|
| 82 |
-
try:
|
| 83 |
-
action, _state = model.predict(obs, deterministic=True)
|
| 84 |
-
# SB3 may return scalar or 1-element array.
|
| 85 |
-
if hasattr(action, "__len__"):
|
| 86 |
-
a = int(action[0])
|
| 87 |
-
else:
|
| 88 |
-
a = int(action)
|
| 89 |
-
return a, "ppo"
|
| 90 |
-
except Exception:
|
| 91 |
-
a = int(rng.randint(0, n_actions - 1))
|
| 92 |
-
return a, "random"
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def run_demo(*, model_path: str = "acre_agent.zip", seed: int = 0) -> None:
|
| 96 |
-
rng = random.Random(seed)
|
| 97 |
-
|
| 98 |
-
# Create a dataset with one messy sample so `reset()` loads it deterministically.
|
| 99 |
-
dataset = CodeSampleDataset(
|
| 100 |
-
[
|
| 101 |
-
CodeSample(
|
| 102 |
-
id="demo_sample",
|
| 103 |
-
language="python",
|
| 104 |
-
code=_messy_sample_code(),
|
| 105 |
-
)
|
| 106 |
-
]
|
| 107 |
-
)
|
| 108 |
-
env = RefactorEnv(dataset=dataset, seed=seed)
|
| 109 |
-
|
| 110 |
-
model = _load_model(model_path)
|
| 111 |
-
model_status = "loaded" if model is not None else "not found (using random actions)"
|
| 112 |
-
|
| 113 |
-
# Reset and capture the original code/metrics.
|
| 114 |
-
obs, info = env.reset()
|
| 115 |
-
original_code = getattr(env, "_code", "")
|
| 116 |
-
original_complexity = float(getattr(env, "_compute_complexity")(original_code))
|
| 117 |
-
original_runtime = _compute_runtime(env.executor, original_code)
|
| 118 |
-
|
| 119 |
-
print("=" * 72)
|
| 120 |
-
print("ACRE: Autonomous RL Code Refactoring Agent (5-step episode)")
|
| 121 |
-
print(f"Model: {model_path} -> {model_status}")
|
| 122 |
-
print(f"Sample: {info.get('sample_id')} ({info.get('language')})")
|
| 123 |
-
print("=" * 72)
|
| 124 |
-
print("\nORIGINAL CODE:\n")
|
| 125 |
-
print(_format_code_block(original_code))
|
| 126 |
-
|
| 127 |
-
total_reward = 0.0
|
| 128 |
-
successful_transformations = 0
|
| 129 |
-
steps_taken = 0
|
| 130 |
-
|
| 131 |
-
for step_idx in range(1, 6):
|
| 132 |
-
action, policy = _choose_action(model, obs, env, rng)
|
| 133 |
-
obs, reward, terminated, truncated, step_info = env.step(action)
|
| 134 |
-
total_reward += float(reward)
|
| 135 |
-
steps_taken = step_idx
|
| 136 |
-
|
| 137 |
-
action_name = step_info.get("action_name", "unknown")
|
| 138 |
-
transform_meta = step_info.get("transform", {})
|
| 139 |
-
if isinstance(transform_meta, dict) and bool(transform_meta.get("success", False)):
|
| 140 |
-
successful_transformations += 1
|
| 141 |
-
transformed_code = getattr(env, "_code", "")
|
| 142 |
-
|
| 143 |
-
print("-" * 72)
|
| 144 |
-
print(f"STEP {step_idx}/5")
|
| 145 |
-
print(f"policy={policy} action={action} ({action_name})")
|
| 146 |
-
print(f"transform={transform_meta}")
|
| 147 |
-
print(f"reward={float(reward):.2f} components={step_info.get('reward_components')}")
|
| 148 |
-
print("\nUPDATED CODE:\n")
|
| 149 |
-
print(_format_code_block(transformed_code))
|
| 150 |
-
|
| 151 |
-
if terminated or truncated:
|
| 152 |
-
break
|
| 153 |
-
|
| 154 |
-
final_code = getattr(env, "_code", "")
|
| 155 |
-
final_complexity = float(getattr(env, "_compute_complexity")(final_code))
|
| 156 |
-
final_runtime = _compute_runtime(env.executor, final_code)
|
| 157 |
-
|
| 158 |
-
print("=" * 72)
|
| 159 |
-
print("FINAL SUMMARY")
|
| 160 |
-
print("=" * 72)
|
| 161 |
-
print(f"total_reward: {total_reward:.2f}")
|
| 162 |
-
print(f"complexity: {original_complexity:.0f} -> {final_complexity:.0f}")
|
| 163 |
-
print(f"runtime_s: {original_runtime:.4f} -> {final_runtime:.4f}")
|
| 164 |
-
|
| 165 |
-
complexity_improvement = ((original_complexity - final_complexity) / max(original_complexity, 1.0)) * 100.0
|
| 166 |
-
print(f"complexity improvement: {complexity_improvement:.2f}%")
|
| 167 |
-
|
| 168 |
-
print("\nCHANGES APPLIED:")
|
| 169 |
-
print(f"- Total steps: {steps_taken}")
|
| 170 |
-
print(f"- Successful transformations: {successful_transformations}")
|
| 171 |
-
|
| 172 |
-
if total_reward > 0:
|
| 173 |
-
_safe_print("\n✅ Code improved successfully")
|
| 174 |
-
else:
|
| 175 |
-
_safe_print("\n⚠️ No significant improvement")
|
| 176 |
-
|
| 177 |
-
print("\nFINAL CODE:\n")
|
| 178 |
-
print(_format_code_block(final_code))
|
| 179 |
-
|
| 180 |
-
env.close()
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
if __name__ == "__main__":
|
| 184 |
-
run_demo()
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/acre/main.py
DELETED
|
@@ -1,39 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import argparse
|
| 4 |
-
|
| 5 |
-
from acre.training.train_agent import TrainConfig, train
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def _build_parser() -> argparse.ArgumentParser:
|
| 9 |
-
parser = argparse.ArgumentParser(prog="acre", description="ACRE: Autonomous Code Refactoring Environment")
|
| 10 |
-
sub = parser.add_subparsers(dest="command", required=False)
|
| 11 |
-
|
| 12 |
-
train_p = sub.add_parser("train", help="Run training (stub)")
|
| 13 |
-
train_p.add_argument("--total-steps", type=int, default=100, help="Total training steps (stub)")
|
| 14 |
-
|
| 15 |
-
sub.add_parser("demo", help="Run a small demo (stub)")
|
| 16 |
-
|
| 17 |
-
return parser
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def run_demo() -> None:
|
| 21 |
-
# Placeholder for a future interactive/demo flow.
|
| 22 |
-
print("ACRE demo mode is not implemented yet.")
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def main(argv: list[str] | None = None) -> None:
|
| 26 |
-
parser = _build_parser()
|
| 27 |
-
args = parser.parse_args(argv)
|
| 28 |
-
|
| 29 |
-
if args.command == "demo":
|
| 30 |
-
run_demo()
|
| 31 |
-
return
|
| 32 |
-
|
| 33 |
-
total_steps = getattr(args, "total_steps", 100)
|
| 34 |
-
train(config=TrainConfig(total_steps=total_steps))
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
if __name__ == "__main__":
|
| 38 |
-
main()
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/acre/tasks/__init__.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
from acre.tasks.task_registry import Task, TaskRegistry
|
| 2 |
-
|
| 3 |
-
__all__ = ["Task", "TaskRegistry"]
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/acre/tasks/task_registry.py
DELETED
|
@@ -1,222 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Three OpenEnv tasks with AST-based graders scoring 0.0-1.0.
|
| 3 |
-
"""
|
| 4 |
-
from __future__ import annotations
|
| 5 |
-
|
| 6 |
-
import ast
|
| 7 |
-
from dataclasses import dataclass
|
| 8 |
-
from typing import Callable, Dict, List, Optional
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
@dataclass
|
| 12 |
-
class Task:
|
| 13 |
-
id: str
|
| 14 |
-
name: str
|
| 15 |
-
description: str
|
| 16 |
-
difficulty: str
|
| 17 |
-
initial_code: str
|
| 18 |
-
_grade_fn: Callable[[str], float]
|
| 19 |
-
|
| 20 |
-
def grade(self, code: str) -> float:
|
| 21 |
-
"""Return a score in [0.0, 1.0]."""
|
| 22 |
-
try:
|
| 23 |
-
return float(min(1.0, max(0.0, self._grade_fn(code))))
|
| 24 |
-
except Exception:
|
| 25 |
-
return 0.0
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
# ---------------------------------------------------------------------------
|
| 29 |
-
# Task 1 — Easy: Rename generic variables
|
| 30 |
-
# ---------------------------------------------------------------------------
|
| 31 |
-
_EASY_CODE = """\
|
| 32 |
-
def compute(x, y, tmp):
|
| 33 |
-
tmp = x + y
|
| 34 |
-
x = tmp * 2
|
| 35 |
-
result = x
|
| 36 |
-
return result
|
| 37 |
-
"""
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def _grade_easy(code: str) -> float:
|
| 41 |
-
"""Score = fraction of generic names (x, tmp) removed from all scopes."""
|
| 42 |
-
generic = {"x", "tmp"}
|
| 43 |
-
try:
|
| 44 |
-
tree = ast.parse(code)
|
| 45 |
-
except SyntaxError:
|
| 46 |
-
return 0.0
|
| 47 |
-
|
| 48 |
-
remaining: set[str] = set()
|
| 49 |
-
|
| 50 |
-
class _Collector(ast.NodeVisitor):
|
| 51 |
-
def visit_Name(self, node: ast.Name) -> None:
|
| 52 |
-
if node.id in generic:
|
| 53 |
-
remaining.add(node.id)
|
| 54 |
-
self.generic_visit(node)
|
| 55 |
-
|
| 56 |
-
def visit_arg(self, node: ast.arg) -> None:
|
| 57 |
-
if node.arg in generic:
|
| 58 |
-
remaining.add(node.arg)
|
| 59 |
-
self.generic_visit(node)
|
| 60 |
-
|
| 61 |
-
_Collector().visit(tree)
|
| 62 |
-
renamed = len(generic - remaining)
|
| 63 |
-
return renamed / len(generic)
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
# ---------------------------------------------------------------------------
|
| 67 |
-
# Task 2 — Medium: Remove dead code
|
| 68 |
-
# ---------------------------------------------------------------------------
|
| 69 |
-
_MEDIUM_CODE = """\
|
| 70 |
-
def process(data):
|
| 71 |
-
result = []
|
| 72 |
-
for item in data:
|
| 73 |
-
result.append(item * 2)
|
| 74 |
-
if False:
|
| 75 |
-
print("never runs")
|
| 76 |
-
unused_var = 42
|
| 77 |
-
return result
|
| 78 |
-
print("unreachable")
|
| 79 |
-
"""
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
def _grade_medium(code: str) -> float:
|
| 83 |
-
"""Score = fraction of dead-code patterns eliminated (3 checks, ~0.33 each)."""
|
| 84 |
-
try:
|
| 85 |
-
tree = ast.parse(code)
|
| 86 |
-
except SyntaxError:
|
| 87 |
-
return 0.0
|
| 88 |
-
|
| 89 |
-
source = ast.unparse(tree)
|
| 90 |
-
score = 0.0
|
| 91 |
-
|
| 92 |
-
# Check 1: if-False block removed
|
| 93 |
-
if "if False" not in source:
|
| 94 |
-
score += 1 / 3
|
| 95 |
-
|
| 96 |
-
# Check 2: unused_var assignment removed
|
| 97 |
-
if "unused_var" not in source:
|
| 98 |
-
score += 1 / 3
|
| 99 |
-
|
| 100 |
-
# Check 3: list comprehension used (loop simplified)
|
| 101 |
-
has_listcomp = any(isinstance(n, ast.ListComp) for n in ast.walk(tree))
|
| 102 |
-
if has_listcomp:
|
| 103 |
-
score += 1 / 3
|
| 104 |
-
|
| 105 |
-
return score
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# ---------------------------------------------------------------------------
|
| 109 |
-
# Task 3 — Hard: Full refactor
|
| 110 |
-
# ---------------------------------------------------------------------------
|
| 111 |
-
_HARD_CODE = """\
|
| 112 |
-
def add(p, q):
|
| 113 |
-
return p + q
|
| 114 |
-
|
| 115 |
-
def compute(x, data, tmp):
|
| 116 |
-
result = []
|
| 117 |
-
for item in data:
|
| 118 |
-
result.append(item * 2)
|
| 119 |
-
if False:
|
| 120 |
-
y = 999
|
| 121 |
-
if True:
|
| 122 |
-
val = add(x, tmp)
|
| 123 |
-
unused = 0
|
| 124 |
-
flag = not not True
|
| 125 |
-
return val
|
| 126 |
-
print("dead")
|
| 127 |
-
"""
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
def _grade_hard(code: str) -> float:
|
| 131 |
-
"""Score = fraction of 5 quality checks passed."""
|
| 132 |
-
try:
|
| 133 |
-
tree = ast.parse(code)
|
| 134 |
-
except SyntaxError:
|
| 135 |
-
return 0.0
|
| 136 |
-
|
| 137 |
-
source = ast.unparse(tree)
|
| 138 |
-
checks = 0
|
| 139 |
-
|
| 140 |
-
# 1. No generic variable names x/tmp in function signature or body
|
| 141 |
-
has_generic = False
|
| 142 |
-
|
| 143 |
-
class _GenCheck(ast.NodeVisitor):
|
| 144 |
-
def visit_arg(self, node: ast.arg) -> None:
|
| 145 |
-
nonlocal has_generic
|
| 146 |
-
if node.arg in {"x", "tmp"}:
|
| 147 |
-
has_generic = True
|
| 148 |
-
|
| 149 |
-
_GenCheck().visit(tree)
|
| 150 |
-
if not has_generic:
|
| 151 |
-
checks += 1
|
| 152 |
-
|
| 153 |
-
# 2. No if False block
|
| 154 |
-
if "if False" not in source:
|
| 155 |
-
checks += 1
|
| 156 |
-
|
| 157 |
-
# 3. if True removed (body inlined)
|
| 158 |
-
if "if True" not in source:
|
| 159 |
-
checks += 1
|
| 160 |
-
|
| 161 |
-
# 4. List comprehension used
|
| 162 |
-
if any(isinstance(n, ast.ListComp) for n in ast.walk(tree)):
|
| 163 |
-
checks += 1
|
| 164 |
-
|
| 165 |
-
# 5. add() call inlined (no call to 'add')
|
| 166 |
-
calls = [n for n in ast.walk(tree) if isinstance(n, ast.Call)]
|
| 167 |
-
fn_names = {c.func.id for c in calls if isinstance(c.func, ast.Name)}
|
| 168 |
-
if "add" not in fn_names:
|
| 169 |
-
checks += 1
|
| 170 |
-
|
| 171 |
-
return checks / 5
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
# ---------------------------------------------------------------------------
|
| 175 |
-
# Registry
|
| 176 |
-
# ---------------------------------------------------------------------------
|
| 177 |
-
|
| 178 |
-
class TaskRegistry:
|
| 179 |
-
def __init__(self) -> None:
|
| 180 |
-
self._tasks: Dict[str, Task] = {}
|
| 181 |
-
self._register_all()
|
| 182 |
-
|
| 183 |
-
def _register_all(self) -> None:
|
| 184 |
-
self._tasks["rename_variables"] = Task(
|
| 185 |
-
id="rename_variables",
|
| 186 |
-
name="Rename Variables (Easy)",
|
| 187 |
-
description="Rename generic variable names (x, tmp) to descriptive ones",
|
| 188 |
-
difficulty="easy",
|
| 189 |
-
initial_code=_EASY_CODE,
|
| 190 |
-
_grade_fn=_grade_easy,
|
| 191 |
-
)
|
| 192 |
-
self._tasks["remove_dead_code"] = Task(
|
| 193 |
-
id="remove_dead_code",
|
| 194 |
-
name="Remove Dead Code (Medium)",
|
| 195 |
-
description="Remove unreachable code, if False blocks, and unused variables",
|
| 196 |
-
difficulty="medium",
|
| 197 |
-
initial_code=_MEDIUM_CODE,
|
| 198 |
-
_grade_fn=_grade_medium,
|
| 199 |
-
)
|
| 200 |
-
self._tasks["full_refactor"] = Task(
|
| 201 |
-
id="full_refactor",
|
| 202 |
-
name="Full Refactor (Hard)",
|
| 203 |
-
description="Apply all transformations: rename, dead code, loops, conditions, inlining",
|
| 204 |
-
difficulty="hard",
|
| 205 |
-
initial_code=_HARD_CODE,
|
| 206 |
-
_grade_fn=_grade_hard,
|
| 207 |
-
)
|
| 208 |
-
|
| 209 |
-
def get_task(self, task_id: str) -> Optional[Task]:
|
| 210 |
-
return self._tasks.get(task_id)
|
| 211 |
-
|
| 212 |
-
def list_tasks(self) -> List[dict]:
|
| 213 |
-
return [
|
| 214 |
-
{
|
| 215 |
-
"id": t.id,
|
| 216 |
-
"name": t.name,
|
| 217 |
-
"description": t.description,
|
| 218 |
-
"difficulty": t.difficulty,
|
| 219 |
-
"initial_code": t.initial_code,
|
| 220 |
-
}
|
| 221 |
-
for t in self._tasks.values()
|
| 222 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/acre/training/__init__.py
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
"""Training utilities for ACRE."""
|
| 2 |
-
|
| 3 |
-
from .train_agent import TrainConfig, train
|
| 4 |
-
|
| 5 |
-
__all__ = ["TrainConfig", "train"]
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/acre/training/train_agent.py
DELETED
|
@@ -1,75 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from dataclasses import dataclass
|
| 4 |
-
from typing import Optional
|
| 5 |
-
|
| 6 |
-
from acre.env.refactor_env import RefactorEnv
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
@dataclass(frozen=True)
|
| 10 |
-
class TrainConfig:
|
| 11 |
-
"""Configuration stub for training."""
|
| 12 |
-
|
| 13 |
-
total_steps: int = 5_000
|
| 14 |
-
seed: Optional[int] = None
|
| 15 |
-
model_path: str = "acre_agent.zip"
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def train(*, env: Optional[RefactorEnv] = None, config: Optional[TrainConfig] = None) -> None:
|
| 19 |
-
"""
|
| 20 |
-
Train a PPO agent on `RefactorEnv` using Stable-Baselines3.
|
| 21 |
-
|
| 22 |
-
This is intentionally lightweight (hackathon-friendly) and focuses on a
|
| 23 |
-
working demo: basic training loop, simple logging, and saving the model.
|
| 24 |
-
"""
|
| 25 |
-
_config = config or TrainConfig()
|
| 26 |
-
_env = env or RefactorEnv(seed=_config.seed)
|
| 27 |
-
|
| 28 |
-
try:
|
| 29 |
-
from stable_baselines3 import PPO
|
| 30 |
-
from stable_baselines3.common.callbacks import BaseCallback
|
| 31 |
-
from stable_baselines3.common.monitor import Monitor
|
| 32 |
-
from stable_baselines3.common.vec_env import DummyVecEnv
|
| 33 |
-
except Exception as e: # pragma: no cover
|
| 34 |
-
print("Stable-Baselines3 is required for training. Install with `pip install -r requirements.txt`.")
|
| 35 |
-
print(f"Import error: {e}")
|
| 36 |
-
return None
|
| 37 |
-
|
| 38 |
-
class EpisodeRewardPrinter(BaseCallback):
|
| 39 |
-
"""Print episode reward when an episode ends (via Monitor)."""
|
| 40 |
-
|
| 41 |
-
def __init__(self) -> None:
|
| 42 |
-
super().__init__()
|
| 43 |
-
self.episode_count = 0
|
| 44 |
-
|
| 45 |
-
def _on_step(self) -> bool:
|
| 46 |
-
infos = self.locals.get("infos", [])
|
| 47 |
-
for info in infos:
|
| 48 |
-
ep = info.get("episode") if isinstance(info, dict) else None
|
| 49 |
-
if isinstance(ep, dict) and "r" in ep:
|
| 50 |
-
self.episode_count += 1
|
| 51 |
-
print(f"episode={self.episode_count} reward={ep['r']:.2f} length={int(ep.get('l', 0))}")
|
| 52 |
-
return True
|
| 53 |
-
|
| 54 |
-
# Wrap with Monitor so SB3 can compute episode stats and expose them in `info["episode"]`.
|
| 55 |
-
def make_env() -> RefactorEnv:
|
| 56 |
-
return Monitor(_env)
|
| 57 |
-
|
| 58 |
-
vec_env = DummyVecEnv([make_env])
|
| 59 |
-
|
| 60 |
-
model = PPO(
|
| 61 |
-
policy="MlpPolicy",
|
| 62 |
-
env=vec_env,
|
| 63 |
-
verbose=0,
|
| 64 |
-
seed=_config.seed,
|
| 65 |
-
n_steps=64,
|
| 66 |
-
batch_size=64,
|
| 67 |
-
)
|
| 68 |
-
|
| 69 |
-
print(f"Training PPO for {int(_config.total_steps)} timesteps...")
|
| 70 |
-
model.learn(total_timesteps=int(_config.total_steps), callback=EpisodeRewardPrinter())
|
| 71 |
-
|
| 72 |
-
model.save(_config.model_path)
|
| 73 |
-
print(f"Saved model to {_config.model_path!r}")
|
| 74 |
-
return None
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/acre/utils/__init__.py
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
"""Shared utility helpers for ACRE."""
|
| 2 |
-
|
| 3 |
-
from .metrics import Metric, MetricLogger
|
| 4 |
-
|
| 5 |
-
__all__ = ["Metric", "MetricLogger"]
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/acre/utils/metrics.py
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from dataclasses import dataclass, field
|
| 4 |
-
from typing import Dict, Iterable, List, Tuple
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
@dataclass(frozen=True)
|
| 8 |
-
class Metric:
|
| 9 |
-
"""Single scalar metric value (placeholder)."""
|
| 10 |
-
|
| 11 |
-
name: str
|
| 12 |
-
value: float
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
@dataclass
|
| 16 |
-
class MetricLogger:
|
| 17 |
-
"""Tiny metric logger stub."""
|
| 18 |
-
|
| 19 |
-
_history: Dict[str, List[float]] = field(default_factory=dict)
|
| 20 |
-
|
| 21 |
-
def log(self, metric: Metric) -> None:
|
| 22 |
-
self._history.setdefault(metric.name, []).append(metric.value)
|
| 23 |
-
|
| 24 |
-
def latest(self) -> Dict[str, float]:
|
| 25 |
-
return {k: v[-1] for k, v in self._history.items() if v}
|
| 26 |
-
|
| 27 |
-
def as_series(self) -> Dict[str, Tuple[float, ...]]:
|
| 28 |
-
return {k: tuple(v) for k, v in self._history.items()}
|
| 29 |
-
|
| 30 |
-
def extend(self, metrics: Iterable[Metric]) -> None:
|
| 31 |
-
for m in metrics:
|
| 32 |
-
self.log(m)
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/inference.py
DELETED
|
@@ -1,278 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
ACRE inference script for OpenEnv submission evaluation.
|
| 3 |
-
|
| 4 |
-
Required environment variables:
|
| 5 |
-
API_BASE_URL: LLM API endpoint (default allowed)
|
| 6 |
-
MODEL_NAME: model identifier (default allowed)
|
| 7 |
-
HF_TOKEN: API token for the OpenAI-compatible endpoint
|
| 8 |
-
ENV_URL: running ACRE server base URL
|
| 9 |
-
|
| 10 |
-
Optional:
|
| 11 |
-
LOCAL_IMAGE_NAME: present for evaluator compatibility when using a local
|
| 12 |
-
Docker image launcher.
|
| 13 |
-
|
| 14 |
-
Stdout format uses strict START / STEP / END event markers.
|
| 15 |
-
"""
|
| 16 |
-
from __future__ import annotations
|
| 17 |
-
|
| 18 |
-
import json
|
| 19 |
-
import os
|
| 20 |
-
import re
|
| 21 |
-
import sys
|
| 22 |
-
import time
|
| 23 |
-
from typing import Dict, List, Tuple
|
| 24 |
-
|
| 25 |
-
import requests
|
| 26 |
-
from openai import OpenAI
|
| 27 |
-
|
| 28 |
-
API_BASE_URL: str = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
|
| 29 |
-
MODEL_NAME: str = os.getenv("MODEL_NAME", "gpt-4o-mini")
|
| 30 |
-
HF_TOKEN: str | None = os.getenv("HF_TOKEN")
|
| 31 |
-
ENV_URL: str | None = os.getenv("ENV_URL")
|
| 32 |
-
LOCAL_IMAGE_NAME: str | None = os.getenv("LOCAL_IMAGE_NAME")
|
| 33 |
-
|
| 34 |
-
TASKS: List[str] = ["rename_variables", "remove_dead_code", "full_refactor"]
|
| 35 |
-
|
| 36 |
-
ACTION_MEANINGS: Dict[int, str] = {
|
| 37 |
-
0: "rename_variable",
|
| 38 |
-
1: "remove_dead_code",
|
| 39 |
-
2: "simplify_loop",
|
| 40 |
-
3: "optimize_condition",
|
| 41 |
-
4: "inline_function",
|
| 42 |
-
}
|
| 43 |
-
|
| 44 |
-
SYSTEM_PROMPT = """\
|
| 45 |
-
You are an RL agent that refactors Python code. Choose one action per step.
|
| 46 |
-
|
| 47 |
-
Actions:
|
| 48 |
-
0 rename_variable - rename generic names (x, tmp, i) to descriptive ones
|
| 49 |
-
1 remove_dead_code - remove unreachable stmts, if False blocks, unused vars
|
| 50 |
-
2 simplify_loop - convert append-loops to list comprehensions
|
| 51 |
-
3 optimize_condition- simplify 'not not x', 'if True/False', 'x==True'
|
| 52 |
-
4 inline_function - inline simple single-return module-level functions
|
| 53 |
-
|
| 54 |
-
Respond ONLY with valid JSON (no markdown):
|
| 55 |
-
{"action": <0-4>, "reason": "<one sentence>"}"""
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def _env_url() -> str:
|
| 59 |
-
if ENV_URL:
|
| 60 |
-
return ENV_URL.rstrip("/")
|
| 61 |
-
raise RuntimeError("ENV_URL must be set before running inference.py")
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def _post(path: str, payload: dict | None = None) -> dict:
|
| 65 |
-
response = requests.post(f"{_env_url()}{path}", json=payload or {}, timeout=30)
|
| 66 |
-
response.raise_for_status()
|
| 67 |
-
return response.json()
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def _get(path: str) -> dict:
|
| 71 |
-
response = requests.get(f"{_env_url()}{path}", timeout=30)
|
| 72 |
-
response.raise_for_status()
|
| 73 |
-
return response.json()
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def reset_env(task_id: str) -> dict:
|
| 77 |
-
return _post("/reset", {"task_id": task_id})
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def step_env(action: int) -> dict:
|
| 81 |
-
return _post("/step", {"action": action})
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def get_state() -> dict:
|
| 85 |
-
return _get("/state")
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def grade(task_id: str, code: str) -> float:
|
| 89 |
-
response = requests.post(
|
| 90 |
-
f"{_env_url()}/tasks/{task_id}/grade",
|
| 91 |
-
json={"code": code},
|
| 92 |
-
timeout=30,
|
| 93 |
-
)
|
| 94 |
-
response.raise_for_status()
|
| 95 |
-
return float(response.json().get("score", 0.0))
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
def choose_action(client: OpenAI, state: dict, task_id: str) -> Tuple[int, str]:
|
| 99 |
-
def heuristic_action() -> Tuple[int, str]:
|
| 100 |
-
code = str(state.get("current_code", ""))
|
| 101 |
-
step_i = int(state.get("episode_steps", 0))
|
| 102 |
-
|
| 103 |
-
has_generic = re.search(r"\b(x|tmp|i)\b", code) is not None
|
| 104 |
-
has_if_false = re.search(r"\bif\s+False\b", code) is not None
|
| 105 |
-
has_if_true = re.search(r"\bif\s+True\b", code) is not None
|
| 106 |
-
has_append_loop = ".append(" in code and "for " in code
|
| 107 |
-
has_double_not = "not not" in code
|
| 108 |
-
has_add_call = "add(" in code
|
| 109 |
-
|
| 110 |
-
if task_id == "rename_variables":
|
| 111 |
-
if has_generic:
|
| 112 |
-
return 0, "heuristic: remove generic names first"
|
| 113 |
-
if has_if_false or "unused" in code:
|
| 114 |
-
return 1, "heuristic: remove dead code"
|
| 115 |
-
if has_append_loop:
|
| 116 |
-
return 2, "heuristic: simplify loop"
|
| 117 |
-
if has_if_true or has_double_not:
|
| 118 |
-
return 3, "heuristic: optimize conditions"
|
| 119 |
-
return 4, "heuristic: inline simple function"
|
| 120 |
-
|
| 121 |
-
if task_id == "remove_dead_code":
|
| 122 |
-
if has_if_false or "unused" in code:
|
| 123 |
-
return 1, "heuristic: remove dead code patterns"
|
| 124 |
-
if has_append_loop:
|
| 125 |
-
return 2, "heuristic: convert append-loop"
|
| 126 |
-
if has_if_true or has_double_not:
|
| 127 |
-
return 3, "heuristic: simplify conditions"
|
| 128 |
-
if has_generic:
|
| 129 |
-
return 0, "heuristic: clean generic names"
|
| 130 |
-
return 4, "heuristic: inline helper"
|
| 131 |
-
|
| 132 |
-
if has_generic:
|
| 133 |
-
return 0, "heuristic: rename generic variables"
|
| 134 |
-
if has_append_loop:
|
| 135 |
-
return 2, "heuristic: simplify loop into listcomp"
|
| 136 |
-
if has_if_false or has_if_true or has_double_not:
|
| 137 |
-
return 3, "heuristic: optimize boolean branches"
|
| 138 |
-
if has_add_call:
|
| 139 |
-
return 4, "heuristic: inline add() call"
|
| 140 |
-
if step_i >= 2:
|
| 141 |
-
return 1, "heuristic: remove remaining dead code"
|
| 142 |
-
return 3, "heuristic: condition optimization as safe default"
|
| 143 |
-
|
| 144 |
-
if not HF_TOKEN:
|
| 145 |
-
return heuristic_action()
|
| 146 |
-
|
| 147 |
-
messages = [
|
| 148 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 149 |
-
{
|
| 150 |
-
"role": "user",
|
| 151 |
-
"content": (
|
| 152 |
-
f"Task: {task_id}\n"
|
| 153 |
-
f"Steps remaining: {state.get('max_steps', 5) - state.get('episode_steps', 0)}\n"
|
| 154 |
-
f"Complexity: {state.get('complexity', 0)}\n\n"
|
| 155 |
-
f"Current code:\n```python\n{state.get('current_code', '')}\n```\n\n"
|
| 156 |
-
"Choose the best action."
|
| 157 |
-
),
|
| 158 |
-
},
|
| 159 |
-
]
|
| 160 |
-
try:
|
| 161 |
-
response = client.chat.completions.create(
|
| 162 |
-
model=MODEL_NAME,
|
| 163 |
-
messages=messages,
|
| 164 |
-
temperature=0.0,
|
| 165 |
-
max_tokens=120,
|
| 166 |
-
)
|
| 167 |
-
raw = (response.choices[0].message.content or "").strip()
|
| 168 |
-
json_blob = raw
|
| 169 |
-
|
| 170 |
-
if "{" not in json_blob or "}" not in json_blob:
|
| 171 |
-
return heuristic_action()
|
| 172 |
-
|
| 173 |
-
match = re.search(r"\{.*\}", json_blob, flags=re.DOTALL)
|
| 174 |
-
if match:
|
| 175 |
-
json_blob = match.group(0)
|
| 176 |
-
|
| 177 |
-
parsed = json.loads(json_blob)
|
| 178 |
-
action = int(parsed.get("action", -1))
|
| 179 |
-
reason = str(parsed.get("reason", ""))
|
| 180 |
-
if 0 <= action <= 4:
|
| 181 |
-
return action, reason or "llm-selected action"
|
| 182 |
-
return heuristic_action()
|
| 183 |
-
except Exception:
|
| 184 |
-
return heuristic_action()
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
def run_episode(client: OpenAI, task_id: str, episode_num: int) -> float:
|
| 188 |
-
reset_env(task_id)
|
| 189 |
-
state = get_state()
|
| 190 |
-
|
| 191 |
-
print(
|
| 192 |
-
json.dumps(
|
| 193 |
-
{
|
| 194 |
-
"event": "START",
|
| 195 |
-
"episode": episode_num,
|
| 196 |
-
"task_id": task_id,
|
| 197 |
-
"initial_complexity": state.get("complexity", 0),
|
| 198 |
-
"initial_code_length": len(state.get("current_code", "")),
|
| 199 |
-
"timestamp": time.time(),
|
| 200 |
-
}
|
| 201 |
-
),
|
| 202 |
-
flush=True,
|
| 203 |
-
)
|
| 204 |
-
|
| 205 |
-
cumulative_reward = 0.0
|
| 206 |
-
|
| 207 |
-
for step_num in range(1, 6):
|
| 208 |
-
action, reason = choose_action(client, state, task_id)
|
| 209 |
-
result = step_env(action)
|
| 210 |
-
state = get_state()
|
| 211 |
-
|
| 212 |
-
reward_payload = result.get("reward", {})
|
| 213 |
-
raw_reward = float(reward_payload.get("raw", 0.0))
|
| 214 |
-
norm_reward = float(reward_payload.get("normalized", (raw_reward + 32) / 52))
|
| 215 |
-
cumulative_reward += raw_reward
|
| 216 |
-
|
| 217 |
-
print(
|
| 218 |
-
json.dumps(
|
| 219 |
-
{
|
| 220 |
-
"event": "STEP",
|
| 221 |
-
"episode": episode_num,
|
| 222 |
-
"step": step_num,
|
| 223 |
-
"action": action,
|
| 224 |
-
"action_name": ACTION_MEANINGS.get(action, "unknown"),
|
| 225 |
-
"reason": reason,
|
| 226 |
-
"reward": round(raw_reward, 4),
|
| 227 |
-
"normalized_reward": round(norm_reward, 4),
|
| 228 |
-
"cumulative_reward": round(cumulative_reward, 4),
|
| 229 |
-
"changed": result.get("info", {}).get("changed", False),
|
| 230 |
-
"reward_components": reward_payload.get("components", {}),
|
| 231 |
-
"done": result.get("done", False),
|
| 232 |
-
}
|
| 233 |
-
),
|
| 234 |
-
flush=True,
|
| 235 |
-
)
|
| 236 |
-
|
| 237 |
-
if result.get("done") or result.get("terminated") or result.get("truncated"):
|
| 238 |
-
break
|
| 239 |
-
|
| 240 |
-
final_state = get_state()
|
| 241 |
-
task_score = grade(task_id, final_state.get("current_code", ""))
|
| 242 |
-
|
| 243 |
-
print(
|
| 244 |
-
json.dumps(
|
| 245 |
-
{
|
| 246 |
-
"event": "END",
|
| 247 |
-
"episode": episode_num,
|
| 248 |
-
"task_id": task_id,
|
| 249 |
-
"cumulative_reward": round(cumulative_reward, 4),
|
| 250 |
-
"normalized_cumulative": round((cumulative_reward + 32) / 52, 4),
|
| 251 |
-
"task_score": round(task_score, 4),
|
| 252 |
-
"final_complexity": final_state.get("complexity", 0),
|
| 253 |
-
"timestamp": time.time(),
|
| 254 |
-
}
|
| 255 |
-
),
|
| 256 |
-
flush=True,
|
| 257 |
-
)
|
| 258 |
-
|
| 259 |
-
return task_score
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
def main() -> None:
|
| 263 |
-
if not ENV_URL:
|
| 264 |
-
raise SystemExit("ENV_URL is required. Example: ENV_URL=http://localhost:7860")
|
| 265 |
-
|
| 266 |
-
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "dummy")
|
| 267 |
-
|
| 268 |
-
scores: List[float] = []
|
| 269 |
-
for i, task_id in enumerate(TASKS, start=1):
|
| 270 |
-
score = run_episode(client, task_id, i)
|
| 271 |
-
scores.append(score)
|
| 272 |
-
|
| 273 |
-
avg_score = sum(scores) / len(scores) if scores else 0.0
|
| 274 |
-
sys.exit(0 if avg_score >= 0.5 else 1)
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
if __name__ == "__main__":
|
| 278 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/models.py
DELETED
|
@@ -1,156 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from typing import Any, Dict, List, Optional, Sequence
|
| 4 |
-
|
| 5 |
-
from pydantic import BaseModel, Field
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class ObservationModel(BaseModel):
|
| 9 |
-
code_length: float
|
| 10 |
-
complexity_score: float
|
| 11 |
-
runtime_s: float
|
| 12 |
-
error_flag: bool
|
| 13 |
-
|
| 14 |
-
@classmethod
|
| 15 |
-
def from_vector(cls, values: Sequence[float]) -> "ObservationModel":
|
| 16 |
-
vector = list(values)
|
| 17 |
-
if len(vector) != 4:
|
| 18 |
-
raise ValueError(f"observation vector must have length 4, got {len(vector)}")
|
| 19 |
-
return cls(
|
| 20 |
-
code_length=float(vector[0]),
|
| 21 |
-
complexity_score=float(vector[1]),
|
| 22 |
-
runtime_s=float(vector[2]),
|
| 23 |
-
error_flag=bool(vector[3]),
|
| 24 |
-
)
|
| 25 |
-
|
| 26 |
-
def to_vector(self) -> List[float]:
|
| 27 |
-
return [
|
| 28 |
-
float(self.code_length),
|
| 29 |
-
float(self.complexity_score),
|
| 30 |
-
float(self.runtime_s),
|
| 31 |
-
float(int(self.error_flag)),
|
| 32 |
-
]
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
class ActionModel(BaseModel):
|
| 36 |
-
action: int = Field(ge=0, le=4)
|
| 37 |
-
action_name: Optional[str] = None
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class RewardModel(BaseModel):
|
| 41 |
-
raw: float
|
| 42 |
-
normalized: float = Field(ge=0.0, le=1.0)
|
| 43 |
-
components: Dict[str, float]
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class HealthResponse(BaseModel):
|
| 47 |
-
status: str
|
| 48 |
-
env: str
|
| 49 |
-
version: str
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
class CompatibilityHealthResponse(BaseModel):
|
| 53 |
-
status: str
|
| 54 |
-
service: str
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
class ResetRequest(BaseModel):
|
| 58 |
-
task_id: Optional[str] = None
|
| 59 |
-
seed: Optional[int] = None
|
| 60 |
-
code: Optional[str] = None
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
class StepRequest(BaseModel):
|
| 64 |
-
action: int = Field(ge=0, le=4)
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
class GradeRequest(BaseModel):
|
| 68 |
-
code: str
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
class TaskInfo(BaseModel):
|
| 72 |
-
id: str
|
| 73 |
-
name: str
|
| 74 |
-
description: str
|
| 75 |
-
difficulty: str
|
| 76 |
-
initial_code: str
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
class TasksResponse(BaseModel):
|
| 80 |
-
tasks: List[TaskInfo]
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
class GradeResponse(BaseModel):
|
| 84 |
-
task_id: str
|
| 85 |
-
score: float
|
| 86 |
-
passed: bool
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
class StateResponse(BaseModel):
|
| 90 |
-
current_code: str
|
| 91 |
-
episode_steps: int
|
| 92 |
-
max_steps: int
|
| 93 |
-
complexity: float
|
| 94 |
-
last_runtime: float
|
| 95 |
-
last_error: bool
|
| 96 |
-
sample_id: Optional[str]
|
| 97 |
-
language: Optional[str]
|
| 98 |
-
task_id: Optional[str]
|
| 99 |
-
observation: ObservationModel
|
| 100 |
-
observation_vector: List[float]
|
| 101 |
-
action_meanings: Dict[int, str]
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
class ResetResponse(BaseModel):
|
| 105 |
-
observation: ObservationModel
|
| 106 |
-
observation_vector: List[float]
|
| 107 |
-
info: Dict[str, Any]
|
| 108 |
-
task_id: Optional[str]
|
| 109 |
-
state: StateResponse
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
class StepResponse(BaseModel):
|
| 113 |
-
action: ActionModel
|
| 114 |
-
observation: ObservationModel
|
| 115 |
-
observation_vector: List[float]
|
| 116 |
-
reward: RewardModel
|
| 117 |
-
done: bool
|
| 118 |
-
terminated: bool
|
| 119 |
-
truncated: bool
|
| 120 |
-
info: Dict[str, Any]
|
| 121 |
-
state: StateResponse
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
class OptimizeRequest(BaseModel):
|
| 125 |
-
code: str
|
| 126 |
-
task_id: Optional[str] = None
|
| 127 |
-
max_steps: int = Field(default=5, ge=1, le=5)
|
| 128 |
-
use_rl: bool = True
|
| 129 |
-
use_llm: bool = False
|
| 130 |
-
fallback_to_llm: bool = True
|
| 131 |
-
rl_model_path: Optional[str] = None
|
| 132 |
-
api_base_url: Optional[str] = None
|
| 133 |
-
model_name: Optional[str] = None
|
| 134 |
-
api_token: Optional[str] = None
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
class OptimizationStep(BaseModel):
|
| 138 |
-
step: int
|
| 139 |
-
action: int
|
| 140 |
-
action_name: str
|
| 141 |
-
reason: str
|
| 142 |
-
source: str
|
| 143 |
-
reward: float
|
| 144 |
-
normalized_reward: float
|
| 145 |
-
changed: bool
|
| 146 |
-
complexity: float
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
class OptimizeResponse(BaseModel):
|
| 150 |
-
original_code: str
|
| 151 |
-
optimized_code: str
|
| 152 |
-
diff: str
|
| 153 |
-
steps: List[OptimizationStep]
|
| 154 |
-
cumulative_reward: float
|
| 155 |
-
task_id: Optional[str]
|
| 156 |
-
task_score: Optional[float]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/openenv.yaml
DELETED
|
@@ -1,85 +0,0 @@
|
|
| 1 |
-
name: ACRE
|
| 2 |
-
version: "1.0.0"
|
| 3 |
-
description: >
|
| 4 |
-
Autonomous Code Refactoring Environment - an RL environment where an
|
| 5 |
-
agent improves Python code quality using AST-level transformations.
|
| 6 |
-
author: "Nikhil Pratap Singh, Pranav Mangal, Ananya Gupta"
|
| 7 |
-
entrypoint: "openenv_interface:OpenEnvRefactorEnv"
|
| 8 |
-
tags:
|
| 9 |
-
- openenv
|
| 10 |
-
|
| 11 |
-
tasks:
|
| 12 |
-
- id: rename_variables
|
| 13 |
-
name: "Rename Variables (Easy)"
|
| 14 |
-
description: "Rename generic variable names (x, tmp) to descriptive ones"
|
| 15 |
-
difficulty: easy
|
| 16 |
-
reward_range: [0.0, 1.0]
|
| 17 |
-
max_steps: 5
|
| 18 |
-
|
| 19 |
-
- id: remove_dead_code
|
| 20 |
-
name: "Remove Dead Code (Medium)"
|
| 21 |
-
description: "Remove unreachable statements, if-False blocks, and unused assignments"
|
| 22 |
-
difficulty: medium
|
| 23 |
-
reward_range: [0.0, 1.0]
|
| 24 |
-
max_steps: 5
|
| 25 |
-
|
| 26 |
-
- id: full_refactor
|
| 27 |
-
name: "Full Refactor (Hard)"
|
| 28 |
-
description: "Apply all transformations - rename, dead code removal, loop simplification, condition optimization, and function inlining"
|
| 29 |
-
difficulty: hard
|
| 30 |
-
reward_range: [0.0, 1.0]
|
| 31 |
-
max_steps: 5
|
| 32 |
-
|
| 33 |
-
observation_space:
|
| 34 |
-
type: Box
|
| 35 |
-
shape: [4]
|
| 36 |
-
dtype: float32
|
| 37 |
-
low: [0.0, 0.0, 0.0, 0.0]
|
| 38 |
-
high: [inf, inf, inf, 1.0]
|
| 39 |
-
fields:
|
| 40 |
-
- code_length
|
| 41 |
-
- complexity_score
|
| 42 |
-
- runtime_s
|
| 43 |
-
- error_flag
|
| 44 |
-
|
| 45 |
-
action_space:
|
| 46 |
-
type: Discrete
|
| 47 |
-
n: 5
|
| 48 |
-
actions:
|
| 49 |
-
0: rename_variable
|
| 50 |
-
1: remove_dead_code
|
| 51 |
-
2: simplify_loop
|
| 52 |
-
3: optimize_condition
|
| 53 |
-
4: inline_function
|
| 54 |
-
|
| 55 |
-
api:
|
| 56 |
-
health: "GET /"
|
| 57 |
-
reset: "POST /reset"
|
| 58 |
-
step: "POST /step"
|
| 59 |
-
state: "GET /state"
|
| 60 |
-
tasks: "GET /tasks"
|
| 61 |
-
grade: "POST /tasks/{task_id}/grade"
|
| 62 |
-
|
| 63 |
-
reward:
|
| 64 |
-
raw_range: [-32, 20]
|
| 65 |
-
normalized_range: [0.0, 1.0]
|
| 66 |
-
formula: "(raw + 32) / 52"
|
| 67 |
-
components:
|
| 68 |
-
success: { max: 10, min: -10 }
|
| 69 |
-
complexity: { max: 5, min: -5 }
|
| 70 |
-
performance: { max: 5, min: -2 }
|
| 71 |
-
error: { max: 0, min: -15 }
|
| 72 |
-
no_change: { max: 0, min: -2 }
|
| 73 |
-
|
| 74 |
-
validation:
|
| 75 |
-
python_api:
|
| 76 |
-
reset: "ObservationModel"
|
| 77 |
-
step: "(ObservationModel, RewardModel, done, info)"
|
| 78 |
-
state: "StateResponse"
|
| 79 |
-
http_api:
|
| 80 |
-
health: "GET /"
|
| 81 |
-
reset: "POST /reset"
|
| 82 |
-
step: "POST /step"
|
| 83 |
-
state: "GET /state"
|
| 84 |
-
tasks: "GET /tasks"
|
| 85 |
-
grade: "POST /tasks/{task_id}/grade"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/openenv_interface.py
DELETED
|
@@ -1,116 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from typing import Any, Dict, Optional, Tuple
|
| 4 |
-
|
| 5 |
-
try:
|
| 6 |
-
from openenv.env import Env as OpenEnvBase
|
| 7 |
-
except Exception: # pragma: no cover
|
| 8 |
-
class OpenEnvBase:
|
| 9 |
-
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
| 10 |
-
return None
|
| 11 |
-
|
| 12 |
-
from acre.datasets.code_samples import CodeSample, CodeSampleDataset
|
| 13 |
-
from acre.env.refactor_env import RefactorEnv
|
| 14 |
-
from acre.tasks.task_registry import TaskRegistry
|
| 15 |
-
from models import ActionModel, ObservationModel, RewardModel, StateResponse
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class OpenEnvRefactorEnv(OpenEnvBase):
|
| 19 |
-
"""
|
| 20 |
-
Canonical OpenEnv interface for ACRE.
|
| 21 |
-
|
| 22 |
-
This wrapper keeps the strict hackathon contract:
|
| 23 |
-
- reset() -> ObservationModel
|
| 24 |
-
- step(action) -> (ObservationModel, RewardModel, done, info)
|
| 25 |
-
- state() -> StateResponse
|
| 26 |
-
"""
|
| 27 |
-
|
| 28 |
-
def __init__(
|
| 29 |
-
self,
|
| 30 |
-
*,
|
| 31 |
-
env: Optional[RefactorEnv] = None,
|
| 32 |
-
registry: Optional[TaskRegistry] = None,
|
| 33 |
-
) -> None:
|
| 34 |
-
super().__init__(
|
| 35 |
-
name="ACRE",
|
| 36 |
-
state_space="ObservationModel",
|
| 37 |
-
action_space="ActionModel",
|
| 38 |
-
episode_max_length=RefactorEnv.MAX_STEPS,
|
| 39 |
-
)
|
| 40 |
-
self._env = env or RefactorEnv()
|
| 41 |
-
self._registry = registry or TaskRegistry()
|
| 42 |
-
self._task_id: Optional[str] = None
|
| 43 |
-
self._last_reset_info: Dict[str, Any] = {}
|
| 44 |
-
|
| 45 |
-
@property
|
| 46 |
-
def action_meanings(self) -> Dict[int, str]:
|
| 47 |
-
return self._env.ACTION_MEANINGS
|
| 48 |
-
|
| 49 |
-
@property
|
| 50 |
-
def last_reset_info(self) -> Dict[str, Any]:
|
| 51 |
-
return dict(self._last_reset_info)
|
| 52 |
-
|
| 53 |
-
def _load_episode_source(self, *, task_id: Optional[str], code: Optional[str]) -> None:
|
| 54 |
-
initial_code = code
|
| 55 |
-
if initial_code is None and task_id:
|
| 56 |
-
task = self._registry.get_task(task_id)
|
| 57 |
-
if task is None:
|
| 58 |
-
raise ValueError(f"Task '{task_id}' not found")
|
| 59 |
-
initial_code = task.initial_code
|
| 60 |
-
|
| 61 |
-
if initial_code is None:
|
| 62 |
-
return None
|
| 63 |
-
|
| 64 |
-
self._env.dataset = CodeSampleDataset(
|
| 65 |
-
[
|
| 66 |
-
CodeSample(
|
| 67 |
-
id=task_id or "custom",
|
| 68 |
-
language="python",
|
| 69 |
-
code=initial_code,
|
| 70 |
-
)
|
| 71 |
-
]
|
| 72 |
-
)
|
| 73 |
-
return None
|
| 74 |
-
|
| 75 |
-
def reset(
|
| 76 |
-
self,
|
| 77 |
-
*,
|
| 78 |
-
seed: Optional[int] = None,
|
| 79 |
-
task_id: Optional[str] = None,
|
| 80 |
-
code: Optional[str] = None,
|
| 81 |
-
) -> ObservationModel:
|
| 82 |
-
self._task_id = task_id
|
| 83 |
-
self._load_episode_source(task_id=task_id, code=code)
|
| 84 |
-
observation, info = self._env.reset(seed=seed)
|
| 85 |
-
self._last_reset_info = dict(info)
|
| 86 |
-
return ObservationModel.from_vector(observation.tolist())
|
| 87 |
-
|
| 88 |
-
def step(self, action: int | ActionModel) -> Tuple[ObservationModel, RewardModel, bool, Dict[str, Any]]:
|
| 89 |
-
action_value = action.action if isinstance(action, ActionModel) else int(action)
|
| 90 |
-
observation, raw_reward, terminated, truncated, info = self._env.step(action_value)
|
| 91 |
-
reward = RewardModel(
|
| 92 |
-
raw=float(raw_reward),
|
| 93 |
-
normalized=float(info.get("normalized_reward", 0.0)),
|
| 94 |
-
components=dict(info.get("reward_components", {})),
|
| 95 |
-
)
|
| 96 |
-
done = bool(terminated or truncated)
|
| 97 |
-
return ObservationModel.from_vector(observation.tolist()), reward, done, dict(info)
|
| 98 |
-
|
| 99 |
-
def state(self) -> StateResponse:
|
| 100 |
-
raw_state = self._env.state()
|
| 101 |
-
observation_vector = list(raw_state.get("observation", [0.0, 0.0, 0.0, 0.0]))
|
| 102 |
-
observation = ObservationModel.from_vector(observation_vector)
|
| 103 |
-
return StateResponse(
|
| 104 |
-
current_code=str(raw_state.get("current_code", "")),
|
| 105 |
-
episode_steps=int(raw_state.get("episode_steps", 0)),
|
| 106 |
-
max_steps=int(raw_state.get("max_steps", RefactorEnv.MAX_STEPS)),
|
| 107 |
-
complexity=float(raw_state.get("complexity", 0.0)),
|
| 108 |
-
last_runtime=float(raw_state.get("last_runtime", 0.0)),
|
| 109 |
-
last_error=bool(raw_state.get("last_error", False)),
|
| 110 |
-
sample_id=raw_state.get("sample_id"),
|
| 111 |
-
language=raw_state.get("language"),
|
| 112 |
-
task_id=self._task_id,
|
| 113 |
-
observation=observation,
|
| 114 |
-
observation_vector=observation.to_vector(),
|
| 115 |
-
action_meanings=dict(raw_state.get("action_meanings", {})),
|
| 116 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/requirements.txt
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
fastapi>=0.109.0
|
| 2 |
-
uvicorn[standard]>=0.27.0
|
| 3 |
-
numpy>=1.26
|
| 4 |
-
gymnasium
|
| 5 |
-
stable-baselines3
|
| 6 |
-
radon>=6.0.1
|
| 7 |
-
openai>=1.0.0
|
| 8 |
-
openenv>=0.1.13
|
| 9 |
-
requests>=2.31.0
|
| 10 |
-
pydantic>=2.0.0
|
| 11 |
-
typing_extensions>=4.0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/server.py
DELETED
|
@@ -1,667 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
ACRE OpenEnv HTTP server.
|
| 3 |
-
|
| 4 |
-
Endpoints (all required by OpenEnv spec):
|
| 5 |
-
GET / — health check (must return HTTP 200)
|
| 6 |
-
POST /reset — reset environment, returns observation + info
|
| 7 |
-
POST /step — take one step, returns obs/reward/done/info
|
| 8 |
-
GET /state — full current state snapshot
|
| 9 |
-
GET /tasks — list all tasks with initial code
|
| 10 |
-
POST /tasks/{task_id}/grade — grade code for a specific task
|
| 11 |
-
"""
|
| 12 |
-
from __future__ import annotations
|
| 13 |
-
|
| 14 |
-
import difflib
|
| 15 |
-
import os
|
| 16 |
-
import re
|
| 17 |
-
import json
|
| 18 |
-
from typing import Optional
|
| 19 |
-
|
| 20 |
-
import uvicorn
|
| 21 |
-
import numpy as np
|
| 22 |
-
from fastapi import FastAPI, HTTPException
|
| 23 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 24 |
-
from fastapi.responses import HTMLResponse
|
| 25 |
-
from openai import OpenAI
|
| 26 |
-
|
| 27 |
-
try:
|
| 28 |
-
from stable_baselines3 import PPO
|
| 29 |
-
except Exception:
|
| 30 |
-
PPO = None # type: ignore[assignment]
|
| 31 |
-
|
| 32 |
-
from acre.tasks.task_registry import TaskRegistry
|
| 33 |
-
from models import (
|
| 34 |
-
ActionModel,
|
| 35 |
-
CompatibilityHealthResponse,
|
| 36 |
-
GradeRequest,
|
| 37 |
-
GradeResponse,
|
| 38 |
-
HealthResponse,
|
| 39 |
-
OptimizationStep,
|
| 40 |
-
OptimizeRequest,
|
| 41 |
-
OptimizeResponse,
|
| 42 |
-
ResetRequest,
|
| 43 |
-
ResetResponse,
|
| 44 |
-
StateResponse,
|
| 45 |
-
StepRequest,
|
| 46 |
-
StepResponse,
|
| 47 |
-
TaskInfo,
|
| 48 |
-
TasksResponse,
|
| 49 |
-
)
|
| 50 |
-
from openenv_interface import OpenEnvRefactorEnv
|
| 51 |
-
|
| 52 |
-
DEFAULT_API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
|
| 53 |
-
DEFAULT_MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
|
| 54 |
-
DEFAULT_RL_MODEL_PATH = os.getenv("RL_MODEL_PATH", "acre_agent.zip")
|
| 55 |
-
|
| 56 |
-
# ---------------------------------------------------------------------------
|
| 57 |
-
# App setup
|
| 58 |
-
# ---------------------------------------------------------------------------
|
| 59 |
-
|
| 60 |
-
app = FastAPI(
|
| 61 |
-
title="ACRE — Autonomous Code Refactoring Environment",
|
| 62 |
-
description="OpenEnv-compatible RL environment for Python code refactoring.",
|
| 63 |
-
version="1.0.0",
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
app.add_middleware(
|
| 67 |
-
CORSMiddleware,
|
| 68 |
-
allow_origins=["*"],
|
| 69 |
-
allow_methods=["*"],
|
| 70 |
-
allow_headers=["*"],
|
| 71 |
-
)
|
| 72 |
-
|
| 73 |
-
# Global singletons
|
| 74 |
-
registry = TaskRegistry()
|
| 75 |
-
_env: Optional[OpenEnvRefactorEnv] = None
|
| 76 |
-
_rl_model_cache: dict[str, object] = {}
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def get_env() -> OpenEnvRefactorEnv:
|
| 80 |
-
global _env
|
| 81 |
-
if _env is None:
|
| 82 |
-
_env = OpenEnvRefactorEnv(registry=registry)
|
| 83 |
-
return _env
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
def _state_response() -> StateResponse:
|
| 87 |
-
return get_env().state()
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def _choose_action_heuristic(code: str, task_id: Optional[str]) -> int:
|
| 91 |
-
has_generic = re.search(r"\b(x|tmp|i)\b", code) is not None
|
| 92 |
-
has_if_false = re.search(r"\bif\s+False\b", code) is not None
|
| 93 |
-
has_if_true = re.search(r"\bif\s+True\b", code) is not None
|
| 94 |
-
has_append_loop = ".append(" in code and "for " in code
|
| 95 |
-
has_double_not = "not not" in code
|
| 96 |
-
has_add_call = "add(" in code
|
| 97 |
-
|
| 98 |
-
if task_id == "rename_variables":
|
| 99 |
-
if has_generic:
|
| 100 |
-
return 0
|
| 101 |
-
if has_if_false or "unused" in code:
|
| 102 |
-
return 1
|
| 103 |
-
if has_append_loop:
|
| 104 |
-
return 2
|
| 105 |
-
if has_if_true or has_double_not:
|
| 106 |
-
return 3
|
| 107 |
-
return 4
|
| 108 |
-
|
| 109 |
-
if task_id == "remove_dead_code":
|
| 110 |
-
if has_if_false or "unused" in code:
|
| 111 |
-
return 1
|
| 112 |
-
if has_append_loop:
|
| 113 |
-
return 2
|
| 114 |
-
if has_if_true or has_double_not:
|
| 115 |
-
return 3
|
| 116 |
-
if has_generic:
|
| 117 |
-
return 0
|
| 118 |
-
return 4
|
| 119 |
-
|
| 120 |
-
if has_generic:
|
| 121 |
-
return 0
|
| 122 |
-
if has_append_loop:
|
| 123 |
-
return 2
|
| 124 |
-
if has_if_false or has_if_true or has_double_not:
|
| 125 |
-
return 3
|
| 126 |
-
if has_add_call:
|
| 127 |
-
return 4
|
| 128 |
-
return 1
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def _choose_action_llm(
|
| 132 |
-
*,
|
| 133 |
-
code: str,
|
| 134 |
-
task_id: Optional[str],
|
| 135 |
-
step_index: int,
|
| 136 |
-
max_steps: int,
|
| 137 |
-
api_base_url: str,
|
| 138 |
-
model_name: str,
|
| 139 |
-
api_token: str,
|
| 140 |
-
) -> tuple[int, str, str]:
|
| 141 |
-
if not api_token.strip():
|
| 142 |
-
return _choose_action_heuristic(code, task_id), "empty token -> heuristic", "heuristic"
|
| 143 |
-
|
| 144 |
-
client = OpenAI(base_url=api_base_url, api_key=api_token)
|
| 145 |
-
messages = [
|
| 146 |
-
{
|
| 147 |
-
"role": "system",
|
| 148 |
-
"content": (
|
| 149 |
-
"You are a code-refactoring action selector. Return ONLY compact JSON: "
|
| 150 |
-
'{"action": <0-4>, "reason": "..."}.\n'
|
| 151 |
-
"Actions: 0=rename_variable,1=remove_dead_code,2=simplify_loop,3=optimize_condition,4=inline_function"
|
| 152 |
-
),
|
| 153 |
-
},
|
| 154 |
-
{
|
| 155 |
-
"role": "user",
|
| 156 |
-
"content": (
|
| 157 |
-
f"task_id={task_id or 'auto'}\n"
|
| 158 |
-
f"step={step_index}/{max_steps}\n"
|
| 159 |
-
"Current code:\n"
|
| 160 |
-
f"```python\n{code}\n```"
|
| 161 |
-
),
|
| 162 |
-
},
|
| 163 |
-
]
|
| 164 |
-
try:
|
| 165 |
-
resp = client.chat.completions.create(
|
| 166 |
-
model=model_name,
|
| 167 |
-
messages=messages,
|
| 168 |
-
temperature=0.0,
|
| 169 |
-
max_tokens=120,
|
| 170 |
-
)
|
| 171 |
-
raw = (resp.choices[0].message.content or "").strip()
|
| 172 |
-
m = re.search(r"\{.*\}", raw, flags=re.DOTALL)
|
| 173 |
-
blob = m.group(0) if m else raw
|
| 174 |
-
parsed = json.loads(blob)
|
| 175 |
-
action = int(parsed.get("action", -1))
|
| 176 |
-
reason = str(parsed.get("reason", "llm-selected action"))
|
| 177 |
-
if 0 <= action <= 4:
|
| 178 |
-
return action, reason, "llm"
|
| 179 |
-
except Exception as exc:
|
| 180 |
-
return _choose_action_heuristic(code, task_id), f"llm error -> heuristic: {exc}", "heuristic"
|
| 181 |
-
|
| 182 |
-
return _choose_action_heuristic(code, task_id), "invalid llm output -> heuristic", "heuristic"
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
def _choose_action_rl(observation: list[float], model_path: str) -> tuple[Optional[int], str, str]:
|
| 186 |
-
if PPO is None:
|
| 187 |
-
return None, "stable-baselines3 unavailable", "rl"
|
| 188 |
-
if not os.path.exists(model_path):
|
| 189 |
-
return None, f"rl model not found: {model_path}", "rl"
|
| 190 |
-
|
| 191 |
-
try:
|
| 192 |
-
model = _rl_model_cache.get(model_path)
|
| 193 |
-
if model is None:
|
| 194 |
-
model = PPO.load(model_path)
|
| 195 |
-
_rl_model_cache[model_path] = model
|
| 196 |
-
|
| 197 |
-
obs = np.asarray(observation, dtype=np.float32)
|
| 198 |
-
action, _ = model.predict(obs, deterministic=True)
|
| 199 |
-
action_i = int(action)
|
| 200 |
-
if 0 <= action_i <= 4:
|
| 201 |
-
return action_i, "rl policy action", "rl"
|
| 202 |
-
return None, f"invalid rl action: {action_i}", "rl"
|
| 203 |
-
except Exception as exc:
|
| 204 |
-
return None, f"rl failure: {exc}", "rl"
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
def _demo_html() -> str:
|
| 208 |
-
return """<!doctype html>
|
| 209 |
-
<html lang=\"en\">
|
| 210 |
-
<head>
|
| 211 |
-
<meta charset=\"utf-8\" />
|
| 212 |
-
<meta name=\"viewport\" content=\"width=device-width, initial-scale=1\" />
|
| 213 |
-
<title>ACRE Refactor Demo</title>
|
| 214 |
-
<style>
|
| 215 |
-
@import url('https://fonts.googleapis.com/css2?family=Space+Grotesk:wght@400;600;700&display=swap');
|
| 216 |
-
:root {
|
| 217 |
-
--bg0: #0b1f2a;
|
| 218 |
-
--bg1: #14344a;
|
| 219 |
-
--ink: #eaf7ff;
|
| 220 |
-
--muted: #a7c8db;
|
| 221 |
-
--brand: #1ec28b;
|
| 222 |
-
--warn: #ffcb47;
|
| 223 |
-
--panel: rgba(8, 24, 36, 0.72);
|
| 224 |
-
--stroke: rgba(140, 197, 225, 0.35);
|
| 225 |
-
}
|
| 226 |
-
* { box-sizing: border-box; }
|
| 227 |
-
body {
|
| 228 |
-
margin: 0;
|
| 229 |
-
color: var(--ink);
|
| 230 |
-
font-family: 'Space Grotesk', sans-serif;
|
| 231 |
-
background:
|
| 232 |
-
radial-gradient(circle at 12% 18%, rgba(30, 194, 139, 0.28), transparent 35%),
|
| 233 |
-
radial-gradient(circle at 88% 8%, rgba(255, 203, 71, 0.22), transparent 30%),
|
| 234 |
-
linear-gradient(150deg, var(--bg0), var(--bg1));
|
| 235 |
-
min-height: 100vh;
|
| 236 |
-
}
|
| 237 |
-
.wrap {
|
| 238 |
-
max-width: 1200px;
|
| 239 |
-
margin: 0 auto;
|
| 240 |
-
padding: 28px 20px 40px;
|
| 241 |
-
}
|
| 242 |
-
h1 {
|
| 243 |
-
margin: 0 0 6px;
|
| 244 |
-
font-size: clamp(1.6rem, 2vw + 1rem, 2.6rem);
|
| 245 |
-
letter-spacing: 0.2px;
|
| 246 |
-
}
|
| 247 |
-
.sub { margin: 0 0 20px; color: var(--muted); }
|
| 248 |
-
.grid {
|
| 249 |
-
display: grid;
|
| 250 |
-
grid-template-columns: 1fr;
|
| 251 |
-
gap: 16px;
|
| 252 |
-
}
|
| 253 |
-
.panel {
|
| 254 |
-
border: 1px solid var(--stroke);
|
| 255 |
-
border-radius: 14px;
|
| 256 |
-
background: var(--panel);
|
| 257 |
-
backdrop-filter: blur(4px);
|
| 258 |
-
padding: 14px;
|
| 259 |
-
}
|
| 260 |
-
.controls {
|
| 261 |
-
display: grid;
|
| 262 |
-
grid-template-columns: 1fr 1fr;
|
| 263 |
-
gap: 8px;
|
| 264 |
-
margin-bottom: 10px;
|
| 265 |
-
}
|
| 266 |
-
textarea, pre {
|
| 267 |
-
width: 100%;
|
| 268 |
-
min-height: 260px;
|
| 269 |
-
border: 1px solid var(--stroke);
|
| 270 |
-
border-radius: 10px;
|
| 271 |
-
padding: 12px;
|
| 272 |
-
background: rgba(1, 13, 24, 0.82);
|
| 273 |
-
color: #dcf4ff;
|
| 274 |
-
font-family: Consolas, 'Courier New', monospace;
|
| 275 |
-
font-size: 13px;
|
| 276 |
-
line-height: 1.4;
|
| 277 |
-
overflow: auto;
|
| 278 |
-
white-space: pre;
|
| 279 |
-
}
|
| 280 |
-
button, select {
|
| 281 |
-
border: 1px solid var(--stroke);
|
| 282 |
-
border-radius: 10px;
|
| 283 |
-
padding: 10px 12px;
|
| 284 |
-
background: rgba(11, 36, 52, 0.9);
|
| 285 |
-
color: var(--ink);
|
| 286 |
-
font-weight: 600;
|
| 287 |
-
}
|
| 288 |
-
button.primary {
|
| 289 |
-
background: linear-gradient(120deg, #19a7ff, #1ec28b);
|
| 290 |
-
color: #032235;
|
| 291 |
-
border: none;
|
| 292 |
-
}
|
| 293 |
-
.cols {
|
| 294 |
-
display: grid;
|
| 295 |
-
grid-template-columns: 1fr;
|
| 296 |
-
gap: 14px;
|
| 297 |
-
}
|
| 298 |
-
.meta {
|
| 299 |
-
color: var(--muted);
|
| 300 |
-
font-size: 0.92rem;
|
| 301 |
-
margin-top: 8px;
|
| 302 |
-
}
|
| 303 |
-
.badge {
|
| 304 |
-
color: #082b22;
|
| 305 |
-
background: var(--brand);
|
| 306 |
-
border-radius: 999px;
|
| 307 |
-
padding: 2px 9px;
|
| 308 |
-
font-size: 12px;
|
| 309 |
-
font-weight: 700;
|
| 310 |
-
}
|
| 311 |
-
.warn {
|
| 312 |
-
color: #2a1c00;
|
| 313 |
-
background: var(--warn);
|
| 314 |
-
}
|
| 315 |
-
@media (min-width: 900px) {
|
| 316 |
-
.cols { grid-template-columns: 1fr 1fr; }
|
| 317 |
-
}
|
| 318 |
-
</style>
|
| 319 |
-
</head>
|
| 320 |
-
<body>
|
| 321 |
-
<div class=\"wrap\">
|
| 322 |
-
<h1>ACRE Live Refactor Arena</h1>
|
| 323 |
-
<p class=\"sub\">Paste old code, run the agent, and compare before and after with a full diff and step-by-step rewards.</p>
|
| 324 |
-
|
| 325 |
-
<div class=\"panel\">
|
| 326 |
-
<div class=\"controls\">
|
| 327 |
-
<button onclick=\"loadExample(1)\">Load Example 1</button>
|
| 328 |
-
<button onclick=\"loadExample(2)\">Load Example 2</button>
|
| 329 |
-
<select id=\"task\">
|
| 330 |
-
<option value=\"\">Auto strategy</option>
|
| 331 |
-
<option value=\"rename_variables\">rename_variables</option>
|
| 332 |
-
<option value=\"remove_dead_code\">remove_dead_code</option>
|
| 333 |
-
<option value=\"full_refactor\">full_refactor</option>
|
| 334 |
-
</select>
|
| 335 |
-
<button class=\"primary\" onclick=\"runOptimize()\">Run Optimization</button>
|
| 336 |
-
</div>
|
| 337 |
-
<div class=\"controls\" style=\"margin-bottom: 10px;\">
|
| 338 |
-
<select id=\"mode\">
|
| 339 |
-
<option value=\"rl_then_llm\">RL First -> LLM Fallback</option>
|
| 340 |
-
<option value=\"heuristic\">Heuristic Agent (no API key)</option>
|
| 341 |
-
<option value=\"llm\">LLM Agent (OpenAI-compatible API)</option>
|
| 342 |
-
</select>
|
| 343 |
-
<input id=\"rlModelPath\" placeholder=\"RL model path\" value=\"acre_agent.zip\" style=\"border:1px solid var(--stroke);border-radius:10px;padding:10px 12px;background:rgba(1,13,24,0.82);color:#dcf4ff;\" />
|
| 344 |
-
<input id=\"baseUrl\" placeholder=\"API base URL (optional)\" value=\"https://api.openai.com/v1\" style=\"border:1px solid var(--stroke);border-radius:10px;padding:10px 12px;background:rgba(1,13,24,0.82);color:#dcf4ff;\" />
|
| 345 |
-
<input id=\"modelName\" placeholder=\"Model name (optional)\" value=\"gpt-4o-mini\" style=\"border:1px solid var(--stroke);border-radius:10px;padding:10px 12px;background:rgba(1,13,24,0.82);color:#dcf4ff;\" />
|
| 346 |
-
<input id=\"apiToken\" type=\"password\" placeholder=\"Paste API token here for LLM mode\" style=\"border:1px solid var(--stroke);border-radius:10px;padding:10px 12px;background:rgba(1,13,24,0.82);color:#dcf4ff;\" />
|
| 347 |
-
</div>
|
| 348 |
-
<div class=\"controls\" style=\"margin-bottom: 10px;\">
|
| 349 |
-
<label style=\"display:flex;align-items:center;gap:8px;padding:8px 10px;border:1px solid var(--stroke);border-radius:10px;\">
|
| 350 |
-
<input id=\"autoSuggest\" type=\"checkbox\" />
|
| 351 |
-
Auto suggest after typing pause
|
| 352 |
-
</label>
|
| 353 |
-
</div>
|
| 354 |
-
<textarea id=\"input\" spellcheck=\"false\" placeholder=\"Paste your Python code here...\"></textarea>
|
| 355 |
-
<p class=\"meta\" id=\"status\">Status: ready</p>
|
| 356 |
-
</div>
|
| 357 |
-
|
| 358 |
-
<div class=\"cols\" style=\"margin-top: 14px\">
|
| 359 |
-
<div class=\"panel\">
|
| 360 |
-
<h3>Original Code</h3>
|
| 361 |
-
<pre id=\"original\"></pre>
|
| 362 |
-
</div>
|
| 363 |
-
<div class=\"panel\">
|
| 364 |
-
<h3>Optimized Code</h3>
|
| 365 |
-
<pre id=\"optimized\"></pre>
|
| 366 |
-
</div>
|
| 367 |
-
</div>
|
| 368 |
-
|
| 369 |
-
<div class=\"panel\" style=\"margin-top: 14px\">
|
| 370 |
-
<h3>Diff</h3>
|
| 371 |
-
<pre id=\"diff\"></pre>
|
| 372 |
-
</div>
|
| 373 |
-
|
| 374 |
-
<div class=\"panel\" style=\"margin-top: 14px\">
|
| 375 |
-
<h3>Step Logs</h3>
|
| 376 |
-
<pre id=\"steps\"></pre>
|
| 377 |
-
</div>
|
| 378 |
-
</div>
|
| 379 |
-
|
| 380 |
-
<script>
|
| 381 |
-
const EX1 = `def compute(x, y, tmp):\n tmp = x + y\n x = tmp * 2\n result = x\n return result\n`;
|
| 382 |
-
const EX2 = `def add(p, q):\n return p + q\n\ndef compute(x, data, tmp):\n result = []\n for item in data:\n result.append(item * 2)\n if False:\n y = 999\n if True:\n val = add(x, tmp)\n unused = 0\n flag = not not True\n return val\n print(\"dead\")\n`;
|
| 383 |
-
let autoTimer = null;
|
| 384 |
-
|
| 385 |
-
function loadExample(i) {
|
| 386 |
-
document.getElementById('input').value = i === 1 ? EX1 : EX2;
|
| 387 |
-
document.getElementById('status').textContent = `Status: loaded example ${i}`;
|
| 388 |
-
}
|
| 389 |
-
|
| 390 |
-
async function runOptimize() {
|
| 391 |
-
const code = document.getElementById('input').value;
|
| 392 |
-
const task = document.getElementById('task').value || null;
|
| 393 |
-
const mode = document.getElementById('mode').value;
|
| 394 |
-
const useRl = mode === 'rl_then_llm';
|
| 395 |
-
const useLlm = mode === 'llm' || mode === 'rl_then_llm';
|
| 396 |
-
const fallbackToLlm = mode === 'rl_then_llm';
|
| 397 |
-
const rlModelPath = document.getElementById('rlModelPath').value || null;
|
| 398 |
-
const apiToken = document.getElementById('apiToken').value || null;
|
| 399 |
-
const apiBaseUrl = document.getElementById('baseUrl').value || null;
|
| 400 |
-
const modelName = document.getElementById('modelName').value || null;
|
| 401 |
-
if (!code.trim()) {
|
| 402 |
-
document.getElementById('status').innerHTML = 'Status: <span class=\"badge warn\">please paste code first</span>';
|
| 403 |
-
return;
|
| 404 |
-
}
|
| 405 |
-
if (mode === 'llm' && (!apiToken || !apiToken.trim())) {
|
| 406 |
-
document.getElementById('status').innerHTML = 'Status: <span class=\"badge warn\">paste API token for LLM mode</span>';
|
| 407 |
-
return;
|
| 408 |
-
}
|
| 409 |
-
|
| 410 |
-
document.getElementById('status').textContent = 'Status: running optimization...';
|
| 411 |
-
try {
|
| 412 |
-
const res = await fetch('/optimize', {
|
| 413 |
-
method: 'POST',
|
| 414 |
-
headers: {'Content-Type': 'application/json'},
|
| 415 |
-
body: JSON.stringify({
|
| 416 |
-
code,
|
| 417 |
-
task_id: task,
|
| 418 |
-
max_steps: 5,
|
| 419 |
-
use_rl: useRl,
|
| 420 |
-
use_llm: useLlm,
|
| 421 |
-
fallback_to_llm: fallbackToLlm,
|
| 422 |
-
rl_model_path: rlModelPath,
|
| 423 |
-
api_base_url: apiBaseUrl,
|
| 424 |
-
model_name: modelName,
|
| 425 |
-
api_token: apiToken,
|
| 426 |
-
})
|
| 427 |
-
});
|
| 428 |
-
const data = await res.json();
|
| 429 |
-
if (!res.ok) {
|
| 430 |
-
throw new Error(data.detail || 'request failed');
|
| 431 |
-
}
|
| 432 |
-
|
| 433 |
-
document.getElementById('original').textContent = data.original_code;
|
| 434 |
-
document.getElementById('optimized').textContent = data.optimized_code;
|
| 435 |
-
document.getElementById('diff').textContent = data.diff || '(no diff)';
|
| 436 |
-
document.getElementById('steps').textContent = JSON.stringify(data.steps, null, 2);
|
| 437 |
-
|
| 438 |
-
const scoreText = data.task_score === null ? 'n/a' : data.task_score;
|
| 439 |
-
document.getElementById('status').innerHTML = `Status: <span class=\"badge\">done</span> cumulative_reward=${data.cumulative_reward.toFixed(2)} task_score=${scoreText}`;
|
| 440 |
-
} catch (err) {
|
| 441 |
-
document.getElementById('status').innerHTML = `Status: <span class=\"badge warn\">error</span> ${err.message}`;
|
| 442 |
-
}
|
| 443 |
-
}
|
| 444 |
-
|
| 445 |
-
loadExample(1);
|
| 446 |
-
document.getElementById('input').addEventListener('input', () => {
|
| 447 |
-
if (!document.getElementById('autoSuggest').checked) {
|
| 448 |
-
return;
|
| 449 |
-
}
|
| 450 |
-
if (autoTimer) {
|
| 451 |
-
clearTimeout(autoTimer);
|
| 452 |
-
}
|
| 453 |
-
autoTimer = setTimeout(() => {
|
| 454 |
-
runOptimize();
|
| 455 |
-
}, 1200);
|
| 456 |
-
});
|
| 457 |
-
</script>
|
| 458 |
-
</body>
|
| 459 |
-
</html>"""
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
# ---------------------------------------------------------------------------
|
| 463 |
-
# Routes
|
| 464 |
-
# ---------------------------------------------------------------------------
|
| 465 |
-
|
| 466 |
-
@app.get("/", response_model=HealthResponse)
|
| 467 |
-
def health() -> HealthResponse:
|
| 468 |
-
"""Health check — OpenEnv pings this URL to verify the Space is live."""
|
| 469 |
-
return HealthResponse(status="ok", env="ACRE", version="1.0.0")
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
@app.get("/health", response_model=CompatibilityHealthResponse)
|
| 473 |
-
def health_compat() -> CompatibilityHealthResponse:
|
| 474 |
-
"""Compatibility health route used by some OpenEnv reference environments."""
|
| 475 |
-
return CompatibilityHealthResponse(status="healthy", service="acre-env")
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
@app.get("/demo", response_class=HTMLResponse)
|
| 479 |
-
def demo_ui() -> HTMLResponse:
|
| 480 |
-
"""Simple UI to compare original and optimized code side-by-side."""
|
| 481 |
-
return HTMLResponse(content=_demo_html())
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
@app.post("/reset", response_model=ResetResponse)
|
| 485 |
-
def reset(req: ResetRequest = ResetRequest()) -> ResetResponse:
|
| 486 |
-
"""Reset the environment. Optionally load a task's initial code."""
|
| 487 |
-
env = get_env()
|
| 488 |
-
try:
|
| 489 |
-
obs = env.reset(seed=req.seed, task_id=req.task_id, code=req.code)
|
| 490 |
-
except ValueError as exc:
|
| 491 |
-
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
| 492 |
-
return ResetResponse(
|
| 493 |
-
observation=obs,
|
| 494 |
-
observation_vector=obs.to_vector(),
|
| 495 |
-
info=env.last_reset_info,
|
| 496 |
-
task_id=req.task_id,
|
| 497 |
-
state=_state_response(),
|
| 498 |
-
)
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
@app.post("/step", response_model=StepResponse)
|
| 502 |
-
def step(req: StepRequest) -> StepResponse:
|
| 503 |
-
"""Take one refactoring step."""
|
| 504 |
-
env = get_env()
|
| 505 |
-
if not (0 <= req.action <= 4):
|
| 506 |
-
raise HTTPException(status_code=400, detail="action must be 0–4")
|
| 507 |
-
|
| 508 |
-
obs, reward, done, info = env.step(req.action)
|
| 509 |
-
action_name = str(info.get("action_name", env.action_meanings.get(req.action, "unknown")))
|
| 510 |
-
|
| 511 |
-
return StepResponse(
|
| 512 |
-
action=ActionModel(action=req.action, action_name=action_name),
|
| 513 |
-
observation=obs,
|
| 514 |
-
observation_vector=obs.to_vector(),
|
| 515 |
-
reward=reward,
|
| 516 |
-
done=done,
|
| 517 |
-
terminated=done,
|
| 518 |
-
truncated=False,
|
| 519 |
-
info=info,
|
| 520 |
-
state=_state_response(),
|
| 521 |
-
)
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
@app.get("/state", response_model=StateResponse)
|
| 525 |
-
def state() -> StateResponse:
|
| 526 |
-
"""Return full current environment state (OpenEnv spec requirement)."""
|
| 527 |
-
return _state_response()
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
@app.get("/tasks", response_model=TasksResponse)
|
| 531 |
-
def list_tasks() -> TasksResponse:
|
| 532 |
-
"""Enumerate all tasks (easy → medium → hard)."""
|
| 533 |
-
return TasksResponse(tasks=[TaskInfo.model_validate(t) for t in registry.list_tasks()])
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
@app.post("/tasks/{task_id}/grade", response_model=GradeResponse)
|
| 537 |
-
def grade(task_id: str, req: GradeRequest) -> GradeResponse:
|
| 538 |
-
"""Grade submitted code against a task's grader (returns score 0.0–1.0)."""
|
| 539 |
-
task = registry.get_task(task_id)
|
| 540 |
-
if task is None:
|
| 541 |
-
raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found")
|
| 542 |
-
score = task.grade(req.code)
|
| 543 |
-
return GradeResponse(
|
| 544 |
-
task_id=task_id,
|
| 545 |
-
score=round(score, 4),
|
| 546 |
-
passed=score >= 0.8,
|
| 547 |
-
)
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
@app.post("/optimize", response_model=OptimizeResponse)
|
| 551 |
-
def optimize(req: OptimizeRequest) -> OptimizeResponse:
|
| 552 |
-
"""Run a full optimization episode and return code comparison artifacts."""
|
| 553 |
-
code = req.code.strip("\n")
|
| 554 |
-
if not code.strip():
|
| 555 |
-
raise HTTPException(status_code=400, detail="code must be non-empty")
|
| 556 |
-
|
| 557 |
-
env = get_env()
|
| 558 |
-
try:
|
| 559 |
-
env.reset(task_id=req.task_id, code=code)
|
| 560 |
-
except ValueError as exc:
|
| 561 |
-
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
| 562 |
-
|
| 563 |
-
steps: list[OptimizationStep] = []
|
| 564 |
-
cumulative_reward = 0.0
|
| 565 |
-
|
| 566 |
-
for step_idx in range(1, req.max_steps + 1):
|
| 567 |
-
state_now = env.state()
|
| 568 |
-
current_code = state_now.current_code
|
| 569 |
-
obs_list = [float(x) for x in state_now.observation_vector]
|
| 570 |
-
|
| 571 |
-
action: int
|
| 572 |
-
reason: str
|
| 573 |
-
source: str
|
| 574 |
-
|
| 575 |
-
if req.use_rl:
|
| 576 |
-
rl_action, rl_reason, rl_source = _choose_action_rl(
|
| 577 |
-
observation=obs_list,
|
| 578 |
-
model_path=req.rl_model_path or DEFAULT_RL_MODEL_PATH,
|
| 579 |
-
)
|
| 580 |
-
if rl_action is not None:
|
| 581 |
-
action, reason, source = rl_action, rl_reason, rl_source
|
| 582 |
-
elif req.fallback_to_llm and req.use_llm:
|
| 583 |
-
action, reason, source = _choose_action_llm(
|
| 584 |
-
code=current_code,
|
| 585 |
-
task_id=req.task_id,
|
| 586 |
-
step_index=step_idx,
|
| 587 |
-
max_steps=req.max_steps,
|
| 588 |
-
api_base_url=req.api_base_url or DEFAULT_API_BASE_URL,
|
| 589 |
-
model_name=req.model_name or DEFAULT_MODEL_NAME,
|
| 590 |
-
api_token=req.api_token or "",
|
| 591 |
-
)
|
| 592 |
-
reason = f"{rl_reason}; {reason}"
|
| 593 |
-
else:
|
| 594 |
-
action = _choose_action_heuristic(current_code, req.task_id)
|
| 595 |
-
reason = f"{rl_reason}; heuristic fallback"
|
| 596 |
-
source = "heuristic"
|
| 597 |
-
elif req.use_llm:
|
| 598 |
-
action, reason, source = _choose_action_llm(
|
| 599 |
-
code=current_code,
|
| 600 |
-
task_id=req.task_id,
|
| 601 |
-
step_index=step_idx,
|
| 602 |
-
max_steps=req.max_steps,
|
| 603 |
-
api_base_url=req.api_base_url or DEFAULT_API_BASE_URL,
|
| 604 |
-
model_name=req.model_name or DEFAULT_MODEL_NAME,
|
| 605 |
-
api_token=req.api_token or "",
|
| 606 |
-
)
|
| 607 |
-
else:
|
| 608 |
-
action = _choose_action_heuristic(current_code, req.task_id)
|
| 609 |
-
reason = "heuristic policy"
|
| 610 |
-
source = "heuristic"
|
| 611 |
-
|
| 612 |
-
_, reward, done, info = env.step(action)
|
| 613 |
-
state_now = env.state()
|
| 614 |
-
|
| 615 |
-
cumulative_reward += float(reward.raw)
|
| 616 |
-
steps.append(
|
| 617 |
-
OptimizationStep(
|
| 618 |
-
step=step_idx,
|
| 619 |
-
action=action,
|
| 620 |
-
action_name=info.get("action_name", "unknown"),
|
| 621 |
-
reason=reason,
|
| 622 |
-
source=source,
|
| 623 |
-
reward=float(reward.raw),
|
| 624 |
-
normalized_reward=float(reward.normalized),
|
| 625 |
-
changed=bool(info.get("changed", False)),
|
| 626 |
-
complexity=float(state_now.complexity),
|
| 627 |
-
)
|
| 628 |
-
)
|
| 629 |
-
|
| 630 |
-
if done:
|
| 631 |
-
break
|
| 632 |
-
|
| 633 |
-
final_code = str(env.state().current_code)
|
| 634 |
-
diff_lines = difflib.unified_diff(
|
| 635 |
-
code.splitlines(),
|
| 636 |
-
final_code.splitlines(),
|
| 637 |
-
fromfile="original.py",
|
| 638 |
-
tofile="optimized.py",
|
| 639 |
-
lineterm="",
|
| 640 |
-
)
|
| 641 |
-
diff_text = "\n".join(diff_lines)
|
| 642 |
-
|
| 643 |
-
task_score: Optional[float] = None
|
| 644 |
-
if req.task_id:
|
| 645 |
-
task = registry.get_task(req.task_id)
|
| 646 |
-
if task is None:
|
| 647 |
-
raise HTTPException(status_code=404, detail=f"Task '{req.task_id}' not found")
|
| 648 |
-
task_score = round(task.grade(final_code), 4)
|
| 649 |
-
|
| 650 |
-
return OptimizeResponse(
|
| 651 |
-
original_code=code,
|
| 652 |
-
optimized_code=final_code,
|
| 653 |
-
diff=diff_text,
|
| 654 |
-
steps=steps,
|
| 655 |
-
cumulative_reward=round(cumulative_reward, 4),
|
| 656 |
-
task_id=req.task_id,
|
| 657 |
-
task_score=task_score,
|
| 658 |
-
)
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
# ---------------------------------------------------------------------------
|
| 662 |
-
# Entry point
|
| 663 |
-
# ---------------------------------------------------------------------------
|
| 664 |
-
|
| 665 |
-
if __name__ == "__main__":
|
| 666 |
-
port = int(os.getenv("PORT", 7860))
|
| 667 |
-
uvicorn.run(app, host="0.0.0.0", port=port)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ACRE_FINAL/validate.py
DELETED
|
@@ -1,281 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
ACRE pre-submission validator.
|
| 3 |
-
|
| 4 |
-
Checks the repository against the submission checklist and, when a server URL is
|
| 5 |
-
available, probes the HTTP API as well.
|
| 6 |
-
|
| 7 |
-
Run:
|
| 8 |
-
python validate.py --url http://localhost:7860
|
| 9 |
-
"""
|
| 10 |
-
from __future__ import annotations
|
| 11 |
-
|
| 12 |
-
import argparse
|
| 13 |
-
import ast
|
| 14 |
-
import re
|
| 15 |
-
import sys
|
| 16 |
-
from typing import Any, Tuple
|
| 17 |
-
|
| 18 |
-
try:
|
| 19 |
-
import requests
|
| 20 |
-
except ImportError:
|
| 21 |
-
print("[ERROR] requests is required. Run: pip install requests")
|
| 22 |
-
sys.exit(1)
|
| 23 |
-
|
| 24 |
-
PASS = "\033[92m[PASS]\033[0m"
|
| 25 |
-
FAIL = "\033[91m[FAIL]\033[0m"
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def check(label: str, ok: bool, detail: str = "") -> bool:
|
| 29 |
-
status = PASS if ok else FAIL
|
| 30 |
-
message = f" {status} {label}"
|
| 31 |
-
if detail:
|
| 32 |
-
message += f" - {detail}"
|
| 33 |
-
print(message)
|
| 34 |
-
return ok
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def get(url: str, path: str, timeout: int = 15) -> Tuple[bool, Any]:
|
| 38 |
-
try:
|
| 39 |
-
response = requests.get(f"{url}{path}", timeout=timeout)
|
| 40 |
-
response.raise_for_status()
|
| 41 |
-
return True, response.json()
|
| 42 |
-
except Exception as exc:
|
| 43 |
-
return False, str(exc)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def post(url: str, path: str, payload: dict, timeout: int = 15) -> Tuple[bool, Any]:
|
| 47 |
-
try:
|
| 48 |
-
response = requests.post(f"{url}{path}", json=payload, timeout=timeout)
|
| 49 |
-
response.raise_for_status()
|
| 50 |
-
return True, response.json()
|
| 51 |
-
except Exception as exc:
|
| 52 |
-
return False, str(exc)
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def read_text(path: str) -> str:
|
| 56 |
-
with open(path, encoding="utf-8") as handle:
|
| 57 |
-
return handle.read()
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def run_validation(base_url: str) -> int:
|
| 61 |
-
failures = 0
|
| 62 |
-
|
| 63 |
-
print("\n" + "=" * 60)
|
| 64 |
-
print(" ACRE Pre-Submission Validator")
|
| 65 |
-
print("=" * 60)
|
| 66 |
-
print(f" Target: {base_url}\n")
|
| 67 |
-
|
| 68 |
-
print("1. Static repository checks")
|
| 69 |
-
try:
|
| 70 |
-
interface_src = read_text("openenv_interface.py")
|
| 71 |
-
tree = ast.parse(interface_src)
|
| 72 |
-
classes = {node.name: node for node in tree.body if isinstance(node, ast.ClassDef)}
|
| 73 |
-
env_cls = classes.get("OpenEnvRefactorEnv")
|
| 74 |
-
failures += 0 if check("openenv_interface.py exists", True) else 1
|
| 75 |
-
failures += 0 if check("OpenEnvRefactorEnv is defined", env_cls is not None) else 1
|
| 76 |
-
if env_cls is not None:
|
| 77 |
-
methods = {node.name for node in env_cls.body if isinstance(node, ast.FunctionDef)}
|
| 78 |
-
for method_name in ["reset", "step", "state"]:
|
| 79 |
-
failures += 0 if check(
|
| 80 |
-
f"OpenEnvRefactorEnv implements {method_name}()",
|
| 81 |
-
method_name in methods,
|
| 82 |
-
) else 1
|
| 83 |
-
except FileNotFoundError:
|
| 84 |
-
failures += 1
|
| 85 |
-
check("openenv_interface.py exists", False, "file not found")
|
| 86 |
-
|
| 87 |
-
try:
|
| 88 |
-
models_src = read_text("models.py")
|
| 89 |
-
for name in ["ObservationModel", "ActionModel", "RewardModel"]:
|
| 90 |
-
failures += 0 if check(
|
| 91 |
-
f"{name} is defined in models.py",
|
| 92 |
-
f"class {name}" in models_src,
|
| 93 |
-
) else 1
|
| 94 |
-
except FileNotFoundError:
|
| 95 |
-
failures += 1
|
| 96 |
-
check("models.py exists", False, "file not found")
|
| 97 |
-
|
| 98 |
-
print("\n2. Health check (GET /)")
|
| 99 |
-
ok, data = get(base_url, "/")
|
| 100 |
-
failures += 0 if check("GET / returns HTTP 200", ok) else 1
|
| 101 |
-
if ok:
|
| 102 |
-
failures += 0 if check(
|
| 103 |
-
"Response has status field",
|
| 104 |
-
isinstance(data, dict) and "status" in data,
|
| 105 |
-
str(data),
|
| 106 |
-
) else 1
|
| 107 |
-
|
| 108 |
-
print("\n3. Tasks (GET /tasks)")
|
| 109 |
-
ok, data = get(base_url, "/tasks")
|
| 110 |
-
failures += 0 if check("GET /tasks returns 200", ok) else 1
|
| 111 |
-
if ok:
|
| 112 |
-
tasks = data.get("tasks", []) if isinstance(data, dict) else []
|
| 113 |
-
failures += 0 if check("At least 3 tasks defined", len(tasks) >= 3, f"found {len(tasks)}") else 1
|
| 114 |
-
difficulties = [t.get("difficulty", "") for t in tasks]
|
| 115 |
-
for diff in ["easy", "medium", "hard"]:
|
| 116 |
-
failures += 0 if check(f"Task with difficulty '{diff}' exists", diff in difficulties) else 1
|
| 117 |
-
for task in tasks:
|
| 118 |
-
failures += 0 if check(
|
| 119 |
-
f"Task '{task.get('id')}' has initial_code",
|
| 120 |
-
bool(task.get("initial_code")),
|
| 121 |
-
) else 1
|
| 122 |
-
|
| 123 |
-
print("\n4. Reset (POST /reset)")
|
| 124 |
-
ok, data = post(base_url, "/reset", {})
|
| 125 |
-
failures += 0 if check("POST /reset returns 200", ok) else 1
|
| 126 |
-
if ok:
|
| 127 |
-
observation = data.get("observation", {})
|
| 128 |
-
failures += 0 if check("Response has observation field", isinstance(observation, dict)) else 1
|
| 129 |
-
failures += 0 if check(
|
| 130 |
-
"Observation is typed with 4 fields",
|
| 131 |
-
{"code_length", "complexity_score", "runtime_s", "error_flag"}.issubset(observation),
|
| 132 |
-
str(observation),
|
| 133 |
-
) else 1
|
| 134 |
-
|
| 135 |
-
ok, _ = post(base_url, "/reset", {"task_id": "rename_variables"})
|
| 136 |
-
failures += 0 if check("POST /reset with task_id works", ok) else 1
|
| 137 |
-
|
| 138 |
-
print("\n5. State (GET /state)")
|
| 139 |
-
ok, data = get(base_url, "/state")
|
| 140 |
-
failures += 0 if check("GET /state returns 200", ok) else 1
|
| 141 |
-
if ok:
|
| 142 |
-
required_keys = [
|
| 143 |
-
"current_code",
|
| 144 |
-
"episode_steps",
|
| 145 |
-
"max_steps",
|
| 146 |
-
"complexity",
|
| 147 |
-
"observation",
|
| 148 |
-
"observation_vector",
|
| 149 |
-
"action_meanings",
|
| 150 |
-
]
|
| 151 |
-
for key in required_keys:
|
| 152 |
-
failures += 0 if check(f"State has '{key}' field", key in data) else 1
|
| 153 |
-
|
| 154 |
-
print("\n6. Step (POST /step)")
|
| 155 |
-
post(base_url, "/reset", {"task_id": "rename_variables"})
|
| 156 |
-
for action in range(5):
|
| 157 |
-
ok, data = post(base_url, "/step", {"action": action})
|
| 158 |
-
failures += 0 if check(
|
| 159 |
-
f"Action {action} executes without error",
|
| 160 |
-
ok and isinstance(data, dict) and "reward" in data and "done" in data,
|
| 161 |
-
) else 1
|
| 162 |
-
if ok:
|
| 163 |
-
reward_payload = data.get("reward", {})
|
| 164 |
-
norm = reward_payload.get("normalized", -1)
|
| 165 |
-
failures += 0 if check(
|
| 166 |
-
f"Action {action} returns typed reward payload",
|
| 167 |
-
{"raw", "normalized", "components"}.issubset(reward_payload),
|
| 168 |
-
str(reward_payload),
|
| 169 |
-
) else 1
|
| 170 |
-
failures += 0 if check(
|
| 171 |
-
f"Action {action} normalized_reward in [0,1]",
|
| 172 |
-
isinstance(norm, (int, float)) and 0.0 <= float(norm) <= 1.0,
|
| 173 |
-
f"got {norm}",
|
| 174 |
-
) else 1
|
| 175 |
-
if data.get("done"):
|
| 176 |
-
break
|
| 177 |
-
|
| 178 |
-
ok, data = post(base_url, "/step", {"action": 99})
|
| 179 |
-
check("Invalid action returns error (not crash)", not ok or "detail" in str(data), "(expected 4xx)")
|
| 180 |
-
|
| 181 |
-
print("\n7. Task graders (POST /tasks/{id}/grade)")
|
| 182 |
-
for task_id in ["rename_variables", "remove_dead_code", "full_refactor"]:
|
| 183 |
-
ok, data = post(base_url, f"/tasks/{task_id}/grade", {"code": "def f(): pass"})
|
| 184 |
-
failures += 0 if check(f"Grade endpoint for '{task_id}' works", ok) else 1
|
| 185 |
-
if ok:
|
| 186 |
-
score = data.get("score", -1)
|
| 187 |
-
failures += 0 if check(
|
| 188 |
-
f"Score for '{task_id}' in [0.0, 1.0]",
|
| 189 |
-
isinstance(score, (int, float)) and 0.0 <= float(score) <= 1.0,
|
| 190 |
-
f"got {score}",
|
| 191 |
-
) else 1
|
| 192 |
-
|
| 193 |
-
print("\n8. openenv.yaml")
|
| 194 |
-
try:
|
| 195 |
-
openenv_yaml = read_text("openenv.yaml")
|
| 196 |
-
failures += 0 if check("openenv.yaml exists", True) else 1
|
| 197 |
-
for field in ["tasks:", "action_space:", "observation_space:", "reward:", "entrypoint:", "validation:"]:
|
| 198 |
-
failures += 0 if check(f"openenv.yaml has '{field}' section", field in openenv_yaml) else 1
|
| 199 |
-
except FileNotFoundError:
|
| 200 |
-
failures += 1
|
| 201 |
-
check("openenv.yaml exists", False, "file not found")
|
| 202 |
-
|
| 203 |
-
print("\n9. inference.py")
|
| 204 |
-
try:
|
| 205 |
-
inference_src = read_text("inference.py")
|
| 206 |
-
failures += 0 if check("inference.py exists", True) else 1
|
| 207 |
-
for marker in ['"event": "START"', '"event": "STEP"', '"event": "END"']:
|
| 208 |
-
failures += 0 if check(f"inference.py emits {marker}", marker in inference_src) else 1
|
| 209 |
-
failures += 0 if check(
|
| 210 |
-
"Uses OpenAI client",
|
| 211 |
-
"from openai import OpenAI" in inference_src,
|
| 212 |
-
) else 1
|
| 213 |
-
for var in ["API_BASE_URL", "MODEL_NAME", "HF_TOKEN", "ENV_URL", "LOCAL_IMAGE_NAME"]:
|
| 214 |
-
failures += 0 if check(f"inference.py reads {var} from env", var in inference_src) else 1
|
| 215 |
-
failures += 0 if check(
|
| 216 |
-
"API_BASE_URL has a default",
|
| 217 |
-
'os.getenv("API_BASE_URL", "https://api.openai.com/v1")' in inference_src,
|
| 218 |
-
) else 1
|
| 219 |
-
failures += 0 if check(
|
| 220 |
-
"MODEL_NAME has a default",
|
| 221 |
-
'os.getenv("MODEL_NAME", "gpt-4o-mini")' in inference_src,
|
| 222 |
-
) else 1
|
| 223 |
-
failures += 0 if check(
|
| 224 |
-
"HF_TOKEN has no default",
|
| 225 |
-
re.search(r'HF_TOKEN\s*:\s*.*os\.getenv\("HF_TOKEN"\)', inference_src) is not None,
|
| 226 |
-
) else 1
|
| 227 |
-
except FileNotFoundError:
|
| 228 |
-
failures += 1
|
| 229 |
-
check("inference.py exists", False, "file not found")
|
| 230 |
-
|
| 231 |
-
print("\n10. Dockerfile")
|
| 232 |
-
try:
|
| 233 |
-
dockerfile = read_text("Dockerfile")
|
| 234 |
-
failures += 0 if check("Dockerfile exists", True) else 1
|
| 235 |
-
failures += 0 if check("Exposes port 7860", "7860" in dockerfile) else 1
|
| 236 |
-
failures += 0 if check("Has CMD/ENTRYPOINT", "CMD" in dockerfile or "ENTRYPOINT" in dockerfile) else 1
|
| 237 |
-
failures += 0 if check("Does not set a default HF_TOKEN", "ENV HF_TOKEN" not in dockerfile) else 1
|
| 238 |
-
except FileNotFoundError:
|
| 239 |
-
failures += 1
|
| 240 |
-
check("Dockerfile exists", False, "file not found")
|
| 241 |
-
|
| 242 |
-
print("\n11. README / Hugging Face metadata")
|
| 243 |
-
try:
|
| 244 |
-
readme = read_text("README.md")
|
| 245 |
-
failures += 0 if check("README has docker SDK front matter", "sdk: docker" in readme) else 1
|
| 246 |
-
failures += 0 if check("README includes openenv tag", "openenv" in readme) else 1
|
| 247 |
-
for section in [
|
| 248 |
-
"Environment Overview and Motivation",
|
| 249 |
-
"Definitions of Action and Observation Spaces",
|
| 250 |
-
"Task Descriptions with Expected Difficulty Levels",
|
| 251 |
-
"Setup and Usage Instructions",
|
| 252 |
-
"Baseline Performance Scores",
|
| 253 |
-
]:
|
| 254 |
-
failures += 0 if check(f"README includes '{section}'", section in readme) else 1
|
| 255 |
-
except FileNotFoundError:
|
| 256 |
-
failures += 1
|
| 257 |
-
check("README.md exists", False, "file not found")
|
| 258 |
-
|
| 259 |
-
print("\n" + "=" * 60)
|
| 260 |
-
if failures == 0:
|
| 261 |
-
print(f" {PASS} All checks passed. Repository is submission-ready.")
|
| 262 |
-
else:
|
| 263 |
-
print(f" {FAIL} {failures} check(s) failed. Fix before submitting.")
|
| 264 |
-
print("=" * 60 + "\n")
|
| 265 |
-
|
| 266 |
-
return failures
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
def main() -> None:
|
| 270 |
-
parser = argparse.ArgumentParser(description="ACRE pre-submission validator")
|
| 271 |
-
parser.add_argument(
|
| 272 |
-
"--url",
|
| 273 |
-
default="http://localhost:7860",
|
| 274 |
-
help="Base URL of the running ACRE server",
|
| 275 |
-
)
|
| 276 |
-
args = parser.parse_args()
|
| 277 |
-
sys.exit(run_validation(args.url))
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
if __name__ == "__main__":
|
| 281 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -167,9 +167,9 @@ The deterministic fallback policy used by `inference.py` produces the following
|
|
| 167 |
|
| 168 |
| Task | Score |
|
| 169 |
|---|---|
|
| 170 |
-
| `rename_variables` | 1.
|
| 171 |
-
| `remove_dead_code` |
|
| 172 |
-
| `full_refactor` |
|
| 173 |
-
| Average |
|
| 174 |
|
| 175 |
These scores come from the built-in heuristic policy with `HF_TOKEN` unset, which keeps the baseline reproducible across runs.
|
|
|
|
| 167 |
|
| 168 |
| Task | Score |
|
| 169 |
|---|---|
|
| 170 |
+
| `rename_variables` | 1.0000 |
|
| 171 |
+
| `remove_dead_code` | 0.2500 |
|
| 172 |
+
| `full_refactor` | 0.7143 |
|
| 173 |
+
| Average | 0.6548 |
|
| 174 |
|
| 175 |
These scores come from the built-in heuristic policy with `HF_TOKEN` unset, which keeps the baseline reproducible across runs.
|
acre/tasks/task_registry.py
CHANGED
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
| 5 |
|
| 6 |
import ast
|
| 7 |
from dataclasses import dataclass
|
| 8 |
-
from typing import Callable, Dict, List, Optional
|
| 9 |
|
| 10 |
|
| 11 |
@dataclass
|
|
@@ -14,9 +14,13 @@ class Task:
|
|
| 14 |
name: str
|
| 15 |
description: str
|
| 16 |
difficulty: str
|
| 17 |
-
|
| 18 |
_grade_fn: Callable[[str], float]
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def grade(self, code: str) -> float:
|
| 21 |
"""Return a score in [0.0, 1.0]."""
|
| 22 |
try:
|
|
@@ -25,21 +29,90 @@ class Task:
|
|
| 25 |
return 0.0
|
| 26 |
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
# ---------------------------------------------------------------------------
|
| 29 |
# Task 1 — Easy: Rename generic variables
|
| 30 |
# ---------------------------------------------------------------------------
|
| 31 |
-
|
|
|
|
| 32 |
def compute(x, y, tmp):
|
| 33 |
tmp = x + y
|
| 34 |
x = tmp * 2
|
| 35 |
result = x
|
| 36 |
return result
|
| 37 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
def _grade_easy(code: str) -> float:
|
| 41 |
-
"""Score = fraction of generic names
|
| 42 |
-
generic = {"x", "tmp"}
|
| 43 |
try:
|
| 44 |
tree = ast.parse(code)
|
| 45 |
except SyntaxError:
|
|
@@ -66,7 +139,8 @@ def _grade_easy(code: str) -> float:
|
|
| 66 |
# ---------------------------------------------------------------------------
|
| 67 |
# Task 2 — Medium: Remove dead code
|
| 68 |
# ---------------------------------------------------------------------------
|
| 69 |
-
|
|
|
|
| 70 |
def process(data):
|
| 71 |
result = []
|
| 72 |
for item in data:
|
|
@@ -76,31 +150,74 @@ def process(data):
|
|
| 76 |
unused_var = 42
|
| 77 |
return result
|
| 78 |
print("unreachable")
|
| 79 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
def _grade_medium(code: str) -> float:
|
| 83 |
-
"""Score = fraction of dead-code patterns eliminated (
|
| 84 |
try:
|
| 85 |
tree = ast.parse(code)
|
| 86 |
except SyntaxError:
|
| 87 |
return 0.0
|
| 88 |
|
| 89 |
-
source =
|
| 90 |
score = 0.0
|
| 91 |
|
| 92 |
-
# Check 1: if-False
|
| 93 |
-
if "if False" not in source:
|
| 94 |
-
score +=
|
| 95 |
|
| 96 |
-
# Check 2:
|
| 97 |
-
if
|
| 98 |
-
score +=
|
| 99 |
|
| 100 |
# Check 3: list comprehension used (loop simplified)
|
| 101 |
has_listcomp = any(isinstance(n, ast.ListComp) for n in ast.walk(tree))
|
| 102 |
if has_listcomp:
|
| 103 |
-
score +=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
return score
|
| 106 |
|
|
@@ -108,7 +225,8 @@ def _grade_medium(code: str) -> float:
|
|
| 108 |
# ---------------------------------------------------------------------------
|
| 109 |
# Task 3 — Hard: Full refactor
|
| 110 |
# ---------------------------------------------------------------------------
|
| 111 |
-
|
|
|
|
| 112 |
def add(p, q):
|
| 113 |
return p + q
|
| 114 |
|
|
@@ -124,34 +242,89 @@ def compute(x, data, tmp):
|
|
| 124 |
flag = not not True
|
| 125 |
return val
|
| 126 |
print("dead")
|
| 127 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
|
| 130 |
def _grade_hard(code: str) -> float:
|
| 131 |
-
"""Score = fraction of
|
| 132 |
try:
|
| 133 |
tree = ast.parse(code)
|
| 134 |
except SyntaxError:
|
| 135 |
return 0.0
|
| 136 |
|
| 137 |
-
source =
|
| 138 |
checks = 0
|
| 139 |
|
| 140 |
-
# 1. No generic variable names x/tmp in function signature
|
| 141 |
has_generic = False
|
| 142 |
|
| 143 |
class _GenCheck(ast.NodeVisitor):
|
| 144 |
def visit_arg(self, node: ast.arg) -> None:
|
| 145 |
nonlocal has_generic
|
| 146 |
-
if node.arg in {"x", "tmp"}:
|
| 147 |
has_generic = True
|
| 148 |
|
| 149 |
_GenCheck().visit(tree)
|
| 150 |
if not has_generic:
|
| 151 |
checks += 1
|
| 152 |
|
| 153 |
-
# 2. No if False block
|
| 154 |
-
if "if False" not in source:
|
| 155 |
checks += 1
|
| 156 |
|
| 157 |
# 3. if True removed (body inlined)
|
|
@@ -162,13 +335,21 @@ def _grade_hard(code: str) -> float:
|
|
| 162 |
if any(isinstance(n, ast.ListComp) for n in ast.walk(tree)):
|
| 163 |
checks += 1
|
| 164 |
|
| 165 |
-
# 5.
|
| 166 |
calls = [n for n in ast.walk(tree) if isinstance(n, ast.Call)]
|
| 167 |
fn_names = {c.func.id for c in calls if isinstance(c.func, ast.Name)}
|
| 168 |
-
if "add"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
checks += 1
|
| 170 |
|
| 171 |
-
return checks /
|
| 172 |
|
| 173 |
|
| 174 |
# ---------------------------------------------------------------------------
|
|
@@ -186,7 +367,7 @@ class TaskRegistry:
|
|
| 186 |
name="Rename Variables (Easy)",
|
| 187 |
description="Rename generic variable names (x, tmp) to descriptive ones",
|
| 188 |
difficulty="easy",
|
| 189 |
-
|
| 190 |
_grade_fn=_grade_easy,
|
| 191 |
)
|
| 192 |
self._tasks["remove_dead_code"] = Task(
|
|
@@ -194,7 +375,7 @@ class TaskRegistry:
|
|
| 194 |
name="Remove Dead Code (Medium)",
|
| 195 |
description="Remove unreachable code, if False blocks, and unused variables",
|
| 196 |
difficulty="medium",
|
| 197 |
-
|
| 198 |
_grade_fn=_grade_medium,
|
| 199 |
)
|
| 200 |
self._tasks["full_refactor"] = Task(
|
|
@@ -202,7 +383,7 @@ class TaskRegistry:
|
|
| 202 |
name="Full Refactor (Hard)",
|
| 203 |
description="Apply all transformations: rename, dead code, loops, conditions, inlining",
|
| 204 |
difficulty="hard",
|
| 205 |
-
|
| 206 |
_grade_fn=_grade_hard,
|
| 207 |
)
|
| 208 |
|
|
|
|
| 5 |
|
| 6 |
import ast
|
| 7 |
from dataclasses import dataclass
|
| 8 |
+
from typing import Callable, Dict, List, Optional, Sequence
|
| 9 |
|
| 10 |
|
| 11 |
@dataclass
|
|
|
|
| 14 |
name: str
|
| 15 |
description: str
|
| 16 |
difficulty: str
|
| 17 |
+
samples: List[str]
|
| 18 |
_grade_fn: Callable[[str], float]
|
| 19 |
|
| 20 |
+
@property
|
| 21 |
+
def initial_code(self) -> str:
|
| 22 |
+
return str(self.samples[0]) if self.samples else ""
|
| 23 |
+
|
| 24 |
def grade(self, code: str) -> float:
|
| 25 |
"""Return a score in [0.0, 1.0]."""
|
| 26 |
try:
|
|
|
|
| 29 |
return 0.0
|
| 30 |
|
| 31 |
|
| 32 |
+
def _safe_unparse(tree: ast.AST) -> str:
|
| 33 |
+
try:
|
| 34 |
+
return ast.unparse(tree)
|
| 35 |
+
except Exception:
|
| 36 |
+
return ""
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _has_unreachable_after_terminator(stmts: Sequence[ast.stmt]) -> bool:
|
| 40 |
+
unreachable = False
|
| 41 |
+
for s in stmts:
|
| 42 |
+
if unreachable:
|
| 43 |
+
# ignore empty docstrings as "unreachable" noise
|
| 44 |
+
if isinstance(s, ast.Expr) and isinstance(s.value, ast.Constant) and isinstance(s.value.value, str):
|
| 45 |
+
continue
|
| 46 |
+
return True
|
| 47 |
+
if isinstance(s, (ast.Return, ast.Raise)):
|
| 48 |
+
unreachable = True
|
| 49 |
+
return False
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _tree_has_unreachable(tree: ast.AST) -> bool:
|
| 53 |
+
class _Scan(ast.NodeVisitor):
|
| 54 |
+
def __init__(self) -> None:
|
| 55 |
+
self.bad = False
|
| 56 |
+
|
| 57 |
+
def visit_FunctionDef(self, node: ast.FunctionDef) -> None: # noqa: N802
|
| 58 |
+
if _has_unreachable_after_terminator(node.body):
|
| 59 |
+
self.bad = True
|
| 60 |
+
self.generic_visit(node)
|
| 61 |
+
|
| 62 |
+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: # noqa: N802
|
| 63 |
+
if _has_unreachable_after_terminator(node.body):
|
| 64 |
+
self.bad = True
|
| 65 |
+
self.generic_visit(node)
|
| 66 |
+
|
| 67 |
+
s = _Scan()
|
| 68 |
+
s.visit(tree)
|
| 69 |
+
return bool(s.bad)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
# ---------------------------------------------------------------------------
|
| 73 |
# Task 1 — Easy: Rename generic variables
|
| 74 |
# ---------------------------------------------------------------------------
|
| 75 |
+
_EASY_SAMPLES: List[str] = [
|
| 76 |
+
"""\
|
| 77 |
def compute(x, y, tmp):
|
| 78 |
tmp = x + y
|
| 79 |
x = tmp * 2
|
| 80 |
result = x
|
| 81 |
return result
|
| 82 |
+
""",
|
| 83 |
+
"""\
|
| 84 |
+
def normalize(tmp, x):
|
| 85 |
+
for i in range(3):
|
| 86 |
+
tmp = tmp + i
|
| 87 |
+
return tmp * x
|
| 88 |
+
""",
|
| 89 |
+
"""\
|
| 90 |
+
def score(items):
|
| 91 |
+
tmp = 0
|
| 92 |
+
for i in items:
|
| 93 |
+
tmp += i
|
| 94 |
+
x = tmp
|
| 95 |
+
return x
|
| 96 |
+
""",
|
| 97 |
+
"""\
|
| 98 |
+
def transform(x):
|
| 99 |
+
tmp = x
|
| 100 |
+
if tmp > 10:
|
| 101 |
+
tmp = tmp - 1
|
| 102 |
+
return tmp
|
| 103 |
+
""",
|
| 104 |
+
"""\
|
| 105 |
+
def merge(a, b):
|
| 106 |
+
x = a
|
| 107 |
+
tmp = b
|
| 108 |
+
return x + tmp
|
| 109 |
+
""",
|
| 110 |
+
]
|
| 111 |
|
| 112 |
|
| 113 |
def _grade_easy(code: str) -> float:
|
| 114 |
+
"""Score = fraction of generic names removed from all scopes."""
|
| 115 |
+
generic = {"x", "tmp", "i"}
|
| 116 |
try:
|
| 117 |
tree = ast.parse(code)
|
| 118 |
except SyntaxError:
|
|
|
|
| 139 |
# ---------------------------------------------------------------------------
|
| 140 |
# Task 2 — Medium: Remove dead code
|
| 141 |
# ---------------------------------------------------------------------------
|
| 142 |
+
_MEDIUM_SAMPLES: List[str] = [
|
| 143 |
+
"""\
|
| 144 |
def process(data):
|
| 145 |
result = []
|
| 146 |
for item in data:
|
|
|
|
| 150 |
unused_var = 42
|
| 151 |
return result
|
| 152 |
print("unreachable")
|
| 153 |
+
""",
|
| 154 |
+
"""\
|
| 155 |
+
def build(values):
|
| 156 |
+
out = []
|
| 157 |
+
for v in values:
|
| 158 |
+
out.append(v + 1)
|
| 159 |
+
while False:
|
| 160 |
+
out.append(999)
|
| 161 |
+
dead = 0
|
| 162 |
+
return out
|
| 163 |
+
dead += 1
|
| 164 |
+
""",
|
| 165 |
+
"""\
|
| 166 |
+
def route(flag):
|
| 167 |
+
if False:
|
| 168 |
+
return 1
|
| 169 |
+
if True:
|
| 170 |
+
x = 2
|
| 171 |
+
y = x
|
| 172 |
+
return y
|
| 173 |
+
""",
|
| 174 |
+
"""\
|
| 175 |
+
def clean(xs):
|
| 176 |
+
res = []
|
| 177 |
+
for x in xs:
|
| 178 |
+
res.append(x * 2)
|
| 179 |
+
unused = "remove me"
|
| 180 |
+
if False:
|
| 181 |
+
unused2 = 123
|
| 182 |
+
return res
|
| 183 |
+
""",
|
| 184 |
+
"""\
|
| 185 |
+
def calc(n):
|
| 186 |
+
total = 0
|
| 187 |
+
for i in range(n):
|
| 188 |
+
total += i
|
| 189 |
+
return total
|
| 190 |
+
print("dead")
|
| 191 |
+
""",
|
| 192 |
+
]
|
| 193 |
|
| 194 |
|
| 195 |
def _grade_medium(code: str) -> float:
|
| 196 |
+
"""Score = fraction of dead-code patterns eliminated (4 checks, 0.25 each)."""
|
| 197 |
try:
|
| 198 |
tree = ast.parse(code)
|
| 199 |
except SyntaxError:
|
| 200 |
return 0.0
|
| 201 |
|
| 202 |
+
source = _safe_unparse(tree)
|
| 203 |
score = 0.0
|
| 204 |
|
| 205 |
+
# Check 1: if/while-False removed
|
| 206 |
+
if ("if False" not in source) and ("while False" not in source):
|
| 207 |
+
score += 0.25
|
| 208 |
|
| 209 |
+
# Check 2: no unreachable statements after return/raise
|
| 210 |
+
if not _tree_has_unreachable(tree):
|
| 211 |
+
score += 0.25
|
| 212 |
|
| 213 |
# Check 3: list comprehension used (loop simplified)
|
| 214 |
has_listcomp = any(isinstance(n, ast.ListComp) for n in ast.walk(tree))
|
| 215 |
if has_listcomp:
|
| 216 |
+
score += 0.25
|
| 217 |
+
|
| 218 |
+
# Check 4: obvious dead/unused sentinel names removed
|
| 219 |
+
if all(name not in source for name in ["unused_var", "unused", "dead", "unused2"]):
|
| 220 |
+
score += 0.25
|
| 221 |
|
| 222 |
return score
|
| 223 |
|
|
|
|
| 225 |
# ---------------------------------------------------------------------------
|
| 226 |
# Task 3 — Hard: Full refactor
|
| 227 |
# ---------------------------------------------------------------------------
|
| 228 |
+
_HARD_SAMPLES: List[str] = [
|
| 229 |
+
"""\
|
| 230 |
def add(p, q):
|
| 231 |
return p + q
|
| 232 |
|
|
|
|
| 242 |
flag = not not True
|
| 243 |
return val
|
| 244 |
print("dead")
|
| 245 |
+
""",
|
| 246 |
+
"""\
|
| 247 |
+
def helper(a, b):
|
| 248 |
+
return a + b
|
| 249 |
+
|
| 250 |
+
def pipeline(tmp, xs, x):
|
| 251 |
+
out = []
|
| 252 |
+
for i in xs:
|
| 253 |
+
out.append(i * 2)
|
| 254 |
+
if True:
|
| 255 |
+
y = helper(tmp, x)
|
| 256 |
+
if False:
|
| 257 |
+
y = 0
|
| 258 |
+
return y
|
| 259 |
+
y = 123
|
| 260 |
+
""",
|
| 261 |
+
"""\
|
| 262 |
+
def add(p, q):
|
| 263 |
+
return p + q
|
| 264 |
+
|
| 265 |
+
def compute(x, data, tmp):
|
| 266 |
+
result = []
|
| 267 |
+
for item in data:
|
| 268 |
+
result.append(item * 2)
|
| 269 |
+
if False:
|
| 270 |
+
print("never")
|
| 271 |
+
val = add(x, tmp)
|
| 272 |
+
return val
|
| 273 |
+
""",
|
| 274 |
+
"""\
|
| 275 |
+
def add(p, q):
|
| 276 |
+
return p + q
|
| 277 |
+
|
| 278 |
+
def compute(x, data, tmp):
|
| 279 |
+
res = []
|
| 280 |
+
for item in data:
|
| 281 |
+
res.append(item * 2)
|
| 282 |
+
flag = not not True
|
| 283 |
+
if True:
|
| 284 |
+
return add(x, tmp)
|
| 285 |
+
""",
|
| 286 |
+
"""\
|
| 287 |
+
def plus(p, q):
|
| 288 |
+
return p + q
|
| 289 |
+
|
| 290 |
+
def compute(tmp, data, x):
|
| 291 |
+
out = []
|
| 292 |
+
for item in data:
|
| 293 |
+
out.append(item * 2)
|
| 294 |
+
if False:
|
| 295 |
+
tmp = 999
|
| 296 |
+
if True:
|
| 297 |
+
val = plus(x, tmp)
|
| 298 |
+
return val
|
| 299 |
+
""",
|
| 300 |
+
]
|
| 301 |
|
| 302 |
|
| 303 |
def _grade_hard(code: str) -> float:
|
| 304 |
+
"""Score = fraction of 7 quality checks passed."""
|
| 305 |
try:
|
| 306 |
tree = ast.parse(code)
|
| 307 |
except SyntaxError:
|
| 308 |
return 0.0
|
| 309 |
|
| 310 |
+
source = _safe_unparse(tree)
|
| 311 |
checks = 0
|
| 312 |
|
| 313 |
+
# 1. No generic variable names x/tmp/i in function signature
|
| 314 |
has_generic = False
|
| 315 |
|
| 316 |
class _GenCheck(ast.NodeVisitor):
|
| 317 |
def visit_arg(self, node: ast.arg) -> None:
|
| 318 |
nonlocal has_generic
|
| 319 |
+
if node.arg in {"x", "tmp", "i"}:
|
| 320 |
has_generic = True
|
| 321 |
|
| 322 |
_GenCheck().visit(tree)
|
| 323 |
if not has_generic:
|
| 324 |
checks += 1
|
| 325 |
|
| 326 |
+
# 2. No if/while False block
|
| 327 |
+
if ("if False" not in source) and ("while False" not in source):
|
| 328 |
checks += 1
|
| 329 |
|
| 330 |
# 3. if True removed (body inlined)
|
|
|
|
| 335 |
if any(isinstance(n, ast.ListComp) for n in ast.walk(tree)):
|
| 336 |
checks += 1
|
| 337 |
|
| 338 |
+
# 5. helper calls inlined (no call sites remain)
|
| 339 |
calls = [n for n in ast.walk(tree) if isinstance(n, ast.Call)]
|
| 340 |
fn_names = {c.func.id for c in calls if isinstance(c.func, ast.Name)}
|
| 341 |
+
if not ({"add", "plus", "helper"} & fn_names):
|
| 342 |
+
checks += 1
|
| 343 |
+
|
| 344 |
+
# 6. no unreachable after return/raise
|
| 345 |
+
if not _tree_has_unreachable(tree):
|
| 346 |
+
checks += 1
|
| 347 |
+
|
| 348 |
+
# 7. remove double-not
|
| 349 |
+
if "not not" not in source:
|
| 350 |
checks += 1
|
| 351 |
|
| 352 |
+
return checks / 7
|
| 353 |
|
| 354 |
|
| 355 |
# ---------------------------------------------------------------------------
|
|
|
|
| 367 |
name="Rename Variables (Easy)",
|
| 368 |
description="Rename generic variable names (x, tmp) to descriptive ones",
|
| 369 |
difficulty="easy",
|
| 370 |
+
samples=_EASY_SAMPLES,
|
| 371 |
_grade_fn=_grade_easy,
|
| 372 |
)
|
| 373 |
self._tasks["remove_dead_code"] = Task(
|
|
|
|
| 375 |
name="Remove Dead Code (Medium)",
|
| 376 |
description="Remove unreachable code, if False blocks, and unused variables",
|
| 377 |
difficulty="medium",
|
| 378 |
+
samples=_MEDIUM_SAMPLES,
|
| 379 |
_grade_fn=_grade_medium,
|
| 380 |
)
|
| 381 |
self._tasks["full_refactor"] = Task(
|
|
|
|
| 383 |
name="Full Refactor (Hard)",
|
| 384 |
description="Apply all transformations: rename, dead code, loops, conditions, inlining",
|
| 385 |
difficulty="hard",
|
| 386 |
+
samples=_HARD_SAMPLES,
|
| 387 |
_grade_fn=_grade_hard,
|
| 388 |
)
|
| 389 |
|
inference.py
CHANGED
|
@@ -1,17 +1,18 @@
|
|
| 1 |
"""
|
| 2 |
ACRE inference script for OpenEnv submission evaluation.
|
| 3 |
|
| 4 |
-
|
| 5 |
-
API_BASE_URL: LLM API endpoint (default allowed)
|
| 6 |
-
MODEL_NAME: model identifier (default allowed)
|
| 7 |
-
HF_TOKEN: API token for the OpenAI-compatible endpoint
|
| 8 |
-
ENV_URL: running ACRE server base URL
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
| 15 |
"""
|
| 16 |
from __future__ import annotations
|
| 17 |
|
|
@@ -20,7 +21,7 @@ import os
|
|
| 20 |
import re
|
| 21 |
import sys
|
| 22 |
import time
|
| 23 |
-
from typing import Dict, List, Tuple
|
| 24 |
|
| 25 |
import requests
|
| 26 |
from openai import OpenAI
|
|
@@ -95,7 +96,7 @@ def grade(task_id: str, code: str) -> float:
|
|
| 95 |
return float(response.json().get("score", 0.0))
|
| 96 |
|
| 97 |
|
| 98 |
-
def choose_action(client: OpenAI, state: dict, task_id: str) -> Tuple[int, str]:
|
| 99 |
def heuristic_action() -> Tuple[int, str]:
|
| 100 |
code = str(state.get("current_code", ""))
|
| 101 |
step_i = int(state.get("episode_steps", 0))
|
|
@@ -141,7 +142,8 @@ def choose_action(client: OpenAI, state: dict, task_id: str) -> Tuple[int, str]:
|
|
| 141 |
return 1, "heuristic: remove remaining dead code"
|
| 142 |
return 3, "heuristic: condition optimization as safe default"
|
| 143 |
|
| 144 |
-
|
|
|
|
| 145 |
return heuristic_action()
|
| 146 |
|
| 147 |
messages = [
|
|
@@ -184,23 +186,12 @@ def choose_action(client: OpenAI, state: dict, task_id: str) -> Tuple[int, str]:
|
|
| 184 |
return heuristic_action()
|
| 185 |
|
| 186 |
|
| 187 |
-
def run_episode(client: OpenAI, task_id: str, episode_num: int) -> float:
|
| 188 |
reset_env(task_id)
|
| 189 |
state = get_state()
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
{
|
| 194 |
-
"event": "START",
|
| 195 |
-
"episode": episode_num,
|
| 196 |
-
"task_id": task_id,
|
| 197 |
-
"initial_complexity": state.get("complexity", 0),
|
| 198 |
-
"initial_code_length": len(state.get("current_code", "")),
|
| 199 |
-
"timestamp": time.time(),
|
| 200 |
-
}
|
| 201 |
-
),
|
| 202 |
-
flush=True,
|
| 203 |
-
)
|
| 204 |
|
| 205 |
cumulative_reward = 0.0
|
| 206 |
|
|
@@ -214,25 +205,8 @@ def run_episode(client: OpenAI, task_id: str, episode_num: int) -> float:
|
|
| 214 |
norm_reward = float(reward_payload.get("normalized", (raw_reward + 32) / 52))
|
| 215 |
cumulative_reward += raw_reward
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
{
|
| 220 |
-
"event": "STEP",
|
| 221 |
-
"episode": episode_num,
|
| 222 |
-
"step": step_num,
|
| 223 |
-
"action": action,
|
| 224 |
-
"action_name": ACTION_MEANINGS.get(action, "unknown"),
|
| 225 |
-
"reason": reason,
|
| 226 |
-
"reward": round(raw_reward, 4),
|
| 227 |
-
"normalized_reward": round(norm_reward, 4),
|
| 228 |
-
"cumulative_reward": round(cumulative_reward, 4),
|
| 229 |
-
"changed": result.get("info", {}).get("changed", False),
|
| 230 |
-
"reward_components": reward_payload.get("components", {}),
|
| 231 |
-
"done": result.get("done", False),
|
| 232 |
-
}
|
| 233 |
-
),
|
| 234 |
-
flush=True,
|
| 235 |
-
)
|
| 236 |
|
| 237 |
if result.get("done") or result.get("terminated") or result.get("truncated"):
|
| 238 |
break
|
|
@@ -240,21 +214,8 @@ def run_episode(client: OpenAI, task_id: str, episode_num: int) -> float:
|
|
| 240 |
final_state = get_state()
|
| 241 |
task_score = grade(task_id, final_state.get("current_code", ""))
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
{
|
| 246 |
-
"event": "END",
|
| 247 |
-
"episode": episode_num,
|
| 248 |
-
"task_id": task_id,
|
| 249 |
-
"cumulative_reward": round(cumulative_reward, 4),
|
| 250 |
-
"normalized_cumulative": round((cumulative_reward + 32) / 52, 4),
|
| 251 |
-
"task_score": round(task_score, 4),
|
| 252 |
-
"final_complexity": final_state.get("complexity", 0),
|
| 253 |
-
"timestamp": time.time(),
|
| 254 |
-
}
|
| 255 |
-
),
|
| 256 |
-
flush=True,
|
| 257 |
-
)
|
| 258 |
|
| 259 |
return task_score
|
| 260 |
|
|
@@ -263,7 +224,9 @@ def main() -> None:
|
|
| 263 |
if not ENV_URL:
|
| 264 |
raise SystemExit("ENV_URL is required. Example: ENV_URL=http://localhost:7860")
|
| 265 |
|
| 266 |
-
client
|
|
|
|
|
|
|
| 267 |
|
| 268 |
scores: List[float] = []
|
| 269 |
for i, task_id in enumerate(TASKS, start=1):
|
|
|
|
| 1 |
"""
|
| 2 |
ACRE inference script for OpenEnv submission evaluation.
|
| 3 |
|
| 4 |
+
Environment variables:
|
| 5 |
+
- API_BASE_URL: LLM API endpoint (default allowed)
|
| 6 |
+
- MODEL_NAME: model identifier (default allowed)
|
| 7 |
+
- HF_TOKEN: API token for the OpenAI-compatible endpoint (NO default)
|
| 8 |
+
- ENV_URL: running ACRE server base URL (required)
|
| 9 |
+
- LOCAL_IMAGE_NAME: present for evaluator compatibility (optional)
|
| 10 |
+
- USE_LLM: set to "1" to enable LLM action selection when HF_TOKEN is set
|
| 11 |
+
|
| 12 |
+
STRICT stdout format (do not change):
|
| 13 |
+
START <task_id>
|
| 14 |
+
STEP <action_int>
|
| 15 |
+
END <score_float>
|
| 16 |
"""
|
| 17 |
from __future__ import annotations
|
| 18 |
|
|
|
|
| 21 |
import re
|
| 22 |
import sys
|
| 23 |
import time
|
| 24 |
+
from typing import Dict, List, Optional, Tuple
|
| 25 |
|
| 26 |
import requests
|
| 27 |
from openai import OpenAI
|
|
|
|
| 96 |
return float(response.json().get("score", 0.0))
|
| 97 |
|
| 98 |
|
| 99 |
+
def choose_action(client: Optional[OpenAI], state: dict, task_id: str) -> Tuple[int, str]:
|
| 100 |
def heuristic_action() -> Tuple[int, str]:
|
| 101 |
code = str(state.get("current_code", ""))
|
| 102 |
step_i = int(state.get("episode_steps", 0))
|
|
|
|
| 142 |
return 1, "heuristic: remove remaining dead code"
|
| 143 |
return 3, "heuristic: condition optimization as safe default"
|
| 144 |
|
| 145 |
+
use_llm = bool(HF_TOKEN) and os.getenv("USE_LLM", "0") == "1"
|
| 146 |
+
if (not use_llm) or client is None:
|
| 147 |
return heuristic_action()
|
| 148 |
|
| 149 |
messages = [
|
|
|
|
| 186 |
return heuristic_action()
|
| 187 |
|
| 188 |
|
| 189 |
+
def run_episode(client: Optional[OpenAI], task_id: str, episode_num: int) -> float:
|
| 190 |
reset_env(task_id)
|
| 191 |
state = get_state()
|
| 192 |
|
| 193 |
+
# STRICT logging format required by evaluator.
|
| 194 |
+
print(f"START {task_id}", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
cumulative_reward = 0.0
|
| 197 |
|
|
|
|
| 205 |
norm_reward = float(reward_payload.get("normalized", (raw_reward + 32) / 52))
|
| 206 |
cumulative_reward += raw_reward
|
| 207 |
|
| 208 |
+
# STRICT logging format required by evaluator.
|
| 209 |
+
print(f"STEP {int(action)}", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
if result.get("done") or result.get("terminated") or result.get("truncated"):
|
| 212 |
break
|
|
|
|
| 214 |
final_state = get_state()
|
| 215 |
task_score = grade(task_id, final_state.get("current_code", ""))
|
| 216 |
|
| 217 |
+
# STRICT logging format required by evaluator.
|
| 218 |
+
print(f"END {task_score:.4f}", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
return task_score
|
| 221 |
|
|
|
|
| 224 |
if not ENV_URL:
|
| 225 |
raise SystemExit("ENV_URL is required. Example: ENV_URL=http://localhost:7860")
|
| 226 |
|
| 227 |
+
client: Optional[OpenAI] = None
|
| 228 |
+
if HF_TOKEN and os.getenv("USE_LLM", "0") == "1":
|
| 229 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 230 |
|
| 231 |
scores: List[float] = []
|
| 232 |
for i, task_id in enumerate(TASKS, start=1):
|
openenv_interface.py
CHANGED
|
@@ -62,7 +62,23 @@ class OpenEnvRefactorEnv(OpenEnvBase):
|
|
| 62 |
task = self._registry.get_task(task_id)
|
| 63 |
if task is None:
|
| 64 |
raise ValueError(f"Task '{task_id}' not found")
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
if initial_code is None:
|
| 68 |
return None
|
|
|
|
| 62 |
task = self._registry.get_task(task_id)
|
| 63 |
if task is None:
|
| 64 |
raise ValueError(f"Task '{task_id}' not found")
|
| 65 |
+
# Load a multi-sample dataset for this task. Sample selection is
|
| 66 |
+
# deterministic given the `seed` passed to `reset()`.
|
| 67 |
+
samples = list(getattr(task, "samples", []) or [])
|
| 68 |
+
if not samples:
|
| 69 |
+
initial_code = task.initial_code
|
| 70 |
+
else:
|
| 71 |
+
self._env.dataset = CodeSampleDataset(
|
| 72 |
+
[
|
| 73 |
+
CodeSample(
|
| 74 |
+
id=f"{task_id}:{i}",
|
| 75 |
+
language="python",
|
| 76 |
+
code=str(src),
|
| 77 |
+
)
|
| 78 |
+
for i, src in enumerate(samples)
|
| 79 |
+
]
|
| 80 |
+
)
|
| 81 |
+
return None
|
| 82 |
|
| 83 |
if initial_code is None:
|
| 84 |
return None
|
validate.py
CHANGED
|
@@ -204,8 +204,16 @@ def run_validation(base_url: str) -> int:
|
|
| 204 |
try:
|
| 205 |
inference_src = read_text("inference.py")
|
| 206 |
failures += 0 if check("inference.py exists", True) else 1
|
| 207 |
-
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
failures += 0 if check(
|
| 210 |
"Uses OpenAI client",
|
| 211 |
"from openai import OpenAI" in inference_src,
|
|
|
|
| 204 |
try:
|
| 205 |
inference_src = read_text("inference.py")
|
| 206 |
failures += 0 if check("inference.py exists", True) else 1
|
| 207 |
+
# Accept either the older JSON event markers or the strict hackathon
|
| 208 |
+
# line-based format:
|
| 209 |
+
# START <task_id>
|
| 210 |
+
# STEP <action>
|
| 211 |
+
# END <score>
|
| 212 |
+
json_markers_ok = all(m in inference_src for m in ['"event": "START"', '"event": "STEP"', '"event": "END"'])
|
| 213 |
+
line_markers_ok = all(m in inference_src for m in ["START ", "STEP ", "END "])
|
| 214 |
+
failures += 0 if check("inference.py emits START marker", json_markers_ok or line_markers_ok) else 1
|
| 215 |
+
failures += 0 if check("inference.py emits STEP marker", json_markers_ok or line_markers_ok) else 1
|
| 216 |
+
failures += 0 if check("inference.py emits END marker", json_markers_ok or line_markers_ok) else 1
|
| 217 |
failures += 0 if check(
|
| 218 |
"Uses OpenAI client",
|
| 219 |
"from openai import OpenAI" in inference_src,
|