File size: 18,680 Bytes
11952db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Julia Code Action Environment.

This module provides a server-side environment implementation for executing
Julia code actions using JuliaExecutor.
"""

import itertools
import logging
import re
import time
import uuid

# Support both in-repo and standalone imports
try:
    # In-repo imports (when running from OpenEnv repository)
    from openenv.core.env_server.interfaces import Action, Environment, Observation
    from ..models import JuliaAction, JuliaObservation, JuliaState
    from .julia_executor import JuliaExecutor
    from .julia_transforms import create_safe_julia_transform
except ImportError:
    # Standalone imports (when environment is standalone)
    from openenv.core.env_server.interfaces import Action, Environment, Observation
    from models import JuliaAction, JuliaObservation, JuliaState
    from server.julia_executor import JuliaExecutor
    from server.julia_transforms import create_safe_julia_transform

# Get logger for this module (inherits from julia_env logger)
logger = logging.getLogger("julia_env.codeact")

# Thread-safe request counter for tracking
_request_counter = itertools.count(1)


def _detect_infinite_loop(code: str) -> tuple[bool, str]:
    """
    Detect potential infinite loops in Julia code.

    This function scans for `while true` loops without break/return/error statements.

    Args:
        code: Julia code string to analyze

    Returns:
        Tuple of (has_infinite_loop: bool, reason: str)
    """
    # Remove comments and strings to avoid false positives
    # Remove single-line comments
    code_without_comments = re.sub(r"#.*", "", code)
    # Remove multi-line strings (triple quotes)
    code_without_comments = re.sub(
        r'""".*?"""', "", code_without_comments, flags=re.DOTALL
    )
    # Remove single-line strings
    code_without_comments = re.sub(r'"[^"]*"', "", code_without_comments)

    # Find all while true blocks
    while_true_pattern = r"\bwhile\s+true\b"
    while_true_matches = list(
        re.finditer(while_true_pattern, code_without_comments, re.IGNORECASE)
    )

    if not while_true_matches:
        return False, ""

    # For each while true, check if there's a break/return/error in the same block
    for match in while_true_matches:
        start_pos = match.end()

        # Find the end of this while block by counting 'while'/'end' pairs
        # Simplified heuristic: look for break/return/error before the corresponding 'end'
        remaining_code = code_without_comments[start_pos:]

        # Extract potential loop body (up to next 'end' keyword)
        # This is a simplified check - doesn't perfectly handle nested blocks
        end_match = re.search(r"\bend\b", remaining_code)
        if end_match:
            loop_body = remaining_code[: end_match.start()]
        else:
            loop_body = remaining_code

        # Check for loop exit mechanisms in this block
        has_break = re.search(r"\bbreak\b", loop_body) is not None
        has_return = re.search(r"\breturn\b", loop_body) is not None
        has_error = re.search(r"\berror\(", loop_body) is not None
        has_throw = re.search(r"\bthrow\(", loop_body) is not None
        has_exit = re.search(r"\bexit\(", loop_body) is not None

        if not (has_break or has_return or has_error or has_throw or has_exit):
            loop_preview = loop_body[:100].strip()
            return (
                True,
                f"Infinite loop detected: 'while true' without break/return/error/throw. Preview: {loop_preview}",
            )

    return False, ""


class JuliaCodeActEnv(Environment):
    """
    Julia Code Action Environment for executing code and tracking state.

    This environment executes Julia code submitted as JuliaAction during step,
    maintains the last exit code in its state, and returns results wrapped
    in JuliaObservation.

    Example:
        >>> env = JuliaCodeActEnv()
        >>> obs = env.reset()
        >>> action = JuliaAction(core_code='println("Hello, Julia!")', test_code='')
        >>> obs = env.step(action)
        >>> print(obs.stdout)  # "Hello, Julia!\\n"
        >>> print(obs.exit_code)  # 0
        >>> print(env.state.last_exit_code)  # 0
    """

    # Allow concurrent sessions - each session has its own isolated state
    SUPPORTS_CONCURRENT_SESSIONS = True

    def __init__(self, use_process_pool: bool = True):
        """
        Initialize the Julia Code Act Environment.

        Args:
            use_process_pool: Use persistent Julia process pool for better performance
                            and to avoid Juliaup lock contention (default: True)
        """
        self._executor = JuliaExecutor(use_process_pool=use_process_pool)
        self._state = JuliaState()
        self.transform = create_safe_julia_transform()

    def reset(self, **kwargs) -> Observation:
        """
        Reset environment for a fresh Julia execution session.
        Returns an empty JuliaObservation with exit_code=0.

        Note: Executor is reused to leverage process pool.
        """
        self._state = JuliaState(episode_id=str(uuid.uuid4()), step_count=0)
        self._state.last_exit_code = 0
        self._state.last_code_compiles = True
        # Don't recreate executor - reuse it to leverage process pool

        observation = JuliaObservation(
            stdout="",
            stderr="",
            exit_code=0,
            reward=0.0,
            metadata={"core_code": "", "test_code": ""},
            tests_passed=0,
            tests_failed=0,
            code_compiles=True,
        )

        observation = self._apply_transform(observation)
        return observation

    def step(self, action: Action, **kwargs) -> Observation:
        """
        Execute Julia code and return the result as JuliaObservation.

        Optimized single-pass execution:
        - Runs core_code + test_code together
        - Infers compilation status from combined execution
        - 2x faster than double execution

        Args:
            action: JuliaAction with core_code and optional test_code
            **kwargs: Optional parameters including:
                - timeout: Execution timeout in seconds (default: 120)
        """
        request_id = next(_request_counter)

        if not isinstance(action, JuliaAction):
            logger.error(f"[REQ-{request_id}] Invalid action type: {type(action)}")
            raise ValueError(f"Expected JuliaAction, got {type(action)}")

        # Get timeout from kwargs (default handled by executor)
        timeout = kwargs.get("timeout")

        # Log request details
        code_preview = (
            action.core_code[:200] + "..."
            if len(action.core_code) > 200
            else action.core_code
        )
        logger.info(f"[REQ-{request_id}] === NEW EXECUTION REQUEST ===")
        logger.info(
            f"[REQ-{request_id}] Session: {self._state.episode_id}, Step: {self._state.step_count}"
        )
        logger.info(
            f"[REQ-{request_id}] Code length: {len(action.core_code)} chars, Test length: {len(action.test_code or '')} chars"
        )
        logger.debug(f"[REQ-{request_id}] Code preview: {code_preview}")
        logger.info(
            f"[REQ-{request_id}] Timeout: {timeout}s"
            if timeout
            else f"[REQ-{request_id}] Timeout: default"
        )

        start_time = time.time()

        # Single execution: Run core_code + test_code together (if test_code provided)
        if action.test_code:
            combined_code = action.core_code + "\n\n" + action.test_code
        else:
            combined_code = action.core_code

        # Pre-execution check: detect infinite loops to avoid timeout
        has_infinite_loop, loop_reason = _detect_infinite_loop(action.core_code)
        if has_infinite_loop:
            logger.warning(f"[REQ-{request_id}] INFINITE LOOP DETECTED: {loop_reason}")

            # Update environment state
            self._state.step_count += 1
            self._state.last_exit_code = 1
            self._state.last_code_compiles = True  # Code compiles but has infinite loop
            self._state.total_tests_passed = 0
            self._state.total_tests_failed = 0

            # Build observation with penalty
            observation = JuliaObservation(
                stdout="",
                stderr=f"Infinite loop detected (pre-execution check): {loop_reason}",
                exit_code=1,
                reward=-1.0,  # Penalize infinite loops
                metadata={
                    "core_code": action.core_code,
                    "test_code": action.test_code or "",
                    "infinite_loop_detected": True,
                    "infinite_loop_reason": loop_reason,
                },
                tests_passed=0,
                tests_failed=0,
                code_compiles=True,  # Code would compile, but not run
            )

            logger.info(
                f"[REQ-{request_id}] RESULT: infinite_loop=True, "
                f"tests_passed=0, tests_failed=0, reward=-1.00"
            )

            observation = self._apply_transform(observation)
            return observation

        try:
            full_result = self._executor.run(combined_code, timeout=timeout)
            execution_time = time.time() - start_time

            logger.info(
                f"[REQ-{request_id}] Execution completed in {execution_time:.2f}s, exit_code={full_result.exit_code}"
            )

            # Log stderr if present (often contains errors or test output)
            if full_result.stderr:
                stderr_preview = (
                    full_result.stderr[:500] + "..."
                    if len(full_result.stderr) > 500
                    else full_result.stderr
                )
                logger.debug(f"[REQ-{request_id}] Stderr: {stderr_preview}")

        except Exception as e:
            execution_time = time.time() - start_time
            logger.error(
                f"[REQ-{request_id}] EXECUTION FAILED after {execution_time:.2f}s: {e}"
            )
            raise

        # Parse test results from execution output
        tests_passed, tests_failed = self._parse_test_results(
            full_result.stdout, full_result.stderr
        )

        # Infer compilation status from execution
        # If tests ran, code compiled successfully
        # If exit_code != 0 and no tests ran, code didn't compile
        code_compiles = (
            full_result.exit_code == 0  # Clean execution
            or tests_passed > 0  # Some tests passed (code must have compiled)
            or tests_failed > 0  # Some tests failed (code compiled but tests failed)
        )

        # If no tests detected and non-zero exit, check for compilation errors
        if not code_compiles and tests_passed == 0 and tests_failed == 0:
            # Check stderr for compilation errors
            stderr_lower = full_result.stderr.lower()
            if any(
                err in stderr_lower
                for err in ["error", "syntax", "undefined", "loadError"]
            ):
                code_compiles = False
            else:
                # If no clear compilation error, assume it compiled
                code_compiles = True

        # Calculate reward based on compilation and test results
        reward = self._calculate_reward(code_compiles, tests_passed, tests_failed)

        # Log final results
        logger.info(
            f"[REQ-{request_id}] RESULT: compiles={code_compiles}, "
            f"tests_passed={tests_passed}, tests_failed={tests_failed}, reward={reward:.2f}"
        )

        # Update environment state
        self._state.step_count += 1
        self._state.last_exit_code = full_result.exit_code
        self._state.last_code_compiles = code_compiles
        self._state.total_tests_passed = tests_passed
        self._state.total_tests_failed = tests_failed

        # Build observation
        observation = JuliaObservation(
            stdout=full_result.stdout,
            stderr=full_result.stderr,
            exit_code=full_result.exit_code,
            reward=reward,
            metadata={
                "core_code": action.core_code,
                "test_code": action.test_code or "",
            },
            tests_passed=tests_passed,
            tests_failed=tests_failed,
            code_compiles=code_compiles,
        )

        # Apply safety and quality transforms
        observation = self._apply_transform(observation)

        return observation

    def _parse_test_results(self, stdout: str, stderr: str) -> tuple[int, int]:
        """
        Parse Julia test output to count passed/failed tests.

        Julia's Test module outputs results like:
        "Test Summary:      | Pass  Fail  Total  Time"
        "Add function Tests |    1     1      2  1.5s"

        Also checks error messages:
        "Some tests did not pass: 1 passed, 1 failed, 0 errored, 0 broken."

        Args:
            stdout: Standard output from Julia execution
            stderr: Standard error from Julia execution

        Returns:
            Tuple of (tests_passed, tests_failed)
        """
        # Combine stdout and stderr for analysis
        passed = 0
        failed = 0
        output = stdout + "\n" + stderr

        # Method 1: Look for "Some tests did not pass" error message
        # Pattern: "Some tests did not pass: X passed, Y failed, Z errored, W broken."
        error_pattern = r"Some tests did not pass:\s*(\d+)\s+passed,\s*(\d+)\s+failed,\s*(\d+)\s+errored"
        match = re.search(error_pattern, output)

        if match:
            passed = int(match.group(1))
            failed = int(match.group(2))
            errored = int(match.group(3))
            return passed, failed + errored  # Treat errors as failures

        # Method 2: Look for Test Summary table
        # Multiple possible formats:
        # All pass:     "Test Summary: | Pass  Total  Time"
        #               "My Tests     |    3      3  0.5s"
        # Some fail:    "Test Summary: | Pass  Fail  Total  Time"
        #               "My Tests     |    2     1      3  0.5s"
        # All error:    "Test Summary: | Error  Total  Time"
        #               "My Tests     |     3      3  0.9s"
        # Mixed:        "Test Summary: | Pass  Fail  Error  Total  Time"
        #               "My Tests     |    1     1      1      3  0.5s"
        summary_lines = output.split("\n")
        for i, line in enumerate(summary_lines):
            if "Test Summary:" in line and i + 1 < len(summary_lines):
                header_line = line
                next_line = summary_lines[i + 1]

                # Determine which columns are present
                has_pass = "Pass" in header_line
                has_fail = "Fail" in header_line
                has_error = "Error" in header_line

                # Extract all numbers from the line
                all_numbers = re.findall(r"\d+", next_line)
                if not all_numbers:
                    continue

                # Last number is always Total, second to last is Time (skip it)
                # Extract based on which columns exist
                if has_pass and has_fail and has_error:
                    # Pass  Fail  Error  Total  Time
                    if len(all_numbers) >= 5:
                        passed = int(all_numbers[0])
                        failed = int(all_numbers[1]) + int(
                            all_numbers[2]
                        )  # Fail + Error
                        return passed, failed
                elif has_pass and has_fail:
                    # Pass  Fail  Total  Time
                    if len(all_numbers) >= 4:
                        passed = int(all_numbers[0])
                        failed = int(all_numbers[1])
                        return passed, failed
                elif has_pass and has_error:
                    # Pass  Error  Total  Time
                    if len(all_numbers) >= 4:
                        passed = int(all_numbers[0])
                        failed = int(all_numbers[1])  # Treat errors as failures
                        return passed, failed
                elif has_fail and has_error:
                    # Fail  Error  Total  Time (no passes)
                    if len(all_numbers) >= 4:
                        passed = 0
                        failed = int(all_numbers[0]) + int(all_numbers[1])
                        return passed, failed
                elif has_pass:
                    # Pass  Total  Time (no failures/errors)
                    if len(all_numbers) >= 3:
                        passed = int(all_numbers[0])
                        failed = 0
                        return passed, failed
                elif has_error:
                    # Error  Total  Time (all errors, no passes)
                    if len(all_numbers) >= 3:
                        passed = 0
                        failed = int(all_numbers[0])  # Treat all errors as failures
                        return passed, failed
                elif has_fail:
                    # Fail  Total  Time (all failures, no passes)
                    if len(all_numbers) >= 3:
                        passed = 0
                        failed = int(all_numbers[0])
                        return passed, failed

        return passed, failed

    def _calculate_reward(
        self, code_compiles: bool, tests_passed: int, tests_failed: int
    ) -> float:
        """
        Normalized percentage-based reward for Julia GRPO.
        Returns rewards in [-1, 1.5] range for comparability across problems.
        """
        if not code_compiles:
            return -1.0

        total_tests = tests_passed + tests_failed
        if total_tests == 0:
            return 0.0  # No signal when no tests run

        pass_rate = tests_passed / total_tests

        # Scaled 0-1 with bonus for perfection
        if pass_rate == 1.0:
            return 1.5  # Bonus for passing all tests
        return pass_rate

    def _apply_transform(self, observation: JuliaObservation) -> JuliaObservation:
        """Apply safety and quality transforms to observation."""
        if self.transform:
            observation = self.transform(observation)
        return observation

    @property
    def state(self) -> JuliaState:
        """Return current environment state."""
        return self._state