File size: 2,437 Bytes
f58914c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4caadb6
 
f58914c
 
 
 
 
 
 
 
 
4caadb6
 
 
 
 
f58914c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
envs/julia_env/julia_transforms.py
--------------------------------
Safety and quality transforms for Julia code.
"""

import re

# Support both in-repo and standalone imports
try:
    # In-repo imports
    from openenv.core.env_server.interfaces import Transform
    from ..models import JuliaObservation
except ImportError:
    # Standalone imports
    from openenv.core.env_server.interfaces import Transform
    from models import JuliaObservation


# -------------------------
# Safety Transform
# -------------------------
class JuliaSafetyTransform(Transform):
    """Detects dangerous Julia operations and penalizes them with a negative reward."""

    def __init__(self, penalty: float = -3.0):
        self.penalty = penalty
        self.dangerous_patterns = [
            r"run\(",
            r"read\(",
            r"write\(",
            r"unsafe_",
            r"ccall\(",
            r"Base\.exit",
            r"Base\.kill",
            r"rm\(",  # file deletion
            r"download\(",  # downloading
        ]

    def __call__(self, observation):
        # Only act on JuliaObservation objects
        if not isinstance(observation, JuliaObservation):
            return observation

        # Extract executed code from metadata (core_code + test_code)
        if observation.metadata:
            code = (
                observation.metadata.get("core_code", "")
                + "\n"
                + observation.metadata.get("test_code", "")
            )
        else:
            code = ""

        for pattern in self.dangerous_patterns:
            if re.search(pattern, code):
                # Apply penalty and record violation
                observation.reward = (observation.reward or 0.0) + self.penalty
                observation.metadata = observation.metadata or {}
                observation.metadata["safety_violation"] = pattern
                return observation

        # Safe code gets neutral reward
        observation.reward = observation.reward or 0.0
        return observation


# -------------------------
# Factory
# -------------------------
def create_safe_julia_transform():
    """Creates safety transform for Julia code."""
    return JuliaSafetyTransform()