ArshVerma commited on
Commit
0320a8d
·
1 Parent(s): 99cbab5

Add action validation and reward recording

Browse files

Add validation and tighten action handling across the env and inference code.

- codereview_env/models.py: add a pydantic model_validator on Action to enforce required fields for FLAG_ISSUE (body, filename, line_number, category, severity) and require verdict/body for APPROVE/REQUEST_CHANGES actions.
- codereview_env/env.py: record ActionRecord before applying action (initialize reward=0.0), append it to history, compute reward delta and update running_score, then update the history record's reward (rounded to 4 decimals). Also guard history access when determining termination reason and set a default terminated_reason.
- codereview_env/scenarios.py: change required_verdict for arch_003, arch_005, and arch_008 from NEEDS_DISCUSSION to REQUEST_CHANGES.
- inference.py: normalize verdict strings to lowercase ("lgtm" and "request_changes").

These changes improve data validation, ensure accurate per-action reward recording in history, and standardize verdict values.

codereview_env/env.py CHANGED
@@ -49,8 +49,8 @@ class CodeReviewEnv:
49
  s = self._state
50
  s["step_count"] += 1
51
 
52
- # Record action in history
53
- s["history"].append(ActionRecord(
54
  action_type=action.action_type,
55
  body=action.body,
56
  filename=action.filename,
@@ -58,14 +58,19 @@ class CodeReviewEnv:
58
  severity=action.severity,
59
  category=action.category,
60
  verdict=action.verdict,
 
61
  timestamp=datetime.now(timezone.utc).isoformat()
62
- ))
 
63
 
64
  # Apply action logic and compute incremental reward delta
65
  prev_score = s["running_score"]
66
  reward_delta = self._apply_action(action)
67
  s["running_score"] = prev_score + reward_delta
68
 
 
 
 
69
  # Check termination
70
  s["done"] = (
71
  action.action_type in (ActionType.APPROVE, ActionType.REQUEST_CHANGES)
@@ -150,13 +155,13 @@ class CodeReviewEnv:
150
  missed_ids = list(all_gt_ids - s["issues_found"])
151
  final_score = self._grade(sc, s)
152
 
153
- terminated_reason = "max_steps"
154
  if s["done"]:
155
  if s["noise_budget"] <= 0:
156
  terminated_reason = "noise_exhausted"
157
- elif s["history"][-1].action_type in (ActionType.APPROVE, ActionType.REQUEST_CHANGES):
158
  terminated_reason = "terminal_action"
159
- elif s["step_count"] >= s["max_steps"]:
160
  terminated_reason = "max_steps"
161
 
162
  return EpisodeResult(
 
49
  s = self._state
50
  s["step_count"] += 1
51
 
52
+ # Record action in history (reward will be updated after calculation)
53
+ record = ActionRecord(
54
  action_type=action.action_type,
55
  body=action.body,
56
  filename=action.filename,
 
58
  severity=action.severity,
59
  category=action.category,
60
  verdict=action.verdict,
61
+ reward=0.0,
62
  timestamp=datetime.now(timezone.utc).isoformat()
63
+ )
64
+ s["history"].append(record)
65
 
66
  # Apply action logic and compute incremental reward delta
67
  prev_score = s["running_score"]
68
  reward_delta = self._apply_action(action)
69
  s["running_score"] = prev_score + reward_delta
70
 
71
+ # Update the history record with the actual reward
72
+ record.reward = round(reward_delta, 4)
73
+
74
  # Check termination
75
  s["done"] = (
76
  action.action_type in (ActionType.APPROVE, ActionType.REQUEST_CHANGES)
 
155
  missed_ids = list(all_gt_ids - s["issues_found"])
156
  final_score = self._grade(sc, s)
157
 
158
+ terminated_reason = ""
159
  if s["done"]:
160
  if s["noise_budget"] <= 0:
161
  terminated_reason = "noise_exhausted"
162
+ elif s["history"] and s["history"][-1].action_type in (ActionType.APPROVE, ActionType.REQUEST_CHANGES):
163
  terminated_reason = "terminal_action"
164
+ else:
165
  terminated_reason = "max_steps"
166
 
167
  return EpisodeResult(
codereview_env/models.py CHANGED
@@ -1,6 +1,6 @@
1
  from enum import Enum
2
  from typing import List, Optional, Union
3
- from pydantic import BaseModel
4
 
5
  class TaskId(str, Enum):
6
  BUG_DETECTION = "bug_detection"
@@ -73,6 +73,20 @@ class Action(BaseModel):
73
  severity: Optional[Severity] = None
74
  verdict: Optional[Verdict] = None
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  class ActionRecord(BaseModel):
77
  """Immutable record of a step taken — stored in episode history."""
78
  action_type: ActionType
 
1
  from enum import Enum
2
  from typing import List, Optional, Union
3
+ from pydantic import BaseModel, model_validator
4
 
5
  class TaskId(str, Enum):
6
  BUG_DETECTION = "bug_detection"
 
73
  severity: Optional[Severity] = None
74
  verdict: Optional[Verdict] = None
75
 
76
+ @model_validator(mode="after")
77
+ def validate_action_fields(self) -> "Action":
78
+ if self.action_type == ActionType.FLAG_ISSUE:
79
+ if not self.body or not self.filename or self.line_number is None:
80
+ raise ValueError("flag_issue requires body, filename, and line_number")
81
+ if not self.category or not self.severity:
82
+ raise ValueError("flag_issue requires category and severity")
83
+ elif self.action_type in (ActionType.APPROVE, ActionType.REQUEST_CHANGES):
84
+ if not self.verdict:
85
+ raise ValueError(f"{self.action_type.value} action requires a verdict")
86
+ if not self.body:
87
+ raise ValueError(f"{self.action_type.value} action requires a body summary")
88
+ return self
89
+
90
  class ActionRecord(BaseModel):
91
  """Immutable record of a step taken — stored in episode history."""
92
  action_type: ActionType
codereview_env/scenarios.py CHANGED
@@ -810,7 +810,7 @@ arch_003 = Scenario(
810
  line_number=2,
811
  description="Tight coupling: service depends on concrete implementation instead of abstraction",
812
  keywords=["tight coupling", "dependency injection"],
813
- required_verdict=Verdict.NEEDS_DISCUSSION
814
  )
815
  ],
816
  hash="arch_003",
@@ -881,7 +881,7 @@ arch_005 = Scenario(
881
  line_number=6,
882
  description="Missing resilience (retry, timeout, circuit breaker) on external API dependency",
883
  keywords=["retry", "resilience"],
884
- required_verdict=Verdict.NEEDS_DISCUSSION
885
  )
886
  ],
887
  hash="arch_005",
@@ -984,7 +984,7 @@ arch_008 = Scenario(
984
  line_number=2,
985
  description="Secret leaked in code comment; should be in environment variables only",
986
  keywords=["secret", "comment"],
987
- required_verdict=Verdict.NEEDS_DISCUSSION
988
  )
989
  ],
990
  hash="arch_008",
 
810
  line_number=2,
811
  description="Tight coupling: service depends on concrete implementation instead of abstraction",
812
  keywords=["tight coupling", "dependency injection"],
813
+ required_verdict=Verdict.REQUEST_CHANGES
814
  )
815
  ],
816
  hash="arch_003",
 
881
  line_number=6,
882
  description="Missing resilience (retry, timeout, circuit breaker) on external API dependency",
883
  keywords=["retry", "resilience"],
884
+ required_verdict=Verdict.REQUEST_CHANGES
885
  )
886
  ],
887
  hash="arch_005",
 
984
  line_number=2,
985
  description="Secret leaked in code comment; should be in environment variables only",
986
  keywords=["secret", "comment"],
987
+ required_verdict=Verdict.REQUEST_CHANGES
988
  )
989
  ],
990
  hash="arch_008",
inference.py CHANGED
@@ -165,9 +165,9 @@ def sanitize_action(action_dict: dict, task_id: str) -> dict:
165
 
166
  elif action_type in ("approve", "request_changes"):
167
  if action_type == "approve":
168
- action_dict["verdict"] = "LGTM"
169
  else:
170
- action_dict["verdict"] = "REQUEST_CHANGES"
171
  if "body" not in action_dict:
172
  action_dict["body"] = "Review complete."
173
 
 
165
 
166
  elif action_type in ("approve", "request_changes"):
167
  if action_type == "approve":
168
+ action_dict["verdict"] = "lgtm"
169
  else:
170
+ action_dict["verdict"] = "request_changes"
171
  if "body" not in action_dict:
172
  action_dict["body"] = "Review complete."
173