Add action validation and reward recording
Browse filesAdd validation and tighten action handling across the env and inference code.
- codereview_env/models.py: add a pydantic model_validator on Action to enforce required fields for FLAG_ISSUE (body, filename, line_number, category, severity) and require verdict/body for APPROVE/REQUEST_CHANGES actions.
- codereview_env/env.py: record ActionRecord before applying action (initialize reward=0.0), append it to history, compute reward delta and update running_score, then update the history record's reward (rounded to 4 decimals). Also guard history access when determining termination reason and set a default terminated_reason.
- codereview_env/scenarios.py: change required_verdict for arch_003, arch_005, and arch_008 from NEEDS_DISCUSSION to REQUEST_CHANGES.
- inference.py: normalize verdict strings to lowercase ("lgtm" and "request_changes").
These changes improve data validation, ensure accurate per-action reward recording in history, and standardize verdict values.
- codereview_env/env.py +11 -6
- codereview_env/models.py +15 -1
- codereview_env/scenarios.py +3 -3
- inference.py +2 -2
|
@@ -49,8 +49,8 @@ class CodeReviewEnv:
|
|
| 49 |
s = self._state
|
| 50 |
s["step_count"] += 1
|
| 51 |
|
| 52 |
-
# Record action in history
|
| 53 |
-
|
| 54 |
action_type=action.action_type,
|
| 55 |
body=action.body,
|
| 56 |
filename=action.filename,
|
|
@@ -58,14 +58,19 @@ class CodeReviewEnv:
|
|
| 58 |
severity=action.severity,
|
| 59 |
category=action.category,
|
| 60 |
verdict=action.verdict,
|
|
|
|
| 61 |
timestamp=datetime.now(timezone.utc).isoformat()
|
| 62 |
-
)
|
|
|
|
| 63 |
|
| 64 |
# Apply action logic and compute incremental reward delta
|
| 65 |
prev_score = s["running_score"]
|
| 66 |
reward_delta = self._apply_action(action)
|
| 67 |
s["running_score"] = prev_score + reward_delta
|
| 68 |
|
|
|
|
|
|
|
|
|
|
| 69 |
# Check termination
|
| 70 |
s["done"] = (
|
| 71 |
action.action_type in (ActionType.APPROVE, ActionType.REQUEST_CHANGES)
|
|
@@ -150,13 +155,13 @@ class CodeReviewEnv:
|
|
| 150 |
missed_ids = list(all_gt_ids - s["issues_found"])
|
| 151 |
final_score = self._grade(sc, s)
|
| 152 |
|
| 153 |
-
terminated_reason = "
|
| 154 |
if s["done"]:
|
| 155 |
if s["noise_budget"] <= 0:
|
| 156 |
terminated_reason = "noise_exhausted"
|
| 157 |
-
elif s["history"][-1].action_type in (ActionType.APPROVE, ActionType.REQUEST_CHANGES):
|
| 158 |
terminated_reason = "terminal_action"
|
| 159 |
-
|
| 160 |
terminated_reason = "max_steps"
|
| 161 |
|
| 162 |
return EpisodeResult(
|
|
|
|
| 49 |
s = self._state
|
| 50 |
s["step_count"] += 1
|
| 51 |
|
| 52 |
+
# Record action in history (reward will be updated after calculation)
|
| 53 |
+
record = ActionRecord(
|
| 54 |
action_type=action.action_type,
|
| 55 |
body=action.body,
|
| 56 |
filename=action.filename,
|
|
|
|
| 58 |
severity=action.severity,
|
| 59 |
category=action.category,
|
| 60 |
verdict=action.verdict,
|
| 61 |
+
reward=0.0,
|
| 62 |
timestamp=datetime.now(timezone.utc).isoformat()
|
| 63 |
+
)
|
| 64 |
+
s["history"].append(record)
|
| 65 |
|
| 66 |
# Apply action logic and compute incremental reward delta
|
| 67 |
prev_score = s["running_score"]
|
| 68 |
reward_delta = self._apply_action(action)
|
| 69 |
s["running_score"] = prev_score + reward_delta
|
| 70 |
|
| 71 |
+
# Update the history record with the actual reward
|
| 72 |
+
record.reward = round(reward_delta, 4)
|
| 73 |
+
|
| 74 |
# Check termination
|
| 75 |
s["done"] = (
|
| 76 |
action.action_type in (ActionType.APPROVE, ActionType.REQUEST_CHANGES)
|
|
|
|
| 155 |
missed_ids = list(all_gt_ids - s["issues_found"])
|
| 156 |
final_score = self._grade(sc, s)
|
| 157 |
|
| 158 |
+
terminated_reason = ""
|
| 159 |
if s["done"]:
|
| 160 |
if s["noise_budget"] <= 0:
|
| 161 |
terminated_reason = "noise_exhausted"
|
| 162 |
+
elif s["history"] and s["history"][-1].action_type in (ActionType.APPROVE, ActionType.REQUEST_CHANGES):
|
| 163 |
terminated_reason = "terminal_action"
|
| 164 |
+
else:
|
| 165 |
terminated_reason = "max_steps"
|
| 166 |
|
| 167 |
return EpisodeResult(
|
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from enum import Enum
|
| 2 |
from typing import List, Optional, Union
|
| 3 |
-
from pydantic import BaseModel
|
| 4 |
|
| 5 |
class TaskId(str, Enum):
|
| 6 |
BUG_DETECTION = "bug_detection"
|
|
@@ -73,6 +73,20 @@ class Action(BaseModel):
|
|
| 73 |
severity: Optional[Severity] = None
|
| 74 |
verdict: Optional[Verdict] = None
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
class ActionRecord(BaseModel):
|
| 77 |
"""Immutable record of a step taken — stored in episode history."""
|
| 78 |
action_type: ActionType
|
|
|
|
| 1 |
from enum import Enum
|
| 2 |
from typing import List, Optional, Union
|
| 3 |
+
from pydantic import BaseModel, model_validator
|
| 4 |
|
| 5 |
class TaskId(str, Enum):
|
| 6 |
BUG_DETECTION = "bug_detection"
|
|
|
|
| 73 |
severity: Optional[Severity] = None
|
| 74 |
verdict: Optional[Verdict] = None
|
| 75 |
|
| 76 |
+
@model_validator(mode="after")
|
| 77 |
+
def validate_action_fields(self) -> "Action":
|
| 78 |
+
if self.action_type == ActionType.FLAG_ISSUE:
|
| 79 |
+
if not self.body or not self.filename or self.line_number is None:
|
| 80 |
+
raise ValueError("flag_issue requires body, filename, and line_number")
|
| 81 |
+
if not self.category or not self.severity:
|
| 82 |
+
raise ValueError("flag_issue requires category and severity")
|
| 83 |
+
elif self.action_type in (ActionType.APPROVE, ActionType.REQUEST_CHANGES):
|
| 84 |
+
if not self.verdict:
|
| 85 |
+
raise ValueError(f"{self.action_type.value} action requires a verdict")
|
| 86 |
+
if not self.body:
|
| 87 |
+
raise ValueError(f"{self.action_type.value} action requires a body summary")
|
| 88 |
+
return self
|
| 89 |
+
|
| 90 |
class ActionRecord(BaseModel):
|
| 91 |
"""Immutable record of a step taken — stored in episode history."""
|
| 92 |
action_type: ActionType
|
|
@@ -810,7 +810,7 @@ arch_003 = Scenario(
|
|
| 810 |
line_number=2,
|
| 811 |
description="Tight coupling: service depends on concrete implementation instead of abstraction",
|
| 812 |
keywords=["tight coupling", "dependency injection"],
|
| 813 |
-
required_verdict=Verdict.
|
| 814 |
)
|
| 815 |
],
|
| 816 |
hash="arch_003",
|
|
@@ -881,7 +881,7 @@ arch_005 = Scenario(
|
|
| 881 |
line_number=6,
|
| 882 |
description="Missing resilience (retry, timeout, circuit breaker) on external API dependency",
|
| 883 |
keywords=["retry", "resilience"],
|
| 884 |
-
required_verdict=Verdict.
|
| 885 |
)
|
| 886 |
],
|
| 887 |
hash="arch_005",
|
|
@@ -984,7 +984,7 @@ arch_008 = Scenario(
|
|
| 984 |
line_number=2,
|
| 985 |
description="Secret leaked in code comment; should be in environment variables only",
|
| 986 |
keywords=["secret", "comment"],
|
| 987 |
-
required_verdict=Verdict.
|
| 988 |
)
|
| 989 |
],
|
| 990 |
hash="arch_008",
|
|
|
|
| 810 |
line_number=2,
|
| 811 |
description="Tight coupling: service depends on concrete implementation instead of abstraction",
|
| 812 |
keywords=["tight coupling", "dependency injection"],
|
| 813 |
+
required_verdict=Verdict.REQUEST_CHANGES
|
| 814 |
)
|
| 815 |
],
|
| 816 |
hash="arch_003",
|
|
|
|
| 881 |
line_number=6,
|
| 882 |
description="Missing resilience (retry, timeout, circuit breaker) on external API dependency",
|
| 883 |
keywords=["retry", "resilience"],
|
| 884 |
+
required_verdict=Verdict.REQUEST_CHANGES
|
| 885 |
)
|
| 886 |
],
|
| 887 |
hash="arch_005",
|
|
|
|
| 984 |
line_number=2,
|
| 985 |
description="Secret leaked in code comment; should be in environment variables only",
|
| 986 |
keywords=["secret", "comment"],
|
| 987 |
+
required_verdict=Verdict.REQUEST_CHANGES
|
| 988 |
)
|
| 989 |
],
|
| 990 |
hash="arch_008",
|
|
@@ -165,9 +165,9 @@ def sanitize_action(action_dict: dict, task_id: str) -> dict:
|
|
| 165 |
|
| 166 |
elif action_type in ("approve", "request_changes"):
|
| 167 |
if action_type == "approve":
|
| 168 |
-
action_dict["verdict"] = "
|
| 169 |
else:
|
| 170 |
-
action_dict["verdict"] = "
|
| 171 |
if "body" not in action_dict:
|
| 172 |
action_dict["body"] = "Review complete."
|
| 173 |
|
|
|
|
| 165 |
|
| 166 |
elif action_type in ("approve", "request_changes"):
|
| 167 |
if action_type == "approve":
|
| 168 |
+
action_dict["verdict"] = "lgtm"
|
| 169 |
else:
|
| 170 |
+
action_dict["verdict"] = "request_changes"
|
| 171 |
if "body" not in action_dict:
|
| 172 |
action_dict["body"] = "Review complete."
|
| 173 |
|