Spaces:
Sleeping
Sleeping
File size: 18,680 Bytes
11952db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Julia Code Action Environment.
This module provides a server-side environment implementation for executing
Julia code actions using JuliaExecutor.
"""
import itertools
import logging
import re
import time
import uuid
# Support both in-repo and standalone imports
try:
# In-repo imports (when running from OpenEnv repository)
from openenv.core.env_server.interfaces import Action, Environment, Observation
from ..models import JuliaAction, JuliaObservation, JuliaState
from .julia_executor import JuliaExecutor
from .julia_transforms import create_safe_julia_transform
except ImportError:
# Standalone imports (when environment is standalone)
from openenv.core.env_server.interfaces import Action, Environment, Observation
from models import JuliaAction, JuliaObservation, JuliaState
from server.julia_executor import JuliaExecutor
from server.julia_transforms import create_safe_julia_transform
# Get logger for this module (inherits from julia_env logger)
logger = logging.getLogger("julia_env.codeact")
# Thread-safe request counter for tracking
_request_counter = itertools.count(1)
def _detect_infinite_loop(code: str) -> tuple[bool, str]:
"""
Detect potential infinite loops in Julia code.
This function scans for `while true` loops without break/return/error statements.
Args:
code: Julia code string to analyze
Returns:
Tuple of (has_infinite_loop: bool, reason: str)
"""
# Remove comments and strings to avoid false positives
# Remove single-line comments
code_without_comments = re.sub(r"#.*", "", code)
# Remove multi-line strings (triple quotes)
code_without_comments = re.sub(
r'""".*?"""', "", code_without_comments, flags=re.DOTALL
)
# Remove single-line strings
code_without_comments = re.sub(r'"[^"]*"', "", code_without_comments)
# Find all while true blocks
while_true_pattern = r"\bwhile\s+true\b"
while_true_matches = list(
re.finditer(while_true_pattern, code_without_comments, re.IGNORECASE)
)
if not while_true_matches:
return False, ""
# For each while true, check if there's a break/return/error in the same block
for match in while_true_matches:
start_pos = match.end()
# Find the end of this while block by counting 'while'/'end' pairs
# Simplified heuristic: look for break/return/error before the corresponding 'end'
remaining_code = code_without_comments[start_pos:]
# Extract potential loop body (up to next 'end' keyword)
# This is a simplified check - doesn't perfectly handle nested blocks
end_match = re.search(r"\bend\b", remaining_code)
if end_match:
loop_body = remaining_code[: end_match.start()]
else:
loop_body = remaining_code
# Check for loop exit mechanisms in this block
has_break = re.search(r"\bbreak\b", loop_body) is not None
has_return = re.search(r"\breturn\b", loop_body) is not None
has_error = re.search(r"\berror\(", loop_body) is not None
has_throw = re.search(r"\bthrow\(", loop_body) is not None
has_exit = re.search(r"\bexit\(", loop_body) is not None
if not (has_break or has_return or has_error or has_throw or has_exit):
loop_preview = loop_body[:100].strip()
return (
True,
f"Infinite loop detected: 'while true' without break/return/error/throw. Preview: {loop_preview}",
)
return False, ""
class JuliaCodeActEnv(Environment):
"""
Julia Code Action Environment for executing code and tracking state.
This environment executes Julia code submitted as JuliaAction during step,
maintains the last exit code in its state, and returns results wrapped
in JuliaObservation.
Example:
>>> env = JuliaCodeActEnv()
>>> obs = env.reset()
>>> action = JuliaAction(core_code='println("Hello, Julia!")', test_code='')
>>> obs = env.step(action)
>>> print(obs.stdout) # "Hello, Julia!\\n"
>>> print(obs.exit_code) # 0
>>> print(env.state.last_exit_code) # 0
"""
# Allow concurrent sessions - each session has its own isolated state
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self, use_process_pool: bool = True):
"""
Initialize the Julia Code Act Environment.
Args:
use_process_pool: Use persistent Julia process pool for better performance
and to avoid Juliaup lock contention (default: True)
"""
self._executor = JuliaExecutor(use_process_pool=use_process_pool)
self._state = JuliaState()
self.transform = create_safe_julia_transform()
def reset(self, **kwargs) -> Observation:
"""
Reset environment for a fresh Julia execution session.
Returns an empty JuliaObservation with exit_code=0.
Note: Executor is reused to leverage process pool.
"""
self._state = JuliaState(episode_id=str(uuid.uuid4()), step_count=0)
self._state.last_exit_code = 0
self._state.last_code_compiles = True
# Don't recreate executor - reuse it to leverage process pool
observation = JuliaObservation(
stdout="",
stderr="",
exit_code=0,
reward=0.0,
metadata={"core_code": "", "test_code": ""},
tests_passed=0,
tests_failed=0,
code_compiles=True,
)
observation = self._apply_transform(observation)
return observation
def step(self, action: Action, **kwargs) -> Observation:
"""
Execute Julia code and return the result as JuliaObservation.
Optimized single-pass execution:
- Runs core_code + test_code together
- Infers compilation status from combined execution
- 2x faster than double execution
Args:
action: JuliaAction with core_code and optional test_code
**kwargs: Optional parameters including:
- timeout: Execution timeout in seconds (default: 120)
"""
request_id = next(_request_counter)
if not isinstance(action, JuliaAction):
logger.error(f"[REQ-{request_id}] Invalid action type: {type(action)}")
raise ValueError(f"Expected JuliaAction, got {type(action)}")
# Get timeout from kwargs (default handled by executor)
timeout = kwargs.get("timeout")
# Log request details
code_preview = (
action.core_code[:200] + "..."
if len(action.core_code) > 200
else action.core_code
)
logger.info(f"[REQ-{request_id}] === NEW EXECUTION REQUEST ===")
logger.info(
f"[REQ-{request_id}] Session: {self._state.episode_id}, Step: {self._state.step_count}"
)
logger.info(
f"[REQ-{request_id}] Code length: {len(action.core_code)} chars, Test length: {len(action.test_code or '')} chars"
)
logger.debug(f"[REQ-{request_id}] Code preview: {code_preview}")
logger.info(
f"[REQ-{request_id}] Timeout: {timeout}s"
if timeout
else f"[REQ-{request_id}] Timeout: default"
)
start_time = time.time()
# Single execution: Run core_code + test_code together (if test_code provided)
if action.test_code:
combined_code = action.core_code + "\n\n" + action.test_code
else:
combined_code = action.core_code
# Pre-execution check: detect infinite loops to avoid timeout
has_infinite_loop, loop_reason = _detect_infinite_loop(action.core_code)
if has_infinite_loop:
logger.warning(f"[REQ-{request_id}] INFINITE LOOP DETECTED: {loop_reason}")
# Update environment state
self._state.step_count += 1
self._state.last_exit_code = 1
self._state.last_code_compiles = True # Code compiles but has infinite loop
self._state.total_tests_passed = 0
self._state.total_tests_failed = 0
# Build observation with penalty
observation = JuliaObservation(
stdout="",
stderr=f"Infinite loop detected (pre-execution check): {loop_reason}",
exit_code=1,
reward=-1.0, # Penalize infinite loops
metadata={
"core_code": action.core_code,
"test_code": action.test_code or "",
"infinite_loop_detected": True,
"infinite_loop_reason": loop_reason,
},
tests_passed=0,
tests_failed=0,
code_compiles=True, # Code would compile, but not run
)
logger.info(
f"[REQ-{request_id}] RESULT: infinite_loop=True, "
f"tests_passed=0, tests_failed=0, reward=-1.00"
)
observation = self._apply_transform(observation)
return observation
try:
full_result = self._executor.run(combined_code, timeout=timeout)
execution_time = time.time() - start_time
logger.info(
f"[REQ-{request_id}] Execution completed in {execution_time:.2f}s, exit_code={full_result.exit_code}"
)
# Log stderr if present (often contains errors or test output)
if full_result.stderr:
stderr_preview = (
full_result.stderr[:500] + "..."
if len(full_result.stderr) > 500
else full_result.stderr
)
logger.debug(f"[REQ-{request_id}] Stderr: {stderr_preview}")
except Exception as e:
execution_time = time.time() - start_time
logger.error(
f"[REQ-{request_id}] EXECUTION FAILED after {execution_time:.2f}s: {e}"
)
raise
# Parse test results from execution output
tests_passed, tests_failed = self._parse_test_results(
full_result.stdout, full_result.stderr
)
# Infer compilation status from execution
# If tests ran, code compiled successfully
# If exit_code != 0 and no tests ran, code didn't compile
code_compiles = (
full_result.exit_code == 0 # Clean execution
or tests_passed > 0 # Some tests passed (code must have compiled)
or tests_failed > 0 # Some tests failed (code compiled but tests failed)
)
# If no tests detected and non-zero exit, check for compilation errors
if not code_compiles and tests_passed == 0 and tests_failed == 0:
# Check stderr for compilation errors
stderr_lower = full_result.stderr.lower()
if any(
err in stderr_lower
for err in ["error", "syntax", "undefined", "loadError"]
):
code_compiles = False
else:
# If no clear compilation error, assume it compiled
code_compiles = True
# Calculate reward based on compilation and test results
reward = self._calculate_reward(code_compiles, tests_passed, tests_failed)
# Log final results
logger.info(
f"[REQ-{request_id}] RESULT: compiles={code_compiles}, "
f"tests_passed={tests_passed}, tests_failed={tests_failed}, reward={reward:.2f}"
)
# Update environment state
self._state.step_count += 1
self._state.last_exit_code = full_result.exit_code
self._state.last_code_compiles = code_compiles
self._state.total_tests_passed = tests_passed
self._state.total_tests_failed = tests_failed
# Build observation
observation = JuliaObservation(
stdout=full_result.stdout,
stderr=full_result.stderr,
exit_code=full_result.exit_code,
reward=reward,
metadata={
"core_code": action.core_code,
"test_code": action.test_code or "",
},
tests_passed=tests_passed,
tests_failed=tests_failed,
code_compiles=code_compiles,
)
# Apply safety and quality transforms
observation = self._apply_transform(observation)
return observation
def _parse_test_results(self, stdout: str, stderr: str) -> tuple[int, int]:
"""
Parse Julia test output to count passed/failed tests.
Julia's Test module outputs results like:
"Test Summary: | Pass Fail Total Time"
"Add function Tests | 1 1 2 1.5s"
Also checks error messages:
"Some tests did not pass: 1 passed, 1 failed, 0 errored, 0 broken."
Args:
stdout: Standard output from Julia execution
stderr: Standard error from Julia execution
Returns:
Tuple of (tests_passed, tests_failed)
"""
# Combine stdout and stderr for analysis
passed = 0
failed = 0
output = stdout + "\n" + stderr
# Method 1: Look for "Some tests did not pass" error message
# Pattern: "Some tests did not pass: X passed, Y failed, Z errored, W broken."
error_pattern = r"Some tests did not pass:\s*(\d+)\s+passed,\s*(\d+)\s+failed,\s*(\d+)\s+errored"
match = re.search(error_pattern, output)
if match:
passed = int(match.group(1))
failed = int(match.group(2))
errored = int(match.group(3))
return passed, failed + errored # Treat errors as failures
# Method 2: Look for Test Summary table
# Multiple possible formats:
# All pass: "Test Summary: | Pass Total Time"
# "My Tests | 3 3 0.5s"
# Some fail: "Test Summary: | Pass Fail Total Time"
# "My Tests | 2 1 3 0.5s"
# All error: "Test Summary: | Error Total Time"
# "My Tests | 3 3 0.9s"
# Mixed: "Test Summary: | Pass Fail Error Total Time"
# "My Tests | 1 1 1 3 0.5s"
summary_lines = output.split("\n")
for i, line in enumerate(summary_lines):
if "Test Summary:" in line and i + 1 < len(summary_lines):
header_line = line
next_line = summary_lines[i + 1]
# Determine which columns are present
has_pass = "Pass" in header_line
has_fail = "Fail" in header_line
has_error = "Error" in header_line
# Extract all numbers from the line
all_numbers = re.findall(r"\d+", next_line)
if not all_numbers:
continue
# Last number is always Total, second to last is Time (skip it)
# Extract based on which columns exist
if has_pass and has_fail and has_error:
# Pass Fail Error Total Time
if len(all_numbers) >= 5:
passed = int(all_numbers[0])
failed = int(all_numbers[1]) + int(
all_numbers[2]
) # Fail + Error
return passed, failed
elif has_pass and has_fail:
# Pass Fail Total Time
if len(all_numbers) >= 4:
passed = int(all_numbers[0])
failed = int(all_numbers[1])
return passed, failed
elif has_pass and has_error:
# Pass Error Total Time
if len(all_numbers) >= 4:
passed = int(all_numbers[0])
failed = int(all_numbers[1]) # Treat errors as failures
return passed, failed
elif has_fail and has_error:
# Fail Error Total Time (no passes)
if len(all_numbers) >= 4:
passed = 0
failed = int(all_numbers[0]) + int(all_numbers[1])
return passed, failed
elif has_pass:
# Pass Total Time (no failures/errors)
if len(all_numbers) >= 3:
passed = int(all_numbers[0])
failed = 0
return passed, failed
elif has_error:
# Error Total Time (all errors, no passes)
if len(all_numbers) >= 3:
passed = 0
failed = int(all_numbers[0]) # Treat all errors as failures
return passed, failed
elif has_fail:
# Fail Total Time (all failures, no passes)
if len(all_numbers) >= 3:
passed = 0
failed = int(all_numbers[0])
return passed, failed
return passed, failed
def _calculate_reward(
self, code_compiles: bool, tests_passed: int, tests_failed: int
) -> float:
"""
Normalized percentage-based reward for Julia GRPO.
Returns rewards in [-1, 1.5] range for comparability across problems.
"""
if not code_compiles:
return -1.0
total_tests = tests_passed + tests_failed
if total_tests == 0:
return 0.0 # No signal when no tests run
pass_rate = tests_passed / total_tests
# Scaled 0-1 with bonus for perfection
if pass_rate == 1.0:
return 1.5 # Bonus for passing all tests
return pass_rate
def _apply_transform(self, observation: JuliaObservation) -> JuliaObservation:
"""Apply safety and quality transforms to observation."""
if self.transform:
observation = self.transform(observation)
return observation
@property
def state(self) -> JuliaState:
"""Return current environment state."""
return self._state
|