rust_coder / server /rust_coder_environment.py
Parthiban007's picture
Upload folder using huggingface_hub
7bc8744 verified
raw
history blame
17.9 kB
"""
Rust Coder Environment Implementation.
Evaluates LLM-generated Rust code against 10 sequential coding problems.
Multi-dimensional reward system: Compilation(40%), Correctness(20%),
Coverage(20%), Elegance(10%), Efficiency(10%).
"""
import json
import os
import re
import subprocess
import tempfile
import time
from typing import Dict, List, Optional, Tuple
from openenv.core.env_server.interfaces import Environment
from models import RustCoderAction, RustCoderObservation
# Resolve problems.json: look in same dir as this file, then parent
_HERE = os.path.dirname(os.path.abspath(__file__))
_PROBLEMS_PATHS = [
os.path.join(_HERE, "problems.json"), # server/problems.json
os.path.join(_HERE, "..", "problems.json"), # root problems.json
"problems.json", # cwd fallback
]
def _find_problems_file() -> str:
"""Return the first existing problems.json path."""
for path in _PROBLEMS_PATHS:
if os.path.exists(path):
return os.path.abspath(path)
raise FileNotFoundError(
f"problems.json not found. Searched: {_PROBLEMS_PATHS}"
)
class RustCoderEnvironment(Environment):
"""
OpenEnv-compliant environment for evaluating Rust code submissions.
Manages 10 sequential coding problems. Each episode is a single problem:
- reset() β†’ loads the current problem, returns its description
- step(action) β†’ compiles & tests submitted code, returns reward
- After step(), the episode is done; next reset() loads the next problem.
Reward breakdown (all components normalized to [0, 1]):
Compilation 40% β€” code compiles without errors
Correctness 20% β€” fraction of test assertions that pass
Coverage 20% β€” fraction of tests attempted to run
Elegance 10% β€” code quality heuristics
Efficiency 10% β€” execution time vs. problem baseline
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
# Compile / run timeouts (seconds)
COMPILE_TIMEOUT = 30
RUN_TIMEOUT = 10
def __init__(self) -> None:
"""Initialize environment and load problems from JSON."""
self.problems: List[Dict] = self._load_problems()
self.current_problem_idx: int = 0
self.step_count: int = 0
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _load_problems(self) -> List[Dict]:
"""Load and validate the problems list from problems.json."""
path = _find_problems_file()
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list) or len(data) == 0:
raise ValueError("problems.json must be a non-empty JSON array.")
return data
def _current_problem(self) -> Dict:
idx = self.current_problem_idx % len(self.problems)
return self.problems[idx]
# ------------------------------------------------------------------
# OpenEnv interface
# ------------------------------------------------------------------
@property
def state(self):
"""Return minimal state info (step count, problem index)."""
from openenv.core.env_server.types import State
return State(episode_id=None, step_count=self.step_count)
def reset(self, start_index: int = 0) -> RustCoderObservation:
"""Start a new episode, defaulting to the first problem."""
self.current_problem_idx = start_index % len(self.problems)
self.step_count = 0
problem = self.problems[self.current_problem_idx]
return RustCoderObservation(
problem_description=problem["description"],
starter_code=problem["starter_code"],
compilation_success=False,
compilation_output="",
test_results=[],
reward_breakdown={},
done=False,
reward=0.0,
)
def step(self, action: RustCoderAction) -> RustCoderObservation:
"""Evaluate the submitted code and advance the task index within the single episode."""
self.step_count += 1
problem = self.problems[self.current_problem_idx]
code = action.code
if not code.strip():
done = self.current_problem_idx >= len(self.problems) - 1
if not done:
self.current_problem_idx += 1
return RustCoderObservation(
problem_description=problem["description"],
starter_code=problem.get("starter_code", ""),
compilation_success=False,
compilation_output="Error: no code submitted.",
test_results=[],
reward_breakdown={"compilation": 0.0, "correctness": 0.0, "coverage": 0.0, "elegance": 0.0, "efficiency": 0.0},
done=done,
reward=0.0,
)
# ── 1. Compilation (40%) ──────────────────────────────────────
compilation_success, compilation_output = self._compile_check(code)
r_compilation = 1.0 if compilation_success else 0.0
# ── 2. Correctness + Coverage (20% each) ─────────────────────
test_results: List[Dict] = []
r_correctness = 0.0
r_coverage = 0.0
if compilation_success:
tests = problem.get("tests", [])
if tests:
test_results = self._run_tests(code, tests)
passed = sum(1 for t in test_results if t.get("passed", False))
ran = sum(1 for t in test_results if t.get("ran", False))
r_correctness = passed / len(tests)
r_coverage = ran / len(tests)
else:
# No tests defined β€” give full credit to both dimensions
r_correctness = 1.0
r_coverage = 1.0
# ── 3. Elegance (10%) ─────────────────────────────────────────
r_elegance = self._score_elegance(code)
# ── 4. Efficiency (10%) ───────────────────────────────────────
baseline_ms: float = problem.get("performance_baseline_ms", 100.0)
r_efficiency = 0.0
if compilation_success:
r_efficiency = self._score_efficiency(code, baseline_ms)
# ── Total reward ──────────────────────────────────────────────
reward_breakdown = {
"Compilation": round(r_compilation, 4),
"Correctness": round(r_correctness, 4),
"Coverage": round(r_coverage, 4),
"Elegance": round(r_elegance, 4),
"Efficiency": round(r_efficiency, 4),
}
# Calculate weighted total reward
total_reward = round(
r_compilation * 0.40
+ r_correctness * 0.20
+ r_coverage * 0.20
+ r_elegance * 0.10
+ r_efficiency * 0.10,
4,
)
# ── Advance Logic ─────────────────────────────────────────────
self.current_problem_idx += 1
done = self.current_problem_idx >= len(self.problems)
next_prob_desc = "--- ALL TASKS COMPLETED in this episode ---"
next_starter = ""
if not done:
next_prob = self.problems[self.current_problem_idx]
next_prob_desc = f"--- NEXT TASK: {next_prob['title']} ---\n\n{next_prob['description']}"
next_starter = next_prob['starter_code']
return RustCoderObservation(
problem_description=next_prob_desc,
starter_code=next_starter,
compilation_success=compilation_success,
compilation_output=compilation_output[:2000], # cap length
test_results=test_results,
reward_breakdown=reward_breakdown,
done=done,
reward=total_reward,
)
# ------------------------------------------------------------------
# Compilation
# ------------------------------------------------------------------
def _compile_check(self, code: str) -> Tuple[bool, str]:
"""
Compile code as a Rust library crate.
Returns (success, compiler output).
"""
with tempfile.TemporaryDirectory() as tmpdir:
src = os.path.join(tmpdir, "submission.rs")
out = os.path.join(tmpdir, "submission.rlib")
with open(src, "w", encoding="utf-8") as f:
f.write(code)
try:
proc = subprocess.run(
["rustc", "--crate-type=lib", src, "-o", out,
"--edition=2021"],
capture_output=True,
text=True,
timeout=self.COMPILE_TIMEOUT,
)
return proc.returncode == 0, (proc.stdout + proc.stderr).strip()
except subprocess.TimeoutExpired:
return False, "Compilation timed out."
except FileNotFoundError:
return False, "rustc not found β€” is the Rust toolchain installed?"
# ------------------------------------------------------------------
# Correctness / Coverage
# ------------------------------------------------------------------
def _strip_main(self, code: str) -> str:
"""
Remove fn main() { ... } blocks from submitted code so we can
inject our own test main. Handles simple single-level braces.
"""
# Remove pub/private fn main() { ... }
pattern = re.compile(
r'(pub\s+)?fn\s+main\s*\(\s*\)\s*(?:->\s*[^{]+)?\s*\{',
re.MULTILINE,
)
match = pattern.search(code)
if not match:
return code
start = match.start()
depth = 0
i = match.end() - 1 # position of the opening '{'
while i < len(code):
if code[i] == '{':
depth += 1
elif code[i] == '}':
depth -= 1
if depth == 0:
return code[:start] + code[i + 1:]
i += 1
return code # malformed; return as-is
def _build_test_binary(
self, code: str, assertion: str, tmpdir: str, test_name: str
) -> Tuple[bool, str, str]:
"""
Build a runnable Rust binary that executes one test assertion.
Returns (compiled_ok, binary_path, compiler_output).
"""
body = self._strip_main(code)
src_code = f"""
#[allow(unused_imports, dead_code, unused_variables, unused_mut)]
{body}
fn main() {{
{assertion};
println!("PASS:{test_name}");
}}
"""
src_path = os.path.join(tmpdir, f"{test_name}.rs")
bin_path = os.path.join(tmpdir, test_name)
with open(src_path, "w", encoding="utf-8") as f:
f.write(src_code)
try:
proc = subprocess.run(
["rustc", src_path, "-o", bin_path, "--edition=2021"],
capture_output=True,
text=True,
timeout=self.COMPILE_TIMEOUT,
)
return proc.returncode == 0, bin_path, (proc.stdout + proc.stderr).strip()
except subprocess.TimeoutExpired:
return False, "", "Compile timed out for test."
except FileNotFoundError:
return False, "", "rustc not found."
def _run_tests(self, code: str, tests: List[Dict]) -> List[Dict]:
"""
Run each test assertion as a separate Rust binary.
Returns list of result dicts with keys: name, passed, ran, error.
"""
results = []
with tempfile.TemporaryDirectory() as tmpdir:
for i, test in enumerate(tests):
name = test.get("name", f"test_{i}")
assertion = test.get("test_assertion", "")
should_compile = test.get("should_compile", True)
result: Dict = {
"name": name,
"passed": False,
"ran": False,
"error": None,
}
if not assertion:
result["error"] = "No test assertion defined."
results.append(result)
continue
# Some tests are expected to fail compilation (should_compile=False)
# treat successful compilation + correct output as pass
bin_test_name = f"t{i}_{name[:20]}"
compiled, bin_path, compiler_out = self._build_test_binary(
code, assertion, tmpdir, bin_test_name
)
if not compiled:
if not should_compile:
# The problem's starter code deliberately doesn't compile;
# if the submission also doesn't compile this test β†’ skip
result["error"] = "Binary compile failed (expected for broken starter)."
else:
result["error"] = f"Compile error: {compiler_out[:300]}"
result["ran"] = False
results.append(result)
continue
# Run the binary
result["ran"] = True
try:
run_proc = subprocess.run(
[bin_path],
capture_output=True,
text=True,
timeout=self.RUN_TIMEOUT,
)
stdout = run_proc.stdout.strip()
if run_proc.returncode == 0 and f"PASS:{bin_test_name}" in stdout:
result["passed"] = True
else:
result["error"] = (
f"Test failed. Exit={run_proc.returncode}. "
f"stderr={run_proc.stderr[:200]}"
)
except subprocess.TimeoutExpired:
result["error"] = "Test execution timed out."
except Exception as exc:
result["error"] = str(exc)
results.append(result)
return results
# ------------------------------------------------------------------
# Elegance scoring
# ------------------------------------------------------------------
def _score_elegance(self, code: str) -> float:
"""
Heuristic code-quality score in [0, 1].
Penalties:
- Each `.unwrap()` call β†’ -0.15 (max -0.45)
- Each `.expect(` call β†’ -0.05 (max -0.15)
- Lines > 100 chars β†’ -0.05 per violation (max -0.20)
- `unsafe` blocks β†’ -0.20 unless problem requires FFI
Bonuses:
- Uses `?` operator β†’ +0.10
- Uses `match` expressions β†’ +0.05
- Has doc comments (`///`) β†’ +0.05
"""
score = 1.0
unwrap_count = len(re.findall(r'\.unwrap\(\)', code))
score -= min(unwrap_count * 0.15, 0.45)
expect_count = len(re.findall(r'\.expect\(', code))
score -= min(expect_count * 0.05, 0.15)
long_lines = sum(1 for line in code.splitlines() if len(line) > 100)
score -= min(long_lines * 0.05, 0.20)
if "unsafe" in code:
score -= 0.20
if "?" in code:
score += 0.10
if "match " in code or "match\n" in code:
score += 0.05
if "///" in code:
score += 0.05
return round(max(0.0, min(1.0, score)), 4)
# ------------------------------------------------------------------
# Efficiency scoring
# ------------------------------------------------------------------
def _score_efficiency(self, code: str, baseline_ms: float) -> float:
"""
Time the execution by compiling + running a minimal binary.
Score = min(1.0, baseline_ms / actual_ms).
Returns 0.0 if compilation or execution fails.
"""
body = self._strip_main(code)
# Build a binary with an empty main to measure startup + run overhead
test_src = f"""
#[allow(unused_imports, dead_code, unused_variables)]
{body}
fn main() {{}}
"""
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "eff.rs")
bin_path = os.path.join(tmpdir, "eff")
with open(src_path, "w", encoding="utf-8") as f:
f.write(test_src)
try:
# Compile
proc = subprocess.run(
["rustc", src_path, "-o", bin_path, "--edition=2021"],
capture_output=True, text=True, timeout=self.COMPILE_TIMEOUT,
)
if proc.returncode != 0:
return 0.0
# Time the run
t0 = time.monotonic()
run_proc = subprocess.run(
[bin_path], capture_output=True, timeout=self.RUN_TIMEOUT
)
elapsed_ms = (time.monotonic() - t0) * 1000.0
if run_proc.returncode != 0:
return 0.0
return round(min(1.0, baseline_ms / max(elapsed_ms, 0.1)), 4)
except Exception:
return 0.0