Spaces:
Sleeping
Sleeping
File size: 6,301 Bytes
a21d9ae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | from typing import Dict, Any, Tuple, Optional
from environment.models import (
ReviewAction,
ReviewState,
Observation
)
from environment.tasks import TaskDefinitions
from environment.graders import TaskGrader, RewardCalculator
class CodeReviewEnv:
def __init__(self):
self._state: Optional[ReviewState] = None
self.grader: Optional[TaskGrader] = None
self.reward_calculator = RewardCalculator()
self.max_steps = 50
self.current_task_id: Optional[str] = None
def reset(self, task_id: str = None) -> Dict[str, Any]:
if task_id is None:
task_id = "bug_detection_easy_1"
self.current_task_id = task_id
task_data = TaskDefinitions.get_task(task_id)
code_context = TaskDefinitions.create_code_context(task_data)
task_metadata = TaskDefinitions.create_task_metadata(task_data)
self._state = ReviewState(
code_context=code_context,
task_metadata=task_metadata,
comments_made=[],
suggestions_made=[],
current_step=0,
is_complete=False,
final_decision=None,
last_action_valid=True,
last_error=None
)
self.grader = TaskGrader(task_metadata.expected_issues)
self.reward_calculator.reset()
return self._get_observation()
def step(self, action: Dict[str, Any]) -> Tuple[Dict[str, Any], float, bool, Dict[str, Any]]:
if self._state is None:
return {}, -0.1, True, {"error": "Environment not initialized. Call reset() first."}
if self._state.is_complete:
return self._get_observation(), 0.0, True, {"error": "Episode already complete"}
try:
review_action = ReviewAction(**action)
except Exception as e:
self._state.last_action_valid = False
self._state.last_error = str(e)
return self._get_observation(), -0.1, False, {"error": str(e), "last_action_valid": False}
self._state.current_step += 1
self._process_action(review_action)
if review_action.action_type.value == "approve" and not review_action.final_decision:
review_action.final_decision = "approved"
elif review_action.action_type.value == "request_changes" and not review_action.final_decision:
review_action.final_decision = "changes_requested"
if self._state.current_step >= self.max_steps:
self._state.is_complete = True
if not self._state.final_decision:
self._state.final_decision = "changes_requested"
if review_action.final_decision and not self._state.is_complete:
self._state.is_complete = True
self._state.final_decision = review_action.final_decision
reward = self.reward_calculator.calculate_reward(
review_action,
self._state.comments_made,
self._state.suggestions_made,
self._state.final_decision or "changes_requested",
self.grader,
self._state.last_action_valid,
)
info = {
"step": self._state.current_step,
"last_action_valid": self._state.last_action_valid,
"error": self._state.last_error,
"task_score": self.get_task_score(),
}
return self._get_observation(), reward, self._state.is_complete, info
def _process_action(self, action: ReviewAction):
if self._state is None:
return
self._state.last_action_valid = True
self._state.last_error = None
if action.action_type.value == "add_comment":
for comment in action.comments:
if comment.line_number <= self._state.code_context.line_count:
self._state.comments_made.append(comment)
else:
self._state.last_action_valid = False
self._state.last_error = f"Line {comment.line_number} out of range"
elif action.action_type.value == "suggest_fix":
for suggestion in action.suggestions:
if suggestion.original_line <= self._state.code_context.line_count:
self._state.suggestions_made.append(suggestion)
else:
self._state.last_action_valid = False
self._state.last_error = f"Line {suggestion.original_line} out of range"
elif action.action_type.value == "mark_as_resolved":
for comment in action.comments:
for existing_comment in self._state.comments_made:
if existing_comment.line_number == comment.line_number:
existing_comment.resolved = True
def _get_observation(self) -> Dict[str, Any]:
if self._state is None:
return {}
return Observation(
code_diff=self._state.code_context.code_diff,
file_context=self._state.code_context.surrounding_code,
file_path=self._state.code_context.file_path,
language=self._state.code_context.language,
task_description=self._state.task_metadata.description,
task_difficulty=self._state.task_metadata.difficulty,
current_step=self._state.current_step,
max_steps=self.max_steps,
previous_comments=self._state.comments_made,
previous_suggestions=self._state.suggestions_made,
review_complete=self._state.is_complete,
final_decision_made=self._state.final_decision
).model_dump()
def get_task_score(self) -> float:
if not self.grader or self._state is None:
return 0.0
return self.grader.compute_score_from_state(
comments=self._state.comments_made,
suggestions=self._state.suggestions_made,
final_decision=self._state.final_decision or "changes_requested",
)
def close(self):
pass
def state(self) -> Dict[str, Any]:
if self._state:
return self._state.model_dump()
return {} |