Upload 23 files
Browse files- .gitattributes +35 -35
- .gitignore +3 -0
- Dockerfile +23 -23
- README.md +117 -14
- __init__.py +12 -10
- app.py +104 -104
- environment.py +624 -624
- grader.py +13 -12
- models.py +111 -111
- pyproject.toml +38 -30
- requirements-training.txt +9 -0
- requirements.txt +5 -11
- rltool.py +143 -127
- rubrics.py +136 -136
- training.py +934 -934
- training_data.json +0 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,35 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
tmp/
|
Dockerfile
CHANGED
|
@@ -1,24 +1,24 @@
|
|
| 1 |
-
# Dockerfile β OpenEnv server with FastAPI and all dependencies
|
| 2 |
-
FROM python:3.10-slim
|
| 3 |
-
|
| 4 |
-
# Install system dependencies required for chromadb and sentence-transformers
|
| 5 |
-
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 6 |
-
build-essential \
|
| 7 |
-
&& rm -rf /var/lib/apt/lists/*
|
| 8 |
-
|
| 9 |
-
WORKDIR /app
|
| 10 |
-
|
| 11 |
-
# Copy requirements and install Python dependencies
|
| 12 |
-
COPY requirements.txt .
|
| 13 |
-
RUN pip install --no-cache-dir -r requirements.txt
|
| 14 |
-
|
| 15 |
-
# Copy the rest of the application
|
| 16 |
-
COPY . .
|
| 17 |
-
|
| 18 |
-
# Expose the port used by the FastAPI server
|
| 19 |
-
EXPOSE 7860
|
| 20 |
-
|
| 21 |
-
# Run the server using uvicorn
|
| 22 |
-
# Note: 'server.app:app' assumes the FastAPI app is in server/app.py
|
| 23 |
-
ENV ENABLE_WEB_INTERFACE=true
|
| 24 |
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
|
| 1 |
+
# Dockerfile β OpenEnv server with FastAPI and all dependencies
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
# Install system dependencies required for chromadb and sentence-transformers
|
| 5 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 6 |
+
build-essential \
|
| 7 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 8 |
+
|
| 9 |
+
WORKDIR /app
|
| 10 |
+
|
| 11 |
+
# Copy requirements and install Python dependencies
|
| 12 |
+
COPY requirements.txt .
|
| 13 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 14 |
+
|
| 15 |
+
# Copy the rest of the application
|
| 16 |
+
COPY . .
|
| 17 |
+
|
| 18 |
+
# Expose the port used by the FastAPI server
|
| 19 |
+
EXPOSE 7860
|
| 20 |
+
|
| 21 |
+
# Run the server using uvicorn
|
| 22 |
+
# Note: 'server.app:app' assumes the FastAPI app is in server/app.py
|
| 23 |
+
ENV ENABLE_WEB_INTERFACE=true
|
| 24 |
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,14 +1,117 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: CodeReview Training
|
| 3 |
-
emoji: π€
|
| 4 |
-
colorFrom: blue
|
| 5 |
-
colorTo: green
|
| 6 |
-
sdk:
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: CodeReview Training
|
| 3 |
+
emoji: "π€"
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# CodeReview Professional Workflow
|
| 11 |
+
|
| 12 |
+
`CodeReview Professional Workflow` is an OpenEnv environment for training code-fixing agents on realistic review loops instead of one-shot coding tasks. The agent has to inspect buggy code, run tests, lint the patch, query docs, and persuade a simulated author before the episode is considered solved.
|
| 13 |
+
|
| 14 |
+
## Quick links
|
| 15 |
+
|
| 16 |
+
| Artifact | Link |
|
| 17 |
+
| --- | --- |
|
| 18 |
+
| Hugging Face Space | [100XZX001/CodeReview-Professional-Workflow](https://huggingface.co/spaces/100XZX001/CodeReview-Professional-Workflow) |
|
| 19 |
+
| Colab-ready training notebook | [notebooks/code_review_unsloth_training.ipynb](notebooks/code_review_unsloth_training.ipynb) |
|
| 20 |
+
| Local training script | [training.py](training.py) |
|
| 21 |
+
| OpenEnv manifest | [openenv.yaml](openenv.yaml) |
|
| 22 |
+
| Submission slide deck | [submission_assets/code_review_openenv_submission.pptx](submission_assets/code_review_openenv_submission.pptx) |
|
| 23 |
+
| Training artifacts folder | [outputs/README.md](outputs/README.md) |
|
| 24 |
+
|
| 25 |
+
## Why this environment
|
| 26 |
+
|
| 27 |
+
Most code agents are evaluated on static patch generation. Real review work is messier:
|
| 28 |
+
|
| 29 |
+
- you have to diagnose the failure mode before patching
|
| 30 |
+
- you often need tool feedback before you know whether the fix is safe
|
| 31 |
+
- you may need to explain the fix to another developer before it is accepted
|
| 32 |
+
|
| 33 |
+
This environment turns that workflow into a multi-step RL setting with dense rewards and stateful interaction.
|
| 34 |
+
|
| 35 |
+
## How the environment works
|
| 36 |
+
|
| 37 |
+
Each episode samples one injected bug from five difficulty bands:
|
| 38 |
+
|
| 39 |
+
1. `easy`: null checks, missing defaults, simple indexing mistakes
|
| 40 |
+
2. `medium`: off-by-one and wrong-operator bugs
|
| 41 |
+
3. `hard`: numerical safety failures like divide-by-zero
|
| 42 |
+
4. `harder`: concurrency issues like missing locks
|
| 43 |
+
5. `hardest`: deadlock and coordination mistakes
|
| 44 |
+
|
| 45 |
+
The agent can take actions such as:
|
| 46 |
+
|
| 47 |
+
- `inspect`
|
| 48 |
+
- `run_tests`
|
| 49 |
+
- `run_linter`
|
| 50 |
+
- `query_docs`
|
| 51 |
+
- `fix`
|
| 52 |
+
- `comment`
|
| 53 |
+
- `question`
|
| 54 |
+
- `done`
|
| 55 |
+
|
| 56 |
+
Rewards combine test delta, lint delta, tool usage, exploration behavior, step penalties, and terminal success. The observation includes the current code, latest tool output, previous scores, author confidence, progress counters, and recent action history.
|
| 57 |
+
|
| 58 |
+
## OpenEnv-first setup
|
| 59 |
+
|
| 60 |
+
This repo is structured as an OpenEnv environment rather than a custom one-off app:
|
| 61 |
+
|
| 62 |
+
- the environment metadata lives in [openenv.yaml](openenv.yaml)
|
| 63 |
+
- the Space is configured as a Docker-based OpenEnv deployment
|
| 64 |
+
- runtime dependencies are kept lightweight for the Space build
|
| 65 |
+
- training-only packages live separately so judges can run the environment without pulling the full training stack
|
| 66 |
+
|
| 67 |
+
The project now targets `openenv-core>=0.2.3`.
|
| 68 |
+
|
| 69 |
+
## Training
|
| 70 |
+
|
| 71 |
+
The main training entrypoint is [training.py](training.py), which uses Unsloth plus a PPO-style loop over real environment interaction. For judges who want a rerunnable workflow, the repo also includes a Colab-ready notebook:
|
| 72 |
+
|
| 73 |
+
- [notebooks/code_review_unsloth_training.ipynb](notebooks/code_review_unsloth_training.ipynb)
|
| 74 |
+
|
| 75 |
+
### Install locally
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
pip install -e .
|
| 79 |
+
pip install -r requirements-training.txt
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
### Run training
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
python training.py
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
The training run writes the evidence plots in the working directory:
|
| 89 |
+
|
| 90 |
+
- `warmup_loss.png`
|
| 91 |
+
- `reward_curve.png`
|
| 92 |
+
- `loss_curve.png`
|
| 93 |
+
- `training_summary.png`
|
| 94 |
+
|
| 95 |
+
For submission hygiene, copy a real run into `outputs/<run-name>/` and link that folder from this README before final judging.
|
| 96 |
+
|
| 97 |
+
## Results and evidence
|
| 98 |
+
|
| 99 |
+
The expected evidence bundle for a real training run is:
|
| 100 |
+
|
| 101 |
+
- warm-up loss curve
|
| 102 |
+
- PPO reward curve
|
| 103 |
+
- PPO loss curve
|
| 104 |
+
- combined summary panel
|
| 105 |
+
|
| 106 |
+
Use [outputs/README.md](outputs/README.md) as the landing page for committed run artifacts.
|
| 107 |
+
|
| 108 |
+
## Submission materials
|
| 109 |
+
|
| 110 |
+
This repo is set up so every judge-facing artifact can be reached from this README:
|
| 111 |
+
|
| 112 |
+
- environment Space: [100XZX001/CodeReview-Professional-Workflow](https://huggingface.co/spaces/100XZX001/CodeReview-Professional-Workflow)
|
| 113 |
+
- training notebook: [notebooks/code_review_unsloth_training.ipynb](notebooks/code_review_unsloth_training.ipynb)
|
| 114 |
+
- slide deck: [submission_assets/code_review_openenv_submission.pptx](submission_assets/code_review_openenv_submission.pptx)
|
| 115 |
+
- evidence folder: [outputs/README.md](outputs/README.md)
|
| 116 |
+
|
| 117 |
+
No large video files are stored in the repo; any future video or blog submission should be linked by URL from this README.
|
__init__.py
CHANGED
|
@@ -4,13 +4,15 @@
|
|
| 4 |
# This source code is licensed under the BSD-style license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
from .client import
|
| 10 |
-
from .models import
|
| 11 |
-
|
| 12 |
-
__all__ = [
|
| 13 |
-
"
|
| 14 |
-
"
|
| 15 |
-
"
|
| 16 |
-
|
|
|
|
|
|
|
|
|
| 4 |
# This source code is licensed under the BSD-style license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
+
"""Code Review Professional Workflow OpenEnv package."""
|
| 8 |
+
|
| 9 |
+
from .client import CodeReviewEnv
|
| 10 |
+
from .models import AnyAction, Observation, Reward, State
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"AnyAction",
|
| 14 |
+
"Observation",
|
| 15 |
+
"Reward",
|
| 16 |
+
"State",
|
| 17 |
+
"CodeReviewEnv",
|
| 18 |
+
]
|
app.py
CHANGED
|
@@ -1,104 +1,104 @@
|
|
| 1 |
-
# server/app.py β OpenEnv HTTP server
|
| 2 |
-
import sys
|
| 3 |
-
import os
|
| 4 |
-
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
| 5 |
-
|
| 6 |
-
from fastapi import FastAPI, HTTPException
|
| 7 |
-
from environment import CodeReviewEnv
|
| 8 |
-
from models import AnyAction, Observation, Reward, State, action_adapter
|
| 9 |
-
|
| 10 |
-
app = FastAPI(title="Code Review Environment", version="1.0.0")
|
| 11 |
-
env = CodeReviewEnv()
|
| 12 |
-
|
| 13 |
-
# ----------------------------------------------------------------------
|
| 14 |
-
# Health & metadata endpoints
|
| 15 |
-
# ----------------------------------------------------------------------
|
| 16 |
-
@app.get("/")
|
| 17 |
-
def root():
|
| 18 |
-
print("[ROOT] Health check hit")
|
| 19 |
-
return {"status": "crazy good"}
|
| 20 |
-
|
| 21 |
-
@app.get("/health")
|
| 22 |
-
def health():
|
| 23 |
-
print("[HEALTH] Service is healthy")
|
| 24 |
-
return {"status": "healthy"}
|
| 25 |
-
|
| 26 |
-
@app.get("/metadata")
|
| 27 |
-
def metadata():
|
| 28 |
-
print("[METADATA] Requested")
|
| 29 |
-
return {
|
| 30 |
-
"name": "Code Review Professional Workflow",
|
| 31 |
-
"description": (
|
| 32 |
-
"Multiβturn code review environment for professionalβlevel bug fixing. "
|
| 33 |
-
"The agent must inspect, test, lint, query documentation, and negotiate with "
|
| 34 |
-
"a simulated (personaβdriven) author to get a fix accepted. "
|
| 35 |
-
"Includes 25 bugs across 5 difficulty levels, ASTβbased injection, "
|
| 36 |
-
"a rewardβshaping system (full/core profiles), and curriculum learning. "
|
| 37 |
-
"Designed for RL training (PPO, DPO, or any policyβgradient method)."
|
| 38 |
-
)
|
| 39 |
-
}
|
| 40 |
-
|
| 41 |
-
@app.get("/schema")
|
| 42 |
-
def schema():
|
| 43 |
-
print("[SCHEMA] Requested")
|
| 44 |
-
return {
|
| 45 |
-
"action": AnyAction.model_json_schema(),
|
| 46 |
-
"observation": Observation.model_json_schema(),
|
| 47 |
-
"state": State.model_json_schema()
|
| 48 |
-
}
|
| 49 |
-
|
| 50 |
-
@app.post("/mcp")
|
| 51 |
-
def mcp():
|
| 52 |
-
print("[MCP] Ping received")
|
| 53 |
-
return {"jsonrpc": "2.0", "result": None}
|
| 54 |
-
|
| 55 |
-
# ----------------------------------------------------------------------
|
| 56 |
-
# Environment endpoints
|
| 57 |
-
# ----------------------------------------------------------------------
|
| 58 |
-
@app.post("/reset")
|
| 59 |
-
def reset(task: str = "easy"):
|
| 60 |
-
try:
|
| 61 |
-
print(f"[RESET] Starting new episode | task={task}")
|
| 62 |
-
|
| 63 |
-
env.set_task(task)
|
| 64 |
-
obs = env.reset()
|
| 65 |
-
|
| 66 |
-
print(f"[RESET DONE] step={env._step_count}")
|
| 67 |
-
|
| 68 |
-
return obs.__dict__
|
| 69 |
-
except Exception as e:
|
| 70 |
-
print(f"[RESET ERROR] {e}")
|
| 71 |
-
raise HTTPException(status_code=400, detail=str(e))
|
| 72 |
-
|
| 73 |
-
@app.post("/step")
|
| 74 |
-
def step(action: dict):
|
| 75 |
-
try:
|
| 76 |
-
print(f"[STEP INPUT] {action}")
|
| 77 |
-
|
| 78 |
-
parsed_action = action_adapter.validate_python(action)
|
| 79 |
-
obs, reward, done, info = env.step(parsed_action)
|
| 80 |
-
|
| 81 |
-
print(f"[STEP OUTPUT] reward={reward.value:.4f} | done={done}")
|
| 82 |
-
|
| 83 |
-
return {
|
| 84 |
-
"observation": obs.__dict__,
|
| 85 |
-
"reward": reward.value,
|
| 86 |
-
"done": done,
|
| 87 |
-
"info": info
|
| 88 |
-
}
|
| 89 |
-
except Exception as e:
|
| 90 |
-
print(f"[STEP ERROR] {e}")
|
| 91 |
-
raise HTTPException(status_code=400, detail=str(e))
|
| 92 |
-
|
| 93 |
-
@app.get("/state")
|
| 94 |
-
def state():
|
| 95 |
-
print("[STATE] Requested")
|
| 96 |
-
return env._get_observation().__dict__
|
| 97 |
-
|
| 98 |
-
# ----------------------------------------------------------------------
|
| 99 |
-
# Main entry point (for local testing)
|
| 100 |
-
# ----------------------------------------------------------------------
|
| 101 |
-
if __name__ == "__main__":
|
| 102 |
-
import uvicorn
|
| 103 |
-
print("[SERVER START] Running on http://0.0.0.0:7860")
|
| 104 |
-
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 1 |
+
# server/app.py β OpenEnv HTTP server
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
| 5 |
+
|
| 6 |
+
from fastapi import FastAPI, HTTPException
|
| 7 |
+
from environment import CodeReviewEnv
|
| 8 |
+
from models import AnyAction, Observation, Reward, State, action_adapter
|
| 9 |
+
|
| 10 |
+
app = FastAPI(title="Code Review Environment", version="1.0.0")
|
| 11 |
+
env = CodeReviewEnv()
|
| 12 |
+
|
| 13 |
+
# ----------------------------------------------------------------------
|
| 14 |
+
# Health & metadata endpoints
|
| 15 |
+
# ----------------------------------------------------------------------
|
| 16 |
+
@app.get("/")
|
| 17 |
+
def root():
|
| 18 |
+
print("[ROOT] Health check hit")
|
| 19 |
+
return {"status": "crazy good"}
|
| 20 |
+
|
| 21 |
+
@app.get("/health")
|
| 22 |
+
def health():
|
| 23 |
+
print("[HEALTH] Service is healthy")
|
| 24 |
+
return {"status": "healthy"}
|
| 25 |
+
|
| 26 |
+
@app.get("/metadata")
|
| 27 |
+
def metadata():
|
| 28 |
+
print("[METADATA] Requested")
|
| 29 |
+
return {
|
| 30 |
+
"name": "Code Review Professional Workflow",
|
| 31 |
+
"description": (
|
| 32 |
+
"Multiβturn code review environment for professionalβlevel bug fixing. "
|
| 33 |
+
"The agent must inspect, test, lint, query documentation, and negotiate with "
|
| 34 |
+
"a simulated (personaβdriven) author to get a fix accepted. "
|
| 35 |
+
"Includes 25 bugs across 5 difficulty levels, ASTβbased injection, "
|
| 36 |
+
"a rewardβshaping system (full/core profiles), and curriculum learning. "
|
| 37 |
+
"Designed for RL training (PPO, DPO, or any policyβgradient method)."
|
| 38 |
+
)
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
@app.get("/schema")
|
| 42 |
+
def schema():
|
| 43 |
+
print("[SCHEMA] Requested")
|
| 44 |
+
return {
|
| 45 |
+
"action": AnyAction.model_json_schema(),
|
| 46 |
+
"observation": Observation.model_json_schema(),
|
| 47 |
+
"state": State.model_json_schema()
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
@app.post("/mcp")
|
| 51 |
+
def mcp():
|
| 52 |
+
print("[MCP] Ping received")
|
| 53 |
+
return {"jsonrpc": "2.0", "result": None}
|
| 54 |
+
|
| 55 |
+
# ----------------------------------------------------------------------
|
| 56 |
+
# Environment endpoints
|
| 57 |
+
# ----------------------------------------------------------------------
|
| 58 |
+
@app.post("/reset")
|
| 59 |
+
def reset(task: str = "easy"):
|
| 60 |
+
try:
|
| 61 |
+
print(f"[RESET] Starting new episode | task={task}")
|
| 62 |
+
|
| 63 |
+
env.set_task(task)
|
| 64 |
+
obs = env.reset()
|
| 65 |
+
|
| 66 |
+
print(f"[RESET DONE] step={env._step_count}")
|
| 67 |
+
|
| 68 |
+
return obs.__dict__
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"[RESET ERROR] {e}")
|
| 71 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 72 |
+
|
| 73 |
+
@app.post("/step")
|
| 74 |
+
def step(action: dict):
|
| 75 |
+
try:
|
| 76 |
+
print(f"[STEP INPUT] {action}")
|
| 77 |
+
|
| 78 |
+
parsed_action = action_adapter.validate_python(action)
|
| 79 |
+
obs, reward, done, info = env.step(parsed_action)
|
| 80 |
+
|
| 81 |
+
print(f"[STEP OUTPUT] reward={reward.value:.4f} | done={done}")
|
| 82 |
+
|
| 83 |
+
return {
|
| 84 |
+
"observation": obs.__dict__,
|
| 85 |
+
"reward": reward.value,
|
| 86 |
+
"done": done,
|
| 87 |
+
"info": info
|
| 88 |
+
}
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print(f"[STEP ERROR] {e}")
|
| 91 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 92 |
+
|
| 93 |
+
@app.get("/state")
|
| 94 |
+
def state():
|
| 95 |
+
print("[STATE] Requested")
|
| 96 |
+
return env._get_observation().__dict__
|
| 97 |
+
|
| 98 |
+
# ----------------------------------------------------------------------
|
| 99 |
+
# Main entry point (for local testing)
|
| 100 |
+
# ----------------------------------------------------------------------
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
import uvicorn
|
| 103 |
+
print("[SERVER START] Running on http://0.0.0.0:7860")
|
| 104 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
environment.py
CHANGED
|
@@ -1,628 +1,628 @@
|
|
| 1 |
-
# environment.py β FULLY CORRECTED RL Environment (TRUE Markov + Fixed Bugs)
|
| 2 |
-
|
| 3 |
-
import sys
|
| 4 |
-
import subprocess
|
| 5 |
-
import tempfile
|
| 6 |
-
import os
|
| 7 |
-
import re
|
| 8 |
-
from dataclasses import dataclass, field
|
| 9 |
-
from typing import Tuple, Dict, Any, Optional, List
|
| 10 |
-
|
| 11 |
-
from models import (
|
| 12 |
-
AnyAction, WriteComment, ProposeFix, Execute, Inspect,
|
| 13 |
-
RunLinter, RunTests, QueryDocs, Skip, Done, AskQuestion,
|
| 14 |
-
Observation, Reward, State
|
| 15 |
-
)
|
| 16 |
-
from redteam import RedTeam
|
| 17 |
-
from test_runner import TestRunner
|
| 18 |
-
from author import PersonaAuthor
|
| 19 |
-
from rltool import ToolBox
|
| 20 |
-
from rubrics import (
|
| 21 |
-
ToolUsageRubric,
|
| 22 |
-
TestDeltaRubric,
|
| 23 |
-
LintDeltaRubric,
|
| 24 |
-
TerminalSuccessRubric,
|
| 25 |
-
ExplorationRubric,
|
| 26 |
-
AntiHackingRubric,
|
| 27 |
-
StepPenaltyRubric,
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
-
# ======================================================================
|
| 31 |
-
# FULLY MARKOV OBSERVATION (NOTHING HIDDEN)
|
| 32 |
-
# ======================================================================
|
| 33 |
-
@dataclass
|
| 34 |
-
class EnhancedObservation:
|
| 35 |
-
code_snippet: str
|
| 36 |
-
last_tool_output: str
|
| 37 |
-
|
| 38 |
-
current_test_score: float
|
| 39 |
-
current_lint_score: float
|
| 40 |
-
negotiation_score: float
|
| 41 |
-
|
| 42 |
-
previous_test_score: float
|
| 43 |
-
previous_lint_score: float
|
| 44 |
-
|
| 45 |
-
author_confidence: float
|
| 46 |
-
author_threshold: float
|
| 47 |
-
|
| 48 |
-
step: int
|
| 49 |
-
max_steps: int
|
| 50 |
-
progress_ratio: float
|
| 51 |
-
|
| 52 |
-
tests_run: bool
|
| 53 |
-
linter_run: bool
|
| 54 |
-
docs_queried: bool
|
| 55 |
-
|
| 56 |
-
last_action_type: str
|
| 57 |
-
action_history: List[str]
|
| 58 |
-
|
| 59 |
-
done: bool
|
| 60 |
-
|
| 61 |
-
bug_description: str
|
| 62 |
-
comments_count: int
|
| 63 |
-
|
| 64 |
-
# default fields must be at the very end
|
| 65 |
-
author_response: str = ""
|
| 66 |
-
|
| 67 |
-
# ======================================================================
|
| 68 |
-
# HELPER FUNCTIONS
|
| 69 |
-
# ======================================================================
|
| 70 |
-
def execute_code(code: str, timeout_sec: int = 5) -> Tuple[bool, str, str]:
|
| 71 |
-
if not code.strip():
|
| 72 |
-
return False, "", "Error: Empty code"
|
| 73 |
-
|
| 74 |
-
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
|
| 75 |
-
f.write(code)
|
| 76 |
-
tmp_path = f.name
|
| 77 |
-
|
| 78 |
-
try:
|
| 79 |
-
result = subprocess.run(
|
| 80 |
-
[sys.executable, tmp_path],
|
| 81 |
-
capture_output=True,
|
| 82 |
-
text=True,
|
| 83 |
-
timeout=timeout_sec
|
| 84 |
-
)
|
| 85 |
-
success = (result.returncode == 0)
|
| 86 |
-
return success, result.stdout, result.stderr
|
| 87 |
-
except subprocess.TimeoutExpired:
|
| 88 |
-
return False, "", f"Timeout after {timeout_sec}s"
|
| 89 |
-
except Exception as e:
|
| 90 |
-
return False, "", f"Execution error: {str(e)}"
|
| 91 |
-
finally:
|
| 92 |
-
try:
|
| 93 |
-
os.unlink(tmp_path)
|
| 94 |
-
except:
|
| 95 |
-
pass
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
# ======================================================================
|
| 99 |
-
# ENHANCED CODE REVIEW ENVIRONMENT
|
| 100 |
-
# ======================================================================
|
| 101 |
-
@dataclass
|
| 102 |
-
class CodeReviewEnv:
|
| 103 |
-
task: str = "easy"
|
| 104 |
-
max_steps: int = 10
|
| 105 |
-
step_penalty: float = 0.01
|
| 106 |
-
reward_profile: str = "full" # "full" or "core"
|
| 107 |
-
|
| 108 |
-
# Curriculum learning
|
| 109 |
-
auto_difficulty: bool = False
|
| 110 |
-
success_threshold: float = 0.7
|
| 111 |
-
|
| 112 |
-
# Reward shaping parameters
|
| 113 |
-
delta_weight: float = 0.3
|
| 114 |
-
tool_usage_bonus: float = 0.05
|
| 115 |
-
diversity_bonus: float = 0.03
|
| 116 |
-
|
| 117 |
-
_red_team: Optional[RedTeam] = field(init=False, default=None)
|
| 118 |
-
_author: Optional[PersonaAuthor] = field(init=False, default=None)
|
| 119 |
-
|
| 120 |
-
_current_code: str = field(init=False, default="")
|
| 121 |
-
_current_bug_id: str = field(init=False, default="")
|
| 122 |
-
_bug_description: str = field(init=False, default="")
|
| 123 |
-
_oracle_fix: str = field(init=False, default="")
|
| 124 |
-
|
| 125 |
-
_comments: list = field(init=False, default_factory=list)
|
| 126 |
-
_test_results: Optional[str] = field(init=False, default=None)
|
| 127 |
-
_lint_results: Optional[str] = field(init=False, default=None)
|
| 128 |
-
_doc_results: Optional[str] = field(init=False, default=None)
|
| 129 |
-
|
| 130 |
-
_step_count: int = field(init=False, default=0)
|
| 131 |
-
_done: bool = field(init=False, default=False)
|
| 132 |
-
|
| 133 |
-
# State tracking for dense rewards
|
| 134 |
-
_previous_test_score: float = field(init=False, default=0.0)
|
| 135 |
-
_previous_lint_score: float = field(init=False, default=0.0)
|
| 136 |
-
_current_test_score: float = field(init=False, default=0.0)
|
| 137 |
-
_current_lint_score: float = field(init=False, default=0.0)
|
| 138 |
-
|
| 139 |
-
# Tool usage tracking
|
| 140 |
-
_tests_run: bool = field(init=False, default=False)
|
| 141 |
-
_linter_run: bool = field(init=False, default=False)
|
| 142 |
-
_docs_queried: bool = field(init=False, default=False)
|
| 143 |
-
|
| 144 |
-
# Action history
|
| 145 |
-
_action_history: List[str] = field(init=False, default_factory=list)
|
| 146 |
-
_last_action_type: str = field(init=False, default="none")
|
| 147 |
-
_last_author_response: str = field(init=False, default="")
|
| 148 |
-
|
| 149 |
-
# FIXED: Track CUMULATIVE episode reward
|
| 150 |
-
_episode_total_reward: float = field(init=False, default=0.0)
|
| 151 |
-
_episode_rewards: List[float] = field(init=False, default_factory=list)
|
| 152 |
-
_difficulty_level: int = field(init=False, default=0)
|
| 153 |
-
|
| 154 |
-
# Bug-id bridge:
|
| 155 |
-
# RedTeam has fine-grained IDs, while TestRunner currently expects a
|
| 156 |
-
# smaller canonical set. Keep this mapping here so both modules can evolve
|
| 157 |
-
# independently without breaking evaluation.
|
| 158 |
-
_BUG_ID_CANONICAL_MAP = {
|
| 159 |
-
# Easy-family
|
| 160 |
-
"simple_typo": "null_check",
|
| 161 |
-
"default_value": "null_check",
|
| 162 |
-
"empty_return": "null_check",
|
| 163 |
-
"string_index": "off_by_one",
|
| 164 |
-
|
| 165 |
-
# Medium-family
|
| 166 |
-
"loop_skip": "off_by_one",
|
| 167 |
-
"sign_error": "wrong_operator",
|
| 168 |
-
"swap_args": "wrong_operator",
|
| 169 |
-
"uninitialised_var": "null_check",
|
| 170 |
-
|
| 171 |
-
# Hard-family
|
| 172 |
-
"division_by_zero_empty": "division_by_zero",
|
| 173 |
-
"division_by_zero_zero": "division_by_zero",
|
| 174 |
-
"float_precision": "division_by_zero",
|
| 175 |
-
"abs_usage": "division_by_zero",
|
| 176 |
-
"round_error": "division_by_zero",
|
| 177 |
-
}
|
| 178 |
-
|
| 179 |
-
# ===================================================================
|
| 180 |
-
def __post_init__(self):
|
| 181 |
-
self.set_task(self.task)
|
| 182 |
-
|
| 183 |
-
# ===================================================================
|
| 184 |
-
def _build_rubrics(self):
|
| 185 |
-
"""
|
| 186 |
-
Build rubric stack from a named reward profile.
|
| 187 |
-
- full: richer shaping for exploration/tool-use behavior
|
| 188 |
-
- core: minimal stable signal for quick ablations/baselines
|
| 189 |
-
"""
|
| 190 |
-
core_rubrics = [
|
| 191 |
-
TestDeltaRubric(weight=self.delta_weight),
|
| 192 |
-
LintDeltaRubric(weight=self.delta_weight),
|
| 193 |
-
TerminalSuccessRubric(),
|
| 194 |
-
StepPenaltyRubric(penalty=self.step_penalty),
|
| 195 |
-
]
|
| 196 |
-
if self.reward_profile == "core":
|
| 197 |
-
return core_rubrics
|
| 198 |
-
if self.reward_profile == "full":
|
| 199 |
-
return [
|
| 200 |
-
*core_rubrics[:-1], # step penalty appended at end for consistent ordering
|
| 201 |
-
ToolUsageRubric(bonus=self.tool_usage_bonus),
|
| 202 |
-
ExplorationRubric(penalty=-0.05, bonus=self.diversity_bonus * 0.7),
|
| 203 |
-
AntiHackingRubric(),
|
| 204 |
-
core_rubrics[-1],
|
| 205 |
-
]
|
| 206 |
-
raise ValueError(f"Unknown reward_profile: {self.reward_profile}")
|
| 207 |
-
|
| 208 |
-
# ===================================================================
|
| 209 |
-
def set_task(self, task: str):
|
| 210 |
-
if task not in ["easy", "medium", "hard", "harder", "hardest"]:
|
| 211 |
-
raise ValueError(f"Unknown task: {task}")
|
| 212 |
-
|
| 213 |
-
self.task = task
|
| 214 |
-
# Use stochastic bug sampling across episodes; fixed seed here would
|
| 215 |
-
# repeatedly select the same bug and weaken training diversity.
|
| 216 |
-
self._red_team = RedTeam(task, seed=None)
|
| 217 |
-
self._author = PersonaAuthor()
|
| 218 |
-
self.rubrics = self._build_rubrics()
|
| 219 |
-
|
| 220 |
-
task_to_level = {
|
| 221 |
-
"easy": 0, "medium": 1, "hard": 2,
|
| 222 |
-
"harder": 3, "hardest": 4
|
| 223 |
-
}
|
| 224 |
-
self._difficulty_level = task_to_level[task]
|
| 225 |
-
|
| 226 |
-
self._reset_internal()
|
| 227 |
-
|
| 228 |
-
# ===================================================================
|
| 229 |
-
def _reset_internal(self):
|
| 230 |
-
self._step_count = 0 # β FIXED
|
| 231 |
-
self._comments = []
|
| 232 |
-
self._test_results = None
|
| 233 |
-
self._lint_results = None
|
| 234 |
-
self._doc_results = None
|
| 235 |
-
self._done = False
|
| 236 |
-
|
| 237 |
-
# Reset state tracking
|
| 238 |
-
self._previous_test_score = 0.0
|
| 239 |
-
self._previous_lint_score = 0.0
|
| 240 |
-
self._current_test_score = 0.0
|
| 241 |
-
self._current_lint_score = 0.0
|
| 242 |
-
|
| 243 |
-
self._tests_run = False
|
| 244 |
-
self._linter_run = False
|
| 245 |
-
self._docs_queried = False
|
| 246 |
-
|
| 247 |
-
self._action_history = []
|
| 248 |
-
self._last_action_type = "none"
|
| 249 |
-
self._last_author_response = ""
|
| 250 |
-
|
| 251 |
-
# FIXED: Reset episode cumulative reward
|
| 252 |
-
self._episode_total_reward = 0.0
|
| 253 |
-
|
| 254 |
-
self._author.reset()
|
| 255 |
-
|
| 256 |
-
# Base tasks
|
| 257 |
-
if self.task == "easy":
|
| 258 |
-
original = "def get_user(id):\n if id in users:\n return users[id]"
|
| 259 |
-
elif self.task == "medium":
|
| 260 |
-
original = "def process_items(items):\n for item in items:\n print(item)"
|
| 261 |
-
elif self.task == "hard":
|
| 262 |
-
original = "def average(data):\n if not data:\n return 0\n return sum(data) / len(data)"
|
| 263 |
-
elif self.task == "harder":
|
| 264 |
-
original = "counter = 0\ndef increment():\n global counter\n with lock:\n counter += 1"
|
| 265 |
-
else:
|
| 266 |
-
original = "def safe_work():\n with lock1:\n with lock2:\n do_work()"
|
| 267 |
-
|
| 268 |
-
buggy_code, bug_id, desc, oracle = self._red_team.inject_bug(original)
|
| 269 |
-
self._current_code = buggy_code
|
| 270 |
-
self._current_bug_id = bug_id
|
| 271 |
-
self._bug_description = desc
|
| 272 |
-
self._oracle_fix = oracle
|
| 273 |
-
self._comments.append(f"[RedTeam] {desc}")
|
| 274 |
-
|
| 275 |
-
# ===================================================================
|
| 276 |
-
def reset(self) -> EnhancedObservation:
|
| 277 |
-
"""Reset with optional curriculum adjustment."""
|
| 278 |
-
if self.auto_difficulty and len(self._episode_rewards) > 0:
|
| 279 |
-
recent_performance = sum(self._episode_rewards[-5:]) / min(5, len(self._episode_rewards))
|
| 280 |
-
|
| 281 |
-
if recent_performance > self.success_threshold and self._difficulty_level < 4:
|
| 282 |
-
self._difficulty_level += 1
|
| 283 |
-
print(f"[Curriculum] Increasing difficulty to level {self._difficulty_level}")
|
| 284 |
-
elif recent_performance < 0.3 and self._difficulty_level > 0:
|
| 285 |
-
self._difficulty_level -= 1
|
| 286 |
-
print(f"[Curriculum] Decreasing difficulty to level {self._difficulty_level}")
|
| 287 |
-
|
| 288 |
-
level_to_task = {0: "easy", 1: "medium", 2: "hard", 3: "harder", 4: "hardest"}
|
| 289 |
-
self.task = level_to_task[self._difficulty_level]
|
| 290 |
-
# Keep curriculum stochastic for better coverage within each level.
|
| 291 |
-
self._red_team = RedTeam(self.task, seed=None)
|
| 292 |
-
|
| 293 |
-
self._reset_internal()
|
| 294 |
-
return self._get_observation()
|
| 295 |
-
|
| 296 |
-
# ===================================================================
|
| 297 |
-
def _get_observation(self) -> EnhancedObservation:
|
| 298 |
-
"""Return COMPLETE Markov state."""
|
| 299 |
-
# Keep the author's message separate from tool output.
|
| 300 |
-
# Using `_test_results` here can leak unrelated outputs (tests/linter/docs)
|
| 301 |
-
# and gives the policy a noisy signal for dialogue actions.
|
| 302 |
-
if self._last_action_type in ("comment", "question", "fix"):
|
| 303 |
-
author_response = self._last_author_response
|
| 304 |
-
else:
|
| 305 |
-
author_response = ""
|
| 306 |
-
|
| 307 |
-
return EnhancedObservation(
|
| 308 |
-
code_snippet=self._current_code,
|
| 309 |
-
last_tool_output=self._test_results or "",
|
| 310 |
-
author_response=author_response, # β now field exists
|
| 311 |
-
|
| 312 |
-
current_test_score=self._current_test_score,
|
| 313 |
-
current_lint_score=self._current_lint_score,
|
| 314 |
-
negotiation_score=self._author.get_negotiation_score(),
|
| 315 |
-
|
| 316 |
-
previous_test_score=self._previous_test_score,
|
| 317 |
-
previous_lint_score=self._previous_lint_score,
|
| 318 |
-
|
| 319 |
-
author_confidence=self._author._confidence,
|
| 320 |
-
author_threshold=self._author.thresholds.get(self._author.personality, 0.5),
|
| 321 |
-
|
| 322 |
-
step=self._step_count,
|
| 323 |
-
max_steps=self.max_steps,
|
| 324 |
-
# Guard against accidental `max_steps=0` configs.
|
| 325 |
-
progress_ratio=(self._step_count / self.max_steps) if self.max_steps > 0 else 1.0,
|
| 326 |
-
|
| 327 |
-
tests_run=self._tests_run,
|
| 328 |
-
linter_run=self._linter_run,
|
| 329 |
-
docs_queried=self._docs_queried,
|
| 330 |
-
|
| 331 |
-
last_action_type=self._last_action_type,
|
| 332 |
-
action_history=self._action_history[-5:],
|
| 333 |
-
|
| 334 |
-
done=self._done,
|
| 335 |
-
|
| 336 |
-
bug_description=self._bug_description,
|
| 337 |
-
comments_count=len(self._comments),
|
| 338 |
-
)
|
| 339 |
-
|
| 340 |
-
# ===================================================================
|
| 341 |
-
def _get_action_type(self, action: AnyAction) -> str:
|
| 342 |
-
"""Extract action type as string."""
|
| 343 |
-
if isinstance(action, RunTests):
|
| 344 |
-
return "run_tests"
|
| 345 |
-
elif isinstance(action, RunLinter):
|
| 346 |
-
return "run_linter"
|
| 347 |
-
elif isinstance(action, QueryDocs):
|
| 348 |
-
return "query_docs"
|
| 349 |
-
elif isinstance(action, Execute):
|
| 350 |
-
return "execute"
|
| 351 |
-
elif isinstance(action, Inspect):
|
| 352 |
-
return "inspect"
|
| 353 |
-
elif isinstance(action, WriteComment):
|
| 354 |
-
return "comment"
|
| 355 |
-
elif isinstance(action, AskQuestion):
|
| 356 |
-
return "question"
|
| 357 |
-
elif isinstance(action, ProposeFix):
|
| 358 |
-
return "fix"
|
| 359 |
-
elif isinstance(action, Done):
|
| 360 |
-
return "done"
|
| 361 |
-
elif isinstance(action, Skip):
|
| 362 |
-
return "skip"
|
| 363 |
-
else:
|
| 364 |
-
return "unknown"
|
| 365 |
-
|
| 366 |
-
# ===================================================================
|
| 367 |
-
def _get_test_runner_bug_id(self) -> str:
|
| 368 |
-
"""
|
| 369 |
-
Normalize RedTeam bug ids to the canonical ids understood by TestRunner.
|
| 370 |
-
Falls back to the original id for known direct matches.
|
| 371 |
-
"""
|
| 372 |
-
return self._BUG_ID_CANONICAL_MAP.get(self._current_bug_id, self._current_bug_id)
|
| 373 |
-
|
| 374 |
-
# ===================================================================
|
| 375 |
-
def step(self, action: AnyAction) -> Tuple[EnhancedObservation, Reward, bool, Dict[str, Any]]:
|
| 376 |
-
"""
|
| 377 |
-
TRUE RL STEP with:
|
| 378 |
-
- Complete Markov observations (no hidden state)
|
| 379 |
-
- Dense intermediate rewards
|
| 380 |
-
- Delta-based credit assignment (no double-counting)
|
| 381 |
-
- Proper episode reward tracking
|
| 382 |
-
"""
|
| 383 |
-
if self._done:
|
| 384 |
-
raise RuntimeError("Episode already finished")
|
| 385 |
-
|
| 386 |
-
# Store previous metrics for delta computation
|
| 387 |
-
self._previous_test_score = self._current_test_score
|
| 388 |
-
self._previous_lint_score = self._current_lint_score
|
| 389 |
-
# Snapshot tool-usage flags BEFORE action mutates them.
|
| 390 |
-
# Rubrics use these to detect true "first-use" behavior.
|
| 391 |
-
prev_tests_run = self._tests_run
|
| 392 |
-
prev_linter_run = self._linter_run
|
| 393 |
-
prev_docs_queried = self._docs_queried
|
| 394 |
-
|
| 395 |
-
base_reward = 0.0
|
| 396 |
-
action_type = self._get_action_type(action)
|
| 397 |
-
|
| 398 |
-
# Update action history
|
| 399 |
-
self._action_history.append(action_type)
|
| 400 |
-
self._last_action_type = action_type
|
| 401 |
-
|
| 402 |
-
# ==============================================================
|
| 403 |
-
# TOOL ACTIONS
|
| 404 |
-
# ==============================================================
|
| 405 |
-
if isinstance(action, Execute):
|
| 406 |
-
success, stdout, stderr = execute_code(self._current_code)
|
| 407 |
-
output = (stdout + stderr).strip() or "No output"
|
| 408 |
-
self._test_results = f"[Execute] {'Success' if success else 'Failed'}\n{output[:300]}"
|
| 409 |
-
base_reward = 0.001 if success else -0.05
|
| 410 |
-
|
| 411 |
-
elif isinstance(action, Inspect):
|
| 412 |
-
self._test_results = f"[Inspect]\n{self._current_code[:500]}"
|
| 413 |
-
base_reward = 0.001
|
| 414 |
-
|
| 415 |
-
elif isinstance(action, RunLinter):
|
| 416 |
-
lint_output = ToolBox.run_linter(self._current_code)
|
| 417 |
-
self._lint_results = lint_output[:500]
|
| 418 |
-
self._test_results = f"[Linter]\n{self._lint_results}"
|
| 419 |
-
|
| 420 |
-
self._current_lint_score = self._run_linter_score(self._current_code)
|
| 421 |
-
self._linter_run = True
|
| 422 |
-
base_reward = 0.002
|
| 423 |
-
|
| 424 |
-
elif isinstance(action, RunTests):
|
| 425 |
-
runner = TestRunner(self._get_test_runner_bug_id())
|
| 426 |
-
score, output = runner.run_tests(self._current_code)
|
| 427 |
-
|
| 428 |
-
self._current_test_score = score
|
| 429 |
-
self._tests_run = True
|
| 430 |
-
|
| 431 |
-
self._test_results = f"[Tests] Score: {score:.2f}\n{output[:300]}"
|
| 432 |
-
base_reward = 0.002
|
| 433 |
-
|
| 434 |
-
if score > 0.8:
|
| 435 |
-
base_reward += 0.005
|
| 436 |
-
|
| 437 |
-
elif isinstance(action, QueryDocs):
|
| 438 |
-
# Normalize query to avoid rewarding empty/noisy requests.
|
| 439 |
-
query_topic = (action.query_topic or "").strip()
|
| 440 |
-
doc = ToolBox.query_docs(query_topic if query_topic else "general bug fixing")
|
| 441 |
-
self._doc_results = doc
|
| 442 |
-
self._test_results = f"[Docs]\n{doc[:400]}"
|
| 443 |
-
self._docs_queried = True
|
| 444 |
-
base_reward = 0.001
|
| 445 |
-
|
| 446 |
-
# ==============================================================
|
| 447 |
-
# COMMUNICATION ACTIONS
|
| 448 |
-
# ==============================================================
|
| 449 |
-
elif isinstance(action, WriteComment):
|
| 450 |
-
self._comments.append(f"Agent: {action.comment_text}")
|
| 451 |
-
|
| 452 |
-
response = self._author.respond(
|
| 453 |
-
agent_comment=action.comment_text,
|
| 454 |
-
test_results=self._test_results,
|
| 455 |
-
lint_results=self._lint_results,
|
| 456 |
-
doc_results=self._doc_results,
|
| 457 |
-
proposed_fix=None,
|
| 458 |
-
original_code=self._current_code
|
| 459 |
-
)
|
| 460 |
-
|
| 461 |
-
self._comments.append(f"Author: {response}")
|
| 462 |
-
self._last_author_response = response
|
| 463 |
-
self._test_results = f"[Comment] Author: {response[:200]}"
|
| 464 |
-
base_reward = 0.001
|
| 465 |
-
|
| 466 |
-
elif isinstance(action, AskQuestion):
|
| 467 |
-
self._comments.append(f"Agent: {action.question}")
|
| 468 |
-
|
| 469 |
-
response = self._author.respond(
|
| 470 |
-
agent_question=action.question,
|
| 471 |
-
test_results=self._test_results,
|
| 472 |
-
lint_results=self._lint_results,
|
| 473 |
-
doc_results=self._doc_results,
|
| 474 |
-
proposed_fix=None,
|
| 475 |
-
original_code=self._current_code # β FIXED
|
| 476 |
-
)
|
| 477 |
-
|
| 478 |
-
self._comments.append(f"Author: {response}")
|
| 479 |
-
self._last_author_response = response
|
| 480 |
-
self._test_results = f"[Question] Author: {response[:200]}"
|
| 481 |
-
base_reward = 0.002
|
| 482 |
-
|
| 483 |
-
# ==============================================================
|
| 484 |
-
# FINAL FIX ACTION
|
| 485 |
-
# ==============================================================
|
| 486 |
-
elif isinstance(action, ProposeFix):
|
| 487 |
-
if not action.fix_code:
|
| 488 |
-
base_reward = -0.05
|
| 489 |
-
self._done = True
|
| 490 |
-
else:
|
| 491 |
-
# Save original code BEFORE overwriting (for author.respond)
|
| 492 |
-
original_buggy = self._current_code
|
| 493 |
-
self._current_code = action.fix_code
|
| 494 |
-
|
| 495 |
-
runner = TestRunner(self._get_test_runner_bug_id())
|
| 496 |
-
test_score, test_output = runner.run_tests(self._current_code)
|
| 497 |
-
lint_score = self._run_linter_score(self._current_code)
|
| 498 |
-
negotiation_score = self._author.get_negotiation_score()
|
| 499 |
-
|
| 500 |
-
self._current_test_score = test_score
|
| 501 |
-
self._current_lint_score = lint_score
|
| 502 |
-
|
| 503 |
-
# Author gating β determines if the episode ends, reward is separate
|
| 504 |
-
threshold = self._author.thresholds.get(self._author.personality, 0.5)
|
| 505 |
-
if self._author._confidence < threshold:
|
| 506 |
-
if self._step_count < self.max_steps:
|
| 507 |
-
self._done = False
|
| 508 |
-
else:
|
| 509 |
-
self._done = True
|
| 510 |
-
else:
|
| 511 |
-
self._done = True
|
| 512 |
-
|
| 513 |
-
# Get author's verbal feedback (pushback/acceptance)
|
| 514 |
-
author_feedback = self._author.respond(
|
| 515 |
-
agent_comment=f"Proposed fix:\n{action.fix_code}",
|
| 516 |
-
test_results=f"Score: {test_score:.2f}",
|
| 517 |
-
lint_results=f"Score: {lint_score:.2f}",
|
| 518 |
-
doc_results=self._doc_results,
|
| 519 |
-
proposed_fix=action.fix_code,
|
| 520 |
-
original_code=original_buggy # now correctly the buggy code, not the fix
|
| 521 |
-
)
|
| 522 |
-
self._test_results = f"[Fix] Author: {author_feedback[:200]}"
|
| 523 |
-
self._comments.append(f"Author: {author_feedback}")
|
| 524 |
-
self._last_author_response = author_feedback
|
| 525 |
-
|
| 526 |
-
base_reward = 0.001 # rubrics provide the real signal
|
| 527 |
-
|
| 528 |
-
# ==============================================================
|
| 529 |
-
# TERMINATION ACTIONS
|
| 530 |
-
# ==============================================================
|
| 531 |
-
elif isinstance(action, Skip):
|
| 532 |
-
base_reward = -0.03
|
| 533 |
-
self._done = True
|
| 534 |
-
|
| 535 |
-
elif isinstance(action, Done):
|
| 536 |
-
if self._tests_run:
|
| 537 |
-
base_reward = self._current_test_score * 0.5 - 0.2
|
| 538 |
-
else:
|
| 539 |
-
base_reward = -0.04
|
| 540 |
-
self._done = True
|
| 541 |
-
|
| 542 |
-
else:
|
| 543 |
-
base_reward = -0.02
|
| 544 |
-
self._done = True
|
| 545 |
-
|
| 546 |
-
# ==============================================================
|
| 547 |
-
# STEP UPDATE (before rubric computation so info contains final step)
|
| 548 |
-
# ==============================================================
|
| 549 |
-
self._step_count += 1
|
| 550 |
-
if self._step_count >= self.max_steps:
|
| 551 |
-
self._done = True
|
| 552 |
-
|
| 553 |
-
# Get fresh observation (needed for rubrics that may read obs)
|
| 554 |
-
obs = self._get_observation()
|
| 555 |
-
|
| 556 |
-
# Prepare info dict (rubrics may need action_type and deltas)
|
| 557 |
-
info = {
|
| 558 |
-
"action_type": action_type,
|
| 559 |
-
"test_score": self._current_test_score,
|
| 560 |
-
"lint_score": self._current_lint_score,
|
| 561 |
-
"test_delta": self._current_test_score - self._previous_test_score,
|
| 562 |
-
"lint_delta": self._current_lint_score - self._previous_lint_score,
|
| 563 |
-
"prev_tests_run": prev_tests_run,
|
| 564 |
-
"prev_linter_run": prev_linter_run,
|
| 565 |
-
"prev_docs_queried": prev_docs_queried,
|
| 566 |
-
"docs_query_len": len((action.query_topic or "").strip()) if isinstance(action, QueryDocs) else 0,
|
| 567 |
-
"base_reward": base_reward,
|
| 568 |
-
}
|
| 569 |
-
|
| 570 |
-
# ==============================================================
|
| 571 |
-
# COMPUTE FINAL REWARD USING RUBRICS
|
| 572 |
-
# ==============================================================
|
| 573 |
-
rubric_score = sum(r(self, action, obs, None, self._done, info) for r in self.rubrics)
|
| 574 |
-
final_reward = 0.4 * base_reward + rubric_score
|
| 575 |
-
final_reward = max(-1.0, min(1.0, final_reward)) # safety clip
|
| 576 |
-
|
| 577 |
-
# Track cumulative episode reward
|
| 578 |
-
self._episode_total_reward += final_reward
|
| 579 |
-
|
| 580 |
-
# Store episode total if done
|
| 581 |
-
if self._done:
|
| 582 |
-
self._episode_rewards.append(self._episode_total_reward)
|
| 583 |
-
|
| 584 |
-
# Complete info
|
| 585 |
-
info["final_reward"] = final_reward
|
| 586 |
-
info["episode_total"] = self._episode_total_reward
|
| 587 |
-
|
| 588 |
-
return obs, Reward(value=final_reward), self._done, info
|
| 589 |
-
|
| 590 |
-
# ===================================================================
|
| 591 |
-
def _run_linter_score(self, code: str) -> float:
|
| 592 |
-
"""Run pylint and return normalized score [0, 1]."""
|
| 593 |
-
try:
|
| 594 |
-
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
| 595 |
-
f.write(code)
|
| 596 |
-
tmp_path = f.name
|
| 597 |
-
|
| 598 |
result = subprocess.run(
|
| 599 |
-
['pylint', tmp_path, '--score=y', '--exit-zero'],
|
| 600 |
capture_output=True,
|
| 601 |
text=True,
|
| 602 |
timeout=5
|
| 603 |
-
)
|
| 604 |
-
|
| 605 |
-
match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
|
| 606 |
-
if match:
|
| 607 |
-
return float(match.group(1)) / 10.0
|
| 608 |
-
return 0.0
|
| 609 |
-
except:
|
| 610 |
-
return 0.0
|
| 611 |
-
finally:
|
| 612 |
-
try:
|
| 613 |
-
os.unlink(tmp_path)
|
| 614 |
-
except:
|
| 615 |
-
pass
|
| 616 |
-
|
| 617 |
-
# ===================================================================
|
| 618 |
-
def state(self) -> State:
|
| 619 |
-
"""Legacy compatibility."""
|
| 620 |
-
return State(
|
| 621 |
-
pr_title="Code Review",
|
| 622 |
-
pr_description=self._bug_description,
|
| 623 |
-
code_snippet=self._current_code,
|
| 624 |
-
comments=self._comments.copy(),
|
| 625 |
-
test_results=self._test_results,
|
| 626 |
-
step=self._step_count,
|
| 627 |
-
done=self._done
|
| 628 |
-
)
|
|
|
|
| 1 |
+
# environment.py β FULLY CORRECTED RL Environment (TRUE Markov + Fixed Bugs)
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
import subprocess
|
| 5 |
+
import tempfile
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from typing import Tuple, Dict, Any, Optional, List
|
| 10 |
+
|
| 11 |
+
from models import (
|
| 12 |
+
AnyAction, WriteComment, ProposeFix, Execute, Inspect,
|
| 13 |
+
RunLinter, RunTests, QueryDocs, Skip, Done, AskQuestion,
|
| 14 |
+
Observation, Reward, State
|
| 15 |
+
)
|
| 16 |
+
from redteam import RedTeam
|
| 17 |
+
from test_runner import TestRunner
|
| 18 |
+
from author import PersonaAuthor
|
| 19 |
+
from rltool import ToolBox
|
| 20 |
+
from rubrics import (
|
| 21 |
+
ToolUsageRubric,
|
| 22 |
+
TestDeltaRubric,
|
| 23 |
+
LintDeltaRubric,
|
| 24 |
+
TerminalSuccessRubric,
|
| 25 |
+
ExplorationRubric,
|
| 26 |
+
AntiHackingRubric,
|
| 27 |
+
StepPenaltyRubric,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# ======================================================================
|
| 31 |
+
# FULLY MARKOV OBSERVATION (NOTHING HIDDEN)
|
| 32 |
+
# ======================================================================
|
| 33 |
+
@dataclass
|
| 34 |
+
class EnhancedObservation:
|
| 35 |
+
code_snippet: str
|
| 36 |
+
last_tool_output: str
|
| 37 |
+
|
| 38 |
+
current_test_score: float
|
| 39 |
+
current_lint_score: float
|
| 40 |
+
negotiation_score: float
|
| 41 |
+
|
| 42 |
+
previous_test_score: float
|
| 43 |
+
previous_lint_score: float
|
| 44 |
+
|
| 45 |
+
author_confidence: float
|
| 46 |
+
author_threshold: float
|
| 47 |
+
|
| 48 |
+
step: int
|
| 49 |
+
max_steps: int
|
| 50 |
+
progress_ratio: float
|
| 51 |
+
|
| 52 |
+
tests_run: bool
|
| 53 |
+
linter_run: bool
|
| 54 |
+
docs_queried: bool
|
| 55 |
+
|
| 56 |
+
last_action_type: str
|
| 57 |
+
action_history: List[str]
|
| 58 |
+
|
| 59 |
+
done: bool
|
| 60 |
+
|
| 61 |
+
bug_description: str
|
| 62 |
+
comments_count: int
|
| 63 |
+
|
| 64 |
+
# default fields must be at the very end
|
| 65 |
+
author_response: str = ""
|
| 66 |
+
|
| 67 |
+
# ======================================================================
|
| 68 |
+
# HELPER FUNCTIONS
|
| 69 |
+
# ======================================================================
|
| 70 |
+
def execute_code(code: str, timeout_sec: int = 5) -> Tuple[bool, str, str]:
|
| 71 |
+
if not code.strip():
|
| 72 |
+
return False, "", "Error: Empty code"
|
| 73 |
+
|
| 74 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
|
| 75 |
+
f.write(code)
|
| 76 |
+
tmp_path = f.name
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
result = subprocess.run(
|
| 80 |
+
[sys.executable, tmp_path],
|
| 81 |
+
capture_output=True,
|
| 82 |
+
text=True,
|
| 83 |
+
timeout=timeout_sec
|
| 84 |
+
)
|
| 85 |
+
success = (result.returncode == 0)
|
| 86 |
+
return success, result.stdout, result.stderr
|
| 87 |
+
except subprocess.TimeoutExpired:
|
| 88 |
+
return False, "", f"Timeout after {timeout_sec}s"
|
| 89 |
+
except Exception as e:
|
| 90 |
+
return False, "", f"Execution error: {str(e)}"
|
| 91 |
+
finally:
|
| 92 |
+
try:
|
| 93 |
+
os.unlink(tmp_path)
|
| 94 |
+
except:
|
| 95 |
+
pass
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# ======================================================================
|
| 99 |
+
# ENHANCED CODE REVIEW ENVIRONMENT
|
| 100 |
+
# ======================================================================
|
| 101 |
+
@dataclass
|
| 102 |
+
class CodeReviewEnv:
|
| 103 |
+
task: str = "easy"
|
| 104 |
+
max_steps: int = 10
|
| 105 |
+
step_penalty: float = 0.01
|
| 106 |
+
reward_profile: str = "full" # "full" or "core"
|
| 107 |
+
|
| 108 |
+
# Curriculum learning
|
| 109 |
+
auto_difficulty: bool = False
|
| 110 |
+
success_threshold: float = 0.7
|
| 111 |
+
|
| 112 |
+
# Reward shaping parameters
|
| 113 |
+
delta_weight: float = 0.3
|
| 114 |
+
tool_usage_bonus: float = 0.05
|
| 115 |
+
diversity_bonus: float = 0.03
|
| 116 |
+
|
| 117 |
+
_red_team: Optional[RedTeam] = field(init=False, default=None)
|
| 118 |
+
_author: Optional[PersonaAuthor] = field(init=False, default=None)
|
| 119 |
+
|
| 120 |
+
_current_code: str = field(init=False, default="")
|
| 121 |
+
_current_bug_id: str = field(init=False, default="")
|
| 122 |
+
_bug_description: str = field(init=False, default="")
|
| 123 |
+
_oracle_fix: str = field(init=False, default="")
|
| 124 |
+
|
| 125 |
+
_comments: list = field(init=False, default_factory=list)
|
| 126 |
+
_test_results: Optional[str] = field(init=False, default=None)
|
| 127 |
+
_lint_results: Optional[str] = field(init=False, default=None)
|
| 128 |
+
_doc_results: Optional[str] = field(init=False, default=None)
|
| 129 |
+
|
| 130 |
+
_step_count: int = field(init=False, default=0)
|
| 131 |
+
_done: bool = field(init=False, default=False)
|
| 132 |
+
|
| 133 |
+
# State tracking for dense rewards
|
| 134 |
+
_previous_test_score: float = field(init=False, default=0.0)
|
| 135 |
+
_previous_lint_score: float = field(init=False, default=0.0)
|
| 136 |
+
_current_test_score: float = field(init=False, default=0.0)
|
| 137 |
+
_current_lint_score: float = field(init=False, default=0.0)
|
| 138 |
+
|
| 139 |
+
# Tool usage tracking
|
| 140 |
+
_tests_run: bool = field(init=False, default=False)
|
| 141 |
+
_linter_run: bool = field(init=False, default=False)
|
| 142 |
+
_docs_queried: bool = field(init=False, default=False)
|
| 143 |
+
|
| 144 |
+
# Action history
|
| 145 |
+
_action_history: List[str] = field(init=False, default_factory=list)
|
| 146 |
+
_last_action_type: str = field(init=False, default="none")
|
| 147 |
+
_last_author_response: str = field(init=False, default="")
|
| 148 |
+
|
| 149 |
+
# FIXED: Track CUMULATIVE episode reward
|
| 150 |
+
_episode_total_reward: float = field(init=False, default=0.0)
|
| 151 |
+
_episode_rewards: List[float] = field(init=False, default_factory=list)
|
| 152 |
+
_difficulty_level: int = field(init=False, default=0)
|
| 153 |
+
|
| 154 |
+
# Bug-id bridge:
|
| 155 |
+
# RedTeam has fine-grained IDs, while TestRunner currently expects a
|
| 156 |
+
# smaller canonical set. Keep this mapping here so both modules can evolve
|
| 157 |
+
# independently without breaking evaluation.
|
| 158 |
+
_BUG_ID_CANONICAL_MAP = {
|
| 159 |
+
# Easy-family
|
| 160 |
+
"simple_typo": "null_check",
|
| 161 |
+
"default_value": "null_check",
|
| 162 |
+
"empty_return": "null_check",
|
| 163 |
+
"string_index": "off_by_one",
|
| 164 |
+
|
| 165 |
+
# Medium-family
|
| 166 |
+
"loop_skip": "off_by_one",
|
| 167 |
+
"sign_error": "wrong_operator",
|
| 168 |
+
"swap_args": "wrong_operator",
|
| 169 |
+
"uninitialised_var": "null_check",
|
| 170 |
+
|
| 171 |
+
# Hard-family
|
| 172 |
+
"division_by_zero_empty": "division_by_zero",
|
| 173 |
+
"division_by_zero_zero": "division_by_zero",
|
| 174 |
+
"float_precision": "division_by_zero",
|
| 175 |
+
"abs_usage": "division_by_zero",
|
| 176 |
+
"round_error": "division_by_zero",
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
# ===================================================================
|
| 180 |
+
def __post_init__(self):
|
| 181 |
+
self.set_task(self.task)
|
| 182 |
+
|
| 183 |
+
# ===================================================================
|
| 184 |
+
def _build_rubrics(self):
|
| 185 |
+
"""
|
| 186 |
+
Build rubric stack from a named reward profile.
|
| 187 |
+
- full: richer shaping for exploration/tool-use behavior
|
| 188 |
+
- core: minimal stable signal for quick ablations/baselines
|
| 189 |
+
"""
|
| 190 |
+
core_rubrics = [
|
| 191 |
+
TestDeltaRubric(weight=self.delta_weight),
|
| 192 |
+
LintDeltaRubric(weight=self.delta_weight),
|
| 193 |
+
TerminalSuccessRubric(),
|
| 194 |
+
StepPenaltyRubric(penalty=self.step_penalty),
|
| 195 |
+
]
|
| 196 |
+
if self.reward_profile == "core":
|
| 197 |
+
return core_rubrics
|
| 198 |
+
if self.reward_profile == "full":
|
| 199 |
+
return [
|
| 200 |
+
*core_rubrics[:-1], # step penalty appended at end for consistent ordering
|
| 201 |
+
ToolUsageRubric(bonus=self.tool_usage_bonus),
|
| 202 |
+
ExplorationRubric(penalty=-0.05, bonus=self.diversity_bonus * 0.7),
|
| 203 |
+
AntiHackingRubric(),
|
| 204 |
+
core_rubrics[-1],
|
| 205 |
+
]
|
| 206 |
+
raise ValueError(f"Unknown reward_profile: {self.reward_profile}")
|
| 207 |
+
|
| 208 |
+
# ===================================================================
|
| 209 |
+
def set_task(self, task: str):
|
| 210 |
+
if task not in ["easy", "medium", "hard", "harder", "hardest"]:
|
| 211 |
+
raise ValueError(f"Unknown task: {task}")
|
| 212 |
+
|
| 213 |
+
self.task = task
|
| 214 |
+
# Use stochastic bug sampling across episodes; fixed seed here would
|
| 215 |
+
# repeatedly select the same bug and weaken training diversity.
|
| 216 |
+
self._red_team = RedTeam(task, seed=None)
|
| 217 |
+
self._author = PersonaAuthor()
|
| 218 |
+
self.rubrics = self._build_rubrics()
|
| 219 |
+
|
| 220 |
+
task_to_level = {
|
| 221 |
+
"easy": 0, "medium": 1, "hard": 2,
|
| 222 |
+
"harder": 3, "hardest": 4
|
| 223 |
+
}
|
| 224 |
+
self._difficulty_level = task_to_level[task]
|
| 225 |
+
|
| 226 |
+
self._reset_internal()
|
| 227 |
+
|
| 228 |
+
# ===================================================================
|
| 229 |
+
def _reset_internal(self):
|
| 230 |
+
self._step_count = 0 # β FIXED
|
| 231 |
+
self._comments = []
|
| 232 |
+
self._test_results = None
|
| 233 |
+
self._lint_results = None
|
| 234 |
+
self._doc_results = None
|
| 235 |
+
self._done = False
|
| 236 |
+
|
| 237 |
+
# Reset state tracking
|
| 238 |
+
self._previous_test_score = 0.0
|
| 239 |
+
self._previous_lint_score = 0.0
|
| 240 |
+
self._current_test_score = 0.0
|
| 241 |
+
self._current_lint_score = 0.0
|
| 242 |
+
|
| 243 |
+
self._tests_run = False
|
| 244 |
+
self._linter_run = False
|
| 245 |
+
self._docs_queried = False
|
| 246 |
+
|
| 247 |
+
self._action_history = []
|
| 248 |
+
self._last_action_type = "none"
|
| 249 |
+
self._last_author_response = ""
|
| 250 |
+
|
| 251 |
+
# FIXED: Reset episode cumulative reward
|
| 252 |
+
self._episode_total_reward = 0.0
|
| 253 |
+
|
| 254 |
+
self._author.reset()
|
| 255 |
+
|
| 256 |
+
# Base tasks
|
| 257 |
+
if self.task == "easy":
|
| 258 |
+
original = "def get_user(id):\n if id in users:\n return users[id]"
|
| 259 |
+
elif self.task == "medium":
|
| 260 |
+
original = "def process_items(items):\n for item in items:\n print(item)"
|
| 261 |
+
elif self.task == "hard":
|
| 262 |
+
original = "def average(data):\n if not data:\n return 0\n return sum(data) / len(data)"
|
| 263 |
+
elif self.task == "harder":
|
| 264 |
+
original = "counter = 0\ndef increment():\n global counter\n with lock:\n counter += 1"
|
| 265 |
+
else:
|
| 266 |
+
original = "def safe_work():\n with lock1:\n with lock2:\n do_work()"
|
| 267 |
+
|
| 268 |
+
buggy_code, bug_id, desc, oracle = self._red_team.inject_bug(original)
|
| 269 |
+
self._current_code = buggy_code
|
| 270 |
+
self._current_bug_id = bug_id
|
| 271 |
+
self._bug_description = desc
|
| 272 |
+
self._oracle_fix = oracle
|
| 273 |
+
self._comments.append(f"[RedTeam] {desc}")
|
| 274 |
+
|
| 275 |
+
# ===================================================================
|
| 276 |
+
def reset(self) -> EnhancedObservation:
|
| 277 |
+
"""Reset with optional curriculum adjustment."""
|
| 278 |
+
if self.auto_difficulty and len(self._episode_rewards) > 0:
|
| 279 |
+
recent_performance = sum(self._episode_rewards[-5:]) / min(5, len(self._episode_rewards))
|
| 280 |
+
|
| 281 |
+
if recent_performance > self.success_threshold and self._difficulty_level < 4:
|
| 282 |
+
self._difficulty_level += 1
|
| 283 |
+
print(f"[Curriculum] Increasing difficulty to level {self._difficulty_level}")
|
| 284 |
+
elif recent_performance < 0.3 and self._difficulty_level > 0:
|
| 285 |
+
self._difficulty_level -= 1
|
| 286 |
+
print(f"[Curriculum] Decreasing difficulty to level {self._difficulty_level}")
|
| 287 |
+
|
| 288 |
+
level_to_task = {0: "easy", 1: "medium", 2: "hard", 3: "harder", 4: "hardest"}
|
| 289 |
+
self.task = level_to_task[self._difficulty_level]
|
| 290 |
+
# Keep curriculum stochastic for better coverage within each level.
|
| 291 |
+
self._red_team = RedTeam(self.task, seed=None)
|
| 292 |
+
|
| 293 |
+
self._reset_internal()
|
| 294 |
+
return self._get_observation()
|
| 295 |
+
|
| 296 |
+
# ===================================================================
|
| 297 |
+
def _get_observation(self) -> EnhancedObservation:
|
| 298 |
+
"""Return COMPLETE Markov state."""
|
| 299 |
+
# Keep the author's message separate from tool output.
|
| 300 |
+
# Using `_test_results` here can leak unrelated outputs (tests/linter/docs)
|
| 301 |
+
# and gives the policy a noisy signal for dialogue actions.
|
| 302 |
+
if self._last_action_type in ("comment", "question", "fix"):
|
| 303 |
+
author_response = self._last_author_response
|
| 304 |
+
else:
|
| 305 |
+
author_response = ""
|
| 306 |
+
|
| 307 |
+
return EnhancedObservation(
|
| 308 |
+
code_snippet=self._current_code,
|
| 309 |
+
last_tool_output=self._test_results or "",
|
| 310 |
+
author_response=author_response, # β now field exists
|
| 311 |
+
|
| 312 |
+
current_test_score=self._current_test_score,
|
| 313 |
+
current_lint_score=self._current_lint_score,
|
| 314 |
+
negotiation_score=self._author.get_negotiation_score(),
|
| 315 |
+
|
| 316 |
+
previous_test_score=self._previous_test_score,
|
| 317 |
+
previous_lint_score=self._previous_lint_score,
|
| 318 |
+
|
| 319 |
+
author_confidence=self._author._confidence,
|
| 320 |
+
author_threshold=self._author.thresholds.get(self._author.personality, 0.5),
|
| 321 |
+
|
| 322 |
+
step=self._step_count,
|
| 323 |
+
max_steps=self.max_steps,
|
| 324 |
+
# Guard against accidental `max_steps=0` configs.
|
| 325 |
+
progress_ratio=(self._step_count / self.max_steps) if self.max_steps > 0 else 1.0,
|
| 326 |
+
|
| 327 |
+
tests_run=self._tests_run,
|
| 328 |
+
linter_run=self._linter_run,
|
| 329 |
+
docs_queried=self._docs_queried,
|
| 330 |
+
|
| 331 |
+
last_action_type=self._last_action_type,
|
| 332 |
+
action_history=self._action_history[-5:],
|
| 333 |
+
|
| 334 |
+
done=self._done,
|
| 335 |
+
|
| 336 |
+
bug_description=self._bug_description,
|
| 337 |
+
comments_count=len(self._comments),
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# ===================================================================
|
| 341 |
+
def _get_action_type(self, action: AnyAction) -> str:
|
| 342 |
+
"""Extract action type as string."""
|
| 343 |
+
if isinstance(action, RunTests):
|
| 344 |
+
return "run_tests"
|
| 345 |
+
elif isinstance(action, RunLinter):
|
| 346 |
+
return "run_linter"
|
| 347 |
+
elif isinstance(action, QueryDocs):
|
| 348 |
+
return "query_docs"
|
| 349 |
+
elif isinstance(action, Execute):
|
| 350 |
+
return "execute"
|
| 351 |
+
elif isinstance(action, Inspect):
|
| 352 |
+
return "inspect"
|
| 353 |
+
elif isinstance(action, WriteComment):
|
| 354 |
+
return "comment"
|
| 355 |
+
elif isinstance(action, AskQuestion):
|
| 356 |
+
return "question"
|
| 357 |
+
elif isinstance(action, ProposeFix):
|
| 358 |
+
return "fix"
|
| 359 |
+
elif isinstance(action, Done):
|
| 360 |
+
return "done"
|
| 361 |
+
elif isinstance(action, Skip):
|
| 362 |
+
return "skip"
|
| 363 |
+
else:
|
| 364 |
+
return "unknown"
|
| 365 |
+
|
| 366 |
+
# ===================================================================
|
| 367 |
+
def _get_test_runner_bug_id(self) -> str:
|
| 368 |
+
"""
|
| 369 |
+
Normalize RedTeam bug ids to the canonical ids understood by TestRunner.
|
| 370 |
+
Falls back to the original id for known direct matches.
|
| 371 |
+
"""
|
| 372 |
+
return self._BUG_ID_CANONICAL_MAP.get(self._current_bug_id, self._current_bug_id)
|
| 373 |
+
|
| 374 |
+
# ===================================================================
|
| 375 |
+
def step(self, action: AnyAction) -> Tuple[EnhancedObservation, Reward, bool, Dict[str, Any]]:
|
| 376 |
+
"""
|
| 377 |
+
TRUE RL STEP with:
|
| 378 |
+
- Complete Markov observations (no hidden state)
|
| 379 |
+
- Dense intermediate rewards
|
| 380 |
+
- Delta-based credit assignment (no double-counting)
|
| 381 |
+
- Proper episode reward tracking
|
| 382 |
+
"""
|
| 383 |
+
if self._done:
|
| 384 |
+
raise RuntimeError("Episode already finished")
|
| 385 |
+
|
| 386 |
+
# Store previous metrics for delta computation
|
| 387 |
+
self._previous_test_score = self._current_test_score
|
| 388 |
+
self._previous_lint_score = self._current_lint_score
|
| 389 |
+
# Snapshot tool-usage flags BEFORE action mutates them.
|
| 390 |
+
# Rubrics use these to detect true "first-use" behavior.
|
| 391 |
+
prev_tests_run = self._tests_run
|
| 392 |
+
prev_linter_run = self._linter_run
|
| 393 |
+
prev_docs_queried = self._docs_queried
|
| 394 |
+
|
| 395 |
+
base_reward = 0.0
|
| 396 |
+
action_type = self._get_action_type(action)
|
| 397 |
+
|
| 398 |
+
# Update action history
|
| 399 |
+
self._action_history.append(action_type)
|
| 400 |
+
self._last_action_type = action_type
|
| 401 |
+
|
| 402 |
+
# ==============================================================
|
| 403 |
+
# TOOL ACTIONS
|
| 404 |
+
# ==============================================================
|
| 405 |
+
if isinstance(action, Execute):
|
| 406 |
+
success, stdout, stderr = execute_code(self._current_code)
|
| 407 |
+
output = (stdout + stderr).strip() or "No output"
|
| 408 |
+
self._test_results = f"[Execute] {'Success' if success else 'Failed'}\n{output[:300]}"
|
| 409 |
+
base_reward = 0.001 if success else -0.05
|
| 410 |
+
|
| 411 |
+
elif isinstance(action, Inspect):
|
| 412 |
+
self._test_results = f"[Inspect]\n{self._current_code[:500]}"
|
| 413 |
+
base_reward = 0.001
|
| 414 |
+
|
| 415 |
+
elif isinstance(action, RunLinter):
|
| 416 |
+
lint_output = ToolBox.run_linter(self._current_code)
|
| 417 |
+
self._lint_results = lint_output[:500]
|
| 418 |
+
self._test_results = f"[Linter]\n{self._lint_results}"
|
| 419 |
+
|
| 420 |
+
self._current_lint_score = self._run_linter_score(self._current_code)
|
| 421 |
+
self._linter_run = True
|
| 422 |
+
base_reward = 0.002
|
| 423 |
+
|
| 424 |
+
elif isinstance(action, RunTests):
|
| 425 |
+
runner = TestRunner(self._get_test_runner_bug_id())
|
| 426 |
+
score, output = runner.run_tests(self._current_code)
|
| 427 |
+
|
| 428 |
+
self._current_test_score = score
|
| 429 |
+
self._tests_run = True
|
| 430 |
+
|
| 431 |
+
self._test_results = f"[Tests] Score: {score:.2f}\n{output[:300]}"
|
| 432 |
+
base_reward = 0.002
|
| 433 |
+
|
| 434 |
+
if score > 0.8:
|
| 435 |
+
base_reward += 0.005
|
| 436 |
+
|
| 437 |
+
elif isinstance(action, QueryDocs):
|
| 438 |
+
# Normalize query to avoid rewarding empty/noisy requests.
|
| 439 |
+
query_topic = (action.query_topic or "").strip()
|
| 440 |
+
doc = ToolBox.query_docs(query_topic if query_topic else "general bug fixing")
|
| 441 |
+
self._doc_results = doc
|
| 442 |
+
self._test_results = f"[Docs]\n{doc[:400]}"
|
| 443 |
+
self._docs_queried = True
|
| 444 |
+
base_reward = 0.001
|
| 445 |
+
|
| 446 |
+
# ==============================================================
|
| 447 |
+
# COMMUNICATION ACTIONS
|
| 448 |
+
# ==============================================================
|
| 449 |
+
elif isinstance(action, WriteComment):
|
| 450 |
+
self._comments.append(f"Agent: {action.comment_text}")
|
| 451 |
+
|
| 452 |
+
response = self._author.respond(
|
| 453 |
+
agent_comment=action.comment_text,
|
| 454 |
+
test_results=self._test_results,
|
| 455 |
+
lint_results=self._lint_results,
|
| 456 |
+
doc_results=self._doc_results,
|
| 457 |
+
proposed_fix=None,
|
| 458 |
+
original_code=self._current_code
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
self._comments.append(f"Author: {response}")
|
| 462 |
+
self._last_author_response = response
|
| 463 |
+
self._test_results = f"[Comment] Author: {response[:200]}"
|
| 464 |
+
base_reward = 0.001
|
| 465 |
+
|
| 466 |
+
elif isinstance(action, AskQuestion):
|
| 467 |
+
self._comments.append(f"Agent: {action.question}")
|
| 468 |
+
|
| 469 |
+
response = self._author.respond(
|
| 470 |
+
agent_question=action.question,
|
| 471 |
+
test_results=self._test_results,
|
| 472 |
+
lint_results=self._lint_results,
|
| 473 |
+
doc_results=self._doc_results,
|
| 474 |
+
proposed_fix=None,
|
| 475 |
+
original_code=self._current_code # β FIXED
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
self._comments.append(f"Author: {response}")
|
| 479 |
+
self._last_author_response = response
|
| 480 |
+
self._test_results = f"[Question] Author: {response[:200]}"
|
| 481 |
+
base_reward = 0.002
|
| 482 |
+
|
| 483 |
+
# ==============================================================
|
| 484 |
+
# FINAL FIX ACTION
|
| 485 |
+
# ==============================================================
|
| 486 |
+
elif isinstance(action, ProposeFix):
|
| 487 |
+
if not action.fix_code:
|
| 488 |
+
base_reward = -0.05
|
| 489 |
+
self._done = True
|
| 490 |
+
else:
|
| 491 |
+
# Save original code BEFORE overwriting (for author.respond)
|
| 492 |
+
original_buggy = self._current_code
|
| 493 |
+
self._current_code = action.fix_code
|
| 494 |
+
|
| 495 |
+
runner = TestRunner(self._get_test_runner_bug_id())
|
| 496 |
+
test_score, test_output = runner.run_tests(self._current_code)
|
| 497 |
+
lint_score = self._run_linter_score(self._current_code)
|
| 498 |
+
negotiation_score = self._author.get_negotiation_score()
|
| 499 |
+
|
| 500 |
+
self._current_test_score = test_score
|
| 501 |
+
self._current_lint_score = lint_score
|
| 502 |
+
|
| 503 |
+
# Author gating β determines if the episode ends, reward is separate
|
| 504 |
+
threshold = self._author.thresholds.get(self._author.personality, 0.5)
|
| 505 |
+
if self._author._confidence < threshold:
|
| 506 |
+
if self._step_count < self.max_steps:
|
| 507 |
+
self._done = False
|
| 508 |
+
else:
|
| 509 |
+
self._done = True
|
| 510 |
+
else:
|
| 511 |
+
self._done = True
|
| 512 |
+
|
| 513 |
+
# Get author's verbal feedback (pushback/acceptance)
|
| 514 |
+
author_feedback = self._author.respond(
|
| 515 |
+
agent_comment=f"Proposed fix:\n{action.fix_code}",
|
| 516 |
+
test_results=f"Score: {test_score:.2f}",
|
| 517 |
+
lint_results=f"Score: {lint_score:.2f}",
|
| 518 |
+
doc_results=self._doc_results,
|
| 519 |
+
proposed_fix=action.fix_code,
|
| 520 |
+
original_code=original_buggy # now correctly the buggy code, not the fix
|
| 521 |
+
)
|
| 522 |
+
self._test_results = f"[Fix] Author: {author_feedback[:200]}"
|
| 523 |
+
self._comments.append(f"Author: {author_feedback}")
|
| 524 |
+
self._last_author_response = author_feedback
|
| 525 |
+
|
| 526 |
+
base_reward = 0.001 # rubrics provide the real signal
|
| 527 |
+
|
| 528 |
+
# ==============================================================
|
| 529 |
+
# TERMINATION ACTIONS
|
| 530 |
+
# ==============================================================
|
| 531 |
+
elif isinstance(action, Skip):
|
| 532 |
+
base_reward = -0.03
|
| 533 |
+
self._done = True
|
| 534 |
+
|
| 535 |
+
elif isinstance(action, Done):
|
| 536 |
+
if self._tests_run:
|
| 537 |
+
base_reward = self._current_test_score * 0.5 - 0.2
|
| 538 |
+
else:
|
| 539 |
+
base_reward = -0.04
|
| 540 |
+
self._done = True
|
| 541 |
+
|
| 542 |
+
else:
|
| 543 |
+
base_reward = -0.02
|
| 544 |
+
self._done = True
|
| 545 |
+
|
| 546 |
+
# ==============================================================
|
| 547 |
+
# STEP UPDATE (before rubric computation so info contains final step)
|
| 548 |
+
# ==============================================================
|
| 549 |
+
self._step_count += 1
|
| 550 |
+
if self._step_count >= self.max_steps:
|
| 551 |
+
self._done = True
|
| 552 |
+
|
| 553 |
+
# Get fresh observation (needed for rubrics that may read obs)
|
| 554 |
+
obs = self._get_observation()
|
| 555 |
+
|
| 556 |
+
# Prepare info dict (rubrics may need action_type and deltas)
|
| 557 |
+
info = {
|
| 558 |
+
"action_type": action_type,
|
| 559 |
+
"test_score": self._current_test_score,
|
| 560 |
+
"lint_score": self._current_lint_score,
|
| 561 |
+
"test_delta": self._current_test_score - self._previous_test_score,
|
| 562 |
+
"lint_delta": self._current_lint_score - self._previous_lint_score,
|
| 563 |
+
"prev_tests_run": prev_tests_run,
|
| 564 |
+
"prev_linter_run": prev_linter_run,
|
| 565 |
+
"prev_docs_queried": prev_docs_queried,
|
| 566 |
+
"docs_query_len": len((action.query_topic or "").strip()) if isinstance(action, QueryDocs) else 0,
|
| 567 |
+
"base_reward": base_reward,
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
# ==============================================================
|
| 571 |
+
# COMPUTE FINAL REWARD USING RUBRICS
|
| 572 |
+
# ==============================================================
|
| 573 |
+
rubric_score = sum(r(self, action, obs, None, self._done, info) for r in self.rubrics)
|
| 574 |
+
final_reward = 0.4 * base_reward + rubric_score
|
| 575 |
+
final_reward = max(-1.0, min(1.0, final_reward)) # safety clip
|
| 576 |
+
|
| 577 |
+
# Track cumulative episode reward
|
| 578 |
+
self._episode_total_reward += final_reward
|
| 579 |
+
|
| 580 |
+
# Store episode total if done
|
| 581 |
+
if self._done:
|
| 582 |
+
self._episode_rewards.append(self._episode_total_reward)
|
| 583 |
+
|
| 584 |
+
# Complete info
|
| 585 |
+
info["final_reward"] = final_reward
|
| 586 |
+
info["episode_total"] = self._episode_total_reward
|
| 587 |
+
|
| 588 |
+
return obs, Reward(value=final_reward), self._done, info
|
| 589 |
+
|
| 590 |
+
# ===================================================================
|
| 591 |
+
def _run_linter_score(self, code: str) -> float:
|
| 592 |
+
"""Run pylint and return normalized score [0, 1]."""
|
| 593 |
+
try:
|
| 594 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
| 595 |
+
f.write(code)
|
| 596 |
+
tmp_path = f.name
|
| 597 |
+
|
| 598 |
result = subprocess.run(
|
| 599 |
+
[sys.executable, '-m', 'pylint', tmp_path, '--score=y', '--exit-zero'],
|
| 600 |
capture_output=True,
|
| 601 |
text=True,
|
| 602 |
timeout=5
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
|
| 606 |
+
if match:
|
| 607 |
+
return float(match.group(1)) / 10.0
|
| 608 |
+
return 0.0
|
| 609 |
+
except:
|
| 610 |
+
return 0.0
|
| 611 |
+
finally:
|
| 612 |
+
try:
|
| 613 |
+
os.unlink(tmp_path)
|
| 614 |
+
except:
|
| 615 |
+
pass
|
| 616 |
+
|
| 617 |
+
# ===================================================================
|
| 618 |
+
def state(self) -> State:
|
| 619 |
+
"""Legacy compatibility."""
|
| 620 |
+
return State(
|
| 621 |
+
pr_title="Code Review",
|
| 622 |
+
pr_description=self._bug_description,
|
| 623 |
+
code_snippet=self._current_code,
|
| 624 |
+
comments=self._comments.copy(),
|
| 625 |
+
test_results=self._test_results,
|
| 626 |
+
step=self._step_count,
|
| 627 |
+
done=self._done
|
| 628 |
+
)
|
grader.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
| 1 |
# grader.py β Productionβgrade, continuous reward, exploitβaware, example of monolithic scoring
|
| 2 |
import ast
|
| 3 |
-
import subprocess
|
| 4 |
-
import tempfile
|
| 5 |
-
import os
|
| 6 |
-
import re
|
| 7 |
-
|
| 8 |
-
from
|
|
|
|
| 9 |
|
| 10 |
@dataclass
|
| 11 |
class RigorousGrader:
|
|
@@ -105,11 +106,11 @@ class RigorousGrader:
|
|
| 105 |
f.write(code)
|
| 106 |
f.flush()
|
| 107 |
tmp_path = f.name
|
| 108 |
-
result = subprocess.run(
|
| 109 |
-
['pylint', tmp_path, '--score=y', '--exit-zero'],
|
| 110 |
-
capture_output=True,
|
| 111 |
-
text=True,
|
| 112 |
-
timeout=5
|
| 113 |
)
|
| 114 |
match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
|
| 115 |
if match:
|
|
@@ -139,4 +140,4 @@ class RigorousGrader:
|
|
| 139 |
total = max(len(nodes_prop), len(nodes_oracle))
|
| 140 |
return common / total if total > 0 else 0.0
|
| 141 |
except:
|
| 142 |
-
return 0.0
|
|
|
|
| 1 |
# grader.py β Productionβgrade, continuous reward, exploitβaware, example of monolithic scoring
|
| 2 |
import ast
|
| 3 |
+
import subprocess
|
| 4 |
+
import tempfile
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
import sys
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Optional
|
| 10 |
|
| 11 |
@dataclass
|
| 12 |
class RigorousGrader:
|
|
|
|
| 106 |
f.write(code)
|
| 107 |
f.flush()
|
| 108 |
tmp_path = f.name
|
| 109 |
+
result = subprocess.run(
|
| 110 |
+
[sys.executable, '-m', 'pylint', tmp_path, '--score=y', '--exit-zero'],
|
| 111 |
+
capture_output=True,
|
| 112 |
+
text=True,
|
| 113 |
+
timeout=5
|
| 114 |
)
|
| 115 |
match = re.search(r"rated at (\d+\.\d+)/10", result.stdout)
|
| 116 |
if match:
|
|
|
|
| 140 |
total = max(len(nodes_prop), len(nodes_oracle))
|
| 141 |
return common / total if total > 0 else 0.0
|
| 142 |
except:
|
| 143 |
+
return 0.0
|
models.py
CHANGED
|
@@ -1,112 +1,112 @@
|
|
| 1 |
-
# models.py β Typed Models (Discriminated Unions, POMDP Separation)
|
| 2 |
-
from typing import Literal, Union, Annotated, Optional
|
| 3 |
-
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
| 4 |
-
|
| 5 |
-
# ----------------------------------------------------------------------
|
| 6 |
-
# Action classes (discriminated union)
|
| 7 |
-
# ----------------------------------------------------------------------
|
| 8 |
-
class Action(BaseModel):
|
| 9 |
-
action_type: Literal["comment", "skip", "done", "question",
|
| 10 |
-
"fix", "execute", "inspect", "run_linter",
|
| 11 |
-
"run_tests", "query_docs"]
|
| 12 |
-
|
| 13 |
-
class WriteComment(Action):
|
| 14 |
-
action_type: Literal["comment"] = "comment"
|
| 15 |
-
comment_text: str = Field(..., min_length=1)
|
| 16 |
-
|
| 17 |
-
class Skip(Action):
|
| 18 |
-
action_type: Literal["skip"] = "skip"
|
| 19 |
-
|
| 20 |
-
class Done(Action):
|
| 21 |
-
action_type: Literal["done"] = "done"
|
| 22 |
-
|
| 23 |
-
class AskQuestion(Action):
|
| 24 |
-
action_type: Literal["question"] = "question"
|
| 25 |
-
question: str = Field(..., min_length=1)
|
| 26 |
-
|
| 27 |
-
class ProposeFix(Action):
|
| 28 |
-
action_type: Literal["fix"] = "fix"
|
| 29 |
-
fix_code: str = Field(..., min_length=1)
|
| 30 |
-
@field_validator('fix_code')
|
| 31 |
-
@classmethod
|
| 32 |
-
def not_empty(cls, v: str) -> str:
|
| 33 |
-
if not v.strip():
|
| 34 |
-
raise ValueError('fix_code cannot be empty')
|
| 35 |
-
return v
|
| 36 |
-
|
| 37 |
-
class Execute(Action):
|
| 38 |
-
action_type: Literal["execute"] = "execute"
|
| 39 |
-
|
| 40 |
-
class Inspect(Action):
|
| 41 |
-
action_type: Literal["inspect"] = "inspect"
|
| 42 |
-
|
| 43 |
-
class RunLinter(Action):
|
| 44 |
-
action_type: Literal["run_linter"] = "run_linter"
|
| 45 |
-
|
| 46 |
-
class RunTests(Action):
|
| 47 |
-
action_type: Literal["run_tests"] = "run_tests"
|
| 48 |
-
|
| 49 |
-
class QueryDocs(Action):
|
| 50 |
-
action_type: Literal["query_docs"] = "query_docs"
|
| 51 |
-
query_topic: str = Field(..., min_length=1)
|
| 52 |
-
|
| 53 |
-
# Discriminated union for oneβline polymorphic deserialization
|
| 54 |
-
AnyAction = Annotated[
|
| 55 |
-
Union[WriteComment, Skip, Done, AskQuestion, ProposeFix,
|
| 56 |
-
Execute, Inspect, RunLinter, RunTests, QueryDocs],
|
| 57 |
-
Field(discriminator='action_type')
|
| 58 |
-
]
|
| 59 |
-
action_adapter = TypeAdapter(AnyAction)
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def map_to_env(action_type: str, content: Optional[str] = None) -> AnyAction:
|
| 63 |
-
"""
|
| 64 |
-
Convert lightweight agent outputs into typed environment actions.
|
| 65 |
-
Kept at module level so training/inference code can reuse one mapping.
|
| 66 |
-
"""
|
| 67 |
-
if action_type == "run_tests":
|
| 68 |
-
return RunTests()
|
| 69 |
-
if action_type == "run_linter":
|
| 70 |
-
return RunLinter()
|
| 71 |
-
if action_type == "inspect":
|
| 72 |
-
return Inspect()
|
| 73 |
-
if action_type == "fix":
|
| 74 |
-
return ProposeFix(fix_code=content or "")
|
| 75 |
-
if action_type == "comment":
|
| 76 |
-
return WriteComment(comment_text=content or "")
|
| 77 |
-
if action_type == "question":
|
| 78 |
-
return AskQuestion(question=content or "")
|
| 79 |
-
if action_type == "query_docs":
|
| 80 |
-
return QueryDocs(query_topic=content or "")
|
| 81 |
-
if action_type == "done":
|
| 82 |
-
return Done()
|
| 83 |
-
return Skip()
|
| 84 |
-
|
| 85 |
-
# ----------------------------------------------------------------------
|
| 86 |
-
# Observation (POMDP β what the agent sees)
|
| 87 |
-
# ----------------------------------------------------------------------
|
| 88 |
-
class Observation(BaseModel):
|
| 89 |
-
# Base schema model used by API metadata endpoints.
|
| 90 |
-
# Keep this lightweight for compatibility with legacy callers.
|
| 91 |
-
code_snippet: str
|
| 92 |
-
last_tool_output: str = ""
|
| 93 |
-
step: int = 0
|
| 94 |
-
done: bool = False
|
| 95 |
-
|
| 96 |
-
# ----------------------------------------------------------------------
|
| 97 |
-
# Reward (lightweight)
|
| 98 |
-
# ----------------------------------------------------------------------
|
| 99 |
-
class Reward(BaseModel):
|
| 100 |
-
value: float
|
| 101 |
-
|
| 102 |
-
# ----------------------------------------------------------------------
|
| 103 |
-
# State (full environment state β not exposed to agent)
|
| 104 |
-
# ----------------------------------------------------------------------
|
| 105 |
-
class State(BaseModel):
|
| 106 |
-
pr_title: str
|
| 107 |
-
pr_description: str
|
| 108 |
-
code_snippet: str
|
| 109 |
-
comments: list[str]
|
| 110 |
-
test_results: Optional[str]
|
| 111 |
-
step: int
|
| 112 |
done: bool
|
|
|
|
| 1 |
+
# models.py β Typed Models (Discriminated Unions, POMDP Separation)
|
| 2 |
+
from typing import Literal, Union, Annotated, Optional
|
| 3 |
+
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
| 4 |
+
|
| 5 |
+
# ----------------------------------------------------------------------
|
| 6 |
+
# Action classes (discriminated union)
|
| 7 |
+
# ----------------------------------------------------------------------
|
| 8 |
+
class Action(BaseModel):
|
| 9 |
+
action_type: Literal["comment", "skip", "done", "question",
|
| 10 |
+
"fix", "execute", "inspect", "run_linter",
|
| 11 |
+
"run_tests", "query_docs"]
|
| 12 |
+
|
| 13 |
+
class WriteComment(Action):
|
| 14 |
+
action_type: Literal["comment"] = "comment"
|
| 15 |
+
comment_text: str = Field(..., min_length=1)
|
| 16 |
+
|
| 17 |
+
class Skip(Action):
|
| 18 |
+
action_type: Literal["skip"] = "skip"
|
| 19 |
+
|
| 20 |
+
class Done(Action):
|
| 21 |
+
action_type: Literal["done"] = "done"
|
| 22 |
+
|
| 23 |
+
class AskQuestion(Action):
|
| 24 |
+
action_type: Literal["question"] = "question"
|
| 25 |
+
question: str = Field(..., min_length=1)
|
| 26 |
+
|
| 27 |
+
class ProposeFix(Action):
|
| 28 |
+
action_type: Literal["fix"] = "fix"
|
| 29 |
+
fix_code: str = Field(..., min_length=1)
|
| 30 |
+
@field_validator('fix_code')
|
| 31 |
+
@classmethod
|
| 32 |
+
def not_empty(cls, v: str) -> str:
|
| 33 |
+
if not v.strip():
|
| 34 |
+
raise ValueError('fix_code cannot be empty')
|
| 35 |
+
return v
|
| 36 |
+
|
| 37 |
+
class Execute(Action):
|
| 38 |
+
action_type: Literal["execute"] = "execute"
|
| 39 |
+
|
| 40 |
+
class Inspect(Action):
|
| 41 |
+
action_type: Literal["inspect"] = "inspect"
|
| 42 |
+
|
| 43 |
+
class RunLinter(Action):
|
| 44 |
+
action_type: Literal["run_linter"] = "run_linter"
|
| 45 |
+
|
| 46 |
+
class RunTests(Action):
|
| 47 |
+
action_type: Literal["run_tests"] = "run_tests"
|
| 48 |
+
|
| 49 |
+
class QueryDocs(Action):
|
| 50 |
+
action_type: Literal["query_docs"] = "query_docs"
|
| 51 |
+
query_topic: str = Field(..., min_length=1)
|
| 52 |
+
|
| 53 |
+
# Discriminated union for oneβline polymorphic deserialization
|
| 54 |
+
AnyAction = Annotated[
|
| 55 |
+
Union[WriteComment, Skip, Done, AskQuestion, ProposeFix,
|
| 56 |
+
Execute, Inspect, RunLinter, RunTests, QueryDocs],
|
| 57 |
+
Field(discriminator='action_type')
|
| 58 |
+
]
|
| 59 |
+
action_adapter = TypeAdapter(AnyAction)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def map_to_env(action_type: str, content: Optional[str] = None) -> AnyAction:
|
| 63 |
+
"""
|
| 64 |
+
Convert lightweight agent outputs into typed environment actions.
|
| 65 |
+
Kept at module level so training/inference code can reuse one mapping.
|
| 66 |
+
"""
|
| 67 |
+
if action_type == "run_tests":
|
| 68 |
+
return RunTests()
|
| 69 |
+
if action_type == "run_linter":
|
| 70 |
+
return RunLinter()
|
| 71 |
+
if action_type == "inspect":
|
| 72 |
+
return Inspect()
|
| 73 |
+
if action_type == "fix":
|
| 74 |
+
return ProposeFix(fix_code=content or "")
|
| 75 |
+
if action_type == "comment":
|
| 76 |
+
return WriteComment(comment_text=content or "")
|
| 77 |
+
if action_type == "question":
|
| 78 |
+
return AskQuestion(question=content or "")
|
| 79 |
+
if action_type == "query_docs":
|
| 80 |
+
return QueryDocs(query_topic=content or "")
|
| 81 |
+
if action_type == "done":
|
| 82 |
+
return Done()
|
| 83 |
+
return Skip()
|
| 84 |
+
|
| 85 |
+
# ----------------------------------------------------------------------
|
| 86 |
+
# Observation (POMDP β what the agent sees)
|
| 87 |
+
# ----------------------------------------------------------------------
|
| 88 |
+
class Observation(BaseModel):
|
| 89 |
+
# Base schema model used by API metadata endpoints.
|
| 90 |
+
# Keep this lightweight for compatibility with legacy callers.
|
| 91 |
+
code_snippet: str
|
| 92 |
+
last_tool_output: str = ""
|
| 93 |
+
step: int = 0
|
| 94 |
+
done: bool = False
|
| 95 |
+
|
| 96 |
+
# ----------------------------------------------------------------------
|
| 97 |
+
# Reward (lightweight)
|
| 98 |
+
# ----------------------------------------------------------------------
|
| 99 |
+
class Reward(BaseModel):
|
| 100 |
+
value: float
|
| 101 |
+
|
| 102 |
+
# ----------------------------------------------------------------------
|
| 103 |
+
# State (full environment state β not exposed to agent)
|
| 104 |
+
# ----------------------------------------------------------------------
|
| 105 |
+
class State(BaseModel):
|
| 106 |
+
pr_title: str
|
| 107 |
+
pr_description: str
|
| 108 |
+
code_snippet: str
|
| 109 |
+
comments: list[str]
|
| 110 |
+
test_results: Optional[str]
|
| 111 |
+
step: int
|
| 112 |
done: bool
|
pyproject.toml
CHANGED
|
@@ -1,30 +1,38 @@
|
|
| 1 |
-
[build-system]
|
| 2 |
-
requires = ["setuptools>=61.0", "wheel"]
|
| 3 |
-
build-backend = "setuptools.build_meta"
|
| 4 |
-
|
| 5 |
-
[project]
|
| 6 |
-
name = "code_review_professional"
|
| 7 |
-
version = "1.0.0"
|
| 8 |
-
description = "Multi
|
| 9 |
-
authors = [{name = "yuvraj gupta", email = "yuvraj467229@gmail.com"}]
|
| 10 |
-
license = {text = "MIT"}
|
| 11 |
-
readme = "README.md"
|
| 12 |
-
requires-python = ">=3.10"
|
| 13 |
-
dependencies = [
|
| 14 |
-
"openenv-core>=0.2.
|
| 15 |
-
"fastapi>=0.115.0",
|
| 16 |
-
"uvicorn>=0.24.0",
|
| 17 |
-
"
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
"
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "code_review_professional"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "Multi-turn code review environment for OpenEnv with author negotiation and RL training hooks."
|
| 9 |
+
authors = [{name = "yuvraj gupta", email = "yuvraj467229@gmail.com"}]
|
| 10 |
+
license = {text = "MIT"}
|
| 11 |
+
readme = "README.md"
|
| 12 |
+
requires-python = ">=3.10"
|
| 13 |
+
dependencies = [
|
| 14 |
+
"openenv-core>=0.2.3",
|
| 15 |
+
"fastapi>=0.115.0",
|
| 16 |
+
"uvicorn>=0.24.0",
|
| 17 |
+
"pylint>=3.3.0",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
[project.optional-dependencies]
|
| 21 |
+
dev = ["pytest>=7.0", "black>=23.0", "isort>=5.0"]
|
| 22 |
+
training = [
|
| 23 |
+
"accelerate>=1.2.0",
|
| 24 |
+
"chromadb>=0.5.0",
|
| 25 |
+
"datasets>=3.3.0",
|
| 26 |
+
"matplotlib>=3.9.0",
|
| 27 |
+
"sentence-transformers>=3.3.0",
|
| 28 |
+
"torch>=2.4.0",
|
| 29 |
+
"trl>=0.15.0",
|
| 30 |
+
"unsloth>=2025.3.1",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
[project.urls]
|
| 34 |
+
Homepage = "https://huggingface.co/spaces/100XZX001/CodeReview-Professional-Workflow"
|
| 35 |
+
Repository = "https://huggingface.co/spaces/100XZX001/CodeReview-Professional-Workflow"
|
| 36 |
+
|
| 37 |
+
[tool.openenv]
|
| 38 |
+
server = "server.app:app"
|
requirements-training.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate>=1.2.0
|
| 2 |
+
chromadb>=0.5.0
|
| 3 |
+
datasets>=3.3.0
|
| 4 |
+
matplotlib>=3.9.0
|
| 5 |
+
sentence-transformers>=3.3.0
|
| 6 |
+
torch>=2.4.0
|
| 7 |
+
transformers>=4.48.0
|
| 8 |
+
trl>=0.15.0
|
| 9 |
+
unsloth>=2025.3.1
|
requirements.txt
CHANGED
|
@@ -1,11 +1,5 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
sentence-transformers
|
| 7 |
-
chromadb
|
| 8 |
-
pylint
|
| 9 |
-
pydantic
|
| 10 |
-
matplotlib
|
| 11 |
-
huggingface_hub
|
|
|
|
| 1 |
+
openenv-core>=0.2.3
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
pylint>=3.3.0
|
| 5 |
+
pydantic>=2.8.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rltool.py
CHANGED
|
@@ -1,127 +1,143 @@
|
|
| 1 |
-
# tools.py
|
| 2 |
-
import
|
| 3 |
-
import
|
| 4 |
-
import
|
| 5 |
-
|
| 6 |
-
from
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# tools.py - Real vector retrieval for query_docs, linter, and test runner
|
| 2 |
+
import os
|
| 3 |
+
import subprocess
|
| 4 |
+
import sys
|
| 5 |
+
import tempfile
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from sentence_transformers import SentenceTransformer
|
| 10 |
+
except ImportError:
|
| 11 |
+
SentenceTransformer = None
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import chromadb
|
| 15 |
+
except ImportError:
|
| 16 |
+
chromadb = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class ToolBox:
|
| 21 |
+
_embedder = None
|
| 22 |
+
_client = None
|
| 23 |
+
_collection = None
|
| 24 |
+
|
| 25 |
+
@classmethod
|
| 26 |
+
def _get_embedder(cls):
|
| 27 |
+
if cls._embedder is None:
|
| 28 |
+
if SentenceTransformer is None:
|
| 29 |
+
return None
|
| 30 |
+
cls._embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
| 31 |
+
return cls._embedder
|
| 32 |
+
|
| 33 |
+
@classmethod
|
| 34 |
+
def _get_collection(cls):
|
| 35 |
+
if cls._collection is None:
|
| 36 |
+
if chromadb is None:
|
| 37 |
+
return None
|
| 38 |
+
cls._client = chromadb.Client()
|
| 39 |
+
cls._collection = cls._client.create_collection("docs")
|
| 40 |
+
docs = [
|
| 41 |
+
"KeyError occurs when a dictionary key is missing. Use dict.get() or check 'if key in dict'.",
|
| 42 |
+
"pylint error C0304: missing final newline. Add a newline at the end of file.",
|
| 43 |
+
"Deadlock happens when two threads acquire locks in opposite order. Always acquire locks in the same order.",
|
| 44 |
+
"Division by zero: check if list is empty before calculating average, or use try/except.",
|
| 45 |
+
"Threading.Lock: use 'with lock:' to automatically acquire and release.",
|
| 46 |
+
"Off-by-one errors: adjust loop ranges, e.g., range(1, len(arr)-1).",
|
| 47 |
+
]
|
| 48 |
+
embedder = cls._get_embedder()
|
| 49 |
+
if embedder is None:
|
| 50 |
+
return None
|
| 51 |
+
embeddings = embedder.encode(docs).tolist()
|
| 52 |
+
for i, doc in enumerate(docs):
|
| 53 |
+
cls._collection.add(ids=[str(i)], documents=[doc], embeddings=[embeddings[i]])
|
| 54 |
+
return cls._collection
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def run_linter(code: str) -> str:
|
| 58 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f:
|
| 59 |
+
f.write(code)
|
| 60 |
+
f.flush()
|
| 61 |
+
tmp_path = f.name
|
| 62 |
+
try:
|
| 63 |
+
result = subprocess.run(
|
| 64 |
+
[sys.executable, "-m", "pylint", tmp_path, "--exit-zero", "--output-format=text"],
|
| 65 |
+
capture_output=True,
|
| 66 |
+
text=True,
|
| 67 |
+
timeout=10,
|
| 68 |
+
encoding="utf-8",
|
| 69 |
+
)
|
| 70 |
+
output = result.stdout
|
| 71 |
+
if "Your code has been rated" in output:
|
| 72 |
+
output = output.split("Your code has been rated")[0]
|
| 73 |
+
output = output.strip()
|
| 74 |
+
if not output:
|
| 75 |
+
return "No linting issues found."
|
| 76 |
+
return output[:500]
|
| 77 |
+
except FileNotFoundError:
|
| 78 |
+
return "Linter (pylint) not installed."
|
| 79 |
+
except subprocess.TimeoutExpired:
|
| 80 |
+
return "Linter timed out."
|
| 81 |
+
except Exception as e:
|
| 82 |
+
return f"Linter error: {str(e)}"
|
| 83 |
+
finally:
|
| 84 |
+
try:
|
| 85 |
+
os.unlink(tmp_path)
|
| 86 |
+
except OSError:
|
| 87 |
+
pass
|
| 88 |
+
|
| 89 |
+
@staticmethod
|
| 90 |
+
def run_tests(test_script: str) -> str:
|
| 91 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f:
|
| 92 |
+
f.write(test_script)
|
| 93 |
+
f.flush()
|
| 94 |
+
tmp_path = f.name
|
| 95 |
+
try:
|
| 96 |
+
result = subprocess.run(
|
| 97 |
+
[sys.executable, tmp_path],
|
| 98 |
+
capture_output=True,
|
| 99 |
+
text=True,
|
| 100 |
+
timeout=10,
|
| 101 |
+
encoding="utf-8",
|
| 102 |
+
)
|
| 103 |
+
output = result.stdout + result.stderr
|
| 104 |
+
return output.strip() or "Test executed successfully (no output)."
|
| 105 |
+
except subprocess.TimeoutExpired:
|
| 106 |
+
return "Test execution timed out."
|
| 107 |
+
except Exception as e:
|
| 108 |
+
return f"Test runner error: {str(e)}"
|
| 109 |
+
finally:
|
| 110 |
+
try:
|
| 111 |
+
os.unlink(tmp_path)
|
| 112 |
+
except OSError:
|
| 113 |
+
pass
|
| 114 |
+
|
| 115 |
+
@classmethod
|
| 116 |
+
def query_docs(cls, topic: str) -> str:
|
| 117 |
+
"""Retrieve top 3 relevant docs; fall back cleanly when vector deps are missing."""
|
| 118 |
+
try:
|
| 119 |
+
embedder = cls._get_embedder()
|
| 120 |
+
collection = cls._get_collection()
|
| 121 |
+
if embedder is None or collection is None:
|
| 122 |
+
raise RuntimeError("Vector retrieval dependencies are unavailable")
|
| 123 |
+
query_emb = embedder.encode([topic]).tolist()
|
| 124 |
+
results = collection.query(query_embeddings=query_emb, n_results=3)
|
| 125 |
+
if results["documents"] and results["documents"][0]:
|
| 126 |
+
snippets = []
|
| 127 |
+
for i, doc in enumerate(results["documents"][0]):
|
| 128 |
+
snippets.append(f"[{i + 1}] {doc}")
|
| 129 |
+
return "Relevant documentation:\n" + "\n".join(snippets)
|
| 130 |
+
return "No relevant documentation found."
|
| 131 |
+
except Exception:
|
| 132 |
+
topic_lower = topic.lower()
|
| 133 |
+
fallback = {
|
| 134 |
+
"null check": "To avoid KeyError, use 'if key in dict:' before accessing.",
|
| 135 |
+
"keyerror": "Catch KeyError with try/except or use dict.get().",
|
| 136 |
+
"deadlock": "Always acquire locks in the same order to avoid deadlock.",
|
| 137 |
+
"race": "Protect shared state with a lock or make the update atomic.",
|
| 138 |
+
"division": "Guard empty inputs before dividing or return a safe default.",
|
| 139 |
+
}
|
| 140 |
+
for key, value in fallback.items():
|
| 141 |
+
if key in topic_lower:
|
| 142 |
+
return value
|
| 143 |
+
return "No relevant documentation found. Try being more specific."
|
rubrics.py
CHANGED
|
@@ -1,136 +1,136 @@
|
|
| 1 |
-
# rubrics.py β Self-contained Rubrics (no external OpenEnv dependency)
|
| 2 |
-
|
| 3 |
-
class Rubric:
|
| 4 |
-
"""Minimal Rubric base β compatible with OpenEnv but selfβcontained."""
|
| 5 |
-
def __call__(self, env, action, obs, reward, done, info):
|
| 6 |
-
return 0.0
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
# --------------------------------------------------------------------------------
|
| 10 |
-
# 1. TOOLβUSAGE BONUS
|
| 11 |
-
# --------------------------------------------------------------------------------
|
| 12 |
-
class ToolUsageRubric(Rubric):
|
| 13 |
-
def __init__(self, bonus: float = 0.05):
|
| 14 |
-
self.bonus = bonus
|
| 15 |
-
|
| 16 |
-
def __call__(self, env, action, obs, reward, done, info):
|
| 17 |
-
score = 0.0
|
| 18 |
-
action_type = info.get("action_type", "")
|
| 19 |
-
# Use pre-action flags from `info` so first-use bonuses are
|
| 20 |
-
# computed correctly even though env flags are mutated in-step.
|
| 21 |
-
prev_tests_run = info.get("prev_tests_run", env._tests_run)
|
| 22 |
-
prev_linter_run = info.get("prev_linter_run", env._linter_run)
|
| 23 |
-
prev_docs_queried = info.get("prev_docs_queried", env._docs_queried)
|
| 24 |
-
|
| 25 |
-
if action_type == "run_tests":
|
| 26 |
-
if not prev_tests_run:
|
| 27 |
-
score += self.bonus
|
| 28 |
-
score += 0.015
|
| 29 |
-
elif action_type == "run_linter":
|
| 30 |
-
if not prev_linter_run:
|
| 31 |
-
score += self.bonus
|
| 32 |
-
score += 0.015
|
| 33 |
-
elif action_type == "query_docs":
|
| 34 |
-
if not prev_docs_queried:
|
| 35 |
-
score += self.bonus * 0.5
|
| 36 |
-
# Encourage docs usage when it is likely useful:
|
| 37 |
-
# - early exploration phase
|
| 38 |
-
# - non-trivial query text
|
| 39 |
-
if env._step_count <= 4 and info.get("docs_query_len", 0) >= 8:
|
| 40 |
-
score += 0.01
|
| 41 |
-
# Discourage repeated docs calls after the first-use signal.
|
| 42 |
-
if prev_docs_queried:
|
| 43 |
-
score -= 0.01
|
| 44 |
-
elif action_type == "question" and env._step_count <= 3:
|
| 45 |
-
score += 0.02
|
| 46 |
-
return score
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# --------------------------------------------------------------------------------
|
| 50 |
-
# 2. DELTAβBASED REWARDS
|
| 51 |
-
# --------------------------------------------------------------------------------
|
| 52 |
-
class TestDeltaRubric(Rubric):
|
| 53 |
-
def __init__(self, weight: float = 0.3):
|
| 54 |
-
self.weight = weight
|
| 55 |
-
|
| 56 |
-
def __call__(self, env, action, obs, reward, done, info):
|
| 57 |
-
delta = env._current_test_score - env._previous_test_score
|
| 58 |
-
effective = self.weight
|
| 59 |
-
if info.get("action_type") == "fix":
|
| 60 |
-
effective *= 0.4
|
| 61 |
-
return effective * delta
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
class LintDeltaRubric(Rubric):
|
| 65 |
-
def __init__(self, weight: float = 0.3):
|
| 66 |
-
self.weight = weight
|
| 67 |
-
|
| 68 |
-
def __call__(self, env, action, obs, reward, done, info):
|
| 69 |
-
delta = env._current_lint_score - env._previous_lint_score
|
| 70 |
-
effective = self.weight * 0.5
|
| 71 |
-
if info.get("action_type") == "fix":
|
| 72 |
-
effective *= 0.4
|
| 73 |
-
return effective * delta
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
# --------------------------------------------------------------------------------
|
| 77 |
-
# 3. TERMINAL SUCCESS BONUS
|
| 78 |
-
# --------------------------------------------------------------------------------
|
| 79 |
-
class TerminalSuccessRubric(Rubric):
|
| 80 |
-
def __call__(self, env, action, obs, reward, done, info):
|
| 81 |
-
if info.get("action_type") != "fix":
|
| 82 |
-
return 0.0
|
| 83 |
-
score = 0.0
|
| 84 |
-
if env._current_test_score > 0.95:
|
| 85 |
-
score += 0.4
|
| 86 |
-
elif env._current_test_score > 0.85:
|
| 87 |
-
score += 0.2
|
| 88 |
-
return score
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
# --------------------------------------------------------------------------------
|
| 92 |
-
# 4. EXPLORATION & DIVERSITY
|
| 93 |
-
# --------------------------------------------------------------------------------
|
| 94 |
-
class ExplorationRubric(Rubric):
|
| 95 |
-
def __init__(self, penalty: float = -0.05, bonus: float = 0.021):
|
| 96 |
-
self.penalty = penalty
|
| 97 |
-
self.bonus = bonus
|
| 98 |
-
|
| 99 |
-
def __call__(self, env, action, obs, reward, done, info):
|
| 100 |
-
if len(env._action_history) < 3:
|
| 101 |
-
return 0.0
|
| 102 |
-
recent = env._action_history[-3:]
|
| 103 |
-
unique = len(set(recent))
|
| 104 |
-
if unique == 1:
|
| 105 |
-
return self.penalty
|
| 106 |
-
elif unique == 3:
|
| 107 |
-
return self.bonus
|
| 108 |
-
return 0.0
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
# --------------------------------------------------------------------------------
|
| 112 |
-
# 5. ANTIβHACKING & CONSISTENCY
|
| 113 |
-
# --------------------------------------------------------------------------------
|
| 114 |
-
class AntiHackingRubric(Rubric):
|
| 115 |
-
def __call__(self, env, action, obs, reward, done, info):
|
| 116 |
-
if info.get("action_type") != "fix":
|
| 117 |
-
return 0.0
|
| 118 |
-
score = 0.0
|
| 119 |
-
if not env._tests_run:
|
| 120 |
-
score -= 0.25
|
| 121 |
-
if env._step_count < 2:
|
| 122 |
-
score -= 0.1
|
| 123 |
-
if env._tests_run and env._linter_run:
|
| 124 |
-
score += 0.02
|
| 125 |
-
return score
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
# --------------------------------------------------------------------------------
|
| 129 |
-
# 6. STEP PENALTY
|
| 130 |
-
# --------------------------------------------------------------------------------
|
| 131 |
-
class StepPenaltyRubric(Rubric):
|
| 132 |
-
def __init__(self, penalty: float = -0.01):
|
| 133 |
-
self.penalty = penalty
|
| 134 |
-
|
| 135 |
-
def __call__(self, env, action, obs, reward, done, info):
|
| 136 |
-
return self.penalty
|
|
|
|
| 1 |
+
# rubrics.py β Self-contained Rubrics (no external OpenEnv dependency)
|
| 2 |
+
|
| 3 |
+
class Rubric:
|
| 4 |
+
"""Minimal Rubric base β compatible with OpenEnv but selfβcontained."""
|
| 5 |
+
def __call__(self, env, action, obs, reward, done, info):
|
| 6 |
+
return 0.0
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# --------------------------------------------------------------------------------
|
| 10 |
+
# 1. TOOLβUSAGE BONUS
|
| 11 |
+
# --------------------------------------------------------------------------------
|
| 12 |
+
class ToolUsageRubric(Rubric):
|
| 13 |
+
def __init__(self, bonus: float = 0.05):
|
| 14 |
+
self.bonus = bonus
|
| 15 |
+
|
| 16 |
+
def __call__(self, env, action, obs, reward, done, info):
|
| 17 |
+
score = 0.0
|
| 18 |
+
action_type = info.get("action_type", "")
|
| 19 |
+
# Use pre-action flags from `info` so first-use bonuses are
|
| 20 |
+
# computed correctly even though env flags are mutated in-step.
|
| 21 |
+
prev_tests_run = info.get("prev_tests_run", env._tests_run)
|
| 22 |
+
prev_linter_run = info.get("prev_linter_run", env._linter_run)
|
| 23 |
+
prev_docs_queried = info.get("prev_docs_queried", env._docs_queried)
|
| 24 |
+
|
| 25 |
+
if action_type == "run_tests":
|
| 26 |
+
if not prev_tests_run:
|
| 27 |
+
score += self.bonus
|
| 28 |
+
score += 0.015
|
| 29 |
+
elif action_type == "run_linter":
|
| 30 |
+
if not prev_linter_run:
|
| 31 |
+
score += self.bonus
|
| 32 |
+
score += 0.015
|
| 33 |
+
elif action_type == "query_docs":
|
| 34 |
+
if not prev_docs_queried:
|
| 35 |
+
score += self.bonus * 0.5
|
| 36 |
+
# Encourage docs usage when it is likely useful:
|
| 37 |
+
# - early exploration phase
|
| 38 |
+
# - non-trivial query text
|
| 39 |
+
if env._step_count <= 4 and info.get("docs_query_len", 0) >= 8:
|
| 40 |
+
score += 0.01
|
| 41 |
+
# Discourage repeated docs calls after the first-use signal.
|
| 42 |
+
if prev_docs_queried:
|
| 43 |
+
score -= 0.01
|
| 44 |
+
elif action_type == "question" and env._step_count <= 3:
|
| 45 |
+
score += 0.02
|
| 46 |
+
return score
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# --------------------------------------------------------------------------------
|
| 50 |
+
# 2. DELTAβBASED REWARDS
|
| 51 |
+
# --------------------------------------------------------------------------------
|
| 52 |
+
class TestDeltaRubric(Rubric):
|
| 53 |
+
def __init__(self, weight: float = 0.3):
|
| 54 |
+
self.weight = weight
|
| 55 |
+
|
| 56 |
+
def __call__(self, env, action, obs, reward, done, info):
|
| 57 |
+
delta = env._current_test_score - env._previous_test_score
|
| 58 |
+
effective = self.weight
|
| 59 |
+
if info.get("action_type") == "fix":
|
| 60 |
+
effective *= 0.4
|
| 61 |
+
return effective * delta
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class LintDeltaRubric(Rubric):
|
| 65 |
+
def __init__(self, weight: float = 0.3):
|
| 66 |
+
self.weight = weight
|
| 67 |
+
|
| 68 |
+
def __call__(self, env, action, obs, reward, done, info):
|
| 69 |
+
delta = env._current_lint_score - env._previous_lint_score
|
| 70 |
+
effective = self.weight * 0.5
|
| 71 |
+
if info.get("action_type") == "fix":
|
| 72 |
+
effective *= 0.4
|
| 73 |
+
return effective * delta
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# --------------------------------------------------------------------------------
|
| 77 |
+
# 3. TERMINAL SUCCESS BONUS
|
| 78 |
+
# --------------------------------------------------------------------------------
|
| 79 |
+
class TerminalSuccessRubric(Rubric):
|
| 80 |
+
def __call__(self, env, action, obs, reward, done, info):
|
| 81 |
+
if info.get("action_type") != "fix":
|
| 82 |
+
return 0.0
|
| 83 |
+
score = 0.0
|
| 84 |
+
if env._current_test_score > 0.95:
|
| 85 |
+
score += 0.4
|
| 86 |
+
elif env._current_test_score > 0.85:
|
| 87 |
+
score += 0.2
|
| 88 |
+
return score
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# --------------------------------------------------------------------------------
|
| 92 |
+
# 4. EXPLORATION & DIVERSITY
|
| 93 |
+
# --------------------------------------------------------------------------------
|
| 94 |
+
class ExplorationRubric(Rubric):
|
| 95 |
+
def __init__(self, penalty: float = -0.05, bonus: float = 0.021):
|
| 96 |
+
self.penalty = penalty
|
| 97 |
+
self.bonus = bonus
|
| 98 |
+
|
| 99 |
+
def __call__(self, env, action, obs, reward, done, info):
|
| 100 |
+
if len(env._action_history) < 3:
|
| 101 |
+
return 0.0
|
| 102 |
+
recent = env._action_history[-3:]
|
| 103 |
+
unique = len(set(recent))
|
| 104 |
+
if unique == 1:
|
| 105 |
+
return self.penalty
|
| 106 |
+
elif unique == 3:
|
| 107 |
+
return self.bonus
|
| 108 |
+
return 0.0
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# --------------------------------------------------------------------------------
|
| 112 |
+
# 5. ANTIβHACKING & CONSISTENCY
|
| 113 |
+
# --------------------------------------------------------------------------------
|
| 114 |
+
class AntiHackingRubric(Rubric):
|
| 115 |
+
def __call__(self, env, action, obs, reward, done, info):
|
| 116 |
+
if info.get("action_type") != "fix":
|
| 117 |
+
return 0.0
|
| 118 |
+
score = 0.0
|
| 119 |
+
if not env._tests_run:
|
| 120 |
+
score -= 0.25
|
| 121 |
+
if env._step_count < 2:
|
| 122 |
+
score -= 0.1
|
| 123 |
+
if env._tests_run and env._linter_run:
|
| 124 |
+
score += 0.02
|
| 125 |
+
return score
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# --------------------------------------------------------------------------------
|
| 129 |
+
# 6. STEP PENALTY
|
| 130 |
+
# --------------------------------------------------------------------------------
|
| 131 |
+
class StepPenaltyRubric(Rubric):
|
| 132 |
+
def __init__(self, penalty: float = -0.01):
|
| 133 |
+
self.penalty = penalty
|
| 134 |
+
|
| 135 |
+
def __call__(self, env, action, obs, reward, done, info):
|
| 136 |
+
return self.penalty
|
training.py
CHANGED
|
@@ -1,935 +1,935 @@
|
|
| 1 |
-
# training.py β Memoryβsafe: Phiβ3βmini + Expert Demos + Fast PPO (2 iterations)
|
| 2 |
-
import os
|
| 3 |
-
os.environ["TRITON_DISABLE"] = "1"
|
| 4 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Issue #12: prevent OOM from parallel tokenization
|
| 5 |
-
|
| 6 |
-
import torch._dynamo
|
| 7 |
-
torch._dynamo.config.disable = True
|
| 8 |
-
import json
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn.functional as F
|
| 11 |
-
from torch.optim import AdamW
|
| 12 |
-
from dataclasses import dataclass
|
| 13 |
-
from typing import List, Dict, Tuple, Optional
|
| 14 |
-
import numpy as np
|
| 15 |
-
import re
|
| 16 |
-
import random
|
| 17 |
-
import matplotlib.pyplot as plt
|
| 18 |
-
|
| 19 |
-
from unsloth import FastLanguageModel
|
| 20 |
-
from transformers import TrainingArguments
|
| 21 |
-
from trl import SFTTrainer
|
| 22 |
-
from datasets import Dataset
|
| 23 |
-
|
| 24 |
-
from environment import CodeReviewEnv
|
| 25 |
-
from redteam import BUG_DB
|
| 26 |
-
from models import (
|
| 27 |
-
RunTests, RunLinter, Inspect,
|
| 28 |
-
ProposeFix, WriteComment, AskQuestion,
|
| 29 |
-
Done, Skip, QueryDocs, map_to_env as model_map_to_env
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
# ======================================================================
|
| 33 |
-
@dataclass
|
| 34 |
-
class AgentAction:
|
| 35 |
-
action_type: str
|
| 36 |
-
content: Optional[str] = None
|
| 37 |
-
|
| 38 |
-
def parse_action(output: str) -> AgentAction:
|
| 39 |
-
try:
|
| 40 |
-
data = json.loads(output)
|
| 41 |
-
return AgentAction(
|
| 42 |
-
action_type=data.get("action_type", "").lower(),
|
| 43 |
-
content=data.get("content")
|
| 44 |
-
)
|
| 45 |
-
except:
|
| 46 |
-
pass
|
| 47 |
-
json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', output, re.DOTALL)
|
| 48 |
-
if json_match:
|
| 49 |
-
try:
|
| 50 |
-
data = json.loads(json_match.group(1))
|
| 51 |
-
return AgentAction(
|
| 52 |
-
action_type=data.get("action_type", "").lower(),
|
| 53 |
-
content=data.get("content")
|
| 54 |
-
)
|
| 55 |
-
except:
|
| 56 |
-
pass
|
| 57 |
-
action_pattern = r'"action_type"\s*:\s*"(\w+)"'
|
| 58 |
-
match = re.search(action_pattern, output)
|
| 59 |
-
if match:
|
| 60 |
-
return AgentAction(action_type=match.group(1).lower())
|
| 61 |
-
output_lower = output.lower()
|
| 62 |
-
if "test" in output_lower:
|
| 63 |
-
return AgentAction("run_tests")
|
| 64 |
-
if "lint" in output_lower:
|
| 65 |
-
return AgentAction("run_linter")
|
| 66 |
-
if "inspect" in output_lower:
|
| 67 |
-
return AgentAction("inspect")
|
| 68 |
-
if "doc" in output_lower or "documentation" in output_lower:
|
| 69 |
-
return AgentAction("query_docs", "bug fix guidance")
|
| 70 |
-
return AgentAction("invalid", output)
|
| 71 |
-
|
| 72 |
-
def map_to_env(action: AgentAction):
|
| 73 |
-
return model_map_to_env(action.action_type, action.content)
|
| 74 |
-
|
| 75 |
-
# ======================================================================
|
| 76 |
-
def load_model():
|
| 77 |
-
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 78 |
-
model_name="unsloth/Phi-3-mini-4k-instruct-bnb-4bit",
|
| 79 |
-
max_seq_length=480, # smaller window for memory
|
| 80 |
-
load_in_4bit=True,
|
| 81 |
-
)
|
| 82 |
-
model = FastLanguageModel.get_peft_model(
|
| 83 |
-
model,
|
| 84 |
-
r=16,
|
| 85 |
-
target_modules=[
|
| 86 |
-
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 87 |
-
"gate_proj", "up_proj", "down_proj"
|
| 88 |
-
],
|
| 89 |
-
lora_alpha=32,
|
| 90 |
-
lora_dropout=0.0,
|
| 91 |
-
)
|
| 92 |
-
return model, tokenizer
|
| 93 |
-
|
| 94 |
-
def test_model_sanity(model, tokenizer) -> bool:
|
| 95 |
-
print("\n" + "="*60)
|
| 96 |
-
print("SANITY CHECK: Testing base model generation")
|
| 97 |
-
print("="*60)
|
| 98 |
-
test_prompt = "Hello, how are you?"
|
| 99 |
-
messages = [{"role": "user", "content": test_prompt}]
|
| 100 |
-
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 101 |
-
inputs = tokenizer(formatted, return_tensors="pt", max_length=256, truncation=True).to("cuda")
|
| 102 |
-
with torch.no_grad():
|
| 103 |
-
outputs = model.generate(
|
| 104 |
-
**inputs,
|
| 105 |
-
max_new_tokens=30,
|
| 106 |
-
do_sample=True,
|
| 107 |
-
temperature=0.7,
|
| 108 |
-
min_new_tokens=1,
|
| 109 |
-
eos_token_id=tokenizer.eos_token_id,
|
| 110 |
-
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 111 |
-
)
|
| 112 |
-
generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
|
| 113 |
-
response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 114 |
-
print(f"Prompt: {test_prompt}")
|
| 115 |
-
print(f"Response: {repr(response)}")
|
| 116 |
-
if len(response) == 0:
|
| 117 |
-
print("β Model produces empty output β cannot train.")
|
| 118 |
-
return False
|
| 119 |
-
print("β Model sanity check PASSED\n")
|
| 120 |
-
return True
|
| 121 |
-
|
| 122 |
-
# ======================================================================
|
| 123 |
-
def _expert_fix_from_context(obs) -> str:
|
| 124 |
-
"""
|
| 125 |
-
Build a conservative fix template named `fix` (required by tests).
|
| 126 |
-
Uses bug hints + code snippet patterns to create realistic fixes.
|
| 127 |
-
"""
|
| 128 |
-
bug = (getattr(obs, "bug_description", "") or "").lower()
|
| 129 |
-
code = getattr(obs, "code_snippet", "") or ""
|
| 130 |
-
|
| 131 |
-
if "division" in bug or "average" in code.lower():
|
| 132 |
-
return (
|
| 133 |
-
"def fix(data):\n"
|
| 134 |
-
" if not data:\n"
|
| 135 |
-
" return 0\n"
|
| 136 |
-
" return sum(data) / len(data)"
|
| 137 |
-
)
|
| 138 |
-
|
| 139 |
-
if "operator" in bug or "sign" in bug:
|
| 140 |
-
return (
|
| 141 |
-
"def fix(a, b):\n"
|
| 142 |
-
" return a + b"
|
| 143 |
-
)
|
| 144 |
-
|
| 145 |
-
if "off_by_one" in bug or "loop" in bug:
|
| 146 |
-
return (
|
| 147 |
-
"def fix(items):\n"
|
| 148 |
-
" return len(items)"
|
| 149 |
-
)
|
| 150 |
-
|
| 151 |
-
if "null" in bug or "key" in bug or "dict" in code.lower():
|
| 152 |
-
return (
|
| 153 |
-
"def fix(payload):\n"
|
| 154 |
-
" users = payload.get('users', {})\n"
|
| 155 |
-
" user_id = payload.get('id')\n"
|
| 156 |
-
" return users.get(user_id)"
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
-
# Concurrency-heavy tasks (harder/hardest).
|
| 160 |
-
if "race" in bug or "missing_lock" in bug or "thread_safe" in bug or "global_nonatomic" in bug:
|
| 161 |
-
return (
|
| 162 |
-
"import threading\n"
|
| 163 |
-
"_lock = threading.Lock()\n"
|
| 164 |
-
"\n"
|
| 165 |
-
"def fix(counter):\n"
|
| 166 |
-
" with _lock:\n"
|
| 167 |
-
" if counter is None:\n"
|
| 168 |
-
" return 0\n"
|
| 169 |
-
" return counter + 1"
|
| 170 |
-
)
|
| 171 |
-
|
| 172 |
-
if "deadlock" in bug or "double_lock" in bug or "lock order" in bug or "nested_lock" in bug:
|
| 173 |
-
return (
|
| 174 |
-
"import threading\n"
|
| 175 |
-
"_lock_a = threading.Lock()\n"
|
| 176 |
-
"_lock_b = threading.Lock()\n"
|
| 177 |
-
"\n"
|
| 178 |
-
"def fix(work):\n"
|
| 179 |
-
" first, second = (_lock_a, _lock_b)\n"
|
| 180 |
-
" if id(first) > id(second):\n"
|
| 181 |
-
" first, second = second, first\n"
|
| 182 |
-
" with first:\n"
|
| 183 |
-
" with second:\n"
|
| 184 |
-
" return work() if callable(work) else work"
|
| 185 |
-
)
|
| 186 |
-
|
| 187 |
-
if "fork_join" in bug or "join" in bug:
|
| 188 |
-
return (
|
| 189 |
-
"import threading\n"
|
| 190 |
-
"\n"
|
| 191 |
-
"def fix(worker):\n"
|
| 192 |
-
" t = threading.Thread(target=worker)\n"
|
| 193 |
-
" t.start()\n"
|
| 194 |
-
" t.join()\n"
|
| 195 |
-
" return True"
|
| 196 |
-
)
|
| 197 |
-
|
| 198 |
-
# Generic safe fallback keeps the RL pipeline alive for unknown bugs.
|
| 199 |
-
return (
|
| 200 |
-
"def fix(data):\n"
|
| 201 |
-
" if data is None:\n"
|
| 202 |
-
" return None\n"
|
| 203 |
-
" return data"
|
| 204 |
-
)
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
def _expert_supervised_policy(obs) -> str:
|
| 208 |
-
"""
|
| 209 |
-
Real workflow policy:
|
| 210 |
-
inspect -> tests/linter -> docs -> fix -> negotiate -> done.
|
| 211 |
-
"""
|
| 212 |
-
author_msg = (getattr(obs, "author_response", "") or "").lower()
|
| 213 |
-
tool_output = (getattr(obs, "last_tool_output", "") or "").lower()
|
| 214 |
-
|
| 215 |
-
if not getattr(obs, "tests_run", False):
|
| 216 |
-
if "inspect" not in tool_output:
|
| 217 |
-
return '{"action_type": "inspect"}'
|
| 218 |
-
return '{"action_type": "run_tests"}'
|
| 219 |
-
|
| 220 |
-
if not getattr(obs, "linter_run", False):
|
| 221 |
-
return '{"action_type": "run_linter"}'
|
| 222 |
-
|
| 223 |
-
if not getattr(obs, "docs_queried", False):
|
| 224 |
-
return '{"action_type": "query_docs", "content": "python bug fixing best practices for edge cases and null safety"}'
|
| 225 |
-
|
| 226 |
-
# Use docs again on hard tasks when evidence is still weak.
|
| 227 |
-
if getattr(obs, "current_test_score", 0.0) < 0.6 and getattr(obs, "step", 0) >= 3:
|
| 228 |
-
bug_hint = (getattr(obs, "bug_description", "") or "concurrency bug").replace('"', "'")
|
| 229 |
-
return json.dumps(
|
| 230 |
-
{
|
| 231 |
-
"action_type": "query_docs",
|
| 232 |
-
"content": f"python {bug_hint} lock ordering race condition mitigation patterns",
|
| 233 |
-
}
|
| 234 |
-
)
|
| 235 |
-
|
| 236 |
-
# If test quality is poor, propose a concrete fix.
|
| 237 |
-
if getattr(obs, "current_test_score", 0.0) < 0.95:
|
| 238 |
-
fix_code = _expert_fix_from_context(obs)
|
| 239 |
-
return json.dumps({"action_type": "fix", "content": fix_code})
|
| 240 |
-
|
| 241 |
-
# If author is still unconvinced, provide causal explanation.
|
| 242 |
-
if author_msg and ("not convinced" in author_msg or "explain" in author_msg or "brief" in author_msg):
|
| 243 |
-
return (
|
| 244 |
-
'{"action_type": "comment", "content": "This fix works because it handles the failing edge case directly, '
|
| 245 |
-
'keeps behavior deterministic, and aligns with the observed test and lint feedback. '
|
| 246 |
-
'The change is intentionally small to reduce regression risk."}'
|
| 247 |
-
)
|
| 248 |
-
|
| 249 |
-
# If negotiation is strong enough and quality is good, terminate.
|
| 250 |
-
conf = float(getattr(obs, "author_confidence", 0.0))
|
| 251 |
-
threshold = float(getattr(obs, "author_threshold", 0.5))
|
| 252 |
-
score = float(getattr(obs, "current_test_score", 0.0))
|
| 253 |
-
if conf >= threshold and score >= 0.8:
|
| 254 |
-
return '{"action_type": "done"}'
|
| 255 |
-
|
| 256 |
-
# Nudge conversation forward when tests are okay but acceptance is pending.
|
| 257 |
-
return (
|
| 258 |
-
'{"action_type": "question", "content": "Would you like a quick walkthrough of a failing scenario, the root cause, and how the fix prevents regressions?"}'
|
| 259 |
-
)
|
| 260 |
-
|
| 261 |
-
# ======================================================================
|
| 262 |
-
def supervised_warmup(model, tokenizer, env, n_episodes=16, epochs=1, max_steps=8):
|
| 263 |
-
print("\n" + "="*60)
|
| 264 |
-
print("SUPERVISED WARM-UP: Real environment demonstrations")
|
| 265 |
-
print("="*60)
|
| 266 |
-
|
| 267 |
-
examples = []
|
| 268 |
-
tasks = ["easy", "medium", "hard", "harder", "hardest"]
|
| 269 |
-
for ep in range(n_episodes):
|
| 270 |
-
task = random.choice(tasks)
|
| 271 |
-
env.set_task(task)
|
| 272 |
-
obs = env.reset()
|
| 273 |
-
history = []
|
| 274 |
-
done = False
|
| 275 |
-
|
| 276 |
-
steps = 0
|
| 277 |
-
while not done and steps < max_steps:
|
| 278 |
-
prompt = build_prompt(obs, history)
|
| 279 |
-
action_text = _expert_supervised_policy(obs)
|
| 280 |
-
action = parse_action(action_text)
|
| 281 |
-
env_action = map_to_env(action)
|
| 282 |
-
next_obs, _, done, _ = env.step(env_action)
|
| 283 |
-
|
| 284 |
-
messages = [
|
| 285 |
-
{"role": "user", "content": prompt},
|
| 286 |
-
{"role": "assistant", "content": action_text},
|
| 287 |
-
]
|
| 288 |
-
full_text = tokenizer.apply_chat_template(messages, tokenize=False)
|
| 289 |
-
examples.append({"text": full_text})
|
| 290 |
-
|
| 291 |
-
history.append(f"Agent: {action_text}")
|
| 292 |
-
history.append(f"Env: {next_obs.last_tool_output}")
|
| 293 |
-
history = history[-8:]
|
| 294 |
-
obs = next_obs
|
| 295 |
-
steps += 1
|
| 296 |
-
|
| 297 |
-
print(f"Supervised episode {ep+1}: task={task}, steps={steps}, done={done}")
|
| 298 |
-
|
| 299 |
-
if not examples:
|
| 300 |
-
print("No supervised examples generated; skipping warm-up.")
|
| 301 |
-
return
|
| 302 |
-
|
| 303 |
-
dataset = Dataset.from_list(examples)
|
| 304 |
-
trainer = SFTTrainer(
|
| 305 |
-
model=model,
|
| 306 |
-
tokenizer=tokenizer,
|
| 307 |
-
train_dataset=dataset,
|
| 308 |
-
dataset_text_field="text",
|
| 309 |
-
max_seq_length=480,
|
| 310 |
-
args=TrainingArguments(
|
| 311 |
-
output_dir="warmup_output",
|
| 312 |
-
num_train_epochs=epochs,
|
| 313 |
-
per_device_train_batch_size=2,
|
| 314 |
-
gradient_accumulation_steps=2,
|
| 315 |
-
learning_rate=2e-5,
|
| 316 |
-
logging_steps=50,
|
| 317 |
-
save_strategy="no",
|
| 318 |
-
bf16=True,
|
| 319 |
-
),
|
| 320 |
-
)
|
| 321 |
-
print(f"Training on {len(examples)} real env examples for {epochs} epochs...")
|
| 322 |
-
trainer.train()
|
| 323 |
-
print("β Supervised warm-up (real env) complete\n")
|
| 324 |
-
torch.cuda.empty_cache()
|
| 325 |
-
|
| 326 |
-
# ======================================================================
|
| 327 |
-
def generate_action_with_logprob(prompt, model, tokenizer, temperature=0.0, max_retries=2):
|
| 328 |
-
messages = [{"role": "user", "content": prompt}]
|
| 329 |
-
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 330 |
-
inputs = tokenizer(formatted, return_tensors="pt", max_length=480, truncation=True).to("cuda")
|
| 331 |
-
|
| 332 |
-
for attempt in range(max_retries):
|
| 333 |
-
with torch.no_grad():
|
| 334 |
-
outputs = model.generate(
|
| 335 |
-
**inputs,
|
| 336 |
-
max_new_tokens=64,
|
| 337 |
-
do_sample=(temperature > 0),
|
| 338 |
-
temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
|
| 339 |
-
min_new_tokens=1,
|
| 340 |
-
return_dict_in_generate=True,
|
| 341 |
-
output_scores=True,
|
| 342 |
-
)
|
| 343 |
-
generated_ids = outputs.sequences[0][inputs['input_ids'].shape[1]:]
|
| 344 |
-
action_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 345 |
-
|
| 346 |
-
logprobs = []
|
| 347 |
-
for idx, token_id in enumerate(generated_ids):
|
| 348 |
-
if idx < len(outputs.scores):
|
| 349 |
-
token_logits = outputs.scores[idx][0]
|
| 350 |
-
token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
|
| 351 |
-
logprobs.append(token_logprob)
|
| 352 |
-
total_logprob = sum(logprobs) if logprobs else -100.0
|
| 353 |
-
|
| 354 |
-
if not action_text:
|
| 355 |
-
fallback_actions = [
|
| 356 |
-
'{"action_type": "run_tests"}',
|
| 357 |
-
'{"action_type": "run_linter"}',
|
| 358 |
-
'{"action_type": "inspect"}',
|
| 359 |
-
'{"action_type": "skip"}',
|
| 360 |
-
]
|
| 361 |
-
action_text = random.choice(fallback_actions)
|
| 362 |
-
total_logprob = -50.0
|
| 363 |
-
print(f"[WARN] Empty generation β using fallback: {action_text}")
|
| 364 |
-
return action_text, total_logprob
|
| 365 |
-
|
| 366 |
-
try:
|
| 367 |
-
json.loads(action_text)
|
| 368 |
-
return action_text, total_logprob
|
| 369 |
-
except:
|
| 370 |
-
if attempt == max_retries - 1:
|
| 371 |
-
return '{"action_type":"skip"}', -100.0
|
| 372 |
-
continue
|
| 373 |
-
return '{"action_type":"skip"}', -100.0
|
| 374 |
-
|
| 375 |
-
# ======================================================================
|
| 376 |
-
def build_prompt(obs, history_lines: List[str]) -> str:
|
| 377 |
-
author_msg = getattr(obs, "author_response", "") or ""
|
| 378 |
-
tool_output = getattr(obs, "last_tool_output", "") or ""
|
| 379 |
-
author_personality = getattr(obs, "author_personality", "defensive")
|
| 380 |
-
|
| 381 |
-
prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix and name your proposed fix function fix.
|
| 382 |
-
|
| 383 |
-
The developer has a **{author_personality}** personality and will only accept if you provide solid evidence:
|
| 384 |
-
- Tests pass (high pass ratio)
|
| 385 |
-
- Lint is clean (zero errors)
|
| 386 |
-
- Documentation or references are provided
|
| 387 |
-
- Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
|
| 388 |
-
|
| 389 |
-
Workflow:
|
| 390 |
-
1. Use `inspect` to understand the code.
|
| 391 |
-
2. Use `run_tests` and `run_linter` to gather evidence.
|
| 392 |
-
3. Use `query_docs` when you need references or language-specific guidance.
|
| 393 |
-
4. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
|
| 394 |
-
5. If the developer pushes back, read their response carefully and address their specific concern.
|
| 395 |
-
6. Once convinced, use `done` to finish.
|
| 396 |
-
|
| 397 |
-
Code:
|
| 398 |
-
{obs.code_snippet}
|
| 399 |
-
|
| 400 |
-
Author says:
|
| 401 |
-
{author_msg if author_msg else "(no response yet β start with inspection)"}
|
| 402 |
-
|
| 403 |
-
Last tool output:
|
| 404 |
-
{tool_output if tool_output else "(none)"}
|
| 405 |
-
|
| 406 |
-
Available actions:
|
| 407 |
-
run_tests, run_linter, inspect, query_docs, fix, comment, question, done
|
| 408 |
-
|
| 409 |
-
Respond ONLY in JSON:
|
| 410 |
-
{{"action_type": "...", "content": "..."}}"""
|
| 411 |
-
|
| 412 |
-
if history_lines:
|
| 413 |
-
history = "\n".join(history_lines[-6:])
|
| 414 |
-
prompt += f"\n\nPrevious steps:\n{history}"
|
| 415 |
-
return prompt
|
| 416 |
-
|
| 417 |
-
# ======================================================================
|
| 418 |
-
@dataclass
|
| 419 |
-
class Trajectory:
|
| 420 |
-
states: List[str]
|
| 421 |
-
actions: List[str]
|
| 422 |
-
rewards: List[float]
|
| 423 |
-
logprobs: List[float]
|
| 424 |
-
dones: List[bool]
|
| 425 |
-
def __len__(self): return len(self.states)
|
| 426 |
-
|
| 427 |
-
def collect_trajectory(env, model, tokenizer, max_steps=6, temperature=0.0):
|
| 428 |
-
obs = env.reset()
|
| 429 |
-
history_lines = []
|
| 430 |
-
states, actions, rewards, logprobs, dones = [], [], [], [], []
|
| 431 |
-
for step in range(max_steps):
|
| 432 |
-
prompt = build_prompt(obs, history_lines)
|
| 433 |
-
states.append(prompt)
|
| 434 |
-
action_text, logprob = generate_action_with_logprob(prompt, model, tokenizer, temperature)
|
| 435 |
-
actions.append(action_text)
|
| 436 |
-
logprobs.append(logprob)
|
| 437 |
-
action = parse_action(action_text)
|
| 438 |
-
env_action = map_to_env(action)
|
| 439 |
-
next_obs, reward, done, _ = env.step(env_action)
|
| 440 |
-
rewards.append(reward.value)
|
| 441 |
-
dones.append(done)
|
| 442 |
-
history_lines.append(f"Agent: {action_text}")
|
| 443 |
-
history_lines.append(f"Env: {next_obs.last_tool_output}")
|
| 444 |
-
obs = next_obs
|
| 445 |
-
if done: break
|
| 446 |
-
return Trajectory(states, actions, rewards, logprobs, dones)
|
| 447 |
-
|
| 448 |
-
def collect_trajectories(env, model, tokenizer, n_trajectories, max_steps=6,
|
| 449 |
-
task_levels=None, task_weights=None):
|
| 450 |
-
if task_levels is None:
|
| 451 |
-
task_levels = list(BUG_DB.keys())
|
| 452 |
-
if task_weights is not None and len(task_weights) != len(task_levels):
|
| 453 |
-
raise ValueError("task_weights must match task_levels length")
|
| 454 |
-
if task_weights is not None and sum(task_weights) <= 0:
|
| 455 |
-
raise ValueError("task_weights must have a positive total")
|
| 456 |
-
trajectories = []
|
| 457 |
-
for i in range(n_trajectories):
|
| 458 |
-
sampled_task = random.choices(task_levels, weights=task_weights, k=1)[0]
|
| 459 |
-
env.set_task(sampled_task)
|
| 460 |
-
traj = collect_trajectory(env, model, tokenizer, max_steps)
|
| 461 |
-
total_reward = sum(traj.rewards)
|
| 462 |
-
print(f"Trajectory {i+1}/{n_trajectories}: task={sampled_task}, steps={len(traj)}, reward={total_reward:.3f}")
|
| 463 |
-
trajectories.append(traj)
|
| 464 |
-
return trajectories
|
| 465 |
-
|
| 466 |
-
def compute_returns_and_advantages(rewards, dones, gamma=0.99, standardize=True):
|
| 467 |
-
"""
|
| 468 |
-
Compute discounted returns and REINFORCE-style baseline advantages.
|
| 469 |
-
Advantages are centered and optionally standardised.
|
| 470 |
-
"""
|
| 471 |
-
n = len(rewards)
|
| 472 |
-
returns = [0.0]*n
|
| 473 |
-
running = 0.0
|
| 474 |
-
for t in reversed(range(n)):
|
| 475 |
-
if dones[t]: running = 0.0
|
| 476 |
-
running = rewards[t] + gamma * running
|
| 477 |
-
returns[t] = running
|
| 478 |
-
if standardize:
|
| 479 |
-
advantages = np.array(returns) - np.mean(returns)
|
| 480 |
-
adv_std = np.std(advantages) + 1e-8
|
| 481 |
-
advantages = (advantages / adv_std).tolist()
|
| 482 |
-
else:
|
| 483 |
-
advantages = returns.copy()
|
| 484 |
-
return advantages, returns
|
| 485 |
-
|
| 486 |
-
def ppo_update(trajectories, model, tokenizer, optimizer, n_epochs=1, clip_epsilon=0.2,
|
| 487 |
-
entropy_coef=0.01, gamma=0.99):
|
| 488 |
-
model.train()
|
| 489 |
-
all_states, all_actions, all_old_logprobs, all_advantages = [], [], [], []
|
| 490 |
-
for traj in trajectories:
|
| 491 |
-
advantages, _ = compute_returns_and_advantages(traj.rewards, traj.dones, gamma=gamma, standardize=True)
|
| 492 |
-
all_states.extend(traj.states)
|
| 493 |
-
all_actions.extend(traj.actions)
|
| 494 |
-
all_old_logprobs.extend(traj.logprobs)
|
| 495 |
-
all_advantages.extend(advantages)
|
| 496 |
-
n_samples = len(all_states)
|
| 497 |
-
total_loss, total_policy_loss, total_entropy, n_updates = 0.0, 0.0, 0.0, 0
|
| 498 |
-
for epoch in range(n_epochs):
|
| 499 |
-
indices = np.random.permutation(n_samples)
|
| 500 |
-
for i in indices:
|
| 501 |
-
state = all_states[i]
|
| 502 |
-
action = all_actions[i]
|
| 503 |
-
old_logprob = all_old_logprobs[i]
|
| 504 |
-
advantage = all_advantages[i]
|
| 505 |
-
messages = [{"role": "user", "content": state}]
|
| 506 |
-
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 507 |
-
full_text = formatted + action
|
| 508 |
-
inputs = tokenizer(full_text, return_tensors="pt", max_length=480, truncation=True).to("cuda")
|
| 509 |
-
outputs = model(**inputs)
|
| 510 |
-
logits = outputs.logits
|
| 511 |
-
action_ids = tokenizer.encode(action, add_special_tokens=False)
|
| 512 |
-
prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
|
| 513 |
-
action_start = len(prefix_ids)
|
| 514 |
-
logprobs = []
|
| 515 |
-
entropy = 0.0
|
| 516 |
-
for idx, token_id in enumerate(action_ids):
|
| 517 |
-
position = action_start + idx - 1
|
| 518 |
-
if 0 <= position < logits.shape[1]:
|
| 519 |
-
token_logits = logits[0, position]
|
| 520 |
-
log_probs = F.log_softmax(token_logits, dim=-1)
|
| 521 |
-
token_logprob = log_probs[token_id]
|
| 522 |
-
logprobs.append(token_logprob)
|
| 523 |
-
probs = F.softmax(token_logits, dim=-1)
|
| 524 |
-
entropy += -(probs * log_probs).sum()
|
| 525 |
-
if not logprobs: continue
|
| 526 |
-
new_logprob = sum(logprobs)
|
| 527 |
-
avg_entropy = entropy / len(logprobs) if logprobs else 0.0
|
| 528 |
-
ratio = torch.exp(new_logprob - old_logprob)
|
| 529 |
-
surr1 = ratio * advantage
|
| 530 |
-
surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantage
|
| 531 |
-
policy_loss = -torch.min(surr1, surr2)
|
| 532 |
-
loss = policy_loss - entropy_coef * avg_entropy
|
| 533 |
-
optimizer.zero_grad()
|
| 534 |
-
loss.backward()
|
| 535 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 536 |
-
optimizer.step()
|
| 537 |
-
total_loss += loss.item()
|
| 538 |
-
total_policy_loss += policy_loss.item()
|
| 539 |
-
total_entropy += avg_entropy.item()
|
| 540 |
-
n_updates += 1
|
| 541 |
-
torch.cuda.empty_cache()
|
| 542 |
-
return {"loss": total_loss / n_updates if n_updates else 0.0,
|
| 543 |
-
"policy_loss": total_policy_loss / n_updates if n_updates else 0.0,
|
| 544 |
-
"entropy": total_entropy / n_updates if n_updates else 0.0}
|
| 545 |
-
|
| 546 |
-
def evaluate_policy(env, model, tokenizer, n_episodes=3, max_steps=6,
|
| 547 |
-
task_levels=None, verbose=False):
|
| 548 |
-
"""Evaluate the current policy across task levels. Returns metrics + optional traces."""
|
| 549 |
-
model.eval()
|
| 550 |
-
if task_levels is None:
|
| 551 |
-
task_levels = list(BUG_DB.keys())
|
| 552 |
-
total_rewards = []
|
| 553 |
-
traces = [] # human-readable behavior logs
|
| 554 |
-
for ep in range(n_episodes):
|
| 555 |
-
task = task_levels[ep % len(task_levels)]
|
| 556 |
-
env.set_task(task)
|
| 557 |
-
traj = collect_trajectory(env, model, tokenizer, max_steps, temperature=0.0)
|
| 558 |
-
ep_reward = sum(traj.rewards)
|
| 559 |
-
total_rewards.append(ep_reward)
|
| 560 |
-
if verbose:
|
| 561 |
-
actions_taken = []
|
| 562 |
-
for a in traj.actions:
|
| 563 |
-
try:
|
| 564 |
-
actions_taken.append(json.loads(a).get("action_type", "?"))
|
| 565 |
-
except Exception:
|
| 566 |
-
actions_taken.append("?")
|
| 567 |
-
traces.append({
|
| 568 |
-
"task": task,
|
| 569 |
-
"reward": round(ep_reward, 4),
|
| 570 |
-
"steps": len(traj),
|
| 571 |
-
"actions": actions_taken,
|
| 572 |
-
})
|
| 573 |
-
return {
|
| 574 |
-
"avg_reward": float(np.mean(total_rewards)),
|
| 575 |
-
"std_reward": float(np.std(total_rewards)),
|
| 576 |
-
"min_reward": float(np.min(total_rewards)),
|
| 577 |
-
"max_reward": float(np.max(total_rewards)),
|
| 578 |
-
"traces": traces,
|
| 579 |
-
}
|
| 580 |
-
|
| 581 |
-
# ======================================================================
|
| 582 |
-
# MANUAL WARM-UP (no SFTTrainer β no multiprocessing OOM)
|
| 583 |
-
# ======================================================================
|
| 584 |
-
def json_warmup(model, tokenizer, json_path="training_data.json",
|
| 585 |
-
n_episodes=20, epochs=2, lr=2e-5):
|
| 586 |
-
"""
|
| 587 |
-
Supervised warm-up from pre-generated expert demonstrations.
|
| 588 |
-
Uses raw cross-entropy on action tokens with manual gradient steps.
|
| 589 |
-
NO SFTTrainer, NO multiprocessing β runs safely on any GPU.
|
| 590 |
-
"""
|
| 591 |
-
print("\n" + "="*60)
|
| 592 |
-
print("SUPERVISED WARM-UP: training_data.json (manual cross-entropy)")
|
| 593 |
-
print("="*60)
|
| 594 |
-
|
| 595 |
-
with open(json_path, encoding="utf-8") as f:
|
| 596 |
-
data = json.load(f)
|
| 597 |
-
|
| 598 |
-
# Each episode = 7 steps. Select n_episodes worth.
|
| 599 |
-
steps_per_episode = 7
|
| 600 |
-
max_examples = n_episodes * steps_per_episode
|
| 601 |
-
if max_examples < len(data):
|
| 602 |
-
data = data[:max_examples]
|
| 603 |
-
|
| 604 |
-
print(f" {len(data)} examples ({len(data)//steps_per_episode} episodes), "
|
| 605 |
-
f"{epochs} epoch(s), lr={lr}")
|
| 606 |
-
|
| 607 |
-
model.train()
|
| 608 |
-
warmup_opt = AdamW(model.parameters(), lr=lr)
|
| 609 |
-
warmup_losses = [] # per-epoch avg loss
|
| 610 |
-
|
| 611 |
-
for epoch in range(epochs):
|
| 612 |
-
random.shuffle(data)
|
| 613 |
-
epoch_loss = 0.0
|
| 614 |
-
n_valid = 0
|
| 615 |
-
|
| 616 |
-
for i, example in enumerate(data):
|
| 617 |
-
prompt = example["prompt"]
|
| 618 |
-
action = example["action"]
|
| 619 |
-
|
| 620 |
-
# ---- tokenize full sequence (prompt + action) ----
|
| 621 |
-
messages = [
|
| 622 |
-
{"role": "user", "content": prompt},
|
| 623 |
-
{"role": "assistant", "content": action},
|
| 624 |
-
]
|
| 625 |
-
full_text = tokenizer.apply_chat_template(messages, tokenize=False)
|
| 626 |
-
inputs = tokenizer(full_text, return_tensors="pt",
|
| 627 |
-
max_length=480, truncation=True).to("cuda")
|
| 628 |
-
|
| 629 |
-
# ---- find where the action tokens start ----
|
| 630 |
-
prompt_only = tokenizer.apply_chat_template(
|
| 631 |
-
[{"role": "user", "content": prompt}],
|
| 632 |
-
tokenize=False, add_generation_prompt=True
|
| 633 |
-
)
|
| 634 |
-
prompt_ids = tokenizer.encode(prompt_only, add_special_tokens=False)
|
| 635 |
-
prompt_len = len(prompt_ids)
|
| 636 |
-
|
| 637 |
-
total_len = inputs.input_ids.shape[1]
|
| 638 |
-
if prompt_len >= total_len:
|
| 639 |
-
continue # prompt was truncated away, skip
|
| 640 |
-
|
| 641 |
-
# ---- cross-entropy on action tokens only ----
|
| 642 |
-
outputs = model(**inputs)
|
| 643 |
-
logits = outputs.logits
|
| 644 |
-
|
| 645 |
-
# next-token prediction: logits[t] predicts token[t+1]
|
| 646 |
-
shift_logits = logits[0, prompt_len - 1 : total_len - 1]
|
| 647 |
-
shift_labels = inputs.input_ids[0, prompt_len : total_len]
|
| 648 |
-
|
| 649 |
-
min_len = min(shift_logits.shape[0], shift_labels.shape[0])
|
| 650 |
-
if min_len == 0:
|
| 651 |
-
continue
|
| 652 |
-
|
| 653 |
-
loss = F.cross_entropy(shift_logits[:min_len], shift_labels[:min_len])
|
| 654 |
-
|
| 655 |
-
warmup_opt.zero_grad()
|
| 656 |
-
loss.backward()
|
| 657 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 658 |
-
warmup_opt.step()
|
| 659 |
-
|
| 660 |
-
epoch_loss += loss.item()
|
| 661 |
-
n_valid += 1
|
| 662 |
-
|
| 663 |
-
if (i + 1) % 25 == 0:
|
| 664 |
-
avg = epoch_loss / n_valid
|
| 665 |
-
print(f" epoch {epoch+1} step {i+1:3d}/{len(data)} "
|
| 666 |
-
f"running_loss={avg:.4f}")
|
| 667 |
-
|
| 668 |
-
avg_loss = epoch_loss / max(n_valid, 1)
|
| 669 |
-
warmup_losses.append(avg_loss)
|
| 670 |
-
print(f" Epoch {epoch+1} done: avg_loss={avg_loss:.4f} "
|
| 671 |
-
f"({n_valid} valid examples)")
|
| 672 |
-
|
| 673 |
-
torch.cuda.empty_cache()
|
| 674 |
-
print(f"β Warm-up complete. Loss: "
|
| 675 |
-
f"{' β '.join(f'{l:.4f}' for l in warmup_losses)}\n")
|
| 676 |
-
return warmup_losses
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
# ======================================================================
|
| 680 |
-
# MAIN TRAINING PIPELINE
|
| 681 |
-
# ======================================================================
|
| 682 |
-
def train_ppo():
|
| 683 |
-
# --- Hyperparameters ---
|
| 684 |
-
n_iterations = 8 # enough for a clear upward trend
|
| 685 |
-
trajectories_per_iter = 4 # on-policy data per iteration
|
| 686 |
-
n_epochs = 1
|
| 687 |
-
max_steps = 6
|
| 688 |
-
learning_rate = 3e-5
|
| 689 |
-
clip_epsilon = 0.2
|
| 690 |
-
entropy_coef = 0.01
|
| 691 |
-
gamma = 0.99
|
| 692 |
-
|
| 693 |
-
# --- Pre-load embedder before LLM (Issue #13) ---
|
| 694 |
-
from rltool import ToolBox
|
| 695 |
-
print("Pre-loading sentence-transformer embedder...")
|
| 696 |
-
ToolBox._get_embedder()
|
| 697 |
-
print("β Embedder ready")
|
| 698 |
-
|
| 699 |
-
# --- Load model ---
|
| 700 |
-
print("Loading model...")
|
| 701 |
-
model, tokenizer = load_model()
|
| 702 |
-
if not test_model_sanity(model, tokenizer):
|
| 703 |
-
return
|
| 704 |
-
env = CodeReviewEnv()
|
| 705 |
-
task_levels = list(BUG_DB.keys())
|
| 706 |
-
|
| 707 |
-
# ==================================================================
|
| 708 |
-
# PHASE 0: BASELINE (untrained policy)
|
| 709 |
-
# ==================================================================
|
| 710 |
-
print("\n" + "="*60)
|
| 711 |
-
print("PHASE 0 β BASELINE EVALUATION (untrained)")
|
| 712 |
-
print("="*60)
|
| 713 |
-
baseline = evaluate_policy(env, model, tokenizer, n_episodes=5,
|
| 714 |
-
max_steps=max_steps, task_levels=task_levels,
|
| 715 |
-
verbose=True)
|
| 716 |
-
baseline_reward = baseline["avg_reward"]
|
| 717 |
-
print(f"Baseline avg reward: {baseline_reward:.4f} "
|
| 718 |
-
f"(min={baseline['min_reward']:.4f}, max={baseline['max_reward']:.4f})")
|
| 719 |
-
print("Baseline behavior:")
|
| 720 |
-
for t in baseline["traces"]:
|
| 721 |
-
print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
|
| 722 |
-
f"steps={t['steps']} actions={t['actions']}")
|
| 723 |
-
|
| 724 |
-
# ==================================================================
|
| 725 |
-
# PHASE 1: SUPERVISED WARM-UP (expert demos, manual CE)
|
| 726 |
-
# ==================================================================
|
| 727 |
-
warmup_losses = json_warmup(
|
| 728 |
-
model, tokenizer,
|
| 729 |
-
json_path="training_data.json",
|
| 730 |
-
n_episodes=20, # 140 examples (20 Γ 7 steps)
|
| 731 |
-
epochs=2,
|
| 732 |
-
lr=2e-5,
|
| 733 |
-
)
|
| 734 |
-
|
| 735 |
-
# Post-warmup evaluation
|
| 736 |
-
print("="*60)
|
| 737 |
-
print("POST WARM-UP EVALUATION")
|
| 738 |
-
print("="*60)
|
| 739 |
-
post_warmup = evaluate_policy(env, model, tokenizer, n_episodes=5,
|
| 740 |
-
max_steps=max_steps, task_levels=task_levels,
|
| 741 |
-
verbose=True)
|
| 742 |
-
warmup_reward = post_warmup["avg_reward"]
|
| 743 |
-
print(f"Post-warmup avg reward: {warmup_reward:.4f} "
|
| 744 |
-
f"(Ξ vs baseline: {warmup_reward - baseline_reward:+.4f})")
|
| 745 |
-
print("Post-warmup behavior:")
|
| 746 |
-
for t in post_warmup["traces"]:
|
| 747 |
-
print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
|
| 748 |
-
f"steps={t['steps']} actions={t['actions']}")
|
| 749 |
-
|
| 750 |
-
# ==================================================================
|
| 751 |
-
# PHASE 2: TRUE RL β PPO (on-policy, real environment interaction)
|
| 752 |
-
# ==================================================================
|
| 753 |
-
optimizer = AdamW(model.parameters(), lr=learning_rate)
|
| 754 |
-
print(f"\n{'='*60}")
|
| 755 |
-
print(f"PHASE 2 β PPO TRAINING: {n_iterations} iterations Γ "
|
| 756 |
-
f"{trajectories_per_iter} trajectories (true RL)")
|
| 757 |
-
print(f"{'='*60}\n")
|
| 758 |
-
|
| 759 |
-
reward_history = []
|
| 760 |
-
eval_history = []
|
| 761 |
-
loss_history = []
|
| 762 |
-
policy_loss_history = []
|
| 763 |
-
entropy_history = []
|
| 764 |
-
|
| 765 |
-
for iteration in range(n_iterations):
|
| 766 |
-
print(f"\n--- PPO Iteration {iteration + 1}/{n_iterations} ---")
|
| 767 |
-
|
| 768 |
-
# Collect on-policy trajectories from REAL environment
|
| 769 |
-
trajectories = collect_trajectories(
|
| 770 |
-
env, model, tokenizer, trajectories_per_iter, max_steps,
|
| 771 |
-
task_levels=task_levels, task_weights=None
|
| 772 |
-
)
|
| 773 |
-
avg_reward = float(np.mean([sum(t.rewards) for t in trajectories]))
|
| 774 |
-
reward_history.append(avg_reward)
|
| 775 |
-
print(f" Collect avg reward: {avg_reward:+.4f}")
|
| 776 |
-
|
| 777 |
-
# PPO policy gradient update
|
| 778 |
-
metrics = ppo_update(
|
| 779 |
-
trajectories, model, tokenizer, optimizer,
|
| 780 |
-
n_epochs=n_epochs, clip_epsilon=clip_epsilon,
|
| 781 |
-
entropy_coef=entropy_coef, gamma=gamma
|
| 782 |
-
)
|
| 783 |
-
loss_history.append(float(metrics["loss"]))
|
| 784 |
-
policy_loss_history.append(float(metrics["policy_loss"]))
|
| 785 |
-
entropy_history.append(float(metrics["entropy"]))
|
| 786 |
-
print(f" Update loss={metrics['loss']:.4f} "
|
| 787 |
-
f"policy={metrics['policy_loss']:.4f} "
|
| 788 |
-
f"entropy={metrics['entropy']:.4f}")
|
| 789 |
-
|
| 790 |
-
# Evaluate greedy policy after update
|
| 791 |
-
eval_m = evaluate_policy(env, model, tokenizer, n_episodes=3,
|
| 792 |
-
max_steps=max_steps, task_levels=task_levels,
|
| 793 |
-
verbose=False)
|
| 794 |
-
eval_history.append(eval_m["avg_reward"])
|
| 795 |
-
delta = eval_m["avg_reward"] - baseline_reward
|
| 796 |
-
print(f" Eval avg reward: {eval_m['avg_reward']:+.4f} "
|
| 797 |
-
f"(Ξ baseline: {delta:+.4f})")
|
| 798 |
-
|
| 799 |
-
# ==================================================================
|
| 800 |
-
# PHASE 3: FINAL EVALUATION (proof of learning)
|
| 801 |
-
# ==================================================================
|
| 802 |
-
print("\n" + "="*60)
|
| 803 |
-
print("PHASE 3 β FINAL EVALUATION (after all training)")
|
| 804 |
-
print("="*60)
|
| 805 |
-
final = evaluate_policy(env, model, tokenizer, n_episodes=5,
|
| 806 |
-
max_steps=max_steps, task_levels=task_levels,
|
| 807 |
-
verbose=True)
|
| 808 |
-
print(f"Final avg reward: {final['avg_reward']:.4f} "
|
| 809 |
-
f"(min={final['min_reward']:.4f}, max={final['max_reward']:.4f})")
|
| 810 |
-
print("Final behavior:")
|
| 811 |
-
for t in final["traces"]:
|
| 812 |
-
print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
|
| 813 |
-
f"steps={t['steps']} actions={t['actions']}")
|
| 814 |
-
|
| 815 |
-
total_improvement = final["avg_reward"] - baseline_reward
|
| 816 |
-
ppo_improvement = final["avg_reward"] - warmup_reward
|
| 817 |
-
print(f"\n{'='*60}")
|
| 818 |
-
print("TRAINING SUMMARY")
|
| 819 |
-
print(f" Baseline reward: {baseline_reward:+.4f}")
|
| 820 |
-
print(f" Post-warmup reward: {warmup_reward:+.4f} "
|
| 821 |
-
f"(warmup Ξ: {warmup_reward - baseline_reward:+.4f})")
|
| 822 |
-
print(f" Final reward: {final['avg_reward']:+.4f} "
|
| 823 |
-
f"(PPO Ξ: {ppo_improvement:+.4f})")
|
| 824 |
-
print(f" Total improvement: {total_improvement:+.4f}")
|
| 825 |
-
print(f" Reward trend (PPO): {' β '.join(f'{r:+.3f}' for r in reward_history)}")
|
| 826 |
-
print(f" Loss trend (PPO): {' β '.join(f'{l:.4f}' for l in loss_history)}")
|
| 827 |
-
if total_improvement > 0:
|
| 828 |
-
print(f" β Agent IMPROVED by {total_improvement:+.4f}")
|
| 829 |
-
else:
|
| 830 |
-
print(f" β No overall improvement detected")
|
| 831 |
-
print(f"{'='*60}")
|
| 832 |
-
|
| 833 |
-
# ==================================================================
|
| 834 |
-
# PLOTS
|
| 835 |
-
# ==================================================================
|
| 836 |
-
iters = list(range(1, n_iterations + 1))
|
| 837 |
-
|
| 838 |
-
# --- 1. Warm-up loss curve ---
|
| 839 |
-
if warmup_losses:
|
| 840 |
-
fig, ax = plt.subplots(figsize=(7, 4))
|
| 841 |
-
ax.plot(range(1, len(warmup_losses) + 1), warmup_losses,
|
| 842 |
-
marker="o", linewidth=2, color="tab:purple")
|
| 843 |
-
ax.set_title("Warm-up Loss (supervised, per epoch)",
|
| 844 |
-
fontsize=13, fontweight="bold")
|
| 845 |
-
ax.set_xlabel("Epoch")
|
| 846 |
-
ax.set_ylabel("Cross-Entropy Loss")
|
| 847 |
-
ax.grid(alpha=0.3)
|
| 848 |
-
fig.tight_layout()
|
| 849 |
-
fig.savefig("warmup_loss.png", dpi=150)
|
| 850 |
-
plt.close(fig)
|
| 851 |
-
|
| 852 |
-
# --- 2. PPO reward curve ---
|
| 853 |
-
fig, ax = plt.subplots(figsize=(9, 5))
|
| 854 |
-
ax.plot(iters, reward_history, marker="o", linewidth=2,
|
| 855 |
-
label="Collect reward", color="tab:blue")
|
| 856 |
-
ax.plot(iters, eval_history, marker="s", linewidth=2, linestyle="--",
|
| 857 |
-
label="Eval reward", color="tab:green")
|
| 858 |
-
ax.axhline(y=baseline_reward, color="tab:gray", linestyle=":",
|
| 859 |
-
linewidth=1.5, label=f"Baseline ({baseline_reward:+.3f})")
|
| 860 |
-
ax.axhline(y=warmup_reward, color="tab:purple", linestyle=":",
|
| 861 |
-
linewidth=1.5, label=f"Post-warmup ({warmup_reward:+.3f})")
|
| 862 |
-
ax.set_title("PPO Reward per Iteration", fontsize=14, fontweight="bold")
|
| 863 |
-
ax.set_xlabel("Iteration")
|
| 864 |
-
ax.set_ylabel("Average Reward")
|
| 865 |
-
ax.legend(loc="best", fontsize=8)
|
| 866 |
-
ax.grid(alpha=0.3)
|
| 867 |
-
fig.tight_layout()
|
| 868 |
-
fig.savefig("reward_curve.png", dpi=150)
|
| 869 |
-
plt.close(fig)
|
| 870 |
-
|
| 871 |
-
# --- 3. PPO loss curve ---
|
| 872 |
-
fig, ax = plt.subplots(figsize=(9, 5))
|
| 873 |
-
ax.plot(iters, loss_history, marker="o", linewidth=2,
|
| 874 |
-
label="Total loss", color="tab:red")
|
| 875 |
-
ax.plot(iters, policy_loss_history, marker="^", linewidth=2, linestyle="--",
|
| 876 |
-
label="Policy loss", color="tab:orange")
|
| 877 |
-
ax.set_title("PPO Loss per Iteration", fontsize=14, fontweight="bold")
|
| 878 |
-
ax.set_xlabel("Iteration")
|
| 879 |
-
ax.set_ylabel("Loss")
|
| 880 |
-
ax.legend(loc="best")
|
| 881 |
-
ax.grid(alpha=0.3)
|
| 882 |
-
fig.tight_layout()
|
| 883 |
-
fig.savefig("loss_curve.png", dpi=150)
|
| 884 |
-
plt.close(fig)
|
| 885 |
-
|
| 886 |
-
# --- 4. Combined 3-panel summary ---
|
| 887 |
-
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
|
| 888 |
-
|
| 889 |
-
# Panel A: warm-up loss
|
| 890 |
-
if warmup_losses:
|
| 891 |
-
axes[0].plot(range(1, len(warmup_losses) + 1), warmup_losses,
|
| 892 |
-
marker="o", linewidth=2, color="tab:purple")
|
| 893 |
-
axes[0].set_title("A. Warm-up Loss β")
|
| 894 |
-
axes[0].set_xlabel("Epoch")
|
| 895 |
-
axes[0].set_ylabel("CE Loss")
|
| 896 |
-
axes[0].grid(alpha=0.3)
|
| 897 |
-
|
| 898 |
-
# Panel B: PPO reward
|
| 899 |
-
axes[1].plot(iters, reward_history, marker="o", linewidth=2,
|
| 900 |
-
color="tab:blue", label="Collect")
|
| 901 |
-
axes[1].plot(iters, eval_history, marker="s", linewidth=2,
|
| 902 |
-
linestyle="--", color="tab:green", label="Eval")
|
| 903 |
-
axes[1].axhline(y=baseline_reward, color="tab:gray", linestyle=":",
|
| 904 |
-
linewidth=1.5, label="Baseline")
|
| 905 |
-
axes[1].axhline(y=warmup_reward, color="tab:purple", linestyle=":",
|
| 906 |
-
linewidth=1.5, label="Post-warmup")
|
| 907 |
-
axes[1].set_title("B. PPO Reward β")
|
| 908 |
-
axes[1].set_xlabel("Iteration")
|
| 909 |
-
axes[1].set_ylabel("Avg Reward")
|
| 910 |
-
axes[1].legend(fontsize=7)
|
| 911 |
-
axes[1].grid(alpha=0.3)
|
| 912 |
-
|
| 913 |
-
# Panel C: PPO loss
|
| 914 |
-
axes[2].plot(iters, loss_history, marker="o", linewidth=2,
|
| 915 |
-
color="tab:red", label="Total")
|
| 916 |
-
axes[2].plot(iters, policy_loss_history, marker="^", linewidth=2,
|
| 917 |
-
linestyle="--", color="tab:orange", label="Policy")
|
| 918 |
-
axes[2].set_title("C. PPO Loss β")
|
| 919 |
-
axes[2].set_xlabel("Iteration")
|
| 920 |
-
axes[2].set_ylabel("Loss")
|
| 921 |
-
axes[2].legend(fontsize=7)
|
| 922 |
-
axes[2].grid(alpha=0.3)
|
| 923 |
-
|
| 924 |
-
fig.suptitle("Code Review Agent β Full Training Evidence",
|
| 925 |
-
fontsize=14, fontweight="bold")
|
| 926 |
-
fig.tight_layout()
|
| 927 |
-
fig.savefig("training_summary.png", dpi=150)
|
| 928 |
-
plt.close(fig)
|
| 929 |
-
|
| 930 |
-
print("Plots saved: warmup_loss.png, reward_curve.png, "
|
| 931 |
-
"loss_curve.png, training_summary.png")
|
| 932 |
-
print("="*60)
|
| 933 |
-
|
| 934 |
-
if __name__ == "__main__":
|
| 935 |
train_ppo()
|
|
|
|
| 1 |
+
# training.py β Memoryβsafe: Phiβ3βmini + Expert Demos + Fast PPO (2 iterations)
|
| 2 |
+
import os
|
| 3 |
+
os.environ["TRITON_DISABLE"] = "1"
|
| 4 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Issue #12: prevent OOM from parallel tokenization
|
| 5 |
+
|
| 6 |
+
import torch._dynamo
|
| 7 |
+
torch._dynamo.config.disable = True
|
| 8 |
+
import json
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.optim import AdamW
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from typing import List, Dict, Tuple, Optional
|
| 14 |
+
import numpy as np
|
| 15 |
+
import re
|
| 16 |
+
import random
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
|
| 19 |
+
from unsloth import FastLanguageModel
|
| 20 |
+
from transformers import TrainingArguments
|
| 21 |
+
from trl import SFTTrainer
|
| 22 |
+
from datasets import Dataset
|
| 23 |
+
|
| 24 |
+
from environment import CodeReviewEnv
|
| 25 |
+
from redteam import BUG_DB
|
| 26 |
+
from models import (
|
| 27 |
+
RunTests, RunLinter, Inspect,
|
| 28 |
+
ProposeFix, WriteComment, AskQuestion,
|
| 29 |
+
Done, Skip, QueryDocs, map_to_env as model_map_to_env
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# ======================================================================
|
| 33 |
+
@dataclass
|
| 34 |
+
class AgentAction:
|
| 35 |
+
action_type: str
|
| 36 |
+
content: Optional[str] = None
|
| 37 |
+
|
| 38 |
+
def parse_action(output: str) -> AgentAction:
|
| 39 |
+
try:
|
| 40 |
+
data = json.loads(output)
|
| 41 |
+
return AgentAction(
|
| 42 |
+
action_type=data.get("action_type", "").lower(),
|
| 43 |
+
content=data.get("content")
|
| 44 |
+
)
|
| 45 |
+
except:
|
| 46 |
+
pass
|
| 47 |
+
json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', output, re.DOTALL)
|
| 48 |
+
if json_match:
|
| 49 |
+
try:
|
| 50 |
+
data = json.loads(json_match.group(1))
|
| 51 |
+
return AgentAction(
|
| 52 |
+
action_type=data.get("action_type", "").lower(),
|
| 53 |
+
content=data.get("content")
|
| 54 |
+
)
|
| 55 |
+
except:
|
| 56 |
+
pass
|
| 57 |
+
action_pattern = r'"action_type"\s*:\s*"(\w+)"'
|
| 58 |
+
match = re.search(action_pattern, output)
|
| 59 |
+
if match:
|
| 60 |
+
return AgentAction(action_type=match.group(1).lower())
|
| 61 |
+
output_lower = output.lower()
|
| 62 |
+
if "test" in output_lower:
|
| 63 |
+
return AgentAction("run_tests")
|
| 64 |
+
if "lint" in output_lower:
|
| 65 |
+
return AgentAction("run_linter")
|
| 66 |
+
if "inspect" in output_lower:
|
| 67 |
+
return AgentAction("inspect")
|
| 68 |
+
if "doc" in output_lower or "documentation" in output_lower:
|
| 69 |
+
return AgentAction("query_docs", "bug fix guidance")
|
| 70 |
+
return AgentAction("invalid", output)
|
| 71 |
+
|
| 72 |
+
def map_to_env(action: AgentAction):
|
| 73 |
+
return model_map_to_env(action.action_type, action.content)
|
| 74 |
+
|
| 75 |
+
# ======================================================================
|
| 76 |
+
def load_model():
|
| 77 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 78 |
+
model_name="unsloth/Phi-3-mini-4k-instruct-bnb-4bit",
|
| 79 |
+
max_seq_length=480, # smaller window for memory
|
| 80 |
+
load_in_4bit=True,
|
| 81 |
+
)
|
| 82 |
+
model = FastLanguageModel.get_peft_model(
|
| 83 |
+
model,
|
| 84 |
+
r=16,
|
| 85 |
+
target_modules=[
|
| 86 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 87 |
+
"gate_proj", "up_proj", "down_proj"
|
| 88 |
+
],
|
| 89 |
+
lora_alpha=32,
|
| 90 |
+
lora_dropout=0.0,
|
| 91 |
+
)
|
| 92 |
+
return model, tokenizer
|
| 93 |
+
|
| 94 |
+
def test_model_sanity(model, tokenizer) -> bool:
|
| 95 |
+
print("\n" + "="*60)
|
| 96 |
+
print("SANITY CHECK: Testing base model generation")
|
| 97 |
+
print("="*60)
|
| 98 |
+
test_prompt = "Hello, how are you?"
|
| 99 |
+
messages = [{"role": "user", "content": test_prompt}]
|
| 100 |
+
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 101 |
+
inputs = tokenizer(formatted, return_tensors="pt", max_length=256, truncation=True).to("cuda")
|
| 102 |
+
with torch.no_grad():
|
| 103 |
+
outputs = model.generate(
|
| 104 |
+
**inputs,
|
| 105 |
+
max_new_tokens=30,
|
| 106 |
+
do_sample=True,
|
| 107 |
+
temperature=0.7,
|
| 108 |
+
min_new_tokens=1,
|
| 109 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 110 |
+
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 111 |
+
)
|
| 112 |
+
generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
|
| 113 |
+
response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 114 |
+
print(f"Prompt: {test_prompt}")
|
| 115 |
+
print(f"Response: {repr(response)}")
|
| 116 |
+
if len(response) == 0:
|
| 117 |
+
print("β Model produces empty output β cannot train.")
|
| 118 |
+
return False
|
| 119 |
+
print("β Model sanity check PASSED\n")
|
| 120 |
+
return True
|
| 121 |
+
|
| 122 |
+
# ======================================================================
|
| 123 |
+
def _expert_fix_from_context(obs) -> str:
|
| 124 |
+
"""
|
| 125 |
+
Build a conservative fix template named `fix` (required by tests).
|
| 126 |
+
Uses bug hints + code snippet patterns to create realistic fixes.
|
| 127 |
+
"""
|
| 128 |
+
bug = (getattr(obs, "bug_description", "") or "").lower()
|
| 129 |
+
code = getattr(obs, "code_snippet", "") or ""
|
| 130 |
+
|
| 131 |
+
if "division" in bug or "average" in code.lower():
|
| 132 |
+
return (
|
| 133 |
+
"def fix(data):\n"
|
| 134 |
+
" if not data:\n"
|
| 135 |
+
" return 0\n"
|
| 136 |
+
" return sum(data) / len(data)"
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
if "operator" in bug or "sign" in bug:
|
| 140 |
+
return (
|
| 141 |
+
"def fix(a, b):\n"
|
| 142 |
+
" return a + b"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if "off_by_one" in bug or "loop" in bug:
|
| 146 |
+
return (
|
| 147 |
+
"def fix(items):\n"
|
| 148 |
+
" return len(items)"
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
if "null" in bug or "key" in bug or "dict" in code.lower():
|
| 152 |
+
return (
|
| 153 |
+
"def fix(payload):\n"
|
| 154 |
+
" users = payload.get('users', {})\n"
|
| 155 |
+
" user_id = payload.get('id')\n"
|
| 156 |
+
" return users.get(user_id)"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Concurrency-heavy tasks (harder/hardest).
|
| 160 |
+
if "race" in bug or "missing_lock" in bug or "thread_safe" in bug or "global_nonatomic" in bug:
|
| 161 |
+
return (
|
| 162 |
+
"import threading\n"
|
| 163 |
+
"_lock = threading.Lock()\n"
|
| 164 |
+
"\n"
|
| 165 |
+
"def fix(counter):\n"
|
| 166 |
+
" with _lock:\n"
|
| 167 |
+
" if counter is None:\n"
|
| 168 |
+
" return 0\n"
|
| 169 |
+
" return counter + 1"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if "deadlock" in bug or "double_lock" in bug or "lock order" in bug or "nested_lock" in bug:
|
| 173 |
+
return (
|
| 174 |
+
"import threading\n"
|
| 175 |
+
"_lock_a = threading.Lock()\n"
|
| 176 |
+
"_lock_b = threading.Lock()\n"
|
| 177 |
+
"\n"
|
| 178 |
+
"def fix(work):\n"
|
| 179 |
+
" first, second = (_lock_a, _lock_b)\n"
|
| 180 |
+
" if id(first) > id(second):\n"
|
| 181 |
+
" first, second = second, first\n"
|
| 182 |
+
" with first:\n"
|
| 183 |
+
" with second:\n"
|
| 184 |
+
" return work() if callable(work) else work"
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
if "fork_join" in bug or "join" in bug:
|
| 188 |
+
return (
|
| 189 |
+
"import threading\n"
|
| 190 |
+
"\n"
|
| 191 |
+
"def fix(worker):\n"
|
| 192 |
+
" t = threading.Thread(target=worker)\n"
|
| 193 |
+
" t.start()\n"
|
| 194 |
+
" t.join()\n"
|
| 195 |
+
" return True"
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Generic safe fallback keeps the RL pipeline alive for unknown bugs.
|
| 199 |
+
return (
|
| 200 |
+
"def fix(data):\n"
|
| 201 |
+
" if data is None:\n"
|
| 202 |
+
" return None\n"
|
| 203 |
+
" return data"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _expert_supervised_policy(obs) -> str:
|
| 208 |
+
"""
|
| 209 |
+
Real workflow policy:
|
| 210 |
+
inspect -> tests/linter -> docs -> fix -> negotiate -> done.
|
| 211 |
+
"""
|
| 212 |
+
author_msg = (getattr(obs, "author_response", "") or "").lower()
|
| 213 |
+
tool_output = (getattr(obs, "last_tool_output", "") or "").lower()
|
| 214 |
+
|
| 215 |
+
if not getattr(obs, "tests_run", False):
|
| 216 |
+
if "inspect" not in tool_output:
|
| 217 |
+
return '{"action_type": "inspect"}'
|
| 218 |
+
return '{"action_type": "run_tests"}'
|
| 219 |
+
|
| 220 |
+
if not getattr(obs, "linter_run", False):
|
| 221 |
+
return '{"action_type": "run_linter"}'
|
| 222 |
+
|
| 223 |
+
if not getattr(obs, "docs_queried", False):
|
| 224 |
+
return '{"action_type": "query_docs", "content": "python bug fixing best practices for edge cases and null safety"}'
|
| 225 |
+
|
| 226 |
+
# Use docs again on hard tasks when evidence is still weak.
|
| 227 |
+
if getattr(obs, "current_test_score", 0.0) < 0.6 and getattr(obs, "step", 0) >= 3:
|
| 228 |
+
bug_hint = (getattr(obs, "bug_description", "") or "concurrency bug").replace('"', "'")
|
| 229 |
+
return json.dumps(
|
| 230 |
+
{
|
| 231 |
+
"action_type": "query_docs",
|
| 232 |
+
"content": f"python {bug_hint} lock ordering race condition mitigation patterns",
|
| 233 |
+
}
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# If test quality is poor, propose a concrete fix.
|
| 237 |
+
if getattr(obs, "current_test_score", 0.0) < 0.95:
|
| 238 |
+
fix_code = _expert_fix_from_context(obs)
|
| 239 |
+
return json.dumps({"action_type": "fix", "content": fix_code})
|
| 240 |
+
|
| 241 |
+
# If author is still unconvinced, provide causal explanation.
|
| 242 |
+
if author_msg and ("not convinced" in author_msg or "explain" in author_msg or "brief" in author_msg):
|
| 243 |
+
return (
|
| 244 |
+
'{"action_type": "comment", "content": "This fix works because it handles the failing edge case directly, '
|
| 245 |
+
'keeps behavior deterministic, and aligns with the observed test and lint feedback. '
|
| 246 |
+
'The change is intentionally small to reduce regression risk."}'
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# If negotiation is strong enough and quality is good, terminate.
|
| 250 |
+
conf = float(getattr(obs, "author_confidence", 0.0))
|
| 251 |
+
threshold = float(getattr(obs, "author_threshold", 0.5))
|
| 252 |
+
score = float(getattr(obs, "current_test_score", 0.0))
|
| 253 |
+
if conf >= threshold and score >= 0.8:
|
| 254 |
+
return '{"action_type": "done"}'
|
| 255 |
+
|
| 256 |
+
# Nudge conversation forward when tests are okay but acceptance is pending.
|
| 257 |
+
return (
|
| 258 |
+
'{"action_type": "question", "content": "Would you like a quick walkthrough of a failing scenario, the root cause, and how the fix prevents regressions?"}'
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# ======================================================================
|
| 262 |
+
def supervised_warmup(model, tokenizer, env, n_episodes=16, epochs=1, max_steps=8):
|
| 263 |
+
print("\n" + "="*60)
|
| 264 |
+
print("SUPERVISED WARM-UP: Real environment demonstrations")
|
| 265 |
+
print("="*60)
|
| 266 |
+
|
| 267 |
+
examples = []
|
| 268 |
+
tasks = ["easy", "medium", "hard", "harder", "hardest"]
|
| 269 |
+
for ep in range(n_episodes):
|
| 270 |
+
task = random.choice(tasks)
|
| 271 |
+
env.set_task(task)
|
| 272 |
+
obs = env.reset()
|
| 273 |
+
history = []
|
| 274 |
+
done = False
|
| 275 |
+
|
| 276 |
+
steps = 0
|
| 277 |
+
while not done and steps < max_steps:
|
| 278 |
+
prompt = build_prompt(obs, history)
|
| 279 |
+
action_text = _expert_supervised_policy(obs)
|
| 280 |
+
action = parse_action(action_text)
|
| 281 |
+
env_action = map_to_env(action)
|
| 282 |
+
next_obs, _, done, _ = env.step(env_action)
|
| 283 |
+
|
| 284 |
+
messages = [
|
| 285 |
+
{"role": "user", "content": prompt},
|
| 286 |
+
{"role": "assistant", "content": action_text},
|
| 287 |
+
]
|
| 288 |
+
full_text = tokenizer.apply_chat_template(messages, tokenize=False)
|
| 289 |
+
examples.append({"text": full_text})
|
| 290 |
+
|
| 291 |
+
history.append(f"Agent: {action_text}")
|
| 292 |
+
history.append(f"Env: {next_obs.last_tool_output}")
|
| 293 |
+
history = history[-8:]
|
| 294 |
+
obs = next_obs
|
| 295 |
+
steps += 1
|
| 296 |
+
|
| 297 |
+
print(f"Supervised episode {ep+1}: task={task}, steps={steps}, done={done}")
|
| 298 |
+
|
| 299 |
+
if not examples:
|
| 300 |
+
print("No supervised examples generated; skipping warm-up.")
|
| 301 |
+
return
|
| 302 |
+
|
| 303 |
+
dataset = Dataset.from_list(examples)
|
| 304 |
+
trainer = SFTTrainer(
|
| 305 |
+
model=model,
|
| 306 |
+
tokenizer=tokenizer,
|
| 307 |
+
train_dataset=dataset,
|
| 308 |
+
dataset_text_field="text",
|
| 309 |
+
max_seq_length=480,
|
| 310 |
+
args=TrainingArguments(
|
| 311 |
+
output_dir="warmup_output",
|
| 312 |
+
num_train_epochs=epochs,
|
| 313 |
+
per_device_train_batch_size=2,
|
| 314 |
+
gradient_accumulation_steps=2,
|
| 315 |
+
learning_rate=2e-5,
|
| 316 |
+
logging_steps=50,
|
| 317 |
+
save_strategy="no",
|
| 318 |
+
bf16=True,
|
| 319 |
+
),
|
| 320 |
+
)
|
| 321 |
+
print(f"Training on {len(examples)} real env examples for {epochs} epochs...")
|
| 322 |
+
trainer.train()
|
| 323 |
+
print("β Supervised warm-up (real env) complete\n")
|
| 324 |
+
torch.cuda.empty_cache()
|
| 325 |
+
|
| 326 |
+
# ======================================================================
|
| 327 |
+
def generate_action_with_logprob(prompt, model, tokenizer, temperature=0.0, max_retries=2):
|
| 328 |
+
messages = [{"role": "user", "content": prompt}]
|
| 329 |
+
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 330 |
+
inputs = tokenizer(formatted, return_tensors="pt", max_length=480, truncation=True).to("cuda")
|
| 331 |
+
|
| 332 |
+
for attempt in range(max_retries):
|
| 333 |
+
with torch.no_grad():
|
| 334 |
+
outputs = model.generate(
|
| 335 |
+
**inputs,
|
| 336 |
+
max_new_tokens=64,
|
| 337 |
+
do_sample=(temperature > 0),
|
| 338 |
+
temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
|
| 339 |
+
min_new_tokens=1,
|
| 340 |
+
return_dict_in_generate=True,
|
| 341 |
+
output_scores=True,
|
| 342 |
+
)
|
| 343 |
+
generated_ids = outputs.sequences[0][inputs['input_ids'].shape[1]:]
|
| 344 |
+
action_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 345 |
+
|
| 346 |
+
logprobs = []
|
| 347 |
+
for idx, token_id in enumerate(generated_ids):
|
| 348 |
+
if idx < len(outputs.scores):
|
| 349 |
+
token_logits = outputs.scores[idx][0]
|
| 350 |
+
token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
|
| 351 |
+
logprobs.append(token_logprob)
|
| 352 |
+
total_logprob = sum(logprobs) if logprobs else -100.0
|
| 353 |
+
|
| 354 |
+
if not action_text:
|
| 355 |
+
fallback_actions = [
|
| 356 |
+
'{"action_type": "run_tests"}',
|
| 357 |
+
'{"action_type": "run_linter"}',
|
| 358 |
+
'{"action_type": "inspect"}',
|
| 359 |
+
'{"action_type": "skip"}',
|
| 360 |
+
]
|
| 361 |
+
action_text = random.choice(fallback_actions)
|
| 362 |
+
total_logprob = -50.0
|
| 363 |
+
print(f"[WARN] Empty generation β using fallback: {action_text}")
|
| 364 |
+
return action_text, total_logprob
|
| 365 |
+
|
| 366 |
+
try:
|
| 367 |
+
json.loads(action_text)
|
| 368 |
+
return action_text, total_logprob
|
| 369 |
+
except:
|
| 370 |
+
if attempt == max_retries - 1:
|
| 371 |
+
return '{"action_type":"skip"}', -100.0
|
| 372 |
+
continue
|
| 373 |
+
return '{"action_type":"skip"}', -100.0
|
| 374 |
+
|
| 375 |
+
# ======================================================================
|
| 376 |
+
def build_prompt(obs, history_lines: List[str]) -> str:
|
| 377 |
+
author_msg = getattr(obs, "author_response", "") or ""
|
| 378 |
+
tool_output = getattr(obs, "last_tool_output", "") or ""
|
| 379 |
+
author_personality = getattr(obs, "author_personality", "defensive")
|
| 380 |
+
|
| 381 |
+
prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix and name your proposed fix function fix.
|
| 382 |
+
|
| 383 |
+
The developer has a **{author_personality}** personality and will only accept if you provide solid evidence:
|
| 384 |
+
- Tests pass (high pass ratio)
|
| 385 |
+
- Lint is clean (zero errors)
|
| 386 |
+
- Documentation or references are provided
|
| 387 |
+
- Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
|
| 388 |
+
|
| 389 |
+
Workflow:
|
| 390 |
+
1. Use `inspect` to understand the code.
|
| 391 |
+
2. Use `run_tests` and `run_linter` to gather evidence.
|
| 392 |
+
3. Use `query_docs` when you need references or language-specific guidance.
|
| 393 |
+
4. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
|
| 394 |
+
5. If the developer pushes back, read their response carefully and address their specific concern.
|
| 395 |
+
6. Once convinced, use `done` to finish.
|
| 396 |
+
|
| 397 |
+
Code:
|
| 398 |
+
{obs.code_snippet}
|
| 399 |
+
|
| 400 |
+
Author says:
|
| 401 |
+
{author_msg if author_msg else "(no response yet β start with inspection)"}
|
| 402 |
+
|
| 403 |
+
Last tool output:
|
| 404 |
+
{tool_output if tool_output else "(none)"}
|
| 405 |
+
|
| 406 |
+
Available actions:
|
| 407 |
+
run_tests, run_linter, inspect, query_docs, fix, comment, question, done
|
| 408 |
+
|
| 409 |
+
Respond ONLY in JSON:
|
| 410 |
+
{{"action_type": "...", "content": "..."}}"""
|
| 411 |
+
|
| 412 |
+
if history_lines:
|
| 413 |
+
history = "\n".join(history_lines[-6:])
|
| 414 |
+
prompt += f"\n\nPrevious steps:\n{history}"
|
| 415 |
+
return prompt
|
| 416 |
+
|
| 417 |
+
# ======================================================================
|
| 418 |
+
@dataclass
|
| 419 |
+
class Trajectory:
|
| 420 |
+
states: List[str]
|
| 421 |
+
actions: List[str]
|
| 422 |
+
rewards: List[float]
|
| 423 |
+
logprobs: List[float]
|
| 424 |
+
dones: List[bool]
|
| 425 |
+
def __len__(self): return len(self.states)
|
| 426 |
+
|
| 427 |
+
def collect_trajectory(env, model, tokenizer, max_steps=6, temperature=0.0):
|
| 428 |
+
obs = env.reset()
|
| 429 |
+
history_lines = []
|
| 430 |
+
states, actions, rewards, logprobs, dones = [], [], [], [], []
|
| 431 |
+
for step in range(max_steps):
|
| 432 |
+
prompt = build_prompt(obs, history_lines)
|
| 433 |
+
states.append(prompt)
|
| 434 |
+
action_text, logprob = generate_action_with_logprob(prompt, model, tokenizer, temperature)
|
| 435 |
+
actions.append(action_text)
|
| 436 |
+
logprobs.append(logprob)
|
| 437 |
+
action = parse_action(action_text)
|
| 438 |
+
env_action = map_to_env(action)
|
| 439 |
+
next_obs, reward, done, _ = env.step(env_action)
|
| 440 |
+
rewards.append(reward.value)
|
| 441 |
+
dones.append(done)
|
| 442 |
+
history_lines.append(f"Agent: {action_text}")
|
| 443 |
+
history_lines.append(f"Env: {next_obs.last_tool_output}")
|
| 444 |
+
obs = next_obs
|
| 445 |
+
if done: break
|
| 446 |
+
return Trajectory(states, actions, rewards, logprobs, dones)
|
| 447 |
+
|
| 448 |
+
def collect_trajectories(env, model, tokenizer, n_trajectories, max_steps=6,
|
| 449 |
+
task_levels=None, task_weights=None):
|
| 450 |
+
if task_levels is None:
|
| 451 |
+
task_levels = list(BUG_DB.keys())
|
| 452 |
+
if task_weights is not None and len(task_weights) != len(task_levels):
|
| 453 |
+
raise ValueError("task_weights must match task_levels length")
|
| 454 |
+
if task_weights is not None and sum(task_weights) <= 0:
|
| 455 |
+
raise ValueError("task_weights must have a positive total")
|
| 456 |
+
trajectories = []
|
| 457 |
+
for i in range(n_trajectories):
|
| 458 |
+
sampled_task = random.choices(task_levels, weights=task_weights, k=1)[0]
|
| 459 |
+
env.set_task(sampled_task)
|
| 460 |
+
traj = collect_trajectory(env, model, tokenizer, max_steps)
|
| 461 |
+
total_reward = sum(traj.rewards)
|
| 462 |
+
print(f"Trajectory {i+1}/{n_trajectories}: task={sampled_task}, steps={len(traj)}, reward={total_reward:.3f}")
|
| 463 |
+
trajectories.append(traj)
|
| 464 |
+
return trajectories
|
| 465 |
+
|
| 466 |
+
def compute_returns_and_advantages(rewards, dones, gamma=0.99, standardize=True):
|
| 467 |
+
"""
|
| 468 |
+
Compute discounted returns and REINFORCE-style baseline advantages.
|
| 469 |
+
Advantages are centered and optionally standardised.
|
| 470 |
+
"""
|
| 471 |
+
n = len(rewards)
|
| 472 |
+
returns = [0.0]*n
|
| 473 |
+
running = 0.0
|
| 474 |
+
for t in reversed(range(n)):
|
| 475 |
+
if dones[t]: running = 0.0
|
| 476 |
+
running = rewards[t] + gamma * running
|
| 477 |
+
returns[t] = running
|
| 478 |
+
if standardize:
|
| 479 |
+
advantages = np.array(returns) - np.mean(returns)
|
| 480 |
+
adv_std = np.std(advantages) + 1e-8
|
| 481 |
+
advantages = (advantages / adv_std).tolist()
|
| 482 |
+
else:
|
| 483 |
+
advantages = returns.copy()
|
| 484 |
+
return advantages, returns
|
| 485 |
+
|
| 486 |
+
def ppo_update(trajectories, model, tokenizer, optimizer, n_epochs=1, clip_epsilon=0.2,
|
| 487 |
+
entropy_coef=0.01, gamma=0.99):
|
| 488 |
+
model.train()
|
| 489 |
+
all_states, all_actions, all_old_logprobs, all_advantages = [], [], [], []
|
| 490 |
+
for traj in trajectories:
|
| 491 |
+
advantages, _ = compute_returns_and_advantages(traj.rewards, traj.dones, gamma=gamma, standardize=True)
|
| 492 |
+
all_states.extend(traj.states)
|
| 493 |
+
all_actions.extend(traj.actions)
|
| 494 |
+
all_old_logprobs.extend(traj.logprobs)
|
| 495 |
+
all_advantages.extend(advantages)
|
| 496 |
+
n_samples = len(all_states)
|
| 497 |
+
total_loss, total_policy_loss, total_entropy, n_updates = 0.0, 0.0, 0.0, 0
|
| 498 |
+
for epoch in range(n_epochs):
|
| 499 |
+
indices = np.random.permutation(n_samples)
|
| 500 |
+
for i in indices:
|
| 501 |
+
state = all_states[i]
|
| 502 |
+
action = all_actions[i]
|
| 503 |
+
old_logprob = all_old_logprobs[i]
|
| 504 |
+
advantage = all_advantages[i]
|
| 505 |
+
messages = [{"role": "user", "content": state}]
|
| 506 |
+
formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 507 |
+
full_text = formatted + action
|
| 508 |
+
inputs = tokenizer(full_text, return_tensors="pt", max_length=480, truncation=True).to("cuda")
|
| 509 |
+
outputs = model(**inputs)
|
| 510 |
+
logits = outputs.logits
|
| 511 |
+
action_ids = tokenizer.encode(action, add_special_tokens=False)
|
| 512 |
+
prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
|
| 513 |
+
action_start = len(prefix_ids)
|
| 514 |
+
logprobs = []
|
| 515 |
+
entropy = 0.0
|
| 516 |
+
for idx, token_id in enumerate(action_ids):
|
| 517 |
+
position = action_start + idx - 1
|
| 518 |
+
if 0 <= position < logits.shape[1]:
|
| 519 |
+
token_logits = logits[0, position]
|
| 520 |
+
log_probs = F.log_softmax(token_logits, dim=-1)
|
| 521 |
+
token_logprob = log_probs[token_id]
|
| 522 |
+
logprobs.append(token_logprob)
|
| 523 |
+
probs = F.softmax(token_logits, dim=-1)
|
| 524 |
+
entropy += -(probs * log_probs).sum()
|
| 525 |
+
if not logprobs: continue
|
| 526 |
+
new_logprob = sum(logprobs)
|
| 527 |
+
avg_entropy = entropy / len(logprobs) if logprobs else 0.0
|
| 528 |
+
ratio = torch.exp(new_logprob - old_logprob)
|
| 529 |
+
surr1 = ratio * advantage
|
| 530 |
+
surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantage
|
| 531 |
+
policy_loss = -torch.min(surr1, surr2)
|
| 532 |
+
loss = policy_loss - entropy_coef * avg_entropy
|
| 533 |
+
optimizer.zero_grad()
|
| 534 |
+
loss.backward()
|
| 535 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 536 |
+
optimizer.step()
|
| 537 |
+
total_loss += loss.item()
|
| 538 |
+
total_policy_loss += policy_loss.item()
|
| 539 |
+
total_entropy += avg_entropy.item()
|
| 540 |
+
n_updates += 1
|
| 541 |
+
torch.cuda.empty_cache()
|
| 542 |
+
return {"loss": total_loss / n_updates if n_updates else 0.0,
|
| 543 |
+
"policy_loss": total_policy_loss / n_updates if n_updates else 0.0,
|
| 544 |
+
"entropy": total_entropy / n_updates if n_updates else 0.0}
|
| 545 |
+
|
| 546 |
+
def evaluate_policy(env, model, tokenizer, n_episodes=3, max_steps=6,
|
| 547 |
+
task_levels=None, verbose=False):
|
| 548 |
+
"""Evaluate the current policy across task levels. Returns metrics + optional traces."""
|
| 549 |
+
model.eval()
|
| 550 |
+
if task_levels is None:
|
| 551 |
+
task_levels = list(BUG_DB.keys())
|
| 552 |
+
total_rewards = []
|
| 553 |
+
traces = [] # human-readable behavior logs
|
| 554 |
+
for ep in range(n_episodes):
|
| 555 |
+
task = task_levels[ep % len(task_levels)]
|
| 556 |
+
env.set_task(task)
|
| 557 |
+
traj = collect_trajectory(env, model, tokenizer, max_steps, temperature=0.0)
|
| 558 |
+
ep_reward = sum(traj.rewards)
|
| 559 |
+
total_rewards.append(ep_reward)
|
| 560 |
+
if verbose:
|
| 561 |
+
actions_taken = []
|
| 562 |
+
for a in traj.actions:
|
| 563 |
+
try:
|
| 564 |
+
actions_taken.append(json.loads(a).get("action_type", "?"))
|
| 565 |
+
except Exception:
|
| 566 |
+
actions_taken.append("?")
|
| 567 |
+
traces.append({
|
| 568 |
+
"task": task,
|
| 569 |
+
"reward": round(ep_reward, 4),
|
| 570 |
+
"steps": len(traj),
|
| 571 |
+
"actions": actions_taken,
|
| 572 |
+
})
|
| 573 |
+
return {
|
| 574 |
+
"avg_reward": float(np.mean(total_rewards)),
|
| 575 |
+
"std_reward": float(np.std(total_rewards)),
|
| 576 |
+
"min_reward": float(np.min(total_rewards)),
|
| 577 |
+
"max_reward": float(np.max(total_rewards)),
|
| 578 |
+
"traces": traces,
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
# ======================================================================
|
| 582 |
+
# MANUAL WARM-UP (no SFTTrainer β no multiprocessing OOM)
|
| 583 |
+
# ======================================================================
|
| 584 |
+
def json_warmup(model, tokenizer, json_path="training_data.json",
|
| 585 |
+
n_episodes=20, epochs=2, lr=2e-5):
|
| 586 |
+
"""
|
| 587 |
+
Supervised warm-up from pre-generated expert demonstrations.
|
| 588 |
+
Uses raw cross-entropy on action tokens with manual gradient steps.
|
| 589 |
+
NO SFTTrainer, NO multiprocessing β runs safely on any GPU.
|
| 590 |
+
"""
|
| 591 |
+
print("\n" + "="*60)
|
| 592 |
+
print("SUPERVISED WARM-UP: training_data.json (manual cross-entropy)")
|
| 593 |
+
print("="*60)
|
| 594 |
+
|
| 595 |
+
with open(json_path, encoding="utf-8") as f:
|
| 596 |
+
data = json.load(f)
|
| 597 |
+
|
| 598 |
+
# Each episode = 7 steps. Select n_episodes worth.
|
| 599 |
+
steps_per_episode = 7
|
| 600 |
+
max_examples = n_episodes * steps_per_episode
|
| 601 |
+
if max_examples < len(data):
|
| 602 |
+
data = data[:max_examples]
|
| 603 |
+
|
| 604 |
+
print(f" {len(data)} examples ({len(data)//steps_per_episode} episodes), "
|
| 605 |
+
f"{epochs} epoch(s), lr={lr}")
|
| 606 |
+
|
| 607 |
+
model.train()
|
| 608 |
+
warmup_opt = AdamW(model.parameters(), lr=lr)
|
| 609 |
+
warmup_losses = [] # per-epoch avg loss
|
| 610 |
+
|
| 611 |
+
for epoch in range(epochs):
|
| 612 |
+
random.shuffle(data)
|
| 613 |
+
epoch_loss = 0.0
|
| 614 |
+
n_valid = 0
|
| 615 |
+
|
| 616 |
+
for i, example in enumerate(data):
|
| 617 |
+
prompt = example["prompt"]
|
| 618 |
+
action = example["action"]
|
| 619 |
+
|
| 620 |
+
# ---- tokenize full sequence (prompt + action) ----
|
| 621 |
+
messages = [
|
| 622 |
+
{"role": "user", "content": prompt},
|
| 623 |
+
{"role": "assistant", "content": action},
|
| 624 |
+
]
|
| 625 |
+
full_text = tokenizer.apply_chat_template(messages, tokenize=False)
|
| 626 |
+
inputs = tokenizer(full_text, return_tensors="pt",
|
| 627 |
+
max_length=480, truncation=True).to("cuda")
|
| 628 |
+
|
| 629 |
+
# ---- find where the action tokens start ----
|
| 630 |
+
prompt_only = tokenizer.apply_chat_template(
|
| 631 |
+
[{"role": "user", "content": prompt}],
|
| 632 |
+
tokenize=False, add_generation_prompt=True
|
| 633 |
+
)
|
| 634 |
+
prompt_ids = tokenizer.encode(prompt_only, add_special_tokens=False)
|
| 635 |
+
prompt_len = len(prompt_ids)
|
| 636 |
+
|
| 637 |
+
total_len = inputs.input_ids.shape[1]
|
| 638 |
+
if prompt_len >= total_len:
|
| 639 |
+
continue # prompt was truncated away, skip
|
| 640 |
+
|
| 641 |
+
# ---- cross-entropy on action tokens only ----
|
| 642 |
+
outputs = model(**inputs)
|
| 643 |
+
logits = outputs.logits
|
| 644 |
+
|
| 645 |
+
# next-token prediction: logits[t] predicts token[t+1]
|
| 646 |
+
shift_logits = logits[0, prompt_len - 1 : total_len - 1]
|
| 647 |
+
shift_labels = inputs.input_ids[0, prompt_len : total_len]
|
| 648 |
+
|
| 649 |
+
min_len = min(shift_logits.shape[0], shift_labels.shape[0])
|
| 650 |
+
if min_len == 0:
|
| 651 |
+
continue
|
| 652 |
+
|
| 653 |
+
loss = F.cross_entropy(shift_logits[:min_len], shift_labels[:min_len])
|
| 654 |
+
|
| 655 |
+
warmup_opt.zero_grad()
|
| 656 |
+
loss.backward()
|
| 657 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 658 |
+
warmup_opt.step()
|
| 659 |
+
|
| 660 |
+
epoch_loss += loss.item()
|
| 661 |
+
n_valid += 1
|
| 662 |
+
|
| 663 |
+
if (i + 1) % 25 == 0:
|
| 664 |
+
avg = epoch_loss / n_valid
|
| 665 |
+
print(f" epoch {epoch+1} step {i+1:3d}/{len(data)} "
|
| 666 |
+
f"running_loss={avg:.4f}")
|
| 667 |
+
|
| 668 |
+
avg_loss = epoch_loss / max(n_valid, 1)
|
| 669 |
+
warmup_losses.append(avg_loss)
|
| 670 |
+
print(f" Epoch {epoch+1} done: avg_loss={avg_loss:.4f} "
|
| 671 |
+
f"({n_valid} valid examples)")
|
| 672 |
+
|
| 673 |
+
torch.cuda.empty_cache()
|
| 674 |
+
print(f"β Warm-up complete. Loss: "
|
| 675 |
+
f"{' β '.join(f'{l:.4f}' for l in warmup_losses)}\n")
|
| 676 |
+
return warmup_losses
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
# ======================================================================
|
| 680 |
+
# MAIN TRAINING PIPELINE
|
| 681 |
+
# ======================================================================
|
| 682 |
+
def train_ppo():
|
| 683 |
+
# --- Hyperparameters ---
|
| 684 |
+
n_iterations = 8 # enough for a clear upward trend
|
| 685 |
+
trajectories_per_iter = 4 # on-policy data per iteration
|
| 686 |
+
n_epochs = 1
|
| 687 |
+
max_steps = 6
|
| 688 |
+
learning_rate = 3e-5
|
| 689 |
+
clip_epsilon = 0.2
|
| 690 |
+
entropy_coef = 0.01
|
| 691 |
+
gamma = 0.99
|
| 692 |
+
|
| 693 |
+
# --- Pre-load embedder before LLM (Issue #13) ---
|
| 694 |
+
from rltool import ToolBox
|
| 695 |
+
print("Pre-loading sentence-transformer embedder...")
|
| 696 |
+
ToolBox._get_embedder()
|
| 697 |
+
print("β Embedder ready")
|
| 698 |
+
|
| 699 |
+
# --- Load model ---
|
| 700 |
+
print("Loading model...")
|
| 701 |
+
model, tokenizer = load_model()
|
| 702 |
+
if not test_model_sanity(model, tokenizer):
|
| 703 |
+
return
|
| 704 |
+
env = CodeReviewEnv()
|
| 705 |
+
task_levels = list(BUG_DB.keys())
|
| 706 |
+
|
| 707 |
+
# ==================================================================
|
| 708 |
+
# PHASE 0: BASELINE (untrained policy)
|
| 709 |
+
# ==================================================================
|
| 710 |
+
print("\n" + "="*60)
|
| 711 |
+
print("PHASE 0 β BASELINE EVALUATION (untrained)")
|
| 712 |
+
print("="*60)
|
| 713 |
+
baseline = evaluate_policy(env, model, tokenizer, n_episodes=5,
|
| 714 |
+
max_steps=max_steps, task_levels=task_levels,
|
| 715 |
+
verbose=True)
|
| 716 |
+
baseline_reward = baseline["avg_reward"]
|
| 717 |
+
print(f"Baseline avg reward: {baseline_reward:.4f} "
|
| 718 |
+
f"(min={baseline['min_reward']:.4f}, max={baseline['max_reward']:.4f})")
|
| 719 |
+
print("Baseline behavior:")
|
| 720 |
+
for t in baseline["traces"]:
|
| 721 |
+
print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
|
| 722 |
+
f"steps={t['steps']} actions={t['actions']}")
|
| 723 |
+
|
| 724 |
+
# ==================================================================
|
| 725 |
+
# PHASE 1: SUPERVISED WARM-UP (expert demos, manual CE)
|
| 726 |
+
# ==================================================================
|
| 727 |
+
warmup_losses = json_warmup(
|
| 728 |
+
model, tokenizer,
|
| 729 |
+
json_path="training_data.json",
|
| 730 |
+
n_episodes=20, # 140 examples (20 Γ 7 steps)
|
| 731 |
+
epochs=2,
|
| 732 |
+
lr=2e-5,
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
# Post-warmup evaluation
|
| 736 |
+
print("="*60)
|
| 737 |
+
print("POST WARM-UP EVALUATION")
|
| 738 |
+
print("="*60)
|
| 739 |
+
post_warmup = evaluate_policy(env, model, tokenizer, n_episodes=5,
|
| 740 |
+
max_steps=max_steps, task_levels=task_levels,
|
| 741 |
+
verbose=True)
|
| 742 |
+
warmup_reward = post_warmup["avg_reward"]
|
| 743 |
+
print(f"Post-warmup avg reward: {warmup_reward:.4f} "
|
| 744 |
+
f"(Ξ vs baseline: {warmup_reward - baseline_reward:+.4f})")
|
| 745 |
+
print("Post-warmup behavior:")
|
| 746 |
+
for t in post_warmup["traces"]:
|
| 747 |
+
print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
|
| 748 |
+
f"steps={t['steps']} actions={t['actions']}")
|
| 749 |
+
|
| 750 |
+
# ==================================================================
|
| 751 |
+
# PHASE 2: TRUE RL β PPO (on-policy, real environment interaction)
|
| 752 |
+
# ==================================================================
|
| 753 |
+
optimizer = AdamW(model.parameters(), lr=learning_rate)
|
| 754 |
+
print(f"\n{'='*60}")
|
| 755 |
+
print(f"PHASE 2 β PPO TRAINING: {n_iterations} iterations Γ "
|
| 756 |
+
f"{trajectories_per_iter} trajectories (true RL)")
|
| 757 |
+
print(f"{'='*60}\n")
|
| 758 |
+
|
| 759 |
+
reward_history = []
|
| 760 |
+
eval_history = []
|
| 761 |
+
loss_history = []
|
| 762 |
+
policy_loss_history = []
|
| 763 |
+
entropy_history = []
|
| 764 |
+
|
| 765 |
+
for iteration in range(n_iterations):
|
| 766 |
+
print(f"\n--- PPO Iteration {iteration + 1}/{n_iterations} ---")
|
| 767 |
+
|
| 768 |
+
# Collect on-policy trajectories from REAL environment
|
| 769 |
+
trajectories = collect_trajectories(
|
| 770 |
+
env, model, tokenizer, trajectories_per_iter, max_steps,
|
| 771 |
+
task_levels=task_levels, task_weights=None
|
| 772 |
+
)
|
| 773 |
+
avg_reward = float(np.mean([sum(t.rewards) for t in trajectories]))
|
| 774 |
+
reward_history.append(avg_reward)
|
| 775 |
+
print(f" Collect avg reward: {avg_reward:+.4f}")
|
| 776 |
+
|
| 777 |
+
# PPO policy gradient update
|
| 778 |
+
metrics = ppo_update(
|
| 779 |
+
trajectories, model, tokenizer, optimizer,
|
| 780 |
+
n_epochs=n_epochs, clip_epsilon=clip_epsilon,
|
| 781 |
+
entropy_coef=entropy_coef, gamma=gamma
|
| 782 |
+
)
|
| 783 |
+
loss_history.append(float(metrics["loss"]))
|
| 784 |
+
policy_loss_history.append(float(metrics["policy_loss"]))
|
| 785 |
+
entropy_history.append(float(metrics["entropy"]))
|
| 786 |
+
print(f" Update loss={metrics['loss']:.4f} "
|
| 787 |
+
f"policy={metrics['policy_loss']:.4f} "
|
| 788 |
+
f"entropy={metrics['entropy']:.4f}")
|
| 789 |
+
|
| 790 |
+
# Evaluate greedy policy after update
|
| 791 |
+
eval_m = evaluate_policy(env, model, tokenizer, n_episodes=3,
|
| 792 |
+
max_steps=max_steps, task_levels=task_levels,
|
| 793 |
+
verbose=False)
|
| 794 |
+
eval_history.append(eval_m["avg_reward"])
|
| 795 |
+
delta = eval_m["avg_reward"] - baseline_reward
|
| 796 |
+
print(f" Eval avg reward: {eval_m['avg_reward']:+.4f} "
|
| 797 |
+
f"(Ξ baseline: {delta:+.4f})")
|
| 798 |
+
|
| 799 |
+
# ==================================================================
|
| 800 |
+
# PHASE 3: FINAL EVALUATION (proof of learning)
|
| 801 |
+
# ==================================================================
|
| 802 |
+
print("\n" + "="*60)
|
| 803 |
+
print("PHASE 3 β FINAL EVALUATION (after all training)")
|
| 804 |
+
print("="*60)
|
| 805 |
+
final = evaluate_policy(env, model, tokenizer, n_episodes=5,
|
| 806 |
+
max_steps=max_steps, task_levels=task_levels,
|
| 807 |
+
verbose=True)
|
| 808 |
+
print(f"Final avg reward: {final['avg_reward']:.4f} "
|
| 809 |
+
f"(min={final['min_reward']:.4f}, max={final['max_reward']:.4f})")
|
| 810 |
+
print("Final behavior:")
|
| 811 |
+
for t in final["traces"]:
|
| 812 |
+
print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
|
| 813 |
+
f"steps={t['steps']} actions={t['actions']}")
|
| 814 |
+
|
| 815 |
+
total_improvement = final["avg_reward"] - baseline_reward
|
| 816 |
+
ppo_improvement = final["avg_reward"] - warmup_reward
|
| 817 |
+
print(f"\n{'='*60}")
|
| 818 |
+
print("TRAINING SUMMARY")
|
| 819 |
+
print(f" Baseline reward: {baseline_reward:+.4f}")
|
| 820 |
+
print(f" Post-warmup reward: {warmup_reward:+.4f} "
|
| 821 |
+
f"(warmup Ξ: {warmup_reward - baseline_reward:+.4f})")
|
| 822 |
+
print(f" Final reward: {final['avg_reward']:+.4f} "
|
| 823 |
+
f"(PPO Ξ: {ppo_improvement:+.4f})")
|
| 824 |
+
print(f" Total improvement: {total_improvement:+.4f}")
|
| 825 |
+
print(f" Reward trend (PPO): {' β '.join(f'{r:+.3f}' for r in reward_history)}")
|
| 826 |
+
print(f" Loss trend (PPO): {' β '.join(f'{l:.4f}' for l in loss_history)}")
|
| 827 |
+
if total_improvement > 0:
|
| 828 |
+
print(f" β Agent IMPROVED by {total_improvement:+.4f}")
|
| 829 |
+
else:
|
| 830 |
+
print(f" β No overall improvement detected")
|
| 831 |
+
print(f"{'='*60}")
|
| 832 |
+
|
| 833 |
+
# ==================================================================
|
| 834 |
+
# PLOTS
|
| 835 |
+
# ==================================================================
|
| 836 |
+
iters = list(range(1, n_iterations + 1))
|
| 837 |
+
|
| 838 |
+
# --- 1. Warm-up loss curve ---
|
| 839 |
+
if warmup_losses:
|
| 840 |
+
fig, ax = plt.subplots(figsize=(7, 4))
|
| 841 |
+
ax.plot(range(1, len(warmup_losses) + 1), warmup_losses,
|
| 842 |
+
marker="o", linewidth=2, color="tab:purple")
|
| 843 |
+
ax.set_title("Warm-up Loss (supervised, per epoch)",
|
| 844 |
+
fontsize=13, fontweight="bold")
|
| 845 |
+
ax.set_xlabel("Epoch")
|
| 846 |
+
ax.set_ylabel("Cross-Entropy Loss")
|
| 847 |
+
ax.grid(alpha=0.3)
|
| 848 |
+
fig.tight_layout()
|
| 849 |
+
fig.savefig("warmup_loss.png", dpi=150)
|
| 850 |
+
plt.close(fig)
|
| 851 |
+
|
| 852 |
+
# --- 2. PPO reward curve ---
|
| 853 |
+
fig, ax = plt.subplots(figsize=(9, 5))
|
| 854 |
+
ax.plot(iters, reward_history, marker="o", linewidth=2,
|
| 855 |
+
label="Collect reward", color="tab:blue")
|
| 856 |
+
ax.plot(iters, eval_history, marker="s", linewidth=2, linestyle="--",
|
| 857 |
+
label="Eval reward", color="tab:green")
|
| 858 |
+
ax.axhline(y=baseline_reward, color="tab:gray", linestyle=":",
|
| 859 |
+
linewidth=1.5, label=f"Baseline ({baseline_reward:+.3f})")
|
| 860 |
+
ax.axhline(y=warmup_reward, color="tab:purple", linestyle=":",
|
| 861 |
+
linewidth=1.5, label=f"Post-warmup ({warmup_reward:+.3f})")
|
| 862 |
+
ax.set_title("PPO Reward per Iteration", fontsize=14, fontweight="bold")
|
| 863 |
+
ax.set_xlabel("Iteration")
|
| 864 |
+
ax.set_ylabel("Average Reward")
|
| 865 |
+
ax.legend(loc="best", fontsize=8)
|
| 866 |
+
ax.grid(alpha=0.3)
|
| 867 |
+
fig.tight_layout()
|
| 868 |
+
fig.savefig("reward_curve.png", dpi=150)
|
| 869 |
+
plt.close(fig)
|
| 870 |
+
|
| 871 |
+
# --- 3. PPO loss curve ---
|
| 872 |
+
fig, ax = plt.subplots(figsize=(9, 5))
|
| 873 |
+
ax.plot(iters, loss_history, marker="o", linewidth=2,
|
| 874 |
+
label="Total loss", color="tab:red")
|
| 875 |
+
ax.plot(iters, policy_loss_history, marker="^", linewidth=2, linestyle="--",
|
| 876 |
+
label="Policy loss", color="tab:orange")
|
| 877 |
+
ax.set_title("PPO Loss per Iteration", fontsize=14, fontweight="bold")
|
| 878 |
+
ax.set_xlabel("Iteration")
|
| 879 |
+
ax.set_ylabel("Loss")
|
| 880 |
+
ax.legend(loc="best")
|
| 881 |
+
ax.grid(alpha=0.3)
|
| 882 |
+
fig.tight_layout()
|
| 883 |
+
fig.savefig("loss_curve.png", dpi=150)
|
| 884 |
+
plt.close(fig)
|
| 885 |
+
|
| 886 |
+
# --- 4. Combined 3-panel summary ---
|
| 887 |
+
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
|
| 888 |
+
|
| 889 |
+
# Panel A: warm-up loss
|
| 890 |
+
if warmup_losses:
|
| 891 |
+
axes[0].plot(range(1, len(warmup_losses) + 1), warmup_losses,
|
| 892 |
+
marker="o", linewidth=2, color="tab:purple")
|
| 893 |
+
axes[0].set_title("A. Warm-up Loss β")
|
| 894 |
+
axes[0].set_xlabel("Epoch")
|
| 895 |
+
axes[0].set_ylabel("CE Loss")
|
| 896 |
+
axes[0].grid(alpha=0.3)
|
| 897 |
+
|
| 898 |
+
# Panel B: PPO reward
|
| 899 |
+
axes[1].plot(iters, reward_history, marker="o", linewidth=2,
|
| 900 |
+
color="tab:blue", label="Collect")
|
| 901 |
+
axes[1].plot(iters, eval_history, marker="s", linewidth=2,
|
| 902 |
+
linestyle="--", color="tab:green", label="Eval")
|
| 903 |
+
axes[1].axhline(y=baseline_reward, color="tab:gray", linestyle=":",
|
| 904 |
+
linewidth=1.5, label="Baseline")
|
| 905 |
+
axes[1].axhline(y=warmup_reward, color="tab:purple", linestyle=":",
|
| 906 |
+
linewidth=1.5, label="Post-warmup")
|
| 907 |
+
axes[1].set_title("B. PPO Reward β")
|
| 908 |
+
axes[1].set_xlabel("Iteration")
|
| 909 |
+
axes[1].set_ylabel("Avg Reward")
|
| 910 |
+
axes[1].legend(fontsize=7)
|
| 911 |
+
axes[1].grid(alpha=0.3)
|
| 912 |
+
|
| 913 |
+
# Panel C: PPO loss
|
| 914 |
+
axes[2].plot(iters, loss_history, marker="o", linewidth=2,
|
| 915 |
+
color="tab:red", label="Total")
|
| 916 |
+
axes[2].plot(iters, policy_loss_history, marker="^", linewidth=2,
|
| 917 |
+
linestyle="--", color="tab:orange", label="Policy")
|
| 918 |
+
axes[2].set_title("C. PPO Loss β")
|
| 919 |
+
axes[2].set_xlabel("Iteration")
|
| 920 |
+
axes[2].set_ylabel("Loss")
|
| 921 |
+
axes[2].legend(fontsize=7)
|
| 922 |
+
axes[2].grid(alpha=0.3)
|
| 923 |
+
|
| 924 |
+
fig.suptitle("Code Review Agent β Full Training Evidence",
|
| 925 |
+
fontsize=14, fontweight="bold")
|
| 926 |
+
fig.tight_layout()
|
| 927 |
+
fig.savefig("training_summary.png", dpi=150)
|
| 928 |
+
plt.close(fig)
|
| 929 |
+
|
| 930 |
+
print("Plots saved: warmup_loss.png, reward_curve.png, "
|
| 931 |
+
"loss_curve.png, training_summary.png")
|
| 932 |
+
print("="*60)
|
| 933 |
+
|
| 934 |
+
if __name__ == "__main__":
|
| 935 |
train_ppo()
|
training_data.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|