File size: 4,843 Bytes
f44f429 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | """Reinforcement Learning Environment Core.
Defines the environment logic, maintaining the current trajectory
state and mediating between incoming requests and the headless grader.
"""
import random
from typing import Optional, Dict, Any
from server.tasks import TASKS
from server.grader import grade_action
from server.models import StepResult, StateResponse, Action, Observation
ERROR_EPISODE_COMPLETED = "Episode already completed. Call /reset to start a new episode."
class CodeSecurityEnv:
"""Simulates the stateful progression of a software security assessment."""
def __init__(self) -> None:
"""Initialize a fresh environment instance."""
self.current_task: Optional[Dict[str, Any]] = None
self.step_count: int = 0
self.done: bool = False
self.total_reward: float = 0.0
self._task_ids = list(TASKS.keys())
def reset(self, task_id: Optional[str] = None, seed: Optional[int] = None) -> Observation:
"""Reset the environment safely to a new or targeted initial state.
Args:
task_id: Optionally force the environment to yield a specific task definition.
seed: Initialize standard random seed.
Returns:
An Observation baseline reflecting the new scenario context.
"""
if seed is not None:
random.seed(seed)
if task_id and task_id in TASKS:
self.current_task = TASKS[task_id]
else:
chosen_id = random.choice(self._task_ids)
self.current_task = TASKS[chosen_id]
self.step_count = 0
self.done = False
self.total_reward = 0.0
return self._make_observation()
def step(self, action: Action) -> StepResult:
"""Advance the environment state using a provided agent Action payload.
Args:
action: Evaluated metrics provided directly by agent decision matrices.
Returns:
A StepResult containing scalar reward metrics and end-of-episode flag.
"""
if self.current_task is None:
self.reset()
if self.done:
return StepResult(
observation=self._make_observation(),
reward=0.0,
done=True,
info={"error": ERROR_EPISODE_COMPLETED},
)
# Intermediate Step: Request file
if getattr(action, "request_file", False):
self.step_count += 1
reward = 0.20
self.total_reward += reward
self.done = False
return StepResult(
observation=self._make_observation(),
reward=reward,
done=self.done,
info={
"task_name": getattr(self.current_task, "get", dict().get)("name", "Unknown Task") if self.current_task else "Unknown Task",
"step_count": self.step_count
},
)
try:
reward, breakdown = grade_action(action.model_dump(), self.current_task)
except Exception as e:
reward, breakdown = 0.0, {"error": f"Evaluation error: {e}"}
self.step_count += 1
self.total_reward += reward
self.done = True # single-step environment becomes max 2-step
return StepResult(
observation=self._make_observation(),
reward=reward,
done=self.done,
info={
"reward_breakdown": breakdown,
"task_name": self.current_task.get("name", "Unknown Task"),
"step_count": self.step_count
},
)
def state(self) -> StateResponse:
"""Return global analytics tracking the current environment session state."""
current_id = self.current_task["id"] if getattr(self, "current_task", None) else ""
return StateResponse(
task_id=current_id,
step=self.step_count,
done=self.done,
total_reward=self.total_reward,
)
def _make_observation(self) -> Observation:
"""Construct the contextual parameters surrounding an ongoing assessment."""
t = self.current_task
if not t:
raise KeyError("Attempted observation render without an initialized active task")
# Hide the snippet before Step 1
snippet = t["code_snippet"] if self.step_count > 0 else "<FILE CONTENTS HIDDEN - Submit {\"request_file\": true} to view>"
return Observation(
task_id=t["id"],
language=t["language"],
difficulty=t["difficulty"],
code_snippet=snippet,
context=t["context"],
pr_title=t["pr_title"],
file_path=t["file_path"],
)
|