Upload 16 files
Browse files- LICENSE +21 -0
- README.md +118 -103
- __init__.py +16 -16
- author.py +219 -219
- bugs.json +127 -127
- client.py +4 -4
- environment.py +613 -556
- grader.py +141 -147
- models.py +87 -114
- openenv.yaml +135 -135
- pyproject.toml +29 -29
- redteam.py +274 -274
- rubrics.py +136 -123
- test_runner.py +208 -181
- training.py +792 -708
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 YUTA
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,103 +1,118 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: Code Review Professional Workflow
|
| 3 |
-
emoji: 🔥
|
| 4 |
-
colorFrom: blue
|
| 5 |
-
colorTo: purple
|
| 6 |
-
sdk: docker
|
| 7 |
-
app_port: 7860
|
| 8 |
-
pinned: false
|
| 9 |
-
---
|
| 10 |
-
|
| 11 |
-
# Code Review Professional Workflow
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
##
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
| 64 |
-
|
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
| 69 |
-
|
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
| 74 |
-
|
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
| 79 |
-
|
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
| 84 |
-
|
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
| 89 |
-
|
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
| 94 |
-
|
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
```
|
| 98 |
-
|
| 99 |
-
``
|
| 100 |
-
|
| 101 |
-
##
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Code Review Professional Workflow
|
| 3 |
+
emoji: 🔥
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: false
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# Code Review Professional Workflow
|
| 12 |
+
|
| 13 |
+
This project is a multi-turn RL environment where an agent plays the role of a senior code reviewer.
|
| 14 |
+
Instead of just patching code, the agent must gather evidence (`inspect`, `run_tests`, `run_linter`,
|
| 15 |
+
`query_docs`) and convince a simulated developer persona to accept the fix.
|
| 16 |
+
|
| 17 |
+
### Why this environment is interesting
|
| 18 |
+
|
| 19 |
+
- It combines **technical correctness** (tests/lint) with **human acceptance** (negotiation).
|
| 20 |
+
- It includes **25 injected bug types** across 5 difficulty levels via `RedTeam`.
|
| 21 |
+
- It supports both a **full reward profile** (rich shaping) and a **core reward profile**
|
| 22 |
+
(minimal, baseline-friendly signal for ablations).
|
| 23 |
+
|
| 24 |
+
## Quick Start
|
| 25 |
+
|
| 26 |
+
```python
|
| 27 |
+
from environment import CodeReviewEnv
|
| 28 |
+
env = CodeReviewEnv(task="easy", reward_profile="full")
|
| 29 |
+
obs = env.reset()
|
| 30 |
+
print(obs.code_snippet)
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
## Demo Script (Non-Technical Friendly)
|
| 34 |
+
|
| 35 |
+
Use this 60-90 second flow in a demo:
|
| 36 |
+
|
| 37 |
+
1. Reset on `easy` and show the buggy snippet.
|
| 38 |
+
2. Take `inspect` and `run_tests` actions to show evidence gathering.
|
| 39 |
+
3. Ask `query_docs` once to show retrieval-assisted reasoning.
|
| 40 |
+
4. Propose a fix and show accepted/denied feedback from the author persona.
|
| 41 |
+
5. Repeat once on `harder` to show increased challenge.
|
| 42 |
+
|
| 43 |
+
Message for audience: "The agent is learning not only to fix code, but to justify and communicate the fix."
|
| 44 |
+
|
| 45 |
+
## Environment Endpoints
|
| 46 |
+
|
| 47 |
+
- `POST /reset` – reset environment (optional `task` parameter)
|
| 48 |
+
- `POST /step` – take an action (JSON)
|
| 49 |
+
- `GET /state` – get full environment state
|
| 50 |
+
- `GET /health` – health check
|
| 51 |
+
- `GET /metadata` – environment metadata
|
| 52 |
+
- `GET /schema` – action/observation schemas
|
| 53 |
+
- `POST /mcp` – minimal MCP endpoint
|
| 54 |
+
|
| 55 |
+
## Tasks
|
| 56 |
+
## 🐛 Bug Taxonomy (25 bugs across 5 difficulty levels)
|
| 57 |
+
|
| 58 |
+
The **RedTeam** randomly selects one bug from the current difficulty level at the start of every episode.
|
| 59 |
+
Your agent must figure out what’s broken, gather evidence, and convince the simulated author – or it won’t stick.
|
| 60 |
+
|
| 61 |
+
### 🟢 Easy – Null‑Checks & Simple Logic Errors
|
| 62 |
+
|
| 63 |
+
| # | Bug ID | What’s wrong | Injection method |
|
| 64 |
+
|---|--------|--------------|------------------|
|
| 65 |
+
| 1 | `null_check` | Missing `if key in dict:` guard → KeyError | AST: remove the if‑statement |
|
| 66 |
+
| 2 | `simple_typo` | Misspelled variable `users` → `usres` | AST: rename variable |
|
| 67 |
+
| 3 | `string_index` | String index shifted by +1 | AST: change constant in index |
|
| 68 |
+
| 4 | `default_value` | `dict.get(key)` used without a fallback | AST: replace `dict.get(key)` with `dict[key]` |
|
| 69 |
+
| 5 | `empty_return` | Function returns `None` prematurely | AST: insert `return None` early |
|
| 70 |
+
|
| 71 |
+
### 🟡 Medium – Off‑By‑One, Loop Logic & Simple Arithmetic
|
| 72 |
+
|
| 73 |
+
| # | Bug ID | What’s wrong | Injection method |
|
| 74 |
+
|---|--------|--------------|------------------|
|
| 75 |
+
| 6 | `off_by_one` | `range(x)` becomes `range(1, x-1)` – skips first & last | AST: modify range arguments |
|
| 76 |
+
| 7 | `loop_skip` | `range(len(arr))` becomes `range(len(arr)-1)` – misses last element | AST: change range length |
|
| 77 |
+
| 8 | `sign_error` | `sum += item` turned into `sum -= item` | AST: swap Add / Sub |
|
| 78 |
+
| 9 | `swap_args` | Function arguments swapped | AST: swap first two arguments |
|
| 79 |
+
|10 | `uninitialised_var` | Variable used before assignment in a loop | AST: remove the assignment statement |
|
| 80 |
+
|
| 81 |
+
### 🟠 Hard – Division‑By‑Zero, Floating‑Point & Edge Cases
|
| 82 |
+
|
| 83 |
+
| # | Bug ID | What’s wrong | Injection method |
|
| 84 |
+
|---|--------|--------------|------------------|
|
| 85 |
+
|11 | `division_by_zero_empty` | Empty‑list guard removed before averaging | AST: delete `if not data:` |
|
| 86 |
+
|12 | `division_by_zero_zero` | Denominator check removed | AST: remove the zero‑check |
|
| 87 |
+
|13 | `float_precision` | True division `/` replaced by integer division `//` | AST: change Div → FloorDiv |
|
| 88 |
+
|14 | `abs_usage` | `abs()` call removed when comparing differences | AST: delete `abs()` wrapper |
|
| 89 |
+
|15 | `round_error` | `round()` placed too early, causing precision drift | AST: inject `round()` prematurely |
|
| 90 |
+
|
| 91 |
+
### 🔴 Harder – Race Conditions & Atomicity Bugs
|
| 92 |
+
|
| 93 |
+
| # | Bug ID | What’s wrong | Injection method |
|
| 94 |
+
|---|--------|--------------|------------------|
|
| 95 |
+
|16 | `missing_lock` | Shared counter incremented without a lock | Template: remove `with lock:` |
|
| 96 |
+
|17 | `double_lock` | Acquiring the same lock twice → deadlock risk | Template: add extra `lock.acquire()` |
|
| 97 |
+
|18 | `global_nonatomic` | `count = count + 1` (read‑modify‑write) instead of `+=` | AST: modify assignment node |
|
| 98 |
+
|19 | `thread_safe_list` | List append across threads without synchronisation | Template: remove lock from list operation |
|
| 99 |
+
|20 | `volatile_read` | Shared flag read outside a lock → stale value | Template: remove synchronisation block |
|
| 100 |
+
|
| 101 |
+
### ⚫ Hardest – Deadlocks, Ordering & Complex Concurrency
|
| 102 |
+
|
| 103 |
+
| # | Bug ID | What’s wrong | Injection method |
|
| 104 |
+
|---|--------|--------------|------------------|
|
| 105 |
+
|21 | `deadlock_order` | Locks acquired in opposite order in two threads | Template: swap lock order |
|
| 106 |
+
|22 | `nested_lock_timeout` | `lock.acquire()` without a timeout → permanent hang | Template: remove timeout logic |
|
| 107 |
+
|23 | `fork_join` | Thread started but not joined (`join()` missing) | AST: remove `thread.join()` |
|
| 108 |
+
|24 | `mutex_release` | Lock released by a thread that never acquired it | Template: incorrect release logic |
|
| 109 |
+
|25 | `race_on_init` | Shared resource initialised after threads have started | Template: move initialisation after `join()` |
|
| 110 |
+
## Deployment
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
openenv push
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
## License
|
| 117 |
+
|
| 118 |
+
MIT
|
__init__.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
| 1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the BSD-style license found in the
|
| 5 |
-
# LICENSE file in the root directory of this source tree.
|
| 6 |
-
|
| 7 |
-
"""Criticrl Environment."""
|
| 8 |
-
|
| 9 |
-
from .client import CriticrlEnv
|
| 10 |
-
from .models import CriticrlAction, CriticrlObservation
|
| 11 |
-
|
| 12 |
-
__all__ = [
|
| 13 |
-
"CriticrlAction",
|
| 14 |
-
"CriticrlObservation",
|
| 15 |
-
"CriticrlEnv",
|
| 16 |
-
]
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Criticrl Environment."""
|
| 8 |
+
|
| 9 |
+
from .client import CriticrlEnv
|
| 10 |
+
from .models import CriticrlAction, CriticrlObservation
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"CriticrlAction",
|
| 14 |
+
"CriticrlObservation",
|
| 15 |
+
"CriticrlEnv",
|
| 16 |
+
]
|
author.py
CHANGED
|
@@ -1,219 +1,219 @@
|
|
| 1 |
-
# cell 7 author.py – Final production version: stateful, evidence-driven, belief tracking
|
| 2 |
-
|
| 3 |
-
import re
|
| 4 |
-
import ast
|
| 5 |
-
from dataclasses import dataclass, field
|
| 6 |
-
from typing import List, Dict, Any, Optional
|
| 7 |
-
|
| 8 |
-
@dataclass
|
| 9 |
-
class PersonaAuthor:
|
| 10 |
-
"""
|
| 11 |
-
Simulates a human developer with:
|
| 12 |
-
- Continuous belief (confidence)
|
| 13 |
-
- Evidence-based reasoning
|
| 14 |
-
- Conversation memory
|
| 15 |
-
- Code inspection awareness
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
-
personality: str = "defensive" # defensive | junior | collaborative
|
| 19 |
-
max_persuasion_rounds: int = 5
|
| 20 |
-
|
| 21 |
-
# Evidence weights
|
| 22 |
-
weight_test_pass: float = 0.5
|
| 23 |
-
weight_lint_clean: float = 0.2
|
| 24 |
-
weight_doc_found: float = 0.15
|
| 25 |
-
weight_explanation_quality: float = 0.15
|
| 26 |
-
|
| 27 |
-
# Personality thresholds
|
| 28 |
-
thresholds: Dict[str, float] = field(default_factory=lambda: {
|
| 29 |
-
"defensive": 0.7,
|
| 30 |
-
"junior": 0.3,
|
| 31 |
-
"collaborative": 0.5,
|
| 32 |
-
})
|
| 33 |
-
|
| 34 |
-
# Internal state
|
| 35 |
-
_confidence: float = 0.0
|
| 36 |
-
_conversation: List[Dict[str, Any]] = field(default_factory=list)
|
| 37 |
-
_pushback_count: int = 0
|
| 38 |
-
_last_evidence_score: float = 0.0
|
| 39 |
-
_stagnation_counter: int = 0
|
| 40 |
-
|
| 41 |
-
# ------------------------------------------------------------------
|
| 42 |
-
# Lifecycle
|
| 43 |
-
# ------------------------------------------------------------------
|
| 44 |
-
def __post_init__(self):
|
| 45 |
-
self.reset()
|
| 46 |
-
|
| 47 |
-
def reset(self):
|
| 48 |
-
self._confidence = 0.0
|
| 49 |
-
self._conversation.clear()
|
| 50 |
-
self._pushback_count = 0
|
| 51 |
-
self._last_evidence_score = 0.0
|
| 52 |
-
self._stagnation_counter = 0
|
| 53 |
-
|
| 54 |
-
# ------------------------------------------------------------------
|
| 55 |
-
# Main interaction
|
| 56 |
-
# ------------------------------------------------------------------
|
| 57 |
-
# Added weight for code change magnitude
|
| 58 |
-
weight_code_change: float = 0.1 # small change is better
|
| 59 |
-
|
| 60 |
-
def respond(self,
|
| 61 |
-
agent_comment: str = "",
|
| 62 |
-
agent_question: str = "",
|
| 63 |
-
test_results: Optional[str] = None,
|
| 64 |
-
lint_results: Optional[str] = None,
|
| 65 |
-
doc_results: Optional[str] = None,
|
| 66 |
-
proposed_fix: Optional[str] = None,
|
| 67 |
-
original_code: Optional[str] = None) -> str:
|
| 68 |
-
|
| 69 |
-
# Store conversation
|
| 70 |
-
self._conversation.append({
|
| 71 |
-
"comment": agent_comment,
|
| 72 |
-
"question": agent_question,
|
| 73 |
-
"test": test_results,
|
| 74 |
-
"lint": lint_results,
|
| 75 |
-
"docs": doc_results
|
| 76 |
-
})
|
| 77 |
-
|
| 78 |
-
# Extract structured evidence
|
| 79 |
-
evidence = self._extract_evidence(test_results, lint_results, doc_results)
|
| 80 |
-
|
| 81 |
-
# Code inspection
|
| 82 |
-
code_change = 0.0
|
| 83 |
-
if proposed_fix and original_code:
|
| 84 |
-
code_change = self._inspect_code(proposed_fix, original_code)
|
| 85 |
-
evidence["code_change"] = code_change
|
| 86 |
-
|
| 87 |
-
# Explanation score
|
| 88 |
-
text = (agent_comment + " " + agent_question).lower()
|
| 89 |
-
explanation_score = self._score_explanation(text)
|
| 90 |
-
|
| 91 |
-
# Compute evidence score – now includes code change penalty (1 - change)
|
| 92 |
-
evidence_score = (
|
| 93 |
-
self.weight_test_pass * evidence.get("test_pass_ratio", 0.0) +
|
| 94 |
-
self.weight_lint_clean * (1 - min(1.0, evidence.get("lint_errors", 0)/10)) +
|
| 95 |
-
self.weight_doc_found * (1.0 if evidence.get("doc_found") else 0.0) +
|
| 96 |
-
self.weight_explanation_quality * explanation_score +
|
| 97 |
-
self.weight_code_change * (1.0 - code_change) # surgical fix rewarded
|
| 98 |
-
)
|
| 99 |
-
|
| 100 |
-
evidence_score = max(0.0, min(1.0, evidence_score))
|
| 101 |
-
|
| 102 |
-
# Detect improvement
|
| 103 |
-
delta = evidence_score - self._last_evidence_score
|
| 104 |
-
self._last_evidence_score = evidence_score
|
| 105 |
-
|
| 106 |
-
if delta > 0.05:
|
| 107 |
-
self._stagnation_counter = 0
|
| 108 |
-
else:
|
| 109 |
-
self._stagnation_counter += 1
|
| 110 |
-
|
| 111 |
-
# Update belief (momentum)
|
| 112 |
-
lr = 0.3
|
| 113 |
-
self._confidence = (1 - lr) * self._confidence + lr * evidence_score
|
| 114 |
-
|
| 115 |
-
# Penalise stagnation
|
| 116 |
-
if self._stagnation_counter >= 2:
|
| 117 |
-
self._confidence *= 0.9
|
| 118 |
-
|
| 119 |
-
# Decision
|
| 120 |
-
threshold = self.thresholds.get(self.personality, 0.5)
|
| 121 |
-
|
| 122 |
-
if self._confidence >= threshold or self._pushback_count >= self.max_persuasion_rounds:
|
| 123 |
-
return "Alright, I'm convinced. Let's proceed with your fix."
|
| 124 |
-
|
| 125 |
-
# Otherwise push back
|
| 126 |
-
self._pushback_count += 1
|
| 127 |
-
return self._generate_pushback(evidence, text)
|
| 128 |
-
|
| 129 |
-
# ------------------------------------------------------------------
|
| 130 |
-
# Evidence extraction
|
| 131 |
-
# ------------------------------------------------------------------
|
| 132 |
-
def _extract_evidence(self, test_results, lint_results, doc_results):
|
| 133 |
-
evidence = {
|
| 134 |
-
"test_pass_ratio": 0.0,
|
| 135 |
-
"lint_errors": 0,
|
| 136 |
-
"doc_found": False
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
-
# Parse test results
|
| 140 |
-
if test_results:
|
| 141 |
-
match = re.search(r'(\d+)\s*/\s*(\d+)', test_results)
|
| 142 |
-
if match:
|
| 143 |
-
p, t = int(match.group(1)), int(match.group(2))
|
| 144 |
-
evidence["test_pass_ratio"] = p / t if t else 0.0
|
| 145 |
-
elif "true" in test_results.lower():
|
| 146 |
-
evidence["test_pass_ratio"] = 1.0
|
| 147 |
-
elif "false" in test_results.lower():
|
| 148 |
-
evidence["test_pass_ratio"] = 0.0
|
| 149 |
-
|
| 150 |
-
# Lint errors
|
| 151 |
-
if lint_results:
|
| 152 |
-
evidence["lint_errors"] = len(re.findall(r'error', lint_results.lower()))
|
| 153 |
-
|
| 154 |
-
# Docs
|
| 155 |
-
if doc_results and "no relevant" not in doc_results.lower():
|
| 156 |
-
evidence["doc_found"] = True
|
| 157 |
-
|
| 158 |
-
return evidence
|
| 159 |
-
|
| 160 |
-
# ------------------------------------------------------------------
|
| 161 |
-
# Explanation scoring
|
| 162 |
-
# ------------------------------------------------------------------
|
| 163 |
-
def _score_explanation(self, text: str) -> float:
|
| 164 |
-
score = 0.0
|
| 165 |
-
|
| 166 |
-
if "because" in text or "therefore" in text:
|
| 167 |
-
score += 0.3
|
| 168 |
-
if "test" in text or "example" in text:
|
| 169 |
-
score += 0.2
|
| 170 |
-
if len(text.split()) > 30:
|
| 171 |
-
score += 0.2
|
| 172 |
-
if "error" in text or "fix" in text:
|
| 173 |
-
score += 0.1
|
| 174 |
-
|
| 175 |
-
return min(1.0, score)
|
| 176 |
-
|
| 177 |
-
# ------------------------------------------------------------------
|
| 178 |
-
# Code inspection
|
| 179 |
-
# ------------------------------------------------------------------
|
| 180 |
-
def _inspect_code(self, new_code: str, old_code: str) -> float:
|
| 181 |
-
try:
|
| 182 |
-
t1 = ast.parse(old_code)
|
| 183 |
-
t2 = ast.parse(new_code)
|
| 184 |
-
|
| 185 |
-
n1 = len(list(ast.walk(t1)))
|
| 186 |
-
n2 = len(list(ast.walk(t2)))
|
| 187 |
-
|
| 188 |
-
change = abs(n2 - n1) / max(n1, 1)
|
| 189 |
-
return min(1.0, change)
|
| 190 |
-
except:
|
| 191 |
-
return 0.0
|
| 192 |
-
|
| 193 |
-
# ------------------------------------------------------------------
|
| 194 |
-
# Pushback generator
|
| 195 |
-
# ------------------------------------------------------------------
|
| 196 |
-
def _generate_pushback(self, evidence, text):
|
| 197 |
-
if evidence["test_pass_ratio"] < 0.5:
|
| 198 |
-
return "Tests are still failing. Show a passing case."
|
| 199 |
-
|
| 200 |
-
if evidence["lint_errors"] > 0:
|
| 201 |
-
return f"There are {evidence['lint_errors']} lint errors. Fix them."
|
| 202 |
-
|
| 203 |
-
if not evidence["doc_found"]:
|
| 204 |
-
return "Provide documentation or reference."
|
| 205 |
-
|
| 206 |
-
if "because" not in text:
|
| 207 |
-
return "Explain why this works."
|
| 208 |
-
|
| 209 |
-
if len(text.split()) < 20:
|
| 210 |
-
return "Too brief. Expand your reasoning."
|
| 211 |
-
|
| 212 |
-
return "Not convinced yet. Give a concrete example."
|
| 213 |
-
|
| 214 |
-
# ------------------------------------------------------------------
|
| 215 |
-
# Score
|
| 216 |
-
# ------------------------------------------------------------------
|
| 217 |
-
def get_negotiation_score(self) -> float:
|
| 218 |
-
penalty = 0.1 * min(3, self._pushback_count)
|
| 219 |
-
return max(0.0, min(1.0, self._confidence - penalty))
|
|
|
|
| 1 |
+
# cell 7 author.py – Final production version: stateful, evidence-driven, belief tracking
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
import ast
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import List, Dict, Any, Optional
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class PersonaAuthor:
|
| 10 |
+
"""
|
| 11 |
+
Simulates a human developer with:
|
| 12 |
+
- Continuous belief (confidence)
|
| 13 |
+
- Evidence-based reasoning
|
| 14 |
+
- Conversation memory
|
| 15 |
+
- Code inspection awareness
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
personality: str = "defensive" # defensive | junior | collaborative
|
| 19 |
+
max_persuasion_rounds: int = 5
|
| 20 |
+
|
| 21 |
+
# Evidence weights
|
| 22 |
+
weight_test_pass: float = 0.5
|
| 23 |
+
weight_lint_clean: float = 0.2
|
| 24 |
+
weight_doc_found: float = 0.15
|
| 25 |
+
weight_explanation_quality: float = 0.15
|
| 26 |
+
|
| 27 |
+
# Personality thresholds
|
| 28 |
+
thresholds: Dict[str, float] = field(default_factory=lambda: {
|
| 29 |
+
"defensive": 0.7,
|
| 30 |
+
"junior": 0.3,
|
| 31 |
+
"collaborative": 0.5,
|
| 32 |
+
})
|
| 33 |
+
|
| 34 |
+
# Internal state
|
| 35 |
+
_confidence: float = 0.0
|
| 36 |
+
_conversation: List[Dict[str, Any]] = field(default_factory=list)
|
| 37 |
+
_pushback_count: int = 0
|
| 38 |
+
_last_evidence_score: float = 0.0
|
| 39 |
+
_stagnation_counter: int = 0
|
| 40 |
+
|
| 41 |
+
# ------------------------------------------------------------------
|
| 42 |
+
# Lifecycle
|
| 43 |
+
# ------------------------------------------------------------------
|
| 44 |
+
def __post_init__(self):
|
| 45 |
+
self.reset()
|
| 46 |
+
|
| 47 |
+
def reset(self):
|
| 48 |
+
self._confidence = 0.0
|
| 49 |
+
self._conversation.clear()
|
| 50 |
+
self._pushback_count = 0
|
| 51 |
+
self._last_evidence_score = 0.0
|
| 52 |
+
self._stagnation_counter = 0
|
| 53 |
+
|
| 54 |
+
# ------------------------------------------------------------------
|
| 55 |
+
# Main interaction
|
| 56 |
+
# ------------------------------------------------------------------
|
| 57 |
+
# Added weight for code change magnitude
|
| 58 |
+
weight_code_change: float = 0.1 # small change is better
|
| 59 |
+
|
| 60 |
+
def respond(self,
|
| 61 |
+
agent_comment: str = "",
|
| 62 |
+
agent_question: str = "",
|
| 63 |
+
test_results: Optional[str] = None,
|
| 64 |
+
lint_results: Optional[str] = None,
|
| 65 |
+
doc_results: Optional[str] = None,
|
| 66 |
+
proposed_fix: Optional[str] = None,
|
| 67 |
+
original_code: Optional[str] = None) -> str:
|
| 68 |
+
|
| 69 |
+
# Store conversation
|
| 70 |
+
self._conversation.append({
|
| 71 |
+
"comment": agent_comment,
|
| 72 |
+
"question": agent_question,
|
| 73 |
+
"test": test_results,
|
| 74 |
+
"lint": lint_results,
|
| 75 |
+
"docs": doc_results
|
| 76 |
+
})
|
| 77 |
+
|
| 78 |
+
# Extract structured evidence
|
| 79 |
+
evidence = self._extract_evidence(test_results, lint_results, doc_results)
|
| 80 |
+
|
| 81 |
+
# Code inspection
|
| 82 |
+
code_change = 0.0
|
| 83 |
+
if proposed_fix and original_code:
|
| 84 |
+
code_change = self._inspect_code(proposed_fix, original_code)
|
| 85 |
+
evidence["code_change"] = code_change
|
| 86 |
+
|
| 87 |
+
# Explanation score
|
| 88 |
+
text = (agent_comment + " " + agent_question).lower()
|
| 89 |
+
explanation_score = self._score_explanation(text)
|
| 90 |
+
|
| 91 |
+
# Compute evidence score – now includes code change penalty (1 - change)
|
| 92 |
+
evidence_score = (
|
| 93 |
+
self.weight_test_pass * evidence.get("test_pass_ratio", 0.0) +
|
| 94 |
+
self.weight_lint_clean * (1 - min(1.0, evidence.get("lint_errors", 0)/10)) +
|
| 95 |
+
self.weight_doc_found * (1.0 if evidence.get("doc_found") else 0.0) +
|
| 96 |
+
self.weight_explanation_quality * explanation_score +
|
| 97 |
+
self.weight_code_change * (1.0 - code_change) # surgical fix rewarded
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
evidence_score = max(0.0, min(1.0, evidence_score))
|
| 101 |
+
|
| 102 |
+
# Detect improvement
|
| 103 |
+
delta = evidence_score - self._last_evidence_score
|
| 104 |
+
self._last_evidence_score = evidence_score
|
| 105 |
+
|
| 106 |
+
if delta > 0.05:
|
| 107 |
+
self._stagnation_counter = 0
|
| 108 |
+
else:
|
| 109 |
+
self._stagnation_counter += 1
|
| 110 |
+
|
| 111 |
+
# Update belief (momentum)
|
| 112 |
+
lr = 0.3
|
| 113 |
+
self._confidence = (1 - lr) * self._confidence + lr * evidence_score
|
| 114 |
+
|
| 115 |
+
# Penalise stagnation
|
| 116 |
+
if self._stagnation_counter >= 2:
|
| 117 |
+
self._confidence *= 0.9
|
| 118 |
+
|
| 119 |
+
# Decision
|
| 120 |
+
threshold = self.thresholds.get(self.personality, 0.5)
|
| 121 |
+
|
| 122 |
+
if self._confidence >= threshold or self._pushback_count >= self.max_persuasion_rounds:
|
| 123 |
+
return "Alright, I'm convinced. Let's proceed with your fix."
|
| 124 |
+
|
| 125 |
+
# Otherwise push back
|
| 126 |
+
self._pushback_count += 1
|
| 127 |
+
return self._generate_pushback(evidence, text)
|
| 128 |
+
|
| 129 |
+
# ------------------------------------------------------------------
|
| 130 |
+
# Evidence extraction
|
| 131 |
+
# ------------------------------------------------------------------
|
| 132 |
+
def _extract_evidence(self, test_results, lint_results, doc_results):
|
| 133 |
+
evidence = {
|
| 134 |
+
"test_pass_ratio": 0.0,
|
| 135 |
+
"lint_errors": 0,
|
| 136 |
+
"doc_found": False
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
# Parse test results
|
| 140 |
+
if test_results:
|
| 141 |
+
match = re.search(r'(\d+)\s*/\s*(\d+)', test_results)
|
| 142 |
+
if match:
|
| 143 |
+
p, t = int(match.group(1)), int(match.group(2))
|
| 144 |
+
evidence["test_pass_ratio"] = p / t if t else 0.0
|
| 145 |
+
elif "true" in test_results.lower():
|
| 146 |
+
evidence["test_pass_ratio"] = 1.0
|
| 147 |
+
elif "false" in test_results.lower():
|
| 148 |
+
evidence["test_pass_ratio"] = 0.0
|
| 149 |
+
|
| 150 |
+
# Lint errors
|
| 151 |
+
if lint_results:
|
| 152 |
+
evidence["lint_errors"] = len(re.findall(r'error', lint_results.lower()))
|
| 153 |
+
|
| 154 |
+
# Docs
|
| 155 |
+
if doc_results and "no relevant" not in doc_results.lower():
|
| 156 |
+
evidence["doc_found"] = True
|
| 157 |
+
|
| 158 |
+
return evidence
|
| 159 |
+
|
| 160 |
+
# ------------------------------------------------------------------
|
| 161 |
+
# Explanation scoring
|
| 162 |
+
# ------------------------------------------------------------------
|
| 163 |
+
def _score_explanation(self, text: str) -> float:
|
| 164 |
+
score = 0.0
|
| 165 |
+
|
| 166 |
+
if "because" in text or "therefore" in text:
|
| 167 |
+
score += 0.3
|
| 168 |
+
if "test" in text or "example" in text:
|
| 169 |
+
score += 0.2
|
| 170 |
+
if len(text.split()) > 30:
|
| 171 |
+
score += 0.2
|
| 172 |
+
if "error" in text or "fix" in text:
|
| 173 |
+
score += 0.1
|
| 174 |
+
|
| 175 |
+
return min(1.0, score)
|
| 176 |
+
|
| 177 |
+
# ------------------------------------------------------------------
|
| 178 |
+
# Code inspection
|
| 179 |
+
# ------------------------------------------------------------------
|
| 180 |
+
def _inspect_code(self, new_code: str, old_code: str) -> float:
|
| 181 |
+
try:
|
| 182 |
+
t1 = ast.parse(old_code)
|
| 183 |
+
t2 = ast.parse(new_code)
|
| 184 |
+
|
| 185 |
+
n1 = len(list(ast.walk(t1)))
|
| 186 |
+
n2 = len(list(ast.walk(t2)))
|
| 187 |
+
|
| 188 |
+
change = abs(n2 - n1) / max(n1, 1)
|
| 189 |
+
return min(1.0, change)
|
| 190 |
+
except:
|
| 191 |
+
return 0.0
|
| 192 |
+
|
| 193 |
+
# ------------------------------------------------------------------
|
| 194 |
+
# Pushback generator
|
| 195 |
+
# ------------------------------------------------------------------
|
| 196 |
+
def _generate_pushback(self, evidence, text):
|
| 197 |
+
if evidence["test_pass_ratio"] < 0.5:
|
| 198 |
+
return "Tests are still failing. Show a passing case."
|
| 199 |
+
|
| 200 |
+
if evidence["lint_errors"] > 0:
|
| 201 |
+
return f"There are {evidence['lint_errors']} lint errors. Fix them."
|
| 202 |
+
|
| 203 |
+
if not evidence["doc_found"]:
|
| 204 |
+
return "Provide documentation or reference."
|
| 205 |
+
|
| 206 |
+
if "because" not in text:
|
| 207 |
+
return "Explain why this works."
|
| 208 |
+
|
| 209 |
+
if len(text.split()) < 20:
|
| 210 |
+
return "Too brief. Expand your reasoning."
|
| 211 |
+
|
| 212 |
+
return "Not convinced yet. Give a concrete example."
|
| 213 |
+
|
| 214 |
+
# ------------------------------------------------------------------
|
| 215 |
+
# Score
|
| 216 |
+
# ------------------------------------------------------------------
|
| 217 |
+
def get_negotiation_score(self) -> float:
|
| 218 |
+
penalty = 0.1 * min(3, self._pushback_count)
|
| 219 |
+
return max(0.0, min(1.0, self._confidence - penalty))
|
bugs.json
CHANGED
|
@@ -1,127 +1,127 @@
|
|
| 1 |
-
{
|
| 2 |
-
"easy": {
|
| 3 |
-
"null_check": {
|
| 4 |
-
"type": "ast",
|
| 5 |
-
"bug_type": "null_check",
|
| 6 |
-
"oracle_hint": "Add back the if-guard that was removed"
|
| 7 |
-
},
|
| 8 |
-
"simple_typo": {
|
| 9 |
-
"type": "ast",
|
| 10 |
-
"bug_type": "simple_typo",
|
| 11 |
-
"oracle_hint": "Fix the misspelled variable name"
|
| 12 |
-
},
|
| 13 |
-
"string_index": {
|
| 14 |
-
"type": "ast",
|
| 15 |
-
"bug_type": "string_index",
|
| 16 |
-
"oracle_hint": "Correct the index offset"
|
| 17 |
-
},
|
| 18 |
-
"default_value": {
|
| 19 |
-
"type": "ast",
|
| 20 |
-
"bug_type": "default_value",
|
| 21 |
-
"oracle_hint": "Restore dict.get() with proper default"
|
| 22 |
-
},
|
| 23 |
-
"empty_return": {
|
| 24 |
-
"type": "ast",
|
| 25 |
-
"bug_type": "empty_return",
|
| 26 |
-
"oracle_hint": "Remove the premature return None"
|
| 27 |
-
}
|
| 28 |
-
},
|
| 29 |
-
"medium": {
|
| 30 |
-
"off_by_one": {
|
| 31 |
-
"type": "ast",
|
| 32 |
-
"bug_type": "off_by_one"
|
| 33 |
-
},
|
| 34 |
-
"loop_skip": {
|
| 35 |
-
"type": "ast",
|
| 36 |
-
"bug_type": "loop_skip"
|
| 37 |
-
},
|
| 38 |
-
"sign_error": {
|
| 39 |
-
"type": "ast",
|
| 40 |
-
"bug_type": "sign_error"
|
| 41 |
-
},
|
| 42 |
-
"swap_args": {
|
| 43 |
-
"type": "ast",
|
| 44 |
-
"bug_type": "swap_args"
|
| 45 |
-
},
|
| 46 |
-
"uninitialised_var": {
|
| 47 |
-
"type": "ast",
|
| 48 |
-
"bug_type": "uninitialised_var"
|
| 49 |
-
}
|
| 50 |
-
},
|
| 51 |
-
"hard": {
|
| 52 |
-
"division_by_zero_empty": {
|
| 53 |
-
"type": "ast",
|
| 54 |
-
"bug_type": "division_by_zero_empty"
|
| 55 |
-
},
|
| 56 |
-
"division_by_zero_zero": {
|
| 57 |
-
"type": "ast",
|
| 58 |
-
"bug_type": "division_by_zero_zero"
|
| 59 |
-
},
|
| 60 |
-
"float_precision": {
|
| 61 |
-
"type": "ast",
|
| 62 |
-
"bug_type": "float_precision"
|
| 63 |
-
},
|
| 64 |
-
"abs_usage": {
|
| 65 |
-
"type": "ast",
|
| 66 |
-
"bug_type": "abs_usage"
|
| 67 |
-
},
|
| 68 |
-
"round_error": {
|
| 69 |
-
"type": "ast",
|
| 70 |
-
"bug_type": "round_error"
|
| 71 |
-
}
|
| 72 |
-
},
|
| 73 |
-
"harder": {
|
| 74 |
-
"missing_lock": {
|
| 75 |
-
"type": "template",
|
| 76 |
-
"buggy": "counter = 0\ndef increment():\n global counter\n counter += 1",
|
| 77 |
-
"oracle": "counter = 0\nimport threading\nlock = threading.Lock()\ndef increment():\n global counter\n with lock:\n counter += 1"
|
| 78 |
-
},
|
| 79 |
-
"double_lock": {
|
| 80 |
-
"type": "template",
|
| 81 |
-
"buggy": "import threading\nlock = threading.Lock()\ndef do_work():\n lock.acquire()\n lock.acquire()\n print('working')\n lock.release()",
|
| 82 |
-
"oracle": "import threading\nlock = threading.Lock()\ndef do_work():\n with lock:\n print('working')"
|
| 83 |
-
},
|
| 84 |
-
"global_nonatomic": {
|
| 85 |
-
"type": "template",
|
| 86 |
-
"buggy": "count = 0\ndef add():\n global count\n count = count + 1",
|
| 87 |
-
"oracle": "count = 0\ndef add():\n global count\n count += 1"
|
| 88 |
-
},
|
| 89 |
-
"thread_safe_list": {
|
| 90 |
-
"type": "template",
|
| 91 |
-
"buggy": "import threading\nitems = []\ndef append_item(item):\n items.append(item)",
|
| 92 |
-
"oracle": "import threading\nitems = []\nlock = threading.Lock()\ndef append_item(item):\n with lock:\n items.append(item)"
|
| 93 |
-
},
|
| 94 |
-
"volatile_read": {
|
| 95 |
-
"type": "template",
|
| 96 |
-
"buggy": "import threading\nstop = False\ndef worker():\n while not stop:\n pass",
|
| 97 |
-
"oracle": "import threading\nstop = False\nlock = threading.Lock()\ndef worker():\n while True:\n with lock:\n if stop:\n break"
|
| 98 |
-
}
|
| 99 |
-
},
|
| 100 |
-
"hardest": {
|
| 101 |
-
"deadlock_order": {
|
| 102 |
-
"type": "template",
|
| 103 |
-
"buggy": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock2:\n with lock1:\n pass",
|
| 104 |
-
"oracle": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock1:\n with lock2:\n pass"
|
| 105 |
-
},
|
| 106 |
-
"nested_lock_timeout": {
|
| 107 |
-
"type": "template",
|
| 108 |
-
"buggy": "import threading\nlock = threading.Lock()\ndef work():\n lock.acquire()\n # critical section\n lock.release()",
|
| 109 |
-
"oracle": "import threading\nlock = threading.Lock()\ndef work():\n if lock.acquire(timeout=1):\n try:\n # critical section\n finally:\n lock.release()"
|
| 110 |
-
},
|
| 111 |
-
"fork_join": {
|
| 112 |
-
"type": "template",
|
| 113 |
-
"buggy": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()",
|
| 114 |
-
"oracle": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()\nt.join()"
|
| 115 |
-
},
|
| 116 |
-
"mutex_release": {
|
| 117 |
-
"type": "template",
|
| 118 |
-
"buggy": "import threading\nlock = threading.Lock()\ndef thread_A():\n lock.acquire()\n lock.release()\ndef thread_B():\n lock.release()",
|
| 119 |
-
"oracle": "import threading\nlock = threading.Lock()\ndef thread_A():\n with lock:\n pass\ndef thread_B():\n with lock:\n pass"
|
| 120 |
-
},
|
| 121 |
-
"race_on_init": {
|
| 122 |
-
"type": "template",
|
| 123 |
-
"buggy": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nprint(items)",
|
| 124 |
-
"oracle": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nt.join()\nprint(items)"
|
| 125 |
-
}
|
| 126 |
-
}
|
| 127 |
-
}
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"easy": {
|
| 3 |
+
"null_check": {
|
| 4 |
+
"type": "ast",
|
| 5 |
+
"bug_type": "null_check",
|
| 6 |
+
"oracle_hint": "Add back the if-guard that was removed"
|
| 7 |
+
},
|
| 8 |
+
"simple_typo": {
|
| 9 |
+
"type": "ast",
|
| 10 |
+
"bug_type": "simple_typo",
|
| 11 |
+
"oracle_hint": "Fix the misspelled variable name"
|
| 12 |
+
},
|
| 13 |
+
"string_index": {
|
| 14 |
+
"type": "ast",
|
| 15 |
+
"bug_type": "string_index",
|
| 16 |
+
"oracle_hint": "Correct the index offset"
|
| 17 |
+
},
|
| 18 |
+
"default_value": {
|
| 19 |
+
"type": "ast",
|
| 20 |
+
"bug_type": "default_value",
|
| 21 |
+
"oracle_hint": "Restore dict.get() with proper default"
|
| 22 |
+
},
|
| 23 |
+
"empty_return": {
|
| 24 |
+
"type": "ast",
|
| 25 |
+
"bug_type": "empty_return",
|
| 26 |
+
"oracle_hint": "Remove the premature return None"
|
| 27 |
+
}
|
| 28 |
+
},
|
| 29 |
+
"medium": {
|
| 30 |
+
"off_by_one": {
|
| 31 |
+
"type": "ast",
|
| 32 |
+
"bug_type": "off_by_one"
|
| 33 |
+
},
|
| 34 |
+
"loop_skip": {
|
| 35 |
+
"type": "ast",
|
| 36 |
+
"bug_type": "loop_skip"
|
| 37 |
+
},
|
| 38 |
+
"sign_error": {
|
| 39 |
+
"type": "ast",
|
| 40 |
+
"bug_type": "sign_error"
|
| 41 |
+
},
|
| 42 |
+
"swap_args": {
|
| 43 |
+
"type": "ast",
|
| 44 |
+
"bug_type": "swap_args"
|
| 45 |
+
},
|
| 46 |
+
"uninitialised_var": {
|
| 47 |
+
"type": "ast",
|
| 48 |
+
"bug_type": "uninitialised_var"
|
| 49 |
+
}
|
| 50 |
+
},
|
| 51 |
+
"hard": {
|
| 52 |
+
"division_by_zero_empty": {
|
| 53 |
+
"type": "ast",
|
| 54 |
+
"bug_type": "division_by_zero_empty"
|
| 55 |
+
},
|
| 56 |
+
"division_by_zero_zero": {
|
| 57 |
+
"type": "ast",
|
| 58 |
+
"bug_type": "division_by_zero_zero"
|
| 59 |
+
},
|
| 60 |
+
"float_precision": {
|
| 61 |
+
"type": "ast",
|
| 62 |
+
"bug_type": "float_precision"
|
| 63 |
+
},
|
| 64 |
+
"abs_usage": {
|
| 65 |
+
"type": "ast",
|
| 66 |
+
"bug_type": "abs_usage"
|
| 67 |
+
},
|
| 68 |
+
"round_error": {
|
| 69 |
+
"type": "ast",
|
| 70 |
+
"bug_type": "round_error"
|
| 71 |
+
}
|
| 72 |
+
},
|
| 73 |
+
"harder": {
|
| 74 |
+
"missing_lock": {
|
| 75 |
+
"type": "template",
|
| 76 |
+
"buggy": "counter = 0\ndef increment():\n global counter\n counter += 1",
|
| 77 |
+
"oracle": "counter = 0\nimport threading\nlock = threading.Lock()\ndef increment():\n global counter\n with lock:\n counter += 1"
|
| 78 |
+
},
|
| 79 |
+
"double_lock": {
|
| 80 |
+
"type": "template",
|
| 81 |
+
"buggy": "import threading\nlock = threading.Lock()\ndef do_work():\n lock.acquire()\n lock.acquire()\n print('working')\n lock.release()",
|
| 82 |
+
"oracle": "import threading\nlock = threading.Lock()\ndef do_work():\n with lock:\n print('working')"
|
| 83 |
+
},
|
| 84 |
+
"global_nonatomic": {
|
| 85 |
+
"type": "template",
|
| 86 |
+
"buggy": "count = 0\ndef add():\n global count\n count = count + 1",
|
| 87 |
+
"oracle": "count = 0\ndef add():\n global count\n count += 1"
|
| 88 |
+
},
|
| 89 |
+
"thread_safe_list": {
|
| 90 |
+
"type": "template",
|
| 91 |
+
"buggy": "import threading\nitems = []\ndef append_item(item):\n items.append(item)",
|
| 92 |
+
"oracle": "import threading\nitems = []\nlock = threading.Lock()\ndef append_item(item):\n with lock:\n items.append(item)"
|
| 93 |
+
},
|
| 94 |
+
"volatile_read": {
|
| 95 |
+
"type": "template",
|
| 96 |
+
"buggy": "import threading\nstop = False\ndef worker():\n while not stop:\n pass",
|
| 97 |
+
"oracle": "import threading\nstop = False\nlock = threading.Lock()\ndef worker():\n while True:\n with lock:\n if stop:\n break"
|
| 98 |
+
}
|
| 99 |
+
},
|
| 100 |
+
"hardest": {
|
| 101 |
+
"deadlock_order": {
|
| 102 |
+
"type": "template",
|
| 103 |
+
"buggy": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock2:\n with lock1:\n pass",
|
| 104 |
+
"oracle": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock1:\n with lock2:\n pass"
|
| 105 |
+
},
|
| 106 |
+
"nested_lock_timeout": {
|
| 107 |
+
"type": "template",
|
| 108 |
+
"buggy": "import threading\nlock = threading.Lock()\ndef work():\n lock.acquire()\n # critical section\n lock.release()",
|
| 109 |
+
"oracle": "import threading\nlock = threading.Lock()\ndef work():\n if lock.acquire(timeout=1):\n try:\n # critical section\n finally:\n lock.release()"
|
| 110 |
+
},
|
| 111 |
+
"fork_join": {
|
| 112 |
+
"type": "template",
|
| 113 |
+
"buggy": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()",
|
| 114 |
+
"oracle": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()\nt.join()"
|
| 115 |
+
},
|
| 116 |
+
"mutex_release": {
|
| 117 |
+
"type": "template",
|
| 118 |
+
"buggy": "import threading\nlock = threading.Lock()\ndef thread_A():\n lock.acquire()\n lock.release()\ndef thread_B():\n lock.release()",
|
| 119 |
+
"oracle": "import threading\nlock = threading.Lock()\ndef thread_A():\n with lock:\n pass\ndef thread_B():\n with lock:\n pass"
|
| 120 |
+
},
|
| 121 |
+
"race_on_init": {
|
| 122 |
+
"type": "template",
|
| 123 |
+
"buggy": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nprint(items)",
|
| 124 |
+
"oracle": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nt.join()\nprint(items)"
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
}
|
client.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
# client.py – OpenEnv client entry point
|
| 2 |
-
from environment import CodeReviewEnv
|
| 3 |
-
|
| 4 |
-
# The OpenEnv framework will import this class as the environment.
|
| 5 |
__all__ = ["CodeReviewEnv"]
|
|
|
|
| 1 |
+
# client.py – OpenEnv client entry point
|
| 2 |
+
from environment import CodeReviewEnv
|
| 3 |
+
|
| 4 |
+
# The OpenEnv framework will import this class as the environment.
|
| 5 |
__all__ = ["CodeReviewEnv"]
|
environment.py
CHANGED
|
@@ -1,556 +1,613 @@
|
|
| 1 |
-
# environment.py – FULLY CORRECTED RL Environment (TRUE Markov + Fixed Bugs)
|
| 2 |
-
|
| 3 |
-
import sys
|
| 4 |
-
import subprocess
|
| 5 |
-
import tempfile
|
| 6 |
-
import os
|
| 7 |
-
import re
|
| 8 |
-
from dataclasses import dataclass, field
|
| 9 |
-
from typing import Tuple, Dict, Any, Optional, List
|
| 10 |
-
|
| 11 |
-
from models import (
|
| 12 |
-
AnyAction, WriteComment, ProposeFix, Execute, Inspect,
|
| 13 |
-
RunLinter, RunTests, QueryDocs, Skip, Done, AskQuestion,
|
| 14 |
-
Observation, Reward, State
|
| 15 |
-
)
|
| 16 |
-
from redteam import RedTeam
|
| 17 |
-
from test_runner import TestRunner
|
| 18 |
-
from author import PersonaAuthor
|
| 19 |
-
from rltool import ToolBox
|
| 20 |
-
from rubrics import (
|
| 21 |
-
ToolUsageRubric,
|
| 22 |
-
TestDeltaRubric,
|
| 23 |
-
LintDeltaRubric,
|
| 24 |
-
TerminalSuccessRubric,
|
| 25 |
-
ExplorationRubric,
|
| 26 |
-
AntiHackingRubric,
|
| 27 |
-
StepPenaltyRubric,
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
-
# ======================================================================
|
| 31 |
-
# FULLY MARKOV OBSERVATION (NOTHING HIDDEN)
|
| 32 |
-
# ======================================================================
|
| 33 |
-
@dataclass
|
| 34 |
-
class EnhancedObservation:
|
| 35 |
-
code_snippet: str
|
| 36 |
-
last_tool_output: str
|
| 37 |
-
|
| 38 |
-
current_test_score: float
|
| 39 |
-
current_lint_score: float
|
| 40 |
-
negotiation_score: float
|
| 41 |
-
|
| 42 |
-
previous_test_score: float
|
| 43 |
-
previous_lint_score: float
|
| 44 |
-
|
| 45 |
-
author_confidence: float
|
| 46 |
-
author_threshold: float
|
| 47 |
-
|
| 48 |
-
step: int
|
| 49 |
-
max_steps: int
|
| 50 |
-
progress_ratio: float
|
| 51 |
-
|
| 52 |
-
tests_run: bool
|
| 53 |
-
linter_run: bool
|
| 54 |
-
docs_queried: bool
|
| 55 |
-
|
| 56 |
-
last_action_type: str
|
| 57 |
-
action_history: List[str]
|
| 58 |
-
|
| 59 |
-
done: bool
|
| 60 |
-
|
| 61 |
-
bug_description: str
|
| 62 |
-
comments_count: int
|
| 63 |
-
|
| 64 |
-
# default fields must be at the very end
|
| 65 |
-
author_response: str = ""
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
#
|
| 69 |
-
#
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
f.
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
)
|
| 86 |
-
success
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
#
|
| 100 |
-
#
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# Curriculum learning
|
| 109 |
-
auto_difficulty: bool = False
|
| 110 |
-
success_threshold: float = 0.7
|
| 111 |
-
|
| 112 |
-
# Reward shaping parameters
|
| 113 |
-
delta_weight: float = 0.3
|
| 114 |
-
tool_usage_bonus: float = 0.05
|
| 115 |
-
diversity_bonus: float = 0.03
|
| 116 |
-
|
| 117 |
-
_red_team: Optional[RedTeam] = field(init=False, default=None)
|
| 118 |
-
_author: Optional[PersonaAuthor] = field(init=False, default=None)
|
| 119 |
-
|
| 120 |
-
_current_code: str = field(init=False, default="")
|
| 121 |
-
_current_bug_id: str = field(init=False, default="")
|
| 122 |
-
_bug_description: str = field(init=False, default="")
|
| 123 |
-
_oracle_fix: str = field(init=False, default="")
|
| 124 |
-
|
| 125 |
-
_comments: list = field(init=False, default_factory=list)
|
| 126 |
-
_test_results: Optional[str] = field(init=False, default=None)
|
| 127 |
-
_lint_results: Optional[str] = field(init=False, default=None)
|
| 128 |
-
_doc_results: Optional[str] = field(init=False, default=None)
|
| 129 |
-
|
| 130 |
-
_step_count: int = field(init=False, default=0)
|
| 131 |
-
_done: bool = field(init=False, default=False)
|
| 132 |
-
|
| 133 |
-
# State tracking for dense rewards
|
| 134 |
-
_previous_test_score: float = field(init=False, default=0.0)
|
| 135 |
-
_previous_lint_score: float = field(init=False, default=0.0)
|
| 136 |
-
_current_test_score: float = field(init=False, default=0.0)
|
| 137 |
-
_current_lint_score: float = field(init=False, default=0.0)
|
| 138 |
-
|
| 139 |
-
# Tool usage tracking
|
| 140 |
-
_tests_run: bool = field(init=False, default=False)
|
| 141 |
-
_linter_run: bool = field(init=False, default=False)
|
| 142 |
-
_docs_queried: bool = field(init=False, default=False)
|
| 143 |
-
|
| 144 |
-
# Action history
|
| 145 |
-
_action_history: List[str] = field(init=False, default_factory=list)
|
| 146 |
-
_last_action_type: str = field(init=False, default="none")
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
#
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
self.
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
self.
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
self.
|
| 203 |
-
self.
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
self.
|
| 224 |
-
self.
|
| 225 |
-
self.
|
| 226 |
-
self.
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
self.
|
| 247 |
-
return
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
#
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
self._test_results = f"[
|
| 398 |
-
base_reward = 0.001
|
| 399 |
-
|
| 400 |
-
elif isinstance(action,
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
)
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
self.
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
self.
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
self.
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# environment.py – FULLY CORRECTED RL Environment (TRUE Markov + Fixed Bugs)
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
import subprocess
|
| 5 |
+
import tempfile
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from typing import Tuple, Dict, Any, Optional, List
|
| 10 |
+
|
| 11 |
+
from models import (
|
| 12 |
+
AnyAction, WriteComment, ProposeFix, Execute, Inspect,
|
| 13 |
+
RunLinter, RunTests, QueryDocs, Skip, Done, AskQuestion,
|
| 14 |
+
Observation, Reward, State
|
| 15 |
+
)
|
| 16 |
+
from redteam import RedTeam
|
| 17 |
+
from test_runner import TestRunner
|
| 18 |
+
from author import PersonaAuthor
|
| 19 |
+
from rltool import ToolBox
|
| 20 |
+
from rubrics import (
|
| 21 |
+
ToolUsageRubric,
|
| 22 |
+
TestDeltaRubric,
|
| 23 |
+
LintDeltaRubric,
|
| 24 |
+
TerminalSuccessRubric,
|
| 25 |
+
ExplorationRubric,
|
| 26 |
+
AntiHackingRubric,
|
| 27 |
+
StepPenaltyRubric,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# ======================================================================
|
| 31 |
+
# FULLY MARKOV OBSERVATION (NOTHING HIDDEN)
|
| 32 |
+
# ======================================================================
|
| 33 |
+
@dataclass
|
| 34 |
+
class EnhancedObservation:
|
| 35 |
+
code_snippet: str
|
| 36 |
+
last_tool_output: str
|
| 37 |
+
|
| 38 |
+
current_test_score: float
|
| 39 |
+
current_lint_score: float
|
| 40 |
+
negotiation_score: float
|
| 41 |
+
|
| 42 |
+
previous_test_score: float
|
| 43 |
+
previous_lint_score: float
|
| 44 |
+
|
| 45 |
+
author_confidence: float
|
| 46 |
+
author_threshold: float
|
| 47 |
+
|
| 48 |
+
step: int
|
| 49 |
+
max_steps: int
|
| 50 |
+
progress_ratio: float
|
| 51 |
+
|
| 52 |
+
tests_run: bool
|
| 53 |
+
linter_run: bool
|
| 54 |
+
docs_queried: bool
|
| 55 |
+
|
| 56 |
+
last_action_type: str
|
| 57 |
+
action_history: List[str]
|
| 58 |
+
|
| 59 |
+
done: bool
|
| 60 |
+
|
| 61 |
+
bug_description: str
|
| 62 |
+
comments_count: int
|
| 63 |
+
|
| 64 |
+
# default fields must be at the very end
|
| 65 |
+
author_response: str = ""
|
| 66 |
+
|
| 67 |
+
# ======================================================================
|
| 68 |
+
# HELPER FUNCTIONS
|
| 69 |
+
# ======================================================================
|
| 70 |
+
def execute_code(code: str, timeout_sec: int = 5) -> Tuple[bool, str, str]:
|
| 71 |
+
if not code.strip():
|
| 72 |
+
return False, "", "Error: Empty code"
|
| 73 |
+
|
| 74 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
|
| 75 |
+
f.write(code)
|
| 76 |
+
tmp_path = f.name
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
result = subprocess.run(
|
| 80 |
+
[sys.executable, tmp_path],
|
| 81 |
+
capture_output=True,
|
| 82 |
+
text=True,
|
| 83 |
+
timeout=timeout_sec
|
| 84 |
+
)
|
| 85 |
+
success = (result.returncode == 0)
|
| 86 |
+
return success, result.stdout, result.stderr
|
| 87 |
+
except subprocess.TimeoutExpired:
|
| 88 |
+
return False, "", f"Timeout after {timeout_sec}s"
|
| 89 |
+
except Exception as e:
|
| 90 |
+
return False, "", f"Execution error: {str(e)}"
|
| 91 |
+
finally:
|
| 92 |
+
try:
|
| 93 |
+
os.unlink(tmp_path)
|
| 94 |
+
except:
|
| 95 |
+
pass
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# ======================================================================
|
| 99 |
+
# ENHANCED CODE REVIEW ENVIRONMENT
|
| 100 |
+
# ======================================================================
|
| 101 |
+
@dataclass
|
| 102 |
+
class CodeReviewEnv:
|
| 103 |
+
task: str = "easy"
|
| 104 |
+
max_steps: int = 10
|
| 105 |
+
step_penalty: float = 0.01
|
| 106 |
+
reward_profile: str = "full" # "full" or "core"
|
| 107 |
+
|
| 108 |
+
# Curriculum learning
|
| 109 |
+
auto_difficulty: bool = False
|
| 110 |
+
success_threshold: float = 0.7
|
| 111 |
+
|
| 112 |
+
# Reward shaping parameters
|
| 113 |
+
delta_weight: float = 0.3
|
| 114 |
+
tool_usage_bonus: float = 0.05
|
| 115 |
+
diversity_bonus: float = 0.03
|
| 116 |
+
|
| 117 |
+
_red_team: Optional[RedTeam] = field(init=False, default=None)
|
| 118 |
+
_author: Optional[PersonaAuthor] = field(init=False, default=None)
|
| 119 |
+
|
| 120 |
+
_current_code: str = field(init=False, default="")
|
| 121 |
+
_current_bug_id: str = field(init=False, default="")
|
| 122 |
+
_bug_description: str = field(init=False, default="")
|
| 123 |
+
_oracle_fix: str = field(init=False, default="")
|
| 124 |
+
|
| 125 |
+
_comments: list = field(init=False, default_factory=list)
|
| 126 |
+
_test_results: Optional[str] = field(init=False, default=None)
|
| 127 |
+
_lint_results: Optional[str] = field(init=False, default=None)
|
| 128 |
+
_doc_results: Optional[str] = field(init=False, default=None)
|
| 129 |
+
|
| 130 |
+
_step_count: int = field(init=False, default=0)
|
| 131 |
+
_done: bool = field(init=False, default=False)
|
| 132 |
+
|
| 133 |
+
# State tracking for dense rewards
|
| 134 |
+
_previous_test_score: float = field(init=False, default=0.0)
|
| 135 |
+
_previous_lint_score: float = field(init=False, default=0.0)
|
| 136 |
+
_current_test_score: float = field(init=False, default=0.0)
|
| 137 |
+
_current_lint_score: float = field(init=False, default=0.0)
|
| 138 |
+
|
| 139 |
+
# Tool usage tracking
|
| 140 |
+
_tests_run: bool = field(init=False, default=False)
|
| 141 |
+
_linter_run: bool = field(init=False, default=False)
|
| 142 |
+
_docs_queried: bool = field(init=False, default=False)
|
| 143 |
+
|
| 144 |
+
# Action history
|
| 145 |
+
_action_history: List[str] = field(init=False, default_factory=list)
|
| 146 |
+
_last_action_type: str = field(init=False, default="none")
|
| 147 |
+
_last_author_response: str = field(init=False, default="")
|
| 148 |
+
|
| 149 |
+
# FIXED: Track CUMULATIVE episode reward
|
| 150 |
+
_episode_total_reward: float = field(init=False, default=0.0)
|
| 151 |
+
_episode_rewards: List[float] = field(init=False, default_factory=list)
|
| 152 |
+
_difficulty_level: int = field(init=False, default=0)
|
| 153 |
+
|
| 154 |
+
# Bug-id bridge:
|
| 155 |
+
# RedTeam has fine-grained IDs, while TestRunner currently expects a
|
| 156 |
+
# smaller canonical set. Keep this mapping here so both modules can evolve
|
| 157 |
+
# independently without breaking evaluation.
|
| 158 |
+
_BUG_ID_CANONICAL_MAP = {
|
| 159 |
+
"division_by_zero_empty": "division_by_zero",
|
| 160 |
+
"division_by_zero_zero": "division_by_zero",
|
| 161 |
+
"sign_error": "wrong_operator",
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
# ===================================================================
|
| 165 |
+
def __post_init__(self):
|
| 166 |
+
self.set_task(self.task)
|
| 167 |
+
|
| 168 |
+
# ===================================================================
|
| 169 |
+
def _build_rubrics(self):
|
| 170 |
+
"""
|
| 171 |
+
Build rubric stack from a named reward profile.
|
| 172 |
+
- full: richer shaping for exploration/tool-use behavior
|
| 173 |
+
- core: minimal stable signal for quick ablations/baselines
|
| 174 |
+
"""
|
| 175 |
+
core_rubrics = [
|
| 176 |
+
TestDeltaRubric(weight=self.delta_weight),
|
| 177 |
+
LintDeltaRubric(weight=self.delta_weight),
|
| 178 |
+
TerminalSuccessRubric(),
|
| 179 |
+
StepPenaltyRubric(penalty=self.step_penalty),
|
| 180 |
+
]
|
| 181 |
+
if self.reward_profile == "core":
|
| 182 |
+
return core_rubrics
|
| 183 |
+
if self.reward_profile == "full":
|
| 184 |
+
return [
|
| 185 |
+
*core_rubrics[:-1], # step penalty appended at end for consistent ordering
|
| 186 |
+
ToolUsageRubric(bonus=self.tool_usage_bonus),
|
| 187 |
+
ExplorationRubric(penalty=-0.05, bonus=self.diversity_bonus * 0.7),
|
| 188 |
+
AntiHackingRubric(),
|
| 189 |
+
core_rubrics[-1],
|
| 190 |
+
]
|
| 191 |
+
raise ValueError(f"Unknown reward_profile: {self.reward_profile}")
|
| 192 |
+
|
| 193 |
+
# ===================================================================
|
| 194 |
+
def set_task(self, task: str):
|
| 195 |
+
if task not in ["easy", "medium", "hard", "harder", "hardest"]:
|
| 196 |
+
raise ValueError(f"Unknown task: {task}")
|
| 197 |
+
|
| 198 |
+
self.task = task
|
| 199 |
+
# Use stochastic bug sampling across episodes; fixed seed here would
|
| 200 |
+
# repeatedly select the same bug and weaken training diversity.
|
| 201 |
+
self._red_team = RedTeam(task, seed=None)
|
| 202 |
+
self._author = PersonaAuthor()
|
| 203 |
+
self.rubrics = self._build_rubrics()
|
| 204 |
+
|
| 205 |
+
task_to_level = {
|
| 206 |
+
"easy": 0, "medium": 1, "hard": 2,
|
| 207 |
+
"harder": 3, "hardest": 4
|
| 208 |
+
}
|
| 209 |
+
self._difficulty_level = task_to_level[task]
|
| 210 |
+
|
| 211 |
+
self._reset_internal()
|
| 212 |
+
|
| 213 |
+
# ===================================================================
|
| 214 |
+
def _reset_internal(self):
|
| 215 |
+
self._step_count = 0 # ← FIXED
|
| 216 |
+
self._comments = []
|
| 217 |
+
self._test_results = None
|
| 218 |
+
self._lint_results = None
|
| 219 |
+
self._doc_results = None
|
| 220 |
+
self._done = False
|
| 221 |
+
|
| 222 |
+
# Reset state tracking
|
| 223 |
+
self._previous_test_score = 0.0
|
| 224 |
+
self._previous_lint_score = 0.0
|
| 225 |
+
self._current_test_score = 0.0
|
| 226 |
+
self._current_lint_score = 0.0
|
| 227 |
+
|
| 228 |
+
self._tests_run = False
|
| 229 |
+
self._linter_run = False
|
| 230 |
+
self._docs_queried = False
|
| 231 |
+
|
| 232 |
+
self._action_history = []
|
| 233 |
+
self._last_action_type = "none"
|
| 234 |
+
self._last_author_response = ""
|
| 235 |
+
|
| 236 |
+
# FIXED: Reset episode cumulative reward
|
| 237 |
+
self._episode_total_reward = 0.0
|
| 238 |
+
|
| 239 |
+
self._author.reset()
|
| 240 |
+
|
| 241 |
+
# Base tasks
|
| 242 |
+
if self.task == "easy":
|
| 243 |
+
original = "def get_user(id):\n if id in users:\n return users[id]"
|
| 244 |
+
elif self.task == "medium":
|
| 245 |
+
original = "def process_items(items):\n for item in items:\n print(item)"
|
| 246 |
+
elif self.task == "hard":
|
| 247 |
+
original = "def average(data):\n if not data:\n return 0\n return sum(data) / len(data)"
|
| 248 |
+
elif self.task == "harder":
|
| 249 |
+
original = "counter = 0\ndef increment():\n global counter\n with lock:\n counter += 1"
|
| 250 |
+
else:
|
| 251 |
+
original = "def safe_work():\n with lock1:\n with lock2:\n do_work()"
|
| 252 |
+
|
| 253 |
+
buggy_code, bug_id, desc, oracle = self._red_team.inject_bug(original)
|
| 254 |
+
self._current_code = buggy_code
|
| 255 |
+
self._current_bug_id = bug_id
|
| 256 |
+
self._bug_description = desc
|
| 257 |
+
self._oracle_fix = oracle
|
| 258 |
+
self._comments.append(f"[RedTeam] {desc}")
|
| 259 |
+
|
| 260 |
+
# ===================================================================
|
| 261 |
+
def reset(self) -> EnhancedObservation:
|
| 262 |
+
"""Reset with optional curriculum adjustment."""
|
| 263 |
+
if self.auto_difficulty and len(self._episode_rewards) > 0:
|
| 264 |
+
recent_performance = sum(self._episode_rewards[-5:]) / min(5, len(self._episode_rewards))
|
| 265 |
+
|
| 266 |
+
if recent_performance > self.success_threshold and self._difficulty_level < 4:
|
| 267 |
+
self._difficulty_level += 1
|
| 268 |
+
print(f"[Curriculum] Increasing difficulty to level {self._difficulty_level}")
|
| 269 |
+
elif recent_performance < 0.3 and self._difficulty_level > 0:
|
| 270 |
+
self._difficulty_level -= 1
|
| 271 |
+
print(f"[Curriculum] Decreasing difficulty to level {self._difficulty_level}")
|
| 272 |
+
|
| 273 |
+
level_to_task = {0: "easy", 1: "medium", 2: "hard", 3: "harder", 4: "hardest"}
|
| 274 |
+
self.task = level_to_task[self._difficulty_level]
|
| 275 |
+
# Keep curriculum stochastic for better coverage within each level.
|
| 276 |
+
self._red_team = RedTeam(self.task, seed=None)
|
| 277 |
+
|
| 278 |
+
self._reset_internal()
|
| 279 |
+
return self._get_observation()
|
| 280 |
+
|
| 281 |
+
# ===================================================================
|
| 282 |
+
def _get_observation(self) -> EnhancedObservation:
|
| 283 |
+
"""Return COMPLETE Markov state."""
|
| 284 |
+
# Keep the author's message separate from tool output.
|
| 285 |
+
# Using `_test_results` here can leak unrelated outputs (tests/linter/docs)
|
| 286 |
+
# and gives the policy a noisy signal for dialogue actions.
|
| 287 |
+
if self._last_action_type in ("write_comment", "ask_question", "propose_fix"):
|
| 288 |
+
author_response = self._last_author_response
|
| 289 |
+
else:
|
| 290 |
+
author_response = ""
|
| 291 |
+
|
| 292 |
+
return EnhancedObservation(
|
| 293 |
+
code_snippet=self._current_code,
|
| 294 |
+
last_tool_output=self._test_results or "",
|
| 295 |
+
author_response=author_response, # ← now field exists
|
| 296 |
+
|
| 297 |
+
current_test_score=self._current_test_score,
|
| 298 |
+
current_lint_score=self._current_lint_score,
|
| 299 |
+
negotiation_score=self._author.get_negotiation_score(),
|
| 300 |
+
|
| 301 |
+
previous_test_score=self._previous_test_score,
|
| 302 |
+
previous_lint_score=self._previous_lint_score,
|
| 303 |
+
|
| 304 |
+
author_confidence=self._author._confidence,
|
| 305 |
+
author_threshold=self._author.thresholds.get(self._author.personality, 0.5),
|
| 306 |
+
|
| 307 |
+
step=self._step_count,
|
| 308 |
+
max_steps=self.max_steps,
|
| 309 |
+
# Guard against accidental `max_steps=0` configs.
|
| 310 |
+
progress_ratio=(self._step_count / self.max_steps) if self.max_steps > 0 else 1.0,
|
| 311 |
+
|
| 312 |
+
tests_run=self._tests_run,
|
| 313 |
+
linter_run=self._linter_run,
|
| 314 |
+
docs_queried=self._docs_queried,
|
| 315 |
+
|
| 316 |
+
last_action_type=self._last_action_type,
|
| 317 |
+
action_history=self._action_history[-5:],
|
| 318 |
+
|
| 319 |
+
done=self._done,
|
| 320 |
+
|
| 321 |
+
bug_description=self._bug_description,
|
| 322 |
+
comments_count=len(self._comments),
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# ===================================================================
|
| 326 |
+
def _get_action_type(self, action: AnyAction) -> str:
|
| 327 |
+
"""Extract action type as string."""
|
| 328 |
+
if isinstance(action, RunTests):
|
| 329 |
+
return "run_tests"
|
| 330 |
+
elif isinstance(action, RunLinter):
|
| 331 |
+
return "run_linter"
|
| 332 |
+
elif isinstance(action, QueryDocs):
|
| 333 |
+
return "query_docs"
|
| 334 |
+
elif isinstance(action, Execute):
|
| 335 |
+
return "execute"
|
| 336 |
+
elif isinstance(action, Inspect):
|
| 337 |
+
return "inspect"
|
| 338 |
+
elif isinstance(action, WriteComment):
|
| 339 |
+
return "write_comment"
|
| 340 |
+
elif isinstance(action, AskQuestion):
|
| 341 |
+
return "ask_question"
|
| 342 |
+
elif isinstance(action, ProposeFix):
|
| 343 |
+
return "propose_fix"
|
| 344 |
+
elif isinstance(action, Done):
|
| 345 |
+
return "done"
|
| 346 |
+
elif isinstance(action, Skip):
|
| 347 |
+
return "skip"
|
| 348 |
+
else:
|
| 349 |
+
return "unknown"
|
| 350 |
+
|
| 351 |
+
# ===================================================================
|
| 352 |
+
def _get_test_runner_bug_id(self) -> str:
|
| 353 |
+
"""
|
| 354 |
+
Normalize RedTeam bug ids to the canonical ids understood by TestRunner.
|
| 355 |
+
Falls back to the original id for known direct matches.
|
| 356 |
+
"""
|
| 357 |
+
return self._BUG_ID_CANONICAL_MAP.get(self._current_bug_id, self._current_bug_id)
|
| 358 |
+
|
| 359 |
+
# ===================================================================
|
| 360 |
+
def step(self, action: AnyAction) -> Tuple[EnhancedObservation, Reward, bool, Dict[str, Any]]:
|
| 361 |
+
"""
|
| 362 |
+
TRUE RL STEP with:
|
| 363 |
+
- Complete Markov observations (no hidden state)
|
| 364 |
+
- Dense intermediate rewards
|
| 365 |
+
- Delta-based credit assignment (no double-counting)
|
| 366 |
+
- Proper episode reward tracking
|
| 367 |
+
"""
|
| 368 |
+
if self._done:
|
| 369 |
+
raise RuntimeError("Episode already finished")
|
| 370 |
+
|
| 371 |
+
# Store previous metrics for delta computation
|
| 372 |
+
self._previous_test_score = self._current_test_score
|
| 373 |
+
self._previous_lint_score = self._current_lint_score
|
| 374 |
+
# Snapshot tool-usage flags BEFORE action mutates them.
|
| 375 |
+
# Rubrics use these to detect true "first-use" behavior.
|
| 376 |
+
prev_tests_run = self._tests_run
|
| 377 |
+
prev_linter_run = self._linter_run
|
| 378 |
+
prev_docs_queried = self._docs_queried
|
| 379 |
+
|
| 380 |
+
base_reward = 0.0
|
| 381 |
+
action_type = self._get_action_type(action)
|
| 382 |
+
|
| 383 |
+
# Update action history
|
| 384 |
+
self._action_history.append(action_type)
|
| 385 |
+
self._last_action_type = action_type
|
| 386 |
+
|
| 387 |
+
# ==============================================================
|
| 388 |
+
# TOOL ACTIONS
|
| 389 |
+
# ==============================================================
|
| 390 |
+
if isinstance(action, Execute):
|
| 391 |
+
success, stdout, stderr = execute_code(self._current_code)
|
| 392 |
+
output = (stdout + stderr).strip() or "No output"
|
| 393 |
+
self._test_results = f"[Execute] {'Success' if success else 'Failed'}\n{output[:300]}"
|
| 394 |
+
base_reward = 0.001 if success else -0.05
|
| 395 |
+
|
| 396 |
+
elif isinstance(action, Inspect):
|
| 397 |
+
self._test_results = f"[Inspect]\n{self._current_code[:500]}"
|
| 398 |
+
base_reward = 0.001
|
| 399 |
+
|
| 400 |
+
elif isinstance(action, RunLinter):
|
| 401 |
+
lint_output = ToolBox.run_linter(self._current_code)
|
| 402 |
+
self._lint_results = lint_output[:500]
|
| 403 |
+
self._test_results = f"[Linter]\n{self._lint_results}"
|
| 404 |
+
|
| 405 |
+
self._current_lint_score = self._run_linter_score(self._current_code)
|
| 406 |
+
self._linter_run = True
|
| 407 |
+
base_reward = 0.002
|
| 408 |
+
|
| 409 |
+
elif isinstance(action, RunTests):
|
| 410 |
+
runner = TestRunner(self._get_test_runner_bug_id())
|
| 411 |
+
score, output = runner.run_tests(self._current_code)
|
| 412 |
+
|
| 413 |
+
self._current_test_score = score
|
| 414 |
+
self._tests_run = True
|
| 415 |
+
|
| 416 |
+
self._test_results = f"[Tests] Score: {score:.2f}\n{output[:300]}"
|
| 417 |
+
base_reward = 0.002
|
| 418 |
+
|
| 419 |
+
if score > 0.8:
|
| 420 |
+
base_reward += 0.005
|
| 421 |
+
|
| 422 |
+
elif isinstance(action, QueryDocs):
|
| 423 |
+
# Normalize query to avoid rewarding empty/noisy requests.
|
| 424 |
+
query_topic = (action.query_topic or "").strip()
|
| 425 |
+
doc = ToolBox.query_docs(query_topic if query_topic else "general bug fixing")
|
| 426 |
+
self._doc_results = doc
|
| 427 |
+
self._test_results = f"[Docs]\n{doc[:400]}"
|
| 428 |
+
self._docs_queried = True
|
| 429 |
+
base_reward = 0.001
|
| 430 |
+
|
| 431 |
+
# ==============================================================
|
| 432 |
+
# COMMUNICATION ACTIONS
|
| 433 |
+
# ==============================================================
|
| 434 |
+
elif isinstance(action, WriteComment):
|
| 435 |
+
self._comments.append(f"Agent: {action.comment_text}")
|
| 436 |
+
|
| 437 |
+
response = self._author.respond(
|
| 438 |
+
agent_comment=action.comment_text,
|
| 439 |
+
test_results=self._test_results,
|
| 440 |
+
lint_results=self._lint_results,
|
| 441 |
+
doc_results=self._doc_results,
|
| 442 |
+
proposed_fix=None,
|
| 443 |
+
original_code=self._current_code
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
self._comments.append(f"Author: {response}")
|
| 447 |
+
self._last_author_response = response
|
| 448 |
+
self._test_results = f"[Comment] Author: {response[:200]}"
|
| 449 |
+
base_reward = 0.001
|
| 450 |
+
|
| 451 |
+
elif isinstance(action, AskQuestion):
|
| 452 |
+
self._comments.append(f"Agent: {action.question}")
|
| 453 |
+
|
| 454 |
+
response = self._author.respond(
|
| 455 |
+
agent_question=action.question,
|
| 456 |
+
test_results=self._test_results,
|
| 457 |
+
lint_results=self._lint_results,
|
| 458 |
+
doc_results=self._doc_results,
|
| 459 |
+
proposed_fix=None,
|
| 460 |
+
original_code=self._current_code # ← FIXED
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
self._comments.append(f"Author: {response}")
|
| 464 |
+
self._last_author_response = response
|
| 465 |
+
self._test_results = f"[Question] Author: {response[:200]}"
|
| 466 |
+
base_reward = 0.002
|
| 467 |
+
|
| 468 |
+
# ==============================================================
|
| 469 |
+
# FINAL FIX ACTION
|
| 470 |
+
# ==============================================================
|
| 471 |
+
elif isinstance(action, ProposeFix):
|
| 472 |
+
if not action.fix_code:
|
| 473 |
+
base_reward = -0.05
|
| 474 |
+
self._done = True
|
| 475 |
+
else:
|
| 476 |
+
# Save original code BEFORE overwriting (for author.respond)
|
| 477 |
+
original_buggy = self._current_code
|
| 478 |
+
self._current_code = action.fix_code
|
| 479 |
+
|
| 480 |
+
runner = TestRunner(self._get_test_runner_bug_id())
|
| 481 |
+
test_score, test_output = runner.run_tests(self._current_code)
|
| 482 |
+
lint_score = self._run_linter_score(self._current_code)
|
| 483 |
+
negotiation_score = self._author.get_negotiation_score()
|
| 484 |
+
|
| 485 |
+
self._current_test_score = test_score
|
| 486 |
+
self._current_lint_score = lint_score
|
| 487 |
+
|
| 488 |
+
# Author gating – determines if the episode ends, reward is separate
|
| 489 |
+
threshold = self._author.thresholds.get(self._author.personality, 0.5)
|
| 490 |
+
if self._author._confidence < threshold:
|
| 491 |
+
if self._step_count < self.max_steps:
|
| 492 |
+
self._done = False
|
| 493 |
+
else:
|
| 494 |
+
self._done = True
|
| 495 |
+
else:
|
| 496 |
+
self._done = True
|
| 497 |
+
|
| 498 |
+
# Get author's verbal feedback (pushback/acceptance)
|
| 499 |
+
author_feedback = self._author.respond(
|
| 500 |
+
agent_comment=f"Proposed fix:\n{action.fix_code}",
|
| 501 |
+
test_results=f"Score: {test_score:.2f}",
|
| 502 |
+
lint_results=f"Score: {lint_score:.2f}",
|
| 503 |
+
doc_results=self._doc_results,
|
| 504 |
+
proposed_fix=action.fix_code,
|
| 505 |
+
original_code=original_buggy # now correctly the buggy code, not the fix
|
| 506 |
+
)
|
| 507 |
+
self._test_results = f"[Fix] Author: {author_feedback[:200]}"
|
| 508 |
+
self._comments.append(f"Author: {author_feedback}")
|
| 509 |
+
self._last_author_response = author_feedback
|
| 510 |
+
|
| 511 |
+
base_reward = 0.001 # rubrics provide the real signal
|
| 512 |
+
|
| 513 |
+
# ==============================================================
|
| 514 |
+
# TERMINATION ACTIONS
|
| 515 |
+
# ==============================================================
|
| 516 |
+
elif isinstance(action, Skip):
|
| 517 |
+
base_reward = -0.03
|
| 518 |
+
self._done = True
|
| 519 |
+
|
| 520 |
+
elif isinstance(action, Done):
|
| 521 |
+
if self._tests_run:
|
| 522 |
+
base_reward = self._current_test_score * 0.5 - 0.2
|
| 523 |
+
else:
|
| 524 |
+
base_reward = -0.04
|
| 525 |
+
self._done = True
|
| 526 |
+
|
| 527 |
+
else:
|
| 528 |
+
base_reward = -0.02
|
| 529 |
+
self._done = True
|
| 530 |
+
|
| 531 |
+
# ==============================================================
|
| 532 |
+
# STEP UPDATE (before rubric computation so info contains final step)
|
| 533 |
+
# ==============================================================
|
| 534 |
+
self._step_count += 1
|
| 535 |
+
if self._step_count >= self.max_steps:
|
| 536 |
+
self._done = True
|
| 537 |
+
|
| 538 |
+
# Get fresh observation (needed for rubrics that may read obs)
|
| 539 |
+
obs = self._get_observation()
|
| 540 |
+
|
| 541 |
+
# Prepare info dict (rubrics may need action_type and deltas)
|
| 542 |
+
info = {
|
| 543 |
+
"action_type": action_type,
|
| 544 |
+
"test_score": self._current_test_score,
|
| 545 |
+
"lint_score": self._current_lint_score,
|
| 546 |
+
"test_delta": self._current_test_score - self._previous_test_score,
|
| 547 |
+
"lint_delta": self._current_lint_score - self._previous_lint_score,
|
| 548 |
+
"prev_tests_run": prev_tests_run,
|
| 549 |
+
"prev_linter_run": prev_linter_run,
|
| 550 |
+
"prev_docs_queried": prev_docs_queried,
|
| 551 |
+
"docs_query_len": len((action.query_topic or "").strip()) if isinstance(action, QueryDocs) else 0,
|
| 552 |
+
"base_reward": base_reward,
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
# ==============================================================
|
| 556 |
+
# COMPUTE FINAL REWARD USING RUBRICS
|
| 557 |
+
# ==============================================================
|
| 558 |
+
rubric_score = sum(r(self, action, obs, None, self._done, info) for r in self.rubrics)
|
| 559 |
+
final_reward = 0.4 * base_reward + rubric_score
|
| 560 |
+
final_reward = max(-1.0, min(1.0, final_reward)) # safety clip
|
| 561 |
+
|
| 562 |
+
# Track cumulative episode reward
|
| 563 |
+
self._episode_total_reward += final_reward
|
| 564 |
+
|
| 565 |
+
# Store episode total if done
|
| 566 |
+
if self._done:
|
| 567 |
+
self._episode_rewards.append(self._episode_total_reward)
|
| 568 |
+
|
| 569 |
+
# Complete info
|
| 570 |
+
info["final_reward"] = final_reward
|
| 571 |
+
info["episode_total"] = self._episode_total_reward
|
| 572 |
+
|
| 573 |
+
return obs, Reward(value=final_reward), self._done, info
|
| 574 |
+
|
| 575 |
+
# ===================================================================
|
| 576 |
+
def _run_linter_score(self, code: str) -> float:
|
| 577 |
+
"""Run pylint and return normalized score [0, 1]."""
|
| 578 |
+
try:
|
| 579 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
| 580 |
+
f.write(code)
|
| 581 |
+
tmp_path = f.name
|
| 582 |
+
|
| 583 |
+
result = subprocess.run(
|
| 584 |
+
['pylint', tmp_path, '--score=y', '--exit-zero'],
|
| 585 |
+
capture_output=True,
|
| 586 |
+
text=True,
|
| 587 |
+
timeout=5
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
|
| 591 |
+
if match:
|
| 592 |
+
return float(match.group(1)) / 10.0
|
| 593 |
+
return 0.0
|
| 594 |
+
except:
|
| 595 |
+
return 0.0
|
| 596 |
+
finally:
|
| 597 |
+
try:
|
| 598 |
+
os.unlink(tmp_path)
|
| 599 |
+
except:
|
| 600 |
+
pass
|
| 601 |
+
|
| 602 |
+
# ===================================================================
|
| 603 |
+
def state(self) -> State:
|
| 604 |
+
"""Legacy compatibility."""
|
| 605 |
+
return State(
|
| 606 |
+
pr_title="Code Review",
|
| 607 |
+
pr_description=self._bug_description,
|
| 608 |
+
code_snippet=self._current_code,
|
| 609 |
+
comments=self._comments.copy(),
|
| 610 |
+
test_results=self._test_results,
|
| 611 |
+
step=self._step_count,
|
| 612 |
+
done=self._done
|
| 613 |
+
)
|
grader.py
CHANGED
|
@@ -1,148 +1,142 @@
|
|
| 1 |
-
# grader.py – Production‑grade, continuous reward, exploit‑aware
|
| 2 |
-
import ast
|
| 3 |
-
import subprocess
|
| 4 |
-
import tempfile
|
| 5 |
-
import os
|
| 6 |
-
import re
|
| 7 |
-
import
|
| 8 |
-
import
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
"""
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
nodes_prop = [type(n) for n in ast.walk(tree_prop)]
|
| 143 |
-
nodes_oracle = [type(n) for n in ast.walk(tree_oracle)]
|
| 144 |
-
common = sum(1 for n in nodes_prop if n in nodes_oracle)
|
| 145 |
-
total = max(len(nodes_prop), len(nodes_oracle))
|
| 146 |
-
return common / total if total > 0 else 0.0
|
| 147 |
-
except:
|
| 148 |
return 0.0
|
|
|
|
| 1 |
+
# grader.py – Production‑grade, continuous reward, exploit‑aware, example of monolithic scoring
|
| 2 |
+
import ast
|
| 3 |
+
import subprocess
|
| 4 |
+
import tempfile
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class RigorousGrader:
|
| 12 |
+
bug_id: str
|
| 13 |
+
oracle_code: Optional[str] = None
|
| 14 |
+
|
| 15 |
+
def grade_fix(self, proposed_fix: str) -> float:
|
| 16 |
+
"""
|
| 17 |
+
Returns a smooth reward in [0,1] based on:
|
| 18 |
+
- Syntax validity
|
| 19 |
+
- Proportion of tests passed (continuous, not binary)
|
| 20 |
+
- Lint quality (with conservative fallback)
|
| 21 |
+
- Structural similarity to oracle (anti‑gaming)
|
| 22 |
+
- Exploit detection (hardcoded outputs / no real change)
|
| 23 |
+
"""
|
| 24 |
+
# 1. Syntax check (binary – non‑negotiable)
|
| 25 |
+
try:
|
| 26 |
+
ast.parse(proposed_fix)
|
| 27 |
+
except SyntaxError:
|
| 28 |
+
return 0.0 # hard zero, not negative (RL stable)
|
| 29 |
+
|
| 30 |
+
# 2. Exploit detection: trivial or hardcoded fixes
|
| 31 |
+
if self._is_exploit(proposed_fix):
|
| 32 |
+
return 0.0
|
| 33 |
+
|
| 34 |
+
# 3. Continuous test score (proportion of passed test cases)
|
| 35 |
+
test_score = self._run_continuous_tests(proposed_fix)
|
| 36 |
+
|
| 37 |
+
# 4. Lint score (continuous, fallback 0.0 not 0.5)
|
| 38 |
+
lint_score = self._get_lint_score(proposed_fix)
|
| 39 |
+
|
| 40 |
+
# 5. Oracle similarity (structural, not gameable)
|
| 41 |
+
oracle_score = self._ast_similarity(proposed_fix) if self.oracle_code else 0.0
|
| 42 |
+
|
| 43 |
+
# Weighted combination (all continuous)
|
| 44 |
+
final = (0.5 * test_score) + (0.3 * lint_score) + (0.2 * oracle_score)
|
| 45 |
+
return max(0.0, min(1.0, final))
|
| 46 |
+
|
| 47 |
+
def _run_continuous_tests(self, code: str) -> float:
|
| 48 |
+
"""
|
| 49 |
+
Returns proportion of passed tests (0.0 to 1.0).
|
| 50 |
+
Uses multiple test cases per bug type.
|
| 51 |
+
"""
|
| 52 |
+
test_cases = self._get_test_cases()
|
| 53 |
+
if not test_cases:
|
| 54 |
+
return 0.0
|
| 55 |
+
|
| 56 |
+
passed = 0
|
| 57 |
+
for test_input, expected in test_cases:
|
| 58 |
+
if self._run_single_test(code, test_input, expected):
|
| 59 |
+
passed += 1
|
| 60 |
+
return passed / len(test_cases)
|
| 61 |
+
|
| 62 |
+
def _get_test_cases(self) -> list:
|
| 63 |
+
"""Define multiple test cases for each bug type."""
|
| 64 |
+
if self.bug_id == "null_check":
|
| 65 |
+
return [
|
| 66 |
+
({"users": {"alice": "Alice"}, "id": "bob"}, None), # should not crash
|
| 67 |
+
({"users": {"alice": "Alice"}, "id": "alice"}, "Alice"),
|
| 68 |
+
]
|
| 69 |
+
elif self.bug_id == "off_by_one":
|
| 70 |
+
return [
|
| 71 |
+
([1,2,3,4], 4), # should count all elements
|
| 72 |
+
([], 0),
|
| 73 |
+
]
|
| 74 |
+
# Add more for other bugs...
|
| 75 |
+
return []
|
| 76 |
+
|
| 77 |
+
def _run_single_test(self, code: str, test_input, expected) -> bool:
|
| 78 |
+
"""Execute code with given input and compare output."""
|
| 79 |
+
# Simplified – in production, use a safe sandbox
|
| 80 |
+
try:
|
| 81 |
+
# Inject test harness (this is a placeholder)
|
| 82 |
+
exec_globals = {}
|
| 83 |
+
exec(code, exec_globals)
|
| 84 |
+
# Call the function (assume it's named appropriately)
|
| 85 |
+
# This is highly simplified; real implementation would need more care.
|
| 86 |
+
return True # placeholder
|
| 87 |
+
except:
|
| 88 |
+
return False
|
| 89 |
+
|
| 90 |
+
def _is_exploit(self, code: str) -> bool:
|
| 91 |
+
"""Detect hardcoded returns or trivial bypasses."""
|
| 92 |
+
lower = code.lower()
|
| 93 |
+
# Hardcoded return for a specific input
|
| 94 |
+
if "return 0" in lower and "if" not in lower:
|
| 95 |
+
return True
|
| 96 |
+
# No change at all (same as original placeholder)
|
| 97 |
+
if code.strip() == "":
|
| 98 |
+
return True
|
| 99 |
+
return False
|
| 100 |
+
|
| 101 |
+
def _get_lint_score(self, code: str) -> float:
|
| 102 |
+
"""Continuous lint score, fallback 0.0 on error."""
|
| 103 |
+
try:
|
| 104 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
| 105 |
+
f.write(code)
|
| 106 |
+
f.flush()
|
| 107 |
+
tmp_path = f.name
|
| 108 |
+
result = subprocess.run(
|
| 109 |
+
['pylint', tmp_path, '--score=y', '--exit-zero'],
|
| 110 |
+
capture_output=True,
|
| 111 |
+
text=True,
|
| 112 |
+
timeout=5
|
| 113 |
+
)
|
| 114 |
+
match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
|
| 115 |
+
if match:
|
| 116 |
+
score = float(match.group(1)) / 10.0
|
| 117 |
+
else:
|
| 118 |
+
score = 0.0 # was 0.5 – now conservative
|
| 119 |
+
return max(0.0, min(1.0, score))
|
| 120 |
+
except Exception:
|
| 121 |
+
return 0.0
|
| 122 |
+
finally:
|
| 123 |
+
try:
|
| 124 |
+
os.unlink(tmp_path)
|
| 125 |
+
except:
|
| 126 |
+
pass
|
| 127 |
+
|
| 128 |
+
def _ast_similarity(self, proposed_code: str) -> float:
|
| 129 |
+
"""Structural similarity – penalizes structure‑only changes without logic change."""
|
| 130 |
+
if not self.oracle_code:
|
| 131 |
+
return 0.0
|
| 132 |
+
try:
|
| 133 |
+
tree_prop = ast.parse(proposed_code)
|
| 134 |
+
tree_oracle = ast.parse(self.oracle_code)
|
| 135 |
+
# Count matching node types (crude but simple)
|
| 136 |
+
nodes_prop = [type(n) for n in ast.walk(tree_prop)]
|
| 137 |
+
nodes_oracle = [type(n) for n in ast.walk(tree_oracle)]
|
| 138 |
+
common = sum(1 for n in nodes_prop if n in nodes_oracle)
|
| 139 |
+
total = max(len(nodes_prop), len(nodes_oracle))
|
| 140 |
+
return common / total if total > 0 else 0.0
|
| 141 |
+
except:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
return 0.0
|
models.py
CHANGED
|
@@ -1,115 +1,88 @@
|
|
| 1 |
-
# models.py – Typed Models (Discriminated Unions, POMDP Separation)
|
| 2 |
-
from typing import Literal, Union, Annotated, Optional
|
| 3 |
-
from
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
#
|
| 7 |
-
#
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
"
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
@
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
#
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
# Observation (POMDP – what the agent sees)
|
| 89 |
-
# ----------------------------------------------------------------------
|
| 90 |
-
@dataclass(slots=True)
|
| 91 |
-
class Observation:
|
| 92 |
-
code_snippet: str
|
| 93 |
-
last_tool_output: str = ""
|
| 94 |
-
step: int = 0
|
| 95 |
-
done: bool = False
|
| 96 |
-
|
| 97 |
-
# ----------------------------------------------------------------------
|
| 98 |
-
# Reward (lightweight)
|
| 99 |
-
# ----------------------------------------------------------------------
|
| 100 |
-
@dataclass(slots=True)
|
| 101 |
-
class Reward:
|
| 102 |
-
value: float
|
| 103 |
-
|
| 104 |
-
# ----------------------------------------------------------------------
|
| 105 |
-
# State (full environment state – not exposed to agent)
|
| 106 |
-
# ----------------------------------------------------------------------
|
| 107 |
-
@dataclass(slots=True)
|
| 108 |
-
class State:
|
| 109 |
-
pr_title: str
|
| 110 |
-
pr_description: str
|
| 111 |
-
code_snippet: str
|
| 112 |
-
comments: list[str]
|
| 113 |
-
test_results: Optional[str]
|
| 114 |
-
step: int
|
| 115 |
done: bool
|
|
|
|
| 1 |
+
# models.py – Typed Models (Discriminated Unions, POMDP Separation)
|
| 2 |
+
from typing import Literal, Union, Annotated, Optional
|
| 3 |
+
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
| 4 |
+
|
| 5 |
+
# ----------------------------------------------------------------------
|
| 6 |
+
# Action classes (discriminated union)
|
| 7 |
+
# ----------------------------------------------------------------------
|
| 8 |
+
class Action(BaseModel):
|
| 9 |
+
action_type: Literal["write_comment", "skip", "done", "ask_question",
|
| 10 |
+
"propose_fix", "execute", "inspect", "run_linter",
|
| 11 |
+
"run_tests", "query_docs"]
|
| 12 |
+
|
| 13 |
+
class WriteComment(Action):
|
| 14 |
+
action_type: Literal["write_comment"] = "write_comment"
|
| 15 |
+
comment_text: str = Field(..., min_length=1)
|
| 16 |
+
|
| 17 |
+
class Skip(Action):
|
| 18 |
+
action_type: Literal["skip"] = "skip"
|
| 19 |
+
|
| 20 |
+
class Done(Action):
|
| 21 |
+
action_type: Literal["done"] = "done"
|
| 22 |
+
|
| 23 |
+
class AskQuestion(Action):
|
| 24 |
+
action_type: Literal["ask_question"] = "ask_question"
|
| 25 |
+
question: str = Field(..., min_length=1)
|
| 26 |
+
|
| 27 |
+
class ProposeFix(Action):
|
| 28 |
+
action_type: Literal["propose_fix"] = "propose_fix"
|
| 29 |
+
fix_code: str = Field(..., min_length=1)
|
| 30 |
+
@field_validator('fix_code')
|
| 31 |
+
@classmethod
|
| 32 |
+
def not_empty(cls, v: str) -> str:
|
| 33 |
+
if not v.strip():
|
| 34 |
+
raise ValueError('fix_code cannot be empty')
|
| 35 |
+
return v
|
| 36 |
+
|
| 37 |
+
class Execute(Action):
|
| 38 |
+
action_type: Literal["execute"] = "execute"
|
| 39 |
+
|
| 40 |
+
class Inspect(Action):
|
| 41 |
+
action_type: Literal["inspect"] = "inspect"
|
| 42 |
+
|
| 43 |
+
class RunLinter(Action):
|
| 44 |
+
action_type: Literal["run_linter"] = "run_linter"
|
| 45 |
+
|
| 46 |
+
class RunTests(Action):
|
| 47 |
+
action_type: Literal["run_tests"] = "run_tests"
|
| 48 |
+
|
| 49 |
+
class QueryDocs(Action):
|
| 50 |
+
action_type: Literal["query_docs"] = "query_docs"
|
| 51 |
+
query_topic: str = Field(..., min_length=1)
|
| 52 |
+
|
| 53 |
+
# Discriminated union for one‑line polymorphic deserialization
|
| 54 |
+
AnyAction = Annotated[
|
| 55 |
+
Union[WriteComment, Skip, Done, AskQuestion, ProposeFix,
|
| 56 |
+
Execute, Inspect, RunLinter, RunTests, QueryDocs],
|
| 57 |
+
Field(discriminator='action_type')
|
| 58 |
+
]
|
| 59 |
+
action_adapter = TypeAdapter(AnyAction)
|
| 60 |
+
|
| 61 |
+
# ----------------------------------------------------------------------
|
| 62 |
+
# Observation (POMDP – what the agent sees)
|
| 63 |
+
# ----------------------------------------------------------------------
|
| 64 |
+
class Observation(BaseModel):
|
| 65 |
+
# Base schema model used by API metadata endpoints.
|
| 66 |
+
# Keep this lightweight for compatibility with legacy callers.
|
| 67 |
+
code_snippet: str
|
| 68 |
+
last_tool_output: str = ""
|
| 69 |
+
step: int = 0
|
| 70 |
+
done: bool = False
|
| 71 |
+
|
| 72 |
+
# ----------------------------------------------------------------------
|
| 73 |
+
# Reward (lightweight)
|
| 74 |
+
# ----------------------------------------------------------------------
|
| 75 |
+
class Reward(BaseModel):
|
| 76 |
+
value: float
|
| 77 |
+
|
| 78 |
+
# ----------------------------------------------------------------------
|
| 79 |
+
# State (full environment state – not exposed to agent)
|
| 80 |
+
# ----------------------------------------------------------------------
|
| 81 |
+
class State(BaseModel):
|
| 82 |
+
pr_title: str
|
| 83 |
+
pr_description: str
|
| 84 |
+
code_snippet: str
|
| 85 |
+
comments: list[str]
|
| 86 |
+
test_results: Optional[str]
|
| 87 |
+
step: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
done: bool
|
openenv.yaml
CHANGED
|
@@ -1,135 +1,135 @@
|
|
| 1 |
-
# openenv.yaml – Environment metadata for OpenEnv
|
| 2 |
-
name: CodeReview-Professional-Workflow
|
| 3 |
-
version: 1.0.0
|
| 4 |
-
description: |
|
| 5 |
-
Multi‑turn code review environment for professional tasks.
|
| 6 |
-
Agent must inspect, test, lint, query docs, and negotiate with a simulated author
|
| 7 |
-
to fix injected bugs. Supports DPO training on full trajectories.
|
| 8 |
-
author: yuvraj gupta
|
| 9 |
-
license: MIT
|
| 10 |
-
|
| 11 |
-
# ----------------------------------------------------------------------
|
| 12 |
-
# Tasks (difficulty progression)
|
| 13 |
-
# ----------------------------------------------------------------------
|
| 14 |
-
tasks:
|
| 15 |
-
- id: easy
|
| 16 |
-
description: "Fix missing null check in a dictionary lookup"
|
| 17 |
-
- id: medium
|
| 18 |
-
description: "Improve loop efficiency (replace range(len) with direct iteration)"
|
| 19 |
-
- id: hard
|
| 20 |
-
description: "Handle division by zero in average calculation"
|
| 21 |
-
- id: harder
|
| 22 |
-
description: "Fix race condition by adding a lock"
|
| 23 |
-
- id: hardest
|
| 24 |
-
description: "Resolve potential deadlock by standardising lock order"
|
| 25 |
-
|
| 26 |
-
# ----------------------------------------------------------------------
|
| 27 |
-
# Observation space (complete Markov state – agent sees everything)
|
| 28 |
-
# ----------------------------------------------------------------------
|
| 29 |
-
observation_space:
|
| 30 |
-
type: object
|
| 31 |
-
properties:
|
| 32 |
-
code_snippet:
|
| 33 |
-
type: string
|
| 34 |
-
description: "Current code snippet (may contain injected bug)"
|
| 35 |
-
last_tool_output:
|
| 36 |
-
type: string
|
| 37 |
-
description: "Raw output from last tool (test runner, linter, etc.)"
|
| 38 |
-
author_response:
|
| 39 |
-
type: string
|
| 40 |
-
description: "Latest feedback from the simulated human developer"
|
| 41 |
-
current_test_score:
|
| 42 |
-
type: number
|
| 43 |
-
description: "Proportion of tests passed (0.0–1.0)"
|
| 44 |
-
current_lint_score:
|
| 45 |
-
type: number
|
| 46 |
-
description: "Normalised pylint score (0.0–1.0)"
|
| 47 |
-
negotiation_score:
|
| 48 |
-
type: number
|
| 49 |
-
description: "Author's confidence minus pushback penalty"
|
| 50 |
-
previous_test_score:
|
| 51 |
-
type: number
|
| 52 |
-
description: "Test score before the last action"
|
| 53 |
-
previous_lint_score:
|
| 54 |
-
type: number
|
| 55 |
-
description: "Lint score before the last action"
|
| 56 |
-
author_confidence:
|
| 57 |
-
type: number
|
| 58 |
-
description: "Internal belief of the author (0.0–1.0)"
|
| 59 |
-
author_threshold:
|
| 60 |
-
type: number
|
| 61 |
-
description: "Confidence threshold for this personality"
|
| 62 |
-
step:
|
| 63 |
-
type: integer
|
| 64 |
-
description: "Current step number"
|
| 65 |
-
max_steps:
|
| 66 |
-
type: integer
|
| 67 |
-
description: "Maximum steps allowed in the episode"
|
| 68 |
-
progress_ratio:
|
| 69 |
-
type: number
|
| 70 |
-
description: "step / max_steps"
|
| 71 |
-
tests_run:
|
| 72 |
-
type: boolean
|
| 73 |
-
description: "Whether the agent has run tests at least once"
|
| 74 |
-
linter_run:
|
| 75 |
-
type: boolean
|
| 76 |
-
description: "Whether the agent has run the linter at least once"
|
| 77 |
-
docs_queried:
|
| 78 |
-
type: boolean
|
| 79 |
-
description: "Whether the agent has queried documentation"
|
| 80 |
-
last_action_type:
|
| 81 |
-
type: string
|
| 82 |
-
description: "String name of the last executed action"
|
| 83 |
-
action_history:
|
| 84 |
-
type: array
|
| 85 |
-
items:
|
| 86 |
-
type: string
|
| 87 |
-
description: "Last 5 action types"
|
| 88 |
-
done:
|
| 89 |
-
type: boolean
|
| 90 |
-
description: "Whether the episode has finished"
|
| 91 |
-
bug_description:
|
| 92 |
-
type: string
|
| 93 |
-
description: "Short description of the injected bug"
|
| 94 |
-
comments_count:
|
| 95 |
-
type: integer
|
| 96 |
-
description: "Number of comments exchanged so far"
|
| 97 |
-
|
| 98 |
-
# ----------------------------------------------------------------------
|
| 99 |
-
# Action space (short names as produced by the agent)
|
| 100 |
-
# ----------------------------------------------------------------------
|
| 101 |
-
action_space:
|
| 102 |
-
type: object
|
| 103 |
-
properties:
|
| 104 |
-
action_type:
|
| 105 |
-
type: string
|
| 106 |
-
enum:
|
| 107 |
-
- comment
|
| 108 |
-
- skip
|
| 109 |
-
- done
|
| 110 |
-
- question
|
| 111 |
-
- fix
|
| 112 |
-
- execute
|
| 113 |
-
- inspect
|
| 114 |
-
- run_linter
|
| 115 |
-
- run_tests
|
| 116 |
-
- query_docs
|
| 117 |
-
comment_text:
|
| 118 |
-
type: string
|
| 119 |
-
description: "Required for comment"
|
| 120 |
-
question:
|
| 121 |
-
type: string
|
| 122 |
-
description: "Required for question"
|
| 123 |
-
fix_code:
|
| 124 |
-
type: string
|
| 125 |
-
description: "Required for fix"
|
| 126 |
-
query_topic:
|
| 127 |
-
type: string
|
| 128 |
-
description: "Required for query_docs"
|
| 129 |
-
|
| 130 |
-
# ----------------------------------------------------------------------
|
| 131 |
-
# (Optional) Server configuration – used by openenv serve
|
| 132 |
-
# ----------------------------------------------------------------------
|
| 133 |
-
server:
|
| 134 |
-
app: server.app:app
|
| 135 |
-
port: 7860
|
|
|
|
| 1 |
+
# openenv.yaml – Environment metadata for OpenEnv
|
| 2 |
+
name: CodeReview-Professional-Workflow
|
| 3 |
+
version: 1.0.0
|
| 4 |
+
description: |
|
| 5 |
+
Multi‑turn code review environment for professional tasks.
|
| 6 |
+
Agent must inspect, test, lint, query docs, and negotiate with a simulated author
|
| 7 |
+
to fix injected bugs. Supports DPO training on full trajectories.
|
| 8 |
+
author: yuvraj gupta
|
| 9 |
+
license: MIT
|
| 10 |
+
|
| 11 |
+
# ----------------------------------------------------------------------
|
| 12 |
+
# Tasks (difficulty progression)
|
| 13 |
+
# ----------------------------------------------------------------------
|
| 14 |
+
tasks:
|
| 15 |
+
- id: easy
|
| 16 |
+
description: "Fix missing null check in a dictionary lookup"
|
| 17 |
+
- id: medium
|
| 18 |
+
description: "Improve loop efficiency (replace range(len) with direct iteration)"
|
| 19 |
+
- id: hard
|
| 20 |
+
description: "Handle division by zero in average calculation"
|
| 21 |
+
- id: harder
|
| 22 |
+
description: "Fix race condition by adding a lock"
|
| 23 |
+
- id: hardest
|
| 24 |
+
description: "Resolve potential deadlock by standardising lock order"
|
| 25 |
+
|
| 26 |
+
# ----------------------------------------------------------------------
|
| 27 |
+
# Observation space (complete Markov state – agent sees everything)
|
| 28 |
+
# ----------------------------------------------------------------------
|
| 29 |
+
observation_space:
|
| 30 |
+
type: object
|
| 31 |
+
properties:
|
| 32 |
+
code_snippet:
|
| 33 |
+
type: string
|
| 34 |
+
description: "Current code snippet (may contain injected bug)"
|
| 35 |
+
last_tool_output:
|
| 36 |
+
type: string
|
| 37 |
+
description: "Raw output from last tool (test runner, linter, etc.)"
|
| 38 |
+
author_response:
|
| 39 |
+
type: string
|
| 40 |
+
description: "Latest feedback from the simulated human developer"
|
| 41 |
+
current_test_score:
|
| 42 |
+
type: number
|
| 43 |
+
description: "Proportion of tests passed (0.0–1.0)"
|
| 44 |
+
current_lint_score:
|
| 45 |
+
type: number
|
| 46 |
+
description: "Normalised pylint score (0.0–1.0)"
|
| 47 |
+
negotiation_score:
|
| 48 |
+
type: number
|
| 49 |
+
description: "Author's confidence minus pushback penalty"
|
| 50 |
+
previous_test_score:
|
| 51 |
+
type: number
|
| 52 |
+
description: "Test score before the last action"
|
| 53 |
+
previous_lint_score:
|
| 54 |
+
type: number
|
| 55 |
+
description: "Lint score before the last action"
|
| 56 |
+
author_confidence:
|
| 57 |
+
type: number
|
| 58 |
+
description: "Internal belief of the author (0.0–1.0)"
|
| 59 |
+
author_threshold:
|
| 60 |
+
type: number
|
| 61 |
+
description: "Confidence threshold for this personality"
|
| 62 |
+
step:
|
| 63 |
+
type: integer
|
| 64 |
+
description: "Current step number"
|
| 65 |
+
max_steps:
|
| 66 |
+
type: integer
|
| 67 |
+
description: "Maximum steps allowed in the episode"
|
| 68 |
+
progress_ratio:
|
| 69 |
+
type: number
|
| 70 |
+
description: "step / max_steps"
|
| 71 |
+
tests_run:
|
| 72 |
+
type: boolean
|
| 73 |
+
description: "Whether the agent has run tests at least once"
|
| 74 |
+
linter_run:
|
| 75 |
+
type: boolean
|
| 76 |
+
description: "Whether the agent has run the linter at least once"
|
| 77 |
+
docs_queried:
|
| 78 |
+
type: boolean
|
| 79 |
+
description: "Whether the agent has queried documentation"
|
| 80 |
+
last_action_type:
|
| 81 |
+
type: string
|
| 82 |
+
description: "String name of the last executed action"
|
| 83 |
+
action_history:
|
| 84 |
+
type: array
|
| 85 |
+
items:
|
| 86 |
+
type: string
|
| 87 |
+
description: "Last 5 action types"
|
| 88 |
+
done:
|
| 89 |
+
type: boolean
|
| 90 |
+
description: "Whether the episode has finished"
|
| 91 |
+
bug_description:
|
| 92 |
+
type: string
|
| 93 |
+
description: "Short description of the injected bug"
|
| 94 |
+
comments_count:
|
| 95 |
+
type: integer
|
| 96 |
+
description: "Number of comments exchanged so far"
|
| 97 |
+
|
| 98 |
+
# ----------------------------------------------------------------------
|
| 99 |
+
# Action space (short names as produced by the agent)
|
| 100 |
+
# ----------------------------------------------------------------------
|
| 101 |
+
action_space:
|
| 102 |
+
type: object
|
| 103 |
+
properties:
|
| 104 |
+
action_type:
|
| 105 |
+
type: string
|
| 106 |
+
enum:
|
| 107 |
+
- comment
|
| 108 |
+
- skip
|
| 109 |
+
- done
|
| 110 |
+
- question
|
| 111 |
+
- fix
|
| 112 |
+
- execute
|
| 113 |
+
- inspect
|
| 114 |
+
- run_linter
|
| 115 |
+
- run_tests
|
| 116 |
+
- query_docs
|
| 117 |
+
comment_text:
|
| 118 |
+
type: string
|
| 119 |
+
description: "Required for comment"
|
| 120 |
+
question:
|
| 121 |
+
type: string
|
| 122 |
+
description: "Required for question"
|
| 123 |
+
fix_code:
|
| 124 |
+
type: string
|
| 125 |
+
description: "Required for fix"
|
| 126 |
+
query_topic:
|
| 127 |
+
type: string
|
| 128 |
+
description: "Required for query_docs"
|
| 129 |
+
|
| 130 |
+
# ----------------------------------------------------------------------
|
| 131 |
+
# (Optional) Server configuration – used by openenv serve
|
| 132 |
+
# ----------------------------------------------------------------------
|
| 133 |
+
server:
|
| 134 |
+
app: server.app:app
|
| 135 |
+
port: 7860
|
pyproject.toml
CHANGED
|
@@ -1,30 +1,30 @@
|
|
| 1 |
-
[build-system]
|
| 2 |
-
requires = ["setuptools>=61.0", "wheel"]
|
| 3 |
-
build-backend = "setuptools.build_meta"
|
| 4 |
-
|
| 5 |
-
[project]
|
| 6 |
-
name = "code_review_professional"
|
| 7 |
-
version = "1.0.0"
|
| 8 |
-
description = "Multi‑turn code review environment with AST injection, DPO training, and author negotiation"
|
| 9 |
-
authors = [{name = "yuvraj gupta", email = "yuvraj467229@gmail.com"}]
|
| 10 |
-
license = {text = "MIT"}
|
| 11 |
-
readme = "README.md"
|
| 12 |
-
requires-python = ">=3.10"
|
| 13 |
-
dependencies = [
|
| 14 |
-
"openenv-core>=0.2.0",
|
| 15 |
-
"fastapi>=0.115.0",
|
| 16 |
-
"uvicorn>=0.24.0",
|
| 17 |
-
"unsloth>=2025.3.1",
|
| 18 |
-
"trl>=0.15.0",
|
| 19 |
-
"accelerate>=1.2.0",
|
| 20 |
-
"pylint>=3.3.0",
|
| 21 |
-
"sentence-transformers>=3.3.0",
|
| 22 |
-
"datasets>=3.3.0",
|
| 23 |
-
"chromadb>=0.5.0",
|
| 24 |
-
]
|
| 25 |
-
|
| 26 |
-
[project.optional-dependencies]
|
| 27 |
-
dev = ["pytest>=7.0", "black>=23.0", "isort>=5.0"]
|
| 28 |
-
|
| 29 |
-
[tool.openenv]
|
| 30 |
server = "server.app:app"
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "code_review_professional"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "Multi‑turn code review environment with AST injection, DPO training, and author negotiation"
|
| 9 |
+
authors = [{name = "yuvraj gupta", email = "yuvraj467229@gmail.com"}]
|
| 10 |
+
license = {text = "MIT"}
|
| 11 |
+
readme = "README.md"
|
| 12 |
+
requires-python = ">=3.10"
|
| 13 |
+
dependencies = [
|
| 14 |
+
"openenv-core>=0.2.0",
|
| 15 |
+
"fastapi>=0.115.0",
|
| 16 |
+
"uvicorn>=0.24.0",
|
| 17 |
+
"unsloth>=2025.3.1",
|
| 18 |
+
"trl>=0.15.0",
|
| 19 |
+
"accelerate>=1.2.0",
|
| 20 |
+
"pylint>=3.3.0",
|
| 21 |
+
"sentence-transformers>=3.3.0",
|
| 22 |
+
"datasets>=3.3.0",
|
| 23 |
+
"chromadb>=0.5.0",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
[project.optional-dependencies]
|
| 27 |
+
dev = ["pytest>=7.0", "black>=23.0", "isort>=5.0"]
|
| 28 |
+
|
| 29 |
+
[tool.openenv]
|
| 30 |
server = "server.app:app"
|
redteam.py
CHANGED
|
@@ -1,274 +1,274 @@
|
|
| 1 |
-
# redteam.py – Task‑aware bug injection (25 bugs, 5 difficulty levels)
|
| 2 |
-
import ast
|
| 3 |
-
import random
|
| 4 |
-
from dataclasses import dataclass, field
|
| 5 |
-
from typing import Tuple, Optional, List, Dict
|
| 6 |
-
|
| 7 |
-
# ----------------------------------------------------------------------
|
| 8 |
-
# 1. AST Bug Injector (extended for all simple bugs)
|
| 9 |
-
# ----------------------------------------------------------------------
|
| 10 |
-
class ASTBugInjector(ast.NodeTransformer):
|
| 11 |
-
def __init__(self, bug_type: str):
|
| 12 |
-
super().__init__()
|
| 13 |
-
self.bug_type = bug_type
|
| 14 |
-
self.modified = False
|
| 15 |
-
|
| 16 |
-
# --- Easy: null_check, simple_typo, string_index, default_value, empty_return ---
|
| 17 |
-
def visit_If(self, node: ast.If):
|
| 18 |
-
# null_check: remove the if-guard
|
| 19 |
-
if self.bug_type == "null_check" and not self.modified:
|
| 20 |
-
if node.body and len(node.body) == 1:
|
| 21 |
-
self.modified = True
|
| 22 |
-
return node.body[0]
|
| 23 |
-
# division_by_zero_empty: remove the empty check
|
| 24 |
-
if self.bug_type == "division_by_zero_empty" and not self.modified:
|
| 25 |
-
# pattern: if not data: return 0 – we delete the entire if
|
| 26 |
-
if (isinstance(node.test, ast.UnaryOp) and
|
| 27 |
-
isinstance(node.test.op, ast.Not) and
|
| 28 |
-
isinstance(node.test.operand, ast.Name)):
|
| 29 |
-
self.modified = True
|
| 30 |
-
return None # signal to remove this node from parent
|
| 31 |
-
return self.generic_visit(node)
|
| 32 |
-
|
| 33 |
-
def visit_Name(self, node: ast.Name):
|
| 34 |
-
if self.bug_type == "simple_typo" and not self.modified:
|
| 35 |
-
if node.id == "users":
|
| 36 |
-
self.modified = True
|
| 37 |
-
return ast.Name(id="usres", ctx=node.ctx)
|
| 38 |
-
return self.generic_visit(node)
|
| 39 |
-
|
| 40 |
-
def visit_Subscript(self, node: ast.Subscript):
|
| 41 |
-
if self.bug_type == "string_index" and not self.modified:
|
| 42 |
-
if isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Constant):
|
| 43 |
-
old_val = node.slice.value.value
|
| 44 |
-
if isinstance(old_val, int):
|
| 45 |
-
self.modified = True
|
| 46 |
-
node.slice = ast.Index(value=ast.Constant(value=old_val + 1))
|
| 47 |
-
return self.generic_visit(node)
|
| 48 |
-
|
| 49 |
-
def visit_Call(self, node: ast.Call):
|
| 50 |
-
# default_value: change dict.get(key) to dict[key] (no default)
|
| 51 |
-
if self.bug_type == "default_value" and not self.modified:
|
| 52 |
-
if (isinstance(node.func, ast.Attribute) and
|
| 53 |
-
node.func.attr == "get" and len(node.args) == 1):
|
| 54 |
-
self.modified = True
|
| 55 |
-
return ast.Subscript(
|
| 56 |
-
value=node.func.value,
|
| 57 |
-
slice=ast.Index(value=node.args[0]),
|
| 58 |
-
ctx=node.ctx
|
| 59 |
-
)
|
| 60 |
-
# abs_usage: remove abs()
|
| 61 |
-
if self.bug_type == "abs_usage" and not self.modified:
|
| 62 |
-
if isinstance(node.func, ast.Name) and node.func.id == "abs":
|
| 63 |
-
self.modified = True
|
| 64 |
-
return node.args[0]
|
| 65 |
-
return self.generic_visit(node)
|
| 66 |
-
|
| 67 |
-
def visit_FunctionDef(self, node: ast.FunctionDef):
|
| 68 |
-
# empty_return: insert a premature return None
|
| 69 |
-
if self.bug_type == "empty_return" and not self.modified:
|
| 70 |
-
self.modified = True
|
| 71 |
-
node.body.insert(0, ast.Return(value=ast.Constant(value=None)))
|
| 72 |
-
return self.generic_visit(node)
|
| 73 |
-
|
| 74 |
-
# --- Medium: off_by_one, loop_skip, sign_error, swap_args, uninitialised_var ---
|
| 75 |
-
def visit_For(self, node: ast.For):
|
| 76 |
-
if (self.bug_type in ("off_by_one", "loop_skip")) and not self.modified:
|
| 77 |
-
if (isinstance(node.iter, ast.Call) and
|
| 78 |
-
isinstance(node.iter.func, ast.Name) and
|
| 79 |
-
node.iter.func.id == "range"):
|
| 80 |
-
if self.bug_type == "off_by_one":
|
| 81 |
-
new_iter = ast.Call(
|
| 82 |
-
func=ast.Name(id='range', ctx=ast.Load()),
|
| 83 |
-
args=[
|
| 84 |
-
ast.Constant(value=1),
|
| 85 |
-
ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))
|
| 86 |
-
],
|
| 87 |
-
keywords=[]
|
| 88 |
-
)
|
| 89 |
-
node.iter = new_iter
|
| 90 |
-
self.modified = True
|
| 91 |
-
elif self.bug_type == "loop_skip" and len(node.iter.args) == 1:
|
| 92 |
-
new_iter = ast.Call(
|
| 93 |
-
func=ast.Name(id='range', ctx=ast.Load()),
|
| 94 |
-
args=[ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))],
|
| 95 |
-
keywords=[]
|
| 96 |
-
)
|
| 97 |
-
node.iter = new_iter
|
| 98 |
-
self.modified = True
|
| 99 |
-
return self.generic_visit(node)
|
| 100 |
-
|
| 101 |
-
def visit_BinOp(self, node: ast.BinOp):
|
| 102 |
-
# sign_error: flip Add/Sub, wrong_operator: Add->Sub, float_precision: Div->FloorDiv
|
| 103 |
-
if not self.modified:
|
| 104 |
-
if self.bug_type in ("wrong_operator", "sign_error"):
|
| 105 |
-
if isinstance(node.op, ast.Add):
|
| 106 |
-
node.op = ast.Sub()
|
| 107 |
-
self.modified = True
|
| 108 |
-
elif isinstance(node.op, ast.Sub):
|
| 109 |
-
node.op = ast.Add()
|
| 110 |
-
self.modified = True
|
| 111 |
-
elif self.bug_type == "float_precision" and isinstance(node.op, ast.Div):
|
| 112 |
-
node.op = ast.FloorDiv()
|
| 113 |
-
self.modified = True
|
| 114 |
-
return self.generic_visit(node)
|
| 115 |
-
|
| 116 |
-
def visit_arguments(self, node: ast.arguments):
|
| 117 |
-
# swap_args: swap first two arguments of a function
|
| 118 |
-
if self.bug_type == "swap_args" and not self.modified and len(node.args) >= 2:
|
| 119 |
-
self.modified = True
|
| 120 |
-
node.args[0], node.args[1] = node.args[1], node.args[0]
|
| 121 |
-
return self.generic_visit(node)
|
| 122 |
-
|
| 123 |
-
def visit_Assign(self, node: ast.Assign):
|
| 124 |
-
# uninitialised_var: remove an assignment statement (replaced with Pass)
|
| 125 |
-
if self.bug_type == "uninitialised_var" and not self.modified:
|
| 126 |
-
self.modified = True
|
| 127 |
-
return ast.Pass()
|
| 128 |
-
return self.generic_visit(node)
|
| 129 |
-
|
| 130 |
-
# ----------------------------------------------------------------------
|
| 131 |
-
# 2. Bug database (25 bugs, categorized by difficulty)
|
| 132 |
-
# ----------------------------------------------------------------------
|
| 133 |
-
BUG_DB = {
|
| 134 |
-
"easy": {
|
| 135 |
-
"null_check": {"type": "ast", "bug_type": "null_check"},
|
| 136 |
-
"simple_typo": {"type": "ast", "bug_type": "simple_typo"},
|
| 137 |
-
"string_index": {"type": "ast", "bug_type": "string_index"},
|
| 138 |
-
"default_value": {"type": "ast", "bug_type": "default_value"},
|
| 139 |
-
"empty_return": {"type": "ast", "bug_type": "empty_return"},
|
| 140 |
-
},
|
| 141 |
-
"medium": {
|
| 142 |
-
"off_by_one": {"type": "ast", "bug_type": "off_by_one"},
|
| 143 |
-
"loop_skip": {"type": "ast", "bug_type": "loop_skip"},
|
| 144 |
-
"sign_error": {"type": "ast", "bug_type": "sign_error"},
|
| 145 |
-
"swap_args": {"type": "ast", "bug_type": "swap_args"},
|
| 146 |
-
"uninitialised_var": {"type": "ast", "bug_type": "uninitialised_var"},
|
| 147 |
-
},
|
| 148 |
-
"hard": {
|
| 149 |
-
"division_by_zero_empty": {"type": "ast", "bug_type": "division_by_zero_empty"},
|
| 150 |
-
"division_by_zero_zero": {"type": "ast", "bug_type": "division_by_zero_empty"}, # same injector
|
| 151 |
-
"float_precision": {"type": "ast", "bug_type": "float_precision"},
|
| 152 |
-
"abs_usage": {"type": "ast", "bug_type": "abs_usage"},
|
| 153 |
-
"round_error": {"type": "ast", "bug_type": "round_error"}, # can be extended
|
| 154 |
-
},
|
| 155 |
-
"harder": {
|
| 156 |
-
"missing_lock": {
|
| 157 |
-
"type": "template",
|
| 158 |
-
"buggy": "counter = 0\ndef increment():\n global counter\n counter += 1",
|
| 159 |
-
"oracle": "counter = 0\nimport threading\nlock = threading.Lock()\ndef increment():\n global counter\n with lock:\n counter += 1",
|
| 160 |
-
},
|
| 161 |
-
"double_lock": {
|
| 162 |
-
"type": "template",
|
| 163 |
-
"buggy": "import threading\nlock = threading.Lock()\ndef do_work():\n lock.acquire()\n lock.acquire()\n print('working')\n lock.release()",
|
| 164 |
-
"oracle": "import threading\nlock = threading.Lock()\ndef do_work():\n with lock:\n print('working')",
|
| 165 |
-
},
|
| 166 |
-
"global_nonatomic": {
|
| 167 |
-
"type": "template",
|
| 168 |
-
"buggy": "count = 0\ndef add():\n global count\n count = count + 1",
|
| 169 |
-
"oracle": "count = 0\ndef add():\n global count\n count += 1",
|
| 170 |
-
},
|
| 171 |
-
"thread_safe_list": {
|
| 172 |
-
"type": "template",
|
| 173 |
-
"buggy": "import threading\nitems = []\ndef append_item(item):\n items.append(item)",
|
| 174 |
-
"oracle": "import threading\nitems = []\nlock = threading.Lock()\ndef append_item(item):\n with lock:\n items.append(item)",
|
| 175 |
-
},
|
| 176 |
-
"volatile_read": {
|
| 177 |
-
"type": "template",
|
| 178 |
-
"buggy": "import threading\nstop = False\ndef worker():\n while not stop:\n pass",
|
| 179 |
-
"oracle": "import threading\nstop = False\nlock = threading.Lock()\ndef worker():\n while True:\n with lock:\n if stop:\n break",
|
| 180 |
-
},
|
| 181 |
-
},
|
| 182 |
-
"hardest": {
|
| 183 |
-
"deadlock_order": {
|
| 184 |
-
"type": "template",
|
| 185 |
-
"buggy": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock2:\n with lock1:\n pass",
|
| 186 |
-
"oracle": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock1:\n with lock2:\n pass",
|
| 187 |
-
},
|
| 188 |
-
"nested_lock_timeout": {
|
| 189 |
-
"type": "template",
|
| 190 |
-
"buggy": "import threading\nlock = threading.Lock()\ndef work():\n lock.acquire()\n # critical section\n lock.release()",
|
| 191 |
-
"oracle": "import threading\nlock = threading.Lock()\ndef work():\n if lock.acquire(timeout=1):\n try:\n # critical section\n finally:\n lock.release()",
|
| 192 |
-
},
|
| 193 |
-
"fork_join": {
|
| 194 |
-
"type": "template",
|
| 195 |
-
"buggy": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()",
|
| 196 |
-
"oracle": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()\nt.join()",
|
| 197 |
-
},
|
| 198 |
-
"mutex_release": {
|
| 199 |
-
"type": "template",
|
| 200 |
-
"buggy": "import threading\nlock = threading.Lock()\ndef thread_A():\n lock.acquire()\n lock.release()\ndef thread_B():\n lock.release()",
|
| 201 |
-
"oracle": "import threading\nlock = threading.Lock()\ndef thread_A():\n with lock:\n pass\ndef thread_B():\n with lock:\n pass",
|
| 202 |
-
},
|
| 203 |
-
"race_on_init": {
|
| 204 |
-
"type": "template",
|
| 205 |
-
"buggy": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nprint(items)",
|
| 206 |
-
"oracle": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nt.join()\nprint(items)",
|
| 207 |
-
},
|
| 208 |
-
},
|
| 209 |
-
}
|
| 210 |
-
|
| 211 |
-
# ----------------------------------------------------------------------
|
| 212 |
-
# 3. Derived helpers
|
| 213 |
-
# ----------------------------------------------------------------------
|
| 214 |
-
TASK_BUG_MAP = {level: list(bugs.keys()) for level, bugs in BUG_DB.items()}
|
| 215 |
-
|
| 216 |
-
TEMPLATE_BUGS = {}
|
| 217 |
-
for level, bugs in BUG_DB.items():
|
| 218 |
-
for bug_id, bug in bugs.items():
|
| 219 |
-
if bug["type"] == "template":
|
| 220 |
-
TEMPLATE_BUGS[bug_id] = (bug["buggy"], bug["oracle"])
|
| 221 |
-
|
| 222 |
-
# ----------------------------------------------------------------------
|
| 223 |
-
# 4. RedTeam Controller (task‑aware)
|
| 224 |
-
# ----------------------------------------------------------------------
|
| 225 |
-
@dataclass
|
| 226 |
-
class RedTeam:
|
| 227 |
-
task: str
|
| 228 |
-
seed: Optional[int] = 42
|
| 229 |
-
noise_prob: float = 0.2
|
| 230 |
-
_random: random.Random = field(init=False)
|
| 231 |
-
|
| 232 |
-
def __post_init__(self):
|
| 233 |
-
self._random = random.Random(self.seed)
|
| 234 |
-
|
| 235 |
-
def inject_bug(self, original_code: str) -> Tuple[str, str, str, str]:
|
| 236 |
-
"""
|
| 237 |
-
Returns: (buggy_code, bug_type, description, oracle_fix)
|
| 238 |
-
Selects a bug appropriate for the task difficulty.
|
| 239 |
-
"""
|
| 240 |
-
bug_list = TASK_BUG_MAP.get(self.task, ["null_check"])
|
| 241 |
-
bug_type = self._random.choice(bug_list)
|
| 242 |
-
|
| 243 |
-
# Template bug: return hardcoded buggy + oracle
|
| 244 |
-
if bug_type in TEMPLATE_BUGS:
|
| 245 |
-
buggy_code, oracle_code = TEMPLATE_BUGS[bug_type]
|
| 246 |
-
description = f"Template bug: {bug_type}"
|
| 247 |
-
if self._random.random() < self.noise_prob:
|
| 248 |
-
buggy_code += "\n# TODO: refactor later"
|
| 249 |
-
return buggy_code, bug_type, description, oracle_code
|
| 250 |
-
|
| 251 |
-
# AST injection
|
| 252 |
-
try:
|
| 253 |
-
tree = ast.parse(original_code)
|
| 254 |
-
except SyntaxError:
|
| 255 |
-
return original_code, "parse_error", "Syntax error in original code", original_code
|
| 256 |
-
|
| 257 |
-
injector = ASTBugInjector(bug_type)
|
| 258 |
-
modified_tree = injector.visit(tree)
|
| 259 |
-
ast.fix_missing_locations(modified_tree)
|
| 260 |
-
|
| 261 |
-
if injector.modified:
|
| 262 |
-
buggy_code = ast.unparse(modified_tree)
|
| 263 |
-
oracle_fix = original_code
|
| 264 |
-
description = f"AST bug: {bug_type}"
|
| 265 |
-
else:
|
| 266 |
-
buggy_code = original_code
|
| 267 |
-
oracle_fix = original_code
|
| 268 |
-
bug_type = "no_op"
|
| 269 |
-
description = "No suitable code structure found for injection"
|
| 270 |
-
|
| 271 |
-
if self._random.random() < self.noise_prob:
|
| 272 |
-
buggy_code += "\n# TODO: refactor later"
|
| 273 |
-
|
| 274 |
-
return buggy_code, bug_type, description, oracle_fix
|
|
|
|
| 1 |
+
# redteam.py – Task‑aware bug injection (25 bugs, 5 difficulty levels)
|
| 2 |
+
import ast
|
| 3 |
+
import random
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from typing import Tuple, Optional, List, Dict
|
| 6 |
+
|
| 7 |
+
# ----------------------------------------------------------------------
|
| 8 |
+
# 1. AST Bug Injector (extended for all simple bugs)
|
| 9 |
+
# ----------------------------------------------------------------------
|
| 10 |
+
class ASTBugInjector(ast.NodeTransformer):
|
| 11 |
+
def __init__(self, bug_type: str):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.bug_type = bug_type
|
| 14 |
+
self.modified = False
|
| 15 |
+
|
| 16 |
+
# --- Easy: null_check, simple_typo, string_index, default_value, empty_return ---
|
| 17 |
+
def visit_If(self, node: ast.If):
|
| 18 |
+
# null_check: remove the if-guard
|
| 19 |
+
if self.bug_type == "null_check" and not self.modified:
|
| 20 |
+
if node.body and len(node.body) == 1:
|
| 21 |
+
self.modified = True
|
| 22 |
+
return node.body[0]
|
| 23 |
+
# division_by_zero_empty: remove the empty check
|
| 24 |
+
if self.bug_type == "division_by_zero_empty" and not self.modified:
|
| 25 |
+
# pattern: if not data: return 0 – we delete the entire if
|
| 26 |
+
if (isinstance(node.test, ast.UnaryOp) and
|
| 27 |
+
isinstance(node.test.op, ast.Not) and
|
| 28 |
+
isinstance(node.test.operand, ast.Name)):
|
| 29 |
+
self.modified = True
|
| 30 |
+
return None # signal to remove this node from parent
|
| 31 |
+
return self.generic_visit(node)
|
| 32 |
+
|
| 33 |
+
def visit_Name(self, node: ast.Name):
|
| 34 |
+
if self.bug_type == "simple_typo" and not self.modified:
|
| 35 |
+
if node.id == "users":
|
| 36 |
+
self.modified = True
|
| 37 |
+
return ast.Name(id="usres", ctx=node.ctx)
|
| 38 |
+
return self.generic_visit(node)
|
| 39 |
+
|
| 40 |
+
def visit_Subscript(self, node: ast.Subscript):
|
| 41 |
+
if self.bug_type == "string_index" and not self.modified:
|
| 42 |
+
if isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Constant):
|
| 43 |
+
old_val = node.slice.value.value
|
| 44 |
+
if isinstance(old_val, int):
|
| 45 |
+
self.modified = True
|
| 46 |
+
node.slice = ast.Index(value=ast.Constant(value=old_val + 1))
|
| 47 |
+
return self.generic_visit(node)
|
| 48 |
+
|
| 49 |
+
def visit_Call(self, node: ast.Call):
|
| 50 |
+
# default_value: change dict.get(key) to dict[key] (no default)
|
| 51 |
+
if self.bug_type == "default_value" and not self.modified:
|
| 52 |
+
if (isinstance(node.func, ast.Attribute) and
|
| 53 |
+
node.func.attr == "get" and len(node.args) == 1):
|
| 54 |
+
self.modified = True
|
| 55 |
+
return ast.Subscript(
|
| 56 |
+
value=node.func.value,
|
| 57 |
+
slice=ast.Index(value=node.args[0]),
|
| 58 |
+
ctx=node.ctx
|
| 59 |
+
)
|
| 60 |
+
# abs_usage: remove abs()
|
| 61 |
+
if self.bug_type == "abs_usage" and not self.modified:
|
| 62 |
+
if isinstance(node.func, ast.Name) and node.func.id == "abs":
|
| 63 |
+
self.modified = True
|
| 64 |
+
return node.args[0]
|
| 65 |
+
return self.generic_visit(node)
|
| 66 |
+
|
| 67 |
+
def visit_FunctionDef(self, node: ast.FunctionDef):
|
| 68 |
+
# empty_return: insert a premature return None
|
| 69 |
+
if self.bug_type == "empty_return" and not self.modified:
|
| 70 |
+
self.modified = True
|
| 71 |
+
node.body.insert(0, ast.Return(value=ast.Constant(value=None)))
|
| 72 |
+
return self.generic_visit(node)
|
| 73 |
+
|
| 74 |
+
# --- Medium: off_by_one, loop_skip, sign_error, swap_args, uninitialised_var ---
|
| 75 |
+
def visit_For(self, node: ast.For):
|
| 76 |
+
if (self.bug_type in ("off_by_one", "loop_skip")) and not self.modified:
|
| 77 |
+
if (isinstance(node.iter, ast.Call) and
|
| 78 |
+
isinstance(node.iter.func, ast.Name) and
|
| 79 |
+
node.iter.func.id == "range"):
|
| 80 |
+
if self.bug_type == "off_by_one":
|
| 81 |
+
new_iter = ast.Call(
|
| 82 |
+
func=ast.Name(id='range', ctx=ast.Load()),
|
| 83 |
+
args=[
|
| 84 |
+
ast.Constant(value=1),
|
| 85 |
+
ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))
|
| 86 |
+
],
|
| 87 |
+
keywords=[]
|
| 88 |
+
)
|
| 89 |
+
node.iter = new_iter
|
| 90 |
+
self.modified = True
|
| 91 |
+
elif self.bug_type == "loop_skip" and len(node.iter.args) == 1:
|
| 92 |
+
new_iter = ast.Call(
|
| 93 |
+
func=ast.Name(id='range', ctx=ast.Load()),
|
| 94 |
+
args=[ast.BinOp(left=node.iter.args[0], op=ast.Sub(), right=ast.Constant(value=1))],
|
| 95 |
+
keywords=[]
|
| 96 |
+
)
|
| 97 |
+
node.iter = new_iter
|
| 98 |
+
self.modified = True
|
| 99 |
+
return self.generic_visit(node)
|
| 100 |
+
|
| 101 |
+
def visit_BinOp(self, node: ast.BinOp):
|
| 102 |
+
# sign_error: flip Add/Sub, wrong_operator: Add->Sub, float_precision: Div->FloorDiv
|
| 103 |
+
if not self.modified:
|
| 104 |
+
if self.bug_type in ("wrong_operator", "sign_error"):
|
| 105 |
+
if isinstance(node.op, ast.Add):
|
| 106 |
+
node.op = ast.Sub()
|
| 107 |
+
self.modified = True
|
| 108 |
+
elif isinstance(node.op, ast.Sub):
|
| 109 |
+
node.op = ast.Add()
|
| 110 |
+
self.modified = True
|
| 111 |
+
elif self.bug_type == "float_precision" and isinstance(node.op, ast.Div):
|
| 112 |
+
node.op = ast.FloorDiv()
|
| 113 |
+
self.modified = True
|
| 114 |
+
return self.generic_visit(node)
|
| 115 |
+
|
| 116 |
+
def visit_arguments(self, node: ast.arguments):
|
| 117 |
+
# swap_args: swap first two arguments of a function
|
| 118 |
+
if self.bug_type == "swap_args" and not self.modified and len(node.args) >= 2:
|
| 119 |
+
self.modified = True
|
| 120 |
+
node.args[0], node.args[1] = node.args[1], node.args[0]
|
| 121 |
+
return self.generic_visit(node)
|
| 122 |
+
|
| 123 |
+
def visit_Assign(self, node: ast.Assign):
|
| 124 |
+
# uninitialised_var: remove an assignment statement (replaced with Pass)
|
| 125 |
+
if self.bug_type == "uninitialised_var" and not self.modified:
|
| 126 |
+
self.modified = True
|
| 127 |
+
return ast.Pass()
|
| 128 |
+
return self.generic_visit(node)
|
| 129 |
+
|
| 130 |
+
# ----------------------------------------------------------------------
|
| 131 |
+
# 2. Bug database (25 bugs, categorized by difficulty)
|
| 132 |
+
# ----------------------------------------------------------------------
|
| 133 |
+
BUG_DB = {
|
| 134 |
+
"easy": {
|
| 135 |
+
"null_check": {"type": "ast", "bug_type": "null_check"},
|
| 136 |
+
"simple_typo": {"type": "ast", "bug_type": "simple_typo"},
|
| 137 |
+
"string_index": {"type": "ast", "bug_type": "string_index"},
|
| 138 |
+
"default_value": {"type": "ast", "bug_type": "default_value"},
|
| 139 |
+
"empty_return": {"type": "ast", "bug_type": "empty_return"},
|
| 140 |
+
},
|
| 141 |
+
"medium": {
|
| 142 |
+
"off_by_one": {"type": "ast", "bug_type": "off_by_one"},
|
| 143 |
+
"loop_skip": {"type": "ast", "bug_type": "loop_skip"},
|
| 144 |
+
"sign_error": {"type": "ast", "bug_type": "sign_error"},
|
| 145 |
+
"swap_args": {"type": "ast", "bug_type": "swap_args"},
|
| 146 |
+
"uninitialised_var": {"type": "ast", "bug_type": "uninitialised_var"},
|
| 147 |
+
},
|
| 148 |
+
"hard": {
|
| 149 |
+
"division_by_zero_empty": {"type": "ast", "bug_type": "division_by_zero_empty"},
|
| 150 |
+
"division_by_zero_zero": {"type": "ast", "bug_type": "division_by_zero_empty"}, # same injector
|
| 151 |
+
"float_precision": {"type": "ast", "bug_type": "float_precision"},
|
| 152 |
+
"abs_usage": {"type": "ast", "bug_type": "abs_usage"},
|
| 153 |
+
"round_error": {"type": "ast", "bug_type": "round_error"}, # can be extended
|
| 154 |
+
},
|
| 155 |
+
"harder": {
|
| 156 |
+
"missing_lock": {
|
| 157 |
+
"type": "template",
|
| 158 |
+
"buggy": "counter = 0\ndef increment():\n global counter\n counter += 1",
|
| 159 |
+
"oracle": "counter = 0\nimport threading\nlock = threading.Lock()\ndef increment():\n global counter\n with lock:\n counter += 1",
|
| 160 |
+
},
|
| 161 |
+
"double_lock": {
|
| 162 |
+
"type": "template",
|
| 163 |
+
"buggy": "import threading\nlock = threading.Lock()\ndef do_work():\n lock.acquire()\n lock.acquire()\n print('working')\n lock.release()",
|
| 164 |
+
"oracle": "import threading\nlock = threading.Lock()\ndef do_work():\n with lock:\n print('working')",
|
| 165 |
+
},
|
| 166 |
+
"global_nonatomic": {
|
| 167 |
+
"type": "template",
|
| 168 |
+
"buggy": "count = 0\ndef add():\n global count\n count = count + 1",
|
| 169 |
+
"oracle": "count = 0\ndef add():\n global count\n count += 1",
|
| 170 |
+
},
|
| 171 |
+
"thread_safe_list": {
|
| 172 |
+
"type": "template",
|
| 173 |
+
"buggy": "import threading\nitems = []\ndef append_item(item):\n items.append(item)",
|
| 174 |
+
"oracle": "import threading\nitems = []\nlock = threading.Lock()\ndef append_item(item):\n with lock:\n items.append(item)",
|
| 175 |
+
},
|
| 176 |
+
"volatile_read": {
|
| 177 |
+
"type": "template",
|
| 178 |
+
"buggy": "import threading\nstop = False\ndef worker():\n while not stop:\n pass",
|
| 179 |
+
"oracle": "import threading\nstop = False\nlock = threading.Lock()\ndef worker():\n while True:\n with lock:\n if stop:\n break",
|
| 180 |
+
},
|
| 181 |
+
},
|
| 182 |
+
"hardest": {
|
| 183 |
+
"deadlock_order": {
|
| 184 |
+
"type": "template",
|
| 185 |
+
"buggy": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock2:\n with lock1:\n pass",
|
| 186 |
+
"oracle": "import threading\nlock1 = threading.Lock()\nlock2 = threading.Lock()\ndef thread1():\n with lock1:\n with lock2:\n pass\ndef thread2():\n with lock1:\n with lock2:\n pass",
|
| 187 |
+
},
|
| 188 |
+
"nested_lock_timeout": {
|
| 189 |
+
"type": "template",
|
| 190 |
+
"buggy": "import threading\nlock = threading.Lock()\ndef work():\n lock.acquire()\n # critical section\n lock.release()",
|
| 191 |
+
"oracle": "import threading\nlock = threading.Lock()\ndef work():\n if lock.acquire(timeout=1):\n try:\n # critical section\n finally:\n lock.release()",
|
| 192 |
+
},
|
| 193 |
+
"fork_join": {
|
| 194 |
+
"type": "template",
|
| 195 |
+
"buggy": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()",
|
| 196 |
+
"oracle": "import threading\ndef worker():\n pass\nt = threading.Thread(target=worker)\nt.start()\nt.join()",
|
| 197 |
+
},
|
| 198 |
+
"mutex_release": {
|
| 199 |
+
"type": "template",
|
| 200 |
+
"buggy": "import threading\nlock = threading.Lock()\ndef thread_A():\n lock.acquire()\n lock.release()\ndef thread_B():\n lock.release()",
|
| 201 |
+
"oracle": "import threading\nlock = threading.Lock()\ndef thread_A():\n with lock:\n pass\ndef thread_B():\n with lock:\n pass",
|
| 202 |
+
},
|
| 203 |
+
"race_on_init": {
|
| 204 |
+
"type": "template",
|
| 205 |
+
"buggy": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nprint(items)",
|
| 206 |
+
"oracle": "import threading\nitems = []\ndef init():\n global items\n items = [1,2,3]\nt = threading.Thread(target=init)\nt.start()\nt.join()\nprint(items)",
|
| 207 |
+
},
|
| 208 |
+
},
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
# ----------------------------------------------------------------------
|
| 212 |
+
# 3. Derived helpers
|
| 213 |
+
# ----------------------------------------------------------------------
|
| 214 |
+
TASK_BUG_MAP = {level: list(bugs.keys()) for level, bugs in BUG_DB.items()}
|
| 215 |
+
|
| 216 |
+
TEMPLATE_BUGS = {}
|
| 217 |
+
for level, bugs in BUG_DB.items():
|
| 218 |
+
for bug_id, bug in bugs.items():
|
| 219 |
+
if bug["type"] == "template":
|
| 220 |
+
TEMPLATE_BUGS[bug_id] = (bug["buggy"], bug["oracle"])
|
| 221 |
+
|
| 222 |
+
# ----------------------------------------------------------------------
|
| 223 |
+
# 4. RedTeam Controller (task‑aware)
|
| 224 |
+
# ----------------------------------------------------------------------
|
| 225 |
+
@dataclass
|
| 226 |
+
class RedTeam:
|
| 227 |
+
task: str
|
| 228 |
+
seed: Optional[int] = 42
|
| 229 |
+
noise_prob: float = 0.2
|
| 230 |
+
_random: random.Random = field(init=False)
|
| 231 |
+
|
| 232 |
+
def __post_init__(self):
|
| 233 |
+
self._random = random.Random(self.seed)
|
| 234 |
+
|
| 235 |
+
def inject_bug(self, original_code: str) -> Tuple[str, str, str, str]:
|
| 236 |
+
"""
|
| 237 |
+
Returns: (buggy_code, bug_type, description, oracle_fix)
|
| 238 |
+
Selects a bug appropriate for the task difficulty.
|
| 239 |
+
"""
|
| 240 |
+
bug_list = TASK_BUG_MAP.get(self.task, ["null_check"])
|
| 241 |
+
bug_type = self._random.choice(bug_list)
|
| 242 |
+
|
| 243 |
+
# Template bug: return hardcoded buggy + oracle
|
| 244 |
+
if bug_type in TEMPLATE_BUGS:
|
| 245 |
+
buggy_code, oracle_code = TEMPLATE_BUGS[bug_type]
|
| 246 |
+
description = f"Template bug: {bug_type}"
|
| 247 |
+
if self._random.random() < self.noise_prob:
|
| 248 |
+
buggy_code += "\n# TODO: refactor later"
|
| 249 |
+
return buggy_code, bug_type, description, oracle_code
|
| 250 |
+
|
| 251 |
+
# AST injection
|
| 252 |
+
try:
|
| 253 |
+
tree = ast.parse(original_code)
|
| 254 |
+
except SyntaxError:
|
| 255 |
+
return original_code, "parse_error", "Syntax error in original code", original_code
|
| 256 |
+
|
| 257 |
+
injector = ASTBugInjector(bug_type)
|
| 258 |
+
modified_tree = injector.visit(tree)
|
| 259 |
+
ast.fix_missing_locations(modified_tree)
|
| 260 |
+
|
| 261 |
+
if injector.modified:
|
| 262 |
+
buggy_code = ast.unparse(modified_tree)
|
| 263 |
+
oracle_fix = original_code
|
| 264 |
+
description = f"AST bug: {bug_type}"
|
| 265 |
+
else:
|
| 266 |
+
buggy_code = original_code
|
| 267 |
+
oracle_fix = original_code
|
| 268 |
+
bug_type = "no_op"
|
| 269 |
+
description = "No suitable code structure found for injection"
|
| 270 |
+
|
| 271 |
+
if self._random.random() < self.noise_prob:
|
| 272 |
+
buggy_code += "\n# TODO: refactor later"
|
| 273 |
+
|
| 274 |
+
return buggy_code, bug_type, description, oracle_fix
|
rubrics.py
CHANGED
|
@@ -1,123 +1,136 @@
|
|
| 1 |
-
# rubrics.py – Self-contained Rubrics (no external OpenEnv dependency)
|
| 2 |
-
|
| 3 |
-
class Rubric:
|
| 4 |
-
"""Minimal Rubric base – compatible with OpenEnv but self‑contained."""
|
| 5 |
-
def __call__(self, env, action, obs, reward, done, info):
|
| 6 |
-
return 0.0
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
# --------------------------------------------------------------------------------
|
| 10 |
-
# 1. TOOL‑USAGE BONUS
|
| 11 |
-
# --------------------------------------------------------------------------------
|
| 12 |
-
class ToolUsageRubric(Rubric):
|
| 13 |
-
def __init__(self, bonus: float = 0.05):
|
| 14 |
-
self.bonus = bonus
|
| 15 |
-
|
| 16 |
-
def __call__(self, env, action, obs, reward, done, info):
|
| 17 |
-
score = 0.0
|
| 18 |
-
action_type = info.get("action_type", "")
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
score += 0.
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
#
|
| 37 |
-
#
|
| 38 |
-
# --
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
if
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
# --------------------------------------------------------------------------------
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# rubrics.py – Self-contained Rubrics (no external OpenEnv dependency)
|
| 2 |
+
|
| 3 |
+
class Rubric:
|
| 4 |
+
"""Minimal Rubric base – compatible with OpenEnv but self‑contained."""
|
| 5 |
+
def __call__(self, env, action, obs, reward, done, info):
|
| 6 |
+
return 0.0
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# --------------------------------------------------------------------------------
|
| 10 |
+
# 1. TOOL‑USAGE BONUS
|
| 11 |
+
# --------------------------------------------------------------------------------
|
| 12 |
+
class ToolUsageRubric(Rubric):
|
| 13 |
+
def __init__(self, bonus: float = 0.05):
|
| 14 |
+
self.bonus = bonus
|
| 15 |
+
|
| 16 |
+
def __call__(self, env, action, obs, reward, done, info):
|
| 17 |
+
score = 0.0
|
| 18 |
+
action_type = info.get("action_type", "")
|
| 19 |
+
# Use pre-action flags from `info` so first-use bonuses are
|
| 20 |
+
# computed correctly even though env flags are mutated in-step.
|
| 21 |
+
prev_tests_run = info.get("prev_tests_run", env._tests_run)
|
| 22 |
+
prev_linter_run = info.get("prev_linter_run", env._linter_run)
|
| 23 |
+
prev_docs_queried = info.get("prev_docs_queried", env._docs_queried)
|
| 24 |
+
|
| 25 |
+
if action_type == "run_tests":
|
| 26 |
+
if not prev_tests_run:
|
| 27 |
+
score += self.bonus
|
| 28 |
+
score += 0.015
|
| 29 |
+
elif action_type == "run_linter":
|
| 30 |
+
if not prev_linter_run:
|
| 31 |
+
score += self.bonus
|
| 32 |
+
score += 0.015
|
| 33 |
+
elif action_type == "query_docs":
|
| 34 |
+
if not prev_docs_queried:
|
| 35 |
+
score += self.bonus * 0.5
|
| 36 |
+
# Encourage docs usage when it is likely useful:
|
| 37 |
+
# - early exploration phase
|
| 38 |
+
# - non-trivial query text
|
| 39 |
+
if env._step_count <= 4 and info.get("docs_query_len", 0) >= 8:
|
| 40 |
+
score += 0.01
|
| 41 |
+
# Discourage repeated docs calls after the first-use signal.
|
| 42 |
+
if prev_docs_queried:
|
| 43 |
+
score -= 0.01
|
| 44 |
+
elif action_type == "ask_question" and env._step_count <= 3:
|
| 45 |
+
score += 0.02
|
| 46 |
+
return score
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# --------------------------------------------------------------------------------
|
| 50 |
+
# 2. DELTA‑BASED REWARDS
|
| 51 |
+
# --------------------------------------------------------------------------------
|
| 52 |
+
class TestDeltaRubric(Rubric):
|
| 53 |
+
def __init__(self, weight: float = 0.3):
|
| 54 |
+
self.weight = weight
|
| 55 |
+
|
| 56 |
+
def __call__(self, env, action, obs, reward, done, info):
|
| 57 |
+
delta = env._current_test_score - env._previous_test_score
|
| 58 |
+
effective = self.weight
|
| 59 |
+
if info.get("action_type") == "propose_fix":
|
| 60 |
+
effective *= 0.4
|
| 61 |
+
return effective * delta
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class LintDeltaRubric(Rubric):
|
| 65 |
+
def __init__(self, weight: float = 0.3):
|
| 66 |
+
self.weight = weight
|
| 67 |
+
|
| 68 |
+
def __call__(self, env, action, obs, reward, done, info):
|
| 69 |
+
delta = env._current_lint_score - env._previous_lint_score
|
| 70 |
+
effective = self.weight * 0.5
|
| 71 |
+
if info.get("action_type") == "propose_fix":
|
| 72 |
+
effective *= 0.4
|
| 73 |
+
return effective * delta
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# --------------------------------------------------------------------------------
|
| 77 |
+
# 3. TERMINAL SUCCESS BONUS
|
| 78 |
+
# --------------------------------------------------------------------------------
|
| 79 |
+
class TerminalSuccessRubric(Rubric):
|
| 80 |
+
def __call__(self, env, action, obs, reward, done, info):
|
| 81 |
+
if info.get("action_type") != "propose_fix":
|
| 82 |
+
return 0.0
|
| 83 |
+
score = 0.0
|
| 84 |
+
if env._current_test_score > 0.95:
|
| 85 |
+
score += 0.4
|
| 86 |
+
elif env._current_test_score > 0.85:
|
| 87 |
+
score += 0.2
|
| 88 |
+
return score
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# --------------------------------------------------------------------------------
|
| 92 |
+
# 4. EXPLORATION & DIVERSITY
|
| 93 |
+
# --------------------------------------------------------------------------------
|
| 94 |
+
class ExplorationRubric(Rubric):
|
| 95 |
+
def __init__(self, penalty: float = -0.05, bonus: float = 0.021):
|
| 96 |
+
self.penalty = penalty
|
| 97 |
+
self.bonus = bonus
|
| 98 |
+
|
| 99 |
+
def __call__(self, env, action, obs, reward, done, info):
|
| 100 |
+
if len(env._action_history) < 3:
|
| 101 |
+
return 0.0
|
| 102 |
+
recent = env._action_history[-3:]
|
| 103 |
+
unique = len(set(recent))
|
| 104 |
+
if unique == 1:
|
| 105 |
+
return self.penalty
|
| 106 |
+
elif unique == 3:
|
| 107 |
+
return self.bonus
|
| 108 |
+
return 0.0
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# --------------------------------------------------------------------------------
|
| 112 |
+
# 5. ANTI‑HACKING & CONSISTENCY
|
| 113 |
+
# --------------------------------------------------------------------------------
|
| 114 |
+
class AntiHackingRubric(Rubric):
|
| 115 |
+
def __call__(self, env, action, obs, reward, done, info):
|
| 116 |
+
if info.get("action_type") != "propose_fix":
|
| 117 |
+
return 0.0
|
| 118 |
+
score = 0.0
|
| 119 |
+
if not env._tests_run:
|
| 120 |
+
score -= 0.25
|
| 121 |
+
if env._step_count < 2:
|
| 122 |
+
score -= 0.1
|
| 123 |
+
if env._tests_run and env._linter_run:
|
| 124 |
+
score += 0.02
|
| 125 |
+
return score
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# --------------------------------------------------------------------------------
|
| 129 |
+
# 6. STEP PENALTY
|
| 130 |
+
# --------------------------------------------------------------------------------
|
| 131 |
+
class StepPenaltyRubric(Rubric):
|
| 132 |
+
def __init__(self, penalty: float = -0.01):
|
| 133 |
+
self.penalty = penalty
|
| 134 |
+
|
| 135 |
+
def __call__(self, env, action, obs, reward, done, info):
|
| 136 |
+
return self.penalty
|
test_runner.py
CHANGED
|
@@ -1,181 +1,208 @@
|
|
| 1 |
-
# test_runner.py – Full production version with continuous scoring, dynamic function detection, and randomised tests
|
| 2 |
-
import subprocess
|
| 3 |
-
import tempfile
|
| 4 |
-
import os
|
| 5 |
-
import json
|
| 6 |
-
import ast
|
| 7 |
-
import random
|
| 8 |
-
import sys
|
| 9 |
-
from typing import Tuple, List, Any, Optional
|
| 10 |
-
from dataclasses import dataclass
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
elif
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# test_runner.py – Full production version with continuous scoring, dynamic function detection, and randomised tests
|
| 2 |
+
import subprocess
|
| 3 |
+
import tempfile
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import ast
|
| 7 |
+
import random
|
| 8 |
+
import sys
|
| 9 |
+
from typing import Tuple, List, Any, Optional
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
|
| 12 |
+
# Bridge fine-grained RedTeam ids to canonical TestRunner families.
|
| 13 |
+
# This keeps evaluation stable even when bug generators use richer labels.
|
| 14 |
+
BUG_ID_CANONICAL_MAP = {
|
| 15 |
+
# Easy-family bugs on `get_user`-style behavior.
|
| 16 |
+
"simple_typo": "null_check",
|
| 17 |
+
"default_value": "null_check",
|
| 18 |
+
"empty_return": "null_check",
|
| 19 |
+
|
| 20 |
+
# Medium arithmetic/control-flow aliases.
|
| 21 |
+
"loop_skip": "off_by_one",
|
| 22 |
+
"sign_error": "wrong_operator",
|
| 23 |
+
|
| 24 |
+
# Hard numeric-safety aliases.
|
| 25 |
+
"division_by_zero_empty": "division_by_zero",
|
| 26 |
+
"division_by_zero_zero": "division_by_zero",
|
| 27 |
+
"float_precision": "division_by_zero",
|
| 28 |
+
"abs_usage": "division_by_zero",
|
| 29 |
+
"round_error": "division_by_zero",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class TestRunner:
|
| 34 |
+
bug_id: str
|
| 35 |
+
timeout_sec: int = 5
|
| 36 |
+
max_memory_mb: int = 256
|
| 37 |
+
fuzz_rounds: int = 3 # number of random test cases per bug
|
| 38 |
+
|
| 39 |
+
def run_tests(self, fix_code: str) -> Tuple[float, str]:
|
| 40 |
+
"""
|
| 41 |
+
Returns (score, output_message) where score is proportion of passed tests (0.0–1.0).
|
| 42 |
+
"""
|
| 43 |
+
# 1. Detect the function defined in the agent's code (dynamic)
|
| 44 |
+
func_name = self._get_defined_function_name(fix_code)
|
| 45 |
+
if not func_name:
|
| 46 |
+
return 0.0, "No function definition found in agent code"
|
| 47 |
+
|
| 48 |
+
# 2. Normalize bug id so broader RedTeam ids still hit meaningful tests.
|
| 49 |
+
canonical_bug_id = self._canonical_bug_id()
|
| 50 |
+
|
| 51 |
+
# 3. Generate the test script (includes fixed + fuzzed test cases)
|
| 52 |
+
test_script = self._generate_test_script(fix_code, func_name, canonical_bug_id)
|
| 53 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
|
| 54 |
+
f.write(test_script)
|
| 55 |
+
tmp_path = f.name
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
# Resource limiting (Linux only; fallback otherwise)
|
| 59 |
+
try:
|
| 60 |
+
import resource
|
| 61 |
+
resource.setrlimit(resource.RLIMIT_AS, (self.max_memory_mb * 1024 * 1024, self.max_memory_mb * 1024 * 1024))
|
| 62 |
+
except Exception:
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
+
result = subprocess.run(
|
| 66 |
+
[sys.executable, tmp_path],
|
| 67 |
+
capture_output=True,
|
| 68 |
+
text=True,
|
| 69 |
+
timeout=self.timeout_sec,
|
| 70 |
+
encoding='utf-8'
|
| 71 |
+
)
|
| 72 |
+
# Parse JSON output
|
| 73 |
+
try:
|
| 74 |
+
data = json.loads(result.stdout.strip())
|
| 75 |
+
passed = data.get("passed", 0)
|
| 76 |
+
total = data.get("total", 1)
|
| 77 |
+
score = passed / total if total > 0 else 0.0
|
| 78 |
+
return score, result.stdout.strip()
|
| 79 |
+
except json.JSONDecodeError:
|
| 80 |
+
# Fallback: look for "True" (legacy)
|
| 81 |
+
if "True" in result.stdout:
|
| 82 |
+
return 1.0, result.stdout
|
| 83 |
+
return 0.0, result.stdout
|
| 84 |
+
except subprocess.TimeoutExpired:
|
| 85 |
+
return 0.0, "Test execution timed out"
|
| 86 |
+
except Exception as e:
|
| 87 |
+
return 0.0, f"Unexpected error: {str(e)}"
|
| 88 |
+
finally:
|
| 89 |
+
try:
|
| 90 |
+
os.unlink(tmp_path)
|
| 91 |
+
except:
|
| 92 |
+
pass
|
| 93 |
+
|
| 94 |
+
def _get_defined_function_name(self, code: str) -> Optional[str]:
|
| 95 |
+
"""Extract the target function name from the code.
|
| 96 |
+
Looks for a function named 'fix' first; otherwise returns the first function found.
|
| 97 |
+
"""
|
| 98 |
+
try:
|
| 99 |
+
tree = ast.parse(code)
|
| 100 |
+
first_func = None
|
| 101 |
+
for node in ast.walk(tree):
|
| 102 |
+
if isinstance(node, ast.FunctionDef):
|
| 103 |
+
if node.name == "fix":
|
| 104 |
+
return "fix"
|
| 105 |
+
if first_func is None:
|
| 106 |
+
first_func = node.name
|
| 107 |
+
return first_func # fallback if no 'fix' function exists
|
| 108 |
+
except SyntaxError:
|
| 109 |
+
pass
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
def _canonical_bug_id(self) -> str:
|
| 113 |
+
"""Return canonical bug family used by this test harness."""
|
| 114 |
+
return BUG_ID_CANONICAL_MAP.get(self.bug_id, self.bug_id)
|
| 115 |
+
|
| 116 |
+
def _generate_test_script(self, fix_code: str, func_name: str, canonical_bug_id: str) -> str:
|
| 117 |
+
"""Generate a test script that runs fixed + fuzzed test cases and outputs JSON."""
|
| 118 |
+
test_cases = self._get_test_cases(canonical_bug_id, func_name)
|
| 119 |
+
fuzzed_cases = self._generate_fuzzed_cases(canonical_bug_id, func_name)
|
| 120 |
+
all_cases = test_cases + fuzzed_cases
|
| 121 |
+
|
| 122 |
+
lines = []
|
| 123 |
+
lines.append(fix_code)
|
| 124 |
+
lines.append("")
|
| 125 |
+
lines.append("import json")
|
| 126 |
+
lines.append("")
|
| 127 |
+
lines.append("def run_tests():")
|
| 128 |
+
lines.append(f" test_cases = {json.dumps(all_cases)}")
|
| 129 |
+
lines.append(" passed = 0")
|
| 130 |
+
lines.append(" total = len(test_cases)")
|
| 131 |
+
lines.append(" for args, expected in test_cases:")
|
| 132 |
+
lines.append(f" try:")
|
| 133 |
+
lines.append(f" result = {func_name}(*args) if isinstance(args, list) else {func_name}(args)")
|
| 134 |
+
lines.append(f" if result == expected:")
|
| 135 |
+
lines.append(f" passed += 1")
|
| 136 |
+
lines.append(f" except Exception:")
|
| 137 |
+
lines.append(f" pass")
|
| 138 |
+
lines.append(" return {'passed': passed, 'total': total}")
|
| 139 |
+
lines.append("")
|
| 140 |
+
lines.append("if __name__ == '__main__':")
|
| 141 |
+
lines.append(" result = run_tests()")
|
| 142 |
+
lines.append(" print(json.dumps(result))")
|
| 143 |
+
return "\n".join(lines)
|
| 144 |
+
|
| 145 |
+
def _get_test_cases(self, canonical_bug_id: str, func_name: str) -> List[Tuple[List[Any], Any]]:
|
| 146 |
+
"""
|
| 147 |
+
Return a list of (arguments, expected_output) for the given bug_id.
|
| 148 |
+
Uses the actual function name (dynamic) for consistency.
|
| 149 |
+
"""
|
| 150 |
+
if canonical_bug_id == "null_check":
|
| 151 |
+
return [
|
| 152 |
+
([{"users": {"alice": "Alice"}, "id": "bob"}], None), # missing key should not crash
|
| 153 |
+
([{"users": {"alice": "Alice"}, "id": "alice"}], "Alice"),
|
| 154 |
+
]
|
| 155 |
+
elif canonical_bug_id == "off_by_one":
|
| 156 |
+
return [
|
| 157 |
+
([[1,2,3,4]], 4),
|
| 158 |
+
([[]], 0),
|
| 159 |
+
]
|
| 160 |
+
elif canonical_bug_id == "division_by_zero":
|
| 161 |
+
return [
|
| 162 |
+
([[]], 0),
|
| 163 |
+
([[1,2,3]], 2.0),
|
| 164 |
+
]
|
| 165 |
+
elif canonical_bug_id == "wrong_operator":
|
| 166 |
+
return [
|
| 167 |
+
([5,3], 8),
|
| 168 |
+
([-1,1], 0),
|
| 169 |
+
]
|
| 170 |
+
else:
|
| 171 |
+
# For missing_lock, deadlock_order, etc., return empty list (will be handled gracefully)
|
| 172 |
+
return []
|
| 173 |
+
|
| 174 |
+
def _generate_fuzzed_cases(self, canonical_bug_id: str, func_name: str) -> List[Tuple[List[Any], Any]]:
|
| 175 |
+
"""
|
| 176 |
+
Generate random test cases to prevent memorisation.
|
| 177 |
+
Only for bugs where meaningful fuzzing is possible.
|
| 178 |
+
"""
|
| 179 |
+
cases = []
|
| 180 |
+
if canonical_bug_id == "null_check":
|
| 181 |
+
# Random users dictionary and random ids
|
| 182 |
+
for _ in range(self.fuzz_rounds):
|
| 183 |
+
users = {f"user_{i}": f"name_{i}" for i in range(random.randint(1, 5))}
|
| 184 |
+
# Pick existing or missing key
|
| 185 |
+
if random.random() > 0.5:
|
| 186 |
+
key = random.choice(list(users.keys()))
|
| 187 |
+
expected = users[key]
|
| 188 |
+
else:
|
| 189 |
+
key = "missing_" + str(random.randint(100, 999))
|
| 190 |
+
expected = None
|
| 191 |
+
cases.append(([{"users": users, "id": key}], expected))
|
| 192 |
+
elif canonical_bug_id == "off_by_one":
|
| 193 |
+
for _ in range(self.fuzz_rounds):
|
| 194 |
+
length = random.randint(0, 10)
|
| 195 |
+
arr = list(range(length))
|
| 196 |
+
cases.append(([arr], length))
|
| 197 |
+
elif canonical_bug_id == "division_by_zero":
|
| 198 |
+
for _ in range(self.fuzz_rounds):
|
| 199 |
+
length = random.randint(0, 10)
|
| 200 |
+
data = [random.randint(-100, 100) for _ in range(length)]
|
| 201 |
+
expected = sum(data)/length if length else 0
|
| 202 |
+
cases.append(([data], expected))
|
| 203 |
+
elif canonical_bug_id == "wrong_operator":
|
| 204 |
+
for _ in range(self.fuzz_rounds):
|
| 205 |
+
a = random.randint(-100, 100)
|
| 206 |
+
b = random.randint(-100, 100)
|
| 207 |
+
cases.append(([a, b], a + b))
|
| 208 |
+
return cases
|
training.py
CHANGED
|
@@ -1,708 +1,792 @@
|
|
| 1 |
-
# training.py
|
| 2 |
-
import json
|
| 3 |
-
import
|
| 4 |
-
import torch
|
| 5 |
-
|
| 6 |
-
from
|
| 7 |
-
from
|
| 8 |
-
import
|
| 9 |
-
import
|
| 10 |
-
import
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
from
|
| 15 |
-
from
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
return
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
return
|
| 83 |
-
elif action.action_type == "
|
| 84 |
-
return
|
| 85 |
-
elif action.action_type == "
|
| 86 |
-
return
|
| 87 |
-
elif action.action_type == "
|
| 88 |
-
return
|
| 89 |
-
elif action.action_type == "
|
| 90 |
-
return
|
| 91 |
-
|
| 92 |
-
return
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
model = FastLanguageModel.
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
#
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
'{"action_type": "
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
#
|
| 433 |
-
#
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
)
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
if
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# training.py
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch.optim import AdamW
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import List, Dict, Tuple, Optional
|
| 9 |
+
import numpy as np
|
| 10 |
+
import re
|
| 11 |
+
import random
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
|
| 14 |
+
from unsloth import FastLanguageModel
|
| 15 |
+
from transformers import TrainingArguments
|
| 16 |
+
from trl import SFTTrainer
|
| 17 |
+
from datasets import Dataset
|
| 18 |
+
|
| 19 |
+
# Import your environment and actions (unchanged)
|
| 20 |
+
from environment import CodeReviewEnv
|
| 21 |
+
from redteam import BUG_DB
|
| 22 |
+
from models import (
|
| 23 |
+
RunTests, RunLinter, Inspect,
|
| 24 |
+
ProposeFix, WriteComment, AskQuestion,
|
| 25 |
+
Done, Skip , QueryDocs
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# ======================================================================
|
| 29 |
+
# 1. ACTION PARSING (improved with fallback)
|
| 30 |
+
# ======================================================================
|
| 31 |
+
@dataclass
|
| 32 |
+
class AgentAction:
|
| 33 |
+
action_type: str
|
| 34 |
+
content: Optional[str] = None
|
| 35 |
+
|
| 36 |
+
def parse_action(output: str) -> AgentAction:
|
| 37 |
+
"""Robust JSON parsing with regex fallback and keyword detection."""
|
| 38 |
+
# Try strict JSON first
|
| 39 |
+
try:
|
| 40 |
+
data = json.loads(output)
|
| 41 |
+
return AgentAction(
|
| 42 |
+
action_type=data.get("action_type", "").lower(),
|
| 43 |
+
content=data.get("content")
|
| 44 |
+
)
|
| 45 |
+
except:
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
# Try to extract JSON from markdown blocks
|
| 49 |
+
json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', output, re.DOTALL)
|
| 50 |
+
if json_match:
|
| 51 |
+
try:
|
| 52 |
+
data = json.loads(json_match.group(1))
|
| 53 |
+
return AgentAction(
|
| 54 |
+
action_type=data.get("action_type", "").lower(),
|
| 55 |
+
content=data.get("content")
|
| 56 |
+
)
|
| 57 |
+
except:
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
# Try to find "action_type" field with regex
|
| 61 |
+
action_pattern = r'"action_type"\s*:\s*"(\w+)"'
|
| 62 |
+
match = re.search(action_pattern, output)
|
| 63 |
+
if match:
|
| 64 |
+
return AgentAction(action_type=match.group(1).lower())
|
| 65 |
+
|
| 66 |
+
# Keyword detection as last resort
|
| 67 |
+
output_lower = output.lower()
|
| 68 |
+
if "test" in output_lower:
|
| 69 |
+
return AgentAction("run_tests")
|
| 70 |
+
if "lint" in output_lower:
|
| 71 |
+
return AgentAction("run_linter")
|
| 72 |
+
if "inspect" in output_lower:
|
| 73 |
+
return AgentAction("inspect")
|
| 74 |
+
if "doc" in output_lower or "documentation" in output_lower:
|
| 75 |
+
# Bridge natural language mentions to rltool-backed retrieval action.
|
| 76 |
+
return AgentAction("query_docs", "bug fix guidance")
|
| 77 |
+
|
| 78 |
+
return AgentAction("invalid", output)
|
| 79 |
+
|
| 80 |
+
def map_to_env(action: AgentAction):
|
| 81 |
+
if action.action_type == "run_tests":
|
| 82 |
+
return RunTests()
|
| 83 |
+
elif action.action_type == "run_linter":
|
| 84 |
+
return RunLinter()
|
| 85 |
+
elif action.action_type == "inspect":
|
| 86 |
+
return Inspect()
|
| 87 |
+
elif action.action_type == "fix":
|
| 88 |
+
return ProposeFix(fix_code=action.content or "")
|
| 89 |
+
elif action.action_type == "comment":
|
| 90 |
+
return WriteComment(comment_text=action.content or "")
|
| 91 |
+
elif action.action_type == "question":
|
| 92 |
+
return AskQuestion(question=action.content or "")
|
| 93 |
+
elif action.action_type == "query_docs": # <-- new
|
| 94 |
+
return QueryDocs(query_topic=action.content or "")
|
| 95 |
+
elif action.action_type == "done":
|
| 96 |
+
return Done()
|
| 97 |
+
else:
|
| 98 |
+
return Skip()
|
| 99 |
+
|
| 100 |
+
# ======================================================================
|
| 101 |
+
# 2. MODEL SETUP (stabilised LoRA)
|
| 102 |
+
# ======================================================================
|
| 103 |
+
def load_model():
|
| 104 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 105 |
+
model_name="unsloth/gemma-2-2b-it-bnb-4bit",
|
| 106 |
+
max_seq_length=2048,
|
| 107 |
+
load_in_4bit=True,
|
| 108 |
+
)
|
| 109 |
+
# FIXED: Lower rank (16), dropout=0 for stability
|
| 110 |
+
model = FastLanguageModel.get_peft_model(
|
| 111 |
+
model,
|
| 112 |
+
r=16, # was 64 → causes collapse
|
| 113 |
+
target_modules=[
|
| 114 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 115 |
+
"gate_proj", "up_proj", "down_proj"
|
| 116 |
+
],
|
| 117 |
+
lora_alpha=32, # adjusted for r=16
|
| 118 |
+
lora_dropout=0.0, # dropout can cause empty outputs
|
| 119 |
+
)
|
| 120 |
+
# Ensure tokenizer has correct chat template for Gemma-2
|
| 121 |
+
if tokenizer.chat_template is None:
|
| 122 |
+
tokenizer.chat_template = "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}<start_of_turn>user\n{{ message['content'] }}<end_of_turn>\n<start_of_turn>model\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}<end_of_turn>\n{% endif %}{% endfor %}"
|
| 123 |
+
return model, tokenizer
|
| 124 |
+
|
| 125 |
+
# ======================================================================
|
| 126 |
+
# 3. MODEL SANITY CHECK (new – ensures model can generate text)
|
| 127 |
+
# ======================================================================
|
| 128 |
+
def test_model_sanity(model, tokenizer) -> bool:
|
| 129 |
+
print("\n" + "="*60)
|
| 130 |
+
print("SANITY CHECK: Testing base model generation")
|
| 131 |
+
print("="*60)
|
| 132 |
+
test_prompt = "Hello, how are you?"
|
| 133 |
+
messages = [{"role": "user", "content": test_prompt}]
|
| 134 |
+
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 135 |
+
inputs = tokenizer(formatted, return_tensors="pt").to("cuda")
|
| 136 |
+
with torch.no_grad():
|
| 137 |
+
outputs = model.generate(
|
| 138 |
+
**inputs,
|
| 139 |
+
max_new_tokens=30,
|
| 140 |
+
do_sample=True,
|
| 141 |
+
temperature=0.7,
|
| 142 |
+
min_new_tokens=1,
|
| 143 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 144 |
+
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 145 |
+
)
|
| 146 |
+
generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
|
| 147 |
+
response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 148 |
+
print(f"Prompt: {test_prompt}")
|
| 149 |
+
print(f"Response: {repr(response)}")
|
| 150 |
+
if len(response) == 0:
|
| 151 |
+
print("❌ Model produces empty output – cannot train.")
|
| 152 |
+
return False
|
| 153 |
+
print("✓ Model sanity check PASSED\n")
|
| 154 |
+
return True
|
| 155 |
+
|
| 156 |
+
# ======================================================================
|
| 157 |
+
# 4. SUPERVISED WARM-UP (teaches JSON output)
|
| 158 |
+
# ======================================================================
|
| 159 |
+
def supervised_warmup(model, tokenizer, n_examples=500, epochs=2):
|
| 160 |
+
print("\n" + "="*60)
|
| 161 |
+
print("SUPERVISED WARM-UP: Teaching JSON format")
|
| 162 |
+
print("="*60)
|
| 163 |
+
|
| 164 |
+
examples = []
|
| 165 |
+
action_templates = [
|
| 166 |
+
'{"action_type": "run_tests"}',
|
| 167 |
+
'{"action_type": "run_linter"}',
|
| 168 |
+
'{"action_type": "inspect"}',
|
| 169 |
+
'{"action_type": "query_docs", "content": "python keyerror handling"}',
|
| 170 |
+
'{"action_type": "fix", "content": "def corrected():\n pass"}',
|
| 171 |
+
'{"action_type": "comment", "content": "This looks good."}',
|
| 172 |
+
'{"action_type": "question", "content": "Why is this variable used?"}',
|
| 173 |
+
'{"action_type": "done"}',
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
for i in range(n_examples):
|
| 177 |
+
code = f"def example_{i}():\n return {i % 10}"
|
| 178 |
+
last_outputs = [
|
| 179 |
+
"Tests passed: 2/3",
|
| 180 |
+
"Linter found 1 error",
|
| 181 |
+
"Inspection complete",
|
| 182 |
+
"No previous action",
|
| 183 |
+
]
|
| 184 |
+
last_output = random.choice(last_outputs)
|
| 185 |
+
# Use same prompt structure as build_prompt
|
| 186 |
+
prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix and name your proposed fix function fix.
|
| 187 |
+
|
| 188 |
+
The developer has a **{author_personality}** personality and will only accept if you provide solid evidence:
|
| 189 |
+
- Tests pass (high pass ratio)
|
| 190 |
+
- Lint is clean (zero errors)
|
| 191 |
+
- Documentation or references are provided
|
| 192 |
+
- Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
|
| 193 |
+
|
| 194 |
+
Workflow:
|
| 195 |
+
1. Use `inspect` to understand the code.
|
| 196 |
+
2. Use `run_tests` and `run_linter` to gather evidence.
|
| 197 |
+
3. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
|
| 198 |
+
4. If the developer pushes back, read their response carefully and address their specific concern.
|
| 199 |
+
5. Once convinced, use `done` to finish.
|
| 200 |
+
|
| 201 |
+
Code:
|
| 202 |
+
{obs.code_snippet}
|
| 203 |
+
|
| 204 |
+
Author says:
|
| 205 |
+
{author_msg if author_msg else "(no response yet – start with inspection)"}
|
| 206 |
+
|
| 207 |
+
Last tool output:
|
| 208 |
+
{tool_output if tool_output else "(none)"}
|
| 209 |
+
|
| 210 |
+
Available actions:
|
| 211 |
+
run_tests, run_linter, inspect, query_docs, fix, comment, question, done
|
| 212 |
+
|
| 213 |
+
Respond ONLY in JSON:
|
| 214 |
+
{{"action_type": "...", "content": "..."}}"""
|
| 215 |
+
|
| 216 |
+
action_json = random.choice(action_templates)
|
| 217 |
+
messages = [
|
| 218 |
+
{"role": "user", "content": prompt},
|
| 219 |
+
{"role": "assistant", "content": action_json}
|
| 220 |
+
]
|
| 221 |
+
full_text = tokenizer.apply_chat_template(messages, tokenize=False)
|
| 222 |
+
examples.append({"text": full_text})
|
| 223 |
+
|
| 224 |
+
dataset = Dataset.from_list(examples)
|
| 225 |
+
trainer = SFTTrainer(
|
| 226 |
+
model=model,
|
| 227 |
+
tokenizer=tokenizer,
|
| 228 |
+
train_dataset=dataset,
|
| 229 |
+
dataset_text_field="text",
|
| 230 |
+
max_seq_length=512,
|
| 231 |
+
args=TrainingArguments(
|
| 232 |
+
output_dir="warmup_output",
|
| 233 |
+
num_train_epochs=epochs,
|
| 234 |
+
per_device_train_batch_size=4,
|
| 235 |
+
gradient_accumulation_steps=2,
|
| 236 |
+
learning_rate=2e-5,
|
| 237 |
+
logging_steps=50,
|
| 238 |
+
save_strategy="no",
|
| 239 |
+
fp16=True,
|
| 240 |
+
),
|
| 241 |
+
)
|
| 242 |
+
print(f"Training on {n_examples} examples for {epochs} epochs...")
|
| 243 |
+
trainer.train()
|
| 244 |
+
print("✓ Warm-up complete\n")
|
| 245 |
+
|
| 246 |
+
# ======================================================================
|
| 247 |
+
# 5. ACTION GENERATION WITH LOGPROB TRACKING (fixed)
|
| 248 |
+
# ======================================================================
|
| 249 |
+
def generate_action_with_logprob(
|
| 250 |
+
prompt: str,
|
| 251 |
+
model,
|
| 252 |
+
tokenizer,
|
| 253 |
+
temperature: float = 0.0, # changed: greedy by default for stability
|
| 254 |
+
max_retries: int = 2
|
| 255 |
+
) -> Tuple[str, float]:
|
| 256 |
+
"""Generate action using correct chat template, with fallback."""
|
| 257 |
+
messages = [{"role": "user", "content": prompt}]
|
| 258 |
+
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 259 |
+
inputs = tokenizer(formatted, return_tensors="pt").to("cuda")
|
| 260 |
+
|
| 261 |
+
for attempt in range(max_retries):
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
outputs = model.generate(
|
| 264 |
+
**inputs,
|
| 265 |
+
max_new_tokens=128,
|
| 266 |
+
do_sample=(temperature > 0),
|
| 267 |
+
temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
|
| 268 |
+
min_new_tokens=1,
|
| 269 |
+
return_dict_in_generate=True,
|
| 270 |
+
output_scores=True,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
generated_ids = outputs.sequences[0][inputs['input_ids'].shape[1]:]
|
| 274 |
+
action_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 275 |
+
|
| 276 |
+
# Compute logprob
|
| 277 |
+
logprobs = []
|
| 278 |
+
for idx, token_id in enumerate(generated_ids):
|
| 279 |
+
if idx < len(outputs.scores):
|
| 280 |
+
token_logits = outputs.scores[idx][0]
|
| 281 |
+
token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
|
| 282 |
+
logprobs.append(token_logprob)
|
| 283 |
+
total_logprob = sum(logprobs) if logprobs else -100.0
|
| 284 |
+
|
| 285 |
+
# If empty, use fallback
|
| 286 |
+
if not action_text:
|
| 287 |
+
fallback_actions = [
|
| 288 |
+
'{"action_type": "run_tests"}',
|
| 289 |
+
'{"action_type": "run_linter"}',
|
| 290 |
+
'{"action_type": "inspect"}',
|
| 291 |
+
'{"action_type": "skip"}',
|
| 292 |
+
]
|
| 293 |
+
action_text = random.choice(fallback_actions)
|
| 294 |
+
total_logprob = -50.0
|
| 295 |
+
print(f"[WARN] Empty generation → using fallback: {action_text}")
|
| 296 |
+
return action_text, total_logprob
|
| 297 |
+
|
| 298 |
+
# Validate JSON
|
| 299 |
+
try:
|
| 300 |
+
json.loads(action_text)
|
| 301 |
+
return action_text, total_logprob
|
| 302 |
+
except:
|
| 303 |
+
if attempt == max_retries - 1:
|
| 304 |
+
return '{"action_type":"skip"}', -100.0
|
| 305 |
+
continue
|
| 306 |
+
|
| 307 |
+
return '{"action_type":"skip"}', -100.0
|
| 308 |
+
|
| 309 |
+
# ======================================================================
|
| 310 |
+
# 6. PROMPT BUILDER (unchanged – exactly as you wrote)
|
| 311 |
+
# ======================================================================
|
| 312 |
+
def build_prompt(obs, history_lines: List[str]) -> str:
|
| 313 |
+
author_msg = getattr(obs, "author_response", "") or ""
|
| 314 |
+
tool_output = getattr(obs, "last_tool_output", "") or ""
|
| 315 |
+
|
| 316 |
+
# Personality hint (optional but helpful)
|
| 317 |
+
author_personality = getattr(obs, "author_personality", "defensive") # e.g., from env
|
| 318 |
+
|
| 319 |
+
prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix and name your proposed fix function fix.
|
| 320 |
+
|
| 321 |
+
The developer has a **{author_personality}** personality and will only accept if you provide solid evidence:
|
| 322 |
+
- Tests pass (high pass ratio)
|
| 323 |
+
- Lint is clean (zero errors)
|
| 324 |
+
- Documentation or references are provided
|
| 325 |
+
- Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
|
| 326 |
+
|
| 327 |
+
Workflow:
|
| 328 |
+
1. Use `inspect` to understand the code.
|
| 329 |
+
2. Use `run_tests` and `run_linter` to gather evidence.
|
| 330 |
+
3. Use `query_docs` when you need references or language-specific guidance.
|
| 331 |
+
4. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
|
| 332 |
+
5. If the developer pushes back, read their response carefully and address their specific concern.
|
| 333 |
+
6. Once convinced, use `done` to finish.
|
| 334 |
+
|
| 335 |
+
Code:
|
| 336 |
+
{obs.code_snippet}
|
| 337 |
+
|
| 338 |
+
Author says:
|
| 339 |
+
{author_msg if author_msg else "(no response yet – start with inspection)"}
|
| 340 |
+
|
| 341 |
+
Last tool output:
|
| 342 |
+
{tool_output if tool_output else "(none)"}
|
| 343 |
+
|
| 344 |
+
Available actions:
|
| 345 |
+
run_tests, run_linter, inspect, query_docs, fix, comment, question, done
|
| 346 |
+
|
| 347 |
+
Respond ONLY in JSON:
|
| 348 |
+
{{"action_type": "...", "content": "..."}}"""
|
| 349 |
+
|
| 350 |
+
if history_lines:
|
| 351 |
+
history = "\n".join(history_lines[-6:])
|
| 352 |
+
prompt += f"\n\nPrevious steps:\n{history}"
|
| 353 |
+
return prompt
|
| 354 |
+
|
| 355 |
+
# ======================================================================
|
| 356 |
+
# 7. TRAJECTORY STORAGE (unchanged)
|
| 357 |
+
# ======================================================================
|
| 358 |
+
@dataclass
|
| 359 |
+
class Trajectory:
|
| 360 |
+
states: List[str]
|
| 361 |
+
actions: List[str]
|
| 362 |
+
rewards: List[float]
|
| 363 |
+
logprobs: List[float]
|
| 364 |
+
dones: List[bool]
|
| 365 |
+
|
| 366 |
+
def __len__(self):
|
| 367 |
+
return len(self.states)
|
| 368 |
+
|
| 369 |
+
def to_dict(self):
|
| 370 |
+
return {
|
| 371 |
+
"states": self.states,
|
| 372 |
+
"actions": self.actions,
|
| 373 |
+
"rewards": self.rewards,
|
| 374 |
+
"logprobs": self.logprobs,
|
| 375 |
+
"dones": self.dones,
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
# ======================================================================
|
| 379 |
+
# 8. ROLLOUT COLLECTION (uses fixed generate)
|
| 380 |
+
# ======================================================================
|
| 381 |
+
def collect_trajectory(
|
| 382 |
+
env: CodeReviewEnv,
|
| 383 |
+
model,
|
| 384 |
+
tokenizer,
|
| 385 |
+
max_steps: int = 10,
|
| 386 |
+
temperature: float = 0.0 # changed to greedy
|
| 387 |
+
) -> Trajectory:
|
| 388 |
+
obs = env.reset()
|
| 389 |
+
history_lines = []
|
| 390 |
+
|
| 391 |
+
states = []
|
| 392 |
+
actions = []
|
| 393 |
+
rewards = []
|
| 394 |
+
logprobs = []
|
| 395 |
+
dones = []
|
| 396 |
+
|
| 397 |
+
for step in range(max_steps):
|
| 398 |
+
prompt = build_prompt(obs, history_lines)
|
| 399 |
+
states.append(prompt)
|
| 400 |
+
|
| 401 |
+
action_text, logprob = generate_action_with_logprob(
|
| 402 |
+
prompt, model, tokenizer, temperature
|
| 403 |
+
)
|
| 404 |
+
actions.append(action_text)
|
| 405 |
+
logprobs.append(logprob)
|
| 406 |
+
|
| 407 |
+
action = parse_action(action_text)
|
| 408 |
+
env_action = map_to_env(action)
|
| 409 |
+
next_obs, reward, done, _ = env.step(env_action)
|
| 410 |
+
|
| 411 |
+
rewards.append(reward.value)
|
| 412 |
+
dones.append(done)
|
| 413 |
+
|
| 414 |
+
history_lines.append(f"Agent: {action_text}")
|
| 415 |
+
history_lines.append(f"Env: {next_obs.last_tool_output}")
|
| 416 |
+
|
| 417 |
+
obs = next_obs
|
| 418 |
+
if done:
|
| 419 |
+
break
|
| 420 |
+
|
| 421 |
+
return Trajectory(states, actions, rewards, logprobs, dones)
|
| 422 |
+
|
| 423 |
+
def collect_trajectories(
|
| 424 |
+
env: CodeReviewEnv,
|
| 425 |
+
model,
|
| 426 |
+
tokenizer,
|
| 427 |
+
n_trajectories: int,
|
| 428 |
+
max_steps: int = 10,
|
| 429 |
+
task_levels: Optional[List[str]] = None,
|
| 430 |
+
task_weights: Optional[List[float]] = None,
|
| 431 |
+
) -> List[Trajectory]:
|
| 432 |
+
# Link training to RedTeam's full bug distribution by sampling tasks
|
| 433 |
+
# per trajectory instead of training only on env default ("easy").
|
| 434 |
+
if task_levels is None:
|
| 435 |
+
task_levels = list(BUG_DB.keys())
|
| 436 |
+
if task_weights is not None and len(task_weights) != len(task_levels):
|
| 437 |
+
raise ValueError("task_weights must match task_levels length")
|
| 438 |
+
if task_weights is not None and sum(task_weights) <= 0:
|
| 439 |
+
raise ValueError("task_weights must have a positive total")
|
| 440 |
+
|
| 441 |
+
trajectories = []
|
| 442 |
+
for i in range(n_trajectories):
|
| 443 |
+
# Weighted sampling supports curriculum-style training schedules.
|
| 444 |
+
sampled_task = random.choices(task_levels, weights=task_weights, k=1)[0]
|
| 445 |
+
env.set_task(sampled_task)
|
| 446 |
+
traj = collect_trajectory(env, model, tokenizer, max_steps)
|
| 447 |
+
total_reward = sum(traj.rewards)
|
| 448 |
+
print(f"Trajectory {i+1}/{n_trajectories}: "
|
| 449 |
+
f"task={sampled_task}, steps={len(traj)}, reward={total_reward:.3f}")
|
| 450 |
+
trajectories.append(traj)
|
| 451 |
+
return trajectories
|
| 452 |
+
|
| 453 |
+
# ======================================================================
|
| 454 |
+
# 9. ADVANTAGE ESTIMATION (unchanged)
|
| 455 |
+
# ======================================================================
|
| 456 |
+
def compute_returns_and_advantages(
|
| 457 |
+
rewards: List[float],
|
| 458 |
+
dones: List[bool],
|
| 459 |
+
gamma: float = 0.99,
|
| 460 |
+
standardize: bool = True
|
| 461 |
+
) -> Tuple[List[float], List[float]]:
|
| 462 |
+
"""
|
| 463 |
+
Computes discounted returns and normalised advantages (no critic).
|
| 464 |
+
Advantages = returns - mean(returns) (or zero baseline).
|
| 465 |
+
"""
|
| 466 |
+
n = len(rewards)
|
| 467 |
+
returns = [0.0] * n
|
| 468 |
+
running_return = 0.0
|
| 469 |
+
for t in reversed(range(n)):
|
| 470 |
+
if dones[t]:
|
| 471 |
+
running_return = 0.0
|
| 472 |
+
running_return = rewards[t] + gamma * running_return
|
| 473 |
+
returns[t] = running_return
|
| 474 |
+
|
| 475 |
+
if standardize:
|
| 476 |
+
advantages = np.array(returns) - np.mean(returns)
|
| 477 |
+
adv_std = np.std(advantages) + 1e-8
|
| 478 |
+
advantages = (advantages / adv_std).tolist()
|
| 479 |
+
else:
|
| 480 |
+
advantages = returns.copy()
|
| 481 |
+
|
| 482 |
+
return advantages, returns
|
| 483 |
+
# ======================================================================
|
| 484 |
+
# 10. COMPUTE NEW LOGPROBS (unchanged)
|
| 485 |
+
# ======================================================================
|
| 486 |
+
def compute_logprob(prompt: str, action: str, model, tokenizer) -> float:
|
| 487 |
+
messages = [{"role": "user", "content": prompt}]
|
| 488 |
+
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 489 |
+
full_text = formatted + action
|
| 490 |
+
inputs = tokenizer(full_text, return_tensors="pt").to("cuda")
|
| 491 |
+
|
| 492 |
+
with torch.no_grad():
|
| 493 |
+
outputs = model(**inputs)
|
| 494 |
+
logits = outputs.logits
|
| 495 |
+
|
| 496 |
+
action_ids = tokenizer.encode(action, add_special_tokens=False)
|
| 497 |
+
prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
|
| 498 |
+
action_start = len(prefix_ids)
|
| 499 |
+
|
| 500 |
+
logprobs = []
|
| 501 |
+
for idx, token_id in enumerate(action_ids):
|
| 502 |
+
position = action_start + idx - 1
|
| 503 |
+
if 0 <= position < logits.shape[1]:
|
| 504 |
+
token_logits = logits[0, position]
|
| 505 |
+
token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
|
| 506 |
+
logprobs.append(token_logprob)
|
| 507 |
+
return sum(logprobs) if logprobs else -100.0
|
| 508 |
+
|
| 509 |
+
# ======================================================================
|
| 510 |
+
# 11. PPO UPDATE (unchanged except uses compute_logprob correctly)
|
| 511 |
+
# ======================================================================
|
| 512 |
+
def ppo_update(
|
| 513 |
+
trajectories: List[Trajectory],
|
| 514 |
+
model,
|
| 515 |
+
tokenizer,
|
| 516 |
+
optimizer,
|
| 517 |
+
n_epochs: int = 4,
|
| 518 |
+
clip_epsilon: float = 0.2,
|
| 519 |
+
entropy_coef: float = 0.01,
|
| 520 |
+
gamma: float = 0.99,
|
| 521 |
+
) -> Dict[str, float]:
|
| 522 |
+
model.train()
|
| 523 |
+
|
| 524 |
+
all_states = []
|
| 525 |
+
all_actions = []
|
| 526 |
+
all_old_logprobs = []
|
| 527 |
+
all_advantages = []
|
| 528 |
+
all_returns = []
|
| 529 |
+
|
| 530 |
+
for traj in trajectories:
|
| 531 |
+
advantages, returns = compute_returns_and_advantages(
|
| 532 |
+
traj.rewards, traj.dones, gamma=gamma, standardize=True
|
| 533 |
+
)
|
| 534 |
+
all_states.extend(traj.states)
|
| 535 |
+
all_actions.extend(traj.actions)
|
| 536 |
+
all_old_logprobs.extend(traj.logprobs)
|
| 537 |
+
all_advantages.extend(advantages)
|
| 538 |
+
all_returns.extend(returns)
|
| 539 |
+
|
| 540 |
+
n_samples = len(all_states)
|
| 541 |
+
total_loss = 0.0
|
| 542 |
+
total_policy_loss = 0.0
|
| 543 |
+
total_entropy = 0.0
|
| 544 |
+
n_updates = 0
|
| 545 |
+
|
| 546 |
+
for epoch in range(n_epochs):
|
| 547 |
+
indices = np.random.permutation(n_samples)
|
| 548 |
+
for i in indices:
|
| 549 |
+
state = all_states[i]
|
| 550 |
+
action = all_actions[i]
|
| 551 |
+
old_logprob = all_old_logprobs[i]
|
| 552 |
+
advantage = all_advantages[i]
|
| 553 |
+
|
| 554 |
+
# Use the same chat template for PPO update
|
| 555 |
+
messages = [{"role": "user", "content": state}]
|
| 556 |
+
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 557 |
+
full_text = formatted + action
|
| 558 |
+
inputs = tokenizer(full_text, return_tensors="pt").to("cuda")
|
| 559 |
+
|
| 560 |
+
outputs = model(**inputs)
|
| 561 |
+
logits = outputs.logits
|
| 562 |
+
|
| 563 |
+
action_ids = tokenizer.encode(action, add_special_tokens=False)
|
| 564 |
+
prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
|
| 565 |
+
action_start = len(prefix_ids)
|
| 566 |
+
|
| 567 |
+
logprobs = []
|
| 568 |
+
entropy = 0.0
|
| 569 |
+
for idx, token_id in enumerate(action_ids):
|
| 570 |
+
position = action_start + idx - 1
|
| 571 |
+
if 0 <= position < logits.shape[1]:
|
| 572 |
+
token_logits = logits[0, position]
|
| 573 |
+
log_probs = F.log_softmax(token_logits, dim=-1)
|
| 574 |
+
token_logprob = log_probs[token_id]
|
| 575 |
+
logprobs.append(token_logprob)
|
| 576 |
+
|
| 577 |
+
probs = F.softmax(token_logits, dim=-1)
|
| 578 |
+
entropy += -(probs * log_probs).sum()
|
| 579 |
+
|
| 580 |
+
if not logprobs:
|
| 581 |
+
continue
|
| 582 |
+
|
| 583 |
+
new_logprob = sum(logprobs)
|
| 584 |
+
avg_entropy = entropy / len(logprobs) if logprobs else 0.0
|
| 585 |
+
|
| 586 |
+
ratio = torch.exp(new_logprob - old_logprob)
|
| 587 |
+
surr1 = ratio * advantage
|
| 588 |
+
surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantage
|
| 589 |
+
policy_loss = -torch.min(surr1, surr2)
|
| 590 |
+
loss = policy_loss - entropy_coef * avg_entropy
|
| 591 |
+
|
| 592 |
+
optimizer.zero_grad()
|
| 593 |
+
loss.backward()
|
| 594 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 595 |
+
optimizer.step()
|
| 596 |
+
|
| 597 |
+
total_loss += loss.item()
|
| 598 |
+
total_policy_loss += policy_loss.item()
|
| 599 |
+
total_entropy += avg_entropy.item()
|
| 600 |
+
n_updates += 1
|
| 601 |
+
|
| 602 |
+
return {
|
| 603 |
+
"loss": total_loss / n_updates if n_updates > 0 else 0.0,
|
| 604 |
+
"policy_loss": total_policy_loss / n_updates if n_updates > 0 else 0.0,
|
| 605 |
+
"entropy": total_entropy / n_updates if n_updates > 0 else 0.0,
|
| 606 |
+
}
|
| 607 |
+
# ======================================================================
|
| 608 |
+
# 12. EVALUATION (unchanged)
|
| 609 |
+
# ======================================================================
|
| 610 |
+
def evaluate_policy(
|
| 611 |
+
env: CodeReviewEnv,
|
| 612 |
+
model,
|
| 613 |
+
tokenizer,
|
| 614 |
+
n_episodes: int = 10,
|
| 615 |
+
max_steps: int = 10
|
| 616 |
+
) -> Dict[str, float]:
|
| 617 |
+
model.eval()
|
| 618 |
+
total_rewards = []
|
| 619 |
+
episode_lengths = []
|
| 620 |
+
success_count = 0
|
| 621 |
+
|
| 622 |
+
for _ in range(n_episodes):
|
| 623 |
+
traj = collect_trajectory(env, model, tokenizer, max_steps, temperature=0.0)
|
| 624 |
+
total_reward = sum(traj.rewards)
|
| 625 |
+
total_rewards.append(total_reward)
|
| 626 |
+
episode_lengths.append(len(traj))
|
| 627 |
+
if total_reward > 0.5:
|
| 628 |
+
success_count += 1
|
| 629 |
+
|
| 630 |
+
return {
|
| 631 |
+
"avg_reward": np.mean(total_rewards),
|
| 632 |
+
"std_reward": np.std(total_rewards),
|
| 633 |
+
"avg_length": np.mean(episode_lengths),
|
| 634 |
+
"success_rate": success_count / n_episodes,
|
| 635 |
+
}
|
| 636 |
+
|
| 637 |
+
# ======================================================================
|
| 638 |
+
# 13. MAIN TRAINING LOOP (added sanity check and warm-up)
|
| 639 |
+
# ======================================================================
|
| 640 |
+
def train_ppo(
|
| 641 |
+
n_iterations: int = 50,
|
| 642 |
+
trajectories_per_iter: int = 10,
|
| 643 |
+
n_epochs: int = 4,
|
| 644 |
+
max_steps: int = 10,
|
| 645 |
+
learning_rate: float = 3e-5,
|
| 646 |
+
clip_epsilon: float = 0.2,
|
| 647 |
+
entropy_coef: float = 0.01,
|
| 648 |
+
gamma: float = 0.99,
|
| 649 |
+
eval_every: int = 5,
|
| 650 |
+
task_levels: Optional[List[str]] = None,
|
| 651 |
+
curriculum_weighted_sampling: bool = True,
|
| 652 |
+
reward_profile: str = "full",
|
| 653 |
+
):
|
| 654 |
+
print("Loading model...")
|
| 655 |
+
model, tokenizer = load_model()
|
| 656 |
+
|
| 657 |
+
# NEW: Sanity check before any training
|
| 658 |
+
if not test_model_sanity(model, tokenizer):
|
| 659 |
+
print("\n❌ Model sanity check failed – cannot proceed.")
|
| 660 |
+
return
|
| 661 |
+
|
| 662 |
+
# NEW: Supervised warm-up to teach JSON format
|
| 663 |
+
supervised_warmup(model, tokenizer, n_examples=500, epochs=2)
|
| 664 |
+
|
| 665 |
+
optimizer = AdamW(model.parameters(), lr=learning_rate)
|
| 666 |
+
env = CodeReviewEnv(reward_profile=reward_profile)
|
| 667 |
+
if task_levels is None:
|
| 668 |
+
task_levels = list(BUG_DB.keys())
|
| 669 |
+
|
| 670 |
+
print(f"\n{'='*60}")
|
| 671 |
+
print(f"Starting PPO Training")
|
| 672 |
+
print(f"Iterations: {n_iterations}")
|
| 673 |
+
print(f"Trajectories per iteration: {trajectories_per_iter}")
|
| 674 |
+
print(f"PPO epochs: {n_epochs}")
|
| 675 |
+
print(f"Reward profile: {reward_profile}")
|
| 676 |
+
print(f"{'='*60}\n")
|
| 677 |
+
reward_history: List[float] = []
|
| 678 |
+
loss_history: List[float] = []
|
| 679 |
+
|
| 680 |
+
for iteration in range(n_iterations):
|
| 681 |
+
print(f"\n--- Iteration {iteration + 1}/{n_iterations} ---")
|
| 682 |
+
# Optional weighted curriculum:
|
| 683 |
+
# start with easier tasks and smoothly ramp difficulty over training.
|
| 684 |
+
if curriculum_weighted_sampling:
|
| 685 |
+
progress = (iteration + 1) / max(n_iterations, 1)
|
| 686 |
+
easy_w = max(0.15, 0.55 - 0.40 * progress)
|
| 687 |
+
medium_w = max(0.15, 0.25 - 0.10 * progress)
|
| 688 |
+
hard_w = 0.10 + 0.05 * progress
|
| 689 |
+
harder_w = 0.05 + 0.20 * progress
|
| 690 |
+
hardest_w = 0.05 + 0.25 * progress
|
| 691 |
+
task_weight_map = {
|
| 692 |
+
"easy": easy_w,
|
| 693 |
+
"medium": medium_w,
|
| 694 |
+
"hard": hard_w,
|
| 695 |
+
"harder": harder_w,
|
| 696 |
+
"hardest": hardest_w,
|
| 697 |
+
}
|
| 698 |
+
task_weights = [task_weight_map.get(level, 1.0) for level in task_levels]
|
| 699 |
+
else:
|
| 700 |
+
task_weights = None
|
| 701 |
+
|
| 702 |
+
print("Collecting trajectories...")
|
| 703 |
+
trajectories = collect_trajectories(
|
| 704 |
+
env,
|
| 705 |
+
model,
|
| 706 |
+
tokenizer,
|
| 707 |
+
trajectories_per_iter,
|
| 708 |
+
max_steps,
|
| 709 |
+
task_levels=task_levels,
|
| 710 |
+
task_weights=task_weights,
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
avg_reward = np.mean([sum(t.rewards) for t in trajectories])
|
| 714 |
+
avg_length = np.mean([len(t) for t in trajectories])
|
| 715 |
+
reward_history.append(float(avg_reward))
|
| 716 |
+
|
| 717 |
+
print(f"Avg reward: {avg_reward:.3f}")
|
| 718 |
+
print(f"Avg length: {avg_length:.1f}")
|
| 719 |
+
|
| 720 |
+
print("Updating policy...")
|
| 721 |
+
metrics = ppo_update(
|
| 722 |
+
trajectories,
|
| 723 |
+
model,
|
| 724 |
+
tokenizer,
|
| 725 |
+
optimizer,
|
| 726 |
+
n_epochs=n_epochs,
|
| 727 |
+
clip_epsilon=clip_epsilon,
|
| 728 |
+
entropy_coef=entropy_coef,
|
| 729 |
+
gamma=gamma,
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
print(f"Loss: {metrics['loss']:.4f}")
|
| 733 |
+
print(f"Policy loss: {metrics['policy_loss']:.4f}")
|
| 734 |
+
print(f"Entropy: {metrics['entropy']:.4f}")
|
| 735 |
+
loss_history.append(float(metrics["loss"]))
|
| 736 |
+
|
| 737 |
+
if (iteration + 1) % eval_every == 0:
|
| 738 |
+
print("\nEvaluating policy...")
|
| 739 |
+
eval_metrics = evaluate_policy(env, model, tokenizer, n_episodes=10)
|
| 740 |
+
print(f"Eval avg reward: {eval_metrics['avg_reward']:.3f} ± {eval_metrics['std_reward']:.3f}")
|
| 741 |
+
print(f"Eval success rate: {eval_metrics['success_rate']:.2%}")
|
| 742 |
+
print(f"Eval avg length: {eval_metrics['avg_length']:.1f}")
|
| 743 |
+
|
| 744 |
+
print("\n" + "="*60)
|
| 745 |
+
print("Training complete. Saving model...")
|
| 746 |
+
model.save_pretrained("ppo_final_model")
|
| 747 |
+
tokenizer.save_pretrained("ppo_final_model")
|
| 748 |
+
print("Model saved to ppo_final_model/")
|
| 749 |
+
|
| 750 |
+
# Save training curves for quick before/after comparisons.
|
| 751 |
+
# These are intentionally simple line plots to avoid extra dependencies.
|
| 752 |
+
if reward_history:
|
| 753 |
+
plt.figure(figsize=(8, 4))
|
| 754 |
+
plt.plot(range(1, len(reward_history) + 1), reward_history, marker="o")
|
| 755 |
+
plt.title("Average Reward per Iteration")
|
| 756 |
+
plt.xlabel("Iteration")
|
| 757 |
+
plt.ylabel("Average Reward")
|
| 758 |
+
plt.grid(alpha=0.3)
|
| 759 |
+
plt.tight_layout()
|
| 760 |
+
plt.savefig("reward_curve.png", dpi=150)
|
| 761 |
+
plt.close()
|
| 762 |
+
|
| 763 |
+
if loss_history:
|
| 764 |
+
plt.figure(figsize=(8, 4))
|
| 765 |
+
plt.plot(range(1, len(loss_history) + 1), loss_history, marker="o", color="tab:red")
|
| 766 |
+
plt.title("Training Loss per Iteration")
|
| 767 |
+
plt.xlabel("Iteration")
|
| 768 |
+
plt.ylabel("Loss")
|
| 769 |
+
plt.grid(alpha=0.3)
|
| 770 |
+
plt.tight_layout()
|
| 771 |
+
plt.savefig("loss_curve.png", dpi=150)
|
| 772 |
+
plt.close()
|
| 773 |
+
|
| 774 |
+
if os.path.exists("reward_curve.png") and os.path.exists("loss_curve.png"):
|
| 775 |
+
print("Saved reward_curve.png and loss_curve.png")
|
| 776 |
+
print("="*60)
|
| 777 |
+
|
| 778 |
+
# ======================================================================
|
| 779 |
+
# 14. ENTRY POINT (unchanged)
|
| 780 |
+
# ======================================================================
|
| 781 |
+
if __name__ == "__main__":
|
| 782 |
+
train_ppo(
|
| 783 |
+
n_iterations=50,
|
| 784 |
+
trajectories_per_iter=10,
|
| 785 |
+
n_epochs=4,
|
| 786 |
+
max_steps=10,
|
| 787 |
+
learning_rate=3e-5,
|
| 788 |
+
clip_epsilon=0.2,
|
| 789 |
+
entropy_coef=0.01,
|
| 790 |
+
gamma=0.99,
|
| 791 |
+
eval_every=5,
|
| 792 |
+
)
|