Commit ·
bc5030f
1
Parent(s): 8422246
Fixed structure (moved files to root)
Browse files- .gitignore +26 -0
- Dockerfile +23 -0
- README.md +168 -4
- acre/__init__.py +14 -0
- acre/actions/__init__.py +6 -0
- acre/actions/transformations.py +518 -0
- acre/datasets/__init__.py +6 -0
- acre/datasets/code_samples.py +34 -0
- acre/demo.py +185 -0
- acre/main.py +39 -0
- acre/tasks/__init__.py +3 -0
- acre/tasks/task_registry.py +222 -0
- acre/training/__init__.py +6 -0
- acre/training/train_agent.py +75 -0
- acre/utils/__init__.py +6 -0
- acre/utils/metrics.py +33 -0
- inference.py +278 -0
- models.py +156 -0
- openenv.yaml +85 -0
- openenv_interface.py +116 -0
- requirements.txt +11 -0
- server.py +667 -0
- validate.py +281 -0
.gitignore
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
.Python
|
| 6 |
+
*.egg-info/
|
| 7 |
+
dist/
|
| 8 |
+
build/
|
| 9 |
+
.env
|
| 10 |
+
.venv
|
| 11 |
+
venv/
|
| 12 |
+
*.zip
|
| 13 |
+
acre_agent.zip
|
| 14 |
+
*.log
|
| 15 |
+
.DS_Store
|
| 16 |
+
.deps/
|
| 17 |
+
libs/
|
| 18 |
+
numpy.libs/
|
| 19 |
+
*.dll
|
| 20 |
+
*.so
|
| 21 |
+
*.dylib
|
| 22 |
+
env/
|
| 23 |
+
ENV/
|
| 24 |
+
.cache/
|
| 25 |
+
.huggingface/
|
| 26 |
+
Thumbs.db
|
Dockerfile
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 6 |
+
build-essential \
|
| 7 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 8 |
+
|
| 9 |
+
COPY requirements.txt .
|
| 10 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 11 |
+
|
| 12 |
+
COPY . .
|
| 13 |
+
|
| 14 |
+
ENV API_BASE_URL=https://api.openai.com/v1
|
| 15 |
+
ENV MODEL_NAME=gpt-4o-mini
|
| 16 |
+
ENV PORT=7860
|
| 17 |
+
|
| 18 |
+
EXPOSE 7860
|
| 19 |
+
|
| 20 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
| 21 |
+
CMD python -c "import requests; requests.get('http://localhost:7860/').raise_for_status()"
|
| 22 |
+
|
| 23 |
+
CMD ["python", "server.py"]
|
README.md
CHANGED
|
@@ -1,10 +1,174 @@
|
|
| 1 |
---
|
| 2 |
-
title: Autonomous Code Refactoring
|
| 3 |
-
emoji: ⚡
|
| 4 |
colorFrom: blue
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
|
|
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: ACRE - Autonomous Code Refactoring Environment
|
|
|
|
| 3 |
colorFrom: blue
|
| 4 |
+
colorTo: green
|
| 5 |
sdk: docker
|
| 6 |
+
app_port: 7860
|
| 7 |
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
tags:
|
| 10 |
+
- openenv
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# ACRE - Autonomous Code Refactoring Environment
|
| 14 |
+
|
| 15 |
+
ACRE is an OpenEnv-compatible environment for autonomous Python code refactoring. An agent receives real code-cleanup tasks and must improve the code through AST-based transformations while receiving dense reward feedback for correctness, simplification, and performance.
|
| 16 |
+
|
| 17 |
+
## Environment Overview and Motivation
|
| 18 |
+
|
| 19 |
+
This project simulates a realistic developer workflow: cleaning up messy Python code, removing dead logic, simplifying loops, and inlining trivial helpers. The canonical OpenEnv wrapper lives in `openenv_interface.py`, while the original Gymnasium-compatible environment remains available for RL training and demos.
|
| 20 |
+
|
| 21 |
+
## Definitions of Action and Observation Spaces
|
| 22 |
+
|
| 23 |
+
### Action Space - Discrete(5)
|
| 24 |
+
|
| 25 |
+
| Action | Name | Description |
|
| 26 |
+
|---|---|---|
|
| 27 |
+
| 0 | rename_variable | Rename generic variables like `x`, `tmp`, and `i` |
|
| 28 |
+
| 1 | remove_dead_code | Remove unreachable statements, `if False` branches, and unused assignments |
|
| 29 |
+
| 2 | simplify_loop | Convert append-loops into list comprehensions |
|
| 30 |
+
| 3 | optimize_condition | Simplify `not not x`, `if True`, `if False`, and boolean comparisons |
|
| 31 |
+
| 4 | inline_function | Inline simple single-return module-level functions |
|
| 32 |
+
|
| 33 |
+
### Observation Space - Box(4,)
|
| 34 |
+
|
| 35 |
+
The environment tracks:
|
| 36 |
+
|
| 37 |
+
- `code_length`
|
| 38 |
+
- `complexity_score`
|
| 39 |
+
- `runtime_s`
|
| 40 |
+
- `error_flag`
|
| 41 |
+
|
| 42 |
+
### Typed OpenEnv Models
|
| 43 |
+
|
| 44 |
+
The submission-facing interface uses Pydantic models in `models.py`:
|
| 45 |
+
|
| 46 |
+
- `ObservationModel`
|
| 47 |
+
- `ActionModel`
|
| 48 |
+
- `RewardModel`
|
| 49 |
+
- `StateResponse`
|
| 50 |
+
|
| 51 |
+
The canonical interface is:
|
| 52 |
+
|
| 53 |
+
```python
|
| 54 |
+
observation = env.reset(...)
|
| 55 |
+
observation, reward, done, info = env.step(action)
|
| 56 |
+
state = env.state()
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
## Task Descriptions with Expected Difficulty Levels
|
| 60 |
+
|
| 61 |
+
| Task ID | Difficulty | Objective |
|
| 62 |
+
|---|---|---|
|
| 63 |
+
| `rename_variables` | Easy | Remove generic variable names from the snippet |
|
| 64 |
+
| `remove_dead_code` | Medium | Eliminate dead branches, unreachable code, and unused assignments |
|
| 65 |
+
| `full_refactor` | Hard | Combine renaming, dead-code removal, loop simplification, condition optimization, and inlining |
|
| 66 |
+
|
| 67 |
+
Each task includes a deterministic AST-based grader returning a score in `[0.0, 1.0]`.
|
| 68 |
+
|
| 69 |
+
## Reward Design
|
| 70 |
+
|
| 71 |
+
Rewards are shaped throughout the trajectory instead of only at the end.
|
| 72 |
+
|
| 73 |
+
- Success reward for syntactically valid, executable output
|
| 74 |
+
- Complexity reward when control-flow complexity decreases
|
| 75 |
+
- Performance reward when runtime improves
|
| 76 |
+
- Error penalty for invalid or failing code
|
| 77 |
+
- No-change penalty to discourage loops and unproductive actions
|
| 78 |
+
|
| 79 |
+
Raw reward range is `[-32, 20]`, normalized to `[0.0, 1.0]` with `(raw + 32) / 52`.
|
| 80 |
+
|
| 81 |
+
## HTTP API
|
| 82 |
+
|
| 83 |
+
| Method | Path | Purpose |
|
| 84 |
+
|---|---|---|
|
| 85 |
+
| GET | `/` | Health check |
|
| 86 |
+
| GET | `/health` | Compatibility health check |
|
| 87 |
+
| POST | `/reset` | Reset environment and return typed observation/state |
|
| 88 |
+
| POST | `/step` | Apply one action and return typed observation/reward/done |
|
| 89 |
+
| GET | `/state` | Return the current typed state |
|
| 90 |
+
| GET | `/tasks` | List available tasks |
|
| 91 |
+
| POST | `/tasks/{task_id}/grade` | Grade submitted code |
|
| 92 |
+
|
| 93 |
+
## Setup and Usage Instructions
|
| 94 |
+
|
| 95 |
+
### Local setup
|
| 96 |
+
|
| 97 |
+
```bash
|
| 98 |
+
pip install -r requirements.txt
|
| 99 |
+
python server.py
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
### Baseline inference
|
| 103 |
+
|
| 104 |
+
Set environment variables before running:
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
export API_BASE_URL=https://api.openai.com/v1
|
| 108 |
+
export MODEL_NAME=gpt-4o-mini
|
| 109 |
+
export HF_TOKEN=your_key
|
| 110 |
+
export ENV_URL=http://localhost:7860
|
| 111 |
+
python inference.py
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
Notes:
|
| 115 |
+
|
| 116 |
+
- `API_BASE_URL` and `MODEL_NAME` have defaults in `inference.py`
|
| 117 |
+
- `HF_TOKEN` is optional because the script falls back to a deterministic heuristic baseline
|
| 118 |
+
- `LOCAL_IMAGE_NAME` is read for evaluator compatibility when using a local Docker image launcher
|
| 119 |
+
|
| 120 |
+
### Docker / Hugging Face Spaces
|
| 121 |
+
|
| 122 |
+
```bash
|
| 123 |
+
docker build -t acre .
|
| 124 |
+
docker run -p 7860:7860 \
|
| 125 |
+
-e API_BASE_URL=https://api.openai.com/v1 \
|
| 126 |
+
-e MODEL_NAME=gpt-4o-mini \
|
| 127 |
+
-e HF_TOKEN=your_key \
|
| 128 |
+
-e ENV_URL=http://localhost:7860 \
|
| 129 |
+
acre
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
The repository is configured for a Docker-based Hugging Face Space and includes the `openenv` tag in the front matter.
|
| 133 |
+
|
| 134 |
+
## Validation
|
| 135 |
+
|
| 136 |
+
Run the repository validator:
|
| 137 |
+
|
| 138 |
+
```bash
|
| 139 |
+
python validate.py --url http://localhost:7860
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
When using the official hackathon tooling, also run:
|
| 143 |
+
|
| 144 |
+
```bash
|
| 145 |
+
openenv validate
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
## Interactive Demo
|
| 149 |
+
|
| 150 |
+
Start the server and open:
|
| 151 |
+
|
| 152 |
+
```text
|
| 153 |
+
http://localhost:7860/demo
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
The demo shows:
|
| 157 |
+
|
| 158 |
+
- Original code
|
| 159 |
+
- Optimized code
|
| 160 |
+
- Unified diff
|
| 161 |
+
- Per-step action and reward logs
|
| 162 |
+
|
| 163 |
+
## Baseline Performance Scores
|
| 164 |
+
|
| 165 |
+
The deterministic fallback policy used by `inference.py` produces the following reproducible task scores:
|
| 166 |
+
|
| 167 |
+
| Task | Score |
|
| 168 |
+
|---|---|
|
| 169 |
+
| `rename_variables` | 1.0 |
|
| 170 |
+
| `remove_dead_code` | 1.0 |
|
| 171 |
+
| `full_refactor` | 1.0 |
|
| 172 |
+
| Average | 1.0 |
|
| 173 |
+
|
| 174 |
+
These scores come from the built-in heuristic policy with `HF_TOKEN` unset, which keeps the baseline reproducible across runs.
|
acre/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACRE (Autonomous Code Refactoring Environment).
|
| 3 |
+
|
| 4 |
+
Package skeleton for an RL-based code refactoring system.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"env",
|
| 9 |
+
"actions",
|
| 10 |
+
"datasets",
|
| 11 |
+
"training",
|
| 12 |
+
"utils",
|
| 13 |
+
]
|
| 14 |
+
|
acre/actions/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Action definitions and transformations for ACRE."""
|
| 2 |
+
|
| 3 |
+
from .transformations import Transformation, TransformationResult
|
| 4 |
+
|
| 5 |
+
__all__ = ["Transformation", "TransformationResult"]
|
| 6 |
+
|
acre/actions/transformations.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import ast
|
| 4 |
+
import copy
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from itertools import zip_longest
|
| 7 |
+
from typing import Any, Dict, Protocol
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass(frozen=True)
|
| 11 |
+
class TransformationResult:
|
| 12 |
+
"""Output of applying a transformation (placeholder)."""
|
| 13 |
+
|
| 14 |
+
code: str
|
| 15 |
+
changed: bool
|
| 16 |
+
metadata: Dict[str, Any]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Transformation(Protocol):
|
| 20 |
+
"""Protocol for a code transformation."""
|
| 21 |
+
|
| 22 |
+
name: str
|
| 23 |
+
|
| 24 |
+
def apply(self, code: str) -> TransformationResult: ...
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def noop_transformation(code: str) -> TransformationResult:
|
| 28 |
+
"""Baseline transformation that leaves code unchanged."""
|
| 29 |
+
return TransformationResult(code=code, changed=False, metadata={"kind": "noop"})
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _finalize_result(*, original: str, out: str, meta: Dict[str, Any]) -> TransformationResult:
|
| 33 |
+
"""
|
| 34 |
+
Standardize metadata across transformations.
|
| 35 |
+
|
| 36 |
+
- Adds `lines_changed` and `impact` for explainability/metrics.
|
| 37 |
+
- Ensures formatting-only changes don't count as `changed`.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def _count_lines_changed(a: str, b: str) -> int:
|
| 41 |
+
a_lines = a.splitlines()
|
| 42 |
+
b_lines = b.splitlines()
|
| 43 |
+
changed = 0
|
| 44 |
+
for x, y in zip_longest(a_lines, b_lines, fillvalue=None):
|
| 45 |
+
if x != y:
|
| 46 |
+
changed += 1
|
| 47 |
+
return int(changed)
|
| 48 |
+
|
| 49 |
+
lines_changed = _count_lines_changed(original, out)
|
| 50 |
+
|
| 51 |
+
# Fallback identity check: AST round-trips can reformat without changing meaning.
|
| 52 |
+
# If the textual content is the same after stripping, treat it as unchanged.
|
| 53 |
+
if out.strip() == original.strip():
|
| 54 |
+
meta["success"] = False
|
| 55 |
+
meta["lines_changed"] = 0
|
| 56 |
+
meta["impact"] = "low"
|
| 57 |
+
return TransformationResult(code=original, changed=False, metadata=meta)
|
| 58 |
+
|
| 59 |
+
meta["lines_changed"] = lines_changed
|
| 60 |
+
meta["impact"] = "high" if lines_changed >= 3 else "low"
|
| 61 |
+
meta["success"] = True
|
| 62 |
+
return TransformationResult(code=out, changed=True, metadata=meta)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _unchanged(*, code: str, meta: Dict[str, Any]) -> TransformationResult:
|
| 66 |
+
meta.setdefault("success", False)
|
| 67 |
+
meta.setdefault("lines_changed", 0)
|
| 68 |
+
meta.setdefault("impact", "low")
|
| 69 |
+
return TransformationResult(code=code, changed=False, metadata=meta)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def rename_variable(code: str) -> TransformationResult:
|
| 73 |
+
"""
|
| 74 |
+
Rename simple, generic variable names to more descriptive ones.
|
| 75 |
+
|
| 76 |
+
Hackathon-scope heuristic:
|
| 77 |
+
- Rename generic names in priority order: x, tmp, i.
|
| 78 |
+
- Uses descriptive base names and avoids collisions.
|
| 79 |
+
- Applies to Name nodes and function args.
|
| 80 |
+
"""
|
| 81 |
+
meta: Dict[str, Any] = {"type": "rename_variable", "success": False}
|
| 82 |
+
try:
|
| 83 |
+
tree = ast.parse(code)
|
| 84 |
+
|
| 85 |
+
class _NameCollector(ast.NodeVisitor):
|
| 86 |
+
def __init__(self) -> None:
|
| 87 |
+
self.names: set[str] = set()
|
| 88 |
+
|
| 89 |
+
def visit_Name(self, node: ast.Name) -> None: # noqa: N802
|
| 90 |
+
self.names.add(node.id)
|
| 91 |
+
|
| 92 |
+
def visit_arg(self, node: ast.arg) -> None: # noqa: N802
|
| 93 |
+
self.names.add(node.arg)
|
| 94 |
+
|
| 95 |
+
collector = _NameCollector()
|
| 96 |
+
collector.visit(tree)
|
| 97 |
+
|
| 98 |
+
rename_plan = [
|
| 99 |
+
("x", "value"),
|
| 100 |
+
("tmp", "temp_value"),
|
| 101 |
+
("i", "index"),
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
old = ""
|
| 105 |
+
base_new = "value"
|
| 106 |
+
for candidate_old, candidate_base in rename_plan:
|
| 107 |
+
if candidate_old in collector.names:
|
| 108 |
+
old = candidate_old
|
| 109 |
+
base_new = candidate_base
|
| 110 |
+
break
|
| 111 |
+
|
| 112 |
+
if not old:
|
| 113 |
+
return _unchanged(code=code, meta=meta)
|
| 114 |
+
|
| 115 |
+
new = base_new
|
| 116 |
+
i = 1
|
| 117 |
+
while new in collector.names:
|
| 118 |
+
new = f"{base_new}{i}"
|
| 119 |
+
i += 1
|
| 120 |
+
|
| 121 |
+
class _Renamer(ast.NodeTransformer):
|
| 122 |
+
def __init__(self, old_name: str, new_name: str) -> None:
|
| 123 |
+
self.old_name = old_name
|
| 124 |
+
self.new_name = new_name
|
| 125 |
+
self.changed = False
|
| 126 |
+
|
| 127 |
+
def visit_Name(self, node: ast.Name) -> ast.AST: # noqa: N802
|
| 128 |
+
if node.id == self.old_name:
|
| 129 |
+
self.changed = True
|
| 130 |
+
return ast.copy_location(ast.Name(id=self.new_name, ctx=node.ctx), node)
|
| 131 |
+
return node
|
| 132 |
+
|
| 133 |
+
def visit_arg(self, node: ast.arg) -> ast.AST: # noqa: N802
|
| 134 |
+
if node.arg == self.old_name:
|
| 135 |
+
self.changed = True
|
| 136 |
+
new_node = copy.copy(node)
|
| 137 |
+
new_node.arg = self.new_name
|
| 138 |
+
return new_node
|
| 139 |
+
return node
|
| 140 |
+
|
| 141 |
+
renamer = _Renamer(old, new)
|
| 142 |
+
tree = renamer.visit(tree)
|
| 143 |
+
ast.fix_missing_locations(tree)
|
| 144 |
+
|
| 145 |
+
if not renamer.changed:
|
| 146 |
+
return _unchanged(code=code, meta=meta)
|
| 147 |
+
|
| 148 |
+
out = ast.unparse(tree)
|
| 149 |
+
meta["old"] = old
|
| 150 |
+
meta["new"] = new
|
| 151 |
+
# Renames tend to be small diffs; label as low impact unless the diff is large.
|
| 152 |
+
return _finalize_result(original=code, out=out, meta=meta)
|
| 153 |
+
except Exception:
|
| 154 |
+
return _unchanged(code=code, meta=meta)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def remove_dead_code(code: str) -> TransformationResult:
|
| 158 |
+
"""
|
| 159 |
+
Remove simple dead code patterns.
|
| 160 |
+
|
| 161 |
+
Hackathon-scope heuristics:
|
| 162 |
+
- Drop statements after `return` / `raise` in the same block.
|
| 163 |
+
- Remove `if False: ...` blocks (keep `else` if present).
|
| 164 |
+
- Remove assignments to unused names in a block (very simple check).
|
| 165 |
+
"""
|
| 166 |
+
meta: Dict[str, Any] = {"type": "remove_dead_code", "success": False}
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
tree = ast.parse(code)
|
| 170 |
+
|
| 171 |
+
def _is_const_bool(expr: ast.AST, value: bool) -> bool:
|
| 172 |
+
return isinstance(expr, ast.Constant) and isinstance(expr.value, bool) and expr.value is value
|
| 173 |
+
|
| 174 |
+
class _LoadNameCollector(ast.NodeVisitor):
|
| 175 |
+
def __init__(self) -> None:
|
| 176 |
+
self.loaded: set[str] = set()
|
| 177 |
+
|
| 178 |
+
def visit_Name(self, node: ast.Name) -> None: # noqa: N802
|
| 179 |
+
if isinstance(node.ctx, ast.Load):
|
| 180 |
+
self.loaded.add(node.id)
|
| 181 |
+
|
| 182 |
+
class _DeadCode(ast.NodeTransformer):
|
| 183 |
+
def __init__(self) -> None:
|
| 184 |
+
self.changed = False
|
| 185 |
+
|
| 186 |
+
def _prune_unreachable(self, stmts: list[ast.stmt]) -> list[ast.stmt]:
|
| 187 |
+
out: list[ast.stmt] = []
|
| 188 |
+
unreachable = False
|
| 189 |
+
for s in stmts:
|
| 190 |
+
if unreachable:
|
| 191 |
+
self.changed = True
|
| 192 |
+
continue
|
| 193 |
+
out.append(s)
|
| 194 |
+
if isinstance(s, (ast.Return, ast.Raise)):
|
| 195 |
+
unreachable = True
|
| 196 |
+
return out
|
| 197 |
+
|
| 198 |
+
def _remove_unused_assigns(self, stmts: list[ast.stmt]) -> list[ast.stmt]:
|
| 199 |
+
collector = _LoadNameCollector()
|
| 200 |
+
for s in stmts:
|
| 201 |
+
collector.visit(s)
|
| 202 |
+
used = collector.loaded
|
| 203 |
+
|
| 204 |
+
out: list[ast.stmt] = []
|
| 205 |
+
for s in stmts:
|
| 206 |
+
if isinstance(s, ast.Assign) and all(isinstance(t, ast.Name) for t in s.targets):
|
| 207 |
+
targets = [t.id for t in s.targets if isinstance(t, ast.Name)]
|
| 208 |
+
# Remove only if *all* assigned names are unused.
|
| 209 |
+
if targets and all(t not in used for t in targets):
|
| 210 |
+
self.changed = True
|
| 211 |
+
continue
|
| 212 |
+
if isinstance(s, ast.AnnAssign) and isinstance(s.target, ast.Name):
|
| 213 |
+
if s.target.id not in used:
|
| 214 |
+
self.changed = True
|
| 215 |
+
continue
|
| 216 |
+
out.append(s)
|
| 217 |
+
return out
|
| 218 |
+
|
| 219 |
+
def _clean_block(self, stmts: list[ast.stmt]) -> list[ast.stmt]:
|
| 220 |
+
# First apply transformations inside statements.
|
| 221 |
+
visited = [self.visit(s) for s in stmts]
|
| 222 |
+
flat: list[ast.stmt] = []
|
| 223 |
+
for s in visited:
|
| 224 |
+
if s is None:
|
| 225 |
+
self.changed = True
|
| 226 |
+
continue
|
| 227 |
+
if isinstance(s, list):
|
| 228 |
+
flat.extend([x for x in s if isinstance(x, ast.stmt)])
|
| 229 |
+
self.changed = True
|
| 230 |
+
else:
|
| 231 |
+
flat.append(s)
|
| 232 |
+
|
| 233 |
+
flat = self._prune_unreachable(flat)
|
| 234 |
+
flat = self._remove_unused_assigns(flat)
|
| 235 |
+
return flat
|
| 236 |
+
|
| 237 |
+
def visit_Module(self, node: ast.Module) -> ast.AST: # noqa: N802
|
| 238 |
+
node.body = self._clean_block(node.body)
|
| 239 |
+
return node
|
| 240 |
+
|
| 241 |
+
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: # noqa: N802
|
| 242 |
+
node.body = self._clean_block(node.body)
|
| 243 |
+
return node
|
| 244 |
+
|
| 245 |
+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: # noqa: N802
|
| 246 |
+
node.body = self._clean_block(node.body)
|
| 247 |
+
return node
|
| 248 |
+
|
| 249 |
+
def visit_If(self, node: ast.If) -> ast.AST | list[ast.stmt]: # noqa: N802
|
| 250 |
+
node = self.generic_visit(node)
|
| 251 |
+
if _is_const_bool(node.test, False):
|
| 252 |
+
self.changed = True
|
| 253 |
+
return node.orelse or []
|
| 254 |
+
return node
|
| 255 |
+
|
| 256 |
+
def visit_While(self, node: ast.While) -> ast.AST | None: # noqa: N802
|
| 257 |
+
node = self.generic_visit(node)
|
| 258 |
+
if _is_const_bool(node.test, False):
|
| 259 |
+
self.changed = True
|
| 260 |
+
return None
|
| 261 |
+
return node
|
| 262 |
+
|
| 263 |
+
dc = _DeadCode()
|
| 264 |
+
tree = dc.visit(tree)
|
| 265 |
+
ast.fix_missing_locations(tree)
|
| 266 |
+
if not dc.changed:
|
| 267 |
+
return _unchanged(code=code, meta=meta)
|
| 268 |
+
|
| 269 |
+
out = ast.unparse(tree)
|
| 270 |
+
return _finalize_result(original=code, out=out, meta=meta)
|
| 271 |
+
except Exception:
|
| 272 |
+
return _unchanged(code=code, meta=meta)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def simplify_loops(code: str) -> TransformationResult:
|
| 276 |
+
"""
|
| 277 |
+
Simplify very basic loop patterns into more pythonic forms.
|
| 278 |
+
|
| 279 |
+
Supported pattern (only when adjacent in the same block):
|
| 280 |
+
- xs = []
|
| 281 |
+
for t in it:
|
| 282 |
+
xs.append(expr)
|
| 283 |
+
=> xs = [expr for t in it]
|
| 284 |
+
"""
|
| 285 |
+
meta: Dict[str, Any] = {"type": "simplify_loops", "success": False}
|
| 286 |
+
try:
|
| 287 |
+
tree = ast.parse(code)
|
| 288 |
+
|
| 289 |
+
class _LoopSimplifier(ast.NodeTransformer):
|
| 290 |
+
def __init__(self) -> None:
|
| 291 |
+
self.changed = False
|
| 292 |
+
|
| 293 |
+
def _simplify_body(self, body: list[ast.stmt]) -> list[ast.stmt]:
|
| 294 |
+
out: list[ast.stmt] = []
|
| 295 |
+
i = 0
|
| 296 |
+
while i < len(body):
|
| 297 |
+
cur = body[i]
|
| 298 |
+
nxt = body[i + 1] if i + 1 < len(body) else None
|
| 299 |
+
|
| 300 |
+
if (
|
| 301 |
+
isinstance(cur, ast.Assign)
|
| 302 |
+
and len(cur.targets) == 1
|
| 303 |
+
and isinstance(cur.targets[0], ast.Name)
|
| 304 |
+
and isinstance(cur.value, ast.List)
|
| 305 |
+
and cur.value.elts == []
|
| 306 |
+
and isinstance(nxt, ast.For)
|
| 307 |
+
and len(nxt.body) == 1
|
| 308 |
+
and isinstance(nxt.body[0], ast.Expr)
|
| 309 |
+
and isinstance(nxt.body[0].value, ast.Call)
|
| 310 |
+
):
|
| 311 |
+
list_name = cur.targets[0].id
|
| 312 |
+
call = nxt.body[0].value
|
| 313 |
+
if (
|
| 314 |
+
isinstance(call.func, ast.Attribute)
|
| 315 |
+
and isinstance(call.func.value, ast.Name)
|
| 316 |
+
and call.func.value.id == list_name
|
| 317 |
+
and call.func.attr == "append"
|
| 318 |
+
and len(call.args) == 1
|
| 319 |
+
and not call.keywords
|
| 320 |
+
):
|
| 321 |
+
# Build list comprehension: [call.args[0] for <target> in <iter>]
|
| 322 |
+
comp = ast.ListComp(
|
| 323 |
+
elt=call.args[0],
|
| 324 |
+
generators=[
|
| 325 |
+
ast.comprehension(
|
| 326 |
+
target=nxt.target,
|
| 327 |
+
iter=nxt.iter,
|
| 328 |
+
ifs=[],
|
| 329 |
+
is_async=0,
|
| 330 |
+
)
|
| 331 |
+
],
|
| 332 |
+
)
|
| 333 |
+
new_assign = ast.Assign(targets=[ast.Name(id=list_name, ctx=ast.Store())], value=comp)
|
| 334 |
+
out.append(ast.copy_location(new_assign, cur))
|
| 335 |
+
self.changed = True
|
| 336 |
+
i += 2
|
| 337 |
+
continue
|
| 338 |
+
|
| 339 |
+
out.append(cur)
|
| 340 |
+
i += 1
|
| 341 |
+
|
| 342 |
+
return out
|
| 343 |
+
|
| 344 |
+
def visit_Module(self, node: ast.Module) -> ast.AST: # noqa: N802
|
| 345 |
+
node = self.generic_visit(node)
|
| 346 |
+
node.body = self._simplify_body(node.body)
|
| 347 |
+
return node
|
| 348 |
+
|
| 349 |
+
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: # noqa: N802
|
| 350 |
+
node = self.generic_visit(node)
|
| 351 |
+
node.body = self._simplify_body(node.body)
|
| 352 |
+
return node
|
| 353 |
+
|
| 354 |
+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: # noqa: N802
|
| 355 |
+
node = self.generic_visit(node)
|
| 356 |
+
node.body = self._simplify_body(node.body)
|
| 357 |
+
return node
|
| 358 |
+
|
| 359 |
+
simp = _LoopSimplifier()
|
| 360 |
+
tree = simp.visit(tree)
|
| 361 |
+
ast.fix_missing_locations(tree)
|
| 362 |
+
if not simp.changed:
|
| 363 |
+
return _unchanged(code=code, meta=meta)
|
| 364 |
+
|
| 365 |
+
out = ast.unparse(tree)
|
| 366 |
+
return _finalize_result(original=code, out=out, meta=meta)
|
| 367 |
+
except Exception:
|
| 368 |
+
return _unchanged(code=code, meta=meta)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def simplify_loop(code: str) -> TransformationResult:
|
| 372 |
+
# Backwards-compatible alias for the environment's action mapping.
|
| 373 |
+
return simplify_loops(code)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def optimize_condition(code: str) -> TransformationResult:
|
| 377 |
+
"""
|
| 378 |
+
Simplify redundant boolean conditions.
|
| 379 |
+
|
| 380 |
+
Hackathon-scope heuristics:
|
| 381 |
+
- Replace `if True:` with its body; `if False:` with `else` (if present).
|
| 382 |
+
- Simplify `not not X` -> `X`.
|
| 383 |
+
- Simplify comparisons to True/False: `X == True` -> `X`, `X == False` -> `not X`.
|
| 384 |
+
"""
|
| 385 |
+
meta: Dict[str, Any] = {"type": "optimize_condition", "success": False}
|
| 386 |
+
try:
|
| 387 |
+
tree = ast.parse(code)
|
| 388 |
+
|
| 389 |
+
def _is_bool_const(node: ast.AST, value: bool) -> bool:
|
| 390 |
+
return isinstance(node, ast.Constant) and isinstance(node.value, bool) and node.value is value
|
| 391 |
+
|
| 392 |
+
class _CondOpt(ast.NodeTransformer):
|
| 393 |
+
def __init__(self) -> None:
|
| 394 |
+
self.changed = False
|
| 395 |
+
|
| 396 |
+
def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST: # noqa: N802
|
| 397 |
+
node = self.generic_visit(node)
|
| 398 |
+
if isinstance(node.op, ast.Not) and isinstance(node.operand, ast.UnaryOp) and isinstance(node.operand.op, ast.Not):
|
| 399 |
+
self.changed = True
|
| 400 |
+
return node.operand.operand
|
| 401 |
+
return node
|
| 402 |
+
|
| 403 |
+
def visit_Compare(self, node: ast.Compare) -> ast.AST: # noqa: N802
|
| 404 |
+
node = self.generic_visit(node)
|
| 405 |
+
if len(node.ops) == 1 and len(node.comparators) == 1:
|
| 406 |
+
op = node.ops[0]
|
| 407 |
+
rhs = node.comparators[0]
|
| 408 |
+
if isinstance(op, (ast.Eq, ast.Is)) and _is_bool_const(rhs, True):
|
| 409 |
+
self.changed = True
|
| 410 |
+
return node.left
|
| 411 |
+
if isinstance(op, (ast.Eq, ast.Is)) and _is_bool_const(rhs, False):
|
| 412 |
+
self.changed = True
|
| 413 |
+
return ast.UnaryOp(op=ast.Not(), operand=node.left)
|
| 414 |
+
return node
|
| 415 |
+
|
| 416 |
+
def visit_If(self, node: ast.If) -> ast.AST | list[ast.stmt]: # noqa: N802
|
| 417 |
+
node = self.generic_visit(node)
|
| 418 |
+
if _is_bool_const(node.test, True):
|
| 419 |
+
self.changed = True
|
| 420 |
+
return node.body
|
| 421 |
+
if _is_bool_const(node.test, False):
|
| 422 |
+
self.changed = True
|
| 423 |
+
return node.orelse or []
|
| 424 |
+
return node
|
| 425 |
+
|
| 426 |
+
opt = _CondOpt()
|
| 427 |
+
tree = opt.visit(tree)
|
| 428 |
+
ast.fix_missing_locations(tree)
|
| 429 |
+
if not opt.changed:
|
| 430 |
+
return _unchanged(code=code, meta=meta)
|
| 431 |
+
|
| 432 |
+
out = ast.unparse(tree)
|
| 433 |
+
return _finalize_result(original=code, out=out, meta=meta)
|
| 434 |
+
except Exception:
|
| 435 |
+
return _unchanged(code=code, meta=meta)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def inline_function(code: str) -> TransformationResult:
|
| 439 |
+
"""
|
| 440 |
+
Inline very simple functions into their call sites.
|
| 441 |
+
|
| 442 |
+
Supported pattern:
|
| 443 |
+
- def f(a, b): return <expr using only a,b>
|
| 444 |
+
- Replace calls: f(x, y) -> <expr with a->x, b->y>
|
| 445 |
+
Only handles module-level functions and positional args.
|
| 446 |
+
"""
|
| 447 |
+
meta: Dict[str, Any] = {"type": "inline_function", "success": False}
|
| 448 |
+
try:
|
| 449 |
+
tree = ast.parse(code)
|
| 450 |
+
|
| 451 |
+
simple_fns: Dict[str, tuple[list[str], ast.AST]] = {}
|
| 452 |
+
for node in tree.body:
|
| 453 |
+
if not isinstance(node, ast.FunctionDef):
|
| 454 |
+
continue
|
| 455 |
+
if node.decorator_list:
|
| 456 |
+
continue
|
| 457 |
+
args = node.args
|
| 458 |
+
if args.vararg or args.kwarg or args.kwonlyargs or args.defaults or args.posonlyargs:
|
| 459 |
+
continue
|
| 460 |
+
if len(node.body) != 1 or not isinstance(node.body[0], ast.Return) or node.body[0].value is None:
|
| 461 |
+
continue
|
| 462 |
+
arg_names = [a.arg for a in args.args]
|
| 463 |
+
# Ensure the return expression only references the function's args.
|
| 464 |
+
referenced: set[str] = set()
|
| 465 |
+
|
| 466 |
+
class _Ref(ast.NodeVisitor):
|
| 467 |
+
def visit_Name(self, n: ast.Name) -> None: # noqa: N802
|
| 468 |
+
if isinstance(n.ctx, ast.Load):
|
| 469 |
+
referenced.add(n.id)
|
| 470 |
+
|
| 471 |
+
_Ref().visit(node.body[0].value)
|
| 472 |
+
if not referenced.issubset(set(arg_names)):
|
| 473 |
+
continue
|
| 474 |
+
simple_fns[node.name] = (arg_names, node.body[0].value)
|
| 475 |
+
|
| 476 |
+
if not simple_fns:
|
| 477 |
+
return _unchanged(code=code, meta=meta)
|
| 478 |
+
|
| 479 |
+
class _Substitute(ast.NodeTransformer):
|
| 480 |
+
def __init__(self, mapping: Dict[str, ast.AST]) -> None:
|
| 481 |
+
self.mapping = mapping
|
| 482 |
+
|
| 483 |
+
def visit_Name(self, n: ast.Name) -> ast.AST: # noqa: N802
|
| 484 |
+
if isinstance(n.ctx, ast.Load) and n.id in self.mapping:
|
| 485 |
+
return copy.deepcopy(self.mapping[n.id])
|
| 486 |
+
return n
|
| 487 |
+
|
| 488 |
+
class _Inliner(ast.NodeTransformer):
|
| 489 |
+
def __init__(self) -> None:
|
| 490 |
+
self.changed = False
|
| 491 |
+
|
| 492 |
+
def visit_Call(self, node: ast.Call) -> ast.AST: # noqa: N802
|
| 493 |
+
node = self.generic_visit(node)
|
| 494 |
+
if not isinstance(node.func, ast.Name):
|
| 495 |
+
return node
|
| 496 |
+
fn = simple_fns.get(node.func.id)
|
| 497 |
+
if fn is None:
|
| 498 |
+
return node
|
| 499 |
+
arg_names, expr = fn
|
| 500 |
+
if node.keywords or len(node.args) != len(arg_names):
|
| 501 |
+
return node
|
| 502 |
+
mapping = {name: arg for name, arg in zip(arg_names, node.args, strict=True)}
|
| 503 |
+
new_expr = _Substitute(mapping).visit(copy.deepcopy(expr))
|
| 504 |
+
self.changed = True
|
| 505 |
+
return ast.copy_location(new_expr, node)
|
| 506 |
+
|
| 507 |
+
inliner = _Inliner()
|
| 508 |
+
tree = inliner.visit(tree)
|
| 509 |
+
ast.fix_missing_locations(tree)
|
| 510 |
+
if not inliner.changed:
|
| 511 |
+
return _unchanged(code=code, meta=meta)
|
| 512 |
+
|
| 513 |
+
out = ast.unparse(tree)
|
| 514 |
+
meta["inlined"] = sorted(simple_fns.keys())
|
| 515 |
+
return _finalize_result(original=code, out=out, meta=meta)
|
| 516 |
+
except Exception:
|
| 517 |
+
return _unchanged(code=code, meta=meta)
|
| 518 |
+
|
acre/datasets/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Datasets and sample code providers for ACRE."""
|
| 2 |
+
|
| 3 |
+
from .code_samples import CodeSample, CodeSampleDataset
|
| 4 |
+
|
| 5 |
+
__all__ = ["CodeSample", "CodeSampleDataset"]
|
| 6 |
+
|
acre/datasets/code_samples.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Iterable, Iterator, List, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass(frozen=True)
|
| 8 |
+
class CodeSample:
|
| 9 |
+
"""A single code sample (placeholder)."""
|
| 10 |
+
|
| 11 |
+
id: str
|
| 12 |
+
language: str
|
| 13 |
+
code: str
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class CodeSampleDataset:
|
| 17 |
+
"""
|
| 18 |
+
Minimal in-memory dataset stub.
|
| 19 |
+
|
| 20 |
+
Later versions can back this with files, Git repos, or benchmark suites.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, samples: Optional[Iterable[CodeSample]] = None) -> None:
|
| 24 |
+
self._samples: List[CodeSample] = list(samples or [])
|
| 25 |
+
|
| 26 |
+
def __len__(self) -> int:
|
| 27 |
+
return len(self._samples)
|
| 28 |
+
|
| 29 |
+
def __iter__(self) -> Iterator[CodeSample]:
|
| 30 |
+
return iter(self._samples)
|
| 31 |
+
|
| 32 |
+
def add(self, sample: CodeSample) -> None:
|
| 33 |
+
self._samples.append(sample)
|
| 34 |
+
|
acre/demo.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import sys
|
| 6 |
+
from typing import Any, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from acre.datasets.code_samples import CodeSample, CodeSampleDataset
|
| 9 |
+
from acre.env.refactor_env import RefactorEnv
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _load_model(path: str):
|
| 13 |
+
"""Load a Stable-Baselines3 PPO model if available; otherwise return None."""
|
| 14 |
+
if not os.path.exists(path):
|
| 15 |
+
return None
|
| 16 |
+
try:
|
| 17 |
+
from stable_baselines3 import PPO
|
| 18 |
+
except Exception:
|
| 19 |
+
return None
|
| 20 |
+
try:
|
| 21 |
+
return PPO.load(path)
|
| 22 |
+
except Exception:
|
| 23 |
+
return None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _messy_sample_code() -> str:
|
| 27 |
+
# Intentionally "messy" but valid Python for demo purposes.
|
| 28 |
+
return (
|
| 29 |
+
"def add(a,b):\n"
|
| 30 |
+
" x=0\n"
|
| 31 |
+
" for i in range(a):\n"
|
| 32 |
+
" x=x+1\n"
|
| 33 |
+
" if True:\n"
|
| 34 |
+
" x = x\n"
|
| 35 |
+
" if False:\n"
|
| 36 |
+
" y=123\n"
|
| 37 |
+
" else:\n"
|
| 38 |
+
" y=0\n"
|
| 39 |
+
" def f(p,q):\n"
|
| 40 |
+
" return p+q\n"
|
| 41 |
+
" r = f(x,y)\n"
|
| 42 |
+
" return r\n"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _format_code_block(code: str) -> str:
|
| 47 |
+
return "\n".join(f" {line}" for line in code.rstrip().splitlines()) + "\n"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _safe_print(text: str) -> None:
|
| 51 |
+
"""
|
| 52 |
+
Print text safely across Windows consoles (some default encodings can't print emojis).
|
| 53 |
+
"""
|
| 54 |
+
encoding = sys.stdout.encoding or "utf-8"
|
| 55 |
+
try:
|
| 56 |
+
text.encode(encoding)
|
| 57 |
+
print(text, flush=True)
|
| 58 |
+
except Exception:
|
| 59 |
+
# Fall back to ASCII-friendly markers if emojis can't be encoded.
|
| 60 |
+
safe = text.replace("✅", "[OK]").replace("⚠️", "[WARN]").replace("⚠", "[WARN]")
|
| 61 |
+
print(safe, flush=True)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _compute_runtime(executor: Any, code: str) -> float:
|
| 65 |
+
"""Best-effort runtime metric using the current executor contract."""
|
| 66 |
+
try:
|
| 67 |
+
res = executor.run(code, filename="demo.py")
|
| 68 |
+
if getattr(res, "exit_code", 1) == 0 and isinstance(getattr(res, "metrics", None), dict):
|
| 69 |
+
return float(res.metrics.get("runtime_s", 0.0) or 0.0)
|
| 70 |
+
except Exception:
|
| 71 |
+
pass
|
| 72 |
+
return 0.0
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _choose_action(model: Any, obs, env: RefactorEnv, rng: random.Random) -> Tuple[int, str]:
|
| 76 |
+
"""Choose an action from the model, falling back to random."""
|
| 77 |
+
n_actions = int(getattr(getattr(env, "action_space", None), "n", 5))
|
| 78 |
+
if model is None:
|
| 79 |
+
a = int(rng.randint(0, n_actions - 1))
|
| 80 |
+
return a, "random"
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
action, _state = model.predict(obs, deterministic=True)
|
| 84 |
+
# SB3 may return scalar or 1-element array.
|
| 85 |
+
if hasattr(action, "__len__"):
|
| 86 |
+
a = int(action[0])
|
| 87 |
+
else:
|
| 88 |
+
a = int(action)
|
| 89 |
+
return a, "ppo"
|
| 90 |
+
except Exception:
|
| 91 |
+
a = int(rng.randint(0, n_actions - 1))
|
| 92 |
+
return a, "random"
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def run_demo(*, model_path: str = "acre_agent.zip", seed: int = 0) -> None:
|
| 96 |
+
rng = random.Random(seed)
|
| 97 |
+
|
| 98 |
+
# Create a dataset with one messy sample so `reset()` loads it deterministically.
|
| 99 |
+
dataset = CodeSampleDataset(
|
| 100 |
+
[
|
| 101 |
+
CodeSample(
|
| 102 |
+
id="demo_sample",
|
| 103 |
+
language="python",
|
| 104 |
+
code=_messy_sample_code(),
|
| 105 |
+
)
|
| 106 |
+
]
|
| 107 |
+
)
|
| 108 |
+
env = RefactorEnv(dataset=dataset, seed=seed)
|
| 109 |
+
|
| 110 |
+
model = _load_model(model_path)
|
| 111 |
+
model_status = "loaded" if model is not None else "not found (using random actions)"
|
| 112 |
+
|
| 113 |
+
# Reset and capture the original code/metrics.
|
| 114 |
+
obs, info = env.reset()
|
| 115 |
+
original_code = getattr(env, "_code", "")
|
| 116 |
+
original_complexity = float(getattr(env, "_compute_complexity")(original_code))
|
| 117 |
+
original_runtime = _compute_runtime(env.executor, original_code)
|
| 118 |
+
|
| 119 |
+
print("=" * 72)
|
| 120 |
+
print("ACRE: Autonomous RL Code Refactoring Agent (5-step episode)")
|
| 121 |
+
print(f"Model: {model_path} -> {model_status}")
|
| 122 |
+
print(f"Sample: {info.get('sample_id')} ({info.get('language')})")
|
| 123 |
+
print("=" * 72)
|
| 124 |
+
print("\nORIGINAL CODE:\n")
|
| 125 |
+
print(_format_code_block(original_code))
|
| 126 |
+
|
| 127 |
+
total_reward = 0.0
|
| 128 |
+
successful_transformations = 0
|
| 129 |
+
steps_taken = 0
|
| 130 |
+
|
| 131 |
+
for step_idx in range(1, 6):
|
| 132 |
+
action, policy = _choose_action(model, obs, env, rng)
|
| 133 |
+
obs, reward, terminated, truncated, step_info = env.step(action)
|
| 134 |
+
total_reward += float(reward)
|
| 135 |
+
steps_taken = step_idx
|
| 136 |
+
|
| 137 |
+
action_name = step_info.get("action_name", "unknown")
|
| 138 |
+
transform_meta = step_info.get("transform", {})
|
| 139 |
+
if isinstance(transform_meta, dict) and bool(transform_meta.get("success", False)):
|
| 140 |
+
successful_transformations += 1
|
| 141 |
+
transformed_code = getattr(env, "_code", "")
|
| 142 |
+
|
| 143 |
+
print("-" * 72)
|
| 144 |
+
print(f"STEP {step_idx}/5")
|
| 145 |
+
print(f"policy={policy} action={action} ({action_name})")
|
| 146 |
+
print(f"transform={transform_meta}")
|
| 147 |
+
print(f"reward={float(reward):.2f} components={step_info.get('reward_components')}")
|
| 148 |
+
print("\nUPDATED CODE:\n")
|
| 149 |
+
print(_format_code_block(transformed_code))
|
| 150 |
+
|
| 151 |
+
if terminated or truncated:
|
| 152 |
+
break
|
| 153 |
+
|
| 154 |
+
final_code = getattr(env, "_code", "")
|
| 155 |
+
final_complexity = float(getattr(env, "_compute_complexity")(final_code))
|
| 156 |
+
final_runtime = _compute_runtime(env.executor, final_code)
|
| 157 |
+
|
| 158 |
+
print("=" * 72)
|
| 159 |
+
print("FINAL SUMMARY")
|
| 160 |
+
print("=" * 72)
|
| 161 |
+
print(f"total_reward: {total_reward:.2f}")
|
| 162 |
+
print(f"complexity: {original_complexity:.0f} -> {final_complexity:.0f}")
|
| 163 |
+
print(f"runtime_s: {original_runtime:.4f} -> {final_runtime:.4f}")
|
| 164 |
+
|
| 165 |
+
complexity_improvement = ((original_complexity - final_complexity) / max(original_complexity, 1.0)) * 100.0
|
| 166 |
+
print(f"complexity improvement: {complexity_improvement:.2f}%")
|
| 167 |
+
|
| 168 |
+
print("\nCHANGES APPLIED:")
|
| 169 |
+
print(f"- Total steps: {steps_taken}")
|
| 170 |
+
print(f"- Successful transformations: {successful_transformations}")
|
| 171 |
+
|
| 172 |
+
if total_reward > 0:
|
| 173 |
+
_safe_print("\n✅ Code improved successfully")
|
| 174 |
+
else:
|
| 175 |
+
_safe_print("\n⚠️ No significant improvement")
|
| 176 |
+
|
| 177 |
+
print("\nFINAL CODE:\n")
|
| 178 |
+
print(_format_code_block(final_code))
|
| 179 |
+
|
| 180 |
+
env.close()
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
if __name__ == "__main__":
|
| 184 |
+
run_demo()
|
| 185 |
+
|
acre/main.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
from acre.training.train_agent import TrainConfig, train
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _build_parser() -> argparse.ArgumentParser:
|
| 9 |
+
parser = argparse.ArgumentParser(prog="acre", description="ACRE: Autonomous Code Refactoring Environment")
|
| 10 |
+
sub = parser.add_subparsers(dest="command", required=False)
|
| 11 |
+
|
| 12 |
+
train_p = sub.add_parser("train", help="Run training (stub)")
|
| 13 |
+
train_p.add_argument("--total-steps", type=int, default=100, help="Total training steps (stub)")
|
| 14 |
+
|
| 15 |
+
sub.add_parser("demo", help="Run a small demo (stub)")
|
| 16 |
+
|
| 17 |
+
return parser
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def run_demo() -> None:
|
| 21 |
+
# Placeholder for a future interactive/demo flow.
|
| 22 |
+
print("ACRE demo mode is not implemented yet.")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def main(argv: list[str] | None = None) -> None:
|
| 26 |
+
parser = _build_parser()
|
| 27 |
+
args = parser.parse_args(argv)
|
| 28 |
+
|
| 29 |
+
if args.command == "demo":
|
| 30 |
+
run_demo()
|
| 31 |
+
return
|
| 32 |
+
|
| 33 |
+
total_steps = getattr(args, "total_steps", 100)
|
| 34 |
+
train(config=TrainConfig(total_steps=total_steps))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if __name__ == "__main__":
|
| 38 |
+
main()
|
| 39 |
+
|
acre/tasks/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from acre.tasks.task_registry import Task, TaskRegistry
|
| 2 |
+
|
| 3 |
+
__all__ = ["Task", "TaskRegistry"]
|
acre/tasks/task_registry.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Three OpenEnv tasks with AST-based graders scoring 0.0-1.0.
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import ast
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Callable, Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class Task:
|
| 13 |
+
id: str
|
| 14 |
+
name: str
|
| 15 |
+
description: str
|
| 16 |
+
difficulty: str
|
| 17 |
+
initial_code: str
|
| 18 |
+
_grade_fn: Callable[[str], float]
|
| 19 |
+
|
| 20 |
+
def grade(self, code: str) -> float:
|
| 21 |
+
"""Return a score in [0.0, 1.0]."""
|
| 22 |
+
try:
|
| 23 |
+
return float(min(1.0, max(0.0, self._grade_fn(code))))
|
| 24 |
+
except Exception:
|
| 25 |
+
return 0.0
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Task 1 — Easy: Rename generic variables
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
_EASY_CODE = """\
|
| 32 |
+
def compute(x, y, tmp):
|
| 33 |
+
tmp = x + y
|
| 34 |
+
x = tmp * 2
|
| 35 |
+
result = x
|
| 36 |
+
return result
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _grade_easy(code: str) -> float:
|
| 41 |
+
"""Score = fraction of generic names (x, tmp) removed from all scopes."""
|
| 42 |
+
generic = {"x", "tmp"}
|
| 43 |
+
try:
|
| 44 |
+
tree = ast.parse(code)
|
| 45 |
+
except SyntaxError:
|
| 46 |
+
return 0.0
|
| 47 |
+
|
| 48 |
+
remaining: set[str] = set()
|
| 49 |
+
|
| 50 |
+
class _Collector(ast.NodeVisitor):
|
| 51 |
+
def visit_Name(self, node: ast.Name) -> None:
|
| 52 |
+
if node.id in generic:
|
| 53 |
+
remaining.add(node.id)
|
| 54 |
+
self.generic_visit(node)
|
| 55 |
+
|
| 56 |
+
def visit_arg(self, node: ast.arg) -> None:
|
| 57 |
+
if node.arg in generic:
|
| 58 |
+
remaining.add(node.arg)
|
| 59 |
+
self.generic_visit(node)
|
| 60 |
+
|
| 61 |
+
_Collector().visit(tree)
|
| 62 |
+
renamed = len(generic - remaining)
|
| 63 |
+
return renamed / len(generic)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ---------------------------------------------------------------------------
|
| 67 |
+
# Task 2 — Medium: Remove dead code
|
| 68 |
+
# ---------------------------------------------------------------------------
|
| 69 |
+
_MEDIUM_CODE = """\
|
| 70 |
+
def process(data):
|
| 71 |
+
result = []
|
| 72 |
+
for item in data:
|
| 73 |
+
result.append(item * 2)
|
| 74 |
+
if False:
|
| 75 |
+
print("never runs")
|
| 76 |
+
unused_var = 42
|
| 77 |
+
return result
|
| 78 |
+
print("unreachable")
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _grade_medium(code: str) -> float:
|
| 83 |
+
"""Score = fraction of dead-code patterns eliminated (3 checks, ~0.33 each)."""
|
| 84 |
+
try:
|
| 85 |
+
tree = ast.parse(code)
|
| 86 |
+
except SyntaxError:
|
| 87 |
+
return 0.0
|
| 88 |
+
|
| 89 |
+
source = ast.unparse(tree)
|
| 90 |
+
score = 0.0
|
| 91 |
+
|
| 92 |
+
# Check 1: if-False block removed
|
| 93 |
+
if "if False" not in source:
|
| 94 |
+
score += 1 / 3
|
| 95 |
+
|
| 96 |
+
# Check 2: unused_var assignment removed
|
| 97 |
+
if "unused_var" not in source:
|
| 98 |
+
score += 1 / 3
|
| 99 |
+
|
| 100 |
+
# Check 3: list comprehension used (loop simplified)
|
| 101 |
+
has_listcomp = any(isinstance(n, ast.ListComp) for n in ast.walk(tree))
|
| 102 |
+
if has_listcomp:
|
| 103 |
+
score += 1 / 3
|
| 104 |
+
|
| 105 |
+
return score
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
# Task 3 — Hard: Full refactor
|
| 110 |
+
# ---------------------------------------------------------------------------
|
| 111 |
+
_HARD_CODE = """\
|
| 112 |
+
def add(p, q):
|
| 113 |
+
return p + q
|
| 114 |
+
|
| 115 |
+
def compute(x, data, tmp):
|
| 116 |
+
result = []
|
| 117 |
+
for item in data:
|
| 118 |
+
result.append(item * 2)
|
| 119 |
+
if False:
|
| 120 |
+
y = 999
|
| 121 |
+
if True:
|
| 122 |
+
val = add(x, tmp)
|
| 123 |
+
unused = 0
|
| 124 |
+
flag = not not True
|
| 125 |
+
return val
|
| 126 |
+
print("dead")
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _grade_hard(code: str) -> float:
|
| 131 |
+
"""Score = fraction of 5 quality checks passed."""
|
| 132 |
+
try:
|
| 133 |
+
tree = ast.parse(code)
|
| 134 |
+
except SyntaxError:
|
| 135 |
+
return 0.0
|
| 136 |
+
|
| 137 |
+
source = ast.unparse(tree)
|
| 138 |
+
checks = 0
|
| 139 |
+
|
| 140 |
+
# 1. No generic variable names x/tmp in function signature or body
|
| 141 |
+
has_generic = False
|
| 142 |
+
|
| 143 |
+
class _GenCheck(ast.NodeVisitor):
|
| 144 |
+
def visit_arg(self, node: ast.arg) -> None:
|
| 145 |
+
nonlocal has_generic
|
| 146 |
+
if node.arg in {"x", "tmp"}:
|
| 147 |
+
has_generic = True
|
| 148 |
+
|
| 149 |
+
_GenCheck().visit(tree)
|
| 150 |
+
if not has_generic:
|
| 151 |
+
checks += 1
|
| 152 |
+
|
| 153 |
+
# 2. No if False block
|
| 154 |
+
if "if False" not in source:
|
| 155 |
+
checks += 1
|
| 156 |
+
|
| 157 |
+
# 3. if True removed (body inlined)
|
| 158 |
+
if "if True" not in source:
|
| 159 |
+
checks += 1
|
| 160 |
+
|
| 161 |
+
# 4. List comprehension used
|
| 162 |
+
if any(isinstance(n, ast.ListComp) for n in ast.walk(tree)):
|
| 163 |
+
checks += 1
|
| 164 |
+
|
| 165 |
+
# 5. add() call inlined (no call to 'add')
|
| 166 |
+
calls = [n for n in ast.walk(tree) if isinstance(n, ast.Call)]
|
| 167 |
+
fn_names = {c.func.id for c in calls if isinstance(c.func, ast.Name)}
|
| 168 |
+
if "add" not in fn_names:
|
| 169 |
+
checks += 1
|
| 170 |
+
|
| 171 |
+
return checks / 5
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ---------------------------------------------------------------------------
|
| 175 |
+
# Registry
|
| 176 |
+
# ---------------------------------------------------------------------------
|
| 177 |
+
|
| 178 |
+
class TaskRegistry:
|
| 179 |
+
def __init__(self) -> None:
|
| 180 |
+
self._tasks: Dict[str, Task] = {}
|
| 181 |
+
self._register_all()
|
| 182 |
+
|
| 183 |
+
def _register_all(self) -> None:
|
| 184 |
+
self._tasks["rename_variables"] = Task(
|
| 185 |
+
id="rename_variables",
|
| 186 |
+
name="Rename Variables (Easy)",
|
| 187 |
+
description="Rename generic variable names (x, tmp) to descriptive ones",
|
| 188 |
+
difficulty="easy",
|
| 189 |
+
initial_code=_EASY_CODE,
|
| 190 |
+
_grade_fn=_grade_easy,
|
| 191 |
+
)
|
| 192 |
+
self._tasks["remove_dead_code"] = Task(
|
| 193 |
+
id="remove_dead_code",
|
| 194 |
+
name="Remove Dead Code (Medium)",
|
| 195 |
+
description="Remove unreachable code, if False blocks, and unused variables",
|
| 196 |
+
difficulty="medium",
|
| 197 |
+
initial_code=_MEDIUM_CODE,
|
| 198 |
+
_grade_fn=_grade_medium,
|
| 199 |
+
)
|
| 200 |
+
self._tasks["full_refactor"] = Task(
|
| 201 |
+
id="full_refactor",
|
| 202 |
+
name="Full Refactor (Hard)",
|
| 203 |
+
description="Apply all transformations: rename, dead code, loops, conditions, inlining",
|
| 204 |
+
difficulty="hard",
|
| 205 |
+
initial_code=_HARD_CODE,
|
| 206 |
+
_grade_fn=_grade_hard,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
def get_task(self, task_id: str) -> Optional[Task]:
|
| 210 |
+
return self._tasks.get(task_id)
|
| 211 |
+
|
| 212 |
+
def list_tasks(self) -> List[dict]:
|
| 213 |
+
return [
|
| 214 |
+
{
|
| 215 |
+
"id": t.id,
|
| 216 |
+
"name": t.name,
|
| 217 |
+
"description": t.description,
|
| 218 |
+
"difficulty": t.difficulty,
|
| 219 |
+
"initial_code": t.initial_code,
|
| 220 |
+
}
|
| 221 |
+
for t in self._tasks.values()
|
| 222 |
+
]
|
acre/training/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training utilities for ACRE."""
|
| 2 |
+
|
| 3 |
+
from .train_agent import TrainConfig, train
|
| 4 |
+
|
| 5 |
+
__all__ = ["TrainConfig", "train"]
|
| 6 |
+
|
acre/training/train_agent.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
from acre.env.refactor_env import RefactorEnv
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass(frozen=True)
|
| 10 |
+
class TrainConfig:
|
| 11 |
+
"""Configuration stub for training."""
|
| 12 |
+
|
| 13 |
+
total_steps: int = 5_000
|
| 14 |
+
seed: Optional[int] = None
|
| 15 |
+
model_path: str = "acre_agent.zip"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def train(*, env: Optional[RefactorEnv] = None, config: Optional[TrainConfig] = None) -> None:
|
| 19 |
+
"""
|
| 20 |
+
Train a PPO agent on `RefactorEnv` using Stable-Baselines3.
|
| 21 |
+
|
| 22 |
+
This is intentionally lightweight (hackathon-friendly) and focuses on a
|
| 23 |
+
working demo: basic training loop, simple logging, and saving the model.
|
| 24 |
+
"""
|
| 25 |
+
_config = config or TrainConfig()
|
| 26 |
+
_env = env or RefactorEnv(seed=_config.seed)
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from stable_baselines3 import PPO
|
| 30 |
+
from stable_baselines3.common.callbacks import BaseCallback
|
| 31 |
+
from stable_baselines3.common.monitor import Monitor
|
| 32 |
+
from stable_baselines3.common.vec_env import DummyVecEnv
|
| 33 |
+
except Exception as e: # pragma: no cover
|
| 34 |
+
print("Stable-Baselines3 is required for training. Install with `pip install -r requirements.txt`.")
|
| 35 |
+
print(f"Import error: {e}")
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
class EpisodeRewardPrinter(BaseCallback):
|
| 39 |
+
"""Print episode reward when an episode ends (via Monitor)."""
|
| 40 |
+
|
| 41 |
+
def __init__(self) -> None:
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.episode_count = 0
|
| 44 |
+
|
| 45 |
+
def _on_step(self) -> bool:
|
| 46 |
+
infos = self.locals.get("infos", [])
|
| 47 |
+
for info in infos:
|
| 48 |
+
ep = info.get("episode") if isinstance(info, dict) else None
|
| 49 |
+
if isinstance(ep, dict) and "r" in ep:
|
| 50 |
+
self.episode_count += 1
|
| 51 |
+
print(f"episode={self.episode_count} reward={ep['r']:.2f} length={int(ep.get('l', 0))}")
|
| 52 |
+
return True
|
| 53 |
+
|
| 54 |
+
# Wrap with Monitor so SB3 can compute episode stats and expose them in `info["episode"]`.
|
| 55 |
+
def make_env() -> RefactorEnv:
|
| 56 |
+
return Monitor(_env)
|
| 57 |
+
|
| 58 |
+
vec_env = DummyVecEnv([make_env])
|
| 59 |
+
|
| 60 |
+
model = PPO(
|
| 61 |
+
policy="MlpPolicy",
|
| 62 |
+
env=vec_env,
|
| 63 |
+
verbose=0,
|
| 64 |
+
seed=_config.seed,
|
| 65 |
+
n_steps=64,
|
| 66 |
+
batch_size=64,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
print(f"Training PPO for {int(_config.total_steps)} timesteps...")
|
| 70 |
+
model.learn(total_timesteps=int(_config.total_steps), callback=EpisodeRewardPrinter())
|
| 71 |
+
|
| 72 |
+
model.save(_config.model_path)
|
| 73 |
+
print(f"Saved model to {_config.model_path!r}")
|
| 74 |
+
return None
|
| 75 |
+
|
acre/utils/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared utility helpers for ACRE."""
|
| 2 |
+
|
| 3 |
+
from .metrics import Metric, MetricLogger
|
| 4 |
+
|
| 5 |
+
__all__ = ["Metric", "MetricLogger"]
|
| 6 |
+
|
acre/utils/metrics.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Dict, Iterable, List, Tuple
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass(frozen=True)
|
| 8 |
+
class Metric:
|
| 9 |
+
"""Single scalar metric value (placeholder)."""
|
| 10 |
+
|
| 11 |
+
name: str
|
| 12 |
+
value: float
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class MetricLogger:
|
| 17 |
+
"""Tiny metric logger stub."""
|
| 18 |
+
|
| 19 |
+
_history: Dict[str, List[float]] = field(default_factory=dict)
|
| 20 |
+
|
| 21 |
+
def log(self, metric: Metric) -> None:
|
| 22 |
+
self._history.setdefault(metric.name, []).append(metric.value)
|
| 23 |
+
|
| 24 |
+
def latest(self) -> Dict[str, float]:
|
| 25 |
+
return {k: v[-1] for k, v in self._history.items() if v}
|
| 26 |
+
|
| 27 |
+
def as_series(self) -> Dict[str, Tuple[float, ...]]:
|
| 28 |
+
return {k: tuple(v) for k, v in self._history.items()}
|
| 29 |
+
|
| 30 |
+
def extend(self, metrics: Iterable[Metric]) -> None:
|
| 31 |
+
for m in metrics:
|
| 32 |
+
self.log(m)
|
| 33 |
+
|
inference.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACRE inference script for OpenEnv submission evaluation.
|
| 3 |
+
|
| 4 |
+
Required environment variables:
|
| 5 |
+
API_BASE_URL: LLM API endpoint (default allowed)
|
| 6 |
+
MODEL_NAME: model identifier (default allowed)
|
| 7 |
+
HF_TOKEN: API token for the OpenAI-compatible endpoint
|
| 8 |
+
ENV_URL: running ACRE server base URL
|
| 9 |
+
|
| 10 |
+
Optional:
|
| 11 |
+
LOCAL_IMAGE_NAME: present for evaluator compatibility when using a local
|
| 12 |
+
Docker image launcher.
|
| 13 |
+
|
| 14 |
+
Stdout format uses strict START / STEP / END event markers.
|
| 15 |
+
"""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import re
|
| 21 |
+
import sys
|
| 22 |
+
import time
|
| 23 |
+
from typing import Dict, List, Tuple
|
| 24 |
+
|
| 25 |
+
import requests
|
| 26 |
+
from openai import OpenAI
|
| 27 |
+
|
| 28 |
+
API_BASE_URL: str = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
|
| 29 |
+
MODEL_NAME: str = os.getenv("MODEL_NAME", "gpt-4o-mini")
|
| 30 |
+
HF_TOKEN: str | None = os.getenv("HF_TOKEN")
|
| 31 |
+
ENV_URL: str | None = os.getenv("ENV_URL")
|
| 32 |
+
LOCAL_IMAGE_NAME: str | None = os.getenv("LOCAL_IMAGE_NAME")
|
| 33 |
+
|
| 34 |
+
TASKS: List[str] = ["rename_variables", "remove_dead_code", "full_refactor"]
|
| 35 |
+
|
| 36 |
+
ACTION_MEANINGS: Dict[int, str] = {
|
| 37 |
+
0: "rename_variable",
|
| 38 |
+
1: "remove_dead_code",
|
| 39 |
+
2: "simplify_loop",
|
| 40 |
+
3: "optimize_condition",
|
| 41 |
+
4: "inline_function",
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
SYSTEM_PROMPT = """\
|
| 45 |
+
You are an RL agent that refactors Python code. Choose one action per step.
|
| 46 |
+
|
| 47 |
+
Actions:
|
| 48 |
+
0 rename_variable - rename generic names (x, tmp, i) to descriptive ones
|
| 49 |
+
1 remove_dead_code - remove unreachable stmts, if False blocks, unused vars
|
| 50 |
+
2 simplify_loop - convert append-loops to list comprehensions
|
| 51 |
+
3 optimize_condition- simplify 'not not x', 'if True/False', 'x==True'
|
| 52 |
+
4 inline_function - inline simple single-return module-level functions
|
| 53 |
+
|
| 54 |
+
Respond ONLY with valid JSON (no markdown):
|
| 55 |
+
{"action": <0-4>, "reason": "<one sentence>"}"""
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _env_url() -> str:
|
| 59 |
+
if ENV_URL:
|
| 60 |
+
return ENV_URL.rstrip("/")
|
| 61 |
+
raise RuntimeError("ENV_URL must be set before running inference.py")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _post(path: str, payload: dict | None = None) -> dict:
|
| 65 |
+
response = requests.post(f"{_env_url()}{path}", json=payload or {}, timeout=30)
|
| 66 |
+
response.raise_for_status()
|
| 67 |
+
return response.json()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _get(path: str) -> dict:
|
| 71 |
+
response = requests.get(f"{_env_url()}{path}", timeout=30)
|
| 72 |
+
response.raise_for_status()
|
| 73 |
+
return response.json()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def reset_env(task_id: str) -> dict:
|
| 77 |
+
return _post("/reset", {"task_id": task_id})
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def step_env(action: int) -> dict:
|
| 81 |
+
return _post("/step", {"action": action})
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_state() -> dict:
|
| 85 |
+
return _get("/state")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def grade(task_id: str, code: str) -> float:
|
| 89 |
+
response = requests.post(
|
| 90 |
+
f"{_env_url()}/tasks/{task_id}/grade",
|
| 91 |
+
json={"code": code},
|
| 92 |
+
timeout=30,
|
| 93 |
+
)
|
| 94 |
+
response.raise_for_status()
|
| 95 |
+
return float(response.json().get("score", 0.0))
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def choose_action(client: OpenAI, state: dict, task_id: str) -> Tuple[int, str]:
|
| 99 |
+
def heuristic_action() -> Tuple[int, str]:
|
| 100 |
+
code = str(state.get("current_code", ""))
|
| 101 |
+
step_i = int(state.get("episode_steps", 0))
|
| 102 |
+
|
| 103 |
+
has_generic = re.search(r"\b(x|tmp|i)\b", code) is not None
|
| 104 |
+
has_if_false = re.search(r"\bif\s+False\b", code) is not None
|
| 105 |
+
has_if_true = re.search(r"\bif\s+True\b", code) is not None
|
| 106 |
+
has_append_loop = ".append(" in code and "for " in code
|
| 107 |
+
has_double_not = "not not" in code
|
| 108 |
+
has_add_call = "add(" in code
|
| 109 |
+
|
| 110 |
+
if task_id == "rename_variables":
|
| 111 |
+
if has_generic:
|
| 112 |
+
return 0, "heuristic: remove generic names first"
|
| 113 |
+
if has_if_false or "unused" in code:
|
| 114 |
+
return 1, "heuristic: remove dead code"
|
| 115 |
+
if has_append_loop:
|
| 116 |
+
return 2, "heuristic: simplify loop"
|
| 117 |
+
if has_if_true or has_double_not:
|
| 118 |
+
return 3, "heuristic: optimize conditions"
|
| 119 |
+
return 4, "heuristic: inline simple function"
|
| 120 |
+
|
| 121 |
+
if task_id == "remove_dead_code":
|
| 122 |
+
if has_if_false or "unused" in code:
|
| 123 |
+
return 1, "heuristic: remove dead code patterns"
|
| 124 |
+
if has_append_loop:
|
| 125 |
+
return 2, "heuristic: convert append-loop"
|
| 126 |
+
if has_if_true or has_double_not:
|
| 127 |
+
return 3, "heuristic: simplify conditions"
|
| 128 |
+
if has_generic:
|
| 129 |
+
return 0, "heuristic: clean generic names"
|
| 130 |
+
return 4, "heuristic: inline helper"
|
| 131 |
+
|
| 132 |
+
if has_generic:
|
| 133 |
+
return 0, "heuristic: rename generic variables"
|
| 134 |
+
if has_append_loop:
|
| 135 |
+
return 2, "heuristic: simplify loop into listcomp"
|
| 136 |
+
if has_if_false or has_if_true or has_double_not:
|
| 137 |
+
return 3, "heuristic: optimize boolean branches"
|
| 138 |
+
if has_add_call:
|
| 139 |
+
return 4, "heuristic: inline add() call"
|
| 140 |
+
if step_i >= 2:
|
| 141 |
+
return 1, "heuristic: remove remaining dead code"
|
| 142 |
+
return 3, "heuristic: condition optimization as safe default"
|
| 143 |
+
|
| 144 |
+
if not HF_TOKEN:
|
| 145 |
+
return heuristic_action()
|
| 146 |
+
|
| 147 |
+
messages = [
|
| 148 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 149 |
+
{
|
| 150 |
+
"role": "user",
|
| 151 |
+
"content": (
|
| 152 |
+
f"Task: {task_id}\n"
|
| 153 |
+
f"Steps remaining: {state.get('max_steps', 5) - state.get('episode_steps', 0)}\n"
|
| 154 |
+
f"Complexity: {state.get('complexity', 0)}\n\n"
|
| 155 |
+
f"Current code:\n```python\n{state.get('current_code', '')}\n```\n\n"
|
| 156 |
+
"Choose the best action."
|
| 157 |
+
),
|
| 158 |
+
},
|
| 159 |
+
]
|
| 160 |
+
try:
|
| 161 |
+
response = client.chat.completions.create(
|
| 162 |
+
model=MODEL_NAME,
|
| 163 |
+
messages=messages,
|
| 164 |
+
temperature=0.0,
|
| 165 |
+
max_tokens=120,
|
| 166 |
+
)
|
| 167 |
+
raw = (response.choices[0].message.content or "").strip()
|
| 168 |
+
json_blob = raw
|
| 169 |
+
|
| 170 |
+
if "{" not in json_blob or "}" not in json_blob:
|
| 171 |
+
return heuristic_action()
|
| 172 |
+
|
| 173 |
+
match = re.search(r"\{.*\}", json_blob, flags=re.DOTALL)
|
| 174 |
+
if match:
|
| 175 |
+
json_blob = match.group(0)
|
| 176 |
+
|
| 177 |
+
parsed = json.loads(json_blob)
|
| 178 |
+
action = int(parsed.get("action", -1))
|
| 179 |
+
reason = str(parsed.get("reason", ""))
|
| 180 |
+
if 0 <= action <= 4:
|
| 181 |
+
return action, reason or "llm-selected action"
|
| 182 |
+
return heuristic_action()
|
| 183 |
+
except Exception:
|
| 184 |
+
return heuristic_action()
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def run_episode(client: OpenAI, task_id: str, episode_num: int) -> float:
|
| 188 |
+
reset_env(task_id)
|
| 189 |
+
state = get_state()
|
| 190 |
+
|
| 191 |
+
print(
|
| 192 |
+
json.dumps(
|
| 193 |
+
{
|
| 194 |
+
"event": "START",
|
| 195 |
+
"episode": episode_num,
|
| 196 |
+
"task_id": task_id,
|
| 197 |
+
"initial_complexity": state.get("complexity", 0),
|
| 198 |
+
"initial_code_length": len(state.get("current_code", "")),
|
| 199 |
+
"timestamp": time.time(),
|
| 200 |
+
}
|
| 201 |
+
),
|
| 202 |
+
flush=True,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
cumulative_reward = 0.0
|
| 206 |
+
|
| 207 |
+
for step_num in range(1, 6):
|
| 208 |
+
action, reason = choose_action(client, state, task_id)
|
| 209 |
+
result = step_env(action)
|
| 210 |
+
state = get_state()
|
| 211 |
+
|
| 212 |
+
reward_payload = result.get("reward", {})
|
| 213 |
+
raw_reward = float(reward_payload.get("raw", 0.0))
|
| 214 |
+
norm_reward = float(reward_payload.get("normalized", (raw_reward + 32) / 52))
|
| 215 |
+
cumulative_reward += raw_reward
|
| 216 |
+
|
| 217 |
+
print(
|
| 218 |
+
json.dumps(
|
| 219 |
+
{
|
| 220 |
+
"event": "STEP",
|
| 221 |
+
"episode": episode_num,
|
| 222 |
+
"step": step_num,
|
| 223 |
+
"action": action,
|
| 224 |
+
"action_name": ACTION_MEANINGS.get(action, "unknown"),
|
| 225 |
+
"reason": reason,
|
| 226 |
+
"reward": round(raw_reward, 4),
|
| 227 |
+
"normalized_reward": round(norm_reward, 4),
|
| 228 |
+
"cumulative_reward": round(cumulative_reward, 4),
|
| 229 |
+
"changed": result.get("info", {}).get("changed", False),
|
| 230 |
+
"reward_components": reward_payload.get("components", {}),
|
| 231 |
+
"done": result.get("done", False),
|
| 232 |
+
}
|
| 233 |
+
),
|
| 234 |
+
flush=True,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
if result.get("done") or result.get("terminated") or result.get("truncated"):
|
| 238 |
+
break
|
| 239 |
+
|
| 240 |
+
final_state = get_state()
|
| 241 |
+
task_score = grade(task_id, final_state.get("current_code", ""))
|
| 242 |
+
|
| 243 |
+
print(
|
| 244 |
+
json.dumps(
|
| 245 |
+
{
|
| 246 |
+
"event": "END",
|
| 247 |
+
"episode": episode_num,
|
| 248 |
+
"task_id": task_id,
|
| 249 |
+
"cumulative_reward": round(cumulative_reward, 4),
|
| 250 |
+
"normalized_cumulative": round((cumulative_reward + 32) / 52, 4),
|
| 251 |
+
"task_score": round(task_score, 4),
|
| 252 |
+
"final_complexity": final_state.get("complexity", 0),
|
| 253 |
+
"timestamp": time.time(),
|
| 254 |
+
}
|
| 255 |
+
),
|
| 256 |
+
flush=True,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
return task_score
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def main() -> None:
|
| 263 |
+
if not ENV_URL:
|
| 264 |
+
raise SystemExit("ENV_URL is required. Example: ENV_URL=http://localhost:7860")
|
| 265 |
+
|
| 266 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "dummy")
|
| 267 |
+
|
| 268 |
+
scores: List[float] = []
|
| 269 |
+
for i, task_id in enumerate(TASKS, start=1):
|
| 270 |
+
score = run_episode(client, task_id, i)
|
| 271 |
+
scores.append(score)
|
| 272 |
+
|
| 273 |
+
avg_score = sum(scores) / len(scores) if scores else 0.0
|
| 274 |
+
sys.exit(0 if avg_score >= 0.5 else 1)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
if __name__ == "__main__":
|
| 278 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, List, Optional, Sequence
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ObservationModel(BaseModel):
|
| 9 |
+
code_length: float
|
| 10 |
+
complexity_score: float
|
| 11 |
+
runtime_s: float
|
| 12 |
+
error_flag: bool
|
| 13 |
+
|
| 14 |
+
@classmethod
|
| 15 |
+
def from_vector(cls, values: Sequence[float]) -> "ObservationModel":
|
| 16 |
+
vector = list(values)
|
| 17 |
+
if len(vector) != 4:
|
| 18 |
+
raise ValueError(f"observation vector must have length 4, got {len(vector)}")
|
| 19 |
+
return cls(
|
| 20 |
+
code_length=float(vector[0]),
|
| 21 |
+
complexity_score=float(vector[1]),
|
| 22 |
+
runtime_s=float(vector[2]),
|
| 23 |
+
error_flag=bool(vector[3]),
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
def to_vector(self) -> List[float]:
|
| 27 |
+
return [
|
| 28 |
+
float(self.code_length),
|
| 29 |
+
float(self.complexity_score),
|
| 30 |
+
float(self.runtime_s),
|
| 31 |
+
float(int(self.error_flag)),
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ActionModel(BaseModel):
|
| 36 |
+
action: int = Field(ge=0, le=4)
|
| 37 |
+
action_name: Optional[str] = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class RewardModel(BaseModel):
|
| 41 |
+
raw: float
|
| 42 |
+
normalized: float = Field(ge=0.0, le=1.0)
|
| 43 |
+
components: Dict[str, float]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class HealthResponse(BaseModel):
|
| 47 |
+
status: str
|
| 48 |
+
env: str
|
| 49 |
+
version: str
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class CompatibilityHealthResponse(BaseModel):
|
| 53 |
+
status: str
|
| 54 |
+
service: str
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ResetRequest(BaseModel):
|
| 58 |
+
task_id: Optional[str] = None
|
| 59 |
+
seed: Optional[int] = None
|
| 60 |
+
code: Optional[str] = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class StepRequest(BaseModel):
|
| 64 |
+
action: int = Field(ge=0, le=4)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class GradeRequest(BaseModel):
|
| 68 |
+
code: str
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class TaskInfo(BaseModel):
|
| 72 |
+
id: str
|
| 73 |
+
name: str
|
| 74 |
+
description: str
|
| 75 |
+
difficulty: str
|
| 76 |
+
initial_code: str
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class TasksResponse(BaseModel):
|
| 80 |
+
tasks: List[TaskInfo]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class GradeResponse(BaseModel):
|
| 84 |
+
task_id: str
|
| 85 |
+
score: float
|
| 86 |
+
passed: bool
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class StateResponse(BaseModel):
|
| 90 |
+
current_code: str
|
| 91 |
+
episode_steps: int
|
| 92 |
+
max_steps: int
|
| 93 |
+
complexity: float
|
| 94 |
+
last_runtime: float
|
| 95 |
+
last_error: bool
|
| 96 |
+
sample_id: Optional[str]
|
| 97 |
+
language: Optional[str]
|
| 98 |
+
task_id: Optional[str]
|
| 99 |
+
observation: ObservationModel
|
| 100 |
+
observation_vector: List[float]
|
| 101 |
+
action_meanings: Dict[int, str]
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class ResetResponse(BaseModel):
|
| 105 |
+
observation: ObservationModel
|
| 106 |
+
observation_vector: List[float]
|
| 107 |
+
info: Dict[str, Any]
|
| 108 |
+
task_id: Optional[str]
|
| 109 |
+
state: StateResponse
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class StepResponse(BaseModel):
|
| 113 |
+
action: ActionModel
|
| 114 |
+
observation: ObservationModel
|
| 115 |
+
observation_vector: List[float]
|
| 116 |
+
reward: RewardModel
|
| 117 |
+
done: bool
|
| 118 |
+
terminated: bool
|
| 119 |
+
truncated: bool
|
| 120 |
+
info: Dict[str, Any]
|
| 121 |
+
state: StateResponse
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class OptimizeRequest(BaseModel):
|
| 125 |
+
code: str
|
| 126 |
+
task_id: Optional[str] = None
|
| 127 |
+
max_steps: int = Field(default=5, ge=1, le=5)
|
| 128 |
+
use_rl: bool = True
|
| 129 |
+
use_llm: bool = False
|
| 130 |
+
fallback_to_llm: bool = True
|
| 131 |
+
rl_model_path: Optional[str] = None
|
| 132 |
+
api_base_url: Optional[str] = None
|
| 133 |
+
model_name: Optional[str] = None
|
| 134 |
+
api_token: Optional[str] = None
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class OptimizationStep(BaseModel):
|
| 138 |
+
step: int
|
| 139 |
+
action: int
|
| 140 |
+
action_name: str
|
| 141 |
+
reason: str
|
| 142 |
+
source: str
|
| 143 |
+
reward: float
|
| 144 |
+
normalized_reward: float
|
| 145 |
+
changed: bool
|
| 146 |
+
complexity: float
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class OptimizeResponse(BaseModel):
|
| 150 |
+
original_code: str
|
| 151 |
+
optimized_code: str
|
| 152 |
+
diff: str
|
| 153 |
+
steps: List[OptimizationStep]
|
| 154 |
+
cumulative_reward: float
|
| 155 |
+
task_id: Optional[str]
|
| 156 |
+
task_score: Optional[float]
|
openenv.yaml
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: ACRE
|
| 2 |
+
version: "1.0.0"
|
| 3 |
+
description: >
|
| 4 |
+
Autonomous Code Refactoring Environment - an RL environment where an
|
| 5 |
+
agent improves Python code quality using AST-level transformations.
|
| 6 |
+
author: "Nikhil Pratap Singh, Pranav Mangal, Ananya Gupta"
|
| 7 |
+
entrypoint: "openenv_interface:OpenEnvRefactorEnv"
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
+
|
| 11 |
+
tasks:
|
| 12 |
+
- id: rename_variables
|
| 13 |
+
name: "Rename Variables (Easy)"
|
| 14 |
+
description: "Rename generic variable names (x, tmp) to descriptive ones"
|
| 15 |
+
difficulty: easy
|
| 16 |
+
reward_range: [0.0, 1.0]
|
| 17 |
+
max_steps: 5
|
| 18 |
+
|
| 19 |
+
- id: remove_dead_code
|
| 20 |
+
name: "Remove Dead Code (Medium)"
|
| 21 |
+
description: "Remove unreachable statements, if-False blocks, and unused assignments"
|
| 22 |
+
difficulty: medium
|
| 23 |
+
reward_range: [0.0, 1.0]
|
| 24 |
+
max_steps: 5
|
| 25 |
+
|
| 26 |
+
- id: full_refactor
|
| 27 |
+
name: "Full Refactor (Hard)"
|
| 28 |
+
description: "Apply all transformations - rename, dead code removal, loop simplification, condition optimization, and function inlining"
|
| 29 |
+
difficulty: hard
|
| 30 |
+
reward_range: [0.0, 1.0]
|
| 31 |
+
max_steps: 5
|
| 32 |
+
|
| 33 |
+
observation_space:
|
| 34 |
+
type: Box
|
| 35 |
+
shape: [4]
|
| 36 |
+
dtype: float32
|
| 37 |
+
low: [0.0, 0.0, 0.0, 0.0]
|
| 38 |
+
high: [inf, inf, inf, 1.0]
|
| 39 |
+
fields:
|
| 40 |
+
- code_length
|
| 41 |
+
- complexity_score
|
| 42 |
+
- runtime_s
|
| 43 |
+
- error_flag
|
| 44 |
+
|
| 45 |
+
action_space:
|
| 46 |
+
type: Discrete
|
| 47 |
+
n: 5
|
| 48 |
+
actions:
|
| 49 |
+
0: rename_variable
|
| 50 |
+
1: remove_dead_code
|
| 51 |
+
2: simplify_loop
|
| 52 |
+
3: optimize_condition
|
| 53 |
+
4: inline_function
|
| 54 |
+
|
| 55 |
+
api:
|
| 56 |
+
health: "GET /"
|
| 57 |
+
reset: "POST /reset"
|
| 58 |
+
step: "POST /step"
|
| 59 |
+
state: "GET /state"
|
| 60 |
+
tasks: "GET /tasks"
|
| 61 |
+
grade: "POST /tasks/{task_id}/grade"
|
| 62 |
+
|
| 63 |
+
reward:
|
| 64 |
+
raw_range: [-32, 20]
|
| 65 |
+
normalized_range: [0.0, 1.0]
|
| 66 |
+
formula: "(raw + 32) / 52"
|
| 67 |
+
components:
|
| 68 |
+
success: { max: 10, min: -10 }
|
| 69 |
+
complexity: { max: 5, min: -5 }
|
| 70 |
+
performance: { max: 5, min: -2 }
|
| 71 |
+
error: { max: 0, min: -15 }
|
| 72 |
+
no_change: { max: 0, min: -2 }
|
| 73 |
+
|
| 74 |
+
validation:
|
| 75 |
+
python_api:
|
| 76 |
+
reset: "ObservationModel"
|
| 77 |
+
step: "(ObservationModel, RewardModel, done, info)"
|
| 78 |
+
state: "StateResponse"
|
| 79 |
+
http_api:
|
| 80 |
+
health: "GET /"
|
| 81 |
+
reset: "POST /reset"
|
| 82 |
+
step: "POST /step"
|
| 83 |
+
state: "GET /state"
|
| 84 |
+
tasks: "GET /tasks"
|
| 85 |
+
grade: "POST /tasks/{task_id}/grade"
|
openenv_interface.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from openenv.env import Env as OpenEnvBase
|
| 7 |
+
except Exception: # pragma: no cover
|
| 8 |
+
class OpenEnvBase:
|
| 9 |
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
| 10 |
+
return None
|
| 11 |
+
|
| 12 |
+
from acre.datasets.code_samples import CodeSample, CodeSampleDataset
|
| 13 |
+
from acre.env.refactor_env import RefactorEnv
|
| 14 |
+
from acre.tasks.task_registry import TaskRegistry
|
| 15 |
+
from models import ActionModel, ObservationModel, RewardModel, StateResponse
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class OpenEnvRefactorEnv(OpenEnvBase):
|
| 19 |
+
"""
|
| 20 |
+
Canonical OpenEnv interface for ACRE.
|
| 21 |
+
|
| 22 |
+
This wrapper keeps the strict hackathon contract:
|
| 23 |
+
- reset() -> ObservationModel
|
| 24 |
+
- step(action) -> (ObservationModel, RewardModel, done, info)
|
| 25 |
+
- state() -> StateResponse
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
*,
|
| 31 |
+
env: Optional[RefactorEnv] = None,
|
| 32 |
+
registry: Optional[TaskRegistry] = None,
|
| 33 |
+
) -> None:
|
| 34 |
+
super().__init__(
|
| 35 |
+
name="ACRE",
|
| 36 |
+
state_space="ObservationModel",
|
| 37 |
+
action_space="ActionModel",
|
| 38 |
+
episode_max_length=RefactorEnv.MAX_STEPS,
|
| 39 |
+
)
|
| 40 |
+
self._env = env or RefactorEnv()
|
| 41 |
+
self._registry = registry or TaskRegistry()
|
| 42 |
+
self._task_id: Optional[str] = None
|
| 43 |
+
self._last_reset_info: Dict[str, Any] = {}
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def action_meanings(self) -> Dict[int, str]:
|
| 47 |
+
return self._env.ACTION_MEANINGS
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def last_reset_info(self) -> Dict[str, Any]:
|
| 51 |
+
return dict(self._last_reset_info)
|
| 52 |
+
|
| 53 |
+
def _load_episode_source(self, *, task_id: Optional[str], code: Optional[str]) -> None:
|
| 54 |
+
initial_code = code
|
| 55 |
+
if initial_code is None and task_id:
|
| 56 |
+
task = self._registry.get_task(task_id)
|
| 57 |
+
if task is None:
|
| 58 |
+
raise ValueError(f"Task '{task_id}' not found")
|
| 59 |
+
initial_code = task.initial_code
|
| 60 |
+
|
| 61 |
+
if initial_code is None:
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
self._env.dataset = CodeSampleDataset(
|
| 65 |
+
[
|
| 66 |
+
CodeSample(
|
| 67 |
+
id=task_id or "custom",
|
| 68 |
+
language="python",
|
| 69 |
+
code=initial_code,
|
| 70 |
+
)
|
| 71 |
+
]
|
| 72 |
+
)
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
def reset(
|
| 76 |
+
self,
|
| 77 |
+
*,
|
| 78 |
+
seed: Optional[int] = None,
|
| 79 |
+
task_id: Optional[str] = None,
|
| 80 |
+
code: Optional[str] = None,
|
| 81 |
+
) -> ObservationModel:
|
| 82 |
+
self._task_id = task_id
|
| 83 |
+
self._load_episode_source(task_id=task_id, code=code)
|
| 84 |
+
observation, info = self._env.reset(seed=seed)
|
| 85 |
+
self._last_reset_info = dict(info)
|
| 86 |
+
return ObservationModel.from_vector(observation.tolist())
|
| 87 |
+
|
| 88 |
+
def step(self, action: int | ActionModel) -> Tuple[ObservationModel, RewardModel, bool, Dict[str, Any]]:
|
| 89 |
+
action_value = action.action if isinstance(action, ActionModel) else int(action)
|
| 90 |
+
observation, raw_reward, terminated, truncated, info = self._env.step(action_value)
|
| 91 |
+
reward = RewardModel(
|
| 92 |
+
raw=float(raw_reward),
|
| 93 |
+
normalized=float(info.get("normalized_reward", 0.0)),
|
| 94 |
+
components=dict(info.get("reward_components", {})),
|
| 95 |
+
)
|
| 96 |
+
done = bool(terminated or truncated)
|
| 97 |
+
return ObservationModel.from_vector(observation.tolist()), reward, done, dict(info)
|
| 98 |
+
|
| 99 |
+
def state(self) -> StateResponse:
|
| 100 |
+
raw_state = self._env.state()
|
| 101 |
+
observation_vector = list(raw_state.get("observation", [0.0, 0.0, 0.0, 0.0]))
|
| 102 |
+
observation = ObservationModel.from_vector(observation_vector)
|
| 103 |
+
return StateResponse(
|
| 104 |
+
current_code=str(raw_state.get("current_code", "")),
|
| 105 |
+
episode_steps=int(raw_state.get("episode_steps", 0)),
|
| 106 |
+
max_steps=int(raw_state.get("max_steps", RefactorEnv.MAX_STEPS)),
|
| 107 |
+
complexity=float(raw_state.get("complexity", 0.0)),
|
| 108 |
+
last_runtime=float(raw_state.get("last_runtime", 0.0)),
|
| 109 |
+
last_error=bool(raw_state.get("last_error", False)),
|
| 110 |
+
sample_id=raw_state.get("sample_id"),
|
| 111 |
+
language=raw_state.get("language"),
|
| 112 |
+
task_id=self._task_id,
|
| 113 |
+
observation=observation,
|
| 114 |
+
observation_vector=observation.to_vector(),
|
| 115 |
+
action_meanings=dict(raw_state.get("action_meanings", {})),
|
| 116 |
+
)
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.109.0
|
| 2 |
+
uvicorn[standard]>=0.27.0
|
| 3 |
+
numpy>=1.26
|
| 4 |
+
gymnasium
|
| 5 |
+
stable-baselines3
|
| 6 |
+
radon>=6.0.1
|
| 7 |
+
openai>=1.0.0
|
| 8 |
+
openenv>=0.1.13
|
| 9 |
+
requests>=2.31.0
|
| 10 |
+
pydantic>=2.0.0
|
| 11 |
+
typing_extensions>=4.0.0
|
server.py
ADDED
|
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACRE OpenEnv HTTP server.
|
| 3 |
+
|
| 4 |
+
Endpoints (all required by OpenEnv spec):
|
| 5 |
+
GET / — health check (must return HTTP 200)
|
| 6 |
+
POST /reset — reset environment, returns observation + info
|
| 7 |
+
POST /step — take one step, returns obs/reward/done/info
|
| 8 |
+
GET /state — full current state snapshot
|
| 9 |
+
GET /tasks — list all tasks with initial code
|
| 10 |
+
POST /tasks/{task_id}/grade — grade code for a specific task
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import difflib
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
import json
|
| 18 |
+
from typing import Optional
|
| 19 |
+
|
| 20 |
+
import uvicorn
|
| 21 |
+
import numpy as np
|
| 22 |
+
from fastapi import FastAPI, HTTPException
|
| 23 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 24 |
+
from fastapi.responses import HTMLResponse
|
| 25 |
+
from openai import OpenAI
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from stable_baselines3 import PPO
|
| 29 |
+
except Exception:
|
| 30 |
+
PPO = None # type: ignore[assignment]
|
| 31 |
+
|
| 32 |
+
from acre.tasks.task_registry import TaskRegistry
|
| 33 |
+
from models import (
|
| 34 |
+
ActionModel,
|
| 35 |
+
CompatibilityHealthResponse,
|
| 36 |
+
GradeRequest,
|
| 37 |
+
GradeResponse,
|
| 38 |
+
HealthResponse,
|
| 39 |
+
OptimizationStep,
|
| 40 |
+
OptimizeRequest,
|
| 41 |
+
OptimizeResponse,
|
| 42 |
+
ResetRequest,
|
| 43 |
+
ResetResponse,
|
| 44 |
+
StateResponse,
|
| 45 |
+
StepRequest,
|
| 46 |
+
StepResponse,
|
| 47 |
+
TaskInfo,
|
| 48 |
+
TasksResponse,
|
| 49 |
+
)
|
| 50 |
+
from openenv_interface import OpenEnvRefactorEnv
|
| 51 |
+
|
| 52 |
+
DEFAULT_API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
|
| 53 |
+
DEFAULT_MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
|
| 54 |
+
DEFAULT_RL_MODEL_PATH = os.getenv("RL_MODEL_PATH", "acre_agent.zip")
|
| 55 |
+
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
# App setup
|
| 58 |
+
# ---------------------------------------------------------------------------
|
| 59 |
+
|
| 60 |
+
app = FastAPI(
|
| 61 |
+
title="ACRE — Autonomous Code Refactoring Environment",
|
| 62 |
+
description="OpenEnv-compatible RL environment for Python code refactoring.",
|
| 63 |
+
version="1.0.0",
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
app.add_middleware(
|
| 67 |
+
CORSMiddleware,
|
| 68 |
+
allow_origins=["*"],
|
| 69 |
+
allow_methods=["*"],
|
| 70 |
+
allow_headers=["*"],
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Global singletons
|
| 74 |
+
registry = TaskRegistry()
|
| 75 |
+
_env: Optional[OpenEnvRefactorEnv] = None
|
| 76 |
+
_rl_model_cache: dict[str, object] = {}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_env() -> OpenEnvRefactorEnv:
|
| 80 |
+
global _env
|
| 81 |
+
if _env is None:
|
| 82 |
+
_env = OpenEnvRefactorEnv(registry=registry)
|
| 83 |
+
return _env
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _state_response() -> StateResponse:
|
| 87 |
+
return get_env().state()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _choose_action_heuristic(code: str, task_id: Optional[str]) -> int:
|
| 91 |
+
has_generic = re.search(r"\b(x|tmp|i)\b", code) is not None
|
| 92 |
+
has_if_false = re.search(r"\bif\s+False\b", code) is not None
|
| 93 |
+
has_if_true = re.search(r"\bif\s+True\b", code) is not None
|
| 94 |
+
has_append_loop = ".append(" in code and "for " in code
|
| 95 |
+
has_double_not = "not not" in code
|
| 96 |
+
has_add_call = "add(" in code
|
| 97 |
+
|
| 98 |
+
if task_id == "rename_variables":
|
| 99 |
+
if has_generic:
|
| 100 |
+
return 0
|
| 101 |
+
if has_if_false or "unused" in code:
|
| 102 |
+
return 1
|
| 103 |
+
if has_append_loop:
|
| 104 |
+
return 2
|
| 105 |
+
if has_if_true or has_double_not:
|
| 106 |
+
return 3
|
| 107 |
+
return 4
|
| 108 |
+
|
| 109 |
+
if task_id == "remove_dead_code":
|
| 110 |
+
if has_if_false or "unused" in code:
|
| 111 |
+
return 1
|
| 112 |
+
if has_append_loop:
|
| 113 |
+
return 2
|
| 114 |
+
if has_if_true or has_double_not:
|
| 115 |
+
return 3
|
| 116 |
+
if has_generic:
|
| 117 |
+
return 0
|
| 118 |
+
return 4
|
| 119 |
+
|
| 120 |
+
if has_generic:
|
| 121 |
+
return 0
|
| 122 |
+
if has_append_loop:
|
| 123 |
+
return 2
|
| 124 |
+
if has_if_false or has_if_true or has_double_not:
|
| 125 |
+
return 3
|
| 126 |
+
if has_add_call:
|
| 127 |
+
return 4
|
| 128 |
+
return 1
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _choose_action_llm(
|
| 132 |
+
*,
|
| 133 |
+
code: str,
|
| 134 |
+
task_id: Optional[str],
|
| 135 |
+
step_index: int,
|
| 136 |
+
max_steps: int,
|
| 137 |
+
api_base_url: str,
|
| 138 |
+
model_name: str,
|
| 139 |
+
api_token: str,
|
| 140 |
+
) -> tuple[int, str, str]:
|
| 141 |
+
if not api_token.strip():
|
| 142 |
+
return _choose_action_heuristic(code, task_id), "empty token -> heuristic", "heuristic"
|
| 143 |
+
|
| 144 |
+
client = OpenAI(base_url=api_base_url, api_key=api_token)
|
| 145 |
+
messages = [
|
| 146 |
+
{
|
| 147 |
+
"role": "system",
|
| 148 |
+
"content": (
|
| 149 |
+
"You are a code-refactoring action selector. Return ONLY compact JSON: "
|
| 150 |
+
'{"action": <0-4>, "reason": "..."}.\n'
|
| 151 |
+
"Actions: 0=rename_variable,1=remove_dead_code,2=simplify_loop,3=optimize_condition,4=inline_function"
|
| 152 |
+
),
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"role": "user",
|
| 156 |
+
"content": (
|
| 157 |
+
f"task_id={task_id or 'auto'}\n"
|
| 158 |
+
f"step={step_index}/{max_steps}\n"
|
| 159 |
+
"Current code:\n"
|
| 160 |
+
f"```python\n{code}\n```"
|
| 161 |
+
),
|
| 162 |
+
},
|
| 163 |
+
]
|
| 164 |
+
try:
|
| 165 |
+
resp = client.chat.completions.create(
|
| 166 |
+
model=model_name,
|
| 167 |
+
messages=messages,
|
| 168 |
+
temperature=0.0,
|
| 169 |
+
max_tokens=120,
|
| 170 |
+
)
|
| 171 |
+
raw = (resp.choices[0].message.content or "").strip()
|
| 172 |
+
m = re.search(r"\{.*\}", raw, flags=re.DOTALL)
|
| 173 |
+
blob = m.group(0) if m else raw
|
| 174 |
+
parsed = json.loads(blob)
|
| 175 |
+
action = int(parsed.get("action", -1))
|
| 176 |
+
reason = str(parsed.get("reason", "llm-selected action"))
|
| 177 |
+
if 0 <= action <= 4:
|
| 178 |
+
return action, reason, "llm"
|
| 179 |
+
except Exception as exc:
|
| 180 |
+
return _choose_action_heuristic(code, task_id), f"llm error -> heuristic: {exc}", "heuristic"
|
| 181 |
+
|
| 182 |
+
return _choose_action_heuristic(code, task_id), "invalid llm output -> heuristic", "heuristic"
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def _choose_action_rl(observation: list[float], model_path: str) -> tuple[Optional[int], str, str]:
|
| 186 |
+
if PPO is None:
|
| 187 |
+
return None, "stable-baselines3 unavailable", "rl"
|
| 188 |
+
if not os.path.exists(model_path):
|
| 189 |
+
return None, f"rl model not found: {model_path}", "rl"
|
| 190 |
+
|
| 191 |
+
try:
|
| 192 |
+
model = _rl_model_cache.get(model_path)
|
| 193 |
+
if model is None:
|
| 194 |
+
model = PPO.load(model_path)
|
| 195 |
+
_rl_model_cache[model_path] = model
|
| 196 |
+
|
| 197 |
+
obs = np.asarray(observation, dtype=np.float32)
|
| 198 |
+
action, _ = model.predict(obs, deterministic=True)
|
| 199 |
+
action_i = int(action)
|
| 200 |
+
if 0 <= action_i <= 4:
|
| 201 |
+
return action_i, "rl policy action", "rl"
|
| 202 |
+
return None, f"invalid rl action: {action_i}", "rl"
|
| 203 |
+
except Exception as exc:
|
| 204 |
+
return None, f"rl failure: {exc}", "rl"
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _demo_html() -> str:
|
| 208 |
+
return """<!doctype html>
|
| 209 |
+
<html lang=\"en\">
|
| 210 |
+
<head>
|
| 211 |
+
<meta charset=\"utf-8\" />
|
| 212 |
+
<meta name=\"viewport\" content=\"width=device-width, initial-scale=1\" />
|
| 213 |
+
<title>ACRE Refactor Demo</title>
|
| 214 |
+
<style>
|
| 215 |
+
@import url('https://fonts.googleapis.com/css2?family=Space+Grotesk:wght@400;600;700&display=swap');
|
| 216 |
+
:root {
|
| 217 |
+
--bg0: #0b1f2a;
|
| 218 |
+
--bg1: #14344a;
|
| 219 |
+
--ink: #eaf7ff;
|
| 220 |
+
--muted: #a7c8db;
|
| 221 |
+
--brand: #1ec28b;
|
| 222 |
+
--warn: #ffcb47;
|
| 223 |
+
--panel: rgba(8, 24, 36, 0.72);
|
| 224 |
+
--stroke: rgba(140, 197, 225, 0.35);
|
| 225 |
+
}
|
| 226 |
+
* { box-sizing: border-box; }
|
| 227 |
+
body {
|
| 228 |
+
margin: 0;
|
| 229 |
+
color: var(--ink);
|
| 230 |
+
font-family: 'Space Grotesk', sans-serif;
|
| 231 |
+
background:
|
| 232 |
+
radial-gradient(circle at 12% 18%, rgba(30, 194, 139, 0.28), transparent 35%),
|
| 233 |
+
radial-gradient(circle at 88% 8%, rgba(255, 203, 71, 0.22), transparent 30%),
|
| 234 |
+
linear-gradient(150deg, var(--bg0), var(--bg1));
|
| 235 |
+
min-height: 100vh;
|
| 236 |
+
}
|
| 237 |
+
.wrap {
|
| 238 |
+
max-width: 1200px;
|
| 239 |
+
margin: 0 auto;
|
| 240 |
+
padding: 28px 20px 40px;
|
| 241 |
+
}
|
| 242 |
+
h1 {
|
| 243 |
+
margin: 0 0 6px;
|
| 244 |
+
font-size: clamp(1.6rem, 2vw + 1rem, 2.6rem);
|
| 245 |
+
letter-spacing: 0.2px;
|
| 246 |
+
}
|
| 247 |
+
.sub { margin: 0 0 20px; color: var(--muted); }
|
| 248 |
+
.grid {
|
| 249 |
+
display: grid;
|
| 250 |
+
grid-template-columns: 1fr;
|
| 251 |
+
gap: 16px;
|
| 252 |
+
}
|
| 253 |
+
.panel {
|
| 254 |
+
border: 1px solid var(--stroke);
|
| 255 |
+
border-radius: 14px;
|
| 256 |
+
background: var(--panel);
|
| 257 |
+
backdrop-filter: blur(4px);
|
| 258 |
+
padding: 14px;
|
| 259 |
+
}
|
| 260 |
+
.controls {
|
| 261 |
+
display: grid;
|
| 262 |
+
grid-template-columns: 1fr 1fr;
|
| 263 |
+
gap: 8px;
|
| 264 |
+
margin-bottom: 10px;
|
| 265 |
+
}
|
| 266 |
+
textarea, pre {
|
| 267 |
+
width: 100%;
|
| 268 |
+
min-height: 260px;
|
| 269 |
+
border: 1px solid var(--stroke);
|
| 270 |
+
border-radius: 10px;
|
| 271 |
+
padding: 12px;
|
| 272 |
+
background: rgba(1, 13, 24, 0.82);
|
| 273 |
+
color: #dcf4ff;
|
| 274 |
+
font-family: Consolas, 'Courier New', monospace;
|
| 275 |
+
font-size: 13px;
|
| 276 |
+
line-height: 1.4;
|
| 277 |
+
overflow: auto;
|
| 278 |
+
white-space: pre;
|
| 279 |
+
}
|
| 280 |
+
button, select {
|
| 281 |
+
border: 1px solid var(--stroke);
|
| 282 |
+
border-radius: 10px;
|
| 283 |
+
padding: 10px 12px;
|
| 284 |
+
background: rgba(11, 36, 52, 0.9);
|
| 285 |
+
color: var(--ink);
|
| 286 |
+
font-weight: 600;
|
| 287 |
+
}
|
| 288 |
+
button.primary {
|
| 289 |
+
background: linear-gradient(120deg, #19a7ff, #1ec28b);
|
| 290 |
+
color: #032235;
|
| 291 |
+
border: none;
|
| 292 |
+
}
|
| 293 |
+
.cols {
|
| 294 |
+
display: grid;
|
| 295 |
+
grid-template-columns: 1fr;
|
| 296 |
+
gap: 14px;
|
| 297 |
+
}
|
| 298 |
+
.meta {
|
| 299 |
+
color: var(--muted);
|
| 300 |
+
font-size: 0.92rem;
|
| 301 |
+
margin-top: 8px;
|
| 302 |
+
}
|
| 303 |
+
.badge {
|
| 304 |
+
color: #082b22;
|
| 305 |
+
background: var(--brand);
|
| 306 |
+
border-radius: 999px;
|
| 307 |
+
padding: 2px 9px;
|
| 308 |
+
font-size: 12px;
|
| 309 |
+
font-weight: 700;
|
| 310 |
+
}
|
| 311 |
+
.warn {
|
| 312 |
+
color: #2a1c00;
|
| 313 |
+
background: var(--warn);
|
| 314 |
+
}
|
| 315 |
+
@media (min-width: 900px) {
|
| 316 |
+
.cols { grid-template-columns: 1fr 1fr; }
|
| 317 |
+
}
|
| 318 |
+
</style>
|
| 319 |
+
</head>
|
| 320 |
+
<body>
|
| 321 |
+
<div class=\"wrap\">
|
| 322 |
+
<h1>ACRE Live Refactor Arena</h1>
|
| 323 |
+
<p class=\"sub\">Paste old code, run the agent, and compare before and after with a full diff and step-by-step rewards.</p>
|
| 324 |
+
|
| 325 |
+
<div class=\"panel\">
|
| 326 |
+
<div class=\"controls\">
|
| 327 |
+
<button onclick=\"loadExample(1)\">Load Example 1</button>
|
| 328 |
+
<button onclick=\"loadExample(2)\">Load Example 2</button>
|
| 329 |
+
<select id=\"task\">
|
| 330 |
+
<option value=\"\">Auto strategy</option>
|
| 331 |
+
<option value=\"rename_variables\">rename_variables</option>
|
| 332 |
+
<option value=\"remove_dead_code\">remove_dead_code</option>
|
| 333 |
+
<option value=\"full_refactor\">full_refactor</option>
|
| 334 |
+
</select>
|
| 335 |
+
<button class=\"primary\" onclick=\"runOptimize()\">Run Optimization</button>
|
| 336 |
+
</div>
|
| 337 |
+
<div class=\"controls\" style=\"margin-bottom: 10px;\">
|
| 338 |
+
<select id=\"mode\">
|
| 339 |
+
<option value=\"rl_then_llm\">RL First -> LLM Fallback</option>
|
| 340 |
+
<option value=\"heuristic\">Heuristic Agent (no API key)</option>
|
| 341 |
+
<option value=\"llm\">LLM Agent (OpenAI-compatible API)</option>
|
| 342 |
+
</select>
|
| 343 |
+
<input id=\"rlModelPath\" placeholder=\"RL model path\" value=\"acre_agent.zip\" style=\"border:1px solid var(--stroke);border-radius:10px;padding:10px 12px;background:rgba(1,13,24,0.82);color:#dcf4ff;\" />
|
| 344 |
+
<input id=\"baseUrl\" placeholder=\"API base URL (optional)\" value=\"https://api.openai.com/v1\" style=\"border:1px solid var(--stroke);border-radius:10px;padding:10px 12px;background:rgba(1,13,24,0.82);color:#dcf4ff;\" />
|
| 345 |
+
<input id=\"modelName\" placeholder=\"Model name (optional)\" value=\"gpt-4o-mini\" style=\"border:1px solid var(--stroke);border-radius:10px;padding:10px 12px;background:rgba(1,13,24,0.82);color:#dcf4ff;\" />
|
| 346 |
+
<input id=\"apiToken\" type=\"password\" placeholder=\"Paste API token here for LLM mode\" style=\"border:1px solid var(--stroke);border-radius:10px;padding:10px 12px;background:rgba(1,13,24,0.82);color:#dcf4ff;\" />
|
| 347 |
+
</div>
|
| 348 |
+
<div class=\"controls\" style=\"margin-bottom: 10px;\">
|
| 349 |
+
<label style=\"display:flex;align-items:center;gap:8px;padding:8px 10px;border:1px solid var(--stroke);border-radius:10px;\">
|
| 350 |
+
<input id=\"autoSuggest\" type=\"checkbox\" />
|
| 351 |
+
Auto suggest after typing pause
|
| 352 |
+
</label>
|
| 353 |
+
</div>
|
| 354 |
+
<textarea id=\"input\" spellcheck=\"false\" placeholder=\"Paste your Python code here...\"></textarea>
|
| 355 |
+
<p class=\"meta\" id=\"status\">Status: ready</p>
|
| 356 |
+
</div>
|
| 357 |
+
|
| 358 |
+
<div class=\"cols\" style=\"margin-top: 14px\">
|
| 359 |
+
<div class=\"panel\">
|
| 360 |
+
<h3>Original Code</h3>
|
| 361 |
+
<pre id=\"original\"></pre>
|
| 362 |
+
</div>
|
| 363 |
+
<div class=\"panel\">
|
| 364 |
+
<h3>Optimized Code</h3>
|
| 365 |
+
<pre id=\"optimized\"></pre>
|
| 366 |
+
</div>
|
| 367 |
+
</div>
|
| 368 |
+
|
| 369 |
+
<div class=\"panel\" style=\"margin-top: 14px\">
|
| 370 |
+
<h3>Diff</h3>
|
| 371 |
+
<pre id=\"diff\"></pre>
|
| 372 |
+
</div>
|
| 373 |
+
|
| 374 |
+
<div class=\"panel\" style=\"margin-top: 14px\">
|
| 375 |
+
<h3>Step Logs</h3>
|
| 376 |
+
<pre id=\"steps\"></pre>
|
| 377 |
+
</div>
|
| 378 |
+
</div>
|
| 379 |
+
|
| 380 |
+
<script>
|
| 381 |
+
const EX1 = `def compute(x, y, tmp):\n tmp = x + y\n x = tmp * 2\n result = x\n return result\n`;
|
| 382 |
+
const EX2 = `def add(p, q):\n return p + q\n\ndef compute(x, data, tmp):\n result = []\n for item in data:\n result.append(item * 2)\n if False:\n y = 999\n if True:\n val = add(x, tmp)\n unused = 0\n flag = not not True\n return val\n print(\"dead\")\n`;
|
| 383 |
+
let autoTimer = null;
|
| 384 |
+
|
| 385 |
+
function loadExample(i) {
|
| 386 |
+
document.getElementById('input').value = i === 1 ? EX1 : EX2;
|
| 387 |
+
document.getElementById('status').textContent = `Status: loaded example ${i}`;
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
async function runOptimize() {
|
| 391 |
+
const code = document.getElementById('input').value;
|
| 392 |
+
const task = document.getElementById('task').value || null;
|
| 393 |
+
const mode = document.getElementById('mode').value;
|
| 394 |
+
const useRl = mode === 'rl_then_llm';
|
| 395 |
+
const useLlm = mode === 'llm' || mode === 'rl_then_llm';
|
| 396 |
+
const fallbackToLlm = mode === 'rl_then_llm';
|
| 397 |
+
const rlModelPath = document.getElementById('rlModelPath').value || null;
|
| 398 |
+
const apiToken = document.getElementById('apiToken').value || null;
|
| 399 |
+
const apiBaseUrl = document.getElementById('baseUrl').value || null;
|
| 400 |
+
const modelName = document.getElementById('modelName').value || null;
|
| 401 |
+
if (!code.trim()) {
|
| 402 |
+
document.getElementById('status').innerHTML = 'Status: <span class=\"badge warn\">please paste code first</span>';
|
| 403 |
+
return;
|
| 404 |
+
}
|
| 405 |
+
if (mode === 'llm' && (!apiToken || !apiToken.trim())) {
|
| 406 |
+
document.getElementById('status').innerHTML = 'Status: <span class=\"badge warn\">paste API token for LLM mode</span>';
|
| 407 |
+
return;
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
document.getElementById('status').textContent = 'Status: running optimization...';
|
| 411 |
+
try {
|
| 412 |
+
const res = await fetch('/optimize', {
|
| 413 |
+
method: 'POST',
|
| 414 |
+
headers: {'Content-Type': 'application/json'},
|
| 415 |
+
body: JSON.stringify({
|
| 416 |
+
code,
|
| 417 |
+
task_id: task,
|
| 418 |
+
max_steps: 5,
|
| 419 |
+
use_rl: useRl,
|
| 420 |
+
use_llm: useLlm,
|
| 421 |
+
fallback_to_llm: fallbackToLlm,
|
| 422 |
+
rl_model_path: rlModelPath,
|
| 423 |
+
api_base_url: apiBaseUrl,
|
| 424 |
+
model_name: modelName,
|
| 425 |
+
api_token: apiToken,
|
| 426 |
+
})
|
| 427 |
+
});
|
| 428 |
+
const data = await res.json();
|
| 429 |
+
if (!res.ok) {
|
| 430 |
+
throw new Error(data.detail || 'request failed');
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
document.getElementById('original').textContent = data.original_code;
|
| 434 |
+
document.getElementById('optimized').textContent = data.optimized_code;
|
| 435 |
+
document.getElementById('diff').textContent = data.diff || '(no diff)';
|
| 436 |
+
document.getElementById('steps').textContent = JSON.stringify(data.steps, null, 2);
|
| 437 |
+
|
| 438 |
+
const scoreText = data.task_score === null ? 'n/a' : data.task_score;
|
| 439 |
+
document.getElementById('status').innerHTML = `Status: <span class=\"badge\">done</span> cumulative_reward=${data.cumulative_reward.toFixed(2)} task_score=${scoreText}`;
|
| 440 |
+
} catch (err) {
|
| 441 |
+
document.getElementById('status').innerHTML = `Status: <span class=\"badge warn\">error</span> ${err.message}`;
|
| 442 |
+
}
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
loadExample(1);
|
| 446 |
+
document.getElementById('input').addEventListener('input', () => {
|
| 447 |
+
if (!document.getElementById('autoSuggest').checked) {
|
| 448 |
+
return;
|
| 449 |
+
}
|
| 450 |
+
if (autoTimer) {
|
| 451 |
+
clearTimeout(autoTimer);
|
| 452 |
+
}
|
| 453 |
+
autoTimer = setTimeout(() => {
|
| 454 |
+
runOptimize();
|
| 455 |
+
}, 1200);
|
| 456 |
+
});
|
| 457 |
+
</script>
|
| 458 |
+
</body>
|
| 459 |
+
</html>"""
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
# ---------------------------------------------------------------------------
|
| 463 |
+
# Routes
|
| 464 |
+
# ---------------------------------------------------------------------------
|
| 465 |
+
|
| 466 |
+
@app.get("/", response_model=HealthResponse)
|
| 467 |
+
def health() -> HealthResponse:
|
| 468 |
+
"""Health check — OpenEnv pings this URL to verify the Space is live."""
|
| 469 |
+
return HealthResponse(status="ok", env="ACRE", version="1.0.0")
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
@app.get("/health", response_model=CompatibilityHealthResponse)
|
| 473 |
+
def health_compat() -> CompatibilityHealthResponse:
|
| 474 |
+
"""Compatibility health route used by some OpenEnv reference environments."""
|
| 475 |
+
return CompatibilityHealthResponse(status="healthy", service="acre-env")
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
@app.get("/demo", response_class=HTMLResponse)
|
| 479 |
+
def demo_ui() -> HTMLResponse:
|
| 480 |
+
"""Simple UI to compare original and optimized code side-by-side."""
|
| 481 |
+
return HTMLResponse(content=_demo_html())
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
@app.post("/reset", response_model=ResetResponse)
|
| 485 |
+
def reset(req: ResetRequest = ResetRequest()) -> ResetResponse:
|
| 486 |
+
"""Reset the environment. Optionally load a task's initial code."""
|
| 487 |
+
env = get_env()
|
| 488 |
+
try:
|
| 489 |
+
obs = env.reset(seed=req.seed, task_id=req.task_id, code=req.code)
|
| 490 |
+
except ValueError as exc:
|
| 491 |
+
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
| 492 |
+
return ResetResponse(
|
| 493 |
+
observation=obs,
|
| 494 |
+
observation_vector=obs.to_vector(),
|
| 495 |
+
info=env.last_reset_info,
|
| 496 |
+
task_id=req.task_id,
|
| 497 |
+
state=_state_response(),
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
@app.post("/step", response_model=StepResponse)
|
| 502 |
+
def step(req: StepRequest) -> StepResponse:
|
| 503 |
+
"""Take one refactoring step."""
|
| 504 |
+
env = get_env()
|
| 505 |
+
if not (0 <= req.action <= 4):
|
| 506 |
+
raise HTTPException(status_code=400, detail="action must be 0–4")
|
| 507 |
+
|
| 508 |
+
obs, reward, done, info = env.step(req.action)
|
| 509 |
+
action_name = str(info.get("action_name", env.action_meanings.get(req.action, "unknown")))
|
| 510 |
+
|
| 511 |
+
return StepResponse(
|
| 512 |
+
action=ActionModel(action=req.action, action_name=action_name),
|
| 513 |
+
observation=obs,
|
| 514 |
+
observation_vector=obs.to_vector(),
|
| 515 |
+
reward=reward,
|
| 516 |
+
done=done,
|
| 517 |
+
terminated=done,
|
| 518 |
+
truncated=False,
|
| 519 |
+
info=info,
|
| 520 |
+
state=_state_response(),
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
@app.get("/state", response_model=StateResponse)
|
| 525 |
+
def state() -> StateResponse:
|
| 526 |
+
"""Return full current environment state (OpenEnv spec requirement)."""
|
| 527 |
+
return _state_response()
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
@app.get("/tasks", response_model=TasksResponse)
|
| 531 |
+
def list_tasks() -> TasksResponse:
|
| 532 |
+
"""Enumerate all tasks (easy → medium → hard)."""
|
| 533 |
+
return TasksResponse(tasks=[TaskInfo.model_validate(t) for t in registry.list_tasks()])
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
@app.post("/tasks/{task_id}/grade", response_model=GradeResponse)
|
| 537 |
+
def grade(task_id: str, req: GradeRequest) -> GradeResponse:
|
| 538 |
+
"""Grade submitted code against a task's grader (returns score 0.0–1.0)."""
|
| 539 |
+
task = registry.get_task(task_id)
|
| 540 |
+
if task is None:
|
| 541 |
+
raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found")
|
| 542 |
+
score = task.grade(req.code)
|
| 543 |
+
return GradeResponse(
|
| 544 |
+
task_id=task_id,
|
| 545 |
+
score=round(score, 4),
|
| 546 |
+
passed=score >= 0.8,
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
@app.post("/optimize", response_model=OptimizeResponse)
|
| 551 |
+
def optimize(req: OptimizeRequest) -> OptimizeResponse:
|
| 552 |
+
"""Run a full optimization episode and return code comparison artifacts."""
|
| 553 |
+
code = req.code.strip("\n")
|
| 554 |
+
if not code.strip():
|
| 555 |
+
raise HTTPException(status_code=400, detail="code must be non-empty")
|
| 556 |
+
|
| 557 |
+
env = get_env()
|
| 558 |
+
try:
|
| 559 |
+
env.reset(task_id=req.task_id, code=code)
|
| 560 |
+
except ValueError as exc:
|
| 561 |
+
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
| 562 |
+
|
| 563 |
+
steps: list[OptimizationStep] = []
|
| 564 |
+
cumulative_reward = 0.0
|
| 565 |
+
|
| 566 |
+
for step_idx in range(1, req.max_steps + 1):
|
| 567 |
+
state_now = env.state()
|
| 568 |
+
current_code = state_now.current_code
|
| 569 |
+
obs_list = [float(x) for x in state_now.observation_vector]
|
| 570 |
+
|
| 571 |
+
action: int
|
| 572 |
+
reason: str
|
| 573 |
+
source: str
|
| 574 |
+
|
| 575 |
+
if req.use_rl:
|
| 576 |
+
rl_action, rl_reason, rl_source = _choose_action_rl(
|
| 577 |
+
observation=obs_list,
|
| 578 |
+
model_path=req.rl_model_path or DEFAULT_RL_MODEL_PATH,
|
| 579 |
+
)
|
| 580 |
+
if rl_action is not None:
|
| 581 |
+
action, reason, source = rl_action, rl_reason, rl_source
|
| 582 |
+
elif req.fallback_to_llm and req.use_llm:
|
| 583 |
+
action, reason, source = _choose_action_llm(
|
| 584 |
+
code=current_code,
|
| 585 |
+
task_id=req.task_id,
|
| 586 |
+
step_index=step_idx,
|
| 587 |
+
max_steps=req.max_steps,
|
| 588 |
+
api_base_url=req.api_base_url or DEFAULT_API_BASE_URL,
|
| 589 |
+
model_name=req.model_name or DEFAULT_MODEL_NAME,
|
| 590 |
+
api_token=req.api_token or "",
|
| 591 |
+
)
|
| 592 |
+
reason = f"{rl_reason}; {reason}"
|
| 593 |
+
else:
|
| 594 |
+
action = _choose_action_heuristic(current_code, req.task_id)
|
| 595 |
+
reason = f"{rl_reason}; heuristic fallback"
|
| 596 |
+
source = "heuristic"
|
| 597 |
+
elif req.use_llm:
|
| 598 |
+
action, reason, source = _choose_action_llm(
|
| 599 |
+
code=current_code,
|
| 600 |
+
task_id=req.task_id,
|
| 601 |
+
step_index=step_idx,
|
| 602 |
+
max_steps=req.max_steps,
|
| 603 |
+
api_base_url=req.api_base_url or DEFAULT_API_BASE_URL,
|
| 604 |
+
model_name=req.model_name or DEFAULT_MODEL_NAME,
|
| 605 |
+
api_token=req.api_token or "",
|
| 606 |
+
)
|
| 607 |
+
else:
|
| 608 |
+
action = _choose_action_heuristic(current_code, req.task_id)
|
| 609 |
+
reason = "heuristic policy"
|
| 610 |
+
source = "heuristic"
|
| 611 |
+
|
| 612 |
+
_, reward, done, info = env.step(action)
|
| 613 |
+
state_now = env.state()
|
| 614 |
+
|
| 615 |
+
cumulative_reward += float(reward.raw)
|
| 616 |
+
steps.append(
|
| 617 |
+
OptimizationStep(
|
| 618 |
+
step=step_idx,
|
| 619 |
+
action=action,
|
| 620 |
+
action_name=info.get("action_name", "unknown"),
|
| 621 |
+
reason=reason,
|
| 622 |
+
source=source,
|
| 623 |
+
reward=float(reward.raw),
|
| 624 |
+
normalized_reward=float(reward.normalized),
|
| 625 |
+
changed=bool(info.get("changed", False)),
|
| 626 |
+
complexity=float(state_now.complexity),
|
| 627 |
+
)
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
if done:
|
| 631 |
+
break
|
| 632 |
+
|
| 633 |
+
final_code = str(env.state().current_code)
|
| 634 |
+
diff_lines = difflib.unified_diff(
|
| 635 |
+
code.splitlines(),
|
| 636 |
+
final_code.splitlines(),
|
| 637 |
+
fromfile="original.py",
|
| 638 |
+
tofile="optimized.py",
|
| 639 |
+
lineterm="",
|
| 640 |
+
)
|
| 641 |
+
diff_text = "\n".join(diff_lines)
|
| 642 |
+
|
| 643 |
+
task_score: Optional[float] = None
|
| 644 |
+
if req.task_id:
|
| 645 |
+
task = registry.get_task(req.task_id)
|
| 646 |
+
if task is None:
|
| 647 |
+
raise HTTPException(status_code=404, detail=f"Task '{req.task_id}' not found")
|
| 648 |
+
task_score = round(task.grade(final_code), 4)
|
| 649 |
+
|
| 650 |
+
return OptimizeResponse(
|
| 651 |
+
original_code=code,
|
| 652 |
+
optimized_code=final_code,
|
| 653 |
+
diff=diff_text,
|
| 654 |
+
steps=steps,
|
| 655 |
+
cumulative_reward=round(cumulative_reward, 4),
|
| 656 |
+
task_id=req.task_id,
|
| 657 |
+
task_score=task_score,
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
# ---------------------------------------------------------------------------
|
| 662 |
+
# Entry point
|
| 663 |
+
# ---------------------------------------------------------------------------
|
| 664 |
+
|
| 665 |
+
if __name__ == "__main__":
|
| 666 |
+
port = int(os.getenv("PORT", 7860))
|
| 667 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
validate.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACRE pre-submission validator.
|
| 3 |
+
|
| 4 |
+
Checks the repository against the submission checklist and, when a server URL is
|
| 5 |
+
available, probes the HTTP API as well.
|
| 6 |
+
|
| 7 |
+
Run:
|
| 8 |
+
python validate.py --url http://localhost:7860
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import ast
|
| 14 |
+
import re
|
| 15 |
+
import sys
|
| 16 |
+
from typing import Any, Tuple
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
import requests
|
| 20 |
+
except ImportError:
|
| 21 |
+
print("[ERROR] requests is required. Run: pip install requests")
|
| 22 |
+
sys.exit(1)
|
| 23 |
+
|
| 24 |
+
PASS = "\033[92m[PASS]\033[0m"
|
| 25 |
+
FAIL = "\033[91m[FAIL]\033[0m"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def check(label: str, ok: bool, detail: str = "") -> bool:
|
| 29 |
+
status = PASS if ok else FAIL
|
| 30 |
+
message = f" {status} {label}"
|
| 31 |
+
if detail:
|
| 32 |
+
message += f" - {detail}"
|
| 33 |
+
print(message)
|
| 34 |
+
return ok
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get(url: str, path: str, timeout: int = 15) -> Tuple[bool, Any]:
|
| 38 |
+
try:
|
| 39 |
+
response = requests.get(f"{url}{path}", timeout=timeout)
|
| 40 |
+
response.raise_for_status()
|
| 41 |
+
return True, response.json()
|
| 42 |
+
except Exception as exc:
|
| 43 |
+
return False, str(exc)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def post(url: str, path: str, payload: dict, timeout: int = 15) -> Tuple[bool, Any]:
|
| 47 |
+
try:
|
| 48 |
+
response = requests.post(f"{url}{path}", json=payload, timeout=timeout)
|
| 49 |
+
response.raise_for_status()
|
| 50 |
+
return True, response.json()
|
| 51 |
+
except Exception as exc:
|
| 52 |
+
return False, str(exc)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def read_text(path: str) -> str:
|
| 56 |
+
with open(path, encoding="utf-8") as handle:
|
| 57 |
+
return handle.read()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def run_validation(base_url: str) -> int:
|
| 61 |
+
failures = 0
|
| 62 |
+
|
| 63 |
+
print("\n" + "=" * 60)
|
| 64 |
+
print(" ACRE Pre-Submission Validator")
|
| 65 |
+
print("=" * 60)
|
| 66 |
+
print(f" Target: {base_url}\n")
|
| 67 |
+
|
| 68 |
+
print("1. Static repository checks")
|
| 69 |
+
try:
|
| 70 |
+
interface_src = read_text("openenv_interface.py")
|
| 71 |
+
tree = ast.parse(interface_src)
|
| 72 |
+
classes = {node.name: node for node in tree.body if isinstance(node, ast.ClassDef)}
|
| 73 |
+
env_cls = classes.get("OpenEnvRefactorEnv")
|
| 74 |
+
failures += 0 if check("openenv_interface.py exists", True) else 1
|
| 75 |
+
failures += 0 if check("OpenEnvRefactorEnv is defined", env_cls is not None) else 1
|
| 76 |
+
if env_cls is not None:
|
| 77 |
+
methods = {node.name for node in env_cls.body if isinstance(node, ast.FunctionDef)}
|
| 78 |
+
for method_name in ["reset", "step", "state"]:
|
| 79 |
+
failures += 0 if check(
|
| 80 |
+
f"OpenEnvRefactorEnv implements {method_name}()",
|
| 81 |
+
method_name in methods,
|
| 82 |
+
) else 1
|
| 83 |
+
except FileNotFoundError:
|
| 84 |
+
failures += 1
|
| 85 |
+
check("openenv_interface.py exists", False, "file not found")
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
models_src = read_text("models.py")
|
| 89 |
+
for name in ["ObservationModel", "ActionModel", "RewardModel"]:
|
| 90 |
+
failures += 0 if check(
|
| 91 |
+
f"{name} is defined in models.py",
|
| 92 |
+
f"class {name}" in models_src,
|
| 93 |
+
) else 1
|
| 94 |
+
except FileNotFoundError:
|
| 95 |
+
failures += 1
|
| 96 |
+
check("models.py exists", False, "file not found")
|
| 97 |
+
|
| 98 |
+
print("\n2. Health check (GET /)")
|
| 99 |
+
ok, data = get(base_url, "/")
|
| 100 |
+
failures += 0 if check("GET / returns HTTP 200", ok) else 1
|
| 101 |
+
if ok:
|
| 102 |
+
failures += 0 if check(
|
| 103 |
+
"Response has status field",
|
| 104 |
+
isinstance(data, dict) and "status" in data,
|
| 105 |
+
str(data),
|
| 106 |
+
) else 1
|
| 107 |
+
|
| 108 |
+
print("\n3. Tasks (GET /tasks)")
|
| 109 |
+
ok, data = get(base_url, "/tasks")
|
| 110 |
+
failures += 0 if check("GET /tasks returns 200", ok) else 1
|
| 111 |
+
if ok:
|
| 112 |
+
tasks = data.get("tasks", []) if isinstance(data, dict) else []
|
| 113 |
+
failures += 0 if check("At least 3 tasks defined", len(tasks) >= 3, f"found {len(tasks)}") else 1
|
| 114 |
+
difficulties = [t.get("difficulty", "") for t in tasks]
|
| 115 |
+
for diff in ["easy", "medium", "hard"]:
|
| 116 |
+
failures += 0 if check(f"Task with difficulty '{diff}' exists", diff in difficulties) else 1
|
| 117 |
+
for task in tasks:
|
| 118 |
+
failures += 0 if check(
|
| 119 |
+
f"Task '{task.get('id')}' has initial_code",
|
| 120 |
+
bool(task.get("initial_code")),
|
| 121 |
+
) else 1
|
| 122 |
+
|
| 123 |
+
print("\n4. Reset (POST /reset)")
|
| 124 |
+
ok, data = post(base_url, "/reset", {})
|
| 125 |
+
failures += 0 if check("POST /reset returns 200", ok) else 1
|
| 126 |
+
if ok:
|
| 127 |
+
observation = data.get("observation", {})
|
| 128 |
+
failures += 0 if check("Response has observation field", isinstance(observation, dict)) else 1
|
| 129 |
+
failures += 0 if check(
|
| 130 |
+
"Observation is typed with 4 fields",
|
| 131 |
+
{"code_length", "complexity_score", "runtime_s", "error_flag"}.issubset(observation),
|
| 132 |
+
str(observation),
|
| 133 |
+
) else 1
|
| 134 |
+
|
| 135 |
+
ok, _ = post(base_url, "/reset", {"task_id": "rename_variables"})
|
| 136 |
+
failures += 0 if check("POST /reset with task_id works", ok) else 1
|
| 137 |
+
|
| 138 |
+
print("\n5. State (GET /state)")
|
| 139 |
+
ok, data = get(base_url, "/state")
|
| 140 |
+
failures += 0 if check("GET /state returns 200", ok) else 1
|
| 141 |
+
if ok:
|
| 142 |
+
required_keys = [
|
| 143 |
+
"current_code",
|
| 144 |
+
"episode_steps",
|
| 145 |
+
"max_steps",
|
| 146 |
+
"complexity",
|
| 147 |
+
"observation",
|
| 148 |
+
"observation_vector",
|
| 149 |
+
"action_meanings",
|
| 150 |
+
]
|
| 151 |
+
for key in required_keys:
|
| 152 |
+
failures += 0 if check(f"State has '{key}' field", key in data) else 1
|
| 153 |
+
|
| 154 |
+
print("\n6. Step (POST /step)")
|
| 155 |
+
post(base_url, "/reset", {"task_id": "rename_variables"})
|
| 156 |
+
for action in range(5):
|
| 157 |
+
ok, data = post(base_url, "/step", {"action": action})
|
| 158 |
+
failures += 0 if check(
|
| 159 |
+
f"Action {action} executes without error",
|
| 160 |
+
ok and isinstance(data, dict) and "reward" in data and "done" in data,
|
| 161 |
+
) else 1
|
| 162 |
+
if ok:
|
| 163 |
+
reward_payload = data.get("reward", {})
|
| 164 |
+
norm = reward_payload.get("normalized", -1)
|
| 165 |
+
failures += 0 if check(
|
| 166 |
+
f"Action {action} returns typed reward payload",
|
| 167 |
+
{"raw", "normalized", "components"}.issubset(reward_payload),
|
| 168 |
+
str(reward_payload),
|
| 169 |
+
) else 1
|
| 170 |
+
failures += 0 if check(
|
| 171 |
+
f"Action {action} normalized_reward in [0,1]",
|
| 172 |
+
isinstance(norm, (int, float)) and 0.0 <= float(norm) <= 1.0,
|
| 173 |
+
f"got {norm}",
|
| 174 |
+
) else 1
|
| 175 |
+
if data.get("done"):
|
| 176 |
+
break
|
| 177 |
+
|
| 178 |
+
ok, data = post(base_url, "/step", {"action": 99})
|
| 179 |
+
check("Invalid action returns error (not crash)", not ok or "detail" in str(data), "(expected 4xx)")
|
| 180 |
+
|
| 181 |
+
print("\n7. Task graders (POST /tasks/{id}/grade)")
|
| 182 |
+
for task_id in ["rename_variables", "remove_dead_code", "full_refactor"]:
|
| 183 |
+
ok, data = post(base_url, f"/tasks/{task_id}/grade", {"code": "def f(): pass"})
|
| 184 |
+
failures += 0 if check(f"Grade endpoint for '{task_id}' works", ok) else 1
|
| 185 |
+
if ok:
|
| 186 |
+
score = data.get("score", -1)
|
| 187 |
+
failures += 0 if check(
|
| 188 |
+
f"Score for '{task_id}' in [0.0, 1.0]",
|
| 189 |
+
isinstance(score, (int, float)) and 0.0 <= float(score) <= 1.0,
|
| 190 |
+
f"got {score}",
|
| 191 |
+
) else 1
|
| 192 |
+
|
| 193 |
+
print("\n8. openenv.yaml")
|
| 194 |
+
try:
|
| 195 |
+
openenv_yaml = read_text("openenv.yaml")
|
| 196 |
+
failures += 0 if check("openenv.yaml exists", True) else 1
|
| 197 |
+
for field in ["tasks:", "action_space:", "observation_space:", "reward:", "entrypoint:", "validation:"]:
|
| 198 |
+
failures += 0 if check(f"openenv.yaml has '{field}' section", field in openenv_yaml) else 1
|
| 199 |
+
except FileNotFoundError:
|
| 200 |
+
failures += 1
|
| 201 |
+
check("openenv.yaml exists", False, "file not found")
|
| 202 |
+
|
| 203 |
+
print("\n9. inference.py")
|
| 204 |
+
try:
|
| 205 |
+
inference_src = read_text("inference.py")
|
| 206 |
+
failures += 0 if check("inference.py exists", True) else 1
|
| 207 |
+
for marker in ['"event": "START"', '"event": "STEP"', '"event": "END"']:
|
| 208 |
+
failures += 0 if check(f"inference.py emits {marker}", marker in inference_src) else 1
|
| 209 |
+
failures += 0 if check(
|
| 210 |
+
"Uses OpenAI client",
|
| 211 |
+
"from openai import OpenAI" in inference_src,
|
| 212 |
+
) else 1
|
| 213 |
+
for var in ["API_BASE_URL", "MODEL_NAME", "HF_TOKEN", "ENV_URL", "LOCAL_IMAGE_NAME"]:
|
| 214 |
+
failures += 0 if check(f"inference.py reads {var} from env", var in inference_src) else 1
|
| 215 |
+
failures += 0 if check(
|
| 216 |
+
"API_BASE_URL has a default",
|
| 217 |
+
'os.getenv("API_BASE_URL", "https://api.openai.com/v1")' in inference_src,
|
| 218 |
+
) else 1
|
| 219 |
+
failures += 0 if check(
|
| 220 |
+
"MODEL_NAME has a default",
|
| 221 |
+
'os.getenv("MODEL_NAME", "gpt-4o-mini")' in inference_src,
|
| 222 |
+
) else 1
|
| 223 |
+
failures += 0 if check(
|
| 224 |
+
"HF_TOKEN has no default",
|
| 225 |
+
re.search(r'HF_TOKEN\s*:\s*.*os\.getenv\("HF_TOKEN"\)', inference_src) is not None,
|
| 226 |
+
) else 1
|
| 227 |
+
except FileNotFoundError:
|
| 228 |
+
failures += 1
|
| 229 |
+
check("inference.py exists", False, "file not found")
|
| 230 |
+
|
| 231 |
+
print("\n10. Dockerfile")
|
| 232 |
+
try:
|
| 233 |
+
dockerfile = read_text("Dockerfile")
|
| 234 |
+
failures += 0 if check("Dockerfile exists", True) else 1
|
| 235 |
+
failures += 0 if check("Exposes port 7860", "7860" in dockerfile) else 1
|
| 236 |
+
failures += 0 if check("Has CMD/ENTRYPOINT", "CMD" in dockerfile or "ENTRYPOINT" in dockerfile) else 1
|
| 237 |
+
failures += 0 if check("Does not set a default HF_TOKEN", "ENV HF_TOKEN" not in dockerfile) else 1
|
| 238 |
+
except FileNotFoundError:
|
| 239 |
+
failures += 1
|
| 240 |
+
check("Dockerfile exists", False, "file not found")
|
| 241 |
+
|
| 242 |
+
print("\n11. README / Hugging Face metadata")
|
| 243 |
+
try:
|
| 244 |
+
readme = read_text("README.md")
|
| 245 |
+
failures += 0 if check("README has docker SDK front matter", "sdk: docker" in readme) else 1
|
| 246 |
+
failures += 0 if check("README includes openenv tag", "openenv" in readme) else 1
|
| 247 |
+
for section in [
|
| 248 |
+
"Environment Overview and Motivation",
|
| 249 |
+
"Definitions of Action and Observation Spaces",
|
| 250 |
+
"Task Descriptions with Expected Difficulty Levels",
|
| 251 |
+
"Setup and Usage Instructions",
|
| 252 |
+
"Baseline Performance Scores",
|
| 253 |
+
]:
|
| 254 |
+
failures += 0 if check(f"README includes '{section}'", section in readme) else 1
|
| 255 |
+
except FileNotFoundError:
|
| 256 |
+
failures += 1
|
| 257 |
+
check("README.md exists", False, "file not found")
|
| 258 |
+
|
| 259 |
+
print("\n" + "=" * 60)
|
| 260 |
+
if failures == 0:
|
| 261 |
+
print(f" {PASS} All checks passed. Repository is submission-ready.")
|
| 262 |
+
else:
|
| 263 |
+
print(f" {FAIL} {failures} check(s) failed. Fix before submitting.")
|
| 264 |
+
print("=" * 60 + "\n")
|
| 265 |
+
|
| 266 |
+
return failures
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def main() -> None:
|
| 270 |
+
parser = argparse.ArgumentParser(description="ACRE pre-submission validator")
|
| 271 |
+
parser.add_argument(
|
| 272 |
+
"--url",
|
| 273 |
+
default="http://localhost:7860",
|
| 274 |
+
help="Base URL of the running ACRE server",
|
| 275 |
+
)
|
| 276 |
+
args = parser.parse_args()
|
| 277 |
+
sys.exit(run_validation(args.url))
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
main()
|