Spaces:
Sleeping
Sleeping
codemaverick2 commited on
Commit Β·
78f3eb2
1
Parent(s): e48a1e4
Add diversity/exploration bonuses, near-miss type check, context truncation
Browse files- README.md +6 -2
- inference.py +6 -0
- server/environment.py +31 -2
- server/graders.py +5 -3
- tests/test_environment.py +53 -0
- tests/test_graders.py +10 -3
README.md
CHANGED
|
@@ -208,7 +208,9 @@ Near-miss (Β±3-5 lines): graduated partial credit via exponential decay
|
|
| 208 |
| TP + early (first 40% of steps) | +0.02 bonus |
|
| 209 |
| TP + high confidence (β₯0.7) | +0.01 bonus |
|
| 210 |
| PBRS potential shaping (Ξ¦(s')βΞ¦(s)) | +0.03β0.08 |
|
| 211 |
-
|
|
|
|
|
|
|
|
| 212 |
| False positive | β0.05 |
|
| 213 |
| False positive flood (4th+ FP) | escalating β0.03 extra |
|
| 214 |
| High-confidence FP | β0.03 extra |
|
|
@@ -220,7 +222,9 @@ Near-miss (Β±3-5 lines): graduated partial credit via exponential decay
|
|
| 220 |
### Reward shaping foundations
|
| 221 |
|
| 222 |
- **Potential-Based Reward Shaping** (Ng et al. 1999): Ξ¦(s) = (tp/total_gt) Γ 0.5. Policy-invariant shaping that improves sample efficiency without changing the optimal policy.
|
| 223 |
-
- **Graduated near-miss** (exponential decay): reward = 0.10 Γ e^(β0.6 Γ (line_diff β 2)) for lines 3-5 off. Gives smooth gradient signal for line-number refinement.
|
|
|
|
|
|
|
| 224 |
- **Variable-Length Return Normalization** (VL Norm 2025): normalized_return = cumulative_reward / steps_used. Makes return comparable across tasks of different lengths.
|
| 225 |
- **Flood protection**: escalating FP penalty prevents reward hacking via flag-spamming.
|
| 226 |
|
|
|
|
| 208 |
| TP + early (first 40% of steps) | +0.02 bonus |
|
| 209 |
| TP + high confidence (β₯0.7) | +0.01 bonus |
|
| 210 |
| PBRS potential shaping (Ξ¦(s')βΞ¦(s)) | +0.03β0.08 |
|
| 211 |
+
| Diversity bonus (first TP in new issue category) | +0.02 |
|
| 212 |
+
| Exploration bonus (first TP in new file, multi-file tasks) | +0.01 |
|
| 213 |
+
| Near-miss (Β±3-5 lines, compatible type, exp decay) | +0.020β0.055 |
|
| 214 |
| False positive | β0.05 |
|
| 215 |
| False positive flood (4th+ FP) | escalating β0.03 extra |
|
| 216 |
| High-confidence FP | β0.03 extra |
|
|
|
|
| 222 |
### Reward shaping foundations
|
| 223 |
|
| 224 |
- **Potential-Based Reward Shaping** (Ng et al. 1999): Ξ¦(s) = (tp/total_gt) Γ 0.5. Policy-invariant shaping that improves sample efficiency without changing the optimal policy.
|
| 225 |
+
- **Graduated near-miss** (exponential decay): reward = 0.10 Γ e^(β0.6 Γ (line_diff β 2)) for lines 3-5 off with compatible issue type. Gives smooth gradient signal for line-number refinement.
|
| 226 |
+
- **Diversity bonus**: +0.02 for first TP in a new issue category (security/bug/performance). Encourages covering all issue types instead of spamming one.
|
| 227 |
+
- **Exploration bonus**: +0.01 for first TP in a new file (multi-file tasks only). Encourages cross-file coverage.
|
| 228 |
- **Variable-Length Return Normalization** (VL Norm 2025): normalized_return = cumulative_reward / steps_used. Makes return comparable across tasks of different lengths.
|
| 229 |
- **Flood protection**: escalating FP penalty prevents reward hacking via flag-spamming.
|
| 230 |
|
inference.py
CHANGED
|
@@ -404,6 +404,12 @@ def run_task(task_id: str, http_client: httpx.Client) -> dict:
|
|
| 404 |
if combined_feedback:
|
| 405 |
messages.append({"role": "user", "content": combined_feedback})
|
| 406 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
atype = action.get("action_type", "")
|
| 408 |
print(f" Step {step_count:2d}: {atype:20s} | reward={str(last_reward):8s} | score={obs.get('current_score', 0.0):.3f}")
|
| 409 |
|
|
|
|
| 404 |
if combined_feedback:
|
| 405 |
messages.append({"role": "user", "content": combined_feedback})
|
| 406 |
|
| 407 |
+
# Context window management: keep system + initial prompt + last 12 exchanges
|
| 408 |
+
# This prevents token limit errors on long episodes (25+ steps)
|
| 409 |
+
max_history = 2 + 24 # system + initial user + 12 assistant/user pairs
|
| 410 |
+
if len(messages) > max_history:
|
| 411 |
+
messages = messages[:2] + messages[-(max_history - 2):]
|
| 412 |
+
|
| 413 |
atype = action.get("action_type", "")
|
| 414 |
print(f" Step {step_count:2d}: {atype:20s} | reward={str(last_reward):8s} | score={obs.get('current_score', 0.0):.3f}")
|
| 415 |
|
server/environment.py
CHANGED
|
@@ -44,6 +44,10 @@ _VALIDATION_PENALTY = -0.02
|
|
| 44 |
# Flood protection: escalating FP penalty
|
| 45 |
_FP_FLOOD_THRESHOLD = 3 # FPs before escalation kicks in
|
| 46 |
_FP_FLOOD_MULTIPLIER = 1.5 # each extra FP beyond threshold costs 1.5x more
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
_SEV_RANK = {"low": 0, "medium": 1, "high": 2, "critical": 3}
|
| 49 |
|
|
@@ -80,6 +84,8 @@ class CodeReviewEnvironment(_BaseEnv):
|
|
| 80 |
self._fp_count: int = 0 # total false positives this episode
|
| 81 |
self._matched_gt_indices: Set[int] = set() # GT indices already matched
|
| 82 |
self._episode_rewards: List[float] = [] # for VL return normalization
|
|
|
|
|
|
|
| 83 |
|
| 84 |
def reset(
|
| 85 |
self,
|
|
@@ -104,6 +110,8 @@ class CodeReviewEnvironment(_BaseEnv):
|
|
| 104 |
self._fp_count = 0
|
| 105 |
self._matched_gt_indices = set()
|
| 106 |
self._episode_rewards = []
|
|
|
|
|
|
|
| 107 |
|
| 108 |
self._state = ReviewState(
|
| 109 |
task_id=task_id,
|
|
@@ -401,6 +409,11 @@ class CodeReviewEnvironment(_BaseEnv):
|
|
| 401 |
fix_suggestion=action.fix_suggestion,
|
| 402 |
)
|
| 403 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
# Classify: TP, near-miss (with line distance), or FP
|
| 405 |
is_tp = False
|
| 406 |
is_near = False
|
|
@@ -460,16 +473,32 @@ class CodeReviewEnvironment(_BaseEnv):
|
|
| 460 |
pbrs_bonus = round(phi_after - phi_before, 4)
|
| 461 |
reward_breakdown["pbrs_shaping"] = pbrs_bonus
|
| 462 |
|
| 463 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
reward_breakdown["total"] = round(reward, 4)
|
| 465 |
|
| 466 |
sev_note = f", severity +{severity_bonus:.2f}" if severity_bonus else ""
|
| 467 |
temp_note = f", early +{temporal_bonus:.2f}" if temporal_bonus else ""
|
| 468 |
conf_note = f", conf +{confidence_bonus:.2f}" if confidence_bonus else ""
|
| 469 |
pbrs_note = f", progress +{pbrs_bonus:.2f}" if pbrs_bonus > 0 else ""
|
|
|
|
| 470 |
feedback = (
|
| 471 |
f"Correct! Issue at {action.filename}:{action.line_number} confirmed. "
|
| 472 |
-
f"[+{reward:.2f}{sev_note}{temp_note}{conf_note}{pbrs_note}]"
|
| 473 |
)
|
| 474 |
|
| 475 |
elif is_near:
|
|
|
|
| 44 |
# Flood protection: escalating FP penalty
|
| 45 |
_FP_FLOOD_THRESHOLD = 3 # FPs before escalation kicks in
|
| 46 |
_FP_FLOOD_MULTIPLIER = 1.5 # each extra FP beyond threshold costs 1.5x more
|
| 47 |
+
# Diversity bonus: reward for covering a new issue category
|
| 48 |
+
_DIVERSITY_BONUS = 0.02 # first TP in a new issue_type category
|
| 49 |
+
# Exploration bonus: first flag in a previously unflagged file
|
| 50 |
+
_FILE_EXPLORATION_BONUS = 0.01
|
| 51 |
|
| 52 |
_SEV_RANK = {"low": 0, "medium": 1, "high": 2, "critical": 3}
|
| 53 |
|
|
|
|
| 84 |
self._fp_count: int = 0 # total false positives this episode
|
| 85 |
self._matched_gt_indices: Set[int] = set() # GT indices already matched
|
| 86 |
self._episode_rewards: List[float] = [] # for VL return normalization
|
| 87 |
+
self._found_categories: Set[str] = set() # issue types already found (for diversity bonus)
|
| 88 |
+
self._flagged_files: Set[str] = set() # files already flagged (for exploration bonus)
|
| 89 |
|
| 90 |
def reset(
|
| 91 |
self,
|
|
|
|
| 110 |
self._fp_count = 0
|
| 111 |
self._matched_gt_indices = set()
|
| 112 |
self._episode_rewards = []
|
| 113 |
+
self._found_categories = set()
|
| 114 |
+
self._flagged_files = set()
|
| 115 |
|
| 116 |
self._state = ReviewState(
|
| 117 |
task_id=task_id,
|
|
|
|
| 409 |
fix_suggestion=action.fix_suggestion,
|
| 410 |
)
|
| 411 |
|
| 412 |
+
# Track file exploration
|
| 413 |
+
is_new_file = action.filename not in self._flagged_files
|
| 414 |
+
if action.filename:
|
| 415 |
+
self._flagged_files.add(action.filename)
|
| 416 |
+
|
| 417 |
# Classify: TP, near-miss (with line distance), or FP
|
| 418 |
is_tp = False
|
| 419 |
is_near = False
|
|
|
|
| 473 |
pbrs_bonus = round(phi_after - phi_before, 4)
|
| 474 |
reward_breakdown["pbrs_shaping"] = pbrs_bonus
|
| 475 |
|
| 476 |
+
# Diversity bonus: first TP in a new issue category
|
| 477 |
+
diversity_bonus = 0.0
|
| 478 |
+
gt_type = matched_gt_issue.issue_type
|
| 479 |
+
if gt_type not in self._found_categories:
|
| 480 |
+
self._found_categories.add(gt_type)
|
| 481 |
+
diversity_bonus = _DIVERSITY_BONUS
|
| 482 |
+
reward_breakdown["diversity_bonus"] = diversity_bonus
|
| 483 |
+
|
| 484 |
+
# Exploration bonus: first flag in a new file (multi-file tasks)
|
| 485 |
+
exploration_bonus = 0.0
|
| 486 |
+
if is_new_file and len(self._task.get("code_files", {})) > 1:
|
| 487 |
+
exploration_bonus = _FILE_EXPLORATION_BONUS
|
| 488 |
+
reward_breakdown["exploration_bonus"] = exploration_bonus
|
| 489 |
+
|
| 490 |
+
reward = (base_reward + severity_bonus + temporal_bonus +
|
| 491 |
+
confidence_bonus + pbrs_bonus + diversity_bonus + exploration_bonus)
|
| 492 |
reward_breakdown["total"] = round(reward, 4)
|
| 493 |
|
| 494 |
sev_note = f", severity +{severity_bonus:.2f}" if severity_bonus else ""
|
| 495 |
temp_note = f", early +{temporal_bonus:.2f}" if temporal_bonus else ""
|
| 496 |
conf_note = f", conf +{confidence_bonus:.2f}" if confidence_bonus else ""
|
| 497 |
pbrs_note = f", progress +{pbrs_bonus:.2f}" if pbrs_bonus > 0 else ""
|
| 498 |
+
div_note = f", new-type +{diversity_bonus:.2f}" if diversity_bonus else ""
|
| 499 |
feedback = (
|
| 500 |
f"Correct! Issue at {action.filename}:{action.line_number} confirmed. "
|
| 501 |
+
f"[+{reward:.2f}{sev_note}{temp_note}{conf_note}{pbrs_note}{div_note}]"
|
| 502 |
)
|
| 503 |
|
| 504 |
elif is_near:
|
server/graders.py
CHANGED
|
@@ -58,21 +58,23 @@ def match_quality(flagged: Issue, gt: Issue) -> str:
|
|
| 58 |
"""
|
| 59 |
Return quality of match between flagged and gt:
|
| 60 |
"exact" β within Β±2 lines and right issue type
|
| 61 |
-
"near" β within Β±3-5 lines
|
| 62 |
"none" β no meaningful match
|
| 63 |
"""
|
| 64 |
if flagged.filename != gt.filename:
|
| 65 |
return "none"
|
| 66 |
|
| 67 |
line_diff = abs(flagged.line_number - gt.line_number)
|
|
|
|
| 68 |
|
| 69 |
if line_diff <= EXACT_TOLERANCE:
|
| 70 |
-
compat = _TYPE_COMPAT.get(gt.issue_type, {gt.issue_type})
|
| 71 |
if flagged.issue_type in compat:
|
| 72 |
return "exact"
|
| 73 |
|
| 74 |
if line_diff <= NEAR_TOLERANCE:
|
| 75 |
-
|
|
|
|
|
|
|
| 76 |
|
| 77 |
return "none"
|
| 78 |
|
|
|
|
| 58 |
"""
|
| 59 |
Return quality of match between flagged and gt:
|
| 60 |
"exact" β within Β±2 lines and right issue type
|
| 61 |
+
"near" β within Β±3-5 lines, same file, and compatible issue type
|
| 62 |
"none" β no meaningful match
|
| 63 |
"""
|
| 64 |
if flagged.filename != gt.filename:
|
| 65 |
return "none"
|
| 66 |
|
| 67 |
line_diff = abs(flagged.line_number - gt.line_number)
|
| 68 |
+
compat = _TYPE_COMPAT.get(gt.issue_type, {gt.issue_type})
|
| 69 |
|
| 70 |
if line_diff <= EXACT_TOLERANCE:
|
|
|
|
| 71 |
if flagged.issue_type in compat:
|
| 72 |
return "exact"
|
| 73 |
|
| 74 |
if line_diff <= NEAR_TOLERANCE:
|
| 75 |
+
# Near-miss requires compatible type to avoid rewarding wrong-type flags
|
| 76 |
+
if flagged.issue_type in compat:
|
| 77 |
+
return "near"
|
| 78 |
|
| 79 |
return "none"
|
| 80 |
|
tests/test_environment.py
CHANGED
|
@@ -838,3 +838,56 @@ class TestFunctionRanges:
|
|
| 838 |
def test_function_ranges_nonempty_for_python(self, env):
|
| 839 |
obs = env.reset(task_id="bug-detection")
|
| 840 |
assert len(obs.code_metadata["function_ranges"]) > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 838 |
def test_function_ranges_nonempty_for_python(self, env):
|
| 839 |
obs = env.reset(task_id="bug-detection")
|
| 840 |
assert len(obs.code_metadata["function_ranges"]) > 0
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
# ---------------------------------------------------------------------------
|
| 844 |
+
# Diversity bonus
|
| 845 |
+
# ---------------------------------------------------------------------------
|
| 846 |
+
|
| 847 |
+
class TestDiversityBonus:
|
| 848 |
+
def test_first_tp_in_category_gets_diversity_bonus(self, env):
|
| 849 |
+
"""First TP in a new issue category should include diversity_bonus."""
|
| 850 |
+
env.reset(task_id="security-audit")
|
| 851 |
+
obs = env.step(ReviewAction(
|
| 852 |
+
action_type="flag_issue", line_number=8, filename="app.py",
|
| 853 |
+
issue_type="security", severity="high", description="hardcoded secret"
|
| 854 |
+
))
|
| 855 |
+
# First security TP β should have diversity bonus
|
| 856 |
+
assert obs.reward_breakdown.get("diversity_bonus", 0) > 0
|
| 857 |
+
|
| 858 |
+
def test_second_tp_same_category_no_diversity_bonus(self, env):
|
| 859 |
+
"""Second TP in same category should NOT get diversity bonus."""
|
| 860 |
+
env.reset(task_id="security-audit")
|
| 861 |
+
env.step(ReviewAction(
|
| 862 |
+
action_type="flag_issue", line_number=8, filename="app.py",
|
| 863 |
+
issue_type="security", severity="high", description="hardcoded secret"
|
| 864 |
+
))
|
| 865 |
+
obs2 = env.step(ReviewAction(
|
| 866 |
+
action_type="flag_issue", line_number=19, filename="app.py",
|
| 867 |
+
issue_type="security", severity="critical", description="sql injection"
|
| 868 |
+
))
|
| 869 |
+
assert obs2.reward_breakdown.get("diversity_bonus", 0) == 0
|
| 870 |
+
|
| 871 |
+
|
| 872 |
+
# ---------------------------------------------------------------------------
|
| 873 |
+
# Exploration bonus (multi-file tasks)
|
| 874 |
+
# ---------------------------------------------------------------------------
|
| 875 |
+
|
| 876 |
+
class TestExplorationBonus:
|
| 877 |
+
def test_multifile_first_flag_gets_exploration_bonus(self, env):
|
| 878 |
+
"""First flag in a new file of a multi-file task gets exploration bonus."""
|
| 879 |
+
env.reset(task_id="comprehensive-review")
|
| 880 |
+
obs = env.step(ReviewAction(
|
| 881 |
+
action_type="flag_issue", line_number=7, filename="models.py",
|
| 882 |
+
issue_type="security", severity="critical", description="plaintext password"
|
| 883 |
+
))
|
| 884 |
+
assert obs.reward_breakdown.get("exploration_bonus", 0) > 0
|
| 885 |
+
|
| 886 |
+
def test_singlefile_no_exploration_bonus(self, env):
|
| 887 |
+
"""Single-file tasks should not give exploration bonus."""
|
| 888 |
+
env.reset(task_id="bug-detection")
|
| 889 |
+
obs = env.step(ReviewAction(
|
| 890 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 891 |
+
issue_type="bug", severity="high", description="off by one"
|
| 892 |
+
))
|
| 893 |
+
assert obs.reward_breakdown.get("exploration_bonus", 0) == 0
|
tests/test_graders.py
CHANGED
|
@@ -104,11 +104,18 @@ class TestMatchQuality:
|
|
| 104 |
gt = _issue(6, "utils.py", "bug", "high")
|
| 105 |
assert match_quality(f, gt) == "none"
|
| 106 |
|
| 107 |
-
def
|
| 108 |
-
"""Near match
|
| 109 |
f = _issue(10, "utils.py", "performance", "high")
|
| 110 |
gt = _issue(6, "utils.py", "bug", "high")
|
| 111 |
-
# 4 lines away β
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
assert match_quality(f, gt) == "near"
|
| 113 |
|
| 114 |
def test_near_tolerance_constant(self):
|
|
|
|
| 104 |
gt = _issue(6, "utils.py", "bug", "high")
|
| 105 |
assert match_quality(f, gt) == "none"
|
| 106 |
|
| 107 |
+
def test_near_requires_compatible_type(self):
|
| 108 |
+
"""Near match requires compatible issue type (not just proximity)."""
|
| 109 |
f = _issue(10, "utils.py", "performance", "high")
|
| 110 |
gt = _issue(6, "utils.py", "bug", "high")
|
| 111 |
+
# 4 lines away but wrong type β none
|
| 112 |
+
assert match_quality(f, gt) == "none"
|
| 113 |
+
|
| 114 |
+
def test_near_with_compatible_type(self):
|
| 115 |
+
"""Near match works with compatible type (bug/logic)."""
|
| 116 |
+
f = _issue(10, "utils.py", "logic", "high")
|
| 117 |
+
gt = _issue(6, "utils.py", "bug", "high")
|
| 118 |
+
# 4 lines away, compatible type β near
|
| 119 |
assert match_quality(f, gt) == "near"
|
| 120 |
|
| 121 |
def test_near_tolerance_constant(self):
|