""" PRobe Environment — async-native implementation. Episode lifecycle: 1. reset() → ObservationType (starts a new episode) 2. step(a) → (Obs, RewardType, done, info) (execute one action) 3. state() → dict (full internal snapshot) Tasks cycle automatically: 0 (ultra-easy) → 1 (easy) → … → 6 (causal chain) → 0 … Dynamic world features (v3) ─────────────────────────── • Code mutation — each episode applies surface-level variable renames, a line shift, and a constant nudge so the agent must read the code rather than memorise tokens. • GET_CONTEXT — the agent can spend a step probing a specific line to receive the surrounding ±5 lines of context. • Causal unlocks — finding certain issues appends a new context hint to the observation, modelling real-world situations where one discovery leads to deeper investigation. Thread / task safety: each Environment instance owns its own state. For concurrent GRPO rollouts spin up one instance per worker. """ from __future__ import annotations import asyncio import concurrent.futures import dataclasses import logging from typing import Any from uuid import uuid4 from openenv.core.env_server.interfaces import Environment from openenv.core.env_server.types import State try: from ..agent.models import ActionType, ProbeAction, ProbeObservation, RewardType from ._import_compat import ( CodeReviewGrader, EpisodeMemory, LINE_TOLERANCE, TASKS, mutate_task, run_scanner, ) except ImportError: from agent.models import ActionType, ProbeAction, ProbeObservation, RewardType # type: ignore[no-redef] from environment._import_compat import ( # type: ignore[no-redef] CodeReviewGrader, EpisodeMemory, LINE_TOLERANCE, TASKS, mutate_task, run_scanner, ) log = logging.getLogger(__name__) @dataclasses.dataclass class EpisodeState: """ All mutable state for a single review episode. Using a dataclass eliminates stringly-typed dict key access and makes the shape of an episode explicit and statically checkable. """ task: dict[str, Any] review_comments: list[dict[str, Any]] = dataclasses.field(default_factory=list) issues_found: list[str] = dataclasses.field(default_factory=list) # Issues found with the correct bug/backdoor label. correct_classifications: int = 0 review_decision: str | None = None review_submitted: bool = False cumulative_reward: float = 0.0 # Progressive context unlocked by finding key issues. context_hints: list[str] = dataclasses.field(default_factory=list) hints_unlocked: set[str] = dataclasses.field(default_factory=set) scanner_used: bool = False class ProbeEnvironment(Environment): """ PRobe — Pull Request Investigation Environment. Public interface is fully async. The sync wrappers (reset / step / state) required by openenv's create_app are also provided; they delegate to the async versions via asyncio.run() so they are safe to call from sync contexts (e.g. tests without an event loop, openenv HTTP wrappers). """ SUPPORTS_CONCURRENT_SESSIONS: bool = True # ── Construction ────────────────────────────────────────────────────── def __init__(self, memory_dir: str | None = None) -> None: self._episode_id: str = str(uuid4()) self._step_count: int = 0 self._reset_count: int = 0 self._memory: EpisodeMemory = EpisodeMemory( memory_dir=memory_dir, instance_id=self._episode_id[:8], ) initial_task = TASKS[0] self._grader: CodeReviewGrader = CodeReviewGrader(initial_task) self._episode: EpisodeState = EpisodeState(task=initial_task) # ── Async-native interface (primary) ────────────────────────────────── async def async_reset(self) -> ProbeObservation: task_id = self._reset_count % len(TASKS) episode_seed = self._reset_count self._reset_count += 1 self._episode_id = str(uuid4()) self._step_count = 0 # Apply surface mutation so the agent cannot memorise exact tokens mutated_task = mutate_task(TASKS[task_id], seed=episode_seed) self._grader = CodeReviewGrader(mutated_task) self._episode = EpisodeState(task=mutated_task) # Inject cross-episode prior-knowledge hint when this task was seen before prior_hint = self._memory.prior_hint(task_id, TASKS[task_id]) if prior_hint: self._episode.context_hints.append(prior_hint) log.debug("EpisodeMemory: injected prior hint for task %d", task_id) return self._build_observation(reward=0.0, done=False) async def async_step( self, action: ProbeAction ) -> tuple[ProbeObservation, RewardType, bool, dict[str, Any]]: self._step_count += 1 current_task = self._episode.task episode_done = False step_reward: RewardType if action.action_type == ActionType.ADD_COMMENT: step_reward = self._handle_add_comment(action) elif action.action_type == ActionType.GET_CONTEXT: step_reward = self._handle_get_context(action) elif action.action_type == ActionType.RUN_SCANNER: step_reward = self._handle_run_scanner() elif action.action_type == ActionType.REQUEST_CHANGES: step_reward = self._handle_request_changes(action) elif action.action_type == ActionType.APPROVE: step_reward = self._handle_approve() elif action.action_type == ActionType.SUBMIT_REVIEW: step_reward, episode_done = self._handle_submit_review() elif action.action_type == ActionType.ESCALATE_TO_SECURITY_REVIEW: step_reward, episode_done = self._handle_escalate(action) else: step_reward = RewardType( total=-0.05, components={"illegal_action_penalty": -0.05}, passed=False, explanation=f"Unknown action type: {action.action_type}", step=self._step_count, terminal=False, ) # Apply step-budget penalty when the episode runs out of steps. if not episode_done and self._step_count >= current_task["max_steps"]: penalised_total = max(-1.0, step_reward.total - 0.05) step_reward = RewardType( total=round(penalised_total, 4), components={**step_reward.components, "step_budget_penalty": -0.05}, passed=False, explanation=step_reward.explanation + " [Step limit reached.]", step=self._step_count, terminal=True, ) episode_done = True self._episode.cumulative_reward = round( self._episode.cumulative_reward + step_reward.total, 4 ) observation = self._build_observation(reward=step_reward.total, done=episode_done) info = { "episode_id": self._episode_id, "cumulative_reward": self._episode.cumulative_reward, "issues_found": list(self._episode.issues_found), "review_decision": self._episode.review_decision, } return observation, step_reward, episode_done, info async def async_state(self) -> dict[str, Any]: task = self._episode.task return { "episode_id": self._episode_id, "step_count": self._step_count, "task_id": task["id"], "task_difficulty": task["difficulty"], "task_name": task["name"], "issues_found": list(self._episode.issues_found), "total_issues": len(task["issues"]), "review_decision": self._episode.review_decision, "review_submitted": self._episode.review_submitted, "cumulative_reward": self._episode.cumulative_reward, "max_steps": task["max_steps"], "scanner_used": self._episode.scanner_used, "correct_classifications": self._episode.correct_classifications, "escalation_required": task.get("escalation_required", False), } # ── Sync wrappers (openenv / create_app compatibility) ──────────────── def reset(self) -> ProbeObservation: # type: ignore[override] try: asyncio.get_running_loop() except RuntimeError: return asyncio.run(self.async_reset()) # Called from inside a running loop (e.g. pytest-asyncio) -- run in a # fresh thread that has its own event loop. with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: return pool.submit(asyncio.run, self.async_reset()).result() def step(self, action: ProbeAction) -> ProbeObservation: # type: ignore[override] """ Sync step for openenv compatibility. Returns only the Observation (reward is embedded in obs.reward). Use async_step() for the full (obs, reward, done, info) tuple. """ try: asyncio.get_running_loop() except RuntimeError: obs, _, _, _ = asyncio.run(self.async_step(action)) return obs with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: obs, _, _, _ = pool.submit(asyncio.run, self.async_step(action)).result() return obs @property def state(self) -> State: # type: ignore[override] # Sync property required by openenv's create_app interface. # Full async snapshot is available via async_state(). return State(episode_id=self._episode_id, step_count=self._step_count) # ── Action handlers ─────────────────────────────────────────────────── def _handle_add_comment(self, action: ProbeAction) -> RewardType: # Use getattr rather than action.classification directly because older # serialised actions from GRPO rollout workers may omit the field. classification_val = getattr(action, "classification", None) classification_str = classification_val.value if classification_val else None entry = { "type": "comment", "line": action.line_number, "text": action.comment, "severity": action.severity.value if action.severity else None, "category": action.category.value if action.category else None, "classification": classification_str, } self._episode.review_comments.append(entry) score, new_finds, breakdown = self._grader.score_comment( line_number=action.line_number, comment=action.comment, already_found=self._episode.issues_found, classification=classification_str, ) self._episode.issues_found.extend(new_finds) # Track correct classifications for metrics logging if new_finds and classification_str: task_issues = {iss["id"]: iss for iss in self._episode.task["issues"]} for fid in new_finds: expected = task_issues.get(fid, {}).get("classification") if expected and classification_str.lower().replace("-", "_") == expected.lower().replace("-", "_"): self._episode.correct_classifications += 1 clamped = round(max(-1.0, min(1.0, score)), 4) if new_finds: explanation = f"Identified issue(s): {new_finds}" elif score < 0: explanation = "False-positive comment — matched no known issue." else: explanation = "Comment recorded; no new issue matched." # ── Causal unlock: check whether any newly found issue reveals context self._unlock_causal_hints(new_finds) return RewardType( total=clamped, components=breakdown, passed=bool(new_finds), explanation=explanation, step=self._step_count, terminal=False, ) def _unlock_causal_hints(self, newly_found: list[str]) -> None: """Append context hint text for any issue that has an 'unlocks' key.""" task = self._episode.task hint_map: dict[str, str] = task.get("context_hints", {}) for issue in task["issues"]: unlock_key = issue.get("unlocks") if ( unlock_key and issue["id"] in newly_found and unlock_key not in self._episode.hints_unlocked and unlock_key in hint_map ): self._episode.hints_unlocked.add(unlock_key) self._episode.context_hints.append(hint_map[unlock_key]) def _handle_get_context( self, action: ProbeAction ) -> RewardType: """ GET_CONTEXT - reveal +/-5 lines around the requested line number. Costs a small step penalty (-0.01) to discourage random probing, but waives the penalty when the probed line is near a real issue. """ requested_line = action.line_number task = self._episode.task code_lines = task["code"].split("\n") if requested_line is None: return RewardType( total=-0.02, components={"invalid_context_probe_penalty": -0.02}, passed=False, explanation="GET_CONTEXT requires a line_number.", step=self._step_count, terminal=False, ) # Build an 11-line window: 5 lines before the target, the target itself, # and 5 lines after. The asymmetric offsets (-6 / +5) account for # Python's 0-based list indexing against 1-based line numbers. window_start = max(0, requested_line - 6) window_end = min(len(code_lines), requested_line + 5) context_snippet = "\n".join( f"{line_idx + 1:3}: {code_lines[line_idx]}" for line_idx in range(window_start, window_end) ) is_near_real_issue = any( (iss["line_range"][0] - LINE_TOLERANCE) <= requested_line <= (iss["line_range"][1] + LINE_TOLERANCE) for iss in task["issues"] ) probe_penalty = 0.0 if is_near_real_issue else -0.01 self._episode.review_comments.append({ "type": "context_probe", "line": requested_line, "context": context_snippet, }) return RewardType( total=probe_penalty, components={"context_probe_penalty": probe_penalty}, passed=is_near_real_issue, explanation=f"Context around line {requested_line}:\n{context_snippet}", step=self._step_count, terminal=False, ) def _handle_run_scanner(self) -> RewardType: """ RUN_SCANNER — invoke the simulated static-analysis tool. Reward design ───────────── • First use in an episode: free (+0.0) — the agent should always try the scanner at least once. • Repeated use costs -0.02 per call — the tool output doesn't change within an episode (same seed), so redundant calls waste the step budget without new information. The scan result is stored in ``review_comments`` so it appears in ``review_history`` on the next observation. The agent must still call ``ADD_COMMENT`` to earn reward from any finding. """ task = self._episode.task # _mutation_seed is injected by mutate_task() on reset; fall back to # reset_count so the scanner seed is always defined even on the very # first episode before any mutation has occurred. episode_seed = task.get("_mutation_seed", self._reset_count) is_first_scan = not self._episode.scanner_used self._episode.scanner_used = True scan_result = run_scanner(task, seed=episode_seed) self._episode.review_comments.append({ "type": "scanner_result", "tool": scan_result["tool"], "findings": scan_result["findings"], "missed_count": scan_result["missed_count"], "note": scan_result["note"], }) penalty = 0.0 if is_first_scan else -0.02 finding_count = len(scan_result["findings"]) explanation = ( f"[{scan_result['tool']}] {finding_count} finding(s) reported " f"({scan_result['missed_count']} issue(s) may have been missed). " f"{scan_result['note']}" ) if not is_first_scan: explanation = "Scanner already run this episode — results unchanged. " + explanation log.debug( "RUN_SCANNER: %d findings, missed=%d, seed=%d", finding_count, scan_result["missed_count"], episode_seed, ) return RewardType( total=penalty, components={"scanner_penalty": penalty}, passed=is_first_scan, explanation=explanation, step=self._step_count, terminal=False, ) def _handle_request_changes(self, action: ProbeAction) -> RewardType: self._episode.review_decision = "request_changes" self._episode.review_comments.append( {"type": "request_changes", "text": action.comment} ) return RewardType( total=0.0, components={}, passed=True, explanation="REQUEST_CHANGES recorded. Terminal reward applied on SUBMIT_REVIEW.", step=self._step_count, terminal=False, ) def _handle_approve(self) -> RewardType: self._episode.review_decision = "approve" total_issue_count = len(self._episode.task["issues"]) found_issue_count = len(set(self._episode.issues_found)) coverage_fraction = found_issue_count / total_issue_count if total_issue_count > 0 else 0.0 # Approving a PR while fewer than 50 % of issues are found is treated # as reckless sign-off — a hard penalisation rather than a gentle nudge. if coverage_fraction < 0.5: return RewardType( total=-0.15, components={"bad_approval_penalty": -0.15}, passed=False, explanation=f"APPROVE with only {found_issue_count}/{total_issue_count} issues found.", step=self._step_count, terminal=False, ) return RewardType( total=0.02, components={"approval_credit": 0.02}, passed=True, explanation="APPROVE recorded.", step=self._step_count, terminal=False, ) def _handle_submit_review(self) -> tuple[RewardType, bool]: if self._episode.review_submitted: return ( RewardType( total=-0.05, components={"duplicate_submit_penalty": -0.05}, passed=False, explanation="Review already submitted.", step=self._step_count, terminal=False, ), False, ) self._episode.review_submitted = True task = self._episode.task # Deduplicate before scoring: the agent may have re-commented on the # same line in multiple steps, but each issue should only count once. unique_issues_found = list(set(self._episode.issues_found)) terminal_reward = self._grader.final_score( issues_found=unique_issues_found, review_decision=self._episode.review_decision, steps_used=self._step_count, max_steps=task["max_steps"], ) if unique_issues_found: self._memory.record(task["id"], unique_issues_found) log.debug( "EpisodeMemory: recorded %d finding(s) for task %d", len(unique_issues_found), task["id"], ) return terminal_reward, True def _handle_escalate(self, action: ProbeAction) -> tuple[RewardType, bool]: """ ESCALATE_TO_SECURITY_REVIEW - terminal action for adversarial tasks. Correct only when the task has escalation_required=True. Calling on a non-adversarial task incurs a false-alarm penalty. """ if self._episode.review_submitted: return ( RewardType( total=-0.05, components={"duplicate_submit_penalty": -0.05}, passed=False, explanation="Review already submitted.", step=self._step_count, terminal=False, ), False, ) self._episode.review_submitted = True self._episode.review_decision = "escalate_to_security_review" task = self._episode.task unique_issues_found = list(set(self._episode.issues_found)) terminal_reward = self._grader.final_score( issues_found=unique_issues_found, review_decision="escalate_to_security_review", steps_used=self._step_count, max_steps=task["max_steps"], ) if unique_issues_found: self._memory.record(task["id"], unique_issues_found) return terminal_reward, True def _build_observation(self, reward: float, done: bool) -> ProbeObservation: task = self._episode.task # Adversarial tasks hint that the PR is from an external contributor; # non-adversarial tasks hint it is from a trusted team member. # Neither phrasing reveals whether backdoors are present (partial observability). adversarial_hint = ( "This PR was submitted by an external contributor with no prior commit history." if task.get("escalation_required") else "This PR was submitted by a trusted team member." ) return ProbeObservation( code_snippet=task["code"], task_description=task["description"], file_name=task["file_name"], task_id=task["id"], task_difficulty=task["difficulty"], review_history=list(self._episode.review_comments), step_count=self._step_count, max_steps=task["max_steps"], issues_found_count=len(set(self._episode.issues_found)), total_issues=len(task["issues"]), done=done, reward=round(max(-1.0, min(1.0, reward)), 4), context_hints=list(self._episode.context_hints), adversarial_hint=adversarial_hint, metadata={ "cumulative_reward": self._episode.cumulative_reward, "review_decision": self._episode.review_decision, "episode_id": self._episode_id, "mutation_seed": task.get("_mutation_seed"), "correct_classifications": self._episode.correct_classifications, "escalation_required": task.get("escalation_required", False), }, )