code-review-environment / train_env.py
ashishbaberwal's picture
New Final
1939cbc
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
from environment.env import CodeReviewEnv
@dataclass
class TemplateAction:
name: str
payload: Dict[str, Any]
class TrainingEnv:
"""Thin wrapper around CodeReviewEnv for policy training experiments."""
def __init__(self, task_ids: List[str] | None = None, max_steps: int = 5, seed: int = 42):
self.env = CodeReviewEnv()
self.max_steps = max_steps
self.seed = seed
self.task_ids = task_ids or ["bug_detection_easy_1"]
self.task_cursor = 0
def next_task(self) -> str:
task_id = self.task_ids[self.task_cursor % len(self.task_ids)]
self.task_cursor += 1
return task_id
def run_episode(self, action_plan: List[TemplateAction]) -> Tuple[float, float, int]:
task_id = self.next_task()
self.env.max_steps = self.max_steps
obs = self.env.reset(task_id=task_id, seed=self.seed)
done = False
total_reward = 0.0
steps = 0
for action in action_plan:
if done:
break
obs, reward, done, _ = self.env.step(action.payload)
total_reward += float(reward)
steps += 1
task_score = float(self.env.get_task_score())
return total_reward, task_score, steps
def default_action_catalog() -> Dict[str, List[TemplateAction]]:
return {
"phase_1": [
TemplateAction(
"good_comment",
{
"action_type": "add_comment",
"comments": [
{
"line_number": 3,
"content": "Potential division_by_zero or similar correctness issue",
"is_issue": True,
"severity": "high",
}
],
"suggestions": [],
},
),
TemplateAction(
"weak_comment",
{
"action_type": "add_comment",
"comments": [
{
"line_number": 1,
"content": "maybe issue",
"is_issue": True,
"severity": "low",
}
],
"suggestions": [],
},
),
],
"phase_2": [
TemplateAction(
"good_fix",
{
"action_type": "suggest_fix",
"comments": [],
"suggestions": [
{
"original_line": 3,
"suggested_code": "return total / len(numbers) if numbers else 0",
"explanation": "guard empty input",
}
],
},
),
TemplateAction(
"bad_fix",
{
"action_type": "suggest_fix",
"comments": [],
"suggestions": [
{
"original_line": 1,
"suggested_code": "pass",
"explanation": "placeholder",
}
],
},
),
],
"phase_3": [
TemplateAction(
"request_changes",
{
"action_type": "request_changes",
"comments": [],
"suggestions": [],
"final_decision": "changes_requested",
},
),
TemplateAction(
"approve",
{
"action_type": "approve",
"comments": [],
"suggestions": [],
"final_decision": "approved",
},
),
],
}