File size: 6,033 Bytes
bc5030f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | from __future__ import annotations
import os
import random
import sys
from typing import Any, Optional, Tuple
from acre.datasets.code_samples import CodeSample, CodeSampleDataset
from acre.env.refactor_env import RefactorEnv
def _load_model(path: str):
"""Load a Stable-Baselines3 PPO model if available; otherwise return None."""
if not os.path.exists(path):
return None
try:
from stable_baselines3 import PPO
except Exception:
return None
try:
return PPO.load(path)
except Exception:
return None
def _messy_sample_code() -> str:
# Intentionally "messy" but valid Python for demo purposes.
return (
"def add(a,b):\n"
" x=0\n"
" for i in range(a):\n"
" x=x+1\n"
" if True:\n"
" x = x\n"
" if False:\n"
" y=123\n"
" else:\n"
" y=0\n"
" def f(p,q):\n"
" return p+q\n"
" r = f(x,y)\n"
" return r\n"
)
def _format_code_block(code: str) -> str:
return "\n".join(f" {line}" for line in code.rstrip().splitlines()) + "\n"
def _safe_print(text: str) -> None:
"""
Print text safely across Windows consoles (some default encodings can't print emojis).
"""
encoding = sys.stdout.encoding or "utf-8"
try:
text.encode(encoding)
print(text, flush=True)
except Exception:
# Fall back to ASCII-friendly markers if emojis can't be encoded.
safe = text.replace("✅", "[OK]").replace("⚠️", "[WARN]").replace("⚠", "[WARN]")
print(safe, flush=True)
def _compute_runtime(executor: Any, code: str) -> float:
"""Best-effort runtime metric using the current executor contract."""
try:
res = executor.run(code, filename="demo.py")
if getattr(res, "exit_code", 1) == 0 and isinstance(getattr(res, "metrics", None), dict):
return float(res.metrics.get("runtime_s", 0.0) or 0.0)
except Exception:
pass
return 0.0
def _choose_action(model: Any, obs, env: RefactorEnv, rng: random.Random) -> Tuple[int, str]:
"""Choose an action from the model, falling back to random."""
n_actions = int(getattr(getattr(env, "action_space", None), "n", 5))
if model is None:
a = int(rng.randint(0, n_actions - 1))
return a, "random"
try:
action, _state = model.predict(obs, deterministic=True)
# SB3 may return scalar or 1-element array.
if hasattr(action, "__len__"):
a = int(action[0])
else:
a = int(action)
return a, "ppo"
except Exception:
a = int(rng.randint(0, n_actions - 1))
return a, "random"
def run_demo(*, model_path: str = "acre_agent.zip", seed: int = 0) -> None:
rng = random.Random(seed)
# Create a dataset with one messy sample so `reset()` loads it deterministically.
dataset = CodeSampleDataset(
[
CodeSample(
id="demo_sample",
language="python",
code=_messy_sample_code(),
)
]
)
env = RefactorEnv(dataset=dataset, seed=seed)
model = _load_model(model_path)
model_status = "loaded" if model is not None else "not found (using random actions)"
# Reset and capture the original code/metrics.
obs, info = env.reset()
original_code = getattr(env, "_code", "")
original_complexity = float(getattr(env, "_compute_complexity")(original_code))
original_runtime = _compute_runtime(env.executor, original_code)
print("=" * 72)
print("ACRE: Autonomous RL Code Refactoring Agent (5-step episode)")
print(f"Model: {model_path} -> {model_status}")
print(f"Sample: {info.get('sample_id')} ({info.get('language')})")
print("=" * 72)
print("\nORIGINAL CODE:\n")
print(_format_code_block(original_code))
total_reward = 0.0
successful_transformations = 0
steps_taken = 0
for step_idx in range(1, 6):
action, policy = _choose_action(model, obs, env, rng)
obs, reward, terminated, truncated, step_info = env.step(action)
total_reward += float(reward)
steps_taken = step_idx
action_name = step_info.get("action_name", "unknown")
transform_meta = step_info.get("transform", {})
if isinstance(transform_meta, dict) and bool(transform_meta.get("success", False)):
successful_transformations += 1
transformed_code = getattr(env, "_code", "")
print("-" * 72)
print(f"STEP {step_idx}/5")
print(f"policy={policy} action={action} ({action_name})")
print(f"transform={transform_meta}")
print(f"reward={float(reward):.2f} components={step_info.get('reward_components')}")
print("\nUPDATED CODE:\n")
print(_format_code_block(transformed_code))
if terminated or truncated:
break
final_code = getattr(env, "_code", "")
final_complexity = float(getattr(env, "_compute_complexity")(final_code))
final_runtime = _compute_runtime(env.executor, final_code)
print("=" * 72)
print("FINAL SUMMARY")
print("=" * 72)
print(f"total_reward: {total_reward:.2f}")
print(f"complexity: {original_complexity:.0f} -> {final_complexity:.0f}")
print(f"runtime_s: {original_runtime:.4f} -> {final_runtime:.4f}")
complexity_improvement = ((original_complexity - final_complexity) / max(original_complexity, 1.0)) * 100.0
print(f"complexity improvement: {complexity_improvement:.2f}%")
print("\nCHANGES APPLIED:")
print(f"- Total steps: {steps_taken}")
print(f"- Successful transformations: {successful_transformations}")
if total_reward > 0:
_safe_print("\n✅ Code improved successfully")
else:
_safe_print("\n⚠️ No significant improvement")
print("\nFINAL CODE:\n")
print(_format_code_block(final_code))
env.close()
if __name__ == "__main__":
run_demo()
|