coding_env / server /transforms.py
Zach Wentz
Test deployment attempt 2
f920a21
raw
history blame
2.99 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Transforms specific to coding environments."""
import ast
import re
from core.env_server.base_transforms import CompositeTransform
from core.env_server.interfaces import Transform
from core.env_server.types import Observation
from ..models import CodeObservation
class CodeSafetyTransform(Transform):
"""Evaluates code safety and assigns penalties for dangerous patterns."""
def __init__(self, penalty: float = -1.0):
self.penalty = penalty
self.dangerous_patterns = [
r"import\s+os",
r"import\s+subprocess",
r"eval\(",
r"exec\(",
r"__import__",
r"open\(",
]
def __call__(self, observation: Observation) -> Observation:
if not isinstance(observation, CodeObservation):
return observation
if "last_code" in observation.metadata:
code = observation.metadata["last_code"]
for pattern in self.dangerous_patterns:
if re.search(pattern, code):
observation.reward = self.penalty
observation.metadata["safety_violation"] = pattern
break
else:
if observation.reward is None:
observation.reward = 0.0
return observation
class CodeQualityTransform(Transform):
"""Evaluates and rewards code quality metrics."""
def __init__(
self,
concise_bonus: float = 0.1,
max_length_threshold: int = 100,
syntax_penalty: float = -0.2,
):
self.concise_bonus = concise_bonus
self.max_length_threshold = max_length_threshold
self.syntax_penalty = syntax_penalty
def __call__(self, observation: Observation) -> Observation:
if not isinstance(observation, CodeObservation):
return observation
quality_score = 0.0
if "last_code" in observation.metadata:
code = observation.metadata["last_code"]
# Reward concise code
if len(code.strip()) <= self.max_length_threshold:
quality_score += self.concise_bonus
# Check syntax (redundant but useful for quality assessment)
try:
ast.parse(code)
except SyntaxError:
quality_score += self.syntax_penalty
# Add to existing reward
if observation.reward is None:
observation.reward = quality_score
else:
observation.reward += quality_score
return observation
def create_safe_coding_transform() -> CompositeTransform:
"""Create a transform focused on safe coding practices and quality."""
return CompositeTransform([CodeSafetyTransform(), CodeQualityTransform()])