Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import random | |
| import uuid | |
| from typing import Any, Optional | |
| try: | |
| from core.env_server.interfaces import Environment | |
| except ImportError: | |
| try: | |
| from openenv.core.env_server.interfaces import Environment | |
| except ImportError: | |
| from openenv_core.env_server.interfaces import Environment | |
| try: | |
| from ..models import ( | |
| CodeSecurityAction, | |
| CodeSecurityObservation, | |
| CodeSecurityState, | |
| FindingRecord, | |
| ) | |
| from .grader import evaluate_finding, final_grade | |
| from .tasks import TaskSpec, get_task, list_task_ids | |
| except ImportError: | |
| from models import ( | |
| CodeSecurityAction, | |
| CodeSecurityObservation, | |
| CodeSecurityState, | |
| FindingRecord, | |
| ) | |
| from server.grader import evaluate_finding, final_grade | |
| from server.tasks import TaskSpec, get_task, list_task_ids | |
| class CodeSecurityAuditorEnvironment( | |
| Environment[CodeSecurityAction, CodeSecurityObservation, CodeSecurityState] | |
| ): | |
| """Real-world code security auditing simulator with deterministic graders.""" | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| MIN_STRICT_SCORE = 0.001 | |
| MAX_STRICT_SCORE = 0.999 | |
| def __init__(self, default_task_id: str = "easy"): | |
| self._default_task_id = default_task_id | |
| self._task_cursor = 0 | |
| self._task: Optional[TaskSpec] = None | |
| self._state = CodeSecurityState() | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> CodeSecurityObservation: | |
| requested_task = kwargs.get("task_id") or kwargs.get("task") | |
| if requested_task is not None: | |
| task = get_task(str(requested_task)) | |
| elif seed is not None: | |
| rng = random.Random(seed) | |
| task = get_task(rng.choice(list_task_ids())) | |
| elif self._default_task_id: | |
| task = get_task(self._default_task_id) | |
| else: | |
| task_order = list_task_ids() | |
| task = get_task(task_order[self._task_cursor % len(task_order)]) | |
| self._task_cursor += 1 | |
| self._task = task | |
| self._state = CodeSecurityState( | |
| episode_id=episode_id or str(uuid.uuid4()), | |
| step_count=0, | |
| task_id=task.id, | |
| task_title=task.title, | |
| difficulty=task.difficulty, | |
| objective=task.objective, | |
| max_steps=task.max_steps, | |
| inspected_files=[], | |
| findings_submitted=[], | |
| matched_vulnerability_ids=[], | |
| false_positive_count=0, | |
| duplicate_submission_count=0, | |
| quality_multiplier=1.0, | |
| final_score=None, | |
| ) | |
| return self._build_observation( | |
| reward=0.0, | |
| done=False, | |
| feedback=( | |
| "Audit started. Use inspect_file before submit_finding. " | |
| "Finish with submit_final_report." | |
| ), | |
| focused_file=None, | |
| excerpt="", | |
| extra_metadata={ | |
| "available_task_ids": list_task_ids(), | |
| "task_id": task.id, | |
| }, | |
| ) | |
| def step( | |
| self, | |
| action: CodeSecurityAction, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> CodeSecurityObservation: | |
| del timeout_s, kwargs | |
| task = self._require_task() | |
| if self._state.final_score is not None: | |
| return self._build_observation( | |
| reward=0.0, | |
| done=True, | |
| feedback="Episode already terminated. Call reset() to start a new task.", | |
| focused_file=None, | |
| excerpt="", | |
| ) | |
| self._state.step_count += 1 | |
| feedback = "" | |
| reward = 0.0 | |
| focused_file = None | |
| excerpt = "" | |
| if action.action_type == "inspect_file": | |
| reward, feedback, focused_file, excerpt = self._handle_inspect_file(action, task) | |
| elif action.action_type == "submit_finding": | |
| reward, feedback = self._handle_submit_finding(action, task) | |
| elif action.action_type == "submit_final_report": | |
| reward, feedback = self._handle_submit_final_report() | |
| else: | |
| feedback = f"Unsupported action_type={action.action_type}." | |
| self._degrade_quality(0.03) | |
| done = self._state.final_score is not None | |
| if not done and self._state.step_count >= self._state.max_steps: | |
| score = self._compute_final_score(task) | |
| self._state.final_score = score | |
| done = True | |
| reward = score | |
| feedback = ( | |
| f"Max steps reached. Auto-finalized audit score={score:.3f}. " | |
| "Use fewer but higher-quality findings to improve precision." | |
| ) | |
| return self._build_observation( | |
| reward=reward, | |
| done=done, | |
| feedback=feedback, | |
| focused_file=focused_file, | |
| excerpt=excerpt, | |
| extra_metadata={ | |
| "last_action_error": None, | |
| }, | |
| ) | |
| def state(self) -> CodeSecurityState: | |
| return self._state | |
| def _require_task(self) -> TaskSpec: | |
| if self._task is None: | |
| raise RuntimeError("Environment has no active task. Call reset() first.") | |
| return self._task | |
| def _degrade_quality(self, amount: float) -> None: | |
| self._state.quality_multiplier = max(0.2, self._state.quality_multiplier - amount) | |
| def _format_file(self, content: str) -> str: | |
| lines = content.splitlines() | |
| numbered = [f"{idx + 1:>3}: {line}" for idx, line in enumerate(lines)] | |
| return "\n".join(numbered) | |
| def _handle_inspect_file( | |
| self, | |
| action: CodeSecurityAction, | |
| task: TaskSpec, | |
| ) -> tuple[float, str, Optional[str], str]: | |
| filename = action.filename or "" | |
| if filename not in task.repository: | |
| self._degrade_quality(0.04) | |
| return 0.0, f"Unknown file '{filename}'.", filename or None, "" | |
| first_time = filename not in self._state.inspected_files | |
| if first_time: | |
| self._state.inspected_files.append(filename) | |
| excerpt = self._format_file(task.repository[filename]) | |
| unmatched_in_file = any( | |
| vuln.filename == filename and vuln.id not in self._state.matched_vulnerability_ids | |
| for vuln in task.vulnerabilities | |
| ) | |
| if first_time and unmatched_in_file: | |
| reward = 0.04 | |
| feedback = "Useful inspection: this file likely contains unresolved security issues." | |
| elif first_time: | |
| reward = 0.02 | |
| feedback = "Inspection noted. No strong security signal yet." | |
| else: | |
| reward = 0.0 | |
| feedback = "File already inspected; repeated reads do not improve score." | |
| self._degrade_quality(0.01) | |
| return reward, feedback, filename, excerpt | |
| def _handle_submit_finding( | |
| self, | |
| action: CodeSecurityAction, | |
| task: TaskSpec, | |
| ) -> tuple[float, str]: | |
| required_missing = [] | |
| if not action.filename: | |
| required_missing.append("filename") | |
| if action.line_start is None: | |
| required_missing.append("line_start") | |
| if not action.vuln_type: | |
| required_missing.append("vuln_type") | |
| if not action.severity: | |
| required_missing.append("severity") | |
| if required_missing: | |
| self._degrade_quality(0.05) | |
| missing = ", ".join(required_missing) | |
| return 0.0, f"Incomplete finding. Missing fields: {missing}." | |
| line_end = action.line_end if action.line_end is not None else action.line_start | |
| evaluation = evaluate_finding( | |
| task=task, | |
| filename=action.filename, | |
| vuln_type=action.vuln_type, | |
| severity=action.severity, | |
| line_start=action.line_start, | |
| line_end=line_end, | |
| confidence=action.confidence, | |
| matched_already=self._state.matched_vulnerability_ids, | |
| ) | |
| finding_id = f"finding-{len(self._state.findings_submitted) + 1}" | |
| finding_record = FindingRecord( | |
| finding_id=finding_id, | |
| filename=action.filename, | |
| line_start=action.line_start, | |
| line_end=line_end, | |
| vuln_type=action.vuln_type, | |
| severity=action.severity, | |
| confidence=action.confidence, | |
| evidence=(action.evidence or "").strip(), | |
| summary=(action.summary or "").strip(), | |
| matched_vulnerability_id=evaluation.matched_vulnerability_id, | |
| component_score=evaluation.component_score, | |
| ) | |
| self._state.findings_submitted.append(finding_record) | |
| if evaluation.is_confirmed_match and evaluation.matched_vulnerability_id is not None: | |
| self._state.matched_vulnerability_ids.append(evaluation.matched_vulnerability_id) | |
| reward = min(1.0, (0.25 + 0.75 * evaluation.component_score) * self._state.quality_multiplier) | |
| feedback = ( | |
| f"{evaluation.feedback} " | |
| f"Confirmed={len(self._state.matched_vulnerability_ids)}/{len(task.vulnerabilities)}." | |
| ) | |
| return reward, feedback | |
| if ( | |
| evaluation.matched_vulnerability_id is not None | |
| and evaluation.matched_vulnerability_id in self._state.matched_vulnerability_ids | |
| ): | |
| self._state.duplicate_submission_count += 1 | |
| self._degrade_quality(0.04) | |
| return 0.01, evaluation.feedback | |
| if evaluation.component_score >= 0.45: | |
| self._degrade_quality(0.01) | |
| reward = min(0.2, 0.2 * evaluation.component_score * self._state.quality_multiplier) | |
| return reward, f"Partial progress: {evaluation.feedback}" | |
| self._state.false_positive_count += 1 | |
| self._degrade_quality(0.05) | |
| return 0.0, f"Likely false positive: {evaluation.feedback}" | |
| def _handle_submit_final_report(self) -> tuple[float, str]: | |
| task = self._require_task() | |
| score = self._compute_final_score(task) | |
| self._state.final_score = score | |
| feedback = ( | |
| f"Audit finalized. Final deterministic score={score:.3f}. " | |
| f"Confirmed {len(self._state.matched_vulnerability_ids)} of {len(task.vulnerabilities)} vulnerabilities." | |
| ) | |
| return score, feedback | |
| def _compute_final_score(self, task: TaskSpec) -> float: | |
| if self._state.findings_submitted: | |
| avg_component = sum(f.component_score for f in self._state.findings_submitted) / len( | |
| self._state.findings_submitted | |
| ) | |
| else: | |
| avg_component = 0.0 | |
| if self._state.findings_submitted: | |
| avg_calibration = sum( | |
| max(0.0, 1.0 - abs(f.confidence - 0.75)) for f in self._state.findings_submitted | |
| ) / len(self._state.findings_submitted) | |
| else: | |
| avg_calibration = 0.0 | |
| score = final_grade( | |
| task=task, | |
| confirmed_vulnerability_ids=self._state.matched_vulnerability_ids, | |
| findings_count=len(self._state.findings_submitted), | |
| false_positive_count=self._state.false_positive_count, | |
| duplicate_count=self._state.duplicate_submission_count, | |
| avg_component_score=avg_component, | |
| avg_confidence_calibration=avg_calibration, | |
| ) | |
| # This quality factor makes spam and random guesses strictly dominated, | |
| # limiting reward hacking while preserving partial-credit gradients. | |
| score *= self._state.quality_multiplier | |
| return max(self.MIN_STRICT_SCORE, min(self.MAX_STRICT_SCORE, score)) | |
| def _build_observation( | |
| self, | |
| *, | |
| reward: float, | |
| done: bool, | |
| feedback: str, | |
| focused_file: Optional[str], | |
| excerpt: str, | |
| extra_metadata: Optional[dict[str, Any]] = None, | |
| ) -> CodeSecurityObservation: | |
| task = self._require_task() | |
| findings_public = [ | |
| { | |
| "finding_id": f.finding_id, | |
| "filename": f.filename, | |
| "line_start": f.line_start, | |
| "line_end": f.line_end, | |
| "vuln_type": f.vuln_type, | |
| "severity": f.severity, | |
| "confidence": f.confidence, | |
| "component_score": round(f.component_score, 3), | |
| } | |
| for f in self._state.findings_submitted | |
| ] | |
| score_hint = len(self._state.matched_vulnerability_ids) / max(1, len(task.vulnerabilities)) | |
| metadata = { | |
| "quality_multiplier": round(self._state.quality_multiplier, 4), | |
| "false_positive_count": self._state.false_positive_count, | |
| "duplicate_submission_count": self._state.duplicate_submission_count, | |
| "confirmed_vulnerabilities": len(self._state.matched_vulnerability_ids), | |
| "total_vulnerabilities": len(task.vulnerabilities), | |
| "task_id": task.id, | |
| "difficulty": task.difficulty, | |
| "available_task_ids": list_task_ids(), | |
| "last_action_error": None, | |
| } | |
| if extra_metadata: | |
| metadata.update(extra_metadata) | |
| return CodeSecurityObservation( | |
| done=done, | |
| reward=max(0.0, min(1.0, reward)), | |
| metadata=metadata, | |
| task_id=task.id, | |
| task_title=task.title, | |
| difficulty=task.difficulty, | |
| objective=task.objective, | |
| instructions=( | |
| "Valid actions: inspect_file, submit_finding, submit_final_report. " | |
| "For submit_finding include filename, line_start/line_end, vuln_type, severity, confidence." | |
| ), | |
| available_files=sorted(task.repository.keys()), | |
| focused_file=focused_file, | |
| file_excerpt=excerpt, | |
| findings_so_far=findings_public, | |
| steps_remaining=max(0, self._state.max_steps - self._state.step_count), | |
| last_feedback=feedback, | |
| score_hint=max(0.0, min(1.0, score_hint)), | |
| ) | |