ArshVerma's picture
Initial CodeLens OpenEnv submission
f8670cd
Raw
History Blame Contribute Delete
7.57 kB
from datetime import datetime, timezone
from typing import List, Optional, Set
from codelens_env.models import (
TaskId, Action, Observation, StepResult, ResetResult,
ActionType, ActionRecord, EpisodeResult, Severity, GroundTruthIssue, Reward
)
from codelens_env.scenarios import get_scenario
from codelens_env.graders.bug_grader import grade_bug_detection
from codelens_env.graders.security_grader import grade_security_audit
from codelens_env.graders.arch_grader import grade_architectural_review
class CodeLensEnv:
MAX_NOISE_BUDGET = 5
TASK_MAX_STEPS = {
TaskId.BUG_DETECTION: 10,
TaskId.SECURITY_AUDIT: 15,
TaskId.ARCHITECTURAL_REVIEW: 20,
}
SEVERITY_WEIGHTS = {
Severity.CRITICAL: 1.0,
Severity.HIGH: 0.8,
Severity.MEDIUM: 0.5,
Severity.LOW: 0.2,
Severity.INFO: 0.0,
}
def __init__(self):
self.task_id: Optional[TaskId] = None
self.seed: int = 42
self.scenario = None
self.step_count: int = 0
self.noise_budget: int = self.MAX_NOISE_BUDGET
self.history: List[ActionRecord] = []
self.matched_issue_ids: Set[str] = set()
self.done: bool = False
self.terminated_reason: str = ""
self.episode_id: str = ""
def reset(self, task_id: TaskId, seed: int = 42) -> ResetResult:
self.scenario = get_scenario(task_id, seed)
self.task_id = task_id
self.seed = seed
self.step_count = 0
self.noise_budget = self.MAX_NOISE_BUDGET
self.history = []
self.matched_issue_ids = set()
self.done = False
self.terminated_reason = ""
obs = self._build_observation()
return ResetResult(
task_id=task_id,
seed=seed,
scenario_hash=self.scenario.hash,
observation=obs
)
def step(self, action: Action) -> StepResult:
if self.done:
raise ValueError("Episode is already finished")
self.step_count += 1
reward = 0.0
match = None # Track matched ground truth issue (if any)
# Determine terminal state and reward
if action.action_type in (ActionType.APPROVE, ActionType.REQUEST_CHANGES):
self.done = True
self.terminated_reason = "terminal_action"
reward = 0.0 # Grader handles final score
elif action.action_type == ActionType.FLAG_ISSUE:
match = None
for issue in self.scenario.ground_truth_issues:
if self._is_match(action, issue):
match = issue
break
if match:
if match.id not in self.matched_issue_ids:
reward = self.SEVERITY_WEIGHTS.get(match.severity, 0.0)
self.matched_issue_ids.add(match.id)
else:
# Already matched: penalty
reward = -0.05
self.noise_budget -= 1
if self.noise_budget <= 0:
self.done = True
self.terminated_reason = "noise_exhausted"
else:
# False positive: penalty
reward = -0.05
self.noise_budget -= 1
if self.noise_budget <= 0:
self.done = True
self.terminated_reason = "noise_exhausted"
# Max steps check
max_steps = self.TASK_MAX_STEPS.get(self.task_id, 10)
if not self.done and self.step_count >= max_steps:
self.done = True
self.terminated_reason = "max_steps"
# Build reward reason
if action.action_type in (ActionType.APPROVE, ActionType.REQUEST_CHANGES):
reward_reason = "Terminal action submitted"
elif action.action_type == ActionType.FLAG_ISSUE:
if match and match.id in self.matched_issue_ids and reward > 0:
reward_reason = f"Correctly identified issue: {match.description[:60]}"
elif match and reward < 0:
reward_reason = "Duplicate issue flagged"
elif not match:
reward_reason = "False positive: no matching ground truth issue"
else:
reward_reason = f"Matched issue {match.id}" if match else "No match"
else:
reward_reason = "Non-scoring action"
# Record action
record = ActionRecord(
action_type=action.action_type,
body=action.body,
filename=action.filename,
line_number=action.line_number,
category=action.category,
severity=action.severity,
verdict=action.verdict,
reward=float(reward),
timestamp=datetime.now(timezone.utc).isoformat()
)
self.history.append(record)
return StepResult(
observation=self._build_observation(),
reward=float(reward),
reward_info=Reward(
value=float(max(0.0, reward)),
reason=reward_reason,
is_terminal=self.done
),
done=self.done,
info={"terminated_reason": self.terminated_reason}
)
def _is_match(self, action: Action, issue: GroundTruthIssue) -> bool:
if action.filename != issue.filename:
return False
if action.line_number is None:
return False
if abs(action.line_number - issue.line_number) > 3:
return False
if action.category != issue.category:
return False
body_lower = (action.body or "").lower()
return any(kw.lower() in body_lower for kw in issue.keywords)
def _build_observation(self) -> Observation:
max_steps = self.TASK_MAX_STEPS.get(self.task_id, 10)
diff = "\n".join(f.patch for f in self.scenario.files_changed)
return Observation(
task_id=self.task_id,
scenario_hash=self.scenario.hash,
pr_title=self.scenario.pr_title,
pr_description=self.scenario.pr_description,
diff=diff,
files_changed=self.scenario.files_changed,
step_count=self.step_count,
max_steps=max_steps,
noise_budget=self.noise_budget,
max_noise_budget=self.MAX_NOISE_BUDGET,
issues_flagged=len(self.matched_issue_ids),
)
def state(self) -> Observation:
return self._build_observation()
def get_final_result(self) -> EpisodeResult:
if self.task_id == TaskId.BUG_DETECTION:
final_score = grade_bug_detection(self.scenario, self.history)
elif self.task_id == TaskId.SECURITY_AUDIT:
final_score = grade_security_audit(self.scenario, self.history)
else:
final_score = grade_architectural_review(self.scenario, self.history)
return EpisodeResult(
episode_id=self.episode_id,
task_id=self.task_id,
scenario_hash=self.scenario.hash,
seed=self.seed,
final_score=round(final_score, 4),
steps_taken=self.step_count,
issues_found=len(self.matched_issue_ids),
issues_total=len(self.scenario.ground_truth_issues),
noise_penalties=self.MAX_NOISE_BUDGET - self.noise_budget,
history=self.history,
terminated_reason=self.terminated_reason
)