from __future__ import annotations import random import uuid from typing import Any, Optional try: from core.env_server.interfaces import Environment except ImportError: try: from openenv.core.env_server.interfaces import Environment except ImportError: from openenv_core.env_server.interfaces import Environment try: from ..models import ( CodeSecurityAction, CodeSecurityObservation, CodeSecurityState, FindingRecord, ) from .grader import evaluate_finding, final_grade from .tasks import TaskSpec, get_task, list_task_ids except ImportError: from models import ( CodeSecurityAction, CodeSecurityObservation, CodeSecurityState, FindingRecord, ) from server.grader import evaluate_finding, final_grade from server.tasks import TaskSpec, get_task, list_task_ids class CodeSecurityAuditorEnvironment( Environment[CodeSecurityAction, CodeSecurityObservation, CodeSecurityState] ): """Real-world code security auditing simulator with deterministic graders.""" SUPPORTS_CONCURRENT_SESSIONS = True MIN_STRICT_SCORE = 0.001 MAX_STRICT_SCORE = 0.999 def __init__(self, default_task_id: str = "easy"): self._default_task_id = default_task_id self._task_cursor = 0 self._task: Optional[TaskSpec] = None self._state = CodeSecurityState() def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, ) -> CodeSecurityObservation: requested_task = kwargs.get("task_id") or kwargs.get("task") if requested_task is not None: task = get_task(str(requested_task)) elif seed is not None: rng = random.Random(seed) task = get_task(rng.choice(list_task_ids())) elif self._default_task_id: task = get_task(self._default_task_id) else: task_order = list_task_ids() task = get_task(task_order[self._task_cursor % len(task_order)]) self._task_cursor += 1 self._task = task self._state = CodeSecurityState( episode_id=episode_id or str(uuid.uuid4()), step_count=0, task_id=task.id, task_title=task.title, difficulty=task.difficulty, objective=task.objective, max_steps=task.max_steps, inspected_files=[], findings_submitted=[], matched_vulnerability_ids=[], false_positive_count=0, duplicate_submission_count=0, quality_multiplier=1.0, final_score=None, ) return self._build_observation( reward=0.0, done=False, feedback=( "Audit started. Use inspect_file before submit_finding. " "Finish with submit_final_report." ), focused_file=None, excerpt="", extra_metadata={ "available_task_ids": list_task_ids(), "task_id": task.id, }, ) def step( self, action: CodeSecurityAction, timeout_s: Optional[float] = None, **kwargs: Any, ) -> CodeSecurityObservation: del timeout_s, kwargs task = self._require_task() if self._state.final_score is not None: return self._build_observation( reward=0.0, done=True, feedback="Episode already terminated. Call reset() to start a new task.", focused_file=None, excerpt="", ) self._state.step_count += 1 feedback = "" reward = 0.0 focused_file = None excerpt = "" if action.action_type == "inspect_file": reward, feedback, focused_file, excerpt = self._handle_inspect_file(action, task) elif action.action_type == "submit_finding": reward, feedback = self._handle_submit_finding(action, task) elif action.action_type == "submit_final_report": reward, feedback = self._handle_submit_final_report() else: feedback = f"Unsupported action_type={action.action_type}." self._degrade_quality(0.03) done = self._state.final_score is not None if not done and self._state.step_count >= self._state.max_steps: score = self._compute_final_score(task) self._state.final_score = score done = True reward = score feedback = ( f"Max steps reached. Auto-finalized audit score={score:.3f}. " "Use fewer but higher-quality findings to improve precision." ) return self._build_observation( reward=reward, done=done, feedback=feedback, focused_file=focused_file, excerpt=excerpt, extra_metadata={ "last_action_error": None, }, ) @property def state(self) -> CodeSecurityState: return self._state def _require_task(self) -> TaskSpec: if self._task is None: raise RuntimeError("Environment has no active task. Call reset() first.") return self._task def _degrade_quality(self, amount: float) -> None: self._state.quality_multiplier = max(0.2, self._state.quality_multiplier - amount) def _format_file(self, content: str) -> str: lines = content.splitlines() numbered = [f"{idx + 1:>3}: {line}" for idx, line in enumerate(lines)] return "\n".join(numbered) def _handle_inspect_file( self, action: CodeSecurityAction, task: TaskSpec, ) -> tuple[float, str, Optional[str], str]: filename = action.filename or "" if filename not in task.repository: self._degrade_quality(0.04) return 0.0, f"Unknown file '{filename}'.", filename or None, "" first_time = filename not in self._state.inspected_files if first_time: self._state.inspected_files.append(filename) excerpt = self._format_file(task.repository[filename]) unmatched_in_file = any( vuln.filename == filename and vuln.id not in self._state.matched_vulnerability_ids for vuln in task.vulnerabilities ) if first_time and unmatched_in_file: reward = 0.04 feedback = "Useful inspection: this file likely contains unresolved security issues." elif first_time: reward = 0.02 feedback = "Inspection noted. No strong security signal yet." else: reward = 0.0 feedback = "File already inspected; repeated reads do not improve score." self._degrade_quality(0.01) return reward, feedback, filename, excerpt def _handle_submit_finding( self, action: CodeSecurityAction, task: TaskSpec, ) -> tuple[float, str]: required_missing = [] if not action.filename: required_missing.append("filename") if action.line_start is None: required_missing.append("line_start") if not action.vuln_type: required_missing.append("vuln_type") if not action.severity: required_missing.append("severity") if required_missing: self._degrade_quality(0.05) missing = ", ".join(required_missing) return 0.0, f"Incomplete finding. Missing fields: {missing}." line_end = action.line_end if action.line_end is not None else action.line_start evaluation = evaluate_finding( task=task, filename=action.filename, vuln_type=action.vuln_type, severity=action.severity, line_start=action.line_start, line_end=line_end, confidence=action.confidence, matched_already=self._state.matched_vulnerability_ids, ) finding_id = f"finding-{len(self._state.findings_submitted) + 1}" finding_record = FindingRecord( finding_id=finding_id, filename=action.filename, line_start=action.line_start, line_end=line_end, vuln_type=action.vuln_type, severity=action.severity, confidence=action.confidence, evidence=(action.evidence or "").strip(), summary=(action.summary or "").strip(), matched_vulnerability_id=evaluation.matched_vulnerability_id, component_score=evaluation.component_score, ) self._state.findings_submitted.append(finding_record) if evaluation.is_confirmed_match and evaluation.matched_vulnerability_id is not None: self._state.matched_vulnerability_ids.append(evaluation.matched_vulnerability_id) reward = min(1.0, (0.25 + 0.75 * evaluation.component_score) * self._state.quality_multiplier) feedback = ( f"{evaluation.feedback} " f"Confirmed={len(self._state.matched_vulnerability_ids)}/{len(task.vulnerabilities)}." ) return reward, feedback if ( evaluation.matched_vulnerability_id is not None and evaluation.matched_vulnerability_id in self._state.matched_vulnerability_ids ): self._state.duplicate_submission_count += 1 self._degrade_quality(0.04) return 0.01, evaluation.feedback if evaluation.component_score >= 0.45: self._degrade_quality(0.01) reward = min(0.2, 0.2 * evaluation.component_score * self._state.quality_multiplier) return reward, f"Partial progress: {evaluation.feedback}" self._state.false_positive_count += 1 self._degrade_quality(0.05) return 0.0, f"Likely false positive: {evaluation.feedback}" def _handle_submit_final_report(self) -> tuple[float, str]: task = self._require_task() score = self._compute_final_score(task) self._state.final_score = score feedback = ( f"Audit finalized. Final deterministic score={score:.3f}. " f"Confirmed {len(self._state.matched_vulnerability_ids)} of {len(task.vulnerabilities)} vulnerabilities." ) return score, feedback def _compute_final_score(self, task: TaskSpec) -> float: if self._state.findings_submitted: avg_component = sum(f.component_score for f in self._state.findings_submitted) / len( self._state.findings_submitted ) else: avg_component = 0.0 if self._state.findings_submitted: avg_calibration = sum( max(0.0, 1.0 - abs(f.confidence - 0.75)) for f in self._state.findings_submitted ) / len(self._state.findings_submitted) else: avg_calibration = 0.0 score = final_grade( task=task, confirmed_vulnerability_ids=self._state.matched_vulnerability_ids, findings_count=len(self._state.findings_submitted), false_positive_count=self._state.false_positive_count, duplicate_count=self._state.duplicate_submission_count, avg_component_score=avg_component, avg_confidence_calibration=avg_calibration, ) # This quality factor makes spam and random guesses strictly dominated, # limiting reward hacking while preserving partial-credit gradients. score *= self._state.quality_multiplier return max(self.MIN_STRICT_SCORE, min(self.MAX_STRICT_SCORE, score)) def _build_observation( self, *, reward: float, done: bool, feedback: str, focused_file: Optional[str], excerpt: str, extra_metadata: Optional[dict[str, Any]] = None, ) -> CodeSecurityObservation: task = self._require_task() findings_public = [ { "finding_id": f.finding_id, "filename": f.filename, "line_start": f.line_start, "line_end": f.line_end, "vuln_type": f.vuln_type, "severity": f.severity, "confidence": f.confidence, "component_score": round(f.component_score, 3), } for f in self._state.findings_submitted ] score_hint = len(self._state.matched_vulnerability_ids) / max(1, len(task.vulnerabilities)) metadata = { "quality_multiplier": round(self._state.quality_multiplier, 4), "false_positive_count": self._state.false_positive_count, "duplicate_submission_count": self._state.duplicate_submission_count, "confirmed_vulnerabilities": len(self._state.matched_vulnerability_ids), "total_vulnerabilities": len(task.vulnerabilities), "task_id": task.id, "difficulty": task.difficulty, "available_task_ids": list_task_ids(), "last_action_error": None, } if extra_metadata: metadata.update(extra_metadata) return CodeSecurityObservation( done=done, reward=max(0.0, min(1.0, reward)), metadata=metadata, task_id=task.id, task_title=task.title, difficulty=task.difficulty, objective=task.objective, instructions=( "Valid actions: inspect_file, submit_finding, submit_final_report. " "For submit_finding include filename, line_start/line_end, vuln_type, severity, confidence." ), available_files=sorted(task.repository.keys()), focused_file=focused_file, file_excerpt=excerpt, findings_so_far=findings_public, steps_remaining=max(0, self._state.max_steps - self._state.step_count), last_feedback=feedback, score_hint=max(0.0, min(1.0, score_hint)), )