codemaverick2 commited on
Commit
78f3eb2
Β·
1 Parent(s): e48a1e4

Add diversity/exploration bonuses, near-miss type check, context truncation

Browse files
README.md CHANGED
@@ -208,7 +208,9 @@ Near-miss (Β±3-5 lines): graduated partial credit via exponential decay
208
  | TP + early (first 40% of steps) | +0.02 bonus |
209
  | TP + high confidence (β‰₯0.7) | +0.01 bonus |
210
  | PBRS potential shaping (Ξ¦(s')βˆ’Ξ¦(s)) | +0.03–0.08 |
211
- | Near-miss (Β±3-5 lines, exponential decay) | +0.020–0.055 |
 
 
212
  | False positive | βˆ’0.05 |
213
  | False positive flood (4th+ FP) | escalating βˆ’0.03 extra |
214
  | High-confidence FP | βˆ’0.03 extra |
@@ -220,7 +222,9 @@ Near-miss (Β±3-5 lines): graduated partial credit via exponential decay
220
  ### Reward shaping foundations
221
 
222
  - **Potential-Based Reward Shaping** (Ng et al. 1999): Ξ¦(s) = (tp/total_gt) Γ— 0.5. Policy-invariant shaping that improves sample efficiency without changing the optimal policy.
223
- - **Graduated near-miss** (exponential decay): reward = 0.10 Γ— e^(βˆ’0.6 Γ— (line_diff βˆ’ 2)) for lines 3-5 off. Gives smooth gradient signal for line-number refinement.
 
 
224
  - **Variable-Length Return Normalization** (VL Norm 2025): normalized_return = cumulative_reward / steps_used. Makes return comparable across tasks of different lengths.
225
  - **Flood protection**: escalating FP penalty prevents reward hacking via flag-spamming.
226
 
 
208
  | TP + early (first 40% of steps) | +0.02 bonus |
209
  | TP + high confidence (β‰₯0.7) | +0.01 bonus |
210
  | PBRS potential shaping (Ξ¦(s')βˆ’Ξ¦(s)) | +0.03–0.08 |
211
+ | Diversity bonus (first TP in new issue category) | +0.02 |
212
+ | Exploration bonus (first TP in new file, multi-file tasks) | +0.01 |
213
+ | Near-miss (Β±3-5 lines, compatible type, exp decay) | +0.020–0.055 |
214
  | False positive | βˆ’0.05 |
215
  | False positive flood (4th+ FP) | escalating βˆ’0.03 extra |
216
  | High-confidence FP | βˆ’0.03 extra |
 
222
  ### Reward shaping foundations
223
 
224
  - **Potential-Based Reward Shaping** (Ng et al. 1999): Ξ¦(s) = (tp/total_gt) Γ— 0.5. Policy-invariant shaping that improves sample efficiency without changing the optimal policy.
225
+ - **Graduated near-miss** (exponential decay): reward = 0.10 Γ— e^(βˆ’0.6 Γ— (line_diff βˆ’ 2)) for lines 3-5 off with compatible issue type. Gives smooth gradient signal for line-number refinement.
226
+ - **Diversity bonus**: +0.02 for first TP in a new issue category (security/bug/performance). Encourages covering all issue types instead of spamming one.
227
+ - **Exploration bonus**: +0.01 for first TP in a new file (multi-file tasks only). Encourages cross-file coverage.
228
  - **Variable-Length Return Normalization** (VL Norm 2025): normalized_return = cumulative_reward / steps_used. Makes return comparable across tasks of different lengths.
229
  - **Flood protection**: escalating FP penalty prevents reward hacking via flag-spamming.
230
 
inference.py CHANGED
@@ -404,6 +404,12 @@ def run_task(task_id: str, http_client: httpx.Client) -> dict:
404
  if combined_feedback:
405
  messages.append({"role": "user", "content": combined_feedback})
406
 
 
 
 
 
 
 
407
  atype = action.get("action_type", "")
408
  print(f" Step {step_count:2d}: {atype:20s} | reward={str(last_reward):8s} | score={obs.get('current_score', 0.0):.3f}")
409
 
 
404
  if combined_feedback:
405
  messages.append({"role": "user", "content": combined_feedback})
406
 
407
+ # Context window management: keep system + initial prompt + last 12 exchanges
408
+ # This prevents token limit errors on long episodes (25+ steps)
409
+ max_history = 2 + 24 # system + initial user + 12 assistant/user pairs
410
+ if len(messages) > max_history:
411
+ messages = messages[:2] + messages[-(max_history - 2):]
412
+
413
  atype = action.get("action_type", "")
414
  print(f" Step {step_count:2d}: {atype:20s} | reward={str(last_reward):8s} | score={obs.get('current_score', 0.0):.3f}")
415
 
server/environment.py CHANGED
@@ -44,6 +44,10 @@ _VALIDATION_PENALTY = -0.02
44
  # Flood protection: escalating FP penalty
45
  _FP_FLOOD_THRESHOLD = 3 # FPs before escalation kicks in
46
  _FP_FLOOD_MULTIPLIER = 1.5 # each extra FP beyond threshold costs 1.5x more
 
 
 
 
47
 
48
  _SEV_RANK = {"low": 0, "medium": 1, "high": 2, "critical": 3}
49
 
@@ -80,6 +84,8 @@ class CodeReviewEnvironment(_BaseEnv):
80
  self._fp_count: int = 0 # total false positives this episode
81
  self._matched_gt_indices: Set[int] = set() # GT indices already matched
82
  self._episode_rewards: List[float] = [] # for VL return normalization
 
 
83
 
84
  def reset(
85
  self,
@@ -104,6 +110,8 @@ class CodeReviewEnvironment(_BaseEnv):
104
  self._fp_count = 0
105
  self._matched_gt_indices = set()
106
  self._episode_rewards = []
 
 
107
 
108
  self._state = ReviewState(
109
  task_id=task_id,
@@ -401,6 +409,11 @@ class CodeReviewEnvironment(_BaseEnv):
401
  fix_suggestion=action.fix_suggestion,
402
  )
403
 
 
 
 
 
 
404
  # Classify: TP, near-miss (with line distance), or FP
405
  is_tp = False
406
  is_near = False
@@ -460,16 +473,32 @@ class CodeReviewEnvironment(_BaseEnv):
460
  pbrs_bonus = round(phi_after - phi_before, 4)
461
  reward_breakdown["pbrs_shaping"] = pbrs_bonus
462
 
463
- reward = base_reward + severity_bonus + temporal_bonus + confidence_bonus + pbrs_bonus
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  reward_breakdown["total"] = round(reward, 4)
465
 
466
  sev_note = f", severity +{severity_bonus:.2f}" if severity_bonus else ""
467
  temp_note = f", early +{temporal_bonus:.2f}" if temporal_bonus else ""
468
  conf_note = f", conf +{confidence_bonus:.2f}" if confidence_bonus else ""
469
  pbrs_note = f", progress +{pbrs_bonus:.2f}" if pbrs_bonus > 0 else ""
 
470
  feedback = (
471
  f"Correct! Issue at {action.filename}:{action.line_number} confirmed. "
472
- f"[+{reward:.2f}{sev_note}{temp_note}{conf_note}{pbrs_note}]"
473
  )
474
 
475
  elif is_near:
 
44
  # Flood protection: escalating FP penalty
45
  _FP_FLOOD_THRESHOLD = 3 # FPs before escalation kicks in
46
  _FP_FLOOD_MULTIPLIER = 1.5 # each extra FP beyond threshold costs 1.5x more
47
+ # Diversity bonus: reward for covering a new issue category
48
+ _DIVERSITY_BONUS = 0.02 # first TP in a new issue_type category
49
+ # Exploration bonus: first flag in a previously unflagged file
50
+ _FILE_EXPLORATION_BONUS = 0.01
51
 
52
  _SEV_RANK = {"low": 0, "medium": 1, "high": 2, "critical": 3}
53
 
 
84
  self._fp_count: int = 0 # total false positives this episode
85
  self._matched_gt_indices: Set[int] = set() # GT indices already matched
86
  self._episode_rewards: List[float] = [] # for VL return normalization
87
+ self._found_categories: Set[str] = set() # issue types already found (for diversity bonus)
88
+ self._flagged_files: Set[str] = set() # files already flagged (for exploration bonus)
89
 
90
  def reset(
91
  self,
 
110
  self._fp_count = 0
111
  self._matched_gt_indices = set()
112
  self._episode_rewards = []
113
+ self._found_categories = set()
114
+ self._flagged_files = set()
115
 
116
  self._state = ReviewState(
117
  task_id=task_id,
 
409
  fix_suggestion=action.fix_suggestion,
410
  )
411
 
412
+ # Track file exploration
413
+ is_new_file = action.filename not in self._flagged_files
414
+ if action.filename:
415
+ self._flagged_files.add(action.filename)
416
+
417
  # Classify: TP, near-miss (with line distance), or FP
418
  is_tp = False
419
  is_near = False
 
473
  pbrs_bonus = round(phi_after - phi_before, 4)
474
  reward_breakdown["pbrs_shaping"] = pbrs_bonus
475
 
476
+ # Diversity bonus: first TP in a new issue category
477
+ diversity_bonus = 0.0
478
+ gt_type = matched_gt_issue.issue_type
479
+ if gt_type not in self._found_categories:
480
+ self._found_categories.add(gt_type)
481
+ diversity_bonus = _DIVERSITY_BONUS
482
+ reward_breakdown["diversity_bonus"] = diversity_bonus
483
+
484
+ # Exploration bonus: first flag in a new file (multi-file tasks)
485
+ exploration_bonus = 0.0
486
+ if is_new_file and len(self._task.get("code_files", {})) > 1:
487
+ exploration_bonus = _FILE_EXPLORATION_BONUS
488
+ reward_breakdown["exploration_bonus"] = exploration_bonus
489
+
490
+ reward = (base_reward + severity_bonus + temporal_bonus +
491
+ confidence_bonus + pbrs_bonus + diversity_bonus + exploration_bonus)
492
  reward_breakdown["total"] = round(reward, 4)
493
 
494
  sev_note = f", severity +{severity_bonus:.2f}" if severity_bonus else ""
495
  temp_note = f", early +{temporal_bonus:.2f}" if temporal_bonus else ""
496
  conf_note = f", conf +{confidence_bonus:.2f}" if confidence_bonus else ""
497
  pbrs_note = f", progress +{pbrs_bonus:.2f}" if pbrs_bonus > 0 else ""
498
+ div_note = f", new-type +{diversity_bonus:.2f}" if diversity_bonus else ""
499
  feedback = (
500
  f"Correct! Issue at {action.filename}:{action.line_number} confirmed. "
501
+ f"[+{reward:.2f}{sev_note}{temp_note}{conf_note}{pbrs_note}{div_note}]"
502
  )
503
 
504
  elif is_near:
server/graders.py CHANGED
@@ -58,21 +58,23 @@ def match_quality(flagged: Issue, gt: Issue) -> str:
58
  """
59
  Return quality of match between flagged and gt:
60
  "exact" β€” within Β±2 lines and right issue type
61
- "near" β€” within Β±3-5 lines and same file (regardless of type)
62
  "none" β€” no meaningful match
63
  """
64
  if flagged.filename != gt.filename:
65
  return "none"
66
 
67
  line_diff = abs(flagged.line_number - gt.line_number)
 
68
 
69
  if line_diff <= EXACT_TOLERANCE:
70
- compat = _TYPE_COMPAT.get(gt.issue_type, {gt.issue_type})
71
  if flagged.issue_type in compat:
72
  return "exact"
73
 
74
  if line_diff <= NEAR_TOLERANCE:
75
- return "near"
 
 
76
 
77
  return "none"
78
 
 
58
  """
59
  Return quality of match between flagged and gt:
60
  "exact" β€” within Β±2 lines and right issue type
61
+ "near" β€” within Β±3-5 lines, same file, and compatible issue type
62
  "none" β€” no meaningful match
63
  """
64
  if flagged.filename != gt.filename:
65
  return "none"
66
 
67
  line_diff = abs(flagged.line_number - gt.line_number)
68
+ compat = _TYPE_COMPAT.get(gt.issue_type, {gt.issue_type})
69
 
70
  if line_diff <= EXACT_TOLERANCE:
 
71
  if flagged.issue_type in compat:
72
  return "exact"
73
 
74
  if line_diff <= NEAR_TOLERANCE:
75
+ # Near-miss requires compatible type to avoid rewarding wrong-type flags
76
+ if flagged.issue_type in compat:
77
+ return "near"
78
 
79
  return "none"
80
 
tests/test_environment.py CHANGED
@@ -838,3 +838,56 @@ class TestFunctionRanges:
838
  def test_function_ranges_nonempty_for_python(self, env):
839
  obs = env.reset(task_id="bug-detection")
840
  assert len(obs.code_metadata["function_ranges"]) > 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838
  def test_function_ranges_nonempty_for_python(self, env):
839
  obs = env.reset(task_id="bug-detection")
840
  assert len(obs.code_metadata["function_ranges"]) > 0
841
+
842
+
843
+ # ---------------------------------------------------------------------------
844
+ # Diversity bonus
845
+ # ---------------------------------------------------------------------------
846
+
847
+ class TestDiversityBonus:
848
+ def test_first_tp_in_category_gets_diversity_bonus(self, env):
849
+ """First TP in a new issue category should include diversity_bonus."""
850
+ env.reset(task_id="security-audit")
851
+ obs = env.step(ReviewAction(
852
+ action_type="flag_issue", line_number=8, filename="app.py",
853
+ issue_type="security", severity="high", description="hardcoded secret"
854
+ ))
855
+ # First security TP β†’ should have diversity bonus
856
+ assert obs.reward_breakdown.get("diversity_bonus", 0) > 0
857
+
858
+ def test_second_tp_same_category_no_diversity_bonus(self, env):
859
+ """Second TP in same category should NOT get diversity bonus."""
860
+ env.reset(task_id="security-audit")
861
+ env.step(ReviewAction(
862
+ action_type="flag_issue", line_number=8, filename="app.py",
863
+ issue_type="security", severity="high", description="hardcoded secret"
864
+ ))
865
+ obs2 = env.step(ReviewAction(
866
+ action_type="flag_issue", line_number=19, filename="app.py",
867
+ issue_type="security", severity="critical", description="sql injection"
868
+ ))
869
+ assert obs2.reward_breakdown.get("diversity_bonus", 0) == 0
870
+
871
+
872
+ # ---------------------------------------------------------------------------
873
+ # Exploration bonus (multi-file tasks)
874
+ # ---------------------------------------------------------------------------
875
+
876
+ class TestExplorationBonus:
877
+ def test_multifile_first_flag_gets_exploration_bonus(self, env):
878
+ """First flag in a new file of a multi-file task gets exploration bonus."""
879
+ env.reset(task_id="comprehensive-review")
880
+ obs = env.step(ReviewAction(
881
+ action_type="flag_issue", line_number=7, filename="models.py",
882
+ issue_type="security", severity="critical", description="plaintext password"
883
+ ))
884
+ assert obs.reward_breakdown.get("exploration_bonus", 0) > 0
885
+
886
+ def test_singlefile_no_exploration_bonus(self, env):
887
+ """Single-file tasks should not give exploration bonus."""
888
+ env.reset(task_id="bug-detection")
889
+ obs = env.step(ReviewAction(
890
+ action_type="flag_issue", line_number=6, filename="utils.py",
891
+ issue_type="bug", severity="high", description="off by one"
892
+ ))
893
+ assert obs.reward_breakdown.get("exploration_bonus", 0) == 0
tests/test_graders.py CHANGED
@@ -104,11 +104,18 @@ class TestMatchQuality:
104
  gt = _issue(6, "utils.py", "bug", "high")
105
  assert match_quality(f, gt) == "none"
106
 
107
- def test_near_ignores_type_difference(self):
108
- """Near match checks same file + line range, ignores type."""
109
  f = _issue(10, "utils.py", "performance", "high")
110
  gt = _issue(6, "utils.py", "bug", "high")
111
- # 4 lines away β†’ near
 
 
 
 
 
 
 
112
  assert match_quality(f, gt) == "near"
113
 
114
  def test_near_tolerance_constant(self):
 
104
  gt = _issue(6, "utils.py", "bug", "high")
105
  assert match_quality(f, gt) == "none"
106
 
107
+ def test_near_requires_compatible_type(self):
108
+ """Near match requires compatible issue type (not just proximity)."""
109
  f = _issue(10, "utils.py", "performance", "high")
110
  gt = _issue(6, "utils.py", "bug", "high")
111
+ # 4 lines away but wrong type β†’ none
112
+ assert match_quality(f, gt) == "none"
113
+
114
+ def test_near_with_compatible_type(self):
115
+ """Near match works with compatible type (bug/logic)."""
116
+ f = _issue(10, "utils.py", "logic", "high")
117
+ gt = _issue(6, "utils.py", "bug", "high")
118
+ # 4 lines away, compatible type β†’ near
119
  assert match_quality(f, gt) == "near"
120
 
121
  def test_near_tolerance_constant(self):