from typing import Dict, Any, Tuple, Optional from environment.models import ( ReviewAction, ReviewState, Observation ) from environment.tasks import TaskDefinitions from environment.graders import TaskGrader, RewardCalculator class CodeReviewEnv: def __init__(self): self._state: Optional[ReviewState] = None self.grader: Optional[TaskGrader] = None self.reward_calculator = RewardCalculator() self.max_steps = 50 self.current_task_id: Optional[str] = None def reset(self, task_id: str = None) -> Dict[str, Any]: if task_id is None: task_id = "bug_detection_easy_1" self.current_task_id = task_id task_data = TaskDefinitions.get_task(task_id) code_context = TaskDefinitions.create_code_context(task_data) task_metadata = TaskDefinitions.create_task_metadata(task_data) self._state = ReviewState( code_context=code_context, task_metadata=task_metadata, comments_made=[], suggestions_made=[], current_step=0, is_complete=False, final_decision=None, last_action_valid=True, last_error=None ) self.grader = TaskGrader(task_metadata.expected_issues) self.reward_calculator.reset() return self._get_observation() def step(self, action: Dict[str, Any]) -> Tuple[Dict[str, Any], float, bool, Dict[str, Any]]: if self._state is None: return {}, -0.1, True, {"error": "Environment not initialized. Call reset() first."} if self._state.is_complete: return self._get_observation(), 0.0, True, {"error": "Episode already complete"} try: review_action = ReviewAction(**action) except Exception as e: self._state.last_action_valid = False self._state.last_error = str(e) return self._get_observation(), -0.1, False, {"error": str(e), "last_action_valid": False} self._state.current_step += 1 self._process_action(review_action) if review_action.action_type.value == "approve" and not review_action.final_decision: review_action.final_decision = "approved" elif review_action.action_type.value == "request_changes" and not review_action.final_decision: review_action.final_decision = "changes_requested" if self._state.current_step >= self.max_steps: self._state.is_complete = True if not self._state.final_decision: self._state.final_decision = "changes_requested" if review_action.final_decision and not self._state.is_complete: self._state.is_complete = True self._state.final_decision = review_action.final_decision reward = self.reward_calculator.calculate_reward( review_action, self._state.comments_made, self._state.suggestions_made, self._state.final_decision or "changes_requested", self.grader, self._state.last_action_valid, ) info = { "step": self._state.current_step, "last_action_valid": self._state.last_action_valid, "error": self._state.last_error, "task_score": self.get_task_score(), } return self._get_observation(), reward, self._state.is_complete, info def _process_action(self, action: ReviewAction): if self._state is None: return self._state.last_action_valid = True self._state.last_error = None if action.action_type.value == "add_comment": for comment in action.comments: if comment.line_number <= self._state.code_context.line_count: self._state.comments_made.append(comment) else: self._state.last_action_valid = False self._state.last_error = f"Line {comment.line_number} out of range" elif action.action_type.value == "suggest_fix": for suggestion in action.suggestions: if suggestion.original_line <= self._state.code_context.line_count: self._state.suggestions_made.append(suggestion) else: self._state.last_action_valid = False self._state.last_error = f"Line {suggestion.original_line} out of range" elif action.action_type.value == "mark_as_resolved": for comment in action.comments: for existing_comment in self._state.comments_made: if existing_comment.line_number == comment.line_number: existing_comment.resolved = True def _get_observation(self) -> Dict[str, Any]: if self._state is None: return {} return Observation( code_diff=self._state.code_context.code_diff, file_context=self._state.code_context.surrounding_code, file_path=self._state.code_context.file_path, language=self._state.code_context.language, task_description=self._state.task_metadata.description, task_difficulty=self._state.task_metadata.difficulty, current_step=self._state.current_step, max_steps=self.max_steps, previous_comments=self._state.comments_made, previous_suggestions=self._state.suggestions_made, review_complete=self._state.is_complete, final_decision_made=self._state.final_decision ).model_dump() def get_task_score(self) -> float: if not self.grader or self._state is None: return 0.0 return self.grader.compute_score_from_state( comments=self._state.comments_made, suggestions=self._state.suggestions_made, final_decision=self._state.final_decision or "changes_requested", ) def close(self): pass def state(self) -> Dict[str, Any]: if self._state: return self._state.model_dump() return {}