Spaces:
Sleeping
Sleeping
File size: 2,870 Bytes
be32845 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
"""
envs/julia_env/julia_transforms.py
--------------------------------
Safety and quality transforms for Julia code.
"""
import re
from core.env_server.base_transforms import CompositeTransform
from core.env_server.interfaces import Transform
from ..models import JuliaObservation
# -------------------------
# Safety Transform
# -------------------------
class JuliaSafetyTransform(Transform):
"""Detects dangerous Julia operations and penalizes them with a negative reward."""
def __init__(self, penalty: float = -3.0):
self.penalty = penalty
self.dangerous_patterns = [
r"run\(",
r"read\(",
r"write\(",
r"unsafe_",
r"ccall\(",
r"Base\.exit",
r"Base\.kill",
r"rm\(", # file deletion
r"download\(" # downloading
]
def __call__(self, observation):
# Only act on JuliaObservation objects
if not isinstance(observation, JuliaObservation):
return observation
# Extract last executed code from metadata
code = observation.metadata.get("last_code", "") if observation.metadata else ""
for pattern in self.dangerous_patterns:
if re.search(pattern, code):
# Apply penalty and record violation
observation.reward = (observation.reward or 0.0) + self.penalty
observation.metadata = observation.metadata or {}
observation.metadata["safety_violation"] = pattern
return observation
# Safe code gets neutral reward
observation.reward = observation.reward or 0.0
return observation
# -------------------------
# Quality Transform
# -------------------------
class JuliaQualityTransform(Transform):
"""Evaluates and rewards Julia code quality."""
def __init__(self, concise_bonus=1, max_length_threshold=120):
self.concise_bonus = concise_bonus
self.max_length_threshold = max_length_threshold
def __call__(self, observation):
# Only act on JuliaObservation objects
if not isinstance(observation, JuliaObservation):
return observation
code = observation.metadata.get("last_code", "") if observation.metadata else ""
reward = observation.reward or 0.0
# Reward concise code
if len(code.strip()) <= self.max_length_threshold:
reward += self.concise_bonus
else:
reward -= 0.1 # slight penalty for verbosity
observation.reward = reward
return observation
# -------------------------
# Composite Transform
# -------------------------
def create_safe_julia_transform():
"""Combines safety and quality transforms into one pipeline."""
return CompositeTransform([JuliaSafetyTransform(), JuliaQualityTransform()])
|