Spaces:
Sleeping
Sleeping
| """ | |
| CodeCourtEnv — OpenEnv-compliant environment. | |
| Implements reset / step / render following the OpenEnv spec. | |
| """ | |
| import random | |
| from typing import Dict, Any, Optional, Tuple | |
| from env.dynamic_curriculum import build_dynamic_problem, generate_dynamic_trap_tests | |
| from env.problem_types import ARCHETYPES, build_problem | |
| from env.state import EpisodeState | |
| from oracle.executor import OracleExecutor | |
| from oracle.validator import ProblemValidator | |
| from rewards.rubrics import SetterRubric, SolverRubric | |
| from rewards.elo import EloTracker | |
| class CodeCourtEnv: | |
| """ | |
| Adversarial Curriculum Arena environment. | |
| Two LLM agents compete: | |
| - Setter: generates a problem + test cases | |
| - Solver: writes code to solve the problem | |
| Minimax rewards: | |
| - Setter wins (+50) if Solver fails AND Setter can solve it | |
| - Solver wins (+50) if it passes all test cases | |
| """ | |
| ENV_NAME = "codecourt-v1" | |
| VERSION = "1.0.0" | |
| def __init__( | |
| self, | |
| archetypes: Optional[list] = None, | |
| time_limit: float = 2.0, | |
| memory_limit_mb: int = 256, | |
| difficulty_progression: bool = True, | |
| seed: int = 42, | |
| dynamic_problems: bool = True, | |
| dynamic_traps: bool = True, | |
| ): | |
| self.archetypes = archetypes or list(ARCHETYPES.keys()) | |
| self.oracle = OracleExecutor(time_limit, memory_limit_mb) | |
| self.validator = ProblemValidator() | |
| self.setter_rubric = SetterRubric() | |
| self.solver_rubric = SolverRubric() | |
| self.elo = EloTracker() | |
| self.difficulty_progression = difficulty_progression | |
| self.dynamic_problems = dynamic_problems | |
| self.dynamic_traps = dynamic_traps | |
| self.rng = random.Random(seed) | |
| self._episode_count = 0 | |
| self._current_state: Optional[EpisodeState] = None | |
| self._current_difficulty = 1 | |
| self._solver_pass_streak = 0 | |
| # ────────────────────────────────────────────────────────── | |
| # OpenEnv Interface | |
| # ────────────────────────────────────────────────────────── | |
| def reset(self) -> Dict[str, Any]: | |
| """Start a new episode. Returns initial observation.""" | |
| archetype = self.rng.choice(self.archetypes) | |
| task_id = self.rng.randint(0, len(ARCHETYPES[archetype]["tasks"]) - 1) | |
| variant_seed = self.rng.randint(0, 10**9) | |
| self._current_state = EpisodeState( | |
| episode_id=self._episode_count, | |
| archetype=archetype, | |
| task_id=task_id, | |
| difficulty=self._current_difficulty, | |
| ) | |
| self._episode_count += 1 | |
| # Build the ground-truth problem (Setter starts from this template) | |
| if self.dynamic_problems: | |
| problem = build_dynamic_problem(archetype, task_id, self._current_difficulty, seed=variant_seed) | |
| else: | |
| problem = build_problem(archetype, task_id, self._current_difficulty, seed=variant_seed) | |
| self._current_state.problem = problem | |
| obs = { | |
| "episode_id": self._current_state.episode_id, | |
| "archetype": archetype, | |
| "task_id": task_id, | |
| "difficulty": self._current_difficulty, | |
| "problem_template": problem["description"], | |
| "public_test_cases": problem["public_test_cases"], | |
| "hidden_test_count": len(problem["hidden_test_cases"]), | |
| "variant_seed": variant_seed, | |
| "generation_mode": problem.get("generation_mode", "static"), | |
| "elo": self.elo.get_stats(), | |
| } | |
| return obs | |
| def step( | |
| self, | |
| setter_code: str, | |
| solver_code: str, | |
| ) -> Tuple[Dict, Dict, bool, Dict]: | |
| """ | |
| Run one full episode step: | |
| 1. Validate problem | |
| 2. Run setter_code against test cases (self-consistency) | |
| 3. Run solver_code against test cases | |
| 4. Compute rewards | |
| 5. Update Elo, difficulty | |
| Returns: (setter_reward_info, solver_reward_info, done, info) | |
| """ | |
| state = self._current_state | |
| assert state is not None, "Call reset() before step()" | |
| problem = state.problem | |
| public_test_cases = problem.get("public_test_cases", problem["test_cases"]) | |
| hidden_test_cases = problem.get("hidden_test_cases", problem["test_cases"]) | |
| trap_test_cases = [] | |
| if self.dynamic_traps: | |
| trap_test_cases = generate_dynamic_trap_tests(problem, solver_code) | |
| if trap_test_cases: | |
| problem["trap_test_cases"] = trap_test_cases | |
| hidden_test_cases = hidden_test_cases + [ | |
| {"input": tc["input"], "expected": tc["expected"]} for tc in trap_test_cases | |
| ] | |
| all_test_cases = public_test_cases + hidden_test_cases | |
| # ── 1. Validate problem structure ── | |
| validation = self.validator.validate(problem) | |
| state.setter_valid = validation.valid | |
| # ── 2. Setter self-consistency check ── | |
| setter_run = self.oracle.run_against_tests(setter_code, all_test_cases) | |
| state.setter_code = setter_code | |
| state.setter_result = setter_run | |
| # ── 3. Solver attempts ── | |
| solver_public_run = self.oracle.run_against_tests(solver_code, public_test_cases) | |
| solver_hidden_run = self.oracle.run_against_tests(solver_code, hidden_test_cases) | |
| solver_run = self.oracle.run_against_tests(solver_code, all_test_cases) | |
| state.solver_code = solver_code | |
| state.solver_public_result = solver_public_run | |
| state.solver_hidden_result = solver_hidden_run | |
| state.solver_result = solver_run | |
| # ── 4. Rewards ── | |
| setter_breakdown = self.setter_rubric.score( | |
| setter_result=setter_run, | |
| solver_public_result=solver_public_run, | |
| solver_hidden_result=solver_hidden_run, | |
| problem_valid=state.setter_valid, | |
| optimal_complexity=problem.get("optimal_complexity", "O(N)"), | |
| ) | |
| solver_breakdown = self.solver_rubric.score( | |
| public_result=solver_public_run, | |
| hidden_result=solver_hidden_run, | |
| solver_code=solver_code, | |
| optimal_complexity=problem.get("optimal_complexity", "O(N)"), | |
| ) | |
| state.setter_reward = setter_breakdown.total | |
| state.solver_reward = solver_breakdown.total | |
| # ── 5. Determine outcome ── | |
| solver_passed = solver_hidden_run["overall_status"] == "pass" | |
| setter_can_solve = setter_run["overall_status"] == "pass" | |
| if not state.setter_valid or not setter_can_solve: | |
| state.outcome = "invalid" | |
| elif solver_passed: | |
| state.outcome = "solver_wins" | |
| self._solver_pass_streak += 1 | |
| else: | |
| state.outcome = "setter_wins" | |
| self._solver_pass_streak = 0 | |
| # ── 6. Update Elo ── | |
| self.elo.update( | |
| setter_won=(state.outcome == "setter_wins"), | |
| setter_reward=state.setter_reward, | |
| solver_reward=state.solver_reward, | |
| ) | |
| # ── 7. Difficulty progression ── | |
| if self.difficulty_progression: | |
| if self._solver_pass_streak >= 3 and self._current_difficulty < 3: | |
| self._current_difficulty += 1 | |
| self._solver_pass_streak = 0 | |
| state.done = True | |
| info = { | |
| "outcome": state.outcome, | |
| "setter_valid": state.setter_valid, | |
| "setter_pass_rate": setter_run["pass_rate"], | |
| "solver_public_pass_rate": solver_public_run["pass_rate"], | |
| "solver_hidden_pass_rate": solver_hidden_run["pass_rate"], | |
| "solver_pass_rate": solver_hidden_run["pass_rate"], | |
| "difficulty": self._current_difficulty, | |
| "elo": self.elo.get_stats(), | |
| "validation_errors": validation.errors, | |
| "validation_warnings": validation.warnings, | |
| "generation_mode": problem.get("generation_mode", "static"), | |
| "dynamic_trap_count": len(trap_test_cases), | |
| "dynamic_traps": trap_test_cases, | |
| "effective_hidden_test_count": len(hidden_test_cases), | |
| } | |
| return ( | |
| {"reward": state.setter_reward, "breakdown": setter_breakdown.__dict__}, | |
| {"reward": state.solver_reward, "breakdown": solver_breakdown.__dict__}, | |
| True, # done — one step per episode | |
| info, | |
| ) | |
| def render(self, mode: str = "text") -> str: | |
| """Human-readable state summary.""" | |
| s = self._current_state | |
| if s is None: | |
| return "No active episode. Call reset() first." | |
| lines = [ | |
| f"═══ Episode {s.episode_id} ═══", | |
| f"Archetype : {s.archetype} / Task {s.task_id} / Difficulty {s.difficulty}", | |
| f"Outcome : {s.outcome}", | |
| f"Setter R : {s.setter_reward:+.1f}", | |
| f"Solver R : {s.solver_reward:+.1f}", | |
| f"Elo : Setter={self.elo.setter_elo:.0f} " | |
| f"Solver={self.elo.solver_elo:.0f}", | |
| ] | |
| return "\n".join(lines) | |
| def get_metrics(self) -> Dict[str, Any]: | |
| """Return aggregate training metrics.""" | |
| return { | |
| "total_episodes": self._episode_count, | |
| "current_difficulty": self._current_difficulty, | |
| **self.elo.get_stats(), | |
| } | |