| from typing import Dict, Any, Tuple, Optional |
| from environment.models import ( |
| ReviewAction, |
| ReviewState, |
| Observation |
| ) |
| from environment.tasks import TaskDefinitions |
| from environment.graders import TaskGrader, RewardCalculator |
|
|
|
|
| class CodeReviewEnv: |
| |
| def __init__(self): |
| self._state: Optional[ReviewState] = None |
| self.grader: Optional[TaskGrader] = None |
| self.reward_calculator = RewardCalculator() |
| self.max_steps = 50 |
| self.current_task_id: Optional[str] = None |
| |
| def reset(self, task_id: str = None) -> Dict[str, Any]: |
| if task_id is None: |
| task_id = "bug_detection_easy_1" |
| |
| self.current_task_id = task_id |
| task_data = TaskDefinitions.get_task(task_id) |
| |
| code_context = TaskDefinitions.create_code_context(task_data) |
| task_metadata = TaskDefinitions.create_task_metadata(task_data) |
| |
| self._state = ReviewState( |
| code_context=code_context, |
| task_metadata=task_metadata, |
| comments_made=[], |
| suggestions_made=[], |
| current_step=0, |
| is_complete=False, |
| final_decision=None, |
| last_action_valid=True, |
| last_error=None |
| ) |
| |
| self.grader = TaskGrader(task_metadata.expected_issues) |
| self.reward_calculator.reset() |
| |
| return self._get_observation() |
| |
| def step(self, action: Dict[str, Any]) -> Tuple[Dict[str, Any], float, bool, Dict[str, Any]]: |
| if self._state is None: |
| return {}, -0.1, True, {"error": "Environment not initialized. Call reset() first."} |
|
|
| if self._state.is_complete: |
| return self._get_observation(), 0.0, True, {"error": "Episode already complete"} |
| |
| try: |
| review_action = ReviewAction(**action) |
| except Exception as e: |
| self._state.last_action_valid = False |
| self._state.last_error = str(e) |
| return self._get_observation(), -0.1, False, {"error": str(e), "last_action_valid": False} |
| |
| self._state.current_step += 1 |
| self._process_action(review_action) |
|
|
| if review_action.action_type.value == "approve" and not review_action.final_decision: |
| review_action.final_decision = "approved" |
| elif review_action.action_type.value == "request_changes" and not review_action.final_decision: |
| review_action.final_decision = "changes_requested" |
| |
| if self._state.current_step >= self.max_steps: |
| self._state.is_complete = True |
| if not self._state.final_decision: |
| self._state.final_decision = "changes_requested" |
| |
| if review_action.final_decision and not self._state.is_complete: |
| self._state.is_complete = True |
| self._state.final_decision = review_action.final_decision |
| |
| reward = self.reward_calculator.calculate_reward( |
| review_action, |
| self._state.comments_made, |
| self._state.suggestions_made, |
| self._state.final_decision or "changes_requested", |
| self.grader, |
| self._state.last_action_valid, |
| ) |
| |
| info = { |
| "step": self._state.current_step, |
| "last_action_valid": self._state.last_action_valid, |
| "error": self._state.last_error, |
| "task_score": self.get_task_score(), |
| } |
| |
| return self._get_observation(), reward, self._state.is_complete, info |
| |
| def _process_action(self, action: ReviewAction): |
| if self._state is None: |
| return |
|
|
| self._state.last_action_valid = True |
| self._state.last_error = None |
| |
| if action.action_type.value == "add_comment": |
| for comment in action.comments: |
| if comment.line_number <= self._state.code_context.line_count: |
| self._state.comments_made.append(comment) |
| else: |
| self._state.last_action_valid = False |
| self._state.last_error = f"Line {comment.line_number} out of range" |
| |
| elif action.action_type.value == "suggest_fix": |
| for suggestion in action.suggestions: |
| if suggestion.original_line <= self._state.code_context.line_count: |
| self._state.suggestions_made.append(suggestion) |
| else: |
| self._state.last_action_valid = False |
| self._state.last_error = f"Line {suggestion.original_line} out of range" |
| |
| elif action.action_type.value == "mark_as_resolved": |
| for comment in action.comments: |
| for existing_comment in self._state.comments_made: |
| if existing_comment.line_number == comment.line_number: |
| existing_comment.resolved = True |
| |
| def _get_observation(self) -> Dict[str, Any]: |
| if self._state is None: |
| return {} |
|
|
| return Observation( |
| code_diff=self._state.code_context.code_diff, |
| file_context=self._state.code_context.surrounding_code, |
| file_path=self._state.code_context.file_path, |
| language=self._state.code_context.language, |
| task_description=self._state.task_metadata.description, |
| task_difficulty=self._state.task_metadata.difficulty, |
| current_step=self._state.current_step, |
| max_steps=self.max_steps, |
| previous_comments=self._state.comments_made, |
| previous_suggestions=self._state.suggestions_made, |
| review_complete=self._state.is_complete, |
| final_decision_made=self._state.final_decision |
| ).model_dump() |
| |
| def get_task_score(self) -> float: |
| if not self.grader or self._state is None: |
| return 0.0 |
|
|
| return self.grader.compute_score_from_state( |
| comments=self._state.comments_made, |
| suggestions=self._state.suggestions_made, |
| final_decision=self._state.final_decision or "changes_requested", |
| ) |
| |
| def close(self): |
| pass |
| |
| def state(self) -> Dict[str, Any]: |
| if self._state: |
| return self._state.model_dump() |
| return {} |