PRANAV05092003's picture
Fixed structure (moved files to root)
bc5030f
from __future__ import annotations
import os
import random
import sys
from typing import Any, Optional, Tuple
from acre.datasets.code_samples import CodeSample, CodeSampleDataset
from acre.env.refactor_env import RefactorEnv
def _load_model(path: str):
"""Load a Stable-Baselines3 PPO model if available; otherwise return None."""
if not os.path.exists(path):
return None
try:
from stable_baselines3 import PPO
except Exception:
return None
try:
return PPO.load(path)
except Exception:
return None
def _messy_sample_code() -> str:
# Intentionally "messy" but valid Python for demo purposes.
return (
"def add(a,b):\n"
" x=0\n"
" for i in range(a):\n"
" x=x+1\n"
" if True:\n"
" x = x\n"
" if False:\n"
" y=123\n"
" else:\n"
" y=0\n"
" def f(p,q):\n"
" return p+q\n"
" r = f(x,y)\n"
" return r\n"
)
def _format_code_block(code: str) -> str:
return "\n".join(f" {line}" for line in code.rstrip().splitlines()) + "\n"
def _safe_print(text: str) -> None:
"""
Print text safely across Windows consoles (some default encodings can't print emojis).
"""
encoding = sys.stdout.encoding or "utf-8"
try:
text.encode(encoding)
print(text, flush=True)
except Exception:
# Fall back to ASCII-friendly markers if emojis can't be encoded.
safe = text.replace("✅", "[OK]").replace("⚠️", "[WARN]").replace("⚠", "[WARN]")
print(safe, flush=True)
def _compute_runtime(executor: Any, code: str) -> float:
"""Best-effort runtime metric using the current executor contract."""
try:
res = executor.run(code, filename="demo.py")
if getattr(res, "exit_code", 1) == 0 and isinstance(getattr(res, "metrics", None), dict):
return float(res.metrics.get("runtime_s", 0.0) or 0.0)
except Exception:
pass
return 0.0
def _choose_action(model: Any, obs, env: RefactorEnv, rng: random.Random) -> Tuple[int, str]:
"""Choose an action from the model, falling back to random."""
n_actions = int(getattr(getattr(env, "action_space", None), "n", 5))
if model is None:
a = int(rng.randint(0, n_actions - 1))
return a, "random"
try:
action, _state = model.predict(obs, deterministic=True)
# SB3 may return scalar or 1-element array.
if hasattr(action, "__len__"):
a = int(action[0])
else:
a = int(action)
return a, "ppo"
except Exception:
a = int(rng.randint(0, n_actions - 1))
return a, "random"
def run_demo(*, model_path: str = "acre_agent.zip", seed: int = 0) -> None:
rng = random.Random(seed)
# Create a dataset with one messy sample so `reset()` loads it deterministically.
dataset = CodeSampleDataset(
[
CodeSample(
id="demo_sample",
language="python",
code=_messy_sample_code(),
)
]
)
env = RefactorEnv(dataset=dataset, seed=seed)
model = _load_model(model_path)
model_status = "loaded" if model is not None else "not found (using random actions)"
# Reset and capture the original code/metrics.
obs, info = env.reset()
original_code = getattr(env, "_code", "")
original_complexity = float(getattr(env, "_compute_complexity")(original_code))
original_runtime = _compute_runtime(env.executor, original_code)
print("=" * 72)
print("ACRE: Autonomous RL Code Refactoring Agent (5-step episode)")
print(f"Model: {model_path} -> {model_status}")
print(f"Sample: {info.get('sample_id')} ({info.get('language')})")
print("=" * 72)
print("\nORIGINAL CODE:\n")
print(_format_code_block(original_code))
total_reward = 0.0
successful_transformations = 0
steps_taken = 0
for step_idx in range(1, 6):
action, policy = _choose_action(model, obs, env, rng)
obs, reward, terminated, truncated, step_info = env.step(action)
total_reward += float(reward)
steps_taken = step_idx
action_name = step_info.get("action_name", "unknown")
transform_meta = step_info.get("transform", {})
if isinstance(transform_meta, dict) and bool(transform_meta.get("success", False)):
successful_transformations += 1
transformed_code = getattr(env, "_code", "")
print("-" * 72)
print(f"STEP {step_idx}/5")
print(f"policy={policy} action={action} ({action_name})")
print(f"transform={transform_meta}")
print(f"reward={float(reward):.2f} components={step_info.get('reward_components')}")
print("\nUPDATED CODE:\n")
print(_format_code_block(transformed_code))
if terminated or truncated:
break
final_code = getattr(env, "_code", "")
final_complexity = float(getattr(env, "_compute_complexity")(final_code))
final_runtime = _compute_runtime(env.executor, final_code)
print("=" * 72)
print("FINAL SUMMARY")
print("=" * 72)
print(f"total_reward: {total_reward:.2f}")
print(f"complexity: {original_complexity:.0f} -> {final_complexity:.0f}")
print(f"runtime_s: {original_runtime:.4f} -> {final_runtime:.4f}")
complexity_improvement = ((original_complexity - final_complexity) / max(original_complexity, 1.0)) * 100.0
print(f"complexity improvement: {complexity_improvement:.2f}%")
print("\nCHANGES APPLIED:")
print(f"- Total steps: {steps_taken}")
print(f"- Successful transformations: {successful_transformations}")
if total_reward > 0:
_safe_print("\n✅ Code improved successfully")
else:
_safe_print("\n⚠️ No significant improvement")
print("\nFINAL CODE:\n")
print(_format_code_block(final_code))
env.close()
if __name__ == "__main__":
run_demo()