julia_env-pr-170 / src /envs /julia_env /server /julia_codeact_env.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
be32845 verified
"""
Julia Code Action Environment.
This environment mirrors the PythonCodeActEnv but runs Julia code instead.
It executes Julia code using JuliaExecutor, captures output,
tracks the last exit code, and returns a JuliaObservation.
"""
import re
import uuid
from core.env_server import Environment
from core.tools import JuliaExecutor
from ..models import JuliaAction, JuliaObservation, JuliaState
from .julia_transforms import create_safe_julia_transform
class JuliaCodeActEnv(Environment):
"""
Julia Code Action Environment for executing code and tracking state.
This environment executes Julia code submitted as CodeAction during step,
maintains the last exit code in its state, and returns results wrapped
in CodeObservation.
Example:
>>> env = JuliaCodeActEnv()
>>> obs = env.reset()
>>> action = CodeAction(code='println("Hello, Julia!")')
>>> obs = env.step(action)
>>> print(obs.stdout) # "Hello, Julia!\n"
>>> print(obs.exit_code) # 0
>>> print(env.state.last_exit_code) # 0
"""
def __init__(self):
"""Initialize the Julia Code Act Environment."""
self._executor = JuliaExecutor()
self._state = JuliaState()
self.transform = create_safe_julia_transform()
def reset(self) -> JuliaObservation:
"""
Reset environment for a fresh Julia execution session.
Returns an empty JuliaObservation with exit_code=0.
"""
self._state = JuliaState(episode_id=str(uuid.uuid4()), step_count=0)
self._state.last_exit_code = 0
self._state.last_code_compiles = True
self._executor = JuliaExecutor()
observation = JuliaObservation(
stdout="",
stderr="",
exit_code=0,
reward=0.0,
metadata={"core_code": "", "test_code": ""},
tests_passed=0,
tests_failed=0,
code_compiles=True,
)
observation = self._apply_transform(observation)
return observation
def step(self, action: JuliaAction) -> JuliaObservation:
"""
Execute Julia code and return the result as JuliaObservation.
Optimized single-pass execution:
- Runs core_code + test_code together
- Infers compilation status from combined execution
- 2x faster than double execution
"""
if not isinstance(action, JuliaAction):
raise ValueError(f"Expected JuliaAction, got {type(action)}")
# Single execution: Run core_code + test_code together
combined_code = action.core_code + "\n\n" + action.test_code
full_result = self._executor.run(combined_code)
# Parse test results from execution output
tests_passed, tests_failed = self._parse_test_results(
full_result.stdout, full_result.stderr
)
# Infer compilation status from execution
# If tests ran, code compiled successfully
# If exit_code != 0 and no tests ran, code didn't compile
code_compiles = (
full_result.exit_code == 0 # Clean execution
or tests_passed > 0 # Some tests passed (code must have compiled)
or tests_failed > 0 # Some tests failed (code compiled but tests failed)
)
# If no tests detected and non-zero exit, check for compilation errors
if not code_compiles and tests_passed == 0 and tests_failed == 0:
# Check stderr for compilation errors
stderr_lower = full_result.stderr.lower()
if any(
err in stderr_lower
for err in ["error", "syntax", "undefined", "loadError"]
):
code_compiles = False
else:
# If no clear compilation error, assume it compiled
code_compiles = True
# Calculate reward based on compilation and test results
reward = self._calculate_reward(code_compiles, tests_passed, tests_failed)
# Update environment state
self._state.step_count += 1
self._state.last_exit_code = full_result.exit_code
self._state.last_code_compiles = code_compiles
self._state.total_tests_passed = tests_passed
self._state.total_tests_failed = tests_failed
# Build observation
observation = JuliaObservation(
stdout=full_result.stdout,
stderr=full_result.stderr,
exit_code=full_result.exit_code,
reward=reward,
metadata={"core_code": action.core_code, "test_code": action.test_code},
tests_passed=tests_passed,
tests_failed=tests_failed,
code_compiles=code_compiles,
)
# Apply safety and quality transforms
observation = self._apply_transform(observation)
return observation
def _parse_test_results(self, stdout: str, stderr: str) -> tuple[int, int]:
"""
Parse Julia test output to count passed/failed tests.
Julia's Test module outputs results like:
"Test Summary: | Pass Fail Total Time"
"Add function Tests | 1 1 2 1.5s"
Also checks error messages:
"Some tests did not pass: 1 passed, 1 failed, 0 errored, 0 broken."
Args:
stdout: Standard output from Julia execution
stderr: Standard error from Julia execution
Returns:
Tuple of (tests_passed, tests_failed)
"""
# Combine stdout and stderr for analysis
passed = 0
failed = 0
output = stdout + "\n" + stderr
# Method 1: Look for "Some tests did not pass" error message
# Pattern: "Some tests did not pass: X passed, Y failed, Z errored, W broken."
error_pattern = r"Some tests did not pass:\s*(\d+)\s+passed,\s*(\d+)\s+failed,\s*(\d+)\s+errored"
match = re.search(error_pattern, output)
if match:
passed = int(match.group(1))
failed = int(match.group(2))
errored = int(match.group(3))
return passed, failed + errored # Treat errors as failures
# Method 2: Look for Test Summary table
# Multiple possible formats:
# All pass: "Test Summary: | Pass Total Time"
# "My Tests | 3 3 0.5s"
# Some fail: "Test Summary: | Pass Fail Total Time"
# "My Tests | 2 1 3 0.5s"
# All error: "Test Summary: | Error Total Time"
# "My Tests | 3 3 0.9s"
# Mixed: "Test Summary: | Pass Fail Error Total Time"
# "My Tests | 1 1 1 3 0.5s"
summary_lines = output.split("\n")
for i, line in enumerate(summary_lines):
if "Test Summary:" in line and i + 1 < len(summary_lines):
header_line = line
next_line = summary_lines[i + 1]
# Determine which columns are present
has_pass = "Pass" in header_line
has_fail = "Fail" in header_line
has_error = "Error" in header_line
# Extract all numbers from the line
all_numbers = re.findall(r"\d+", next_line)
if not all_numbers:
continue
# Last number is always Total, second to last is Time (skip it)
# Extract based on which columns exist
if has_pass and has_fail and has_error:
# Pass Fail Error Total Time
if len(all_numbers) >= 5:
passed = int(all_numbers[0])
failed = int(all_numbers[1]) + int(
all_numbers[2]
) # Fail + Error
return passed, failed
elif has_pass and has_fail:
# Pass Fail Total Time
if len(all_numbers) >= 4:
passed = int(all_numbers[0])
failed = int(all_numbers[1])
return passed, failed
elif has_pass and has_error:
# Pass Error Total Time
if len(all_numbers) >= 4:
passed = int(all_numbers[0])
failed = int(all_numbers[1]) # Treat errors as failures
return passed, failed
elif has_fail and has_error:
# Fail Error Total Time (no passes)
if len(all_numbers) >= 4:
passed = 0
failed = int(all_numbers[0]) + int(all_numbers[1])
return passed, failed
elif has_pass:
# Pass Total Time (no failures/errors)
if len(all_numbers) >= 3:
passed = int(all_numbers[0])
failed = 0
return passed, failed
elif has_error:
# Error Total Time (all errors, no passes)
if len(all_numbers) >= 3:
passed = 0
failed = int(all_numbers[0]) # Treat all errors as failures
return passed, failed
elif has_fail:
# Fail Total Time (all failures, no passes)
if len(all_numbers) >= 3:
passed = 0
failed = int(all_numbers[0])
return passed, failed
return passed, failed
def _calculate_reward(
self, code_compiles: bool, tests_passed: int, tests_failed: int
) -> int:
"""
Optimized integer reward for Julia GRPO.
Strong signal shaping: rewards correctness, penalizes instability,
and gives higher incentive for near-perfect results.
"""
# Code doesn't compile — immediate strong penalty
if not code_compiles:
return -3
reward = 1
reward += 3 * tests_passed - 1 * tests_failed
if tests_failed == 0 and tests_passed > 0:
reward += 2
return reward
def _apply_transform(self, observation: JuliaObservation) -> JuliaObservation:
"""Apply safety and quality transforms to observation."""
if self.transform:
observation = self.transform(observation)
return observation
@property
def state(self) -> JuliaState:
"""Return current environment state."""
return self._state