Shardul Dhekane
CRE Main Update
a21d9ae
from typing import Dict, Any, Tuple, Optional
from environment.models import (
ReviewAction,
ReviewState,
Observation
)
from environment.tasks import TaskDefinitions
from environment.graders import TaskGrader, RewardCalculator
class CodeReviewEnv:
def __init__(self):
self._state: Optional[ReviewState] = None
self.grader: Optional[TaskGrader] = None
self.reward_calculator = RewardCalculator()
self.max_steps = 50
self.current_task_id: Optional[str] = None
def reset(self, task_id: str = None) -> Dict[str, Any]:
if task_id is None:
task_id = "bug_detection_easy_1"
self.current_task_id = task_id
task_data = TaskDefinitions.get_task(task_id)
code_context = TaskDefinitions.create_code_context(task_data)
task_metadata = TaskDefinitions.create_task_metadata(task_data)
self._state = ReviewState(
code_context=code_context,
task_metadata=task_metadata,
comments_made=[],
suggestions_made=[],
current_step=0,
is_complete=False,
final_decision=None,
last_action_valid=True,
last_error=None
)
self.grader = TaskGrader(task_metadata.expected_issues)
self.reward_calculator.reset()
return self._get_observation()
def step(self, action: Dict[str, Any]) -> Tuple[Dict[str, Any], float, bool, Dict[str, Any]]:
if self._state is None:
return {}, -0.1, True, {"error": "Environment not initialized. Call reset() first."}
if self._state.is_complete:
return self._get_observation(), 0.0, True, {"error": "Episode already complete"}
try:
review_action = ReviewAction(**action)
except Exception as e:
self._state.last_action_valid = False
self._state.last_error = str(e)
return self._get_observation(), -0.1, False, {"error": str(e), "last_action_valid": False}
self._state.current_step += 1
self._process_action(review_action)
if review_action.action_type.value == "approve" and not review_action.final_decision:
review_action.final_decision = "approved"
elif review_action.action_type.value == "request_changes" and not review_action.final_decision:
review_action.final_decision = "changes_requested"
if self._state.current_step >= self.max_steps:
self._state.is_complete = True
if not self._state.final_decision:
self._state.final_decision = "changes_requested"
if review_action.final_decision and not self._state.is_complete:
self._state.is_complete = True
self._state.final_decision = review_action.final_decision
reward = self.reward_calculator.calculate_reward(
review_action,
self._state.comments_made,
self._state.suggestions_made,
self._state.final_decision or "changes_requested",
self.grader,
self._state.last_action_valid,
)
info = {
"step": self._state.current_step,
"last_action_valid": self._state.last_action_valid,
"error": self._state.last_error,
"task_score": self.get_task_score(),
}
return self._get_observation(), reward, self._state.is_complete, info
def _process_action(self, action: ReviewAction):
if self._state is None:
return
self._state.last_action_valid = True
self._state.last_error = None
if action.action_type.value == "add_comment":
for comment in action.comments:
if comment.line_number <= self._state.code_context.line_count:
self._state.comments_made.append(comment)
else:
self._state.last_action_valid = False
self._state.last_error = f"Line {comment.line_number} out of range"
elif action.action_type.value == "suggest_fix":
for suggestion in action.suggestions:
if suggestion.original_line <= self._state.code_context.line_count:
self._state.suggestions_made.append(suggestion)
else:
self._state.last_action_valid = False
self._state.last_error = f"Line {suggestion.original_line} out of range"
elif action.action_type.value == "mark_as_resolved":
for comment in action.comments:
for existing_comment in self._state.comments_made:
if existing_comment.line_number == comment.line_number:
existing_comment.resolved = True
def _get_observation(self) -> Dict[str, Any]:
if self._state is None:
return {}
return Observation(
code_diff=self._state.code_context.code_diff,
file_context=self._state.code_context.surrounding_code,
file_path=self._state.code_context.file_path,
language=self._state.code_context.language,
task_description=self._state.task_metadata.description,
task_difficulty=self._state.task_metadata.difficulty,
current_step=self._state.current_step,
max_steps=self.max_steps,
previous_comments=self._state.comments_made,
previous_suggestions=self._state.suggestions_made,
review_complete=self._state.is_complete,
final_decision_made=self._state.final_decision
).model_dump()
def get_task_score(self) -> float:
if not self.grader or self._state is None:
return 0.0
return self.grader.compute_score_from_state(
comments=self._state.comments_made,
suggestions=self._state.suggestions_made,
final_decision=self._state.final_decision or "changes_requested",
)
def close(self):
pass
def state(self) -> Dict[str, Any]:
if self._state:
return self._state.model_dump()
return {}