# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Julia Code Action Environment. This module provides a server-side environment implementation for executing Julia code actions using JuliaExecutor. """ import itertools import logging import re import time import uuid # Support both in-repo and standalone imports try: # In-repo imports (when running from OpenEnv repository) from openenv.core.env_server.interfaces import Action, Environment, Observation from ..models import JuliaAction, JuliaObservation, JuliaState from .julia_executor import JuliaExecutor from .julia_transforms import create_safe_julia_transform except ImportError: # Standalone imports (when environment is standalone) from openenv.core.env_server.interfaces import Action, Environment, Observation from models import JuliaAction, JuliaObservation, JuliaState from server.julia_executor import JuliaExecutor from server.julia_transforms import create_safe_julia_transform # Get logger for this module (inherits from julia_env logger) logger = logging.getLogger("julia_env.codeact") # Thread-safe request counter for tracking _request_counter = itertools.count(1) def _detect_infinite_loop(code: str) -> tuple[bool, str]: """ Detect potential infinite loops in Julia code. This function scans for `while true` loops without break/return/error statements. Args: code: Julia code string to analyze Returns: Tuple of (has_infinite_loop: bool, reason: str) """ # Remove comments and strings to avoid false positives # Remove single-line comments code_without_comments = re.sub(r"#.*", "", code) # Remove multi-line strings (triple quotes) code_without_comments = re.sub( r'""".*?"""', "", code_without_comments, flags=re.DOTALL ) # Remove single-line strings code_without_comments = re.sub(r'"[^"]*"', "", code_without_comments) # Find all while true blocks while_true_pattern = r"\bwhile\s+true\b" while_true_matches = list( re.finditer(while_true_pattern, code_without_comments, re.IGNORECASE) ) if not while_true_matches: return False, "" # For each while true, check if there's a break/return/error in the same block for match in while_true_matches: start_pos = match.end() # Find the end of this while block by counting 'while'/'end' pairs # Simplified heuristic: look for break/return/error before the corresponding 'end' remaining_code = code_without_comments[start_pos:] # Extract potential loop body (up to next 'end' keyword) # This is a simplified check - doesn't perfectly handle nested blocks end_match = re.search(r"\bend\b", remaining_code) if end_match: loop_body = remaining_code[: end_match.start()] else: loop_body = remaining_code # Check for loop exit mechanisms in this block has_break = re.search(r"\bbreak\b", loop_body) is not None has_return = re.search(r"\breturn\b", loop_body) is not None has_error = re.search(r"\berror\(", loop_body) is not None has_throw = re.search(r"\bthrow\(", loop_body) is not None has_exit = re.search(r"\bexit\(", loop_body) is not None if not (has_break or has_return or has_error or has_throw or has_exit): loop_preview = loop_body[:100].strip() return ( True, f"Infinite loop detected: 'while true' without break/return/error/throw. Preview: {loop_preview}", ) return False, "" class JuliaCodeActEnv(Environment): """ Julia Code Action Environment for executing code and tracking state. This environment executes Julia code submitted as JuliaAction during step, maintains the last exit code in its state, and returns results wrapped in JuliaObservation. Example: >>> env = JuliaCodeActEnv() >>> obs = env.reset() >>> action = JuliaAction(core_code='println("Hello, Julia!")', test_code='') >>> obs = env.step(action) >>> print(obs.stdout) # "Hello, Julia!\\n" >>> print(obs.exit_code) # 0 >>> print(env.state.last_exit_code) # 0 """ # Allow concurrent sessions - each session has its own isolated state SUPPORTS_CONCURRENT_SESSIONS = True def __init__(self, use_process_pool: bool = True): """ Initialize the Julia Code Act Environment. Args: use_process_pool: Use persistent Julia process pool for better performance and to avoid Juliaup lock contention (default: True) """ self._executor = JuliaExecutor(use_process_pool=use_process_pool) self._state = JuliaState() self.transform = create_safe_julia_transform() def reset(self, **kwargs) -> Observation: """ Reset environment for a fresh Julia execution session. Returns an empty JuliaObservation with exit_code=0. Note: Executor is reused to leverage process pool. """ self._state = JuliaState(episode_id=str(uuid.uuid4()), step_count=0) self._state.last_exit_code = 0 self._state.last_code_compiles = True # Don't recreate executor - reuse it to leverage process pool observation = JuliaObservation( stdout="", stderr="", exit_code=0, reward=0.0, metadata={"core_code": "", "test_code": ""}, tests_passed=0, tests_failed=0, code_compiles=True, ) observation = self._apply_transform(observation) return observation def step(self, action: Action, **kwargs) -> Observation: """ Execute Julia code and return the result as JuliaObservation. Optimized single-pass execution: - Runs core_code + test_code together - Infers compilation status from combined execution - 2x faster than double execution Args: action: JuliaAction with core_code and optional test_code **kwargs: Optional parameters including: - timeout: Execution timeout in seconds (default: 120) """ request_id = next(_request_counter) if not isinstance(action, JuliaAction): logger.error(f"[REQ-{request_id}] Invalid action type: {type(action)}") raise ValueError(f"Expected JuliaAction, got {type(action)}") # Get timeout from kwargs (default handled by executor) timeout = kwargs.get("timeout") # Log request details code_preview = ( action.core_code[:200] + "..." if len(action.core_code) > 200 else action.core_code ) logger.info(f"[REQ-{request_id}] === NEW EXECUTION REQUEST ===") logger.info( f"[REQ-{request_id}] Session: {self._state.episode_id}, Step: {self._state.step_count}" ) logger.info( f"[REQ-{request_id}] Code length: {len(action.core_code)} chars, Test length: {len(action.test_code or '')} chars" ) logger.debug(f"[REQ-{request_id}] Code preview: {code_preview}") logger.info( f"[REQ-{request_id}] Timeout: {timeout}s" if timeout else f"[REQ-{request_id}] Timeout: default" ) start_time = time.time() # Single execution: Run core_code + test_code together (if test_code provided) if action.test_code: combined_code = action.core_code + "\n\n" + action.test_code else: combined_code = action.core_code # Pre-execution check: detect infinite loops to avoid timeout has_infinite_loop, loop_reason = _detect_infinite_loop(action.core_code) if has_infinite_loop: logger.warning(f"[REQ-{request_id}] INFINITE LOOP DETECTED: {loop_reason}") # Update environment state self._state.step_count += 1 self._state.last_exit_code = 1 self._state.last_code_compiles = True # Code compiles but has infinite loop self._state.total_tests_passed = 0 self._state.total_tests_failed = 0 # Build observation with penalty observation = JuliaObservation( stdout="", stderr=f"Infinite loop detected (pre-execution check): {loop_reason}", exit_code=1, reward=-1.0, # Penalize infinite loops metadata={ "core_code": action.core_code, "test_code": action.test_code or "", "infinite_loop_detected": True, "infinite_loop_reason": loop_reason, }, tests_passed=0, tests_failed=0, code_compiles=True, # Code would compile, but not run ) logger.info( f"[REQ-{request_id}] RESULT: infinite_loop=True, " f"tests_passed=0, tests_failed=0, reward=-1.00" ) observation = self._apply_transform(observation) return observation try: full_result = self._executor.run(combined_code, timeout=timeout) execution_time = time.time() - start_time logger.info( f"[REQ-{request_id}] Execution completed in {execution_time:.2f}s, exit_code={full_result.exit_code}" ) # Log stderr if present (often contains errors or test output) if full_result.stderr: stderr_preview = ( full_result.stderr[:500] + "..." if len(full_result.stderr) > 500 else full_result.stderr ) logger.debug(f"[REQ-{request_id}] Stderr: {stderr_preview}") except Exception as e: execution_time = time.time() - start_time logger.error( f"[REQ-{request_id}] EXECUTION FAILED after {execution_time:.2f}s: {e}" ) raise # Parse test results from execution output tests_passed, tests_failed = self._parse_test_results( full_result.stdout, full_result.stderr ) # Infer compilation status from execution # If tests ran, code compiled successfully # If exit_code != 0 and no tests ran, code didn't compile code_compiles = ( full_result.exit_code == 0 # Clean execution or tests_passed > 0 # Some tests passed (code must have compiled) or tests_failed > 0 # Some tests failed (code compiled but tests failed) ) # If no tests detected and non-zero exit, check for compilation errors if not code_compiles and tests_passed == 0 and tests_failed == 0: # Check stderr for compilation errors stderr_lower = full_result.stderr.lower() if any( err in stderr_lower for err in ["error", "syntax", "undefined", "loadError"] ): code_compiles = False else: # If no clear compilation error, assume it compiled code_compiles = True # Calculate reward based on compilation and test results reward = self._calculate_reward(code_compiles, tests_passed, tests_failed) # Log final results logger.info( f"[REQ-{request_id}] RESULT: compiles={code_compiles}, " f"tests_passed={tests_passed}, tests_failed={tests_failed}, reward={reward:.2f}" ) # Update environment state self._state.step_count += 1 self._state.last_exit_code = full_result.exit_code self._state.last_code_compiles = code_compiles self._state.total_tests_passed = tests_passed self._state.total_tests_failed = tests_failed # Build observation observation = JuliaObservation( stdout=full_result.stdout, stderr=full_result.stderr, exit_code=full_result.exit_code, reward=reward, metadata={ "core_code": action.core_code, "test_code": action.test_code or "", }, tests_passed=tests_passed, tests_failed=tests_failed, code_compiles=code_compiles, ) # Apply safety and quality transforms observation = self._apply_transform(observation) return observation def _parse_test_results(self, stdout: str, stderr: str) -> tuple[int, int]: """ Parse Julia test output to count passed/failed tests. Julia's Test module outputs results like: "Test Summary: | Pass Fail Total Time" "Add function Tests | 1 1 2 1.5s" Also checks error messages: "Some tests did not pass: 1 passed, 1 failed, 0 errored, 0 broken." Args: stdout: Standard output from Julia execution stderr: Standard error from Julia execution Returns: Tuple of (tests_passed, tests_failed) """ # Combine stdout and stderr for analysis passed = 0 failed = 0 output = stdout + "\n" + stderr # Method 1: Look for "Some tests did not pass" error message # Pattern: "Some tests did not pass: X passed, Y failed, Z errored, W broken." error_pattern = r"Some tests did not pass:\s*(\d+)\s+passed,\s*(\d+)\s+failed,\s*(\d+)\s+errored" match = re.search(error_pattern, output) if match: passed = int(match.group(1)) failed = int(match.group(2)) errored = int(match.group(3)) return passed, failed + errored # Treat errors as failures # Method 2: Look for Test Summary table # Multiple possible formats: # All pass: "Test Summary: | Pass Total Time" # "My Tests | 3 3 0.5s" # Some fail: "Test Summary: | Pass Fail Total Time" # "My Tests | 2 1 3 0.5s" # All error: "Test Summary: | Error Total Time" # "My Tests | 3 3 0.9s" # Mixed: "Test Summary: | Pass Fail Error Total Time" # "My Tests | 1 1 1 3 0.5s" summary_lines = output.split("\n") for i, line in enumerate(summary_lines): if "Test Summary:" in line and i + 1 < len(summary_lines): header_line = line next_line = summary_lines[i + 1] # Determine which columns are present has_pass = "Pass" in header_line has_fail = "Fail" in header_line has_error = "Error" in header_line # Extract all numbers from the line all_numbers = re.findall(r"\d+", next_line) if not all_numbers: continue # Last number is always Total, second to last is Time (skip it) # Extract based on which columns exist if has_pass and has_fail and has_error: # Pass Fail Error Total Time if len(all_numbers) >= 5: passed = int(all_numbers[0]) failed = int(all_numbers[1]) + int( all_numbers[2] ) # Fail + Error return passed, failed elif has_pass and has_fail: # Pass Fail Total Time if len(all_numbers) >= 4: passed = int(all_numbers[0]) failed = int(all_numbers[1]) return passed, failed elif has_pass and has_error: # Pass Error Total Time if len(all_numbers) >= 4: passed = int(all_numbers[0]) failed = int(all_numbers[1]) # Treat errors as failures return passed, failed elif has_fail and has_error: # Fail Error Total Time (no passes) if len(all_numbers) >= 4: passed = 0 failed = int(all_numbers[0]) + int(all_numbers[1]) return passed, failed elif has_pass: # Pass Total Time (no failures/errors) if len(all_numbers) >= 3: passed = int(all_numbers[0]) failed = 0 return passed, failed elif has_error: # Error Total Time (all errors, no passes) if len(all_numbers) >= 3: passed = 0 failed = int(all_numbers[0]) # Treat all errors as failures return passed, failed elif has_fail: # Fail Total Time (all failures, no passes) if len(all_numbers) >= 3: passed = 0 failed = int(all_numbers[0]) return passed, failed return passed, failed def _calculate_reward( self, code_compiles: bool, tests_passed: int, tests_failed: int ) -> float: """ Normalized percentage-based reward for Julia GRPO. Returns rewards in [-1, 1.5] range for comparability across problems. """ if not code_compiles: return -1.0 total_tests = tests_passed + tests_failed if total_tests == 0: return 0.0 # No signal when no tests run pass_rate = tests_passed / total_tests # Scaled 0-1 with bonus for perfection if pass_rate == 1.0: return 1.5 # Bonus for passing all tests return pass_rate def _apply_transform(self, observation: JuliaObservation) -> JuliaObservation: """Apply safety and quality transforms to observation.""" if self.transform: observation = self.transform(observation) return observation @property def state(self) -> JuliaState: """Return current environment state.""" return self._state