from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import json, os, random, subprocess, tempfile, sys, time from typing import Optional, List from server.executor import run_code_with_tests from server.grader import calculate_codearena_reward app = FastAPI(title="CodeArena RL Environment") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"] ) # -- State ------------------------------------------------ current_task = {} step_count = 0 previous_attempts = [] episode_rewards = [] TASKS_DIR = os.path.join(os.path.dirname(__file__), "..", "tasks") # -- Models ----------------------------------------------- class ResetRequest(BaseModel): task_id: str = "easy" buggy_code: Optional[str] = None class StepRequest(BaseModel): proposed_fix: str class AgentPerformanceTracker: def __init__(self): self.episode_rewards = [] self.current_difficulty = "easy" self.steps_at_difficulty = 0 def record_episode(self, avg_reward: float): self.episode_rewards.append(round(avg_reward, 4)) if len(self.episode_rewards) > 10: self.episode_rewards.pop(0) self.steps_at_difficulty += 1 self._maybe_escalate() def _maybe_escalate(self): if self.steps_at_difficulty < 3: return if not self.episode_rewards: return avg = sum(self.episode_rewards) / len(self.episode_rewards) old = self.current_difficulty if avg > 0.80 and self.current_difficulty == "easy": self.current_difficulty = "medium" self.steps_at_difficulty = 0 elif avg > 0.75 and self.current_difficulty == "medium": self.current_difficulty = "hard" self.steps_at_difficulty = 0 elif avg < 0.35 and self.current_difficulty == "hard": self.current_difficulty = "medium" self.steps_at_difficulty = 0 elif avg < 0.35 and self.current_difficulty == "medium": self.current_difficulty = "easy" self.steps_at_difficulty = 0 if self.current_difficulty != old: print( f"[CURRICULUM] {old} -> {self.current_difficulty} " f"after avg={avg:.3f}" ) def get_difficulty(self) -> str: return self.current_difficulty def get_stats(self) -> dict: avg = ( sum(self.episode_rewards) / len(self.episode_rewards) if self.episode_rewards else 0.0 ) return { "current_difficulty": self.current_difficulty, "recent_avg_reward": round(avg, 3), "episodes_tracked": len(self.episode_rewards), "steps_at_current_difficulty": self.steps_at_difficulty } tracker = AgentPerformanceTracker() # -- Helpers ---------------------------------------------- def load_random_task(difficulty: str): folder = os.path.join(TASKS_DIR, difficulty) files = [f for f in os.listdir(folder) if f.endswith(".json")] if not files: raise ValueError(f"No tasks found in {folder}") path = os.path.join(folder, random.choice(files)) with open(path) as f: return json.load(f) def run_tests(code: str, tests: list): passed = 0 total = len(tests) compile_ok = True error_log = "" start_time = time.time() with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, dir=tempfile.gettempdir()) as tmp: tmp.write(code) tmp_path = tmp.name try: for test in tests: inp = test["input"] expected = test["expected"] # Build a test runner script test_script = f""" import sys sys.path.insert(0, '') exec(open(r'{tmp_path}').read()) inp = {repr(inp)} if isinstance(inp, list): result = list(locals().values())[-1](*inp) if callable(list(locals().values())[-1]) else None # find the function import types funcs = {{k:v for k,v in locals().items() if isinstance(v, types.FunctionType)}} if funcs: fn = list(funcs.values())[0] result = fn(*inp) else: result = None else: result = None expected = {repr(expected)} print('PASS' if result == expected else f'FAIL got {{result}} expected {{expected}}') """ runner = tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, dir=tempfile.gettempdir()) runner.write(test_script) runner.close() try: proc = subprocess.run( [sys.executable, runner.name], capture_output=True, text=True, timeout=5 ) if "PASS" in proc.stdout: passed += 1 elif proc.returncode != 0: compile_ok = False error_log = proc.stderr[:300] except subprocess.TimeoutExpired: error_log = "Timeout" finally: os.unlink(runner.name) finally: os.unlink(tmp_path) return compile_ok, passed, total, error_log, time.time() - start_time def evaluate_fix(code: str, task: dict): if "test_code" in task: result = run_code_with_tests( code=code, test_code=task.get("test_code", ""), timeout=max(float(task.get("optimal_time_seconds", 0.05)) * 10, 2.0), ) return ( result.compile_success, result.test_passed, max(result.test_total, 1), result.runtime_errors, result.execution_time_seconds, ) return run_tests(code, task.get("tests", [])) # -- Endpoints -------------------------------------------- @app.get("/") def health(): return {"status": "ok", "environment": "CodeArena"} @app.post("/reset") def reset(req: ResetRequest): global current_task, step_count, previous_attempts, episode_rewards step_count = 0 previous_attempts = [] episode_rewards = [] if req.buggy_code: # Custom mode -- user pasted their own broken code current_task = { "task_id": "custom", "buggy_code": req.buggy_code, "description": "User-provided code -- fix the bug", "tests": [ {"input": [1, 2], "expected": None}, {"input": [0, 0], "expected": None}, {"input": [5, 5], "expected": None} ] } elif req.task_id in ("easy", "medium", "hard", "type_errors", "security_bugs"): current_task = load_random_task(req.task_id) elif req.task_id == "auto": current_task = load_random_task(tracker.get_difficulty()) else: current_task = load_random_task("easy") tests_count = len(current_task.get("tests", [])) if "test_code" in current_task: tests_count = 1 return { "task_id": current_task["task_id"], "curriculum_info": tracker.get_stats(), "observation": { "buggy_code": current_task["buggy_code"], "error_log": "", "test_results": f"0/{tests_count} tests passing", "previous_attempts": [] } } @app.post("/step") def step(req: StepRequest): global step_count, previous_attempts, episode_rewards step_count += 1 is_repeated_fix = req.proposed_fix in previous_attempts tests = current_task.get("tests", []) # For custom tasks with no expected values, just check compile + run if current_task["task_id"] == "custom": compile_ok, passed, total, error_log, execution_time_seconds = True, 0, 1, "", 0.0 try: with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: f.write(req.proposed_fix) tmp_path = f.name start_time = time.time() proc = subprocess.run([sys.executable, tmp_path], capture_output=True, text=True, timeout=5) execution_time_seconds = time.time() - start_time compile_ok = proc.returncode == 0 passed = 1 if compile_ok else 0 error_log = proc.stderr[:300] if not compile_ok else "" os.unlink(tmp_path) except Exception as e: compile_ok = False error_log = str(e) else: compile_ok, passed, total, error_log, execution_time_seconds = evaluate_fix( req.proposed_fix, current_task, ) total = max(total, 1) reward, reward_components = calculate_codearena_reward( compile_ok=compile_ok, passed=passed, total=total, execution_time_seconds=execution_time_seconds, optimal_time_seconds=float(current_task.get("optimal_time_seconds", 0.05)), buggy_code=current_task.get("buggy_code", ""), proposed_fix=req.proposed_fix, task_category=current_task.get("difficulty", current_task.get("task_id", "easy").split("-")[0]), step_count=step_count, is_repeated_fix=is_repeated_fix, ) episode_rewards.append(reward) previous_attempts.append(req.proposed_fix) done = (reward > 0.95) or (step_count >= 5) if done and episode_rewards: tracker.record_episode(sum(episode_rewards) / len(episode_rewards)) tests_total = len(tests) if tests else total return { "reward": reward, "done": done, "reward_components": reward_components, "curriculum_info": tracker.get_stats(), "observation": { "buggy_code": current_task["buggy_code"], "error_log": error_log, "test_results": f"{passed}/{tests_total} tests passing", "previous_attempts": previous_attempts[-3:] } } @app.get("/state") def state(): return { "task_id": current_task.get("task_id", "none"), "step_count": step_count, "observation": { "buggy_code": current_task.get("buggy_code", ""), "error_log": "", "test_results": "", "previous_attempts": previous_attempts } } @app.get("/curriculum") async def get_curriculum(): return tracker.get_stats()