PRobe / environment /probe_environment.py
Thakur, Mahipal
Added Meaning ful comments
4ec7361
"""
PRobe Environment — async-native implementation.
Episode lifecycle:
1. reset() → ObservationType (starts a new episode)
2. step(a) → (Obs, RewardType, done, info) (execute one action)
3. state() → dict (full internal snapshot)
Tasks cycle automatically: 0 (ultra-easy) → 1 (easy) → … → 6 (causal chain) → 0 …
Dynamic world features (v3)
───────────────────────────
• Code mutation — each episode applies surface-level variable renames,
a line shift, and a constant nudge so the agent must
read the code rather than memorise tokens.
• GET_CONTEXT — the agent can spend a step probing a specific line to
receive the surrounding ±5 lines of context.
• Causal unlocks — finding certain issues appends a new context hint to
the observation, modelling real-world situations where
one discovery leads to deeper investigation.
Thread / task safety: each Environment instance owns its own state.
For concurrent GRPO rollouts spin up one instance per worker.
"""
from __future__ import annotations
import asyncio
import concurrent.futures
import dataclasses
import logging
from typing import Any
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from ..agent.models import ActionType, ProbeAction, ProbeObservation, RewardType
from ._import_compat import (
CodeReviewGrader,
EpisodeMemory,
LINE_TOLERANCE,
TASKS,
mutate_task,
run_scanner,
)
except ImportError:
from agent.models import ActionType, ProbeAction, ProbeObservation, RewardType # type: ignore[no-redef]
from environment._import_compat import ( # type: ignore[no-redef]
CodeReviewGrader,
EpisodeMemory,
LINE_TOLERANCE,
TASKS,
mutate_task,
run_scanner,
)
log = logging.getLogger(__name__)
@dataclasses.dataclass
class EpisodeState:
"""
All mutable state for a single review episode.
Using a dataclass eliminates stringly-typed dict key access and makes
the shape of an episode explicit and statically checkable.
"""
task: dict[str, Any]
review_comments: list[dict[str, Any]] = dataclasses.field(default_factory=list)
issues_found: list[str] = dataclasses.field(default_factory=list)
# Issues found with the correct bug/backdoor label.
correct_classifications: int = 0
review_decision: str | None = None
review_submitted: bool = False
cumulative_reward: float = 0.0
# Progressive context unlocked by finding key issues.
context_hints: list[str] = dataclasses.field(default_factory=list)
hints_unlocked: set[str] = dataclasses.field(default_factory=set)
scanner_used: bool = False
class ProbeEnvironment(Environment):
"""
PRobe — Pull Request Investigation Environment.
Public interface is fully async. The sync wrappers (reset / step / state)
required by openenv's create_app are also provided; they delegate to the
async versions via asyncio.run() so they are safe to call from sync
contexts (e.g. tests without an event loop, openenv HTTP wrappers).
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
# ── Construction ──────────────────────────────────────────────────────
def __init__(self, memory_dir: str | None = None) -> None:
self._episode_id: str = str(uuid4())
self._step_count: int = 0
self._reset_count: int = 0
self._memory: EpisodeMemory = EpisodeMemory(
memory_dir=memory_dir,
instance_id=self._episode_id[:8],
)
initial_task = TASKS[0]
self._grader: CodeReviewGrader = CodeReviewGrader(initial_task)
self._episode: EpisodeState = EpisodeState(task=initial_task)
# ── Async-native interface (primary) ──────────────────────────────────
async def async_reset(self) -> ProbeObservation:
task_id = self._reset_count % len(TASKS)
episode_seed = self._reset_count
self._reset_count += 1
self._episode_id = str(uuid4())
self._step_count = 0
# Apply surface mutation so the agent cannot memorise exact tokens
mutated_task = mutate_task(TASKS[task_id], seed=episode_seed)
self._grader = CodeReviewGrader(mutated_task)
self._episode = EpisodeState(task=mutated_task)
# Inject cross-episode prior-knowledge hint when this task was seen before
prior_hint = self._memory.prior_hint(task_id, TASKS[task_id])
if prior_hint:
self._episode.context_hints.append(prior_hint)
log.debug("EpisodeMemory: injected prior hint for task %d", task_id)
return self._build_observation(reward=0.0, done=False)
async def async_step(
self, action: ProbeAction
) -> tuple[ProbeObservation, RewardType, bool, dict[str, Any]]:
self._step_count += 1
current_task = self._episode.task
episode_done = False
step_reward: RewardType
if action.action_type == ActionType.ADD_COMMENT:
step_reward = self._handle_add_comment(action)
elif action.action_type == ActionType.GET_CONTEXT:
step_reward = self._handle_get_context(action)
elif action.action_type == ActionType.RUN_SCANNER:
step_reward = self._handle_run_scanner()
elif action.action_type == ActionType.REQUEST_CHANGES:
step_reward = self._handle_request_changes(action)
elif action.action_type == ActionType.APPROVE:
step_reward = self._handle_approve()
elif action.action_type == ActionType.SUBMIT_REVIEW:
step_reward, episode_done = self._handle_submit_review()
elif action.action_type == ActionType.ESCALATE_TO_SECURITY_REVIEW:
step_reward, episode_done = self._handle_escalate(action)
else:
step_reward = RewardType(
total=-0.05,
components={"illegal_action_penalty": -0.05},
passed=False,
explanation=f"Unknown action type: {action.action_type}",
step=self._step_count,
terminal=False,
)
# Apply step-budget penalty when the episode runs out of steps.
if not episode_done and self._step_count >= current_task["max_steps"]:
penalised_total = max(-1.0, step_reward.total - 0.05)
step_reward = RewardType(
total=round(penalised_total, 4),
components={**step_reward.components, "step_budget_penalty": -0.05},
passed=False,
explanation=step_reward.explanation + " [Step limit reached.]",
step=self._step_count,
terminal=True,
)
episode_done = True
self._episode.cumulative_reward = round(
self._episode.cumulative_reward + step_reward.total, 4
)
observation = self._build_observation(reward=step_reward.total, done=episode_done)
info = {
"episode_id": self._episode_id,
"cumulative_reward": self._episode.cumulative_reward,
"issues_found": list(self._episode.issues_found),
"review_decision": self._episode.review_decision,
}
return observation, step_reward, episode_done, info
async def async_state(self) -> dict[str, Any]:
task = self._episode.task
return {
"episode_id": self._episode_id,
"step_count": self._step_count,
"task_id": task["id"],
"task_difficulty": task["difficulty"],
"task_name": task["name"],
"issues_found": list(self._episode.issues_found),
"total_issues": len(task["issues"]),
"review_decision": self._episode.review_decision,
"review_submitted": self._episode.review_submitted,
"cumulative_reward": self._episode.cumulative_reward,
"max_steps": task["max_steps"],
"scanner_used": self._episode.scanner_used,
"correct_classifications": self._episode.correct_classifications,
"escalation_required": task.get("escalation_required", False),
}
# ── Sync wrappers (openenv / create_app compatibility) ────────────────
def reset(self) -> ProbeObservation: # type: ignore[override]
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(self.async_reset())
# Called from inside a running loop (e.g. pytest-asyncio) -- run in a
# fresh thread that has its own event loop.
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(asyncio.run, self.async_reset()).result()
def step(self, action: ProbeAction) -> ProbeObservation: # type: ignore[override]
"""
Sync step for openenv compatibility.
Returns only the Observation (reward is embedded in obs.reward).
Use async_step() for the full (obs, reward, done, info) tuple.
"""
try:
asyncio.get_running_loop()
except RuntimeError:
obs, _, _, _ = asyncio.run(self.async_step(action))
return obs
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
obs, _, _, _ = pool.submit(asyncio.run, self.async_step(action)).result()
return obs
@property
def state(self) -> State: # type: ignore[override]
# Sync property required by openenv's create_app interface.
# Full async snapshot is available via async_state().
return State(episode_id=self._episode_id, step_count=self._step_count)
# ── Action handlers ───────────────────────────────────────────────────
def _handle_add_comment(self, action: ProbeAction) -> RewardType:
# Use getattr rather than action.classification directly because older
# serialised actions from GRPO rollout workers may omit the field.
classification_val = getattr(action, "classification", None)
classification_str = classification_val.value if classification_val else None
entry = {
"type": "comment",
"line": action.line_number,
"text": action.comment,
"severity": action.severity.value if action.severity else None,
"category": action.category.value if action.category else None,
"classification": classification_str,
}
self._episode.review_comments.append(entry)
score, new_finds, breakdown = self._grader.score_comment(
line_number=action.line_number,
comment=action.comment,
already_found=self._episode.issues_found,
classification=classification_str,
)
self._episode.issues_found.extend(new_finds)
# Track correct classifications for metrics logging
if new_finds and classification_str:
task_issues = {iss["id"]: iss for iss in self._episode.task["issues"]}
for fid in new_finds:
expected = task_issues.get(fid, {}).get("classification")
if expected and classification_str.lower().replace("-", "_") == expected.lower().replace("-", "_"):
self._episode.correct_classifications += 1
clamped = round(max(-1.0, min(1.0, score)), 4)
if new_finds:
explanation = f"Identified issue(s): {new_finds}"
elif score < 0:
explanation = "False-positive comment — matched no known issue."
else:
explanation = "Comment recorded; no new issue matched."
# ── Causal unlock: check whether any newly found issue reveals context
self._unlock_causal_hints(new_finds)
return RewardType(
total=clamped,
components=breakdown,
passed=bool(new_finds),
explanation=explanation,
step=self._step_count,
terminal=False,
)
def _unlock_causal_hints(self, newly_found: list[str]) -> None:
"""Append context hint text for any issue that has an 'unlocks' key."""
task = self._episode.task
hint_map: dict[str, str] = task.get("context_hints", {})
for issue in task["issues"]:
unlock_key = issue.get("unlocks")
if (
unlock_key
and issue["id"] in newly_found
and unlock_key not in self._episode.hints_unlocked
and unlock_key in hint_map
):
self._episode.hints_unlocked.add(unlock_key)
self._episode.context_hints.append(hint_map[unlock_key])
def _handle_get_context(
self, action: ProbeAction
) -> RewardType:
"""
GET_CONTEXT - reveal +/-5 lines around the requested line number.
Costs a small step penalty (-0.01) to discourage random probing,
but waives the penalty when the probed line is near a real issue.
"""
requested_line = action.line_number
task = self._episode.task
code_lines = task["code"].split("\n")
if requested_line is None:
return RewardType(
total=-0.02,
components={"invalid_context_probe_penalty": -0.02},
passed=False,
explanation="GET_CONTEXT requires a line_number.",
step=self._step_count,
terminal=False,
)
# Build an 11-line window: 5 lines before the target, the target itself,
# and 5 lines after. The asymmetric offsets (-6 / +5) account for
# Python's 0-based list indexing against 1-based line numbers.
window_start = max(0, requested_line - 6)
window_end = min(len(code_lines), requested_line + 5)
context_snippet = "\n".join(
f"{line_idx + 1:3}: {code_lines[line_idx]}"
for line_idx in range(window_start, window_end)
)
is_near_real_issue = any(
(iss["line_range"][0] - LINE_TOLERANCE) <= requested_line <= (iss["line_range"][1] + LINE_TOLERANCE)
for iss in task["issues"]
)
probe_penalty = 0.0 if is_near_real_issue else -0.01
self._episode.review_comments.append({
"type": "context_probe",
"line": requested_line,
"context": context_snippet,
})
return RewardType(
total=probe_penalty,
components={"context_probe_penalty": probe_penalty},
passed=is_near_real_issue,
explanation=f"Context around line {requested_line}:\n{context_snippet}",
step=self._step_count,
terminal=False,
)
def _handle_run_scanner(self) -> RewardType:
"""
RUN_SCANNER — invoke the simulated static-analysis tool.
Reward design
─────────────
• First use in an episode: free (+0.0) — the agent should always
try the scanner at least once.
• Repeated use costs -0.02 per call — the tool output doesn't change
within an episode (same seed), so redundant calls waste the step
budget without new information.
The scan result is stored in ``review_comments`` so it appears in
``review_history`` on the next observation. The agent must still
call ``ADD_COMMENT`` to earn reward from any finding.
"""
task = self._episode.task
# _mutation_seed is injected by mutate_task() on reset; fall back to
# reset_count so the scanner seed is always defined even on the very
# first episode before any mutation has occurred.
episode_seed = task.get("_mutation_seed", self._reset_count)
is_first_scan = not self._episode.scanner_used
self._episode.scanner_used = True
scan_result = run_scanner(task, seed=episode_seed)
self._episode.review_comments.append({
"type": "scanner_result",
"tool": scan_result["tool"],
"findings": scan_result["findings"],
"missed_count": scan_result["missed_count"],
"note": scan_result["note"],
})
penalty = 0.0 if is_first_scan else -0.02
finding_count = len(scan_result["findings"])
explanation = (
f"[{scan_result['tool']}] {finding_count} finding(s) reported "
f"({scan_result['missed_count']} issue(s) may have been missed). "
f"{scan_result['note']}"
)
if not is_first_scan:
explanation = "Scanner already run this episode — results unchanged. " + explanation
log.debug(
"RUN_SCANNER: %d findings, missed=%d, seed=%d",
finding_count, scan_result["missed_count"], episode_seed,
)
return RewardType(
total=penalty,
components={"scanner_penalty": penalty},
passed=is_first_scan,
explanation=explanation,
step=self._step_count,
terminal=False,
)
def _handle_request_changes(self, action: ProbeAction) -> RewardType:
self._episode.review_decision = "request_changes"
self._episode.review_comments.append(
{"type": "request_changes", "text": action.comment}
)
return RewardType(
total=0.0,
components={},
passed=True,
explanation="REQUEST_CHANGES recorded. Terminal reward applied on SUBMIT_REVIEW.",
step=self._step_count,
terminal=False,
)
def _handle_approve(self) -> RewardType:
self._episode.review_decision = "approve"
total_issue_count = len(self._episode.task["issues"])
found_issue_count = len(set(self._episode.issues_found))
coverage_fraction = found_issue_count / total_issue_count if total_issue_count > 0 else 0.0
# Approving a PR while fewer than 50 % of issues are found is treated
# as reckless sign-off — a hard penalisation rather than a gentle nudge.
if coverage_fraction < 0.5:
return RewardType(
total=-0.15,
components={"bad_approval_penalty": -0.15},
passed=False,
explanation=f"APPROVE with only {found_issue_count}/{total_issue_count} issues found.",
step=self._step_count,
terminal=False,
)
return RewardType(
total=0.02,
components={"approval_credit": 0.02},
passed=True,
explanation="APPROVE recorded.",
step=self._step_count,
terminal=False,
)
def _handle_submit_review(self) -> tuple[RewardType, bool]:
if self._episode.review_submitted:
return (
RewardType(
total=-0.05,
components={"duplicate_submit_penalty": -0.05},
passed=False,
explanation="Review already submitted.",
step=self._step_count,
terminal=False,
),
False,
)
self._episode.review_submitted = True
task = self._episode.task
# Deduplicate before scoring: the agent may have re-commented on the
# same line in multiple steps, but each issue should only count once.
unique_issues_found = list(set(self._episode.issues_found))
terminal_reward = self._grader.final_score(
issues_found=unique_issues_found,
review_decision=self._episode.review_decision,
steps_used=self._step_count,
max_steps=task["max_steps"],
)
if unique_issues_found:
self._memory.record(task["id"], unique_issues_found)
log.debug(
"EpisodeMemory: recorded %d finding(s) for task %d",
len(unique_issues_found),
task["id"],
)
return terminal_reward, True
def _handle_escalate(self, action: ProbeAction) -> tuple[RewardType, bool]:
"""
ESCALATE_TO_SECURITY_REVIEW - terminal action for adversarial tasks.
Correct only when the task has escalation_required=True.
Calling on a non-adversarial task incurs a false-alarm penalty.
"""
if self._episode.review_submitted:
return (
RewardType(
total=-0.05,
components={"duplicate_submit_penalty": -0.05},
passed=False,
explanation="Review already submitted.",
step=self._step_count,
terminal=False,
),
False,
)
self._episode.review_submitted = True
self._episode.review_decision = "escalate_to_security_review"
task = self._episode.task
unique_issues_found = list(set(self._episode.issues_found))
terminal_reward = self._grader.final_score(
issues_found=unique_issues_found,
review_decision="escalate_to_security_review",
steps_used=self._step_count,
max_steps=task["max_steps"],
)
if unique_issues_found:
self._memory.record(task["id"], unique_issues_found)
return terminal_reward, True
def _build_observation(self, reward: float, done: bool) -> ProbeObservation:
task = self._episode.task
# Adversarial tasks hint that the PR is from an external contributor;
# non-adversarial tasks hint it is from a trusted team member.
# Neither phrasing reveals whether backdoors are present (partial observability).
adversarial_hint = (
"This PR was submitted by an external contributor with no prior commit history."
if task.get("escalation_required")
else "This PR was submitted by a trusted team member."
)
return ProbeObservation(
code_snippet=task["code"],
task_description=task["description"],
file_name=task["file_name"],
task_id=task["id"],
task_difficulty=task["difficulty"],
review_history=list(self._episode.review_comments),
step_count=self._step_count,
max_steps=task["max_steps"],
issues_found_count=len(set(self._episode.issues_found)),
total_issues=len(task["issues"]),
done=done,
reward=round(max(-1.0, min(1.0, reward)), 4),
context_hints=list(self._episode.context_hints),
adversarial_hint=adversarial_hint,
metadata={
"cumulative_reward": self._episode.cumulative_reward,
"review_decision": self._episode.review_decision,
"episode_id": self._episode_id,
"mutation_seed": task.get("_mutation_seed"),
"correct_classifications": self._episode.correct_classifications,
"escalation_required": task.get("escalation_required", False),
},
)