100XZX001 commited on
Commit
94b1baf
Β·
verified Β·
1 Parent(s): 86c792b

Upload 23 files

Browse files
Files changed (16) hide show
  1. .gitattributes +35 -35
  2. .gitignore +3 -0
  3. Dockerfile +23 -23
  4. README.md +117 -14
  5. __init__.py +12 -10
  6. app.py +104 -104
  7. environment.py +624 -624
  8. grader.py +13 -12
  9. models.py +111 -111
  10. pyproject.toml +38 -30
  11. requirements-training.txt +9 -0
  12. requirements.txt +5 -11
  13. rltool.py +143 -127
  14. rubrics.py +136 -136
  15. training.py +934 -934
  16. training_data.json +0 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ tmp/
Dockerfile CHANGED
@@ -1,24 +1,24 @@
1
- # Dockerfile – OpenEnv server with FastAPI and all dependencies
2
- FROM python:3.10-slim
3
-
4
- # Install system dependencies required for chromadb and sentence-transformers
5
- RUN apt-get update && apt-get install -y --no-install-recommends \
6
- build-essential \
7
- && rm -rf /var/lib/apt/lists/*
8
-
9
- WORKDIR /app
10
-
11
- # Copy requirements and install Python dependencies
12
- COPY requirements.txt .
13
- RUN pip install --no-cache-dir -r requirements.txt
14
-
15
- # Copy the rest of the application
16
- COPY . .
17
-
18
- # Expose the port used by the FastAPI server
19
- EXPOSE 7860
20
-
21
- # Run the server using uvicorn
22
- # Note: 'server.app:app' assumes the FastAPI app is in server/app.py
23
- ENV ENABLE_WEB_INTERFACE=true
24
  CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
+ # Dockerfile – OpenEnv server with FastAPI and all dependencies
2
+ FROM python:3.10-slim
3
+
4
+ # Install system dependencies required for chromadb and sentence-transformers
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ build-essential \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ WORKDIR /app
10
+
11
+ # Copy requirements and install Python dependencies
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ # Copy the rest of the application
16
+ COPY . .
17
+
18
+ # Expose the port used by the FastAPI server
19
+ EXPOSE 7860
20
+
21
+ # Run the server using uvicorn
22
+ # Note: 'server.app:app' assumes the FastAPI app is in server/app.py
23
+ ENV ENABLE_WEB_INTERFACE=true
24
  CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,14 +1,117 @@
1
- ---
2
- title: CodeReview Training
3
- emoji: πŸ€–
4
- colorFrom: blue
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 4.44.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- # CodeReview PPO Training
13
-
14
- This Space trains an LLM agent to fix injected bugs using PPO and rubrics.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CodeReview Training
3
+ emoji: "πŸ€–"
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ # CodeReview Professional Workflow
11
+
12
+ `CodeReview Professional Workflow` is an OpenEnv environment for training code-fixing agents on realistic review loops instead of one-shot coding tasks. The agent has to inspect buggy code, run tests, lint the patch, query docs, and persuade a simulated author before the episode is considered solved.
13
+
14
+ ## Quick links
15
+
16
+ | Artifact | Link |
17
+ | --- | --- |
18
+ | Hugging Face Space | [100XZX001/CodeReview-Professional-Workflow](https://huggingface.co/spaces/100XZX001/CodeReview-Professional-Workflow) |
19
+ | Colab-ready training notebook | [notebooks/code_review_unsloth_training.ipynb](notebooks/code_review_unsloth_training.ipynb) |
20
+ | Local training script | [training.py](training.py) |
21
+ | OpenEnv manifest | [openenv.yaml](openenv.yaml) |
22
+ | Submission slide deck | [submission_assets/code_review_openenv_submission.pptx](submission_assets/code_review_openenv_submission.pptx) |
23
+ | Training artifacts folder | [outputs/README.md](outputs/README.md) |
24
+
25
+ ## Why this environment
26
+
27
+ Most code agents are evaluated on static patch generation. Real review work is messier:
28
+
29
+ - you have to diagnose the failure mode before patching
30
+ - you often need tool feedback before you know whether the fix is safe
31
+ - you may need to explain the fix to another developer before it is accepted
32
+
33
+ This environment turns that workflow into a multi-step RL setting with dense rewards and stateful interaction.
34
+
35
+ ## How the environment works
36
+
37
+ Each episode samples one injected bug from five difficulty bands:
38
+
39
+ 1. `easy`: null checks, missing defaults, simple indexing mistakes
40
+ 2. `medium`: off-by-one and wrong-operator bugs
41
+ 3. `hard`: numerical safety failures like divide-by-zero
42
+ 4. `harder`: concurrency issues like missing locks
43
+ 5. `hardest`: deadlock and coordination mistakes
44
+
45
+ The agent can take actions such as:
46
+
47
+ - `inspect`
48
+ - `run_tests`
49
+ - `run_linter`
50
+ - `query_docs`
51
+ - `fix`
52
+ - `comment`
53
+ - `question`
54
+ - `done`
55
+
56
+ Rewards combine test delta, lint delta, tool usage, exploration behavior, step penalties, and terminal success. The observation includes the current code, latest tool output, previous scores, author confidence, progress counters, and recent action history.
57
+
58
+ ## OpenEnv-first setup
59
+
60
+ This repo is structured as an OpenEnv environment rather than a custom one-off app:
61
+
62
+ - the environment metadata lives in [openenv.yaml](openenv.yaml)
63
+ - the Space is configured as a Docker-based OpenEnv deployment
64
+ - runtime dependencies are kept lightweight for the Space build
65
+ - training-only packages live separately so judges can run the environment without pulling the full training stack
66
+
67
+ The project now targets `openenv-core>=0.2.3`.
68
+
69
+ ## Training
70
+
71
+ The main training entrypoint is [training.py](training.py), which uses Unsloth plus a PPO-style loop over real environment interaction. For judges who want a rerunnable workflow, the repo also includes a Colab-ready notebook:
72
+
73
+ - [notebooks/code_review_unsloth_training.ipynb](notebooks/code_review_unsloth_training.ipynb)
74
+
75
+ ### Install locally
76
+
77
+ ```bash
78
+ pip install -e .
79
+ pip install -r requirements-training.txt
80
+ ```
81
+
82
+ ### Run training
83
+
84
+ ```bash
85
+ python training.py
86
+ ```
87
+
88
+ The training run writes the evidence plots in the working directory:
89
+
90
+ - `warmup_loss.png`
91
+ - `reward_curve.png`
92
+ - `loss_curve.png`
93
+ - `training_summary.png`
94
+
95
+ For submission hygiene, copy a real run into `outputs/<run-name>/` and link that folder from this README before final judging.
96
+
97
+ ## Results and evidence
98
+
99
+ The expected evidence bundle for a real training run is:
100
+
101
+ - warm-up loss curve
102
+ - PPO reward curve
103
+ - PPO loss curve
104
+ - combined summary panel
105
+
106
+ Use [outputs/README.md](outputs/README.md) as the landing page for committed run artifacts.
107
+
108
+ ## Submission materials
109
+
110
+ This repo is set up so every judge-facing artifact can be reached from this README:
111
+
112
+ - environment Space: [100XZX001/CodeReview-Professional-Workflow](https://huggingface.co/spaces/100XZX001/CodeReview-Professional-Workflow)
113
+ - training notebook: [notebooks/code_review_unsloth_training.ipynb](notebooks/code_review_unsloth_training.ipynb)
114
+ - slide deck: [submission_assets/code_review_openenv_submission.pptx](submission_assets/code_review_openenv_submission.pptx)
115
+ - evidence folder: [outputs/README.md](outputs/README.md)
116
+
117
+ No large video files are stored in the repo; any future video or blog submission should be linked by URL from this README.
__init__.py CHANGED
@@ -4,13 +4,15 @@
4
  # This source code is licensed under the BSD-style license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
- """Criticrl Environment."""
8
-
9
- from .client import CriticrlEnv
10
- from .models import CriticrlAction, CriticrlObservation
11
-
12
- __all__ = [
13
- "CriticrlAction",
14
- "CriticrlObservation",
15
- "CriticrlEnv",
16
- ]
 
 
 
4
  # This source code is licensed under the BSD-style license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
+ """Code Review Professional Workflow OpenEnv package."""
8
+
9
+ from .client import CodeReviewEnv
10
+ from .models import AnyAction, Observation, Reward, State
11
+
12
+ __all__ = [
13
+ "AnyAction",
14
+ "Observation",
15
+ "Reward",
16
+ "State",
17
+ "CodeReviewEnv",
18
+ ]
app.py CHANGED
@@ -1,104 +1,104 @@
1
- # server/app.py – OpenEnv HTTP server
2
- import sys
3
- import os
4
- sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
5
-
6
- from fastapi import FastAPI, HTTPException
7
- from environment import CodeReviewEnv
8
- from models import AnyAction, Observation, Reward, State, action_adapter
9
-
10
- app = FastAPI(title="Code Review Environment", version="1.0.0")
11
- env = CodeReviewEnv()
12
-
13
- # ----------------------------------------------------------------------
14
- # Health & metadata endpoints
15
- # ----------------------------------------------------------------------
16
- @app.get("/")
17
- def root():
18
- print("[ROOT] Health check hit")
19
- return {"status": "crazy good"}
20
-
21
- @app.get("/health")
22
- def health():
23
- print("[HEALTH] Service is healthy")
24
- return {"status": "healthy"}
25
-
26
- @app.get("/metadata")
27
- def metadata():
28
- print("[METADATA] Requested")
29
- return {
30
- "name": "Code Review Professional Workflow",
31
- "description": (
32
- "Multi‑turn code review environment for professional‑level bug fixing. "
33
- "The agent must inspect, test, lint, query documentation, and negotiate with "
34
- "a simulated (persona‑driven) author to get a fix accepted. "
35
- "Includes 25 bugs across 5 difficulty levels, AST‑based injection, "
36
- "a reward‑shaping system (full/core profiles), and curriculum learning. "
37
- "Designed for RL training (PPO, DPO, or any policy‑gradient method)."
38
- )
39
- }
40
-
41
- @app.get("/schema")
42
- def schema():
43
- print("[SCHEMA] Requested")
44
- return {
45
- "action": AnyAction.model_json_schema(),
46
- "observation": Observation.model_json_schema(),
47
- "state": State.model_json_schema()
48
- }
49
-
50
- @app.post("/mcp")
51
- def mcp():
52
- print("[MCP] Ping received")
53
- return {"jsonrpc": "2.0", "result": None}
54
-
55
- # ----------------------------------------------------------------------
56
- # Environment endpoints
57
- # ----------------------------------------------------------------------
58
- @app.post("/reset")
59
- def reset(task: str = "easy"):
60
- try:
61
- print(f"[RESET] Starting new episode | task={task}")
62
-
63
- env.set_task(task)
64
- obs = env.reset()
65
-
66
- print(f"[RESET DONE] step={env._step_count}")
67
-
68
- return obs.__dict__
69
- except Exception as e:
70
- print(f"[RESET ERROR] {e}")
71
- raise HTTPException(status_code=400, detail=str(e))
72
-
73
- @app.post("/step")
74
- def step(action: dict):
75
- try:
76
- print(f"[STEP INPUT] {action}")
77
-
78
- parsed_action = action_adapter.validate_python(action)
79
- obs, reward, done, info = env.step(parsed_action)
80
-
81
- print(f"[STEP OUTPUT] reward={reward.value:.4f} | done={done}")
82
-
83
- return {
84
- "observation": obs.__dict__,
85
- "reward": reward.value,
86
- "done": done,
87
- "info": info
88
- }
89
- except Exception as e:
90
- print(f"[STEP ERROR] {e}")
91
- raise HTTPException(status_code=400, detail=str(e))
92
-
93
- @app.get("/state")
94
- def state():
95
- print("[STATE] Requested")
96
- return env._get_observation().__dict__
97
-
98
- # ----------------------------------------------------------------------
99
- # Main entry point (for local testing)
100
- # ----------------------------------------------------------------------
101
- if __name__ == "__main__":
102
- import uvicorn
103
- print("[SERVER START] Running on http://0.0.0.0:7860")
104
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ # server/app.py – OpenEnv HTTP server
2
+ import sys
3
+ import os
4
+ sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
5
+
6
+ from fastapi import FastAPI, HTTPException
7
+ from environment import CodeReviewEnv
8
+ from models import AnyAction, Observation, Reward, State, action_adapter
9
+
10
+ app = FastAPI(title="Code Review Environment", version="1.0.0")
11
+ env = CodeReviewEnv()
12
+
13
+ # ----------------------------------------------------------------------
14
+ # Health & metadata endpoints
15
+ # ----------------------------------------------------------------------
16
+ @app.get("/")
17
+ def root():
18
+ print("[ROOT] Health check hit")
19
+ return {"status": "crazy good"}
20
+
21
+ @app.get("/health")
22
+ def health():
23
+ print("[HEALTH] Service is healthy")
24
+ return {"status": "healthy"}
25
+
26
+ @app.get("/metadata")
27
+ def metadata():
28
+ print("[METADATA] Requested")
29
+ return {
30
+ "name": "Code Review Professional Workflow",
31
+ "description": (
32
+ "Multi‑turn code review environment for professional‑level bug fixing. "
33
+ "The agent must inspect, test, lint, query documentation, and negotiate with "
34
+ "a simulated (persona‑driven) author to get a fix accepted. "
35
+ "Includes 25 bugs across 5 difficulty levels, AST‑based injection, "
36
+ "a reward‑shaping system (full/core profiles), and curriculum learning. "
37
+ "Designed for RL training (PPO, DPO, or any policy‑gradient method)."
38
+ )
39
+ }
40
+
41
+ @app.get("/schema")
42
+ def schema():
43
+ print("[SCHEMA] Requested")
44
+ return {
45
+ "action": AnyAction.model_json_schema(),
46
+ "observation": Observation.model_json_schema(),
47
+ "state": State.model_json_schema()
48
+ }
49
+
50
+ @app.post("/mcp")
51
+ def mcp():
52
+ print("[MCP] Ping received")
53
+ return {"jsonrpc": "2.0", "result": None}
54
+
55
+ # ----------------------------------------------------------------------
56
+ # Environment endpoints
57
+ # ----------------------------------------------------------------------
58
+ @app.post("/reset")
59
+ def reset(task: str = "easy"):
60
+ try:
61
+ print(f"[RESET] Starting new episode | task={task}")
62
+
63
+ env.set_task(task)
64
+ obs = env.reset()
65
+
66
+ print(f"[RESET DONE] step={env._step_count}")
67
+
68
+ return obs.__dict__
69
+ except Exception as e:
70
+ print(f"[RESET ERROR] {e}")
71
+ raise HTTPException(status_code=400, detail=str(e))
72
+
73
+ @app.post("/step")
74
+ def step(action: dict):
75
+ try:
76
+ print(f"[STEP INPUT] {action}")
77
+
78
+ parsed_action = action_adapter.validate_python(action)
79
+ obs, reward, done, info = env.step(parsed_action)
80
+
81
+ print(f"[STEP OUTPUT] reward={reward.value:.4f} | done={done}")
82
+
83
+ return {
84
+ "observation": obs.__dict__,
85
+ "reward": reward.value,
86
+ "done": done,
87
+ "info": info
88
+ }
89
+ except Exception as e:
90
+ print(f"[STEP ERROR] {e}")
91
+ raise HTTPException(status_code=400, detail=str(e))
92
+
93
+ @app.get("/state")
94
+ def state():
95
+ print("[STATE] Requested")
96
+ return env._get_observation().__dict__
97
+
98
+ # ----------------------------------------------------------------------
99
+ # Main entry point (for local testing)
100
+ # ----------------------------------------------------------------------
101
+ if __name__ == "__main__":
102
+ import uvicorn
103
+ print("[SERVER START] Running on http://0.0.0.0:7860")
104
+ uvicorn.run(app, host="0.0.0.0", port=7860)
environment.py CHANGED
@@ -1,628 +1,628 @@
1
- # environment.py – FULLY CORRECTED RL Environment (TRUE Markov + Fixed Bugs)
2
-
3
- import sys
4
- import subprocess
5
- import tempfile
6
- import os
7
- import re
8
- from dataclasses import dataclass, field
9
- from typing import Tuple, Dict, Any, Optional, List
10
-
11
- from models import (
12
- AnyAction, WriteComment, ProposeFix, Execute, Inspect,
13
- RunLinter, RunTests, QueryDocs, Skip, Done, AskQuestion,
14
- Observation, Reward, State
15
- )
16
- from redteam import RedTeam
17
- from test_runner import TestRunner
18
- from author import PersonaAuthor
19
- from rltool import ToolBox
20
- from rubrics import (
21
- ToolUsageRubric,
22
- TestDeltaRubric,
23
- LintDeltaRubric,
24
- TerminalSuccessRubric,
25
- ExplorationRubric,
26
- AntiHackingRubric,
27
- StepPenaltyRubric,
28
- )
29
-
30
- # ======================================================================
31
- # FULLY MARKOV OBSERVATION (NOTHING HIDDEN)
32
- # ======================================================================
33
- @dataclass
34
- class EnhancedObservation:
35
- code_snippet: str
36
- last_tool_output: str
37
-
38
- current_test_score: float
39
- current_lint_score: float
40
- negotiation_score: float
41
-
42
- previous_test_score: float
43
- previous_lint_score: float
44
-
45
- author_confidence: float
46
- author_threshold: float
47
-
48
- step: int
49
- max_steps: int
50
- progress_ratio: float
51
-
52
- tests_run: bool
53
- linter_run: bool
54
- docs_queried: bool
55
-
56
- last_action_type: str
57
- action_history: List[str]
58
-
59
- done: bool
60
-
61
- bug_description: str
62
- comments_count: int
63
-
64
- # default fields must be at the very end
65
- author_response: str = ""
66
-
67
- # ======================================================================
68
- # HELPER FUNCTIONS
69
- # ======================================================================
70
- def execute_code(code: str, timeout_sec: int = 5) -> Tuple[bool, str, str]:
71
- if not code.strip():
72
- return False, "", "Error: Empty code"
73
-
74
- with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
75
- f.write(code)
76
- tmp_path = f.name
77
-
78
- try:
79
- result = subprocess.run(
80
- [sys.executable, tmp_path],
81
- capture_output=True,
82
- text=True,
83
- timeout=timeout_sec
84
- )
85
- success = (result.returncode == 0)
86
- return success, result.stdout, result.stderr
87
- except subprocess.TimeoutExpired:
88
- return False, "", f"Timeout after {timeout_sec}s"
89
- except Exception as e:
90
- return False, "", f"Execution error: {str(e)}"
91
- finally:
92
- try:
93
- os.unlink(tmp_path)
94
- except:
95
- pass
96
-
97
-
98
- # ======================================================================
99
- # ENHANCED CODE REVIEW ENVIRONMENT
100
- # ======================================================================
101
- @dataclass
102
- class CodeReviewEnv:
103
- task: str = "easy"
104
- max_steps: int = 10
105
- step_penalty: float = 0.01
106
- reward_profile: str = "full" # "full" or "core"
107
-
108
- # Curriculum learning
109
- auto_difficulty: bool = False
110
- success_threshold: float = 0.7
111
-
112
- # Reward shaping parameters
113
- delta_weight: float = 0.3
114
- tool_usage_bonus: float = 0.05
115
- diversity_bonus: float = 0.03
116
-
117
- _red_team: Optional[RedTeam] = field(init=False, default=None)
118
- _author: Optional[PersonaAuthor] = field(init=False, default=None)
119
-
120
- _current_code: str = field(init=False, default="")
121
- _current_bug_id: str = field(init=False, default="")
122
- _bug_description: str = field(init=False, default="")
123
- _oracle_fix: str = field(init=False, default="")
124
-
125
- _comments: list = field(init=False, default_factory=list)
126
- _test_results: Optional[str] = field(init=False, default=None)
127
- _lint_results: Optional[str] = field(init=False, default=None)
128
- _doc_results: Optional[str] = field(init=False, default=None)
129
-
130
- _step_count: int = field(init=False, default=0)
131
- _done: bool = field(init=False, default=False)
132
-
133
- # State tracking for dense rewards
134
- _previous_test_score: float = field(init=False, default=0.0)
135
- _previous_lint_score: float = field(init=False, default=0.0)
136
- _current_test_score: float = field(init=False, default=0.0)
137
- _current_lint_score: float = field(init=False, default=0.0)
138
-
139
- # Tool usage tracking
140
- _tests_run: bool = field(init=False, default=False)
141
- _linter_run: bool = field(init=False, default=False)
142
- _docs_queried: bool = field(init=False, default=False)
143
-
144
- # Action history
145
- _action_history: List[str] = field(init=False, default_factory=list)
146
- _last_action_type: str = field(init=False, default="none")
147
- _last_author_response: str = field(init=False, default="")
148
-
149
- # FIXED: Track CUMULATIVE episode reward
150
- _episode_total_reward: float = field(init=False, default=0.0)
151
- _episode_rewards: List[float] = field(init=False, default_factory=list)
152
- _difficulty_level: int = field(init=False, default=0)
153
-
154
- # Bug-id bridge:
155
- # RedTeam has fine-grained IDs, while TestRunner currently expects a
156
- # smaller canonical set. Keep this mapping here so both modules can evolve
157
- # independently without breaking evaluation.
158
- _BUG_ID_CANONICAL_MAP = {
159
- # Easy-family
160
- "simple_typo": "null_check",
161
- "default_value": "null_check",
162
- "empty_return": "null_check",
163
- "string_index": "off_by_one",
164
-
165
- # Medium-family
166
- "loop_skip": "off_by_one",
167
- "sign_error": "wrong_operator",
168
- "swap_args": "wrong_operator",
169
- "uninitialised_var": "null_check",
170
-
171
- # Hard-family
172
- "division_by_zero_empty": "division_by_zero",
173
- "division_by_zero_zero": "division_by_zero",
174
- "float_precision": "division_by_zero",
175
- "abs_usage": "division_by_zero",
176
- "round_error": "division_by_zero",
177
- }
178
-
179
- # ===================================================================
180
- def __post_init__(self):
181
- self.set_task(self.task)
182
-
183
- # ===================================================================
184
- def _build_rubrics(self):
185
- """
186
- Build rubric stack from a named reward profile.
187
- - full: richer shaping for exploration/tool-use behavior
188
- - core: minimal stable signal for quick ablations/baselines
189
- """
190
- core_rubrics = [
191
- TestDeltaRubric(weight=self.delta_weight),
192
- LintDeltaRubric(weight=self.delta_weight),
193
- TerminalSuccessRubric(),
194
- StepPenaltyRubric(penalty=self.step_penalty),
195
- ]
196
- if self.reward_profile == "core":
197
- return core_rubrics
198
- if self.reward_profile == "full":
199
- return [
200
- *core_rubrics[:-1], # step penalty appended at end for consistent ordering
201
- ToolUsageRubric(bonus=self.tool_usage_bonus),
202
- ExplorationRubric(penalty=-0.05, bonus=self.diversity_bonus * 0.7),
203
- AntiHackingRubric(),
204
- core_rubrics[-1],
205
- ]
206
- raise ValueError(f"Unknown reward_profile: {self.reward_profile}")
207
-
208
- # ===================================================================
209
- def set_task(self, task: str):
210
- if task not in ["easy", "medium", "hard", "harder", "hardest"]:
211
- raise ValueError(f"Unknown task: {task}")
212
-
213
- self.task = task
214
- # Use stochastic bug sampling across episodes; fixed seed here would
215
- # repeatedly select the same bug and weaken training diversity.
216
- self._red_team = RedTeam(task, seed=None)
217
- self._author = PersonaAuthor()
218
- self.rubrics = self._build_rubrics()
219
-
220
- task_to_level = {
221
- "easy": 0, "medium": 1, "hard": 2,
222
- "harder": 3, "hardest": 4
223
- }
224
- self._difficulty_level = task_to_level[task]
225
-
226
- self._reset_internal()
227
-
228
- # ===================================================================
229
- def _reset_internal(self):
230
- self._step_count = 0 # ← FIXED
231
- self._comments = []
232
- self._test_results = None
233
- self._lint_results = None
234
- self._doc_results = None
235
- self._done = False
236
-
237
- # Reset state tracking
238
- self._previous_test_score = 0.0
239
- self._previous_lint_score = 0.0
240
- self._current_test_score = 0.0
241
- self._current_lint_score = 0.0
242
-
243
- self._tests_run = False
244
- self._linter_run = False
245
- self._docs_queried = False
246
-
247
- self._action_history = []
248
- self._last_action_type = "none"
249
- self._last_author_response = ""
250
-
251
- # FIXED: Reset episode cumulative reward
252
- self._episode_total_reward = 0.0
253
-
254
- self._author.reset()
255
-
256
- # Base tasks
257
- if self.task == "easy":
258
- original = "def get_user(id):\n if id in users:\n return users[id]"
259
- elif self.task == "medium":
260
- original = "def process_items(items):\n for item in items:\n print(item)"
261
- elif self.task == "hard":
262
- original = "def average(data):\n if not data:\n return 0\n return sum(data) / len(data)"
263
- elif self.task == "harder":
264
- original = "counter = 0\ndef increment():\n global counter\n with lock:\n counter += 1"
265
- else:
266
- original = "def safe_work():\n with lock1:\n with lock2:\n do_work()"
267
-
268
- buggy_code, bug_id, desc, oracle = self._red_team.inject_bug(original)
269
- self._current_code = buggy_code
270
- self._current_bug_id = bug_id
271
- self._bug_description = desc
272
- self._oracle_fix = oracle
273
- self._comments.append(f"[RedTeam] {desc}")
274
-
275
- # ===================================================================
276
- def reset(self) -> EnhancedObservation:
277
- """Reset with optional curriculum adjustment."""
278
- if self.auto_difficulty and len(self._episode_rewards) > 0:
279
- recent_performance = sum(self._episode_rewards[-5:]) / min(5, len(self._episode_rewards))
280
-
281
- if recent_performance > self.success_threshold and self._difficulty_level < 4:
282
- self._difficulty_level += 1
283
- print(f"[Curriculum] Increasing difficulty to level {self._difficulty_level}")
284
- elif recent_performance < 0.3 and self._difficulty_level > 0:
285
- self._difficulty_level -= 1
286
- print(f"[Curriculum] Decreasing difficulty to level {self._difficulty_level}")
287
-
288
- level_to_task = {0: "easy", 1: "medium", 2: "hard", 3: "harder", 4: "hardest"}
289
- self.task = level_to_task[self._difficulty_level]
290
- # Keep curriculum stochastic for better coverage within each level.
291
- self._red_team = RedTeam(self.task, seed=None)
292
-
293
- self._reset_internal()
294
- return self._get_observation()
295
-
296
- # ===================================================================
297
- def _get_observation(self) -> EnhancedObservation:
298
- """Return COMPLETE Markov state."""
299
- # Keep the author's message separate from tool output.
300
- # Using `_test_results` here can leak unrelated outputs (tests/linter/docs)
301
- # and gives the policy a noisy signal for dialogue actions.
302
- if self._last_action_type in ("comment", "question", "fix"):
303
- author_response = self._last_author_response
304
- else:
305
- author_response = ""
306
-
307
- return EnhancedObservation(
308
- code_snippet=self._current_code,
309
- last_tool_output=self._test_results or "",
310
- author_response=author_response, # ← now field exists
311
-
312
- current_test_score=self._current_test_score,
313
- current_lint_score=self._current_lint_score,
314
- negotiation_score=self._author.get_negotiation_score(),
315
-
316
- previous_test_score=self._previous_test_score,
317
- previous_lint_score=self._previous_lint_score,
318
-
319
- author_confidence=self._author._confidence,
320
- author_threshold=self._author.thresholds.get(self._author.personality, 0.5),
321
-
322
- step=self._step_count,
323
- max_steps=self.max_steps,
324
- # Guard against accidental `max_steps=0` configs.
325
- progress_ratio=(self._step_count / self.max_steps) if self.max_steps > 0 else 1.0,
326
-
327
- tests_run=self._tests_run,
328
- linter_run=self._linter_run,
329
- docs_queried=self._docs_queried,
330
-
331
- last_action_type=self._last_action_type,
332
- action_history=self._action_history[-5:],
333
-
334
- done=self._done,
335
-
336
- bug_description=self._bug_description,
337
- comments_count=len(self._comments),
338
- )
339
-
340
- # ===================================================================
341
- def _get_action_type(self, action: AnyAction) -> str:
342
- """Extract action type as string."""
343
- if isinstance(action, RunTests):
344
- return "run_tests"
345
- elif isinstance(action, RunLinter):
346
- return "run_linter"
347
- elif isinstance(action, QueryDocs):
348
- return "query_docs"
349
- elif isinstance(action, Execute):
350
- return "execute"
351
- elif isinstance(action, Inspect):
352
- return "inspect"
353
- elif isinstance(action, WriteComment):
354
- return "comment"
355
- elif isinstance(action, AskQuestion):
356
- return "question"
357
- elif isinstance(action, ProposeFix):
358
- return "fix"
359
- elif isinstance(action, Done):
360
- return "done"
361
- elif isinstance(action, Skip):
362
- return "skip"
363
- else:
364
- return "unknown"
365
-
366
- # ===================================================================
367
- def _get_test_runner_bug_id(self) -> str:
368
- """
369
- Normalize RedTeam bug ids to the canonical ids understood by TestRunner.
370
- Falls back to the original id for known direct matches.
371
- """
372
- return self._BUG_ID_CANONICAL_MAP.get(self._current_bug_id, self._current_bug_id)
373
-
374
- # ===================================================================
375
- def step(self, action: AnyAction) -> Tuple[EnhancedObservation, Reward, bool, Dict[str, Any]]:
376
- """
377
- TRUE RL STEP with:
378
- - Complete Markov observations (no hidden state)
379
- - Dense intermediate rewards
380
- - Delta-based credit assignment (no double-counting)
381
- - Proper episode reward tracking
382
- """
383
- if self._done:
384
- raise RuntimeError("Episode already finished")
385
-
386
- # Store previous metrics for delta computation
387
- self._previous_test_score = self._current_test_score
388
- self._previous_lint_score = self._current_lint_score
389
- # Snapshot tool-usage flags BEFORE action mutates them.
390
- # Rubrics use these to detect true "first-use" behavior.
391
- prev_tests_run = self._tests_run
392
- prev_linter_run = self._linter_run
393
- prev_docs_queried = self._docs_queried
394
-
395
- base_reward = 0.0
396
- action_type = self._get_action_type(action)
397
-
398
- # Update action history
399
- self._action_history.append(action_type)
400
- self._last_action_type = action_type
401
-
402
- # ==============================================================
403
- # TOOL ACTIONS
404
- # ==============================================================
405
- if isinstance(action, Execute):
406
- success, stdout, stderr = execute_code(self._current_code)
407
- output = (stdout + stderr).strip() or "No output"
408
- self._test_results = f"[Execute] {'Success' if success else 'Failed'}\n{output[:300]}"
409
- base_reward = 0.001 if success else -0.05
410
-
411
- elif isinstance(action, Inspect):
412
- self._test_results = f"[Inspect]\n{self._current_code[:500]}"
413
- base_reward = 0.001
414
-
415
- elif isinstance(action, RunLinter):
416
- lint_output = ToolBox.run_linter(self._current_code)
417
- self._lint_results = lint_output[:500]
418
- self._test_results = f"[Linter]\n{self._lint_results}"
419
-
420
- self._current_lint_score = self._run_linter_score(self._current_code)
421
- self._linter_run = True
422
- base_reward = 0.002
423
-
424
- elif isinstance(action, RunTests):
425
- runner = TestRunner(self._get_test_runner_bug_id())
426
- score, output = runner.run_tests(self._current_code)
427
-
428
- self._current_test_score = score
429
- self._tests_run = True
430
-
431
- self._test_results = f"[Tests] Score: {score:.2f}\n{output[:300]}"
432
- base_reward = 0.002
433
-
434
- if score > 0.8:
435
- base_reward += 0.005
436
-
437
- elif isinstance(action, QueryDocs):
438
- # Normalize query to avoid rewarding empty/noisy requests.
439
- query_topic = (action.query_topic or "").strip()
440
- doc = ToolBox.query_docs(query_topic if query_topic else "general bug fixing")
441
- self._doc_results = doc
442
- self._test_results = f"[Docs]\n{doc[:400]}"
443
- self._docs_queried = True
444
- base_reward = 0.001
445
-
446
- # ==============================================================
447
- # COMMUNICATION ACTIONS
448
- # ==============================================================
449
- elif isinstance(action, WriteComment):
450
- self._comments.append(f"Agent: {action.comment_text}")
451
-
452
- response = self._author.respond(
453
- agent_comment=action.comment_text,
454
- test_results=self._test_results,
455
- lint_results=self._lint_results,
456
- doc_results=self._doc_results,
457
- proposed_fix=None,
458
- original_code=self._current_code
459
- )
460
-
461
- self._comments.append(f"Author: {response}")
462
- self._last_author_response = response
463
- self._test_results = f"[Comment] Author: {response[:200]}"
464
- base_reward = 0.001
465
-
466
- elif isinstance(action, AskQuestion):
467
- self._comments.append(f"Agent: {action.question}")
468
-
469
- response = self._author.respond(
470
- agent_question=action.question,
471
- test_results=self._test_results,
472
- lint_results=self._lint_results,
473
- doc_results=self._doc_results,
474
- proposed_fix=None,
475
- original_code=self._current_code # ← FIXED
476
- )
477
-
478
- self._comments.append(f"Author: {response}")
479
- self._last_author_response = response
480
- self._test_results = f"[Question] Author: {response[:200]}"
481
- base_reward = 0.002
482
-
483
- # ==============================================================
484
- # FINAL FIX ACTION
485
- # ==============================================================
486
- elif isinstance(action, ProposeFix):
487
- if not action.fix_code:
488
- base_reward = -0.05
489
- self._done = True
490
- else:
491
- # Save original code BEFORE overwriting (for author.respond)
492
- original_buggy = self._current_code
493
- self._current_code = action.fix_code
494
-
495
- runner = TestRunner(self._get_test_runner_bug_id())
496
- test_score, test_output = runner.run_tests(self._current_code)
497
- lint_score = self._run_linter_score(self._current_code)
498
- negotiation_score = self._author.get_negotiation_score()
499
-
500
- self._current_test_score = test_score
501
- self._current_lint_score = lint_score
502
-
503
- # Author gating – determines if the episode ends, reward is separate
504
- threshold = self._author.thresholds.get(self._author.personality, 0.5)
505
- if self._author._confidence < threshold:
506
- if self._step_count < self.max_steps:
507
- self._done = False
508
- else:
509
- self._done = True
510
- else:
511
- self._done = True
512
-
513
- # Get author's verbal feedback (pushback/acceptance)
514
- author_feedback = self._author.respond(
515
- agent_comment=f"Proposed fix:\n{action.fix_code}",
516
- test_results=f"Score: {test_score:.2f}",
517
- lint_results=f"Score: {lint_score:.2f}",
518
- doc_results=self._doc_results,
519
- proposed_fix=action.fix_code,
520
- original_code=original_buggy # now correctly the buggy code, not the fix
521
- )
522
- self._test_results = f"[Fix] Author: {author_feedback[:200]}"
523
- self._comments.append(f"Author: {author_feedback}")
524
- self._last_author_response = author_feedback
525
-
526
- base_reward = 0.001 # rubrics provide the real signal
527
-
528
- # ==============================================================
529
- # TERMINATION ACTIONS
530
- # ==============================================================
531
- elif isinstance(action, Skip):
532
- base_reward = -0.03
533
- self._done = True
534
-
535
- elif isinstance(action, Done):
536
- if self._tests_run:
537
- base_reward = self._current_test_score * 0.5 - 0.2
538
- else:
539
- base_reward = -0.04
540
- self._done = True
541
-
542
- else:
543
- base_reward = -0.02
544
- self._done = True
545
-
546
- # ==============================================================
547
- # STEP UPDATE (before rubric computation so info contains final step)
548
- # ==============================================================
549
- self._step_count += 1
550
- if self._step_count >= self.max_steps:
551
- self._done = True
552
-
553
- # Get fresh observation (needed for rubrics that may read obs)
554
- obs = self._get_observation()
555
-
556
- # Prepare info dict (rubrics may need action_type and deltas)
557
- info = {
558
- "action_type": action_type,
559
- "test_score": self._current_test_score,
560
- "lint_score": self._current_lint_score,
561
- "test_delta": self._current_test_score - self._previous_test_score,
562
- "lint_delta": self._current_lint_score - self._previous_lint_score,
563
- "prev_tests_run": prev_tests_run,
564
- "prev_linter_run": prev_linter_run,
565
- "prev_docs_queried": prev_docs_queried,
566
- "docs_query_len": len((action.query_topic or "").strip()) if isinstance(action, QueryDocs) else 0,
567
- "base_reward": base_reward,
568
- }
569
-
570
- # ==============================================================
571
- # COMPUTE FINAL REWARD USING RUBRICS
572
- # ==============================================================
573
- rubric_score = sum(r(self, action, obs, None, self._done, info) for r in self.rubrics)
574
- final_reward = 0.4 * base_reward + rubric_score
575
- final_reward = max(-1.0, min(1.0, final_reward)) # safety clip
576
-
577
- # Track cumulative episode reward
578
- self._episode_total_reward += final_reward
579
-
580
- # Store episode total if done
581
- if self._done:
582
- self._episode_rewards.append(self._episode_total_reward)
583
-
584
- # Complete info
585
- info["final_reward"] = final_reward
586
- info["episode_total"] = self._episode_total_reward
587
-
588
- return obs, Reward(value=final_reward), self._done, info
589
-
590
- # ===================================================================
591
- def _run_linter_score(self, code: str) -> float:
592
- """Run pylint and return normalized score [0, 1]."""
593
- try:
594
- with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
595
- f.write(code)
596
- tmp_path = f.name
597
-
598
  result = subprocess.run(
599
- ['pylint', tmp_path, '--score=y', '--exit-zero'],
600
  capture_output=True,
601
  text=True,
602
  timeout=5
603
- )
604
-
605
- match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
606
- if match:
607
- return float(match.group(1)) / 10.0
608
- return 0.0
609
- except:
610
- return 0.0
611
- finally:
612
- try:
613
- os.unlink(tmp_path)
614
- except:
615
- pass
616
-
617
- # ===================================================================
618
- def state(self) -> State:
619
- """Legacy compatibility."""
620
- return State(
621
- pr_title="Code Review",
622
- pr_description=self._bug_description,
623
- code_snippet=self._current_code,
624
- comments=self._comments.copy(),
625
- test_results=self._test_results,
626
- step=self._step_count,
627
- done=self._done
628
- )
 
1
+ # environment.py – FULLY CORRECTED RL Environment (TRUE Markov + Fixed Bugs)
2
+
3
+ import sys
4
+ import subprocess
5
+ import tempfile
6
+ import os
7
+ import re
8
+ from dataclasses import dataclass, field
9
+ from typing import Tuple, Dict, Any, Optional, List
10
+
11
+ from models import (
12
+ AnyAction, WriteComment, ProposeFix, Execute, Inspect,
13
+ RunLinter, RunTests, QueryDocs, Skip, Done, AskQuestion,
14
+ Observation, Reward, State
15
+ )
16
+ from redteam import RedTeam
17
+ from test_runner import TestRunner
18
+ from author import PersonaAuthor
19
+ from rltool import ToolBox
20
+ from rubrics import (
21
+ ToolUsageRubric,
22
+ TestDeltaRubric,
23
+ LintDeltaRubric,
24
+ TerminalSuccessRubric,
25
+ ExplorationRubric,
26
+ AntiHackingRubric,
27
+ StepPenaltyRubric,
28
+ )
29
+
30
+ # ======================================================================
31
+ # FULLY MARKOV OBSERVATION (NOTHING HIDDEN)
32
+ # ======================================================================
33
+ @dataclass
34
+ class EnhancedObservation:
35
+ code_snippet: str
36
+ last_tool_output: str
37
+
38
+ current_test_score: float
39
+ current_lint_score: float
40
+ negotiation_score: float
41
+
42
+ previous_test_score: float
43
+ previous_lint_score: float
44
+
45
+ author_confidence: float
46
+ author_threshold: float
47
+
48
+ step: int
49
+ max_steps: int
50
+ progress_ratio: float
51
+
52
+ tests_run: bool
53
+ linter_run: bool
54
+ docs_queried: bool
55
+
56
+ last_action_type: str
57
+ action_history: List[str]
58
+
59
+ done: bool
60
+
61
+ bug_description: str
62
+ comments_count: int
63
+
64
+ # default fields must be at the very end
65
+ author_response: str = ""
66
+
67
+ # ======================================================================
68
+ # HELPER FUNCTIONS
69
+ # ======================================================================
70
+ def execute_code(code: str, timeout_sec: int = 5) -> Tuple[bool, str, str]:
71
+ if not code.strip():
72
+ return False, "", "Error: Empty code"
73
+
74
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
75
+ f.write(code)
76
+ tmp_path = f.name
77
+
78
+ try:
79
+ result = subprocess.run(
80
+ [sys.executable, tmp_path],
81
+ capture_output=True,
82
+ text=True,
83
+ timeout=timeout_sec
84
+ )
85
+ success = (result.returncode == 0)
86
+ return success, result.stdout, result.stderr
87
+ except subprocess.TimeoutExpired:
88
+ return False, "", f"Timeout after {timeout_sec}s"
89
+ except Exception as e:
90
+ return False, "", f"Execution error: {str(e)}"
91
+ finally:
92
+ try:
93
+ os.unlink(tmp_path)
94
+ except:
95
+ pass
96
+
97
+
98
+ # ======================================================================
99
+ # ENHANCED CODE REVIEW ENVIRONMENT
100
+ # ======================================================================
101
+ @dataclass
102
+ class CodeReviewEnv:
103
+ task: str = "easy"
104
+ max_steps: int = 10
105
+ step_penalty: float = 0.01
106
+ reward_profile: str = "full" # "full" or "core"
107
+
108
+ # Curriculum learning
109
+ auto_difficulty: bool = False
110
+ success_threshold: float = 0.7
111
+
112
+ # Reward shaping parameters
113
+ delta_weight: float = 0.3
114
+ tool_usage_bonus: float = 0.05
115
+ diversity_bonus: float = 0.03
116
+
117
+ _red_team: Optional[RedTeam] = field(init=False, default=None)
118
+ _author: Optional[PersonaAuthor] = field(init=False, default=None)
119
+
120
+ _current_code: str = field(init=False, default="")
121
+ _current_bug_id: str = field(init=False, default="")
122
+ _bug_description: str = field(init=False, default="")
123
+ _oracle_fix: str = field(init=False, default="")
124
+
125
+ _comments: list = field(init=False, default_factory=list)
126
+ _test_results: Optional[str] = field(init=False, default=None)
127
+ _lint_results: Optional[str] = field(init=False, default=None)
128
+ _doc_results: Optional[str] = field(init=False, default=None)
129
+
130
+ _step_count: int = field(init=False, default=0)
131
+ _done: bool = field(init=False, default=False)
132
+
133
+ # State tracking for dense rewards
134
+ _previous_test_score: float = field(init=False, default=0.0)
135
+ _previous_lint_score: float = field(init=False, default=0.0)
136
+ _current_test_score: float = field(init=False, default=0.0)
137
+ _current_lint_score: float = field(init=False, default=0.0)
138
+
139
+ # Tool usage tracking
140
+ _tests_run: bool = field(init=False, default=False)
141
+ _linter_run: bool = field(init=False, default=False)
142
+ _docs_queried: bool = field(init=False, default=False)
143
+
144
+ # Action history
145
+ _action_history: List[str] = field(init=False, default_factory=list)
146
+ _last_action_type: str = field(init=False, default="none")
147
+ _last_author_response: str = field(init=False, default="")
148
+
149
+ # FIXED: Track CUMULATIVE episode reward
150
+ _episode_total_reward: float = field(init=False, default=0.0)
151
+ _episode_rewards: List[float] = field(init=False, default_factory=list)
152
+ _difficulty_level: int = field(init=False, default=0)
153
+
154
+ # Bug-id bridge:
155
+ # RedTeam has fine-grained IDs, while TestRunner currently expects a
156
+ # smaller canonical set. Keep this mapping here so both modules can evolve
157
+ # independently without breaking evaluation.
158
+ _BUG_ID_CANONICAL_MAP = {
159
+ # Easy-family
160
+ "simple_typo": "null_check",
161
+ "default_value": "null_check",
162
+ "empty_return": "null_check",
163
+ "string_index": "off_by_one",
164
+
165
+ # Medium-family
166
+ "loop_skip": "off_by_one",
167
+ "sign_error": "wrong_operator",
168
+ "swap_args": "wrong_operator",
169
+ "uninitialised_var": "null_check",
170
+
171
+ # Hard-family
172
+ "division_by_zero_empty": "division_by_zero",
173
+ "division_by_zero_zero": "division_by_zero",
174
+ "float_precision": "division_by_zero",
175
+ "abs_usage": "division_by_zero",
176
+ "round_error": "division_by_zero",
177
+ }
178
+
179
+ # ===================================================================
180
+ def __post_init__(self):
181
+ self.set_task(self.task)
182
+
183
+ # ===================================================================
184
+ def _build_rubrics(self):
185
+ """
186
+ Build rubric stack from a named reward profile.
187
+ - full: richer shaping for exploration/tool-use behavior
188
+ - core: minimal stable signal for quick ablations/baselines
189
+ """
190
+ core_rubrics = [
191
+ TestDeltaRubric(weight=self.delta_weight),
192
+ LintDeltaRubric(weight=self.delta_weight),
193
+ TerminalSuccessRubric(),
194
+ StepPenaltyRubric(penalty=self.step_penalty),
195
+ ]
196
+ if self.reward_profile == "core":
197
+ return core_rubrics
198
+ if self.reward_profile == "full":
199
+ return [
200
+ *core_rubrics[:-1], # step penalty appended at end for consistent ordering
201
+ ToolUsageRubric(bonus=self.tool_usage_bonus),
202
+ ExplorationRubric(penalty=-0.05, bonus=self.diversity_bonus * 0.7),
203
+ AntiHackingRubric(),
204
+ core_rubrics[-1],
205
+ ]
206
+ raise ValueError(f"Unknown reward_profile: {self.reward_profile}")
207
+
208
+ # ===================================================================
209
+ def set_task(self, task: str):
210
+ if task not in ["easy", "medium", "hard", "harder", "hardest"]:
211
+ raise ValueError(f"Unknown task: {task}")
212
+
213
+ self.task = task
214
+ # Use stochastic bug sampling across episodes; fixed seed here would
215
+ # repeatedly select the same bug and weaken training diversity.
216
+ self._red_team = RedTeam(task, seed=None)
217
+ self._author = PersonaAuthor()
218
+ self.rubrics = self._build_rubrics()
219
+
220
+ task_to_level = {
221
+ "easy": 0, "medium": 1, "hard": 2,
222
+ "harder": 3, "hardest": 4
223
+ }
224
+ self._difficulty_level = task_to_level[task]
225
+
226
+ self._reset_internal()
227
+
228
+ # ===================================================================
229
+ def _reset_internal(self):
230
+ self._step_count = 0 # ← FIXED
231
+ self._comments = []
232
+ self._test_results = None
233
+ self._lint_results = None
234
+ self._doc_results = None
235
+ self._done = False
236
+
237
+ # Reset state tracking
238
+ self._previous_test_score = 0.0
239
+ self._previous_lint_score = 0.0
240
+ self._current_test_score = 0.0
241
+ self._current_lint_score = 0.0
242
+
243
+ self._tests_run = False
244
+ self._linter_run = False
245
+ self._docs_queried = False
246
+
247
+ self._action_history = []
248
+ self._last_action_type = "none"
249
+ self._last_author_response = ""
250
+
251
+ # FIXED: Reset episode cumulative reward
252
+ self._episode_total_reward = 0.0
253
+
254
+ self._author.reset()
255
+
256
+ # Base tasks
257
+ if self.task == "easy":
258
+ original = "def get_user(id):\n if id in users:\n return users[id]"
259
+ elif self.task == "medium":
260
+ original = "def process_items(items):\n for item in items:\n print(item)"
261
+ elif self.task == "hard":
262
+ original = "def average(data):\n if not data:\n return 0\n return sum(data) / len(data)"
263
+ elif self.task == "harder":
264
+ original = "counter = 0\ndef increment():\n global counter\n with lock:\n counter += 1"
265
+ else:
266
+ original = "def safe_work():\n with lock1:\n with lock2:\n do_work()"
267
+
268
+ buggy_code, bug_id, desc, oracle = self._red_team.inject_bug(original)
269
+ self._current_code = buggy_code
270
+ self._current_bug_id = bug_id
271
+ self._bug_description = desc
272
+ self._oracle_fix = oracle
273
+ self._comments.append(f"[RedTeam] {desc}")
274
+
275
+ # ===================================================================
276
+ def reset(self) -> EnhancedObservation:
277
+ """Reset with optional curriculum adjustment."""
278
+ if self.auto_difficulty and len(self._episode_rewards) > 0:
279
+ recent_performance = sum(self._episode_rewards[-5:]) / min(5, len(self._episode_rewards))
280
+
281
+ if recent_performance > self.success_threshold and self._difficulty_level < 4:
282
+ self._difficulty_level += 1
283
+ print(f"[Curriculum] Increasing difficulty to level {self._difficulty_level}")
284
+ elif recent_performance < 0.3 and self._difficulty_level > 0:
285
+ self._difficulty_level -= 1
286
+ print(f"[Curriculum] Decreasing difficulty to level {self._difficulty_level}")
287
+
288
+ level_to_task = {0: "easy", 1: "medium", 2: "hard", 3: "harder", 4: "hardest"}
289
+ self.task = level_to_task[self._difficulty_level]
290
+ # Keep curriculum stochastic for better coverage within each level.
291
+ self._red_team = RedTeam(self.task, seed=None)
292
+
293
+ self._reset_internal()
294
+ return self._get_observation()
295
+
296
+ # ===================================================================
297
+ def _get_observation(self) -> EnhancedObservation:
298
+ """Return COMPLETE Markov state."""
299
+ # Keep the author's message separate from tool output.
300
+ # Using `_test_results` here can leak unrelated outputs (tests/linter/docs)
301
+ # and gives the policy a noisy signal for dialogue actions.
302
+ if self._last_action_type in ("comment", "question", "fix"):
303
+ author_response = self._last_author_response
304
+ else:
305
+ author_response = ""
306
+
307
+ return EnhancedObservation(
308
+ code_snippet=self._current_code,
309
+ last_tool_output=self._test_results or "",
310
+ author_response=author_response, # ← now field exists
311
+
312
+ current_test_score=self._current_test_score,
313
+ current_lint_score=self._current_lint_score,
314
+ negotiation_score=self._author.get_negotiation_score(),
315
+
316
+ previous_test_score=self._previous_test_score,
317
+ previous_lint_score=self._previous_lint_score,
318
+
319
+ author_confidence=self._author._confidence,
320
+ author_threshold=self._author.thresholds.get(self._author.personality, 0.5),
321
+
322
+ step=self._step_count,
323
+ max_steps=self.max_steps,
324
+ # Guard against accidental `max_steps=0` configs.
325
+ progress_ratio=(self._step_count / self.max_steps) if self.max_steps > 0 else 1.0,
326
+
327
+ tests_run=self._tests_run,
328
+ linter_run=self._linter_run,
329
+ docs_queried=self._docs_queried,
330
+
331
+ last_action_type=self._last_action_type,
332
+ action_history=self._action_history[-5:],
333
+
334
+ done=self._done,
335
+
336
+ bug_description=self._bug_description,
337
+ comments_count=len(self._comments),
338
+ )
339
+
340
+ # ===================================================================
341
+ def _get_action_type(self, action: AnyAction) -> str:
342
+ """Extract action type as string."""
343
+ if isinstance(action, RunTests):
344
+ return "run_tests"
345
+ elif isinstance(action, RunLinter):
346
+ return "run_linter"
347
+ elif isinstance(action, QueryDocs):
348
+ return "query_docs"
349
+ elif isinstance(action, Execute):
350
+ return "execute"
351
+ elif isinstance(action, Inspect):
352
+ return "inspect"
353
+ elif isinstance(action, WriteComment):
354
+ return "comment"
355
+ elif isinstance(action, AskQuestion):
356
+ return "question"
357
+ elif isinstance(action, ProposeFix):
358
+ return "fix"
359
+ elif isinstance(action, Done):
360
+ return "done"
361
+ elif isinstance(action, Skip):
362
+ return "skip"
363
+ else:
364
+ return "unknown"
365
+
366
+ # ===================================================================
367
+ def _get_test_runner_bug_id(self) -> str:
368
+ """
369
+ Normalize RedTeam bug ids to the canonical ids understood by TestRunner.
370
+ Falls back to the original id for known direct matches.
371
+ """
372
+ return self._BUG_ID_CANONICAL_MAP.get(self._current_bug_id, self._current_bug_id)
373
+
374
+ # ===================================================================
375
+ def step(self, action: AnyAction) -> Tuple[EnhancedObservation, Reward, bool, Dict[str, Any]]:
376
+ """
377
+ TRUE RL STEP with:
378
+ - Complete Markov observations (no hidden state)
379
+ - Dense intermediate rewards
380
+ - Delta-based credit assignment (no double-counting)
381
+ - Proper episode reward tracking
382
+ """
383
+ if self._done:
384
+ raise RuntimeError("Episode already finished")
385
+
386
+ # Store previous metrics for delta computation
387
+ self._previous_test_score = self._current_test_score
388
+ self._previous_lint_score = self._current_lint_score
389
+ # Snapshot tool-usage flags BEFORE action mutates them.
390
+ # Rubrics use these to detect true "first-use" behavior.
391
+ prev_tests_run = self._tests_run
392
+ prev_linter_run = self._linter_run
393
+ prev_docs_queried = self._docs_queried
394
+
395
+ base_reward = 0.0
396
+ action_type = self._get_action_type(action)
397
+
398
+ # Update action history
399
+ self._action_history.append(action_type)
400
+ self._last_action_type = action_type
401
+
402
+ # ==============================================================
403
+ # TOOL ACTIONS
404
+ # ==============================================================
405
+ if isinstance(action, Execute):
406
+ success, stdout, stderr = execute_code(self._current_code)
407
+ output = (stdout + stderr).strip() or "No output"
408
+ self._test_results = f"[Execute] {'Success' if success else 'Failed'}\n{output[:300]}"
409
+ base_reward = 0.001 if success else -0.05
410
+
411
+ elif isinstance(action, Inspect):
412
+ self._test_results = f"[Inspect]\n{self._current_code[:500]}"
413
+ base_reward = 0.001
414
+
415
+ elif isinstance(action, RunLinter):
416
+ lint_output = ToolBox.run_linter(self._current_code)
417
+ self._lint_results = lint_output[:500]
418
+ self._test_results = f"[Linter]\n{self._lint_results}"
419
+
420
+ self._current_lint_score = self._run_linter_score(self._current_code)
421
+ self._linter_run = True
422
+ base_reward = 0.002
423
+
424
+ elif isinstance(action, RunTests):
425
+ runner = TestRunner(self._get_test_runner_bug_id())
426
+ score, output = runner.run_tests(self._current_code)
427
+
428
+ self._current_test_score = score
429
+ self._tests_run = True
430
+
431
+ self._test_results = f"[Tests] Score: {score:.2f}\n{output[:300]}"
432
+ base_reward = 0.002
433
+
434
+ if score > 0.8:
435
+ base_reward += 0.005
436
+
437
+ elif isinstance(action, QueryDocs):
438
+ # Normalize query to avoid rewarding empty/noisy requests.
439
+ query_topic = (action.query_topic or "").strip()
440
+ doc = ToolBox.query_docs(query_topic if query_topic else "general bug fixing")
441
+ self._doc_results = doc
442
+ self._test_results = f"[Docs]\n{doc[:400]}"
443
+ self._docs_queried = True
444
+ base_reward = 0.001
445
+
446
+ # ==============================================================
447
+ # COMMUNICATION ACTIONS
448
+ # ==============================================================
449
+ elif isinstance(action, WriteComment):
450
+ self._comments.append(f"Agent: {action.comment_text}")
451
+
452
+ response = self._author.respond(
453
+ agent_comment=action.comment_text,
454
+ test_results=self._test_results,
455
+ lint_results=self._lint_results,
456
+ doc_results=self._doc_results,
457
+ proposed_fix=None,
458
+ original_code=self._current_code
459
+ )
460
+
461
+ self._comments.append(f"Author: {response}")
462
+ self._last_author_response = response
463
+ self._test_results = f"[Comment] Author: {response[:200]}"
464
+ base_reward = 0.001
465
+
466
+ elif isinstance(action, AskQuestion):
467
+ self._comments.append(f"Agent: {action.question}")
468
+
469
+ response = self._author.respond(
470
+ agent_question=action.question,
471
+ test_results=self._test_results,
472
+ lint_results=self._lint_results,
473
+ doc_results=self._doc_results,
474
+ proposed_fix=None,
475
+ original_code=self._current_code # ← FIXED
476
+ )
477
+
478
+ self._comments.append(f"Author: {response}")
479
+ self._last_author_response = response
480
+ self._test_results = f"[Question] Author: {response[:200]}"
481
+ base_reward = 0.002
482
+
483
+ # ==============================================================
484
+ # FINAL FIX ACTION
485
+ # ==============================================================
486
+ elif isinstance(action, ProposeFix):
487
+ if not action.fix_code:
488
+ base_reward = -0.05
489
+ self._done = True
490
+ else:
491
+ # Save original code BEFORE overwriting (for author.respond)
492
+ original_buggy = self._current_code
493
+ self._current_code = action.fix_code
494
+
495
+ runner = TestRunner(self._get_test_runner_bug_id())
496
+ test_score, test_output = runner.run_tests(self._current_code)
497
+ lint_score = self._run_linter_score(self._current_code)
498
+ negotiation_score = self._author.get_negotiation_score()
499
+
500
+ self._current_test_score = test_score
501
+ self._current_lint_score = lint_score
502
+
503
+ # Author gating – determines if the episode ends, reward is separate
504
+ threshold = self._author.thresholds.get(self._author.personality, 0.5)
505
+ if self._author._confidence < threshold:
506
+ if self._step_count < self.max_steps:
507
+ self._done = False
508
+ else:
509
+ self._done = True
510
+ else:
511
+ self._done = True
512
+
513
+ # Get author's verbal feedback (pushback/acceptance)
514
+ author_feedback = self._author.respond(
515
+ agent_comment=f"Proposed fix:\n{action.fix_code}",
516
+ test_results=f"Score: {test_score:.2f}",
517
+ lint_results=f"Score: {lint_score:.2f}",
518
+ doc_results=self._doc_results,
519
+ proposed_fix=action.fix_code,
520
+ original_code=original_buggy # now correctly the buggy code, not the fix
521
+ )
522
+ self._test_results = f"[Fix] Author: {author_feedback[:200]}"
523
+ self._comments.append(f"Author: {author_feedback}")
524
+ self._last_author_response = author_feedback
525
+
526
+ base_reward = 0.001 # rubrics provide the real signal
527
+
528
+ # ==============================================================
529
+ # TERMINATION ACTIONS
530
+ # ==============================================================
531
+ elif isinstance(action, Skip):
532
+ base_reward = -0.03
533
+ self._done = True
534
+
535
+ elif isinstance(action, Done):
536
+ if self._tests_run:
537
+ base_reward = self._current_test_score * 0.5 - 0.2
538
+ else:
539
+ base_reward = -0.04
540
+ self._done = True
541
+
542
+ else:
543
+ base_reward = -0.02
544
+ self._done = True
545
+
546
+ # ==============================================================
547
+ # STEP UPDATE (before rubric computation so info contains final step)
548
+ # ==============================================================
549
+ self._step_count += 1
550
+ if self._step_count >= self.max_steps:
551
+ self._done = True
552
+
553
+ # Get fresh observation (needed for rubrics that may read obs)
554
+ obs = self._get_observation()
555
+
556
+ # Prepare info dict (rubrics may need action_type and deltas)
557
+ info = {
558
+ "action_type": action_type,
559
+ "test_score": self._current_test_score,
560
+ "lint_score": self._current_lint_score,
561
+ "test_delta": self._current_test_score - self._previous_test_score,
562
+ "lint_delta": self._current_lint_score - self._previous_lint_score,
563
+ "prev_tests_run": prev_tests_run,
564
+ "prev_linter_run": prev_linter_run,
565
+ "prev_docs_queried": prev_docs_queried,
566
+ "docs_query_len": len((action.query_topic or "").strip()) if isinstance(action, QueryDocs) else 0,
567
+ "base_reward": base_reward,
568
+ }
569
+
570
+ # ==============================================================
571
+ # COMPUTE FINAL REWARD USING RUBRICS
572
+ # ==============================================================
573
+ rubric_score = sum(r(self, action, obs, None, self._done, info) for r in self.rubrics)
574
+ final_reward = 0.4 * base_reward + rubric_score
575
+ final_reward = max(-1.0, min(1.0, final_reward)) # safety clip
576
+
577
+ # Track cumulative episode reward
578
+ self._episode_total_reward += final_reward
579
+
580
+ # Store episode total if done
581
+ if self._done:
582
+ self._episode_rewards.append(self._episode_total_reward)
583
+
584
+ # Complete info
585
+ info["final_reward"] = final_reward
586
+ info["episode_total"] = self._episode_total_reward
587
+
588
+ return obs, Reward(value=final_reward), self._done, info
589
+
590
+ # ===================================================================
591
+ def _run_linter_score(self, code: str) -> float:
592
+ """Run pylint and return normalized score [0, 1]."""
593
+ try:
594
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
595
+ f.write(code)
596
+ tmp_path = f.name
597
+
598
  result = subprocess.run(
599
+ [sys.executable, '-m', 'pylint', tmp_path, '--score=y', '--exit-zero'],
600
  capture_output=True,
601
  text=True,
602
  timeout=5
603
+ )
604
+
605
+ match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
606
+ if match:
607
+ return float(match.group(1)) / 10.0
608
+ return 0.0
609
+ except:
610
+ return 0.0
611
+ finally:
612
+ try:
613
+ os.unlink(tmp_path)
614
+ except:
615
+ pass
616
+
617
+ # ===================================================================
618
+ def state(self) -> State:
619
+ """Legacy compatibility."""
620
+ return State(
621
+ pr_title="Code Review",
622
+ pr_description=self._bug_description,
623
+ code_snippet=self._current_code,
624
+ comments=self._comments.copy(),
625
+ test_results=self._test_results,
626
+ step=self._step_count,
627
+ done=self._done
628
+ )
grader.py CHANGED
@@ -1,11 +1,12 @@
1
  # grader.py – Production‑grade, continuous reward, exploit‑aware, example of monolithic scoring
2
  import ast
3
- import subprocess
4
- import tempfile
5
- import os
6
- import re
7
- from dataclasses import dataclass
8
- from typing import Optional
 
9
 
10
  @dataclass
11
  class RigorousGrader:
@@ -105,11 +106,11 @@ class RigorousGrader:
105
  f.write(code)
106
  f.flush()
107
  tmp_path = f.name
108
- result = subprocess.run(
109
- ['pylint', tmp_path, '--score=y', '--exit-zero'],
110
- capture_output=True,
111
- text=True,
112
- timeout=5
113
  )
114
  match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
115
  if match:
@@ -139,4 +140,4 @@ class RigorousGrader:
139
  total = max(len(nodes_prop), len(nodes_oracle))
140
  return common / total if total > 0 else 0.0
141
  except:
142
- return 0.0
 
1
  # grader.py – Production‑grade, continuous reward, exploit‑aware, example of monolithic scoring
2
  import ast
3
+ import subprocess
4
+ import tempfile
5
+ import os
6
+ import re
7
+ import sys
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
 
11
  @dataclass
12
  class RigorousGrader:
 
106
  f.write(code)
107
  f.flush()
108
  tmp_path = f.name
109
+ result = subprocess.run(
110
+ [sys.executable, '-m', 'pylint', tmp_path, '--score=y', '--exit-zero'],
111
+ capture_output=True,
112
+ text=True,
113
+ timeout=5
114
  )
115
  match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
116
  if match:
 
140
  total = max(len(nodes_prop), len(nodes_oracle))
141
  return common / total if total > 0 else 0.0
142
  except:
143
+ return 0.0
models.py CHANGED
@@ -1,112 +1,112 @@
1
- # models.py – Typed Models (Discriminated Unions, POMDP Separation)
2
- from typing import Literal, Union, Annotated, Optional
3
- from pydantic import BaseModel, Field, TypeAdapter, field_validator
4
-
5
- # ----------------------------------------------------------------------
6
- # Action classes (discriminated union)
7
- # ----------------------------------------------------------------------
8
- class Action(BaseModel):
9
- action_type: Literal["comment", "skip", "done", "question",
10
- "fix", "execute", "inspect", "run_linter",
11
- "run_tests", "query_docs"]
12
-
13
- class WriteComment(Action):
14
- action_type: Literal["comment"] = "comment"
15
- comment_text: str = Field(..., min_length=1)
16
-
17
- class Skip(Action):
18
- action_type: Literal["skip"] = "skip"
19
-
20
- class Done(Action):
21
- action_type: Literal["done"] = "done"
22
-
23
- class AskQuestion(Action):
24
- action_type: Literal["question"] = "question"
25
- question: str = Field(..., min_length=1)
26
-
27
- class ProposeFix(Action):
28
- action_type: Literal["fix"] = "fix"
29
- fix_code: str = Field(..., min_length=1)
30
- @field_validator('fix_code')
31
- @classmethod
32
- def not_empty(cls, v: str) -> str:
33
- if not v.strip():
34
- raise ValueError('fix_code cannot be empty')
35
- return v
36
-
37
- class Execute(Action):
38
- action_type: Literal["execute"] = "execute"
39
-
40
- class Inspect(Action):
41
- action_type: Literal["inspect"] = "inspect"
42
-
43
- class RunLinter(Action):
44
- action_type: Literal["run_linter"] = "run_linter"
45
-
46
- class RunTests(Action):
47
- action_type: Literal["run_tests"] = "run_tests"
48
-
49
- class QueryDocs(Action):
50
- action_type: Literal["query_docs"] = "query_docs"
51
- query_topic: str = Field(..., min_length=1)
52
-
53
- # Discriminated union for one‑line polymorphic deserialization
54
- AnyAction = Annotated[
55
- Union[WriteComment, Skip, Done, AskQuestion, ProposeFix,
56
- Execute, Inspect, RunLinter, RunTests, QueryDocs],
57
- Field(discriminator='action_type')
58
- ]
59
- action_adapter = TypeAdapter(AnyAction)
60
-
61
-
62
- def map_to_env(action_type: str, content: Optional[str] = None) -> AnyAction:
63
- """
64
- Convert lightweight agent outputs into typed environment actions.
65
- Kept at module level so training/inference code can reuse one mapping.
66
- """
67
- if action_type == "run_tests":
68
- return RunTests()
69
- if action_type == "run_linter":
70
- return RunLinter()
71
- if action_type == "inspect":
72
- return Inspect()
73
- if action_type == "fix":
74
- return ProposeFix(fix_code=content or "")
75
- if action_type == "comment":
76
- return WriteComment(comment_text=content or "")
77
- if action_type == "question":
78
- return AskQuestion(question=content or "")
79
- if action_type == "query_docs":
80
- return QueryDocs(query_topic=content or "")
81
- if action_type == "done":
82
- return Done()
83
- return Skip()
84
-
85
- # ----------------------------------------------------------------------
86
- # Observation (POMDP – what the agent sees)
87
- # ----------------------------------------------------------------------
88
- class Observation(BaseModel):
89
- # Base schema model used by API metadata endpoints.
90
- # Keep this lightweight for compatibility with legacy callers.
91
- code_snippet: str
92
- last_tool_output: str = ""
93
- step: int = 0
94
- done: bool = False
95
-
96
- # ----------------------------------------------------------------------
97
- # Reward (lightweight)
98
- # ----------------------------------------------------------------------
99
- class Reward(BaseModel):
100
- value: float
101
-
102
- # ----------------------------------------------------------------------
103
- # State (full environment state – not exposed to agent)
104
- # ----------------------------------------------------------------------
105
- class State(BaseModel):
106
- pr_title: str
107
- pr_description: str
108
- code_snippet: str
109
- comments: list[str]
110
- test_results: Optional[str]
111
- step: int
112
  done: bool
 
1
+ # models.py – Typed Models (Discriminated Unions, POMDP Separation)
2
+ from typing import Literal, Union, Annotated, Optional
3
+ from pydantic import BaseModel, Field, TypeAdapter, field_validator
4
+
5
+ # ----------------------------------------------------------------------
6
+ # Action classes (discriminated union)
7
+ # ----------------------------------------------------------------------
8
+ class Action(BaseModel):
9
+ action_type: Literal["comment", "skip", "done", "question",
10
+ "fix", "execute", "inspect", "run_linter",
11
+ "run_tests", "query_docs"]
12
+
13
+ class WriteComment(Action):
14
+ action_type: Literal["comment"] = "comment"
15
+ comment_text: str = Field(..., min_length=1)
16
+
17
+ class Skip(Action):
18
+ action_type: Literal["skip"] = "skip"
19
+
20
+ class Done(Action):
21
+ action_type: Literal["done"] = "done"
22
+
23
+ class AskQuestion(Action):
24
+ action_type: Literal["question"] = "question"
25
+ question: str = Field(..., min_length=1)
26
+
27
+ class ProposeFix(Action):
28
+ action_type: Literal["fix"] = "fix"
29
+ fix_code: str = Field(..., min_length=1)
30
+ @field_validator('fix_code')
31
+ @classmethod
32
+ def not_empty(cls, v: str) -> str:
33
+ if not v.strip():
34
+ raise ValueError('fix_code cannot be empty')
35
+ return v
36
+
37
+ class Execute(Action):
38
+ action_type: Literal["execute"] = "execute"
39
+
40
+ class Inspect(Action):
41
+ action_type: Literal["inspect"] = "inspect"
42
+
43
+ class RunLinter(Action):
44
+ action_type: Literal["run_linter"] = "run_linter"
45
+
46
+ class RunTests(Action):
47
+ action_type: Literal["run_tests"] = "run_tests"
48
+
49
+ class QueryDocs(Action):
50
+ action_type: Literal["query_docs"] = "query_docs"
51
+ query_topic: str = Field(..., min_length=1)
52
+
53
+ # Discriminated union for one‑line polymorphic deserialization
54
+ AnyAction = Annotated[
55
+ Union[WriteComment, Skip, Done, AskQuestion, ProposeFix,
56
+ Execute, Inspect, RunLinter, RunTests, QueryDocs],
57
+ Field(discriminator='action_type')
58
+ ]
59
+ action_adapter = TypeAdapter(AnyAction)
60
+
61
+
62
+ def map_to_env(action_type: str, content: Optional[str] = None) -> AnyAction:
63
+ """
64
+ Convert lightweight agent outputs into typed environment actions.
65
+ Kept at module level so training/inference code can reuse one mapping.
66
+ """
67
+ if action_type == "run_tests":
68
+ return RunTests()
69
+ if action_type == "run_linter":
70
+ return RunLinter()
71
+ if action_type == "inspect":
72
+ return Inspect()
73
+ if action_type == "fix":
74
+ return ProposeFix(fix_code=content or "")
75
+ if action_type == "comment":
76
+ return WriteComment(comment_text=content or "")
77
+ if action_type == "question":
78
+ return AskQuestion(question=content or "")
79
+ if action_type == "query_docs":
80
+ return QueryDocs(query_topic=content or "")
81
+ if action_type == "done":
82
+ return Done()
83
+ return Skip()
84
+
85
+ # ----------------------------------------------------------------------
86
+ # Observation (POMDP – what the agent sees)
87
+ # ----------------------------------------------------------------------
88
+ class Observation(BaseModel):
89
+ # Base schema model used by API metadata endpoints.
90
+ # Keep this lightweight for compatibility with legacy callers.
91
+ code_snippet: str
92
+ last_tool_output: str = ""
93
+ step: int = 0
94
+ done: bool = False
95
+
96
+ # ----------------------------------------------------------------------
97
+ # Reward (lightweight)
98
+ # ----------------------------------------------------------------------
99
+ class Reward(BaseModel):
100
+ value: float
101
+
102
+ # ----------------------------------------------------------------------
103
+ # State (full environment state – not exposed to agent)
104
+ # ----------------------------------------------------------------------
105
+ class State(BaseModel):
106
+ pr_title: str
107
+ pr_description: str
108
+ code_snippet: str
109
+ comments: list[str]
110
+ test_results: Optional[str]
111
+ step: int
112
  done: bool
pyproject.toml CHANGED
@@ -1,30 +1,38 @@
1
- [build-system]
2
- requires = ["setuptools>=61.0", "wheel"]
3
- build-backend = "setuptools.build_meta"
4
-
5
- [project]
6
- name = "code_review_professional"
7
- version = "1.0.0"
8
- description = "Multi‑turn code review environment with AST injection, DPO training, and author negotiation"
9
- authors = [{name = "yuvraj gupta", email = "yuvraj467229@gmail.com"}]
10
- license = {text = "MIT"}
11
- readme = "README.md"
12
- requires-python = ">=3.10"
13
- dependencies = [
14
- "openenv-core>=0.2.0",
15
- "fastapi>=0.115.0",
16
- "uvicorn>=0.24.0",
17
- "unsloth>=2025.3.1",
18
- "trl>=0.15.0",
19
- "accelerate>=1.2.0",
20
- "pylint>=3.3.0",
21
- "sentence-transformers>=3.3.0",
22
- "datasets>=3.3.0",
23
- "chromadb>=0.5.0",
24
- ]
25
-
26
- [project.optional-dependencies]
27
- dev = ["pytest>=7.0", "black>=23.0", "isort>=5.0"]
28
-
29
- [tool.openenv]
30
- server = "server.app:app"
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "code_review_professional"
7
+ version = "1.0.0"
8
+ description = "Multi-turn code review environment for OpenEnv with author negotiation and RL training hooks."
9
+ authors = [{name = "yuvraj gupta", email = "yuvraj467229@gmail.com"}]
10
+ license = {text = "MIT"}
11
+ readme = "README.md"
12
+ requires-python = ">=3.10"
13
+ dependencies = [
14
+ "openenv-core>=0.2.3",
15
+ "fastapi>=0.115.0",
16
+ "uvicorn>=0.24.0",
17
+ "pylint>=3.3.0",
18
+ ]
19
+
20
+ [project.optional-dependencies]
21
+ dev = ["pytest>=7.0", "black>=23.0", "isort>=5.0"]
22
+ training = [
23
+ "accelerate>=1.2.0",
24
+ "chromadb>=0.5.0",
25
+ "datasets>=3.3.0",
26
+ "matplotlib>=3.9.0",
27
+ "sentence-transformers>=3.3.0",
28
+ "torch>=2.4.0",
29
+ "trl>=0.15.0",
30
+ "unsloth>=2025.3.1",
31
+ ]
32
+
33
+ [project.urls]
34
+ Homepage = "https://huggingface.co/spaces/100XZX001/CodeReview-Professional-Workflow"
35
+ Repository = "https://huggingface.co/spaces/100XZX001/CodeReview-Professional-Workflow"
36
+
37
+ [tool.openenv]
38
+ server = "server.app:app"
requirements-training.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=1.2.0
2
+ chromadb>=0.5.0
3
+ datasets>=3.3.0
4
+ matplotlib>=3.9.0
5
+ sentence-transformers>=3.3.0
6
+ torch>=2.4.0
7
+ transformers>=4.48.0
8
+ trl>=0.15.0
9
+ unsloth>=2025.3.1
requirements.txt CHANGED
@@ -1,11 +1,5 @@
1
- unsloth
2
- transformers
3
- trl
4
- datasets
5
- torch
6
- sentence-transformers
7
- chromadb
8
- pylint
9
- pydantic
10
- matplotlib
11
- huggingface_hub
 
1
+ openenv-core>=0.2.3
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+ pylint>=3.3.0
5
+ pydantic>=2.8.0
 
 
 
 
 
 
rltool.py CHANGED
@@ -1,127 +1,143 @@
1
- # tools.py – Real vector retrieval for query_docs, linter, and test runner
2
- import subprocess
3
- import tempfile
4
- import os
5
- from dataclasses import dataclass
6
- from sentence_transformers import SentenceTransformer
7
- import chromadb
8
-
9
- @dataclass
10
- class ToolBox:
11
- _embedder = None
12
- _client = None
13
- _collection = None
14
-
15
- @classmethod
16
- def _get_embedder(cls):
17
- if cls._embedder is None:
18
- cls._embedder = SentenceTransformer('all-MiniLM-L6-v2')
19
- return cls._embedder
20
-
21
- @classmethod
22
- def _get_collection(cls):
23
- if cls._collection is None:
24
- cls._client = chromadb.Client()
25
- cls._collection = cls._client.create_collection("docs")
26
- # Pre‑load real documentation snippets (can be extended)
27
- docs = [
28
- "KeyError occurs when a dictionary key is missing. Use dict.get() or check 'if key in dict'.",
29
- "pylint error C0304: missing final newline. Add a newline at the end of file.",
30
- "Deadlock happens when two threads acquire locks in opposite order. Always acquire locks in the same order.",
31
- "Division by zero: check if list is empty before calculating average, or use try/except.",
32
- "Threading.Lock: use 'with lock:' to automatically acquire and release.",
33
- "Off‑by‑one errors: adjust loop ranges, e.g., range(1, len(arr)-1).",
34
- ]
35
- embedder = cls._get_embedder()
36
- embeddings = embedder.encode(docs).tolist()
37
- for i, doc in enumerate(docs):
38
- cls._collection.add(ids=[str(i)], documents=[doc], embeddings=[embeddings[i]])
39
- return cls._collection
40
-
41
- @staticmethod
42
- def run_linter(code: str) -> str:
43
- with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
44
- f.write(code)
45
- f.flush()
46
- tmp_path = f.name
47
- try:
48
- result = subprocess.run(
49
- ['pylint', tmp_path, '--exit-zero', '--output-format=text'],
50
- capture_output=True,
51
- text=True,
52
- timeout=10,
53
- encoding='utf-8'
54
- )
55
- output = result.stdout
56
- if "Your code has been rated" in output:
57
- output = output.split("Your code has been rated")[0]
58
- output = output.strip()
59
- if not output:
60
- return "No linting issues found."
61
- return output[:500]
62
- except FileNotFoundError:
63
- return "Linter (pylint) not installed."
64
- except subprocess.TimeoutExpired:
65
- return "Linter timed out."
66
- except Exception as e:
67
- return f"Linter error: {str(e)}"
68
- finally:
69
- try:
70
- os.unlink(tmp_path)
71
- except:
72
- pass
73
-
74
- @staticmethod
75
- def run_tests(test_script: str) -> str:
76
- with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
77
- f.write(test_script)
78
- f.flush()
79
- tmp_path = f.name
80
- try:
81
- result = subprocess.run(
82
- ['python', tmp_path],
83
- capture_output=True,
84
- text=True,
85
- timeout=10,
86
- encoding='utf-8'
87
- )
88
- output = result.stdout + result.stderr
89
- return output.strip() or "Test executed successfully (no output)."
90
- except subprocess.TimeoutExpired:
91
- return "Test execution timed out."
92
- except Exception as e:
93
- return f"Test runner error: {str(e)}"
94
- finally:
95
- try:
96
- os.unlink(tmp_path)
97
- except:
98
- pass
99
-
100
- @classmethod
101
- def query_docs(cls, topic: str) -> str:
102
- """Retrieve top 3 relevant docs. Forces agent to reason across multiple hints."""
103
- try:
104
- embedder = cls._get_embedder()
105
- collection = cls._get_collection()
106
- query_emb = embedder.encode([topic]).tolist()
107
- # Get top 3 results (not just 1)
108
- results = collection.query(query_embeddings=query_emb, n_results=3)
109
- if results['documents'] and results['documents'][0]:
110
- # Return concatenated snippets, labelled for clarity
111
- snippets = []
112
- for i, doc in enumerate(results['documents'][0]):
113
- snippets.append(f"[{i+1}] {doc}")
114
- return "Relevant documentation:\n" + "\n".join(snippets)
115
- return "No relevant documentation found."
116
- except Exception:
117
- # Fallback to keyword matching
118
- topic_lower = topic.lower()
119
- fallback = {
120
- "null check": "To avoid KeyError, use 'if key in dict:' before accessing.",
121
- "keyerror": "Catch KeyError with try/except or use dict.get().",
122
- "deadlock": "Always acquire locks in the same order to avoid deadlock.",
123
- }
124
- for key, value in fallback.items():
125
- if key in topic_lower:
126
- return value
127
- return "No relevant documentation found. Try being more specific."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tools.py - Real vector retrieval for query_docs, linter, and test runner
2
+ import os
3
+ import subprocess
4
+ import sys
5
+ import tempfile
6
+ from dataclasses import dataclass
7
+
8
+ try:
9
+ from sentence_transformers import SentenceTransformer
10
+ except ImportError:
11
+ SentenceTransformer = None
12
+
13
+ try:
14
+ import chromadb
15
+ except ImportError:
16
+ chromadb = None
17
+
18
+
19
+ @dataclass
20
+ class ToolBox:
21
+ _embedder = None
22
+ _client = None
23
+ _collection = None
24
+
25
+ @classmethod
26
+ def _get_embedder(cls):
27
+ if cls._embedder is None:
28
+ if SentenceTransformer is None:
29
+ return None
30
+ cls._embedder = SentenceTransformer("all-MiniLM-L6-v2")
31
+ return cls._embedder
32
+
33
+ @classmethod
34
+ def _get_collection(cls):
35
+ if cls._collection is None:
36
+ if chromadb is None:
37
+ return None
38
+ cls._client = chromadb.Client()
39
+ cls._collection = cls._client.create_collection("docs")
40
+ docs = [
41
+ "KeyError occurs when a dictionary key is missing. Use dict.get() or check 'if key in dict'.",
42
+ "pylint error C0304: missing final newline. Add a newline at the end of file.",
43
+ "Deadlock happens when two threads acquire locks in opposite order. Always acquire locks in the same order.",
44
+ "Division by zero: check if list is empty before calculating average, or use try/except.",
45
+ "Threading.Lock: use 'with lock:' to automatically acquire and release.",
46
+ "Off-by-one errors: adjust loop ranges, e.g., range(1, len(arr)-1).",
47
+ ]
48
+ embedder = cls._get_embedder()
49
+ if embedder is None:
50
+ return None
51
+ embeddings = embedder.encode(docs).tolist()
52
+ for i, doc in enumerate(docs):
53
+ cls._collection.add(ids=[str(i)], documents=[doc], embeddings=[embeddings[i]])
54
+ return cls._collection
55
+
56
+ @staticmethod
57
+ def run_linter(code: str) -> str:
58
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f:
59
+ f.write(code)
60
+ f.flush()
61
+ tmp_path = f.name
62
+ try:
63
+ result = subprocess.run(
64
+ [sys.executable, "-m", "pylint", tmp_path, "--exit-zero", "--output-format=text"],
65
+ capture_output=True,
66
+ text=True,
67
+ timeout=10,
68
+ encoding="utf-8",
69
+ )
70
+ output = result.stdout
71
+ if "Your code has been rated" in output:
72
+ output = output.split("Your code has been rated")[0]
73
+ output = output.strip()
74
+ if not output:
75
+ return "No linting issues found."
76
+ return output[:500]
77
+ except FileNotFoundError:
78
+ return "Linter (pylint) not installed."
79
+ except subprocess.TimeoutExpired:
80
+ return "Linter timed out."
81
+ except Exception as e:
82
+ return f"Linter error: {str(e)}"
83
+ finally:
84
+ try:
85
+ os.unlink(tmp_path)
86
+ except OSError:
87
+ pass
88
+
89
+ @staticmethod
90
+ def run_tests(test_script: str) -> str:
91
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f:
92
+ f.write(test_script)
93
+ f.flush()
94
+ tmp_path = f.name
95
+ try:
96
+ result = subprocess.run(
97
+ [sys.executable, tmp_path],
98
+ capture_output=True,
99
+ text=True,
100
+ timeout=10,
101
+ encoding="utf-8",
102
+ )
103
+ output = result.stdout + result.stderr
104
+ return output.strip() or "Test executed successfully (no output)."
105
+ except subprocess.TimeoutExpired:
106
+ return "Test execution timed out."
107
+ except Exception as e:
108
+ return f"Test runner error: {str(e)}"
109
+ finally:
110
+ try:
111
+ os.unlink(tmp_path)
112
+ except OSError:
113
+ pass
114
+
115
+ @classmethod
116
+ def query_docs(cls, topic: str) -> str:
117
+ """Retrieve top 3 relevant docs; fall back cleanly when vector deps are missing."""
118
+ try:
119
+ embedder = cls._get_embedder()
120
+ collection = cls._get_collection()
121
+ if embedder is None or collection is None:
122
+ raise RuntimeError("Vector retrieval dependencies are unavailable")
123
+ query_emb = embedder.encode([topic]).tolist()
124
+ results = collection.query(query_embeddings=query_emb, n_results=3)
125
+ if results["documents"] and results["documents"][0]:
126
+ snippets = []
127
+ for i, doc in enumerate(results["documents"][0]):
128
+ snippets.append(f"[{i + 1}] {doc}")
129
+ return "Relevant documentation:\n" + "\n".join(snippets)
130
+ return "No relevant documentation found."
131
+ except Exception:
132
+ topic_lower = topic.lower()
133
+ fallback = {
134
+ "null check": "To avoid KeyError, use 'if key in dict:' before accessing.",
135
+ "keyerror": "Catch KeyError with try/except or use dict.get().",
136
+ "deadlock": "Always acquire locks in the same order to avoid deadlock.",
137
+ "race": "Protect shared state with a lock or make the update atomic.",
138
+ "division": "Guard empty inputs before dividing or return a safe default.",
139
+ }
140
+ for key, value in fallback.items():
141
+ if key in topic_lower:
142
+ return value
143
+ return "No relevant documentation found. Try being more specific."
rubrics.py CHANGED
@@ -1,136 +1,136 @@
1
- # rubrics.py – Self-contained Rubrics (no external OpenEnv dependency)
2
-
3
- class Rubric:
4
- """Minimal Rubric base – compatible with OpenEnv but self‑contained."""
5
- def __call__(self, env, action, obs, reward, done, info):
6
- return 0.0
7
-
8
-
9
- # --------------------------------------------------------------------------------
10
- # 1. TOOL‑USAGE BONUS
11
- # --------------------------------------------------------------------------------
12
- class ToolUsageRubric(Rubric):
13
- def __init__(self, bonus: float = 0.05):
14
- self.bonus = bonus
15
-
16
- def __call__(self, env, action, obs, reward, done, info):
17
- score = 0.0
18
- action_type = info.get("action_type", "")
19
- # Use pre-action flags from `info` so first-use bonuses are
20
- # computed correctly even though env flags are mutated in-step.
21
- prev_tests_run = info.get("prev_tests_run", env._tests_run)
22
- prev_linter_run = info.get("prev_linter_run", env._linter_run)
23
- prev_docs_queried = info.get("prev_docs_queried", env._docs_queried)
24
-
25
- if action_type == "run_tests":
26
- if not prev_tests_run:
27
- score += self.bonus
28
- score += 0.015
29
- elif action_type == "run_linter":
30
- if not prev_linter_run:
31
- score += self.bonus
32
- score += 0.015
33
- elif action_type == "query_docs":
34
- if not prev_docs_queried:
35
- score += self.bonus * 0.5
36
- # Encourage docs usage when it is likely useful:
37
- # - early exploration phase
38
- # - non-trivial query text
39
- if env._step_count <= 4 and info.get("docs_query_len", 0) >= 8:
40
- score += 0.01
41
- # Discourage repeated docs calls after the first-use signal.
42
- if prev_docs_queried:
43
- score -= 0.01
44
- elif action_type == "question" and env._step_count <= 3:
45
- score += 0.02
46
- return score
47
-
48
-
49
- # --------------------------------------------------------------------------------
50
- # 2. DELTA‑BASED REWARDS
51
- # --------------------------------------------------------------------------------
52
- class TestDeltaRubric(Rubric):
53
- def __init__(self, weight: float = 0.3):
54
- self.weight = weight
55
-
56
- def __call__(self, env, action, obs, reward, done, info):
57
- delta = env._current_test_score - env._previous_test_score
58
- effective = self.weight
59
- if info.get("action_type") == "fix":
60
- effective *= 0.4
61
- return effective * delta
62
-
63
-
64
- class LintDeltaRubric(Rubric):
65
- def __init__(self, weight: float = 0.3):
66
- self.weight = weight
67
-
68
- def __call__(self, env, action, obs, reward, done, info):
69
- delta = env._current_lint_score - env._previous_lint_score
70
- effective = self.weight * 0.5
71
- if info.get("action_type") == "fix":
72
- effective *= 0.4
73
- return effective * delta
74
-
75
-
76
- # --------------------------------------------------------------------------------
77
- # 3. TERMINAL SUCCESS BONUS
78
- # --------------------------------------------------------------------------------
79
- class TerminalSuccessRubric(Rubric):
80
- def __call__(self, env, action, obs, reward, done, info):
81
- if info.get("action_type") != "fix":
82
- return 0.0
83
- score = 0.0
84
- if env._current_test_score > 0.95:
85
- score += 0.4
86
- elif env._current_test_score > 0.85:
87
- score += 0.2
88
- return score
89
-
90
-
91
- # --------------------------------------------------------------------------------
92
- # 4. EXPLORATION & DIVERSITY
93
- # --------------------------------------------------------------------------------
94
- class ExplorationRubric(Rubric):
95
- def __init__(self, penalty: float = -0.05, bonus: float = 0.021):
96
- self.penalty = penalty
97
- self.bonus = bonus
98
-
99
- def __call__(self, env, action, obs, reward, done, info):
100
- if len(env._action_history) < 3:
101
- return 0.0
102
- recent = env._action_history[-3:]
103
- unique = len(set(recent))
104
- if unique == 1:
105
- return self.penalty
106
- elif unique == 3:
107
- return self.bonus
108
- return 0.0
109
-
110
-
111
- # --------------------------------------------------------------------------------
112
- # 5. ANTI‑HACKING & CONSISTENCY
113
- # --------------------------------------------------------------------------------
114
- class AntiHackingRubric(Rubric):
115
- def __call__(self, env, action, obs, reward, done, info):
116
- if info.get("action_type") != "fix":
117
- return 0.0
118
- score = 0.0
119
- if not env._tests_run:
120
- score -= 0.25
121
- if env._step_count < 2:
122
- score -= 0.1
123
- if env._tests_run and env._linter_run:
124
- score += 0.02
125
- return score
126
-
127
-
128
- # --------------------------------------------------------------------------------
129
- # 6. STEP PENALTY
130
- # --------------------------------------------------------------------------------
131
- class StepPenaltyRubric(Rubric):
132
- def __init__(self, penalty: float = -0.01):
133
- self.penalty = penalty
134
-
135
- def __call__(self, env, action, obs, reward, done, info):
136
- return self.penalty
 
1
+ # rubrics.py – Self-contained Rubrics (no external OpenEnv dependency)
2
+
3
+ class Rubric:
4
+ """Minimal Rubric base – compatible with OpenEnv but self‑contained."""
5
+ def __call__(self, env, action, obs, reward, done, info):
6
+ return 0.0
7
+
8
+
9
+ # --------------------------------------------------------------------------------
10
+ # 1. TOOL‑USAGE BONUS
11
+ # --------------------------------------------------------------------------------
12
+ class ToolUsageRubric(Rubric):
13
+ def __init__(self, bonus: float = 0.05):
14
+ self.bonus = bonus
15
+
16
+ def __call__(self, env, action, obs, reward, done, info):
17
+ score = 0.0
18
+ action_type = info.get("action_type", "")
19
+ # Use pre-action flags from `info` so first-use bonuses are
20
+ # computed correctly even though env flags are mutated in-step.
21
+ prev_tests_run = info.get("prev_tests_run", env._tests_run)
22
+ prev_linter_run = info.get("prev_linter_run", env._linter_run)
23
+ prev_docs_queried = info.get("prev_docs_queried", env._docs_queried)
24
+
25
+ if action_type == "run_tests":
26
+ if not prev_tests_run:
27
+ score += self.bonus
28
+ score += 0.015
29
+ elif action_type == "run_linter":
30
+ if not prev_linter_run:
31
+ score += self.bonus
32
+ score += 0.015
33
+ elif action_type == "query_docs":
34
+ if not prev_docs_queried:
35
+ score += self.bonus * 0.5
36
+ # Encourage docs usage when it is likely useful:
37
+ # - early exploration phase
38
+ # - non-trivial query text
39
+ if env._step_count <= 4 and info.get("docs_query_len", 0) >= 8:
40
+ score += 0.01
41
+ # Discourage repeated docs calls after the first-use signal.
42
+ if prev_docs_queried:
43
+ score -= 0.01
44
+ elif action_type == "question" and env._step_count <= 3:
45
+ score += 0.02
46
+ return score
47
+
48
+
49
+ # --------------------------------------------------------------------------------
50
+ # 2. DELTA‑BASED REWARDS
51
+ # --------------------------------------------------------------------------------
52
+ class TestDeltaRubric(Rubric):
53
+ def __init__(self, weight: float = 0.3):
54
+ self.weight = weight
55
+
56
+ def __call__(self, env, action, obs, reward, done, info):
57
+ delta = env._current_test_score - env._previous_test_score
58
+ effective = self.weight
59
+ if info.get("action_type") == "fix":
60
+ effective *= 0.4
61
+ return effective * delta
62
+
63
+
64
+ class LintDeltaRubric(Rubric):
65
+ def __init__(self, weight: float = 0.3):
66
+ self.weight = weight
67
+
68
+ def __call__(self, env, action, obs, reward, done, info):
69
+ delta = env._current_lint_score - env._previous_lint_score
70
+ effective = self.weight * 0.5
71
+ if info.get("action_type") == "fix":
72
+ effective *= 0.4
73
+ return effective * delta
74
+
75
+
76
+ # --------------------------------------------------------------------------------
77
+ # 3. TERMINAL SUCCESS BONUS
78
+ # --------------------------------------------------------------------------------
79
+ class TerminalSuccessRubric(Rubric):
80
+ def __call__(self, env, action, obs, reward, done, info):
81
+ if info.get("action_type") != "fix":
82
+ return 0.0
83
+ score = 0.0
84
+ if env._current_test_score > 0.95:
85
+ score += 0.4
86
+ elif env._current_test_score > 0.85:
87
+ score += 0.2
88
+ return score
89
+
90
+
91
+ # --------------------------------------------------------------------------------
92
+ # 4. EXPLORATION & DIVERSITY
93
+ # --------------------------------------------------------------------------------
94
+ class ExplorationRubric(Rubric):
95
+ def __init__(self, penalty: float = -0.05, bonus: float = 0.021):
96
+ self.penalty = penalty
97
+ self.bonus = bonus
98
+
99
+ def __call__(self, env, action, obs, reward, done, info):
100
+ if len(env._action_history) < 3:
101
+ return 0.0
102
+ recent = env._action_history[-3:]
103
+ unique = len(set(recent))
104
+ if unique == 1:
105
+ return self.penalty
106
+ elif unique == 3:
107
+ return self.bonus
108
+ return 0.0
109
+
110
+
111
+ # --------------------------------------------------------------------------------
112
+ # 5. ANTI‑HACKING & CONSISTENCY
113
+ # --------------------------------------------------------------------------------
114
+ class AntiHackingRubric(Rubric):
115
+ def __call__(self, env, action, obs, reward, done, info):
116
+ if info.get("action_type") != "fix":
117
+ return 0.0
118
+ score = 0.0
119
+ if not env._tests_run:
120
+ score -= 0.25
121
+ if env._step_count < 2:
122
+ score -= 0.1
123
+ if env._tests_run and env._linter_run:
124
+ score += 0.02
125
+ return score
126
+
127
+
128
+ # --------------------------------------------------------------------------------
129
+ # 6. STEP PENALTY
130
+ # --------------------------------------------------------------------------------
131
+ class StepPenaltyRubric(Rubric):
132
+ def __init__(self, penalty: float = -0.01):
133
+ self.penalty = penalty
134
+
135
+ def __call__(self, env, action, obs, reward, done, info):
136
+ return self.penalty
training.py CHANGED
@@ -1,935 +1,935 @@
1
- # training.py – Memory‑safe: Phi‑3‑mini + Expert Demos + Fast PPO (2 iterations)
2
- import os
3
- os.environ["TRITON_DISABLE"] = "1"
4
- os.environ["TOKENIZERS_PARALLELISM"] = "false" # Issue #12: prevent OOM from parallel tokenization
5
-
6
- import torch._dynamo
7
- torch._dynamo.config.disable = True
8
- import json
9
- import torch
10
- import torch.nn.functional as F
11
- from torch.optim import AdamW
12
- from dataclasses import dataclass
13
- from typing import List, Dict, Tuple, Optional
14
- import numpy as np
15
- import re
16
- import random
17
- import matplotlib.pyplot as plt
18
-
19
- from unsloth import FastLanguageModel
20
- from transformers import TrainingArguments
21
- from trl import SFTTrainer
22
- from datasets import Dataset
23
-
24
- from environment import CodeReviewEnv
25
- from redteam import BUG_DB
26
- from models import (
27
- RunTests, RunLinter, Inspect,
28
- ProposeFix, WriteComment, AskQuestion,
29
- Done, Skip, QueryDocs, map_to_env as model_map_to_env
30
- )
31
-
32
- # ======================================================================
33
- @dataclass
34
- class AgentAction:
35
- action_type: str
36
- content: Optional[str] = None
37
-
38
- def parse_action(output: str) -> AgentAction:
39
- try:
40
- data = json.loads(output)
41
- return AgentAction(
42
- action_type=data.get("action_type", "").lower(),
43
- content=data.get("content")
44
- )
45
- except:
46
- pass
47
- json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', output, re.DOTALL)
48
- if json_match:
49
- try:
50
- data = json.loads(json_match.group(1))
51
- return AgentAction(
52
- action_type=data.get("action_type", "").lower(),
53
- content=data.get("content")
54
- )
55
- except:
56
- pass
57
- action_pattern = r'"action_type"\s*:\s*"(\w+)"'
58
- match = re.search(action_pattern, output)
59
- if match:
60
- return AgentAction(action_type=match.group(1).lower())
61
- output_lower = output.lower()
62
- if "test" in output_lower:
63
- return AgentAction("run_tests")
64
- if "lint" in output_lower:
65
- return AgentAction("run_linter")
66
- if "inspect" in output_lower:
67
- return AgentAction("inspect")
68
- if "doc" in output_lower or "documentation" in output_lower:
69
- return AgentAction("query_docs", "bug fix guidance")
70
- return AgentAction("invalid", output)
71
-
72
- def map_to_env(action: AgentAction):
73
- return model_map_to_env(action.action_type, action.content)
74
-
75
- # ======================================================================
76
- def load_model():
77
- model, tokenizer = FastLanguageModel.from_pretrained(
78
- model_name="unsloth/Phi-3-mini-4k-instruct-bnb-4bit",
79
- max_seq_length=480, # smaller window for memory
80
- load_in_4bit=True,
81
- )
82
- model = FastLanguageModel.get_peft_model(
83
- model,
84
- r=16,
85
- target_modules=[
86
- "q_proj", "k_proj", "v_proj", "o_proj",
87
- "gate_proj", "up_proj", "down_proj"
88
- ],
89
- lora_alpha=32,
90
- lora_dropout=0.0,
91
- )
92
- return model, tokenizer
93
-
94
- def test_model_sanity(model, tokenizer) -> bool:
95
- print("\n" + "="*60)
96
- print("SANITY CHECK: Testing base model generation")
97
- print("="*60)
98
- test_prompt = "Hello, how are you?"
99
- messages = [{"role": "user", "content": test_prompt}]
100
- formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
101
- inputs = tokenizer(formatted, return_tensors="pt", max_length=256, truncation=True).to("cuda")
102
- with torch.no_grad():
103
- outputs = model.generate(
104
- **inputs,
105
- max_new_tokens=30,
106
- do_sample=True,
107
- temperature=0.7,
108
- min_new_tokens=1,
109
- eos_token_id=tokenizer.eos_token_id,
110
- pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
111
- )
112
- generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
113
- response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
114
- print(f"Prompt: {test_prompt}")
115
- print(f"Response: {repr(response)}")
116
- if len(response) == 0:
117
- print("❌ Model produces empty output – cannot train.")
118
- return False
119
- print("βœ“ Model sanity check PASSED\n")
120
- return True
121
-
122
- # ======================================================================
123
- def _expert_fix_from_context(obs) -> str:
124
- """
125
- Build a conservative fix template named `fix` (required by tests).
126
- Uses bug hints + code snippet patterns to create realistic fixes.
127
- """
128
- bug = (getattr(obs, "bug_description", "") or "").lower()
129
- code = getattr(obs, "code_snippet", "") or ""
130
-
131
- if "division" in bug or "average" in code.lower():
132
- return (
133
- "def fix(data):\n"
134
- " if not data:\n"
135
- " return 0\n"
136
- " return sum(data) / len(data)"
137
- )
138
-
139
- if "operator" in bug or "sign" in bug:
140
- return (
141
- "def fix(a, b):\n"
142
- " return a + b"
143
- )
144
-
145
- if "off_by_one" in bug or "loop" in bug:
146
- return (
147
- "def fix(items):\n"
148
- " return len(items)"
149
- )
150
-
151
- if "null" in bug or "key" in bug or "dict" in code.lower():
152
- return (
153
- "def fix(payload):\n"
154
- " users = payload.get('users', {})\n"
155
- " user_id = payload.get('id')\n"
156
- " return users.get(user_id)"
157
- )
158
-
159
- # Concurrency-heavy tasks (harder/hardest).
160
- if "race" in bug or "missing_lock" in bug or "thread_safe" in bug or "global_nonatomic" in bug:
161
- return (
162
- "import threading\n"
163
- "_lock = threading.Lock()\n"
164
- "\n"
165
- "def fix(counter):\n"
166
- " with _lock:\n"
167
- " if counter is None:\n"
168
- " return 0\n"
169
- " return counter + 1"
170
- )
171
-
172
- if "deadlock" in bug or "double_lock" in bug or "lock order" in bug or "nested_lock" in bug:
173
- return (
174
- "import threading\n"
175
- "_lock_a = threading.Lock()\n"
176
- "_lock_b = threading.Lock()\n"
177
- "\n"
178
- "def fix(work):\n"
179
- " first, second = (_lock_a, _lock_b)\n"
180
- " if id(first) > id(second):\n"
181
- " first, second = second, first\n"
182
- " with first:\n"
183
- " with second:\n"
184
- " return work() if callable(work) else work"
185
- )
186
-
187
- if "fork_join" in bug or "join" in bug:
188
- return (
189
- "import threading\n"
190
- "\n"
191
- "def fix(worker):\n"
192
- " t = threading.Thread(target=worker)\n"
193
- " t.start()\n"
194
- " t.join()\n"
195
- " return True"
196
- )
197
-
198
- # Generic safe fallback keeps the RL pipeline alive for unknown bugs.
199
- return (
200
- "def fix(data):\n"
201
- " if data is None:\n"
202
- " return None\n"
203
- " return data"
204
- )
205
-
206
-
207
- def _expert_supervised_policy(obs) -> str:
208
- """
209
- Real workflow policy:
210
- inspect -> tests/linter -> docs -> fix -> negotiate -> done.
211
- """
212
- author_msg = (getattr(obs, "author_response", "") or "").lower()
213
- tool_output = (getattr(obs, "last_tool_output", "") or "").lower()
214
-
215
- if not getattr(obs, "tests_run", False):
216
- if "inspect" not in tool_output:
217
- return '{"action_type": "inspect"}'
218
- return '{"action_type": "run_tests"}'
219
-
220
- if not getattr(obs, "linter_run", False):
221
- return '{"action_type": "run_linter"}'
222
-
223
- if not getattr(obs, "docs_queried", False):
224
- return '{"action_type": "query_docs", "content": "python bug fixing best practices for edge cases and null safety"}'
225
-
226
- # Use docs again on hard tasks when evidence is still weak.
227
- if getattr(obs, "current_test_score", 0.0) < 0.6 and getattr(obs, "step", 0) >= 3:
228
- bug_hint = (getattr(obs, "bug_description", "") or "concurrency bug").replace('"', "'")
229
- return json.dumps(
230
- {
231
- "action_type": "query_docs",
232
- "content": f"python {bug_hint} lock ordering race condition mitigation patterns",
233
- }
234
- )
235
-
236
- # If test quality is poor, propose a concrete fix.
237
- if getattr(obs, "current_test_score", 0.0) < 0.95:
238
- fix_code = _expert_fix_from_context(obs)
239
- return json.dumps({"action_type": "fix", "content": fix_code})
240
-
241
- # If author is still unconvinced, provide causal explanation.
242
- if author_msg and ("not convinced" in author_msg or "explain" in author_msg or "brief" in author_msg):
243
- return (
244
- '{"action_type": "comment", "content": "This fix works because it handles the failing edge case directly, '
245
- 'keeps behavior deterministic, and aligns with the observed test and lint feedback. '
246
- 'The change is intentionally small to reduce regression risk."}'
247
- )
248
-
249
- # If negotiation is strong enough and quality is good, terminate.
250
- conf = float(getattr(obs, "author_confidence", 0.0))
251
- threshold = float(getattr(obs, "author_threshold", 0.5))
252
- score = float(getattr(obs, "current_test_score", 0.0))
253
- if conf >= threshold and score >= 0.8:
254
- return '{"action_type": "done"}'
255
-
256
- # Nudge conversation forward when tests are okay but acceptance is pending.
257
- return (
258
- '{"action_type": "question", "content": "Would you like a quick walkthrough of a failing scenario, the root cause, and how the fix prevents regressions?"}'
259
- )
260
-
261
- # ======================================================================
262
- def supervised_warmup(model, tokenizer, env, n_episodes=16, epochs=1, max_steps=8):
263
- print("\n" + "="*60)
264
- print("SUPERVISED WARM-UP: Real environment demonstrations")
265
- print("="*60)
266
-
267
- examples = []
268
- tasks = ["easy", "medium", "hard", "harder", "hardest"]
269
- for ep in range(n_episodes):
270
- task = random.choice(tasks)
271
- env.set_task(task)
272
- obs = env.reset()
273
- history = []
274
- done = False
275
-
276
- steps = 0
277
- while not done and steps < max_steps:
278
- prompt = build_prompt(obs, history)
279
- action_text = _expert_supervised_policy(obs)
280
- action = parse_action(action_text)
281
- env_action = map_to_env(action)
282
- next_obs, _, done, _ = env.step(env_action)
283
-
284
- messages = [
285
- {"role": "user", "content": prompt},
286
- {"role": "assistant", "content": action_text},
287
- ]
288
- full_text = tokenizer.apply_chat_template(messages, tokenize=False)
289
- examples.append({"text": full_text})
290
-
291
- history.append(f"Agent: {action_text}")
292
- history.append(f"Env: {next_obs.last_tool_output}")
293
- history = history[-8:]
294
- obs = next_obs
295
- steps += 1
296
-
297
- print(f"Supervised episode {ep+1}: task={task}, steps={steps}, done={done}")
298
-
299
- if not examples:
300
- print("No supervised examples generated; skipping warm-up.")
301
- return
302
-
303
- dataset = Dataset.from_list(examples)
304
- trainer = SFTTrainer(
305
- model=model,
306
- tokenizer=tokenizer,
307
- train_dataset=dataset,
308
- dataset_text_field="text",
309
- max_seq_length=480,
310
- args=TrainingArguments(
311
- output_dir="warmup_output",
312
- num_train_epochs=epochs,
313
- per_device_train_batch_size=2,
314
- gradient_accumulation_steps=2,
315
- learning_rate=2e-5,
316
- logging_steps=50,
317
- save_strategy="no",
318
- bf16=True,
319
- ),
320
- )
321
- print(f"Training on {len(examples)} real env examples for {epochs} epochs...")
322
- trainer.train()
323
- print("βœ“ Supervised warm-up (real env) complete\n")
324
- torch.cuda.empty_cache()
325
-
326
- # ======================================================================
327
- def generate_action_with_logprob(prompt, model, tokenizer, temperature=0.0, max_retries=2):
328
- messages = [{"role": "user", "content": prompt}]
329
- formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
330
- inputs = tokenizer(formatted, return_tensors="pt", max_length=480, truncation=True).to("cuda")
331
-
332
- for attempt in range(max_retries):
333
- with torch.no_grad():
334
- outputs = model.generate(
335
- **inputs,
336
- max_new_tokens=64,
337
- do_sample=(temperature > 0),
338
- temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
339
- min_new_tokens=1,
340
- return_dict_in_generate=True,
341
- output_scores=True,
342
- )
343
- generated_ids = outputs.sequences[0][inputs['input_ids'].shape[1]:]
344
- action_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
345
-
346
- logprobs = []
347
- for idx, token_id in enumerate(generated_ids):
348
- if idx < len(outputs.scores):
349
- token_logits = outputs.scores[idx][0]
350
- token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
351
- logprobs.append(token_logprob)
352
- total_logprob = sum(logprobs) if logprobs else -100.0
353
-
354
- if not action_text:
355
- fallback_actions = [
356
- '{"action_type": "run_tests"}',
357
- '{"action_type": "run_linter"}',
358
- '{"action_type": "inspect"}',
359
- '{"action_type": "skip"}',
360
- ]
361
- action_text = random.choice(fallback_actions)
362
- total_logprob = -50.0
363
- print(f"[WARN] Empty generation β†’ using fallback: {action_text}")
364
- return action_text, total_logprob
365
-
366
- try:
367
- json.loads(action_text)
368
- return action_text, total_logprob
369
- except:
370
- if attempt == max_retries - 1:
371
- return '{"action_type":"skip"}', -100.0
372
- continue
373
- return '{"action_type":"skip"}', -100.0
374
-
375
- # ======================================================================
376
- def build_prompt(obs, history_lines: List[str]) -> str:
377
- author_msg = getattr(obs, "author_response", "") or ""
378
- tool_output = getattr(obs, "last_tool_output", "") or ""
379
- author_personality = getattr(obs, "author_personality", "defensive")
380
-
381
- prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix and name your proposed fix function fix.
382
-
383
- The developer has a **{author_personality}** personality and will only accept if you provide solid evidence:
384
- - Tests pass (high pass ratio)
385
- - Lint is clean (zero errors)
386
- - Documentation or references are provided
387
- - Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
388
-
389
- Workflow:
390
- 1. Use `inspect` to understand the code.
391
- 2. Use `run_tests` and `run_linter` to gather evidence.
392
- 3. Use `query_docs` when you need references or language-specific guidance.
393
- 4. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
394
- 5. If the developer pushes back, read their response carefully and address their specific concern.
395
- 6. Once convinced, use `done` to finish.
396
-
397
- Code:
398
- {obs.code_snippet}
399
-
400
- Author says:
401
- {author_msg if author_msg else "(no response yet – start with inspection)"}
402
-
403
- Last tool output:
404
- {tool_output if tool_output else "(none)"}
405
-
406
- Available actions:
407
- run_tests, run_linter, inspect, query_docs, fix, comment, question, done
408
-
409
- Respond ONLY in JSON:
410
- {{"action_type": "...", "content": "..."}}"""
411
-
412
- if history_lines:
413
- history = "\n".join(history_lines[-6:])
414
- prompt += f"\n\nPrevious steps:\n{history}"
415
- return prompt
416
-
417
- # ======================================================================
418
- @dataclass
419
- class Trajectory:
420
- states: List[str]
421
- actions: List[str]
422
- rewards: List[float]
423
- logprobs: List[float]
424
- dones: List[bool]
425
- def __len__(self): return len(self.states)
426
-
427
- def collect_trajectory(env, model, tokenizer, max_steps=6, temperature=0.0):
428
- obs = env.reset()
429
- history_lines = []
430
- states, actions, rewards, logprobs, dones = [], [], [], [], []
431
- for step in range(max_steps):
432
- prompt = build_prompt(obs, history_lines)
433
- states.append(prompt)
434
- action_text, logprob = generate_action_with_logprob(prompt, model, tokenizer, temperature)
435
- actions.append(action_text)
436
- logprobs.append(logprob)
437
- action = parse_action(action_text)
438
- env_action = map_to_env(action)
439
- next_obs, reward, done, _ = env.step(env_action)
440
- rewards.append(reward.value)
441
- dones.append(done)
442
- history_lines.append(f"Agent: {action_text}")
443
- history_lines.append(f"Env: {next_obs.last_tool_output}")
444
- obs = next_obs
445
- if done: break
446
- return Trajectory(states, actions, rewards, logprobs, dones)
447
-
448
- def collect_trajectories(env, model, tokenizer, n_trajectories, max_steps=6,
449
- task_levels=None, task_weights=None):
450
- if task_levels is None:
451
- task_levels = list(BUG_DB.keys())
452
- if task_weights is not None and len(task_weights) != len(task_levels):
453
- raise ValueError("task_weights must match task_levels length")
454
- if task_weights is not None and sum(task_weights) <= 0:
455
- raise ValueError("task_weights must have a positive total")
456
- trajectories = []
457
- for i in range(n_trajectories):
458
- sampled_task = random.choices(task_levels, weights=task_weights, k=1)[0]
459
- env.set_task(sampled_task)
460
- traj = collect_trajectory(env, model, tokenizer, max_steps)
461
- total_reward = sum(traj.rewards)
462
- print(f"Trajectory {i+1}/{n_trajectories}: task={sampled_task}, steps={len(traj)}, reward={total_reward:.3f}")
463
- trajectories.append(traj)
464
- return trajectories
465
-
466
- def compute_returns_and_advantages(rewards, dones, gamma=0.99, standardize=True):
467
- """
468
- Compute discounted returns and REINFORCE-style baseline advantages.
469
- Advantages are centered and optionally standardised.
470
- """
471
- n = len(rewards)
472
- returns = [0.0]*n
473
- running = 0.0
474
- for t in reversed(range(n)):
475
- if dones[t]: running = 0.0
476
- running = rewards[t] + gamma * running
477
- returns[t] = running
478
- if standardize:
479
- advantages = np.array(returns) - np.mean(returns)
480
- adv_std = np.std(advantages) + 1e-8
481
- advantages = (advantages / adv_std).tolist()
482
- else:
483
- advantages = returns.copy()
484
- return advantages, returns
485
-
486
- def ppo_update(trajectories, model, tokenizer, optimizer, n_epochs=1, clip_epsilon=0.2,
487
- entropy_coef=0.01, gamma=0.99):
488
- model.train()
489
- all_states, all_actions, all_old_logprobs, all_advantages = [], [], [], []
490
- for traj in trajectories:
491
- advantages, _ = compute_returns_and_advantages(traj.rewards, traj.dones, gamma=gamma, standardize=True)
492
- all_states.extend(traj.states)
493
- all_actions.extend(traj.actions)
494
- all_old_logprobs.extend(traj.logprobs)
495
- all_advantages.extend(advantages)
496
- n_samples = len(all_states)
497
- total_loss, total_policy_loss, total_entropy, n_updates = 0.0, 0.0, 0.0, 0
498
- for epoch in range(n_epochs):
499
- indices = np.random.permutation(n_samples)
500
- for i in indices:
501
- state = all_states[i]
502
- action = all_actions[i]
503
- old_logprob = all_old_logprobs[i]
504
- advantage = all_advantages[i]
505
- messages = [{"role": "user", "content": state}]
506
- formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
507
- full_text = formatted + action
508
- inputs = tokenizer(full_text, return_tensors="pt", max_length=480, truncation=True).to("cuda")
509
- outputs = model(**inputs)
510
- logits = outputs.logits
511
- action_ids = tokenizer.encode(action, add_special_tokens=False)
512
- prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
513
- action_start = len(prefix_ids)
514
- logprobs = []
515
- entropy = 0.0
516
- for idx, token_id in enumerate(action_ids):
517
- position = action_start + idx - 1
518
- if 0 <= position < logits.shape[1]:
519
- token_logits = logits[0, position]
520
- log_probs = F.log_softmax(token_logits, dim=-1)
521
- token_logprob = log_probs[token_id]
522
- logprobs.append(token_logprob)
523
- probs = F.softmax(token_logits, dim=-1)
524
- entropy += -(probs * log_probs).sum()
525
- if not logprobs: continue
526
- new_logprob = sum(logprobs)
527
- avg_entropy = entropy / len(logprobs) if logprobs else 0.0
528
- ratio = torch.exp(new_logprob - old_logprob)
529
- surr1 = ratio * advantage
530
- surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantage
531
- policy_loss = -torch.min(surr1, surr2)
532
- loss = policy_loss - entropy_coef * avg_entropy
533
- optimizer.zero_grad()
534
- loss.backward()
535
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
536
- optimizer.step()
537
- total_loss += loss.item()
538
- total_policy_loss += policy_loss.item()
539
- total_entropy += avg_entropy.item()
540
- n_updates += 1
541
- torch.cuda.empty_cache()
542
- return {"loss": total_loss / n_updates if n_updates else 0.0,
543
- "policy_loss": total_policy_loss / n_updates if n_updates else 0.0,
544
- "entropy": total_entropy / n_updates if n_updates else 0.0}
545
-
546
- def evaluate_policy(env, model, tokenizer, n_episodes=3, max_steps=6,
547
- task_levels=None, verbose=False):
548
- """Evaluate the current policy across task levels. Returns metrics + optional traces."""
549
- model.eval()
550
- if task_levels is None:
551
- task_levels = list(BUG_DB.keys())
552
- total_rewards = []
553
- traces = [] # human-readable behavior logs
554
- for ep in range(n_episodes):
555
- task = task_levels[ep % len(task_levels)]
556
- env.set_task(task)
557
- traj = collect_trajectory(env, model, tokenizer, max_steps, temperature=0.0)
558
- ep_reward = sum(traj.rewards)
559
- total_rewards.append(ep_reward)
560
- if verbose:
561
- actions_taken = []
562
- for a in traj.actions:
563
- try:
564
- actions_taken.append(json.loads(a).get("action_type", "?"))
565
- except Exception:
566
- actions_taken.append("?")
567
- traces.append({
568
- "task": task,
569
- "reward": round(ep_reward, 4),
570
- "steps": len(traj),
571
- "actions": actions_taken,
572
- })
573
- return {
574
- "avg_reward": float(np.mean(total_rewards)),
575
- "std_reward": float(np.std(total_rewards)),
576
- "min_reward": float(np.min(total_rewards)),
577
- "max_reward": float(np.max(total_rewards)),
578
- "traces": traces,
579
- }
580
-
581
- # ======================================================================
582
- # MANUAL WARM-UP (no SFTTrainer β†’ no multiprocessing OOM)
583
- # ======================================================================
584
- def json_warmup(model, tokenizer, json_path="training_data.json",
585
- n_episodes=20, epochs=2, lr=2e-5):
586
- """
587
- Supervised warm-up from pre-generated expert demonstrations.
588
- Uses raw cross-entropy on action tokens with manual gradient steps.
589
- NO SFTTrainer, NO multiprocessing – runs safely on any GPU.
590
- """
591
- print("\n" + "="*60)
592
- print("SUPERVISED WARM-UP: training_data.json (manual cross-entropy)")
593
- print("="*60)
594
-
595
- with open(json_path, encoding="utf-8") as f:
596
- data = json.load(f)
597
-
598
- # Each episode = 7 steps. Select n_episodes worth.
599
- steps_per_episode = 7
600
- max_examples = n_episodes * steps_per_episode
601
- if max_examples < len(data):
602
- data = data[:max_examples]
603
-
604
- print(f" {len(data)} examples ({len(data)//steps_per_episode} episodes), "
605
- f"{epochs} epoch(s), lr={lr}")
606
-
607
- model.train()
608
- warmup_opt = AdamW(model.parameters(), lr=lr)
609
- warmup_losses = [] # per-epoch avg loss
610
-
611
- for epoch in range(epochs):
612
- random.shuffle(data)
613
- epoch_loss = 0.0
614
- n_valid = 0
615
-
616
- for i, example in enumerate(data):
617
- prompt = example["prompt"]
618
- action = example["action"]
619
-
620
- # ---- tokenize full sequence (prompt + action) ----
621
- messages = [
622
- {"role": "user", "content": prompt},
623
- {"role": "assistant", "content": action},
624
- ]
625
- full_text = tokenizer.apply_chat_template(messages, tokenize=False)
626
- inputs = tokenizer(full_text, return_tensors="pt",
627
- max_length=480, truncation=True).to("cuda")
628
-
629
- # ---- find where the action tokens start ----
630
- prompt_only = tokenizer.apply_chat_template(
631
- [{"role": "user", "content": prompt}],
632
- tokenize=False, add_generation_prompt=True
633
- )
634
- prompt_ids = tokenizer.encode(prompt_only, add_special_tokens=False)
635
- prompt_len = len(prompt_ids)
636
-
637
- total_len = inputs.input_ids.shape[1]
638
- if prompt_len >= total_len:
639
- continue # prompt was truncated away, skip
640
-
641
- # ---- cross-entropy on action tokens only ----
642
- outputs = model(**inputs)
643
- logits = outputs.logits
644
-
645
- # next-token prediction: logits[t] predicts token[t+1]
646
- shift_logits = logits[0, prompt_len - 1 : total_len - 1]
647
- shift_labels = inputs.input_ids[0, prompt_len : total_len]
648
-
649
- min_len = min(shift_logits.shape[0], shift_labels.shape[0])
650
- if min_len == 0:
651
- continue
652
-
653
- loss = F.cross_entropy(shift_logits[:min_len], shift_labels[:min_len])
654
-
655
- warmup_opt.zero_grad()
656
- loss.backward()
657
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
658
- warmup_opt.step()
659
-
660
- epoch_loss += loss.item()
661
- n_valid += 1
662
-
663
- if (i + 1) % 25 == 0:
664
- avg = epoch_loss / n_valid
665
- print(f" epoch {epoch+1} step {i+1:3d}/{len(data)} "
666
- f"running_loss={avg:.4f}")
667
-
668
- avg_loss = epoch_loss / max(n_valid, 1)
669
- warmup_losses.append(avg_loss)
670
- print(f" Epoch {epoch+1} done: avg_loss={avg_loss:.4f} "
671
- f"({n_valid} valid examples)")
672
-
673
- torch.cuda.empty_cache()
674
- print(f"βœ“ Warm-up complete. Loss: "
675
- f"{' β†’ '.join(f'{l:.4f}' for l in warmup_losses)}\n")
676
- return warmup_losses
677
-
678
-
679
- # ======================================================================
680
- # MAIN TRAINING PIPELINE
681
- # ======================================================================
682
- def train_ppo():
683
- # --- Hyperparameters ---
684
- n_iterations = 8 # enough for a clear upward trend
685
- trajectories_per_iter = 4 # on-policy data per iteration
686
- n_epochs = 1
687
- max_steps = 6
688
- learning_rate = 3e-5
689
- clip_epsilon = 0.2
690
- entropy_coef = 0.01
691
- gamma = 0.99
692
-
693
- # --- Pre-load embedder before LLM (Issue #13) ---
694
- from rltool import ToolBox
695
- print("Pre-loading sentence-transformer embedder...")
696
- ToolBox._get_embedder()
697
- print("βœ“ Embedder ready")
698
-
699
- # --- Load model ---
700
- print("Loading model...")
701
- model, tokenizer = load_model()
702
- if not test_model_sanity(model, tokenizer):
703
- return
704
- env = CodeReviewEnv()
705
- task_levels = list(BUG_DB.keys())
706
-
707
- # ==================================================================
708
- # PHASE 0: BASELINE (untrained policy)
709
- # ==================================================================
710
- print("\n" + "="*60)
711
- print("PHASE 0 – BASELINE EVALUATION (untrained)")
712
- print("="*60)
713
- baseline = evaluate_policy(env, model, tokenizer, n_episodes=5,
714
- max_steps=max_steps, task_levels=task_levels,
715
- verbose=True)
716
- baseline_reward = baseline["avg_reward"]
717
- print(f"Baseline avg reward: {baseline_reward:.4f} "
718
- f"(min={baseline['min_reward']:.4f}, max={baseline['max_reward']:.4f})")
719
- print("Baseline behavior:")
720
- for t in baseline["traces"]:
721
- print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
722
- f"steps={t['steps']} actions={t['actions']}")
723
-
724
- # ==================================================================
725
- # PHASE 1: SUPERVISED WARM-UP (expert demos, manual CE)
726
- # ==================================================================
727
- warmup_losses = json_warmup(
728
- model, tokenizer,
729
- json_path="training_data.json",
730
- n_episodes=20, # 140 examples (20 Γ— 7 steps)
731
- epochs=2,
732
- lr=2e-5,
733
- )
734
-
735
- # Post-warmup evaluation
736
- print("="*60)
737
- print("POST WARM-UP EVALUATION")
738
- print("="*60)
739
- post_warmup = evaluate_policy(env, model, tokenizer, n_episodes=5,
740
- max_steps=max_steps, task_levels=task_levels,
741
- verbose=True)
742
- warmup_reward = post_warmup["avg_reward"]
743
- print(f"Post-warmup avg reward: {warmup_reward:.4f} "
744
- f"(Ξ” vs baseline: {warmup_reward - baseline_reward:+.4f})")
745
- print("Post-warmup behavior:")
746
- for t in post_warmup["traces"]:
747
- print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
748
- f"steps={t['steps']} actions={t['actions']}")
749
-
750
- # ==================================================================
751
- # PHASE 2: TRUE RL – PPO (on-policy, real environment interaction)
752
- # ==================================================================
753
- optimizer = AdamW(model.parameters(), lr=learning_rate)
754
- print(f"\n{'='*60}")
755
- print(f"PHASE 2 – PPO TRAINING: {n_iterations} iterations Γ— "
756
- f"{trajectories_per_iter} trajectories (true RL)")
757
- print(f"{'='*60}\n")
758
-
759
- reward_history = []
760
- eval_history = []
761
- loss_history = []
762
- policy_loss_history = []
763
- entropy_history = []
764
-
765
- for iteration in range(n_iterations):
766
- print(f"\n--- PPO Iteration {iteration + 1}/{n_iterations} ---")
767
-
768
- # Collect on-policy trajectories from REAL environment
769
- trajectories = collect_trajectories(
770
- env, model, tokenizer, trajectories_per_iter, max_steps,
771
- task_levels=task_levels, task_weights=None
772
- )
773
- avg_reward = float(np.mean([sum(t.rewards) for t in trajectories]))
774
- reward_history.append(avg_reward)
775
- print(f" Collect avg reward: {avg_reward:+.4f}")
776
-
777
- # PPO policy gradient update
778
- metrics = ppo_update(
779
- trajectories, model, tokenizer, optimizer,
780
- n_epochs=n_epochs, clip_epsilon=clip_epsilon,
781
- entropy_coef=entropy_coef, gamma=gamma
782
- )
783
- loss_history.append(float(metrics["loss"]))
784
- policy_loss_history.append(float(metrics["policy_loss"]))
785
- entropy_history.append(float(metrics["entropy"]))
786
- print(f" Update loss={metrics['loss']:.4f} "
787
- f"policy={metrics['policy_loss']:.4f} "
788
- f"entropy={metrics['entropy']:.4f}")
789
-
790
- # Evaluate greedy policy after update
791
- eval_m = evaluate_policy(env, model, tokenizer, n_episodes=3,
792
- max_steps=max_steps, task_levels=task_levels,
793
- verbose=False)
794
- eval_history.append(eval_m["avg_reward"])
795
- delta = eval_m["avg_reward"] - baseline_reward
796
- print(f" Eval avg reward: {eval_m['avg_reward']:+.4f} "
797
- f"(Ξ” baseline: {delta:+.4f})")
798
-
799
- # ==================================================================
800
- # PHASE 3: FINAL EVALUATION (proof of learning)
801
- # ==================================================================
802
- print("\n" + "="*60)
803
- print("PHASE 3 – FINAL EVALUATION (after all training)")
804
- print("="*60)
805
- final = evaluate_policy(env, model, tokenizer, n_episodes=5,
806
- max_steps=max_steps, task_levels=task_levels,
807
- verbose=True)
808
- print(f"Final avg reward: {final['avg_reward']:.4f} "
809
- f"(min={final['min_reward']:.4f}, max={final['max_reward']:.4f})")
810
- print("Final behavior:")
811
- for t in final["traces"]:
812
- print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
813
- f"steps={t['steps']} actions={t['actions']}")
814
-
815
- total_improvement = final["avg_reward"] - baseline_reward
816
- ppo_improvement = final["avg_reward"] - warmup_reward
817
- print(f"\n{'='*60}")
818
- print("TRAINING SUMMARY")
819
- print(f" Baseline reward: {baseline_reward:+.4f}")
820
- print(f" Post-warmup reward: {warmup_reward:+.4f} "
821
- f"(warmup Ξ”: {warmup_reward - baseline_reward:+.4f})")
822
- print(f" Final reward: {final['avg_reward']:+.4f} "
823
- f"(PPO Ξ”: {ppo_improvement:+.4f})")
824
- print(f" Total improvement: {total_improvement:+.4f}")
825
- print(f" Reward trend (PPO): {' β†’ '.join(f'{r:+.3f}' for r in reward_history)}")
826
- print(f" Loss trend (PPO): {' β†’ '.join(f'{l:.4f}' for l in loss_history)}")
827
- if total_improvement > 0:
828
- print(f" βœ“ Agent IMPROVED by {total_improvement:+.4f}")
829
- else:
830
- print(f" βœ— No overall improvement detected")
831
- print(f"{'='*60}")
832
-
833
- # ==================================================================
834
- # PLOTS
835
- # ==================================================================
836
- iters = list(range(1, n_iterations + 1))
837
-
838
- # --- 1. Warm-up loss curve ---
839
- if warmup_losses:
840
- fig, ax = plt.subplots(figsize=(7, 4))
841
- ax.plot(range(1, len(warmup_losses) + 1), warmup_losses,
842
- marker="o", linewidth=2, color="tab:purple")
843
- ax.set_title("Warm-up Loss (supervised, per epoch)",
844
- fontsize=13, fontweight="bold")
845
- ax.set_xlabel("Epoch")
846
- ax.set_ylabel("Cross-Entropy Loss")
847
- ax.grid(alpha=0.3)
848
- fig.tight_layout()
849
- fig.savefig("warmup_loss.png", dpi=150)
850
- plt.close(fig)
851
-
852
- # --- 2. PPO reward curve ---
853
- fig, ax = plt.subplots(figsize=(9, 5))
854
- ax.plot(iters, reward_history, marker="o", linewidth=2,
855
- label="Collect reward", color="tab:blue")
856
- ax.plot(iters, eval_history, marker="s", linewidth=2, linestyle="--",
857
- label="Eval reward", color="tab:green")
858
- ax.axhline(y=baseline_reward, color="tab:gray", linestyle=":",
859
- linewidth=1.5, label=f"Baseline ({baseline_reward:+.3f})")
860
- ax.axhline(y=warmup_reward, color="tab:purple", linestyle=":",
861
- linewidth=1.5, label=f"Post-warmup ({warmup_reward:+.3f})")
862
- ax.set_title("PPO Reward per Iteration", fontsize=14, fontweight="bold")
863
- ax.set_xlabel("Iteration")
864
- ax.set_ylabel("Average Reward")
865
- ax.legend(loc="best", fontsize=8)
866
- ax.grid(alpha=0.3)
867
- fig.tight_layout()
868
- fig.savefig("reward_curve.png", dpi=150)
869
- plt.close(fig)
870
-
871
- # --- 3. PPO loss curve ---
872
- fig, ax = plt.subplots(figsize=(9, 5))
873
- ax.plot(iters, loss_history, marker="o", linewidth=2,
874
- label="Total loss", color="tab:red")
875
- ax.plot(iters, policy_loss_history, marker="^", linewidth=2, linestyle="--",
876
- label="Policy loss", color="tab:orange")
877
- ax.set_title("PPO Loss per Iteration", fontsize=14, fontweight="bold")
878
- ax.set_xlabel("Iteration")
879
- ax.set_ylabel("Loss")
880
- ax.legend(loc="best")
881
- ax.grid(alpha=0.3)
882
- fig.tight_layout()
883
- fig.savefig("loss_curve.png", dpi=150)
884
- plt.close(fig)
885
-
886
- # --- 4. Combined 3-panel summary ---
887
- fig, axes = plt.subplots(1, 3, figsize=(18, 5))
888
-
889
- # Panel A: warm-up loss
890
- if warmup_losses:
891
- axes[0].plot(range(1, len(warmup_losses) + 1), warmup_losses,
892
- marker="o", linewidth=2, color="tab:purple")
893
- axes[0].set_title("A. Warm-up Loss ↓")
894
- axes[0].set_xlabel("Epoch")
895
- axes[0].set_ylabel("CE Loss")
896
- axes[0].grid(alpha=0.3)
897
-
898
- # Panel B: PPO reward
899
- axes[1].plot(iters, reward_history, marker="o", linewidth=2,
900
- color="tab:blue", label="Collect")
901
- axes[1].plot(iters, eval_history, marker="s", linewidth=2,
902
- linestyle="--", color="tab:green", label="Eval")
903
- axes[1].axhline(y=baseline_reward, color="tab:gray", linestyle=":",
904
- linewidth=1.5, label="Baseline")
905
- axes[1].axhline(y=warmup_reward, color="tab:purple", linestyle=":",
906
- linewidth=1.5, label="Post-warmup")
907
- axes[1].set_title("B. PPO Reward ↑")
908
- axes[1].set_xlabel("Iteration")
909
- axes[1].set_ylabel("Avg Reward")
910
- axes[1].legend(fontsize=7)
911
- axes[1].grid(alpha=0.3)
912
-
913
- # Panel C: PPO loss
914
- axes[2].plot(iters, loss_history, marker="o", linewidth=2,
915
- color="tab:red", label="Total")
916
- axes[2].plot(iters, policy_loss_history, marker="^", linewidth=2,
917
- linestyle="--", color="tab:orange", label="Policy")
918
- axes[2].set_title("C. PPO Loss ↓")
919
- axes[2].set_xlabel("Iteration")
920
- axes[2].set_ylabel("Loss")
921
- axes[2].legend(fontsize=7)
922
- axes[2].grid(alpha=0.3)
923
-
924
- fig.suptitle("Code Review Agent – Full Training Evidence",
925
- fontsize=14, fontweight="bold")
926
- fig.tight_layout()
927
- fig.savefig("training_summary.png", dpi=150)
928
- plt.close(fig)
929
-
930
- print("Plots saved: warmup_loss.png, reward_curve.png, "
931
- "loss_curve.png, training_summary.png")
932
- print("="*60)
933
-
934
- if __name__ == "__main__":
935
  train_ppo()
 
1
+ # training.py – Memory‑safe: Phi‑3‑mini + Expert Demos + Fast PPO (2 iterations)
2
+ import os
3
+ os.environ["TRITON_DISABLE"] = "1"
4
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # Issue #12: prevent OOM from parallel tokenization
5
+
6
+ import torch._dynamo
7
+ torch._dynamo.config.disable = True
8
+ import json
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch.optim import AdamW
12
+ from dataclasses import dataclass
13
+ from typing import List, Dict, Tuple, Optional
14
+ import numpy as np
15
+ import re
16
+ import random
17
+ import matplotlib.pyplot as plt
18
+
19
+ from unsloth import FastLanguageModel
20
+ from transformers import TrainingArguments
21
+ from trl import SFTTrainer
22
+ from datasets import Dataset
23
+
24
+ from environment import CodeReviewEnv
25
+ from redteam import BUG_DB
26
+ from models import (
27
+ RunTests, RunLinter, Inspect,
28
+ ProposeFix, WriteComment, AskQuestion,
29
+ Done, Skip, QueryDocs, map_to_env as model_map_to_env
30
+ )
31
+
32
+ # ======================================================================
33
+ @dataclass
34
+ class AgentAction:
35
+ action_type: str
36
+ content: Optional[str] = None
37
+
38
+ def parse_action(output: str) -> AgentAction:
39
+ try:
40
+ data = json.loads(output)
41
+ return AgentAction(
42
+ action_type=data.get("action_type", "").lower(),
43
+ content=data.get("content")
44
+ )
45
+ except:
46
+ pass
47
+ json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', output, re.DOTALL)
48
+ if json_match:
49
+ try:
50
+ data = json.loads(json_match.group(1))
51
+ return AgentAction(
52
+ action_type=data.get("action_type", "").lower(),
53
+ content=data.get("content")
54
+ )
55
+ except:
56
+ pass
57
+ action_pattern = r'"action_type"\s*:\s*"(\w+)"'
58
+ match = re.search(action_pattern, output)
59
+ if match:
60
+ return AgentAction(action_type=match.group(1).lower())
61
+ output_lower = output.lower()
62
+ if "test" in output_lower:
63
+ return AgentAction("run_tests")
64
+ if "lint" in output_lower:
65
+ return AgentAction("run_linter")
66
+ if "inspect" in output_lower:
67
+ return AgentAction("inspect")
68
+ if "doc" in output_lower or "documentation" in output_lower:
69
+ return AgentAction("query_docs", "bug fix guidance")
70
+ return AgentAction("invalid", output)
71
+
72
+ def map_to_env(action: AgentAction):
73
+ return model_map_to_env(action.action_type, action.content)
74
+
75
+ # ======================================================================
76
+ def load_model():
77
+ model, tokenizer = FastLanguageModel.from_pretrained(
78
+ model_name="unsloth/Phi-3-mini-4k-instruct-bnb-4bit",
79
+ max_seq_length=480, # smaller window for memory
80
+ load_in_4bit=True,
81
+ )
82
+ model = FastLanguageModel.get_peft_model(
83
+ model,
84
+ r=16,
85
+ target_modules=[
86
+ "q_proj", "k_proj", "v_proj", "o_proj",
87
+ "gate_proj", "up_proj", "down_proj"
88
+ ],
89
+ lora_alpha=32,
90
+ lora_dropout=0.0,
91
+ )
92
+ return model, tokenizer
93
+
94
+ def test_model_sanity(model, tokenizer) -> bool:
95
+ print("\n" + "="*60)
96
+ print("SANITY CHECK: Testing base model generation")
97
+ print("="*60)
98
+ test_prompt = "Hello, how are you?"
99
+ messages = [{"role": "user", "content": test_prompt}]
100
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
101
+ inputs = tokenizer(formatted, return_tensors="pt", max_length=256, truncation=True).to("cuda")
102
+ with torch.no_grad():
103
+ outputs = model.generate(
104
+ **inputs,
105
+ max_new_tokens=30,
106
+ do_sample=True,
107
+ temperature=0.7,
108
+ min_new_tokens=1,
109
+ eos_token_id=tokenizer.eos_token_id,
110
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
111
+ )
112
+ generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
113
+ response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
114
+ print(f"Prompt: {test_prompt}")
115
+ print(f"Response: {repr(response)}")
116
+ if len(response) == 0:
117
+ print("❌ Model produces empty output – cannot train.")
118
+ return False
119
+ print("βœ“ Model sanity check PASSED\n")
120
+ return True
121
+
122
+ # ======================================================================
123
+ def _expert_fix_from_context(obs) -> str:
124
+ """
125
+ Build a conservative fix template named `fix` (required by tests).
126
+ Uses bug hints + code snippet patterns to create realistic fixes.
127
+ """
128
+ bug = (getattr(obs, "bug_description", "") or "").lower()
129
+ code = getattr(obs, "code_snippet", "") or ""
130
+
131
+ if "division" in bug or "average" in code.lower():
132
+ return (
133
+ "def fix(data):\n"
134
+ " if not data:\n"
135
+ " return 0\n"
136
+ " return sum(data) / len(data)"
137
+ )
138
+
139
+ if "operator" in bug or "sign" in bug:
140
+ return (
141
+ "def fix(a, b):\n"
142
+ " return a + b"
143
+ )
144
+
145
+ if "off_by_one" in bug or "loop" in bug:
146
+ return (
147
+ "def fix(items):\n"
148
+ " return len(items)"
149
+ )
150
+
151
+ if "null" in bug or "key" in bug or "dict" in code.lower():
152
+ return (
153
+ "def fix(payload):\n"
154
+ " users = payload.get('users', {})\n"
155
+ " user_id = payload.get('id')\n"
156
+ " return users.get(user_id)"
157
+ )
158
+
159
+ # Concurrency-heavy tasks (harder/hardest).
160
+ if "race" in bug or "missing_lock" in bug or "thread_safe" in bug or "global_nonatomic" in bug:
161
+ return (
162
+ "import threading\n"
163
+ "_lock = threading.Lock()\n"
164
+ "\n"
165
+ "def fix(counter):\n"
166
+ " with _lock:\n"
167
+ " if counter is None:\n"
168
+ " return 0\n"
169
+ " return counter + 1"
170
+ )
171
+
172
+ if "deadlock" in bug or "double_lock" in bug or "lock order" in bug or "nested_lock" in bug:
173
+ return (
174
+ "import threading\n"
175
+ "_lock_a = threading.Lock()\n"
176
+ "_lock_b = threading.Lock()\n"
177
+ "\n"
178
+ "def fix(work):\n"
179
+ " first, second = (_lock_a, _lock_b)\n"
180
+ " if id(first) > id(second):\n"
181
+ " first, second = second, first\n"
182
+ " with first:\n"
183
+ " with second:\n"
184
+ " return work() if callable(work) else work"
185
+ )
186
+
187
+ if "fork_join" in bug or "join" in bug:
188
+ return (
189
+ "import threading\n"
190
+ "\n"
191
+ "def fix(worker):\n"
192
+ " t = threading.Thread(target=worker)\n"
193
+ " t.start()\n"
194
+ " t.join()\n"
195
+ " return True"
196
+ )
197
+
198
+ # Generic safe fallback keeps the RL pipeline alive for unknown bugs.
199
+ return (
200
+ "def fix(data):\n"
201
+ " if data is None:\n"
202
+ " return None\n"
203
+ " return data"
204
+ )
205
+
206
+
207
+ def _expert_supervised_policy(obs) -> str:
208
+ """
209
+ Real workflow policy:
210
+ inspect -> tests/linter -> docs -> fix -> negotiate -> done.
211
+ """
212
+ author_msg = (getattr(obs, "author_response", "") or "").lower()
213
+ tool_output = (getattr(obs, "last_tool_output", "") or "").lower()
214
+
215
+ if not getattr(obs, "tests_run", False):
216
+ if "inspect" not in tool_output:
217
+ return '{"action_type": "inspect"}'
218
+ return '{"action_type": "run_tests"}'
219
+
220
+ if not getattr(obs, "linter_run", False):
221
+ return '{"action_type": "run_linter"}'
222
+
223
+ if not getattr(obs, "docs_queried", False):
224
+ return '{"action_type": "query_docs", "content": "python bug fixing best practices for edge cases and null safety"}'
225
+
226
+ # Use docs again on hard tasks when evidence is still weak.
227
+ if getattr(obs, "current_test_score", 0.0) < 0.6 and getattr(obs, "step", 0) >= 3:
228
+ bug_hint = (getattr(obs, "bug_description", "") or "concurrency bug").replace('"', "'")
229
+ return json.dumps(
230
+ {
231
+ "action_type": "query_docs",
232
+ "content": f"python {bug_hint} lock ordering race condition mitigation patterns",
233
+ }
234
+ )
235
+
236
+ # If test quality is poor, propose a concrete fix.
237
+ if getattr(obs, "current_test_score", 0.0) < 0.95:
238
+ fix_code = _expert_fix_from_context(obs)
239
+ return json.dumps({"action_type": "fix", "content": fix_code})
240
+
241
+ # If author is still unconvinced, provide causal explanation.
242
+ if author_msg and ("not convinced" in author_msg or "explain" in author_msg or "brief" in author_msg):
243
+ return (
244
+ '{"action_type": "comment", "content": "This fix works because it handles the failing edge case directly, '
245
+ 'keeps behavior deterministic, and aligns with the observed test and lint feedback. '
246
+ 'The change is intentionally small to reduce regression risk."}'
247
+ )
248
+
249
+ # If negotiation is strong enough and quality is good, terminate.
250
+ conf = float(getattr(obs, "author_confidence", 0.0))
251
+ threshold = float(getattr(obs, "author_threshold", 0.5))
252
+ score = float(getattr(obs, "current_test_score", 0.0))
253
+ if conf >= threshold and score >= 0.8:
254
+ return '{"action_type": "done"}'
255
+
256
+ # Nudge conversation forward when tests are okay but acceptance is pending.
257
+ return (
258
+ '{"action_type": "question", "content": "Would you like a quick walkthrough of a failing scenario, the root cause, and how the fix prevents regressions?"}'
259
+ )
260
+
261
+ # ======================================================================
262
+ def supervised_warmup(model, tokenizer, env, n_episodes=16, epochs=1, max_steps=8):
263
+ print("\n" + "="*60)
264
+ print("SUPERVISED WARM-UP: Real environment demonstrations")
265
+ print("="*60)
266
+
267
+ examples = []
268
+ tasks = ["easy", "medium", "hard", "harder", "hardest"]
269
+ for ep in range(n_episodes):
270
+ task = random.choice(tasks)
271
+ env.set_task(task)
272
+ obs = env.reset()
273
+ history = []
274
+ done = False
275
+
276
+ steps = 0
277
+ while not done and steps < max_steps:
278
+ prompt = build_prompt(obs, history)
279
+ action_text = _expert_supervised_policy(obs)
280
+ action = parse_action(action_text)
281
+ env_action = map_to_env(action)
282
+ next_obs, _, done, _ = env.step(env_action)
283
+
284
+ messages = [
285
+ {"role": "user", "content": prompt},
286
+ {"role": "assistant", "content": action_text},
287
+ ]
288
+ full_text = tokenizer.apply_chat_template(messages, tokenize=False)
289
+ examples.append({"text": full_text})
290
+
291
+ history.append(f"Agent: {action_text}")
292
+ history.append(f"Env: {next_obs.last_tool_output}")
293
+ history = history[-8:]
294
+ obs = next_obs
295
+ steps += 1
296
+
297
+ print(f"Supervised episode {ep+1}: task={task}, steps={steps}, done={done}")
298
+
299
+ if not examples:
300
+ print("No supervised examples generated; skipping warm-up.")
301
+ return
302
+
303
+ dataset = Dataset.from_list(examples)
304
+ trainer = SFTTrainer(
305
+ model=model,
306
+ tokenizer=tokenizer,
307
+ train_dataset=dataset,
308
+ dataset_text_field="text",
309
+ max_seq_length=480,
310
+ args=TrainingArguments(
311
+ output_dir="warmup_output",
312
+ num_train_epochs=epochs,
313
+ per_device_train_batch_size=2,
314
+ gradient_accumulation_steps=2,
315
+ learning_rate=2e-5,
316
+ logging_steps=50,
317
+ save_strategy="no",
318
+ bf16=True,
319
+ ),
320
+ )
321
+ print(f"Training on {len(examples)} real env examples for {epochs} epochs...")
322
+ trainer.train()
323
+ print("βœ“ Supervised warm-up (real env) complete\n")
324
+ torch.cuda.empty_cache()
325
+
326
+ # ======================================================================
327
+ def generate_action_with_logprob(prompt, model, tokenizer, temperature=0.0, max_retries=2):
328
+ messages = [{"role": "user", "content": prompt}]
329
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
330
+ inputs = tokenizer(formatted, return_tensors="pt", max_length=480, truncation=True).to("cuda")
331
+
332
+ for attempt in range(max_retries):
333
+ with torch.no_grad():
334
+ outputs = model.generate(
335
+ **inputs,
336
+ max_new_tokens=64,
337
+ do_sample=(temperature > 0),
338
+ temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
339
+ min_new_tokens=1,
340
+ return_dict_in_generate=True,
341
+ output_scores=True,
342
+ )
343
+ generated_ids = outputs.sequences[0][inputs['input_ids'].shape[1]:]
344
+ action_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
345
+
346
+ logprobs = []
347
+ for idx, token_id in enumerate(generated_ids):
348
+ if idx < len(outputs.scores):
349
+ token_logits = outputs.scores[idx][0]
350
+ token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
351
+ logprobs.append(token_logprob)
352
+ total_logprob = sum(logprobs) if logprobs else -100.0
353
+
354
+ if not action_text:
355
+ fallback_actions = [
356
+ '{"action_type": "run_tests"}',
357
+ '{"action_type": "run_linter"}',
358
+ '{"action_type": "inspect"}',
359
+ '{"action_type": "skip"}',
360
+ ]
361
+ action_text = random.choice(fallback_actions)
362
+ total_logprob = -50.0
363
+ print(f"[WARN] Empty generation β†’ using fallback: {action_text}")
364
+ return action_text, total_logprob
365
+
366
+ try:
367
+ json.loads(action_text)
368
+ return action_text, total_logprob
369
+ except:
370
+ if attempt == max_retries - 1:
371
+ return '{"action_type":"skip"}', -100.0
372
+ continue
373
+ return '{"action_type":"skip"}', -100.0
374
+
375
+ # ======================================================================
376
+ def build_prompt(obs, history_lines: List[str]) -> str:
377
+ author_msg = getattr(obs, "author_response", "") or ""
378
+ tool_output = getattr(obs, "last_tool_output", "") or ""
379
+ author_personality = getattr(obs, "author_personality", "defensive")
380
+
381
+ prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix and name your proposed fix function fix.
382
+
383
+ The developer has a **{author_personality}** personality and will only accept if you provide solid evidence:
384
+ - Tests pass (high pass ratio)
385
+ - Lint is clean (zero errors)
386
+ - Documentation or references are provided
387
+ - Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
388
+
389
+ Workflow:
390
+ 1. Use `inspect` to understand the code.
391
+ 2. Use `run_tests` and `run_linter` to gather evidence.
392
+ 3. Use `query_docs` when you need references or language-specific guidance.
393
+ 4. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
394
+ 5. If the developer pushes back, read their response carefully and address their specific concern.
395
+ 6. Once convinced, use `done` to finish.
396
+
397
+ Code:
398
+ {obs.code_snippet}
399
+
400
+ Author says:
401
+ {author_msg if author_msg else "(no response yet – start with inspection)"}
402
+
403
+ Last tool output:
404
+ {tool_output if tool_output else "(none)"}
405
+
406
+ Available actions:
407
+ run_tests, run_linter, inspect, query_docs, fix, comment, question, done
408
+
409
+ Respond ONLY in JSON:
410
+ {{"action_type": "...", "content": "..."}}"""
411
+
412
+ if history_lines:
413
+ history = "\n".join(history_lines[-6:])
414
+ prompt += f"\n\nPrevious steps:\n{history}"
415
+ return prompt
416
+
417
+ # ======================================================================
418
+ @dataclass
419
+ class Trajectory:
420
+ states: List[str]
421
+ actions: List[str]
422
+ rewards: List[float]
423
+ logprobs: List[float]
424
+ dones: List[bool]
425
+ def __len__(self): return len(self.states)
426
+
427
+ def collect_trajectory(env, model, tokenizer, max_steps=6, temperature=0.0):
428
+ obs = env.reset()
429
+ history_lines = []
430
+ states, actions, rewards, logprobs, dones = [], [], [], [], []
431
+ for step in range(max_steps):
432
+ prompt = build_prompt(obs, history_lines)
433
+ states.append(prompt)
434
+ action_text, logprob = generate_action_with_logprob(prompt, model, tokenizer, temperature)
435
+ actions.append(action_text)
436
+ logprobs.append(logprob)
437
+ action = parse_action(action_text)
438
+ env_action = map_to_env(action)
439
+ next_obs, reward, done, _ = env.step(env_action)
440
+ rewards.append(reward.value)
441
+ dones.append(done)
442
+ history_lines.append(f"Agent: {action_text}")
443
+ history_lines.append(f"Env: {next_obs.last_tool_output}")
444
+ obs = next_obs
445
+ if done: break
446
+ return Trajectory(states, actions, rewards, logprobs, dones)
447
+
448
+ def collect_trajectories(env, model, tokenizer, n_trajectories, max_steps=6,
449
+ task_levels=None, task_weights=None):
450
+ if task_levels is None:
451
+ task_levels = list(BUG_DB.keys())
452
+ if task_weights is not None and len(task_weights) != len(task_levels):
453
+ raise ValueError("task_weights must match task_levels length")
454
+ if task_weights is not None and sum(task_weights) <= 0:
455
+ raise ValueError("task_weights must have a positive total")
456
+ trajectories = []
457
+ for i in range(n_trajectories):
458
+ sampled_task = random.choices(task_levels, weights=task_weights, k=1)[0]
459
+ env.set_task(sampled_task)
460
+ traj = collect_trajectory(env, model, tokenizer, max_steps)
461
+ total_reward = sum(traj.rewards)
462
+ print(f"Trajectory {i+1}/{n_trajectories}: task={sampled_task}, steps={len(traj)}, reward={total_reward:.3f}")
463
+ trajectories.append(traj)
464
+ return trajectories
465
+
466
+ def compute_returns_and_advantages(rewards, dones, gamma=0.99, standardize=True):
467
+ """
468
+ Compute discounted returns and REINFORCE-style baseline advantages.
469
+ Advantages are centered and optionally standardised.
470
+ """
471
+ n = len(rewards)
472
+ returns = [0.0]*n
473
+ running = 0.0
474
+ for t in reversed(range(n)):
475
+ if dones[t]: running = 0.0
476
+ running = rewards[t] + gamma * running
477
+ returns[t] = running
478
+ if standardize:
479
+ advantages = np.array(returns) - np.mean(returns)
480
+ adv_std = np.std(advantages) + 1e-8
481
+ advantages = (advantages / adv_std).tolist()
482
+ else:
483
+ advantages = returns.copy()
484
+ return advantages, returns
485
+
486
+ def ppo_update(trajectories, model, tokenizer, optimizer, n_epochs=1, clip_epsilon=0.2,
487
+ entropy_coef=0.01, gamma=0.99):
488
+ model.train()
489
+ all_states, all_actions, all_old_logprobs, all_advantages = [], [], [], []
490
+ for traj in trajectories:
491
+ advantages, _ = compute_returns_and_advantages(traj.rewards, traj.dones, gamma=gamma, standardize=True)
492
+ all_states.extend(traj.states)
493
+ all_actions.extend(traj.actions)
494
+ all_old_logprobs.extend(traj.logprobs)
495
+ all_advantages.extend(advantages)
496
+ n_samples = len(all_states)
497
+ total_loss, total_policy_loss, total_entropy, n_updates = 0.0, 0.0, 0.0, 0
498
+ for epoch in range(n_epochs):
499
+ indices = np.random.permutation(n_samples)
500
+ for i in indices:
501
+ state = all_states[i]
502
+ action = all_actions[i]
503
+ old_logprob = all_old_logprobs[i]
504
+ advantage = all_advantages[i]
505
+ messages = [{"role": "user", "content": state}]
506
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
507
+ full_text = formatted + action
508
+ inputs = tokenizer(full_text, return_tensors="pt", max_length=480, truncation=True).to("cuda")
509
+ outputs = model(**inputs)
510
+ logits = outputs.logits
511
+ action_ids = tokenizer.encode(action, add_special_tokens=False)
512
+ prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
513
+ action_start = len(prefix_ids)
514
+ logprobs = []
515
+ entropy = 0.0
516
+ for idx, token_id in enumerate(action_ids):
517
+ position = action_start + idx - 1
518
+ if 0 <= position < logits.shape[1]:
519
+ token_logits = logits[0, position]
520
+ log_probs = F.log_softmax(token_logits, dim=-1)
521
+ token_logprob = log_probs[token_id]
522
+ logprobs.append(token_logprob)
523
+ probs = F.softmax(token_logits, dim=-1)
524
+ entropy += -(probs * log_probs).sum()
525
+ if not logprobs: continue
526
+ new_logprob = sum(logprobs)
527
+ avg_entropy = entropy / len(logprobs) if logprobs else 0.0
528
+ ratio = torch.exp(new_logprob - old_logprob)
529
+ surr1 = ratio * advantage
530
+ surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantage
531
+ policy_loss = -torch.min(surr1, surr2)
532
+ loss = policy_loss - entropy_coef * avg_entropy
533
+ optimizer.zero_grad()
534
+ loss.backward()
535
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
536
+ optimizer.step()
537
+ total_loss += loss.item()
538
+ total_policy_loss += policy_loss.item()
539
+ total_entropy += avg_entropy.item()
540
+ n_updates += 1
541
+ torch.cuda.empty_cache()
542
+ return {"loss": total_loss / n_updates if n_updates else 0.0,
543
+ "policy_loss": total_policy_loss / n_updates if n_updates else 0.0,
544
+ "entropy": total_entropy / n_updates if n_updates else 0.0}
545
+
546
+ def evaluate_policy(env, model, tokenizer, n_episodes=3, max_steps=6,
547
+ task_levels=None, verbose=False):
548
+ """Evaluate the current policy across task levels. Returns metrics + optional traces."""
549
+ model.eval()
550
+ if task_levels is None:
551
+ task_levels = list(BUG_DB.keys())
552
+ total_rewards = []
553
+ traces = [] # human-readable behavior logs
554
+ for ep in range(n_episodes):
555
+ task = task_levels[ep % len(task_levels)]
556
+ env.set_task(task)
557
+ traj = collect_trajectory(env, model, tokenizer, max_steps, temperature=0.0)
558
+ ep_reward = sum(traj.rewards)
559
+ total_rewards.append(ep_reward)
560
+ if verbose:
561
+ actions_taken = []
562
+ for a in traj.actions:
563
+ try:
564
+ actions_taken.append(json.loads(a).get("action_type", "?"))
565
+ except Exception:
566
+ actions_taken.append("?")
567
+ traces.append({
568
+ "task": task,
569
+ "reward": round(ep_reward, 4),
570
+ "steps": len(traj),
571
+ "actions": actions_taken,
572
+ })
573
+ return {
574
+ "avg_reward": float(np.mean(total_rewards)),
575
+ "std_reward": float(np.std(total_rewards)),
576
+ "min_reward": float(np.min(total_rewards)),
577
+ "max_reward": float(np.max(total_rewards)),
578
+ "traces": traces,
579
+ }
580
+
581
+ # ======================================================================
582
+ # MANUAL WARM-UP (no SFTTrainer β†’ no multiprocessing OOM)
583
+ # ======================================================================
584
+ def json_warmup(model, tokenizer, json_path="training_data.json",
585
+ n_episodes=20, epochs=2, lr=2e-5):
586
+ """
587
+ Supervised warm-up from pre-generated expert demonstrations.
588
+ Uses raw cross-entropy on action tokens with manual gradient steps.
589
+ NO SFTTrainer, NO multiprocessing – runs safely on any GPU.
590
+ """
591
+ print("\n" + "="*60)
592
+ print("SUPERVISED WARM-UP: training_data.json (manual cross-entropy)")
593
+ print("="*60)
594
+
595
+ with open(json_path, encoding="utf-8") as f:
596
+ data = json.load(f)
597
+
598
+ # Each episode = 7 steps. Select n_episodes worth.
599
+ steps_per_episode = 7
600
+ max_examples = n_episodes * steps_per_episode
601
+ if max_examples < len(data):
602
+ data = data[:max_examples]
603
+
604
+ print(f" {len(data)} examples ({len(data)//steps_per_episode} episodes), "
605
+ f"{epochs} epoch(s), lr={lr}")
606
+
607
+ model.train()
608
+ warmup_opt = AdamW(model.parameters(), lr=lr)
609
+ warmup_losses = [] # per-epoch avg loss
610
+
611
+ for epoch in range(epochs):
612
+ random.shuffle(data)
613
+ epoch_loss = 0.0
614
+ n_valid = 0
615
+
616
+ for i, example in enumerate(data):
617
+ prompt = example["prompt"]
618
+ action = example["action"]
619
+
620
+ # ---- tokenize full sequence (prompt + action) ----
621
+ messages = [
622
+ {"role": "user", "content": prompt},
623
+ {"role": "assistant", "content": action},
624
+ ]
625
+ full_text = tokenizer.apply_chat_template(messages, tokenize=False)
626
+ inputs = tokenizer(full_text, return_tensors="pt",
627
+ max_length=480, truncation=True).to("cuda")
628
+
629
+ # ---- find where the action tokens start ----
630
+ prompt_only = tokenizer.apply_chat_template(
631
+ [{"role": "user", "content": prompt}],
632
+ tokenize=False, add_generation_prompt=True
633
+ )
634
+ prompt_ids = tokenizer.encode(prompt_only, add_special_tokens=False)
635
+ prompt_len = len(prompt_ids)
636
+
637
+ total_len = inputs.input_ids.shape[1]
638
+ if prompt_len >= total_len:
639
+ continue # prompt was truncated away, skip
640
+
641
+ # ---- cross-entropy on action tokens only ----
642
+ outputs = model(**inputs)
643
+ logits = outputs.logits
644
+
645
+ # next-token prediction: logits[t] predicts token[t+1]
646
+ shift_logits = logits[0, prompt_len - 1 : total_len - 1]
647
+ shift_labels = inputs.input_ids[0, prompt_len : total_len]
648
+
649
+ min_len = min(shift_logits.shape[0], shift_labels.shape[0])
650
+ if min_len == 0:
651
+ continue
652
+
653
+ loss = F.cross_entropy(shift_logits[:min_len], shift_labels[:min_len])
654
+
655
+ warmup_opt.zero_grad()
656
+ loss.backward()
657
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
658
+ warmup_opt.step()
659
+
660
+ epoch_loss += loss.item()
661
+ n_valid += 1
662
+
663
+ if (i + 1) % 25 == 0:
664
+ avg = epoch_loss / n_valid
665
+ print(f" epoch {epoch+1} step {i+1:3d}/{len(data)} "
666
+ f"running_loss={avg:.4f}")
667
+
668
+ avg_loss = epoch_loss / max(n_valid, 1)
669
+ warmup_losses.append(avg_loss)
670
+ print(f" Epoch {epoch+1} done: avg_loss={avg_loss:.4f} "
671
+ f"({n_valid} valid examples)")
672
+
673
+ torch.cuda.empty_cache()
674
+ print(f"βœ“ Warm-up complete. Loss: "
675
+ f"{' β†’ '.join(f'{l:.4f}' for l in warmup_losses)}\n")
676
+ return warmup_losses
677
+
678
+
679
+ # ======================================================================
680
+ # MAIN TRAINING PIPELINE
681
+ # ======================================================================
682
+ def train_ppo():
683
+ # --- Hyperparameters ---
684
+ n_iterations = 8 # enough for a clear upward trend
685
+ trajectories_per_iter = 4 # on-policy data per iteration
686
+ n_epochs = 1
687
+ max_steps = 6
688
+ learning_rate = 3e-5
689
+ clip_epsilon = 0.2
690
+ entropy_coef = 0.01
691
+ gamma = 0.99
692
+
693
+ # --- Pre-load embedder before LLM (Issue #13) ---
694
+ from rltool import ToolBox
695
+ print("Pre-loading sentence-transformer embedder...")
696
+ ToolBox._get_embedder()
697
+ print("βœ“ Embedder ready")
698
+
699
+ # --- Load model ---
700
+ print("Loading model...")
701
+ model, tokenizer = load_model()
702
+ if not test_model_sanity(model, tokenizer):
703
+ return
704
+ env = CodeReviewEnv()
705
+ task_levels = list(BUG_DB.keys())
706
+
707
+ # ==================================================================
708
+ # PHASE 0: BASELINE (untrained policy)
709
+ # ==================================================================
710
+ print("\n" + "="*60)
711
+ print("PHASE 0 – BASELINE EVALUATION (untrained)")
712
+ print("="*60)
713
+ baseline = evaluate_policy(env, model, tokenizer, n_episodes=5,
714
+ max_steps=max_steps, task_levels=task_levels,
715
+ verbose=True)
716
+ baseline_reward = baseline["avg_reward"]
717
+ print(f"Baseline avg reward: {baseline_reward:.4f} "
718
+ f"(min={baseline['min_reward']:.4f}, max={baseline['max_reward']:.4f})")
719
+ print("Baseline behavior:")
720
+ for t in baseline["traces"]:
721
+ print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
722
+ f"steps={t['steps']} actions={t['actions']}")
723
+
724
+ # ==================================================================
725
+ # PHASE 1: SUPERVISED WARM-UP (expert demos, manual CE)
726
+ # ==================================================================
727
+ warmup_losses = json_warmup(
728
+ model, tokenizer,
729
+ json_path="training_data.json",
730
+ n_episodes=20, # 140 examples (20 Γ— 7 steps)
731
+ epochs=2,
732
+ lr=2e-5,
733
+ )
734
+
735
+ # Post-warmup evaluation
736
+ print("="*60)
737
+ print("POST WARM-UP EVALUATION")
738
+ print("="*60)
739
+ post_warmup = evaluate_policy(env, model, tokenizer, n_episodes=5,
740
+ max_steps=max_steps, task_levels=task_levels,
741
+ verbose=True)
742
+ warmup_reward = post_warmup["avg_reward"]
743
+ print(f"Post-warmup avg reward: {warmup_reward:.4f} "
744
+ f"(Ξ” vs baseline: {warmup_reward - baseline_reward:+.4f})")
745
+ print("Post-warmup behavior:")
746
+ for t in post_warmup["traces"]:
747
+ print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
748
+ f"steps={t['steps']} actions={t['actions']}")
749
+
750
+ # ==================================================================
751
+ # PHASE 2: TRUE RL – PPO (on-policy, real environment interaction)
752
+ # ==================================================================
753
+ optimizer = AdamW(model.parameters(), lr=learning_rate)
754
+ print(f"\n{'='*60}")
755
+ print(f"PHASE 2 – PPO TRAINING: {n_iterations} iterations Γ— "
756
+ f"{trajectories_per_iter} trajectories (true RL)")
757
+ print(f"{'='*60}\n")
758
+
759
+ reward_history = []
760
+ eval_history = []
761
+ loss_history = []
762
+ policy_loss_history = []
763
+ entropy_history = []
764
+
765
+ for iteration in range(n_iterations):
766
+ print(f"\n--- PPO Iteration {iteration + 1}/{n_iterations} ---")
767
+
768
+ # Collect on-policy trajectories from REAL environment
769
+ trajectories = collect_trajectories(
770
+ env, model, tokenizer, trajectories_per_iter, max_steps,
771
+ task_levels=task_levels, task_weights=None
772
+ )
773
+ avg_reward = float(np.mean([sum(t.rewards) for t in trajectories]))
774
+ reward_history.append(avg_reward)
775
+ print(f" Collect avg reward: {avg_reward:+.4f}")
776
+
777
+ # PPO policy gradient update
778
+ metrics = ppo_update(
779
+ trajectories, model, tokenizer, optimizer,
780
+ n_epochs=n_epochs, clip_epsilon=clip_epsilon,
781
+ entropy_coef=entropy_coef, gamma=gamma
782
+ )
783
+ loss_history.append(float(metrics["loss"]))
784
+ policy_loss_history.append(float(metrics["policy_loss"]))
785
+ entropy_history.append(float(metrics["entropy"]))
786
+ print(f" Update loss={metrics['loss']:.4f} "
787
+ f"policy={metrics['policy_loss']:.4f} "
788
+ f"entropy={metrics['entropy']:.4f}")
789
+
790
+ # Evaluate greedy policy after update
791
+ eval_m = evaluate_policy(env, model, tokenizer, n_episodes=3,
792
+ max_steps=max_steps, task_levels=task_levels,
793
+ verbose=False)
794
+ eval_history.append(eval_m["avg_reward"])
795
+ delta = eval_m["avg_reward"] - baseline_reward
796
+ print(f" Eval avg reward: {eval_m['avg_reward']:+.4f} "
797
+ f"(Ξ” baseline: {delta:+.4f})")
798
+
799
+ # ==================================================================
800
+ # PHASE 3: FINAL EVALUATION (proof of learning)
801
+ # ==================================================================
802
+ print("\n" + "="*60)
803
+ print("PHASE 3 – FINAL EVALUATION (after all training)")
804
+ print("="*60)
805
+ final = evaluate_policy(env, model, tokenizer, n_episodes=5,
806
+ max_steps=max_steps, task_levels=task_levels,
807
+ verbose=True)
808
+ print(f"Final avg reward: {final['avg_reward']:.4f} "
809
+ f"(min={final['min_reward']:.4f}, max={final['max_reward']:.4f})")
810
+ print("Final behavior:")
811
+ for t in final["traces"]:
812
+ print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
813
+ f"steps={t['steps']} actions={t['actions']}")
814
+
815
+ total_improvement = final["avg_reward"] - baseline_reward
816
+ ppo_improvement = final["avg_reward"] - warmup_reward
817
+ print(f"\n{'='*60}")
818
+ print("TRAINING SUMMARY")
819
+ print(f" Baseline reward: {baseline_reward:+.4f}")
820
+ print(f" Post-warmup reward: {warmup_reward:+.4f} "
821
+ f"(warmup Ξ”: {warmup_reward - baseline_reward:+.4f})")
822
+ print(f" Final reward: {final['avg_reward']:+.4f} "
823
+ f"(PPO Ξ”: {ppo_improvement:+.4f})")
824
+ print(f" Total improvement: {total_improvement:+.4f}")
825
+ print(f" Reward trend (PPO): {' β†’ '.join(f'{r:+.3f}' for r in reward_history)}")
826
+ print(f" Loss trend (PPO): {' β†’ '.join(f'{l:.4f}' for l in loss_history)}")
827
+ if total_improvement > 0:
828
+ print(f" βœ“ Agent IMPROVED by {total_improvement:+.4f}")
829
+ else:
830
+ print(f" βœ— No overall improvement detected")
831
+ print(f"{'='*60}")
832
+
833
+ # ==================================================================
834
+ # PLOTS
835
+ # ==================================================================
836
+ iters = list(range(1, n_iterations + 1))
837
+
838
+ # --- 1. Warm-up loss curve ---
839
+ if warmup_losses:
840
+ fig, ax = plt.subplots(figsize=(7, 4))
841
+ ax.plot(range(1, len(warmup_losses) + 1), warmup_losses,
842
+ marker="o", linewidth=2, color="tab:purple")
843
+ ax.set_title("Warm-up Loss (supervised, per epoch)",
844
+ fontsize=13, fontweight="bold")
845
+ ax.set_xlabel("Epoch")
846
+ ax.set_ylabel("Cross-Entropy Loss")
847
+ ax.grid(alpha=0.3)
848
+ fig.tight_layout()
849
+ fig.savefig("warmup_loss.png", dpi=150)
850
+ plt.close(fig)
851
+
852
+ # --- 2. PPO reward curve ---
853
+ fig, ax = plt.subplots(figsize=(9, 5))
854
+ ax.plot(iters, reward_history, marker="o", linewidth=2,
855
+ label="Collect reward", color="tab:blue")
856
+ ax.plot(iters, eval_history, marker="s", linewidth=2, linestyle="--",
857
+ label="Eval reward", color="tab:green")
858
+ ax.axhline(y=baseline_reward, color="tab:gray", linestyle=":",
859
+ linewidth=1.5, label=f"Baseline ({baseline_reward:+.3f})")
860
+ ax.axhline(y=warmup_reward, color="tab:purple", linestyle=":",
861
+ linewidth=1.5, label=f"Post-warmup ({warmup_reward:+.3f})")
862
+ ax.set_title("PPO Reward per Iteration", fontsize=14, fontweight="bold")
863
+ ax.set_xlabel("Iteration")
864
+ ax.set_ylabel("Average Reward")
865
+ ax.legend(loc="best", fontsize=8)
866
+ ax.grid(alpha=0.3)
867
+ fig.tight_layout()
868
+ fig.savefig("reward_curve.png", dpi=150)
869
+ plt.close(fig)
870
+
871
+ # --- 3. PPO loss curve ---
872
+ fig, ax = plt.subplots(figsize=(9, 5))
873
+ ax.plot(iters, loss_history, marker="o", linewidth=2,
874
+ label="Total loss", color="tab:red")
875
+ ax.plot(iters, policy_loss_history, marker="^", linewidth=2, linestyle="--",
876
+ label="Policy loss", color="tab:orange")
877
+ ax.set_title("PPO Loss per Iteration", fontsize=14, fontweight="bold")
878
+ ax.set_xlabel("Iteration")
879
+ ax.set_ylabel("Loss")
880
+ ax.legend(loc="best")
881
+ ax.grid(alpha=0.3)
882
+ fig.tight_layout()
883
+ fig.savefig("loss_curve.png", dpi=150)
884
+ plt.close(fig)
885
+
886
+ # --- 4. Combined 3-panel summary ---
887
+ fig, axes = plt.subplots(1, 3, figsize=(18, 5))
888
+
889
+ # Panel A: warm-up loss
890
+ if warmup_losses:
891
+ axes[0].plot(range(1, len(warmup_losses) + 1), warmup_losses,
892
+ marker="o", linewidth=2, color="tab:purple")
893
+ axes[0].set_title("A. Warm-up Loss ↓")
894
+ axes[0].set_xlabel("Epoch")
895
+ axes[0].set_ylabel("CE Loss")
896
+ axes[0].grid(alpha=0.3)
897
+
898
+ # Panel B: PPO reward
899
+ axes[1].plot(iters, reward_history, marker="o", linewidth=2,
900
+ color="tab:blue", label="Collect")
901
+ axes[1].plot(iters, eval_history, marker="s", linewidth=2,
902
+ linestyle="--", color="tab:green", label="Eval")
903
+ axes[1].axhline(y=baseline_reward, color="tab:gray", linestyle=":",
904
+ linewidth=1.5, label="Baseline")
905
+ axes[1].axhline(y=warmup_reward, color="tab:purple", linestyle=":",
906
+ linewidth=1.5, label="Post-warmup")
907
+ axes[1].set_title("B. PPO Reward ↑")
908
+ axes[1].set_xlabel("Iteration")
909
+ axes[1].set_ylabel("Avg Reward")
910
+ axes[1].legend(fontsize=7)
911
+ axes[1].grid(alpha=0.3)
912
+
913
+ # Panel C: PPO loss
914
+ axes[2].plot(iters, loss_history, marker="o", linewidth=2,
915
+ color="tab:red", label="Total")
916
+ axes[2].plot(iters, policy_loss_history, marker="^", linewidth=2,
917
+ linestyle="--", color="tab:orange", label="Policy")
918
+ axes[2].set_title("C. PPO Loss ↓")
919
+ axes[2].set_xlabel("Iteration")
920
+ axes[2].set_ylabel("Loss")
921
+ axes[2].legend(fontsize=7)
922
+ axes[2].grid(alpha=0.3)
923
+
924
+ fig.suptitle("Code Review Agent – Full Training Evidence",
925
+ fontsize=14, fontweight="bold")
926
+ fig.tight_layout()
927
+ fig.savefig("training_summary.png", dpi=150)
928
+ plt.close(fig)
929
+
930
+ print("Plots saved: warmup_loss.png, reward_curve.png, "
931
+ "loss_curve.png, training_summary.png")
932
+ print("="*60)
933
+
934
+ if __name__ == "__main__":
935
  train_ppo()
training_data.json CHANGED
The diff for this file is too large to render. See raw diff