CodeSecure / server /security_environment.py
Hassan Shaikh
fix: enforce strict open-interval task scores
2916eb9
from __future__ import annotations
import random
import uuid
from typing import Any, Optional
try:
from core.env_server.interfaces import Environment
except ImportError:
try:
from openenv.core.env_server.interfaces import Environment
except ImportError:
from openenv_core.env_server.interfaces import Environment
try:
from ..models import (
CodeSecurityAction,
CodeSecurityObservation,
CodeSecurityState,
FindingRecord,
)
from .grader import evaluate_finding, final_grade
from .tasks import TaskSpec, get_task, list_task_ids
except ImportError:
from models import (
CodeSecurityAction,
CodeSecurityObservation,
CodeSecurityState,
FindingRecord,
)
from server.grader import evaluate_finding, final_grade
from server.tasks import TaskSpec, get_task, list_task_ids
class CodeSecurityAuditorEnvironment(
Environment[CodeSecurityAction, CodeSecurityObservation, CodeSecurityState]
):
"""Real-world code security auditing simulator with deterministic graders."""
SUPPORTS_CONCURRENT_SESSIONS = True
MIN_STRICT_SCORE = 0.001
MAX_STRICT_SCORE = 0.999
def __init__(self, default_task_id: str = "easy"):
self._default_task_id = default_task_id
self._task_cursor = 0
self._task: Optional[TaskSpec] = None
self._state = CodeSecurityState()
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> CodeSecurityObservation:
requested_task = kwargs.get("task_id") or kwargs.get("task")
if requested_task is not None:
task = get_task(str(requested_task))
elif seed is not None:
rng = random.Random(seed)
task = get_task(rng.choice(list_task_ids()))
elif self._default_task_id:
task = get_task(self._default_task_id)
else:
task_order = list_task_ids()
task = get_task(task_order[self._task_cursor % len(task_order)])
self._task_cursor += 1
self._task = task
self._state = CodeSecurityState(
episode_id=episode_id or str(uuid.uuid4()),
step_count=0,
task_id=task.id,
task_title=task.title,
difficulty=task.difficulty,
objective=task.objective,
max_steps=task.max_steps,
inspected_files=[],
findings_submitted=[],
matched_vulnerability_ids=[],
false_positive_count=0,
duplicate_submission_count=0,
quality_multiplier=1.0,
final_score=None,
)
return self._build_observation(
reward=0.0,
done=False,
feedback=(
"Audit started. Use inspect_file before submit_finding. "
"Finish with submit_final_report."
),
focused_file=None,
excerpt="",
extra_metadata={
"available_task_ids": list_task_ids(),
"task_id": task.id,
},
)
def step(
self,
action: CodeSecurityAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> CodeSecurityObservation:
del timeout_s, kwargs
task = self._require_task()
if self._state.final_score is not None:
return self._build_observation(
reward=0.0,
done=True,
feedback="Episode already terminated. Call reset() to start a new task.",
focused_file=None,
excerpt="",
)
self._state.step_count += 1
feedback = ""
reward = 0.0
focused_file = None
excerpt = ""
if action.action_type == "inspect_file":
reward, feedback, focused_file, excerpt = self._handle_inspect_file(action, task)
elif action.action_type == "submit_finding":
reward, feedback = self._handle_submit_finding(action, task)
elif action.action_type == "submit_final_report":
reward, feedback = self._handle_submit_final_report()
else:
feedback = f"Unsupported action_type={action.action_type}."
self._degrade_quality(0.03)
done = self._state.final_score is not None
if not done and self._state.step_count >= self._state.max_steps:
score = self._compute_final_score(task)
self._state.final_score = score
done = True
reward = score
feedback = (
f"Max steps reached. Auto-finalized audit score={score:.3f}. "
"Use fewer but higher-quality findings to improve precision."
)
return self._build_observation(
reward=reward,
done=done,
feedback=feedback,
focused_file=focused_file,
excerpt=excerpt,
extra_metadata={
"last_action_error": None,
},
)
@property
def state(self) -> CodeSecurityState:
return self._state
def _require_task(self) -> TaskSpec:
if self._task is None:
raise RuntimeError("Environment has no active task. Call reset() first.")
return self._task
def _degrade_quality(self, amount: float) -> None:
self._state.quality_multiplier = max(0.2, self._state.quality_multiplier - amount)
def _format_file(self, content: str) -> str:
lines = content.splitlines()
numbered = [f"{idx + 1:>3}: {line}" for idx, line in enumerate(lines)]
return "\n".join(numbered)
def _handle_inspect_file(
self,
action: CodeSecurityAction,
task: TaskSpec,
) -> tuple[float, str, Optional[str], str]:
filename = action.filename or ""
if filename not in task.repository:
self._degrade_quality(0.04)
return 0.0, f"Unknown file '{filename}'.", filename or None, ""
first_time = filename not in self._state.inspected_files
if first_time:
self._state.inspected_files.append(filename)
excerpt = self._format_file(task.repository[filename])
unmatched_in_file = any(
vuln.filename == filename and vuln.id not in self._state.matched_vulnerability_ids
for vuln in task.vulnerabilities
)
if first_time and unmatched_in_file:
reward = 0.04
feedback = "Useful inspection: this file likely contains unresolved security issues."
elif first_time:
reward = 0.02
feedback = "Inspection noted. No strong security signal yet."
else:
reward = 0.0
feedback = "File already inspected; repeated reads do not improve score."
self._degrade_quality(0.01)
return reward, feedback, filename, excerpt
def _handle_submit_finding(
self,
action: CodeSecurityAction,
task: TaskSpec,
) -> tuple[float, str]:
required_missing = []
if not action.filename:
required_missing.append("filename")
if action.line_start is None:
required_missing.append("line_start")
if not action.vuln_type:
required_missing.append("vuln_type")
if not action.severity:
required_missing.append("severity")
if required_missing:
self._degrade_quality(0.05)
missing = ", ".join(required_missing)
return 0.0, f"Incomplete finding. Missing fields: {missing}."
line_end = action.line_end if action.line_end is not None else action.line_start
evaluation = evaluate_finding(
task=task,
filename=action.filename,
vuln_type=action.vuln_type,
severity=action.severity,
line_start=action.line_start,
line_end=line_end,
confidence=action.confidence,
matched_already=self._state.matched_vulnerability_ids,
)
finding_id = f"finding-{len(self._state.findings_submitted) + 1}"
finding_record = FindingRecord(
finding_id=finding_id,
filename=action.filename,
line_start=action.line_start,
line_end=line_end,
vuln_type=action.vuln_type,
severity=action.severity,
confidence=action.confidence,
evidence=(action.evidence or "").strip(),
summary=(action.summary or "").strip(),
matched_vulnerability_id=evaluation.matched_vulnerability_id,
component_score=evaluation.component_score,
)
self._state.findings_submitted.append(finding_record)
if evaluation.is_confirmed_match and evaluation.matched_vulnerability_id is not None:
self._state.matched_vulnerability_ids.append(evaluation.matched_vulnerability_id)
reward = min(1.0, (0.25 + 0.75 * evaluation.component_score) * self._state.quality_multiplier)
feedback = (
f"{evaluation.feedback} "
f"Confirmed={len(self._state.matched_vulnerability_ids)}/{len(task.vulnerabilities)}."
)
return reward, feedback
if (
evaluation.matched_vulnerability_id is not None
and evaluation.matched_vulnerability_id in self._state.matched_vulnerability_ids
):
self._state.duplicate_submission_count += 1
self._degrade_quality(0.04)
return 0.01, evaluation.feedback
if evaluation.component_score >= 0.45:
self._degrade_quality(0.01)
reward = min(0.2, 0.2 * evaluation.component_score * self._state.quality_multiplier)
return reward, f"Partial progress: {evaluation.feedback}"
self._state.false_positive_count += 1
self._degrade_quality(0.05)
return 0.0, f"Likely false positive: {evaluation.feedback}"
def _handle_submit_final_report(self) -> tuple[float, str]:
task = self._require_task()
score = self._compute_final_score(task)
self._state.final_score = score
feedback = (
f"Audit finalized. Final deterministic score={score:.3f}. "
f"Confirmed {len(self._state.matched_vulnerability_ids)} of {len(task.vulnerabilities)} vulnerabilities."
)
return score, feedback
def _compute_final_score(self, task: TaskSpec) -> float:
if self._state.findings_submitted:
avg_component = sum(f.component_score for f in self._state.findings_submitted) / len(
self._state.findings_submitted
)
else:
avg_component = 0.0
if self._state.findings_submitted:
avg_calibration = sum(
max(0.0, 1.0 - abs(f.confidence - 0.75)) for f in self._state.findings_submitted
) / len(self._state.findings_submitted)
else:
avg_calibration = 0.0
score = final_grade(
task=task,
confirmed_vulnerability_ids=self._state.matched_vulnerability_ids,
findings_count=len(self._state.findings_submitted),
false_positive_count=self._state.false_positive_count,
duplicate_count=self._state.duplicate_submission_count,
avg_component_score=avg_component,
avg_confidence_calibration=avg_calibration,
)
# This quality factor makes spam and random guesses strictly dominated,
# limiting reward hacking while preserving partial-credit gradients.
score *= self._state.quality_multiplier
return max(self.MIN_STRICT_SCORE, min(self.MAX_STRICT_SCORE, score))
def _build_observation(
self,
*,
reward: float,
done: bool,
feedback: str,
focused_file: Optional[str],
excerpt: str,
extra_metadata: Optional[dict[str, Any]] = None,
) -> CodeSecurityObservation:
task = self._require_task()
findings_public = [
{
"finding_id": f.finding_id,
"filename": f.filename,
"line_start": f.line_start,
"line_end": f.line_end,
"vuln_type": f.vuln_type,
"severity": f.severity,
"confidence": f.confidence,
"component_score": round(f.component_score, 3),
}
for f in self._state.findings_submitted
]
score_hint = len(self._state.matched_vulnerability_ids) / max(1, len(task.vulnerabilities))
metadata = {
"quality_multiplier": round(self._state.quality_multiplier, 4),
"false_positive_count": self._state.false_positive_count,
"duplicate_submission_count": self._state.duplicate_submission_count,
"confirmed_vulnerabilities": len(self._state.matched_vulnerability_ids),
"total_vulnerabilities": len(task.vulnerabilities),
"task_id": task.id,
"difficulty": task.difficulty,
"available_task_ids": list_task_ids(),
"last_action_error": None,
}
if extra_metadata:
metadata.update(extra_metadata)
return CodeSecurityObservation(
done=done,
reward=max(0.0, min(1.0, reward)),
metadata=metadata,
task_id=task.id,
task_title=task.title,
difficulty=task.difficulty,
objective=task.objective,
instructions=(
"Valid actions: inspect_file, submit_finding, submit_final_report. "
"For submit_finding include filename, line_start/line_end, vuln_type, severity, confidence."
),
available_files=sorted(task.repository.keys()),
focused_file=focused_file,
file_excerpt=excerpt,
findings_so_far=findings_public,
steps_remaining=max(0, self._state.max_steps - self._state.step_count),
last_feedback=feedback,
score_hint=max(0.0, min(1.0, score_hint)),
)