Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Tuple | |
| from environment.env import CodeReviewEnv | |
| class TemplateAction: | |
| name: str | |
| payload: Dict[str, Any] | |
| class TrainingEnv: | |
| """Thin wrapper around CodeReviewEnv for policy training experiments.""" | |
| def __init__(self, task_ids: List[str] | None = None, max_steps: int = 5, seed: int = 42): | |
| self.env = CodeReviewEnv() | |
| self.max_steps = max_steps | |
| self.seed = seed | |
| self.task_ids = task_ids or ["bug_detection_easy_1"] | |
| self.task_cursor = 0 | |
| def next_task(self) -> str: | |
| task_id = self.task_ids[self.task_cursor % len(self.task_ids)] | |
| self.task_cursor += 1 | |
| return task_id | |
| def run_episode(self, action_plan: List[TemplateAction]) -> Tuple[float, float, int]: | |
| task_id = self.next_task() | |
| self.env.max_steps = self.max_steps | |
| obs = self.env.reset(task_id=task_id, seed=self.seed) | |
| done = False | |
| total_reward = 0.0 | |
| steps = 0 | |
| for action in action_plan: | |
| if done: | |
| break | |
| obs, reward, done, _ = self.env.step(action.payload) | |
| total_reward += float(reward) | |
| steps += 1 | |
| task_score = float(self.env.get_task_score()) | |
| return total_reward, task_score, steps | |
| def default_action_catalog() -> Dict[str, List[TemplateAction]]: | |
| return { | |
| "phase_1": [ | |
| TemplateAction( | |
| "good_comment", | |
| { | |
| "action_type": "add_comment", | |
| "comments": [ | |
| { | |
| "line_number": 3, | |
| "content": "Potential division_by_zero or similar correctness issue", | |
| "is_issue": True, | |
| "severity": "high", | |
| } | |
| ], | |
| "suggestions": [], | |
| }, | |
| ), | |
| TemplateAction( | |
| "weak_comment", | |
| { | |
| "action_type": "add_comment", | |
| "comments": [ | |
| { | |
| "line_number": 1, | |
| "content": "maybe issue", | |
| "is_issue": True, | |
| "severity": "low", | |
| } | |
| ], | |
| "suggestions": [], | |
| }, | |
| ), | |
| ], | |
| "phase_2": [ | |
| TemplateAction( | |
| "good_fix", | |
| { | |
| "action_type": "suggest_fix", | |
| "comments": [], | |
| "suggestions": [ | |
| { | |
| "original_line": 3, | |
| "suggested_code": "return total / len(numbers) if numbers else 0", | |
| "explanation": "guard empty input", | |
| } | |
| ], | |
| }, | |
| ), | |
| TemplateAction( | |
| "bad_fix", | |
| { | |
| "action_type": "suggest_fix", | |
| "comments": [], | |
| "suggestions": [ | |
| { | |
| "original_line": 1, | |
| "suggested_code": "pass", | |
| "explanation": "placeholder", | |
| } | |
| ], | |
| }, | |
| ), | |
| ], | |
| "phase_3": [ | |
| TemplateAction( | |
| "request_changes", | |
| { | |
| "action_type": "request_changes", | |
| "comments": [], | |
| "suggestions": [], | |
| "final_decision": "changes_requested", | |
| }, | |
| ), | |
| TemplateAction( | |
| "approve", | |
| { | |
| "action_type": "approve", | |
| "comments": [], | |
| "suggestions": [], | |
| "final_decision": "approved", | |
| }, | |
| ), | |
| ], | |
| } | |