""" Rust Coder Environment Implementation. Evaluates LLM-generated Rust code against 10 sequential coding problems. Multi-dimensional reward system: Compilation(40%), Correctness(20%), Coverage(20%), Elegance(10%), Efficiency(10%). """ import json import os import re import subprocess import tempfile import time from typing import Dict, List, Optional, Tuple from openenv.core.env_server.interfaces import Environment from models import RustCoderAction, RustCoderObservation # Resolve problems.json: look in same dir as this file, then parent _HERE = os.path.dirname(os.path.abspath(__file__)) _PROBLEMS_PATHS = [ os.path.join(_HERE, "problems.json"), # server/problems.json os.path.join(_HERE, "..", "problems.json"), # root problems.json "problems.json", # cwd fallback ] def _find_problems_file() -> str: """Return the first existing problems.json path.""" for path in _PROBLEMS_PATHS: if os.path.exists(path): return os.path.abspath(path) raise FileNotFoundError( f"problems.json not found. Searched: {_PROBLEMS_PATHS}" ) class RustCoderEnvironment(Environment): """ OpenEnv-compliant environment for evaluating Rust code submissions. Manages 10 sequential coding problems. Each episode is a single problem: - reset() → loads the current problem, returns its description - step(action) → compiles & tests submitted code, returns reward - After step(), the episode is done; next reset() loads the next problem. Reward breakdown (all components normalized to [0, 1]): Compilation 40% — code compiles without errors Correctness 20% — fraction of test assertions that pass Coverage 20% — fraction of tests attempted to run Elegance 10% — code quality heuristics Efficiency 10% — execution time vs. problem baseline """ SUPPORTS_CONCURRENT_SESSIONS: bool = True # Compile / run timeouts (seconds) COMPILE_TIMEOUT = 30 RUN_TIMEOUT = 10 def __init__(self) -> None: """Initialize environment and load problems from JSON.""" self.problems: List[Dict] = self._load_problems() self.current_problem_idx: int = 0 self.step_count: int = 0 # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _load_problems(self) -> List[Dict]: """Load and validate the problems list from problems.json.""" path = _find_problems_file() with open(path, "r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, list) or len(data) == 0: raise ValueError("problems.json must be a non-empty JSON array.") return data def _current_problem(self) -> Dict: idx = self.current_problem_idx % len(self.problems) return self.problems[idx] # ------------------------------------------------------------------ # OpenEnv interface # ------------------------------------------------------------------ @property def state(self): """Return minimal state info (step count, problem index).""" from openenv.core.env_server.types import State return State(episode_id=None, step_count=self.step_count) def reset(self, start_index: int = 0) -> RustCoderObservation: """Start a new episode, defaulting to the first problem.""" self.current_problem_idx = start_index % len(self.problems) self.step_count = 0 problem = self.problems[self.current_problem_idx] return RustCoderObservation( problem_description=problem["description"], starter_code=problem["starter_code"], compilation_success=False, compilation_output="", test_results=[], reward_breakdown={}, done=False, reward=0.0, ) def step(self, action: RustCoderAction) -> RustCoderObservation: """Evaluate the submitted code and advance the task index within the single episode.""" self.step_count += 1 problem = self.problems[self.current_problem_idx] code = action.code if not code.strip(): done = self.current_problem_idx >= len(self.problems) - 1 if not done: self.current_problem_idx += 1 return RustCoderObservation( problem_description=problem["description"], starter_code=problem.get("starter_code", ""), compilation_success=False, compilation_output="Error: no code submitted.", test_results=[], reward_breakdown={"compilation": 0.0, "correctness": 0.0, "coverage": 0.0, "elegance": 0.0, "efficiency": 0.0}, done=done, reward=0.0, ) # ── 1. Compilation (40%) ────────────────────────────────────── compilation_success, compilation_output = self._compile_check(code) r_compilation = 1.0 if compilation_success else 0.0 # ── 2. Correctness + Coverage (20% each) ───────────────────── test_results: List[Dict] = [] r_correctness = 0.0 r_coverage = 0.0 if compilation_success: tests = problem.get("tests", []) if tests: test_results = self._run_tests(code, tests) passed = sum(1 for t in test_results if t.get("passed", False)) ran = sum(1 for t in test_results if t.get("ran", False)) r_correctness = passed / len(tests) r_coverage = ran / len(tests) else: # No tests defined — give full credit to both dimensions r_correctness = 1.0 r_coverage = 1.0 # ── 3. Elegance (10%) ───────────────────────────────────────── r_elegance = self._score_elegance(code) # ── 4. Efficiency (10%) ─────────────────────────────────────── baseline_ms: float = problem.get("performance_baseline_ms", 100.0) r_efficiency = 0.0 if compilation_success: r_efficiency = self._score_efficiency(code, baseline_ms) # ── Total reward ────────────────────────────────────────────── reward_breakdown = { "Compilation": round(r_compilation, 4), "Correctness": round(r_correctness, 4), "Coverage": round(r_coverage, 4), "Elegance": round(r_elegance, 4), "Efficiency": round(r_efficiency, 4), } # Calculate weighted total reward total_reward = round( r_compilation * 0.40 + r_correctness * 0.20 + r_coverage * 0.20 + r_elegance * 0.10 + r_efficiency * 0.10, 4, ) # ── Advance Logic ───────────────────────────────────────────── self.current_problem_idx += 1 done = self.current_problem_idx >= len(self.problems) next_prob_desc = "--- ALL TASKS COMPLETED in this episode ---" next_starter = "" if not done: next_prob = self.problems[self.current_problem_idx] next_prob_desc = f"--- NEXT TASK: {next_prob['title']} ---\n\n{next_prob['description']}" next_starter = next_prob['starter_code'] return RustCoderObservation( problem_description=next_prob_desc, starter_code=next_starter, compilation_success=compilation_success, compilation_output=compilation_output[:2000], # cap length test_results=test_results, reward_breakdown=reward_breakdown, done=done, reward=total_reward, ) # ------------------------------------------------------------------ # Compilation # ------------------------------------------------------------------ def _compile_check(self, code: str) -> Tuple[bool, str]: """ Compile code as a Rust library crate. Returns (success, compiler output). """ with tempfile.TemporaryDirectory() as tmpdir: src = os.path.join(tmpdir, "submission.rs") out = os.path.join(tmpdir, "submission.rlib") with open(src, "w", encoding="utf-8") as f: f.write(code) try: proc = subprocess.run( ["rustc", "--crate-type=lib", src, "-o", out, "--edition=2021"], capture_output=True, text=True, timeout=self.COMPILE_TIMEOUT, ) return proc.returncode == 0, (proc.stdout + proc.stderr).strip() except subprocess.TimeoutExpired: return False, "Compilation timed out." except FileNotFoundError: return False, "rustc not found — is the Rust toolchain installed?" # ------------------------------------------------------------------ # Correctness / Coverage # ------------------------------------------------------------------ def _strip_main(self, code: str) -> str: """ Remove fn main() { ... } blocks from submitted code so we can inject our own test main. Handles simple single-level braces. """ # Remove pub/private fn main() { ... } pattern = re.compile( r'(pub\s+)?fn\s+main\s*\(\s*\)\s*(?:->\s*[^{]+)?\s*\{', re.MULTILINE, ) match = pattern.search(code) if not match: return code start = match.start() depth = 0 i = match.end() - 1 # position of the opening '{' while i < len(code): if code[i] == '{': depth += 1 elif code[i] == '}': depth -= 1 if depth == 0: return code[:start] + code[i + 1:] i += 1 return code # malformed; return as-is def _build_test_binary( self, code: str, assertion: str, tmpdir: str, test_name: str ) -> Tuple[bool, str, str]: """ Build a runnable Rust binary that executes one test assertion. Returns (compiled_ok, binary_path, compiler_output). """ body = self._strip_main(code) src_code = f""" #[allow(unused_imports, dead_code, unused_variables, unused_mut)] {body} fn main() {{ {assertion}; println!("PASS:{test_name}"); }} """ src_path = os.path.join(tmpdir, f"{test_name}.rs") bin_path = os.path.join(tmpdir, test_name) with open(src_path, "w", encoding="utf-8") as f: f.write(src_code) try: proc = subprocess.run( ["rustc", src_path, "-o", bin_path, "--edition=2021"], capture_output=True, text=True, timeout=self.COMPILE_TIMEOUT, ) return proc.returncode == 0, bin_path, (proc.stdout + proc.stderr).strip() except subprocess.TimeoutExpired: return False, "", "Compile timed out for test." except FileNotFoundError: return False, "", "rustc not found." def _run_tests(self, code: str, tests: List[Dict]) -> List[Dict]: """ Run each test assertion as a separate Rust binary. Returns list of result dicts with keys: name, passed, ran, error. """ results = [] with tempfile.TemporaryDirectory() as tmpdir: for i, test in enumerate(tests): name = test.get("name", f"test_{i}") assertion = test.get("test_assertion", "") should_compile = test.get("should_compile", True) result: Dict = { "name": name, "passed": False, "ran": False, "error": None, } if not assertion: result["error"] = "No test assertion defined." results.append(result) continue # Some tests are expected to fail compilation (should_compile=False) # treat successful compilation + correct output as pass bin_test_name = f"t{i}_{name[:20]}" compiled, bin_path, compiler_out = self._build_test_binary( code, assertion, tmpdir, bin_test_name ) if not compiled: if not should_compile: # The problem's starter code deliberately doesn't compile; # if the submission also doesn't compile this test → skip result["error"] = "Binary compile failed (expected for broken starter)." else: result["error"] = f"Compile error: {compiler_out[:300]}" result["ran"] = False results.append(result) continue # Run the binary result["ran"] = True try: run_proc = subprocess.run( [bin_path], capture_output=True, text=True, timeout=self.RUN_TIMEOUT, ) stdout = run_proc.stdout.strip() if run_proc.returncode == 0 and f"PASS:{bin_test_name}" in stdout: result["passed"] = True else: result["error"] = ( f"Test failed. Exit={run_proc.returncode}. " f"stderr={run_proc.stderr[:200]}" ) except subprocess.TimeoutExpired: result["error"] = "Test execution timed out." except Exception as exc: result["error"] = str(exc) results.append(result) return results # ------------------------------------------------------------------ # Elegance scoring # ------------------------------------------------------------------ def _score_elegance(self, code: str) -> float: """ Heuristic code-quality score in [0, 1]. Penalties: - Each `.unwrap()` call → -0.15 (max -0.45) - Each `.expect(` call → -0.05 (max -0.15) - Lines > 100 chars → -0.05 per violation (max -0.20) - `unsafe` blocks → -0.20 unless problem requires FFI Bonuses: - Uses `?` operator → +0.10 - Uses `match` expressions → +0.05 - Has doc comments (`///`) → +0.05 """ score = 1.0 unwrap_count = len(re.findall(r'\.unwrap\(\)', code)) score -= min(unwrap_count * 0.15, 0.45) expect_count = len(re.findall(r'\.expect\(', code)) score -= min(expect_count * 0.05, 0.15) long_lines = sum(1 for line in code.splitlines() if len(line) > 100) score -= min(long_lines * 0.05, 0.20) if "unsafe" in code: score -= 0.20 if "?" in code: score += 0.10 if "match " in code or "match\n" in code: score += 0.05 if "///" in code: score += 0.05 return round(max(0.0, min(1.0, score)), 4) # ------------------------------------------------------------------ # Efficiency scoring # ------------------------------------------------------------------ def _score_efficiency(self, code: str, baseline_ms: float) -> float: """ Time the execution by compiling + running a minimal binary. Score = min(1.0, baseline_ms / actual_ms). Returns 0.0 if compilation or execution fails. """ body = self._strip_main(code) # Build a binary with an empty main to measure startup + run overhead test_src = f""" #[allow(unused_imports, dead_code, unused_variables)] {body} fn main() {{}} """ with tempfile.TemporaryDirectory() as tmpdir: src_path = os.path.join(tmpdir, "eff.rs") bin_path = os.path.join(tmpdir, "eff") with open(src_path, "w", encoding="utf-8") as f: f.write(test_src) try: # Compile proc = subprocess.run( ["rustc", src_path, "-o", bin_path, "--edition=2021"], capture_output=True, text=True, timeout=self.COMPILE_TIMEOUT, ) if proc.returncode != 0: return 0.0 # Time the run t0 = time.monotonic() run_proc = subprocess.run( [bin_path], capture_output=True, timeout=self.RUN_TIMEOUT ) elapsed_ms = (time.monotonic() - t0) * 1000.0 if run_proc.returncode != 0: return 0.0 return round(min(1.0, baseline_ms / max(elapsed_ms, 0.1)), 4) except Exception: return 0.0