Spaces:
Running
Running
codemaverick2 commited on
Commit ·
ff9fcbd
0
Parent(s):
Code Review Environment OpenEnv hackathon submission
Browse files- .env.example +10 -0
- .gitignore +9 -0
- Dockerfile +32 -0
- README.md +269 -0
- client.py +154 -0
- demo.py +154 -0
- inference.py +231 -0
- models.py +184 -0
- openenv.yaml +11 -0
- pyproject.toml +25 -0
- requirements.txt +7 -0
- server/__init__.py +1 -0
- server/app.py +306 -0
- server/environment.py +310 -0
- server/graders.py +170 -0
- tasks/__init__.py +3 -0
- tasks/data.py +434 -0
- tests/__init__.py +1 -0
- tests/test_environment.py +314 -0
- tests/test_graders.py +215 -0
- uv.lock +0 -0
.env.example
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copy to .env and fill in values
|
| 2 |
+
|
| 3 |
+
# Required for baseline.py LLM inference
|
| 4 |
+
OPENAI_API_KEY=sk-...
|
| 5 |
+
|
| 6 |
+
# Optional: override the environment URL for baseline.py
|
| 7 |
+
ENV_URL=http://localhost:7860
|
| 8 |
+
|
| 9 |
+
# Optional: override the model for baseline.py
|
| 10 |
+
BASELINE_MODEL=gpt-4o-mini
|
.gitignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
.pytest_cache/
|
| 5 |
+
*.egg-info/
|
| 6 |
+
dist/
|
| 7 |
+
build/
|
| 8 |
+
.env
|
| 9 |
+
.DS_Store
|
Dockerfile
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code Review Environment — Docker Image
|
| 2 |
+
#
|
| 3 |
+
# Build: docker build -t code-review-env .
|
| 4 |
+
# Run: docker run -p 7860:7860 code-review-env
|
| 5 |
+
# Test: curl http://localhost:7860/health
|
| 6 |
+
|
| 7 |
+
FROM python:3.11-slim
|
| 8 |
+
|
| 9 |
+
# Create non-root user (HF Spaces requirement)
|
| 10 |
+
RUN useradd -m -u 1000 appuser
|
| 11 |
+
|
| 12 |
+
WORKDIR /app
|
| 13 |
+
|
| 14 |
+
# Install dependencies first (caching layer)
|
| 15 |
+
COPY requirements.txt .
|
| 16 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 17 |
+
|
| 18 |
+
# Copy application code
|
| 19 |
+
COPY --chown=appuser:appuser . .
|
| 20 |
+
|
| 21 |
+
# Switch to non-root user
|
| 22 |
+
USER appuser
|
| 23 |
+
|
| 24 |
+
# Expose port (HF Spaces uses 7860)
|
| 25 |
+
EXPOSE 7860
|
| 26 |
+
|
| 27 |
+
# Health check
|
| 28 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
| 29 |
+
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860/health')" || exit 1
|
| 30 |
+
|
| 31 |
+
# Start server
|
| 32 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
|
README.md
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Code Review Environment
|
| 3 |
+
emoji: 🔍
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
+
- code-review
|
| 11 |
+
- security-audit
|
| 12 |
+
- reinforcement-learning
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# Code Review Environment
|
| 16 |
+
|
| 17 |
+
An [OpenEnv](https://github.com/meta-pytorch/OpenEnv)-compatible environment for training and evaluating AI agents on code review and security auditing tasks.
|
| 18 |
+
|
| 19 |
+
The agent inspects code files, flags bugs and vulnerabilities with precise line numbers and severity ratings, and receives graded feedback — enabling reinforcement learning from human-quality code review signal.
|
| 20 |
+
|
| 21 |
+
## Why This Environment
|
| 22 |
+
|
| 23 |
+
Code review is one of the highest-value tasks in software engineering. Every professional software team does it daily. Training AI agents to perform thorough, accurate code reviews is commercially valuable and technically challenging:
|
| 24 |
+
|
| 25 |
+
- **Precise reasoning required**: agent must count lines, understand language semantics, reason about control flow
|
| 26 |
+
- **Real impact**: bugs found → prevented production incidents; vulnerabilities found → prevented security breaches
|
| 27 |
+
- **Natural difficulty progression**: obvious logic errors → subtle security vulnerabilities → complex architectural issues
|
| 28 |
+
- **Clear grading**: issues exist at specific lines with specific types — objective F1-based scoring
|
| 29 |
+
|
| 30 |
+
## Action Space
|
| 31 |
+
|
| 32 |
+
```json
|
| 33 |
+
{
|
| 34 |
+
"action_type": "flag_issue | clear_flag | request_hint | submit_review",
|
| 35 |
+
"line_number": 6,
|
| 36 |
+
"filename": "utils.py",
|
| 37 |
+
"issue_type": "bug | security | performance | logic",
|
| 38 |
+
"severity": "low | medium | high | critical",
|
| 39 |
+
"description": "Description of the issue",
|
| 40 |
+
"fix_suggestion": "How to fix it (optional)"
|
| 41 |
+
}
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
| Action | Description | Reward |
|
| 45 |
+
|--------|-------------|--------|
|
| 46 |
+
| `flag_issue` | Mark a line as containing an issue | +0.10 if correct, −0.05 if wrong |
|
| 47 |
+
| `clear_flag` | Remove a previously flagged issue | +0.03 if was FP, −0.03 if was TP |
|
| 48 |
+
| `request_hint` | Get a hint about what to look for | −0.01 |
|
| 49 |
+
| `submit_review` | Finalize and receive graded score | Final F1 score |
|
| 50 |
+
|
| 51 |
+
## Observation Space
|
| 52 |
+
|
| 53 |
+
```json
|
| 54 |
+
{
|
| 55 |
+
"task_id": "bug-detection",
|
| 56 |
+
"task_description": "Review this Python utility module...",
|
| 57 |
+
"code_files": {"utils.py": "def calculate_average(numbers):\n..."},
|
| 58 |
+
"language": "python",
|
| 59 |
+
"flagged_issues": [...],
|
| 60 |
+
"step_count": 3,
|
| 61 |
+
"max_steps": 15,
|
| 62 |
+
"hints_remaining": 2,
|
| 63 |
+
"feedback": "Good catch! Issue flagged at utils.py:6 [+0.10 reward]",
|
| 64 |
+
"current_score": 0.333,
|
| 65 |
+
"done": false,
|
| 66 |
+
"reward": 0.1
|
| 67 |
+
}
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
Note: `code_files` is only populated in the first observation (after `reset()`). Subsequent step observations omit it to keep payloads small.
|
| 71 |
+
|
| 72 |
+
## Tasks
|
| 73 |
+
|
| 74 |
+
### Task 1: `bug-detection` — Easy
|
| 75 |
+
|
| 76 |
+
Identify 3 logical bugs in a Python utility module (`utils.py`).
|
| 77 |
+
|
| 78 |
+
| Line | Issue | Severity |
|
| 79 |
+
|------|-------|----------|
|
| 80 |
+
| 6 | Off-by-one error: `range(len(numbers) + 1)` causes `IndexError` | High |
|
| 81 |
+
| 13 | Binary search upper bound: `len(arr)` should be `len(arr) - 1` | Medium |
|
| 82 |
+
| 33 | Word count initializes new entries to `0` instead of `1` | Low |
|
| 83 |
+
|
| 84 |
+
**Max steps:** 15
|
| 85 |
+
|
| 86 |
+
### Task 2: `security-audit` — Medium
|
| 87 |
+
|
| 88 |
+
Audit a Flask web application (`app.py`) for OWASP Top-10 vulnerabilities.
|
| 89 |
+
|
| 90 |
+
| Line | Issue | Severity |
|
| 91 |
+
|------|-------|----------|
|
| 92 |
+
| 8 | Hardcoded `SECRET_KEY` in source | High |
|
| 93 |
+
| 9 | Hardcoded `DB_PASSWORD` in source | High |
|
| 94 |
+
| 19 | SQL injection via f-string query | Critical |
|
| 95 |
+
| 27 | XSS via unsanitized `render_template_string` | High |
|
| 96 |
+
| 34 | Path traversal via `os.path.join` | High |
|
| 97 |
+
| 40 | Missing authentication on admin endpoint | Critical |
|
| 98 |
+
| 51 | Command injection via `shell=True` | Critical |
|
| 99 |
+
|
| 100 |
+
**Max steps:** 20
|
| 101 |
+
|
| 102 |
+
### Task 3: `comprehensive-review` — Hard
|
| 103 |
+
|
| 104 |
+
Comprehensive review of a Django e-commerce API across two files (`views.py`, `models.py`).
|
| 105 |
+
|
| 106 |
+
| File | Line | Issue | Severity |
|
| 107 |
+
|------|------|-------|----------|
|
| 108 |
+
| views.py | 21 | N+1 query in order creation loop | High |
|
| 109 |
+
| views.py | 26 | Race condition — stock check not atomic | Critical |
|
| 110 |
+
| views.py | 29 | Order created outside transaction | High |
|
| 111 |
+
| views.py | 47 | No max cap on `per_page` parameter | Medium |
|
| 112 |
+
| views.py | 66 | MD5 for payment verification (broken crypto) | Medium |
|
| 113 |
+
| views.py | 67 | Timing attack in payment hash comparison | Medium |
|
| 114 |
+
| models.py | 8 | Plaintext password storage | Critical |
|
| 115 |
+
| models.py | 16 | `FloatField` for monetary values | Medium |
|
| 116 |
+
| models.py | 18 | `BinaryField` with pickled data (RCE risk) | High |
|
| 117 |
+
|
| 118 |
+
**Max steps:** 30
|
| 119 |
+
|
| 120 |
+
## Scoring
|
| 121 |
+
|
| 122 |
+
```
|
| 123 |
+
final_score = 0.70 × F1 + 0.30 × severity_accuracy
|
| 124 |
+
|
| 125 |
+
where:
|
| 126 |
+
F1 = 2 × precision × recall / (precision + recall)
|
| 127 |
+
precision = correct_flags / total_flags
|
| 128 |
+
recall = correct_flags / total_gt_issues
|
| 129 |
+
severity_accuracy = avg(1 − |flag_sev_rank − gt_sev_rank| × 0.34) for matched issues
|
| 130 |
+
|
| 131 |
+
Matching tolerance: ±2 lines, same filename, compatible issue type
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
## API Endpoints
|
| 135 |
+
|
| 136 |
+
| Method | Endpoint | Description |
|
| 137 |
+
|--------|----------|-------------|
|
| 138 |
+
| `POST` | `/reset` | Start new episode. Body: `{"task_id": "bug-detection", "seed": 42}` |
|
| 139 |
+
| `POST` | `/step` | Take action. Body: ReviewAction JSON |
|
| 140 |
+
| `GET` | `/state` | Get current episode state |
|
| 141 |
+
| `GET` | `/health` | Health check → `{"status": "healthy"}` |
|
| 142 |
+
| `GET` | `/tasks` | List all tasks + action schema |
|
| 143 |
+
| `POST` | `/grader` | Grade findings: `{"task_id": "...", "flagged_issues": [...]}` |
|
| 144 |
+
| `POST` | `/baseline` | Run keyword heuristic on all tasks |
|
| 145 |
+
| `WS` | `/ws` | WebSocket session (OpenEnv standard) |
|
| 146 |
+
| `GET` | `/docs` | Swagger UI |
|
| 147 |
+
|
| 148 |
+
## Setup & Usage
|
| 149 |
+
|
| 150 |
+
### Local (uvicorn)
|
| 151 |
+
|
| 152 |
+
```bash
|
| 153 |
+
git clone https://github.com/CodeMaverick2/code-review-env
|
| 154 |
+
cd code-review-env
|
| 155 |
+
pip install -r requirements.txt
|
| 156 |
+
uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
### Docker
|
| 160 |
+
|
| 161 |
+
```bash
|
| 162 |
+
docker build -t code-review-env .
|
| 163 |
+
docker run -p 7860:7860 code-review-env
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
### Quick test
|
| 167 |
+
|
| 168 |
+
```bash
|
| 169 |
+
curl http://localhost:7860/health
|
| 170 |
+
|
| 171 |
+
curl -X POST http://localhost:7860/reset \
|
| 172 |
+
-H "Content-Type: application/json" \
|
| 173 |
+
-d '{"task_id": "bug-detection"}'
|
| 174 |
+
|
| 175 |
+
curl -X POST http://localhost:7860/step \
|
| 176 |
+
-H "Content-Type: application/json" \
|
| 177 |
+
-d '{"action_type": "flag_issue", "line_number": 6, "filename": "utils.py", "issue_type": "bug", "severity": "high", "description": "Off-by-one"}'
|
| 178 |
+
|
| 179 |
+
curl -X POST http://localhost:7860/step \
|
| 180 |
+
-H "Content-Type: application/json" \
|
| 181 |
+
-d '{"action_type": "submit_review"}'
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
### Python client
|
| 185 |
+
|
| 186 |
+
```python
|
| 187 |
+
from client import CodeReviewEnv, ReviewAction
|
| 188 |
+
|
| 189 |
+
with CodeReviewEnv("http://localhost:7860").sync() as env:
|
| 190 |
+
result = env.reset(task_id="bug-detection")
|
| 191 |
+
print(result.observation.code_files["utils.py"])
|
| 192 |
+
|
| 193 |
+
result = env.step(ReviewAction(
|
| 194 |
+
action_type="flag_issue",
|
| 195 |
+
line_number=6,
|
| 196 |
+
filename="utils.py",
|
| 197 |
+
issue_type="bug",
|
| 198 |
+
severity="high",
|
| 199 |
+
description="Off-by-one error in range()"
|
| 200 |
+
))
|
| 201 |
+
print(result.observation.feedback)
|
| 202 |
+
|
| 203 |
+
result = env.step(ReviewAction(action_type="submit_review"))
|
| 204 |
+
print(f"Final score: {result.reward:.3f}")
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
### Inference script
|
| 208 |
+
|
| 209 |
+
```bash
|
| 210 |
+
# No API key needed — uses built-in keyword heuristic
|
| 211 |
+
python inference.py
|
| 212 |
+
|
| 213 |
+
# With LLM (OpenAI-compatible API)
|
| 214 |
+
export API_BASE_URL=https://openrouter.ai/api/v1
|
| 215 |
+
export MODEL_NAME=openai/gpt-4o-mini
|
| 216 |
+
export HF_TOKEN=sk-...
|
| 217 |
+
python inference.py
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
### Demo
|
| 221 |
+
|
| 222 |
+
```bash
|
| 223 |
+
python demo.py
|
| 224 |
+
python demo.py --task security-audit
|
| 225 |
+
python demo.py --task comprehensive-review
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
### Tests
|
| 229 |
+
|
| 230 |
+
```bash
|
| 231 |
+
pip install pytest
|
| 232 |
+
pytest tests/ -v
|
| 233 |
+
```
|
| 234 |
+
|
| 235 |
+
## Baseline Scores
|
| 236 |
+
|
| 237 |
+
| Task | Keyword heuristic | GPT-4o-mini |
|
| 238 |
+
|------|-------------------|-------------|
|
| 239 |
+
| bug-detection | 1.00 | ~0.52 |
|
| 240 |
+
| security-audit | 0.75 | ~0.59 |
|
| 241 |
+
| comprehensive-review | 0.67 | ~0.17 |
|
| 242 |
+
| **Overall** | **0.81** | **~0.43** |
|
| 243 |
+
|
| 244 |
+
Keyword heuristic runs via `inference.py` with no API key. LLM scores use `API_BASE_URL` + `HF_TOKEN`.
|
| 245 |
+
|
| 246 |
+
## Project Structure
|
| 247 |
+
|
| 248 |
+
```
|
| 249 |
+
code-review-env/
|
| 250 |
+
├── README.md
|
| 251 |
+
├── openenv.yaml ← OpenEnv manifest
|
| 252 |
+
├── Dockerfile ← Container (HF Spaces, port 7860)
|
| 253 |
+
├── pyproject.toml ← Package config + entry points
|
| 254 |
+
├── requirements.txt
|
| 255 |
+
├── uv.lock
|
| 256 |
+
├── inference.py ← Inference script
|
| 257 |
+
├── demo.py ← Demo script (no API key needed)
|
| 258 |
+
├── client.py ← HTTP client
|
| 259 |
+
├── models.py ← ReviewAction, ReviewObservation, ReviewState, Issue
|
| 260 |
+
├── tasks/
|
| 261 |
+
│ └── data.py ← 3 task definitions + ground truth
|
| 262 |
+
├── server/
|
| 263 |
+
│ ├── app.py ← FastAPI application
|
| 264 |
+
│ ├── environment.py ← Core environment logic
|
| 265 |
+
│ └── graders.py ← F1 grading + keyword baseline
|
| 266 |
+
└── tests/
|
| 267 |
+
├── test_environment.py
|
| 268 |
+
└── test_graders.py
|
| 269 |
+
```
|
client.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HTTP client for the Code Review Environment.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
from client import CodeReviewEnv, ReviewAction
|
| 6 |
+
|
| 7 |
+
with CodeReviewEnv(base_url="http://localhost:7860").sync() as env:
|
| 8 |
+
result = env.reset(task_id="bug-detection")
|
| 9 |
+
obs = result.observation
|
| 10 |
+
print(obs.task_description)
|
| 11 |
+
print(obs.code_files)
|
| 12 |
+
|
| 13 |
+
# Flag an issue
|
| 14 |
+
result = env.step(ReviewAction(
|
| 15 |
+
action_type="flag_issue",
|
| 16 |
+
line_number=6,
|
| 17 |
+
filename="utils.py",
|
| 18 |
+
issue_type="bug",
|
| 19 |
+
severity="high",
|
| 20 |
+
description="Off-by-one in range()"
|
| 21 |
+
))
|
| 22 |
+
print(result.observation.feedback)
|
| 23 |
+
|
| 24 |
+
# Submit
|
| 25 |
+
result = env.step(ReviewAction(action_type="submit_review"))
|
| 26 |
+
print(f"Score: {result.reward:.3f}")
|
| 27 |
+
"""
|
| 28 |
+
from __future__ import annotations
|
| 29 |
+
|
| 30 |
+
import os
|
| 31 |
+
import sys
|
| 32 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 33 |
+
|
| 34 |
+
from typing import Optional, Generic, TypeVar
|
| 35 |
+
from models import ReviewAction, ReviewObservation, ReviewState, Issue
|
| 36 |
+
|
| 37 |
+
ObsT = TypeVar("ObsT")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class StepResult(Generic[ObsT]):
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
observation: ObsT,
|
| 44 |
+
reward: Optional[float] = None,
|
| 45 |
+
done: bool = False,
|
| 46 |
+
):
|
| 47 |
+
self.observation = observation
|
| 48 |
+
self.reward = reward
|
| 49 |
+
self.done = done
|
| 50 |
+
|
| 51 |
+
def __repr__(self) -> str:
|
| 52 |
+
return (
|
| 53 |
+
f"StepResult(done={self.done}, reward={self.reward}, "
|
| 54 |
+
f"score={getattr(self.observation, 'current_score', None)})"
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
import httpx
|
| 60 |
+
_HAS_HTTPX = True
|
| 61 |
+
except ImportError:
|
| 62 |
+
_HAS_HTTPX = False
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
from openenv.core.http_env_client import HTTPEnvClient as _OfficialClient
|
| 66 |
+
_HAS_OPENENV_CLIENT = True
|
| 67 |
+
except ImportError:
|
| 68 |
+
_HAS_OPENENV_CLIENT = False
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class SyncCodeReviewEnv:
|
| 72 |
+
|
| 73 |
+
def __init__(self, base_url: str = "http://localhost:7860"):
|
| 74 |
+
self.base_url = base_url.rstrip("/")
|
| 75 |
+
if not _HAS_HTTPX:
|
| 76 |
+
raise ImportError("httpx is required: pip install httpx")
|
| 77 |
+
import httpx
|
| 78 |
+
self._client = httpx.Client(timeout=30.0)
|
| 79 |
+
|
| 80 |
+
def __enter__(self):
|
| 81 |
+
return self
|
| 82 |
+
|
| 83 |
+
def __exit__(self, *args):
|
| 84 |
+
self.close()
|
| 85 |
+
|
| 86 |
+
def close(self):
|
| 87 |
+
self._client.close()
|
| 88 |
+
|
| 89 |
+
def reset(
|
| 90 |
+
self,
|
| 91 |
+
task_id: Optional[str] = None,
|
| 92 |
+
seed: Optional[int] = None,
|
| 93 |
+
episode_id: Optional[str] = None,
|
| 94 |
+
) -> StepResult[ReviewObservation]:
|
| 95 |
+
body = {}
|
| 96 |
+
if task_id:
|
| 97 |
+
body["task_id"] = task_id
|
| 98 |
+
if seed is not None:
|
| 99 |
+
body["seed"] = seed
|
| 100 |
+
if episode_id:
|
| 101 |
+
body["episode_id"] = episode_id
|
| 102 |
+
|
| 103 |
+
resp = self._client.post(f"{self.base_url}/reset", json=body)
|
| 104 |
+
resp.raise_for_status()
|
| 105 |
+
obs = ReviewObservation.from_dict(resp.json())
|
| 106 |
+
return StepResult(observation=obs, reward=obs.reward, done=obs.done)
|
| 107 |
+
|
| 108 |
+
def step(self, action: ReviewAction) -> StepResult[ReviewObservation]:
|
| 109 |
+
body = action.to_dict()
|
| 110 |
+
resp = self._client.post(f"{self.base_url}/step", json=body)
|
| 111 |
+
resp.raise_for_status()
|
| 112 |
+
obs = ReviewObservation.from_dict(resp.json())
|
| 113 |
+
return StepResult(observation=obs, reward=obs.reward, done=obs.done)
|
| 114 |
+
|
| 115 |
+
def state(self) -> ReviewState:
|
| 116 |
+
resp = self._client.get(f"{self.base_url}/state")
|
| 117 |
+
resp.raise_for_status()
|
| 118 |
+
data = resp.json()
|
| 119 |
+
return ReviewState(
|
| 120 |
+
task_id=data.get("task_id", ""),
|
| 121 |
+
difficulty=data.get("difficulty", ""),
|
| 122 |
+
episode_id=data.get("episode_id"),
|
| 123 |
+
step_count=data.get("step_count", 0),
|
| 124 |
+
flagged_issues=[Issue.from_dict(i) for i in data.get("flagged_issues", [])],
|
| 125 |
+
current_score=data.get("current_score", 0.0),
|
| 126 |
+
submitted=data.get("submitted", False),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def health(self) -> dict:
|
| 130 |
+
resp = self._client.get(f"{self.base_url}/health")
|
| 131 |
+
resp.raise_for_status()
|
| 132 |
+
return resp.json()
|
| 133 |
+
|
| 134 |
+
def list_tasks(self) -> dict:
|
| 135 |
+
resp = self._client.get(f"{self.base_url}/tasks")
|
| 136 |
+
resp.raise_for_status()
|
| 137 |
+
return resp.json()
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class CodeReviewEnv:
|
| 141 |
+
|
| 142 |
+
def __init__(self, base_url: str = "http://localhost:7860"):
|
| 143 |
+
self.base_url = base_url
|
| 144 |
+
|
| 145 |
+
def sync(self) -> SyncCodeReviewEnv:
|
| 146 |
+
return SyncCodeReviewEnv(self.base_url)
|
| 147 |
+
|
| 148 |
+
def __enter__(self):
|
| 149 |
+
self._sync = self.sync()
|
| 150 |
+
return self._sync
|
| 151 |
+
|
| 152 |
+
def __exit__(self, *args):
|
| 153 |
+
if hasattr(self, "_sync"):
|
| 154 |
+
self._sync.close()
|
demo.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Demo script for the Code Review Environment.
|
| 3 |
+
|
| 4 |
+
Runs a complete episode against the live environment using the
|
| 5 |
+
keyword-heuristic agent (no API key required).
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python demo.py
|
| 9 |
+
python demo.py --url https://tejasghatule-code-review-env.hf.space
|
| 10 |
+
python demo.py --task security-audit
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
import sys
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
import httpx
|
| 20 |
+
|
| 21 |
+
DEFAULT_URL = "https://tejasghatule-code-review-env.hf.space"
|
| 22 |
+
TASKS = ["bug-detection", "security-audit", "comprehensive-review"]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def run_keyword_agent(base_url: str, task_id: str) -> dict:
|
| 26 |
+
"""Run the built-in keyword-heuristic agent via the /baseline endpoint."""
|
| 27 |
+
with httpx.Client(timeout=30) as client:
|
| 28 |
+
# Health check
|
| 29 |
+
health = client.get(f"{base_url}/health")
|
| 30 |
+
health.raise_for_status()
|
| 31 |
+
print(f" Health : {health.json()}")
|
| 32 |
+
|
| 33 |
+
# Reset
|
| 34 |
+
resp = client.post(f"{base_url}/reset", json={"task_id": task_id})
|
| 35 |
+
resp.raise_for_status()
|
| 36 |
+
obs = resp.json()
|
| 37 |
+
|
| 38 |
+
print(f" Task : {obs['task_id']} ({obs.get('difficulty', '')})")
|
| 39 |
+
print(f" Files : {list(obs['code_files'].keys())}")
|
| 40 |
+
print(f" Steps : 0 / {obs['max_steps']}")
|
| 41 |
+
print()
|
| 42 |
+
|
| 43 |
+
# Use /baseline endpoint (deterministic, no LLM)
|
| 44 |
+
baseline = client.post(f"{base_url}/baseline")
|
| 45 |
+
baseline.raise_for_status()
|
| 46 |
+
results = baseline.json()
|
| 47 |
+
|
| 48 |
+
return results
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def run_manual_episode(base_url: str, task_id: str) -> None:
|
| 52 |
+
"""Walk through a full episode step-by-step to demonstrate the API."""
|
| 53 |
+
with httpx.Client(timeout=30) as client:
|
| 54 |
+
print(f"=== Episode Demo: {task_id} ===\n")
|
| 55 |
+
|
| 56 |
+
# 1. Reset
|
| 57 |
+
resp = client.post(f"{base_url}/reset", json={"task_id": task_id})
|
| 58 |
+
resp.raise_for_status()
|
| 59 |
+
obs = resp.json()
|
| 60 |
+
|
| 61 |
+
print(f"Task : {obs['task_description'][:120]}...")
|
| 62 |
+
print(f"Files : {list(obs['code_files'].keys())}")
|
| 63 |
+
print(f"Max steps : {obs['max_steps']}")
|
| 64 |
+
print(f"Score : {obs['current_score']}")
|
| 65 |
+
print()
|
| 66 |
+
|
| 67 |
+
# 2. Flag a known issue (task-specific)
|
| 68 |
+
actions = {
|
| 69 |
+
"bug-detection": {
|
| 70 |
+
"action_type": "flag_issue",
|
| 71 |
+
"line_number": 6,
|
| 72 |
+
"filename": "utils.py",
|
| 73 |
+
"issue_type": "bug",
|
| 74 |
+
"severity": "high",
|
| 75 |
+
"description": "Off-by-one: range(len(numbers) + 1) causes IndexError",
|
| 76 |
+
"fix_suggestion": "Change to range(len(numbers))",
|
| 77 |
+
},
|
| 78 |
+
"security-audit": {
|
| 79 |
+
"action_type": "flag_issue",
|
| 80 |
+
"line_number": 8,
|
| 81 |
+
"filename": "app.py",
|
| 82 |
+
"issue_type": "security",
|
| 83 |
+
"severity": "high",
|
| 84 |
+
"description": "Hardcoded SECRET_KEY in source code",
|
| 85 |
+
"fix_suggestion": "Use os.environ.get('SECRET_KEY')",
|
| 86 |
+
},
|
| 87 |
+
"comprehensive-review": {
|
| 88 |
+
"action_type": "flag_issue",
|
| 89 |
+
"line_number": 8,
|
| 90 |
+
"filename": "models.py",
|
| 91 |
+
"issue_type": "security",
|
| 92 |
+
"severity": "critical",
|
| 93 |
+
"description": "Plaintext password storage in database",
|
| 94 |
+
"fix_suggestion": "Use Django's make_password / check_password",
|
| 95 |
+
},
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
action = actions.get(task_id, actions["bug-detection"])
|
| 99 |
+
print(f"Step 1 — flag_issue at {action['filename']}:{action['line_number']}")
|
| 100 |
+
resp = client.post(f"{base_url}/step", json=action)
|
| 101 |
+
resp.raise_for_status()
|
| 102 |
+
obs = resp.json()
|
| 103 |
+
print(f" Feedback : {obs['feedback']}")
|
| 104 |
+
print(f" Reward : {obs['reward']}")
|
| 105 |
+
print(f" Score : {obs['current_score']}")
|
| 106 |
+
print()
|
| 107 |
+
|
| 108 |
+
# 3. Request a hint
|
| 109 |
+
print("Step 2 — request_hint")
|
| 110 |
+
resp = client.post(f"{base_url}/step", json={"action_type": "request_hint"})
|
| 111 |
+
resp.raise_for_status()
|
| 112 |
+
obs = resp.json()
|
| 113 |
+
print(f" Feedback : {obs['feedback']}")
|
| 114 |
+
print()
|
| 115 |
+
|
| 116 |
+
# 4. Submit
|
| 117 |
+
print("Step 3 — submit_review")
|
| 118 |
+
resp = client.post(f"{base_url}/step", json={"action_type": "submit_review"})
|
| 119 |
+
resp.raise_for_status()
|
| 120 |
+
obs = resp.json()
|
| 121 |
+
print(f" Feedback : {obs['feedback']}")
|
| 122 |
+
print(f" Final score : {obs['reward']:.4f}")
|
| 123 |
+
print(f" Done : {obs['done']}")
|
| 124 |
+
print()
|
| 125 |
+
|
| 126 |
+
# 5. Check state
|
| 127 |
+
state = client.get(f"{base_url}/state")
|
| 128 |
+
state.raise_for_status()
|
| 129 |
+
s = state.json()
|
| 130 |
+
print(f"State — episode_id: {s['episode_id']}, steps: {s['step_count']}, submitted: {s['submitted']}")
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def main():
|
| 134 |
+
parser = argparse.ArgumentParser(description="Code Review Environment demo")
|
| 135 |
+
parser.add_argument("--url", default=DEFAULT_URL, help="Environment base URL")
|
| 136 |
+
parser.add_argument("--task", default="bug-detection", choices=TASKS)
|
| 137 |
+
parser.add_argument("--baseline", action="store_true", help="Run full baseline on all tasks")
|
| 138 |
+
args = parser.parse_args()
|
| 139 |
+
|
| 140 |
+
base_url = args.url.rstrip("/")
|
| 141 |
+
print(f"Code Review Environment — Demo")
|
| 142 |
+
print(f" URL : {base_url}")
|
| 143 |
+
print(f" Task : {args.task}\n")
|
| 144 |
+
|
| 145 |
+
if args.baseline:
|
| 146 |
+
print("Running keyword-heuristic baseline on all tasks...\n")
|
| 147 |
+
results = run_keyword_agent(base_url, args.task)
|
| 148 |
+
print(json.dumps(results, indent=2))
|
| 149 |
+
else:
|
| 150 |
+
run_manual_episode(base_url, args.task)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
if __name__ == "__main__":
|
| 154 |
+
main()
|
inference.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference script for the Code Review Environment.
|
| 3 |
+
|
| 4 |
+
Environment variables:
|
| 5 |
+
API_BASE_URL — LLM API endpoint (e.g. https://openrouter.ai/api/v1)
|
| 6 |
+
MODEL_NAME — Model identifier (e.g. openai/gpt-4o-mini)
|
| 7 |
+
HF_TOKEN — API key for the LLM provider
|
| 8 |
+
ENV_URL — Environment base URL (default: localhost:7860)
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
export API_BASE_URL=https://openrouter.ai/api/v1
|
| 12 |
+
export MODEL_NAME=openai/gpt-4o-mini
|
| 13 |
+
export HF_TOKEN=sk-...
|
| 14 |
+
python inference.py
|
| 15 |
+
"""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
import json
|
| 21 |
+
import time
|
| 22 |
+
|
| 23 |
+
import httpx
|
| 24 |
+
|
| 25 |
+
API_BASE_URL: str = os.environ.get("API_BASE_URL", "").rstrip("/")
|
| 26 |
+
MODEL_NAME: str = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 27 |
+
HF_TOKEN: str = os.environ.get("HF_TOKEN", "")
|
| 28 |
+
ENV_URL: str = os.environ.get("ENV_URL", "http://localhost:7860").rstrip("/")
|
| 29 |
+
|
| 30 |
+
TASK_IDS = ["bug-detection", "security-audit", "comprehensive-review"]
|
| 31 |
+
|
| 32 |
+
SYSTEM_PROMPT = """\
|
| 33 |
+
You are an expert software engineer performing a thorough code review.
|
| 34 |
+
|
| 35 |
+
Your job is to identify bugs, security vulnerabilities, and performance issues in code.
|
| 36 |
+
|
| 37 |
+
For each issue you find, respond with a single JSON object:
|
| 38 |
+
{"action_type": "flag_issue", "line_number": <int>, "filename": "<file>", "issue_type": "bug|security|performance|logic", "severity": "low|medium|high|critical", "description": "<explanation>", "fix_suggestion": "<fix>"}
|
| 39 |
+
|
| 40 |
+
When done, respond with:
|
| 41 |
+
{"action_type": "submit_review"}
|
| 42 |
+
|
| 43 |
+
Rules:
|
| 44 |
+
- Respond with raw JSON only — no markdown fences, no extra text
|
| 45 |
+
- One action per response
|
| 46 |
+
- Be precise with line numbers (count from line 1)
|
| 47 |
+
- Only flag real issues, not style preferences
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def chat_completion(messages: list) -> str:
|
| 52 |
+
try:
|
| 53 |
+
from openai import OpenAI
|
| 54 |
+
except ImportError:
|
| 55 |
+
raise ImportError("pip install openai")
|
| 56 |
+
|
| 57 |
+
kwargs = {"api_key": HF_TOKEN or "no-key"}
|
| 58 |
+
if API_BASE_URL:
|
| 59 |
+
kwargs["base_url"] = API_BASE_URL
|
| 60 |
+
|
| 61 |
+
client = OpenAI(**kwargs)
|
| 62 |
+
response = client.chat.completions.create(
|
| 63 |
+
model=MODEL_NAME,
|
| 64 |
+
messages=messages,
|
| 65 |
+
temperature=0.0,
|
| 66 |
+
max_tokens=400,
|
| 67 |
+
)
|
| 68 |
+
return response.choices[0].message.content.strip()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def parse_action(text: str) -> dict:
|
| 72 |
+
text = text.strip()
|
| 73 |
+
|
| 74 |
+
if "```" in text:
|
| 75 |
+
parts = text.split("```")
|
| 76 |
+
for part in parts:
|
| 77 |
+
part = part.strip()
|
| 78 |
+
if part.startswith("json"):
|
| 79 |
+
part = part[4:].strip()
|
| 80 |
+
if part.startswith("{") or part.startswith("["):
|
| 81 |
+
text = part
|
| 82 |
+
break
|
| 83 |
+
|
| 84 |
+
decoder = json.JSONDecoder()
|
| 85 |
+
for i, ch in enumerate(text):
|
| 86 |
+
if ch in ("{", "["):
|
| 87 |
+
try:
|
| 88 |
+
obj, _ = decoder.raw_decode(text, i)
|
| 89 |
+
if isinstance(obj, dict):
|
| 90 |
+
return obj
|
| 91 |
+
if isinstance(obj, list):
|
| 92 |
+
for item in obj:
|
| 93 |
+
if isinstance(item, dict):
|
| 94 |
+
return item
|
| 95 |
+
except json.JSONDecodeError:
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
return {"action_type": "submit_review"}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def run_keyword_fallback(base_url: str, task_id: str) -> dict:
|
| 102 |
+
"""Fallback: use the built-in /baseline endpoint (no LLM needed)."""
|
| 103 |
+
with httpx.Client(timeout=30) as client:
|
| 104 |
+
resp = client.post(f"{base_url}/baseline")
|
| 105 |
+
resp.raise_for_status()
|
| 106 |
+
results = resp.json()
|
| 107 |
+
score = results["baseline_scores"].get(task_id, {}).get("score", 0.0)
|
| 108 |
+
return {"task_id": task_id, "score": score, "steps": 0, "method": "keyword_heuristic"}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def run_task(task_id: str, http_client: httpx.Client) -> dict:
|
| 112 |
+
resp = http_client.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=30)
|
| 113 |
+
resp.raise_for_status()
|
| 114 |
+
obs = resp.json()
|
| 115 |
+
|
| 116 |
+
code_display = "\n\n".join(
|
| 117 |
+
f"=== {fname} ===\n{code}"
|
| 118 |
+
for fname, code in obs.get("code_files", {}).items()
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
messages = [
|
| 122 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 123 |
+
{
|
| 124 |
+
"role": "user",
|
| 125 |
+
"content": (
|
| 126 |
+
f"Task: {obs.get('task_description', '')}\n\n"
|
| 127 |
+
f"{code_display}\n\n"
|
| 128 |
+
f"Review this code carefully. Flag every issue you find. "
|
| 129 |
+
f"You have {obs.get('max_steps', 20)} steps total."
|
| 130 |
+
),
|
| 131 |
+
},
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
done = False
|
| 135 |
+
step_count = 0
|
| 136 |
+
max_steps = obs.get("max_steps", 20)
|
| 137 |
+
final_score = 0.0
|
| 138 |
+
|
| 139 |
+
while not done and step_count < max_steps:
|
| 140 |
+
action_text = chat_completion(messages)
|
| 141 |
+
action = parse_action(action_text)
|
| 142 |
+
|
| 143 |
+
try:
|
| 144 |
+
step_resp = http_client.post(f"{ENV_URL}/step", json=action, timeout=30)
|
| 145 |
+
step_resp.raise_for_status()
|
| 146 |
+
obs = step_resp.json()
|
| 147 |
+
except Exception as e:
|
| 148 |
+
print(f" Step error: {e}")
|
| 149 |
+
break
|
| 150 |
+
|
| 151 |
+
done = obs.get("done", False)
|
| 152 |
+
step_count += 1
|
| 153 |
+
final_score = obs.get("current_score", 0.0)
|
| 154 |
+
reward = obs.get("reward")
|
| 155 |
+
|
| 156 |
+
messages.append({"role": "assistant", "content": action_text})
|
| 157 |
+
messages.append({
|
| 158 |
+
"role": "user",
|
| 159 |
+
"content": (
|
| 160 |
+
f"Feedback: {obs.get('feedback', '')} "
|
| 161 |
+
f"(step {step_count}/{max_steps}, score: {obs.get('current_score', 0.0):.3f})"
|
| 162 |
+
),
|
| 163 |
+
})
|
| 164 |
+
|
| 165 |
+
atype = action.get("action_type", "")
|
| 166 |
+
print(f" Step {step_count:2d}: {atype:20s} | reward={str(reward):8s} | score={obs.get('current_score', 0.0):.3f}")
|
| 167 |
+
|
| 168 |
+
if atype == "submit_review":
|
| 169 |
+
final_score = obs.get("reward", obs.get("current_score", 0.0)) or 0.0
|
| 170 |
+
break
|
| 171 |
+
|
| 172 |
+
time.sleep(0.3)
|
| 173 |
+
|
| 174 |
+
return {
|
| 175 |
+
"task_id": task_id,
|
| 176 |
+
"score": float(final_score),
|
| 177 |
+
"steps": step_count,
|
| 178 |
+
"method": "llm",
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def main():
|
| 183 |
+
use_llm = bool(HF_TOKEN and API_BASE_URL)
|
| 184 |
+
|
| 185 |
+
print("Code Review Environment — Inference")
|
| 186 |
+
print(f" Model : {MODEL_NAME}")
|
| 187 |
+
print(f" API URL : {API_BASE_URL or '(not set — using keyword heuristic)'}")
|
| 188 |
+
print(f" Env URL : {ENV_URL}")
|
| 189 |
+
print(f" Tasks : {TASK_IDS}\n")
|
| 190 |
+
|
| 191 |
+
try:
|
| 192 |
+
with httpx.Client(timeout=10) as probe:
|
| 193 |
+
health = probe.get(f"{ENV_URL}/health")
|
| 194 |
+
health.raise_for_status()
|
| 195 |
+
print(f" Health: {health.json()}\n")
|
| 196 |
+
except Exception as e:
|
| 197 |
+
print(f"ERROR: Cannot reach environment at {ENV_URL}: {e}")
|
| 198 |
+
sys.exit(1)
|
| 199 |
+
|
| 200 |
+
results = {}
|
| 201 |
+
|
| 202 |
+
if use_llm:
|
| 203 |
+
with httpx.Client(timeout=60) as client:
|
| 204 |
+
for task_id in TASK_IDS:
|
| 205 |
+
print(f"Running task: {task_id}")
|
| 206 |
+
result = run_task(task_id, client)
|
| 207 |
+
results[task_id] = result
|
| 208 |
+
print(f" → score: {result['score']:.4f} ({result['steps']} steps)\n")
|
| 209 |
+
else:
|
| 210 |
+
print("HF_TOKEN / API_BASE_URL not set — using built-in keyword heuristic baseline.\n")
|
| 211 |
+
for task_id in TASK_IDS:
|
| 212 |
+
print(f"Running task: {task_id}")
|
| 213 |
+
result = run_keyword_fallback(ENV_URL, task_id)
|
| 214 |
+
results[task_id] = result
|
| 215 |
+
print(f" → score: {result['score']:.4f}\n")
|
| 216 |
+
|
| 217 |
+
print("=" * 50)
|
| 218 |
+
print("INFERENCE RESULTS")
|
| 219 |
+
print("=" * 50)
|
| 220 |
+
for task_id, r in results.items():
|
| 221 |
+
print(f" {task_id:30s} score={r['score']:.4f}")
|
| 222 |
+
|
| 223 |
+
overall = sum(r["score"] for r in results.values()) / len(results)
|
| 224 |
+
print(f"\n Overall average: {overall:.4f}")
|
| 225 |
+
print("=" * 50)
|
| 226 |
+
|
| 227 |
+
return results
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
if __name__ == "__main__":
|
| 231 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from typing import List, Optional, Dict, Any
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class Issue:
|
| 13 |
+
line_number: int
|
| 14 |
+
filename: str
|
| 15 |
+
issue_type: str # bug | security | performance | logic
|
| 16 |
+
severity: str # low | medium | high | critical
|
| 17 |
+
description: str = ""
|
| 18 |
+
fix_suggestion: Optional[str] = None
|
| 19 |
+
|
| 20 |
+
def to_dict(self) -> dict:
|
| 21 |
+
return {
|
| 22 |
+
"line_number": self.line_number,
|
| 23 |
+
"filename": self.filename,
|
| 24 |
+
"issue_type": self.issue_type,
|
| 25 |
+
"severity": self.severity,
|
| 26 |
+
"description": self.description,
|
| 27 |
+
"fix_suggestion": self.fix_suggestion,
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
@classmethod
|
| 31 |
+
def from_dict(cls, d: dict) -> "Issue":
|
| 32 |
+
return cls(
|
| 33 |
+
line_number=int(d.get("line_number", 0)),
|
| 34 |
+
filename=str(d.get("filename", "")),
|
| 35 |
+
issue_type=str(d.get("issue_type", "bug")),
|
| 36 |
+
severity=str(d.get("severity", "medium")),
|
| 37 |
+
description=str(d.get("description", "")),
|
| 38 |
+
fix_suggestion=d.get("fix_suggestion"),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
from openenv.core.env_server import (
|
| 44 |
+
Action as _BaseAction,
|
| 45 |
+
Observation as _BaseObservation,
|
| 46 |
+
State as _BaseState,
|
| 47 |
+
)
|
| 48 |
+
except ImportError:
|
| 49 |
+
@dataclass
|
| 50 |
+
class _BaseAction:
|
| 51 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 52 |
+
|
| 53 |
+
@dataclass
|
| 54 |
+
class _BaseObservation:
|
| 55 |
+
done: bool = False
|
| 56 |
+
reward: Optional[float] = None
|
| 57 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class _BaseState:
|
| 61 |
+
episode_id: Optional[str] = None
|
| 62 |
+
step_count: int = 0
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dataclass
|
| 66 |
+
class ReviewAction(_BaseAction):
|
| 67 |
+
"""
|
| 68 |
+
Agent action during a code review episode.
|
| 69 |
+
|
| 70 |
+
action_type:
|
| 71 |
+
flag_issue — mark a line as containing an issue
|
| 72 |
+
clear_flag — remove a previously flagged issue
|
| 73 |
+
request_hint — get a hint (-0.01 reward)
|
| 74 |
+
submit_review — end the episode and receive final grade
|
| 75 |
+
"""
|
| 76 |
+
action_type: str = "flag_issue"
|
| 77 |
+
line_number: Optional[int] = None
|
| 78 |
+
filename: Optional[str] = None
|
| 79 |
+
issue_type: Optional[str] = None
|
| 80 |
+
severity: Optional[str] = None
|
| 81 |
+
description: str = ""
|
| 82 |
+
fix_suggestion: Optional[str] = None
|
| 83 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 84 |
+
|
| 85 |
+
def to_dict(self) -> dict:
|
| 86 |
+
return {
|
| 87 |
+
"action_type": self.action_type,
|
| 88 |
+
"line_number": self.line_number,
|
| 89 |
+
"filename": self.filename,
|
| 90 |
+
"issue_type": self.issue_type,
|
| 91 |
+
"severity": self.severity,
|
| 92 |
+
"description": self.description,
|
| 93 |
+
"fix_suggestion": self.fix_suggestion,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
@classmethod
|
| 97 |
+
def from_dict(cls, d: dict) -> "ReviewAction":
|
| 98 |
+
return cls(
|
| 99 |
+
action_type=str(d.get("action_type", "flag_issue")),
|
| 100 |
+
line_number=d.get("line_number"),
|
| 101 |
+
filename=d.get("filename"),
|
| 102 |
+
issue_type=d.get("issue_type"),
|
| 103 |
+
severity=d.get("severity"),
|
| 104 |
+
description=str(d.get("description", "")),
|
| 105 |
+
fix_suggestion=d.get("fix_suggestion"),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@dataclass
|
| 110 |
+
class ReviewObservation(_BaseObservation):
|
| 111 |
+
"""
|
| 112 |
+
Observation returned after each reset/step call.
|
| 113 |
+
code_files is only populated on reset; subsequent steps omit it.
|
| 114 |
+
"""
|
| 115 |
+
task_id: str = ""
|
| 116 |
+
task_description: str = ""
|
| 117 |
+
code_files: Dict[str, str] = field(default_factory=dict)
|
| 118 |
+
language: str = "python"
|
| 119 |
+
flagged_issues: List[Issue] = field(default_factory=list)
|
| 120 |
+
step_count: int = 0
|
| 121 |
+
max_steps: int = 20
|
| 122 |
+
hints_remaining: int = 3
|
| 123 |
+
feedback: str = ""
|
| 124 |
+
current_score: float = 0.0
|
| 125 |
+
done: bool = False
|
| 126 |
+
reward: Optional[float] = None
|
| 127 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 128 |
+
|
| 129 |
+
def to_dict(self) -> dict:
|
| 130 |
+
return {
|
| 131 |
+
"task_id": self.task_id,
|
| 132 |
+
"task_description": self.task_description,
|
| 133 |
+
"code_files": self.code_files,
|
| 134 |
+
"language": self.language,
|
| 135 |
+
"flagged_issues": [i.to_dict() for i in self.flagged_issues],
|
| 136 |
+
"step_count": self.step_count,
|
| 137 |
+
"max_steps": self.max_steps,
|
| 138 |
+
"hints_remaining": self.hints_remaining,
|
| 139 |
+
"feedback": self.feedback,
|
| 140 |
+
"current_score": self.current_score,
|
| 141 |
+
"done": self.done,
|
| 142 |
+
"reward": self.reward,
|
| 143 |
+
"metadata": self.metadata,
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
@classmethod
|
| 147 |
+
def from_dict(cls, d: dict) -> "ReviewObservation":
|
| 148 |
+
return cls(
|
| 149 |
+
task_id=d.get("task_id", ""),
|
| 150 |
+
task_description=d.get("task_description", ""),
|
| 151 |
+
code_files=d.get("code_files", {}),
|
| 152 |
+
language=d.get("language", "python"),
|
| 153 |
+
flagged_issues=[Issue.from_dict(i) for i in d.get("flagged_issues", [])],
|
| 154 |
+
step_count=d.get("step_count", 0),
|
| 155 |
+
max_steps=d.get("max_steps", 20),
|
| 156 |
+
hints_remaining=d.get("hints_remaining", 3),
|
| 157 |
+
feedback=d.get("feedback", ""),
|
| 158 |
+
current_score=d.get("current_score", 0.0),
|
| 159 |
+
done=d.get("done", False),
|
| 160 |
+
reward=d.get("reward"),
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@dataclass
|
| 165 |
+
class ReviewState(_BaseState):
|
| 166 |
+
task_id: str = ""
|
| 167 |
+
difficulty: str = ""
|
| 168 |
+
episode_id: Optional[str] = None
|
| 169 |
+
step_count: int = 0
|
| 170 |
+
flagged_issues: List[Issue] = field(default_factory=list)
|
| 171 |
+
current_score: float = 0.0
|
| 172 |
+
submitted: bool = False
|
| 173 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 174 |
+
|
| 175 |
+
def to_dict(self) -> dict:
|
| 176 |
+
return {
|
| 177 |
+
"task_id": self.task_id,
|
| 178 |
+
"difficulty": self.difficulty,
|
| 179 |
+
"episode_id": self.episode_id,
|
| 180 |
+
"step_count": self.step_count,
|
| 181 |
+
"flagged_issues": [i.to_dict() for i in self.flagged_issues],
|
| 182 |
+
"current_score": self.current_score,
|
| 183 |
+
"submitted": self.submitted,
|
| 184 |
+
}
|
openenv.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: code_review_env
|
| 3 |
+
version: "1.0.0"
|
| 4 |
+
description: >
|
| 5 |
+
A code review and security audit environment for training AI agents.
|
| 6 |
+
The agent identifies bugs, security vulnerabilities, and performance issues
|
| 7 |
+
across three tasks of increasing difficulty (easy → medium → hard).
|
| 8 |
+
type: space
|
| 9 |
+
runtime: fastapi
|
| 10 |
+
app: server.app:app
|
| 11 |
+
port: 7860
|
pyproject.toml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "code-review-env"
|
| 3 |
+
version = "1.0.0"
|
| 4 |
+
description = "OpenEnv environment for code review and security audit training"
|
| 5 |
+
requires-python = ">=3.10"
|
| 6 |
+
dependencies = [
|
| 7 |
+
"fastapi>=0.100.0",
|
| 8 |
+
"uvicorn[standard]>=0.23.0",
|
| 9 |
+
"pydantic>=2.0.0",
|
| 10 |
+
"httpx>=0.24.0",
|
| 11 |
+
"openai>=1.0.0",
|
| 12 |
+
"python-dotenv>=1.0.0",
|
| 13 |
+
"websockets>=11.0",
|
| 14 |
+
"openenv-core>=0.2.0",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
[project.optional-dependencies]
|
| 18 |
+
dev = ["pytest>=7.0", "pytest-asyncio"]
|
| 19 |
+
|
| 20 |
+
[project.scripts]
|
| 21 |
+
serve = "server.app:main"
|
| 22 |
+
|
| 23 |
+
[build-system]
|
| 24 |
+
requires = ["setuptools>=68"]
|
| 25 |
+
build-backend = "setuptools.backends.legacy:build"
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.100.0
|
| 2 |
+
uvicorn[standard]>=0.23.0
|
| 3 |
+
pydantic>=2.0.0
|
| 4 |
+
httpx>=0.24.0
|
| 5 |
+
openai>=1.0.0
|
| 6 |
+
python-dotenv>=1.0.0
|
| 7 |
+
websockets>=11.0
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# server package
|
server/app.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI application for the Code Review Environment.
|
| 3 |
+
|
| 4 |
+
Endpoints:
|
| 5 |
+
POST /reset — start new episode
|
| 6 |
+
POST /step — take an action
|
| 7 |
+
GET /state — get episode state
|
| 8 |
+
GET /health — health check
|
| 9 |
+
GET /tasks — list all tasks + action schema
|
| 10 |
+
POST /grader — grade a set of findings (stateless)
|
| 11 |
+
POST /baseline — run keyword-heuristic baseline on all tasks
|
| 12 |
+
WS /ws — persistent WebSocket session
|
| 13 |
+
GET /docs — Swagger UI (auto-generated)
|
| 14 |
+
"""
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import sys
|
| 18 |
+
import os
|
| 19 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 20 |
+
|
| 21 |
+
import json
|
| 22 |
+
import asyncio
|
| 23 |
+
import dataclasses
|
| 24 |
+
from typing import Optional, List, Dict, Any
|
| 25 |
+
|
| 26 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
|
| 27 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 28 |
+
from pydantic import BaseModel
|
| 29 |
+
|
| 30 |
+
from models import ReviewAction, Issue
|
| 31 |
+
from server.environment import CodeReviewEnvironment
|
| 32 |
+
from server.graders import grade_episode, run_keyword_baseline
|
| 33 |
+
from tasks.data import ALL_TASKS, TASK_IDS
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _serialize(obj) -> dict:
|
| 38 |
+
if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
|
| 39 |
+
d = dataclasses.asdict(obj)
|
| 40 |
+
# asdict handles nested dataclasses and lists recursively
|
| 41 |
+
return d
|
| 42 |
+
if isinstance(obj, dict):
|
| 43 |
+
return obj
|
| 44 |
+
raise TypeError(f"Cannot serialize {type(obj)}")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
_env_instance = CodeReviewEnvironment()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _make_app() -> FastAPI:
|
| 51 |
+
try:
|
| 52 |
+
from openenv.core.env_server import create_fastapi_app
|
| 53 |
+
base = create_fastapi_app(CodeReviewEnvironment)
|
| 54 |
+
return base
|
| 55 |
+
except Exception:
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
_app = FastAPI(
|
| 59 |
+
title="Code Review Environment",
|
| 60 |
+
description=(
|
| 61 |
+
"An OpenEnv environment for training AI agents to perform "
|
| 62 |
+
"code review and security audits."
|
| 63 |
+
),
|
| 64 |
+
version="1.0.0",
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
_app.add_middleware(
|
| 68 |
+
CORSMiddleware,
|
| 69 |
+
allow_origins=["*"],
|
| 70 |
+
allow_methods=["*"],
|
| 71 |
+
allow_headers=["*"],
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
@_app.get("/health")
|
| 75 |
+
async def health():
|
| 76 |
+
return {"status": "healthy"}
|
| 77 |
+
|
| 78 |
+
@_app.post("/reset")
|
| 79 |
+
async def reset(body: dict = None):
|
| 80 |
+
body = body or {}
|
| 81 |
+
task_id = body.get("task_id")
|
| 82 |
+
seed = body.get("seed")
|
| 83 |
+
episode_id = body.get("episode_id")
|
| 84 |
+
obs = _env_instance.reset(task_id=task_id, seed=seed, episode_id=episode_id)
|
| 85 |
+
return _serialize(obs)
|
| 86 |
+
|
| 87 |
+
@_app.post("/step")
|
| 88 |
+
async def step(body: dict):
|
| 89 |
+
action = ReviewAction.from_dict(body)
|
| 90 |
+
obs = _env_instance.step(action)
|
| 91 |
+
return _serialize(obs)
|
| 92 |
+
|
| 93 |
+
@_app.get("/state")
|
| 94 |
+
async def state():
|
| 95 |
+
return _serialize(_env_instance.state)
|
| 96 |
+
|
| 97 |
+
@_app.websocket("/ws")
|
| 98 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 99 |
+
await websocket.accept()
|
| 100 |
+
ws_env = CodeReviewEnvironment()
|
| 101 |
+
try:
|
| 102 |
+
while True:
|
| 103 |
+
raw = await websocket.receive_text()
|
| 104 |
+
msg = json.loads(raw)
|
| 105 |
+
msg_type = msg.get("type", "")
|
| 106 |
+
|
| 107 |
+
if msg_type == "reset":
|
| 108 |
+
data = msg.get("data", {})
|
| 109 |
+
obs = ws_env.reset(
|
| 110 |
+
task_id=data.get("task_id"),
|
| 111 |
+
seed=data.get("seed"),
|
| 112 |
+
episode_id=data.get("episode_id"),
|
| 113 |
+
)
|
| 114 |
+
await websocket.send_text(json.dumps({
|
| 115 |
+
"type": "observation",
|
| 116 |
+
"data": _serialize(obs),
|
| 117 |
+
}))
|
| 118 |
+
|
| 119 |
+
elif msg_type == "step":
|
| 120 |
+
action = ReviewAction.from_dict(msg.get("data", {}))
|
| 121 |
+
obs = ws_env.step(action)
|
| 122 |
+
await websocket.send_text(json.dumps({
|
| 123 |
+
"type": "observation",
|
| 124 |
+
"data": _serialize(obs),
|
| 125 |
+
}))
|
| 126 |
+
|
| 127 |
+
elif msg_type == "state":
|
| 128 |
+
await websocket.send_text(json.dumps({
|
| 129 |
+
"type": "state",
|
| 130 |
+
"data": _serialize(ws_env.state),
|
| 131 |
+
}))
|
| 132 |
+
|
| 133 |
+
elif msg_type == "close":
|
| 134 |
+
break
|
| 135 |
+
|
| 136 |
+
else:
|
| 137 |
+
await websocket.send_text(json.dumps({
|
| 138 |
+
"type": "error",
|
| 139 |
+
"data": f"Unknown message type: {msg_type}",
|
| 140 |
+
}))
|
| 141 |
+
|
| 142 |
+
except WebSocketDisconnect:
|
| 143 |
+
pass
|
| 144 |
+
except Exception as e:
|
| 145 |
+
try:
|
| 146 |
+
await websocket.send_text(json.dumps({"type": "error", "data": str(e)}))
|
| 147 |
+
except Exception:
|
| 148 |
+
pass
|
| 149 |
+
|
| 150 |
+
return _app
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
app = _make_app()
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@app.get("/tasks")
|
| 157 |
+
async def list_tasks():
|
| 158 |
+
tasks_list = []
|
| 159 |
+
for task in ALL_TASKS.values():
|
| 160 |
+
tasks_list.append({
|
| 161 |
+
"task_id": task["task_id"],
|
| 162 |
+
"difficulty": task["difficulty"],
|
| 163 |
+
"description": task["description"],
|
| 164 |
+
"language": task.get("language", "python"),
|
| 165 |
+
"max_steps": task["max_steps"],
|
| 166 |
+
"num_issues": len(task["ground_truth_issues"]),
|
| 167 |
+
"files": list(task["code_files"].keys()),
|
| 168 |
+
})
|
| 169 |
+
|
| 170 |
+
action_schema = {
|
| 171 |
+
"type": "object",
|
| 172 |
+
"description": "ReviewAction — one action per /step call",
|
| 173 |
+
"required": ["action_type"],
|
| 174 |
+
"properties": {
|
| 175 |
+
"action_type": {
|
| 176 |
+
"type": "string",
|
| 177 |
+
"enum": ["flag_issue", "clear_flag", "request_hint", "submit_review"],
|
| 178 |
+
"description": (
|
| 179 |
+
"flag_issue: mark a line as problematic. "
|
| 180 |
+
"clear_flag: remove a previous flag. "
|
| 181 |
+
"request_hint: get a hint (-0.01 reward). "
|
| 182 |
+
"submit_review: end episode and receive final grade."
|
| 183 |
+
),
|
| 184 |
+
},
|
| 185 |
+
"line_number": {
|
| 186 |
+
"type": "integer",
|
| 187 |
+
"description": "Line number of the issue (required for flag_issue / clear_flag)",
|
| 188 |
+
},
|
| 189 |
+
"filename": {
|
| 190 |
+
"type": "string",
|
| 191 |
+
"description": "File where the issue is (required for flag_issue / clear_flag)",
|
| 192 |
+
},
|
| 193 |
+
"issue_type": {
|
| 194 |
+
"type": "string",
|
| 195 |
+
"enum": ["bug", "security", "performance", "logic"],
|
| 196 |
+
"description": "Category of issue (required for flag_issue)",
|
| 197 |
+
},
|
| 198 |
+
"severity": {
|
| 199 |
+
"type": "string",
|
| 200 |
+
"enum": ["low", "medium", "high", "critical"],
|
| 201 |
+
"description": "Severity level (required for flag_issue)",
|
| 202 |
+
},
|
| 203 |
+
"description": {
|
| 204 |
+
"type": "string",
|
| 205 |
+
"description": "Human-readable description of the issue",
|
| 206 |
+
},
|
| 207 |
+
"fix_suggestion": {
|
| 208 |
+
"type": "string",
|
| 209 |
+
"description": "Optional suggested fix",
|
| 210 |
+
},
|
| 211 |
+
},
|
| 212 |
+
"examples": [
|
| 213 |
+
{
|
| 214 |
+
"action_type": "flag_issue",
|
| 215 |
+
"line_number": 6,
|
| 216 |
+
"filename": "utils.py",
|
| 217 |
+
"issue_type": "bug",
|
| 218 |
+
"severity": "high",
|
| 219 |
+
"description": "Off-by-one error in range()",
|
| 220 |
+
"fix_suggestion": "Change range(len(numbers) + 1) to range(len(numbers))",
|
| 221 |
+
},
|
| 222 |
+
{"action_type": "submit_review"},
|
| 223 |
+
],
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
return {
|
| 227 |
+
"tasks": tasks_list,
|
| 228 |
+
"action_schema": action_schema,
|
| 229 |
+
"total_tasks": len(tasks_list),
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class GraderRequest(BaseModel):
|
| 234 |
+
task_id: str
|
| 235 |
+
flagged_issues: List[Dict[str, Any]]
|
| 236 |
+
|
| 237 |
+
@app.post("/grader")
|
| 238 |
+
async def run_grader(request: GraderRequest):
|
| 239 |
+
task = ALL_TASKS.get(request.task_id)
|
| 240 |
+
if not task:
|
| 241 |
+
raise HTTPException(
|
| 242 |
+
status_code=404,
|
| 243 |
+
detail=f"Unknown task_id '{request.task_id}'. Valid: {TASK_IDS}",
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
flagged = [Issue.from_dict(i) for i in request.flagged_issues]
|
| 247 |
+
ground_truth = [Issue.from_dict(gt) for gt in task["ground_truth_issues"]]
|
| 248 |
+
score = grade_episode(flagged, ground_truth)
|
| 249 |
+
|
| 250 |
+
tp = sum(
|
| 251 |
+
1 for f in flagged
|
| 252 |
+
if any(
|
| 253 |
+
True for gt in ground_truth
|
| 254 |
+
if abs(f.line_number - gt.line_number) <= 2
|
| 255 |
+
and f.filename == gt.filename
|
| 256 |
+
)
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
return {
|
| 260 |
+
"task_id": request.task_id,
|
| 261 |
+
"difficulty": task["difficulty"],
|
| 262 |
+
"score": score,
|
| 263 |
+
"max_score": 1.0,
|
| 264 |
+
"details": {
|
| 265 |
+
"total_flagged": len(flagged),
|
| 266 |
+
"true_positives": tp,
|
| 267 |
+
"false_positives": len(flagged) - tp,
|
| 268 |
+
"total_ground_truth": len(ground_truth),
|
| 269 |
+
},
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
@app.post("/baseline")
|
| 274 |
+
async def run_baseline():
|
| 275 |
+
results = {}
|
| 276 |
+
for task_id, task in ALL_TASKS.items():
|
| 277 |
+
findings = run_keyword_baseline(task)
|
| 278 |
+
ground_truth = [Issue.from_dict(gt) for gt in task["ground_truth_issues"]]
|
| 279 |
+
score = grade_episode(findings, ground_truth)
|
| 280 |
+
results[task_id] = {
|
| 281 |
+
"difficulty": task["difficulty"],
|
| 282 |
+
"score": score,
|
| 283 |
+
"findings_count": len(findings),
|
| 284 |
+
"ground_truth_count": len(ground_truth),
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
overall = sum(r["score"] for r in results.values()) / len(results)
|
| 288 |
+
return {
|
| 289 |
+
"baseline_scores": results,
|
| 290 |
+
"overall_average": round(overall, 4),
|
| 291 |
+
"method": "keyword_heuristic",
|
| 292 |
+
"note": (
|
| 293 |
+
"Run 'python baseline.py' with OPENAI_API_KEY for the LLM-based baseline. "
|
| 294 |
+
"This endpoint uses a deterministic regex heuristic."
|
| 295 |
+
),
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def main():
|
| 300 |
+
import uvicorn
|
| 301 |
+
port = int(os.environ.get("PORT", 7860))
|
| 302 |
+
uvicorn.run("server.app:app", host="0.0.0.0", port=port)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
if __name__ == "__main__":
|
| 306 |
+
main()
|
server/environment.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core environment logic for the Code Review Environment.
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import random
|
| 7 |
+
import uuid
|
| 8 |
+
import sys
|
| 9 |
+
import os
|
| 10 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 11 |
+
|
| 12 |
+
from typing import Optional, List
|
| 13 |
+
|
| 14 |
+
from models import Issue, ReviewAction, ReviewObservation, ReviewState
|
| 15 |
+
from tasks.data import ALL_TASKS, TASK_IDS
|
| 16 |
+
from server.graders import grade_episode, compute_live_score, match_issue
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from openenv.core.env_server import Environment as _BaseEnv
|
| 20 |
+
_HAS_OPENENV = True
|
| 21 |
+
except ImportError:
|
| 22 |
+
_HAS_OPENENV = False
|
| 23 |
+
|
| 24 |
+
class _BaseEnv: # type: ignore[no-redef]
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class CodeReviewEnvironment(_BaseEnv):
|
| 29 |
+
"""
|
| 30 |
+
A code review and security audit environment.
|
| 31 |
+
|
| 32 |
+
The agent receives code files and must identify bugs, security
|
| 33 |
+
vulnerabilities, and performance issues by flagging them with
|
| 34 |
+
exact line numbers, types, and severity ratings.
|
| 35 |
+
|
| 36 |
+
Episode flow:
|
| 37 |
+
1. reset(task_id) — agent sees code files and task description
|
| 38 |
+
2. step(flag_issue) — flag a problem; get per-step reward
|
| 39 |
+
3. step(clear_flag) — remove an incorrectly flagged issue
|
| 40 |
+
4. step(request_hint) — get a hint (costs -0.01 reward)
|
| 41 |
+
5. step(submit_review) — episode ends, final grade is returned
|
| 42 |
+
(or auto-ends when max_steps is reached)
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
SUPPORTS_CONCURRENT_SESSIONS = False
|
| 46 |
+
|
| 47 |
+
def __init__(self) -> None:
|
| 48 |
+
self._state = ReviewState()
|
| 49 |
+
self._task: Optional[dict] = None
|
| 50 |
+
self._ground_truth: List[Issue] = []
|
| 51 |
+
self._hint_index: int = 0
|
| 52 |
+
|
| 53 |
+
def reset(
|
| 54 |
+
self,
|
| 55 |
+
task_id: Optional[str] = None,
|
| 56 |
+
seed: Optional[int] = None,
|
| 57 |
+
episode_id: Optional[str] = None,
|
| 58 |
+
**kwargs,
|
| 59 |
+
) -> ReviewObservation:
|
| 60 |
+
"""Start a new review episode."""
|
| 61 |
+
if seed is not None:
|
| 62 |
+
random.seed(seed)
|
| 63 |
+
|
| 64 |
+
if task_id is None or task_id not in ALL_TASKS:
|
| 65 |
+
task_id = random.choice(TASK_IDS)
|
| 66 |
+
|
| 67 |
+
self._task = ALL_TASKS[task_id]
|
| 68 |
+
self._ground_truth = [
|
| 69 |
+
Issue.from_dict(gt)
|
| 70 |
+
for gt in self._task["ground_truth_issues"]
|
| 71 |
+
]
|
| 72 |
+
self._hint_index = 0
|
| 73 |
+
|
| 74 |
+
self._state = ReviewState(
|
| 75 |
+
task_id=task_id,
|
| 76 |
+
difficulty=self._task["difficulty"],
|
| 77 |
+
episode_id=episode_id or str(uuid.uuid4()),
|
| 78 |
+
step_count=0,
|
| 79 |
+
flagged_issues=[],
|
| 80 |
+
current_score=0.0,
|
| 81 |
+
submitted=False,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
return ReviewObservation(
|
| 85 |
+
task_id=task_id,
|
| 86 |
+
task_description=self._task["description"],
|
| 87 |
+
code_files=self._task["code_files"],
|
| 88 |
+
language=self._task.get("language", "python"),
|
| 89 |
+
flagged_issues=[],
|
| 90 |
+
step_count=0,
|
| 91 |
+
max_steps=self._task["max_steps"],
|
| 92 |
+
hints_remaining=len(self._task.get("hints", [])),
|
| 93 |
+
feedback=(
|
| 94 |
+
f"New episode started. Task: {self._task['difficulty'].upper()}. "
|
| 95 |
+
f"Review the code carefully and flag all issues you find. "
|
| 96 |
+
f"Use 'submit_review' when done."
|
| 97 |
+
),
|
| 98 |
+
current_score=0.0,
|
| 99 |
+
done=False,
|
| 100 |
+
reward=None,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def step(
|
| 104 |
+
self,
|
| 105 |
+
action: ReviewAction,
|
| 106 |
+
timeout_s: Optional[float] = None,
|
| 107 |
+
**kwargs,
|
| 108 |
+
) -> ReviewObservation:
|
| 109 |
+
"""Process one agent action and return the new observation."""
|
| 110 |
+
if self._task is None:
|
| 111 |
+
return ReviewObservation(
|
| 112 |
+
done=True,
|
| 113 |
+
reward=0.0,
|
| 114 |
+
feedback="Episode not initialized. Call reset() first.",
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if self._state.submitted:
|
| 118 |
+
return ReviewObservation(
|
| 119 |
+
task_id=self._state.task_id,
|
| 120 |
+
task_description="",
|
| 121 |
+
code_files={},
|
| 122 |
+
flagged_issues=list(self._state.flagged_issues),
|
| 123 |
+
step_count=self._state.step_count,
|
| 124 |
+
max_steps=self._task["max_steps"],
|
| 125 |
+
hints_remaining=0,
|
| 126 |
+
feedback="Episode already submitted. Call reset() to start a new episode.",
|
| 127 |
+
current_score=self._state.current_score,
|
| 128 |
+
done=True,
|
| 129 |
+
reward=0.0,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
if isinstance(action, dict):
|
| 133 |
+
action = ReviewAction.from_dict(action)
|
| 134 |
+
|
| 135 |
+
self._state.step_count += 1
|
| 136 |
+
reward, feedback = self._process_action(action)
|
| 137 |
+
|
| 138 |
+
max_steps = self._task["max_steps"]
|
| 139 |
+
auto_end = self._state.step_count >= max_steps and not self._state.submitted
|
| 140 |
+
done = self._state.submitted or auto_end
|
| 141 |
+
|
| 142 |
+
if auto_end and not self._state.submitted:
|
| 143 |
+
# Grade what was submitted so far
|
| 144 |
+
final = grade_episode(self._state.flagged_issues, self._ground_truth)
|
| 145 |
+
self._state.current_score = final
|
| 146 |
+
reward = final * 0.5 # partial credit for auto-end
|
| 147 |
+
feedback += (
|
| 148 |
+
f" Max steps reached. Auto-graded: {final:.3f}. "
|
| 149 |
+
f"Submit earlier for best score."
|
| 150 |
+
)
|
| 151 |
+
self._state.submitted = True
|
| 152 |
+
|
| 153 |
+
live = compute_live_score(self._state.flagged_issues, self._ground_truth)
|
| 154 |
+
self._state.current_score = live
|
| 155 |
+
|
| 156 |
+
return ReviewObservation(
|
| 157 |
+
task_id=self._state.task_id,
|
| 158 |
+
task_description="",
|
| 159 |
+
code_files={},
|
| 160 |
+
language=self._task.get("language", "python"),
|
| 161 |
+
flagged_issues=list(self._state.flagged_issues),
|
| 162 |
+
step_count=self._state.step_count,
|
| 163 |
+
max_steps=max_steps,
|
| 164 |
+
hints_remaining=max(0, len(self._task.get("hints", [])) - self._hint_index),
|
| 165 |
+
feedback=feedback,
|
| 166 |
+
current_score=live,
|
| 167 |
+
done=done,
|
| 168 |
+
reward=reward,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
@property
|
| 172 |
+
def state(self) -> ReviewState:
|
| 173 |
+
return self._state
|
| 174 |
+
|
| 175 |
+
def _process_action(self, action: ReviewAction):
|
| 176 |
+
atype = (action.action_type or "").strip().lower()
|
| 177 |
+
|
| 178 |
+
if atype == "flag_issue":
|
| 179 |
+
return self._handle_flag(action)
|
| 180 |
+
elif atype == "clear_flag":
|
| 181 |
+
return self._handle_clear(action)
|
| 182 |
+
elif atype == "request_hint":
|
| 183 |
+
return self._handle_hint()
|
| 184 |
+
elif atype == "submit_review":
|
| 185 |
+
return self._handle_submit()
|
| 186 |
+
else:
|
| 187 |
+
return 0.0, (
|
| 188 |
+
f"Unknown action_type '{action.action_type}'. "
|
| 189 |
+
"Use: flag_issue | clear_flag | request_hint | submit_review"
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
def _handle_flag(self, action: ReviewAction):
|
| 193 |
+
if action.line_number is None:
|
| 194 |
+
return -0.02, "flag_issue requires 'line_number'."
|
| 195 |
+
if not action.filename:
|
| 196 |
+
return -0.02, "flag_issue requires 'filename'."
|
| 197 |
+
if action.issue_type not in ("bug", "security", "performance", "logic", None):
|
| 198 |
+
action.issue_type = "bug"
|
| 199 |
+
if action.severity not in ("low", "medium", "high", "critical", None):
|
| 200 |
+
action.severity = "medium"
|
| 201 |
+
|
| 202 |
+
for existing in self._state.flagged_issues:
|
| 203 |
+
if (existing.line_number == action.line_number
|
| 204 |
+
and existing.filename == action.filename):
|
| 205 |
+
return 0.0, (
|
| 206 |
+
f"Line {action.line_number} in {action.filename} already flagged. "
|
| 207 |
+
"Use clear_flag first if you want to change the finding."
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
new_issue = Issue(
|
| 211 |
+
line_number=action.line_number,
|
| 212 |
+
filename=action.filename or "",
|
| 213 |
+
issue_type=action.issue_type or "bug",
|
| 214 |
+
severity=action.severity or "medium",
|
| 215 |
+
description=action.description or "",
|
| 216 |
+
fix_suggestion=action.fix_suggestion,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
is_tp = any(
|
| 220 |
+
match_issue(new_issue, gt)
|
| 221 |
+
for gt in self._ground_truth
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
self._state.flagged_issues.append(new_issue)
|
| 225 |
+
|
| 226 |
+
if is_tp:
|
| 227 |
+
reward = 0.10
|
| 228 |
+
feedback = (
|
| 229 |
+
f"Good catch! Issue flagged at {action.filename}:{action.line_number}. "
|
| 230 |
+
f"[+0.10 reward — correct finding]"
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
reward = -0.05
|
| 234 |
+
feedback = (
|
| 235 |
+
f"Issue flagged at {action.filename}:{action.line_number}. "
|
| 236 |
+
f"[-0.05 reward — no matching ground-truth issue nearby]"
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
return reward, feedback
|
| 240 |
+
|
| 241 |
+
def _handle_clear(self, action: ReviewAction):
|
| 242 |
+
if action.line_number is None or not action.filename:
|
| 243 |
+
return -0.02, "clear_flag requires 'line_number' and 'filename'."
|
| 244 |
+
|
| 245 |
+
before = len(self._state.flagged_issues)
|
| 246 |
+
removed = None
|
| 247 |
+
self._state.flagged_issues = [
|
| 248 |
+
f for f in self._state.flagged_issues
|
| 249 |
+
if not (f.line_number == action.line_number
|
| 250 |
+
and f.filename == action.filename)
|
| 251 |
+
]
|
| 252 |
+
|
| 253 |
+
if len(self._state.flagged_issues) == before:
|
| 254 |
+
return 0.0, (
|
| 255 |
+
f"No flagged issue found at {action.filename}:{action.line_number}."
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
removed_issue = Issue(
|
| 259 |
+
line_number=action.line_number,
|
| 260 |
+
filename=action.filename,
|
| 261 |
+
issue_type="bug",
|
| 262 |
+
severity="medium",
|
| 263 |
+
)
|
| 264 |
+
was_tp = any(match_issue(removed_issue, gt) for gt in self._ground_truth)
|
| 265 |
+
|
| 266 |
+
if was_tp:
|
| 267 |
+
reward = -0.03
|
| 268 |
+
feedback = (
|
| 269 |
+
f"Removed a correct finding at {action.filename}:{action.line_number}. "
|
| 270 |
+
f"[-0.03 reward]"
|
| 271 |
+
)
|
| 272 |
+
else:
|
| 273 |
+
reward = 0.03
|
| 274 |
+
feedback = (
|
| 275 |
+
f"Removed a false positive at {action.filename}:{action.line_number}. "
|
| 276 |
+
f"[+0.03 reward — good correction]"
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
return reward, feedback
|
| 280 |
+
|
| 281 |
+
def _handle_hint(self):
|
| 282 |
+
hints = self._task.get("hints", [])
|
| 283 |
+
if self._hint_index >= len(hints):
|
| 284 |
+
return -0.01, "No more hints available for this task."
|
| 285 |
+
|
| 286 |
+
hint = hints[self._hint_index]
|
| 287 |
+
self._hint_index += 1
|
| 288 |
+
remaining = len(hints) - self._hint_index
|
| 289 |
+
return -0.01, f"Hint {self._hint_index}/{len(hints)}: {hint} ({remaining} hints left)"
|
| 290 |
+
|
| 291 |
+
def _handle_submit(self):
|
| 292 |
+
self._state.submitted = True
|
| 293 |
+
final_score = grade_episode(self._state.flagged_issues, self._ground_truth)
|
| 294 |
+
self._state.current_score = final_score
|
| 295 |
+
|
| 296 |
+
tp_count = sum(
|
| 297 |
+
1 for f in self._state.flagged_issues
|
| 298 |
+
if any(match_issue(f, gt) for gt in self._ground_truth)
|
| 299 |
+
)
|
| 300 |
+
total_gt = len(self._ground_truth)
|
| 301 |
+
total_flagged = len(self._state.flagged_issues)
|
| 302 |
+
|
| 303 |
+
feedback = (
|
| 304 |
+
f"Review submitted! Final score: {final_score:.3f}. "
|
| 305 |
+
f"Found {tp_count}/{total_gt} real issues. "
|
| 306 |
+
f"Total flags: {total_flagged} "
|
| 307 |
+
f"({'perfect' if total_flagged == tp_count else f'{total_flagged - tp_count} false positives'})."
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
return final_score, feedback
|
server/graders.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Grading logic for the Code Review Environment.
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
from typing import List, Tuple, Set
|
| 8 |
+
|
| 9 |
+
import sys
|
| 10 |
+
import os
|
| 11 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 12 |
+
|
| 13 |
+
from models import Issue
|
| 14 |
+
|
| 15 |
+
_SEV_RANK = {"low": 0, "medium": 1, "high": 2, "critical": 3}
|
| 16 |
+
|
| 17 |
+
_TYPE_COMPAT = {
|
| 18 |
+
"bug": {"bug", "logic"},
|
| 19 |
+
"logic": {"bug", "logic"},
|
| 20 |
+
"security": {"security"},
|
| 21 |
+
"performance": {"performance"},
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def match_issue(flagged: Issue, gt: Issue, line_tolerance: int = 2) -> bool:
|
| 26 |
+
if flagged.filename != gt.filename:
|
| 27 |
+
return False
|
| 28 |
+
if abs(flagged.line_number - gt.line_number) > line_tolerance:
|
| 29 |
+
return False
|
| 30 |
+
compat = _TYPE_COMPAT.get(gt.issue_type, {gt.issue_type})
|
| 31 |
+
if flagged.issue_type not in compat:
|
| 32 |
+
return False
|
| 33 |
+
return True
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def grade_episode(
|
| 37 |
+
flagged: List[Issue],
|
| 38 |
+
ground_truth: List[Issue],
|
| 39 |
+
line_tolerance: int = 2,
|
| 40 |
+
) -> float:
|
| 41 |
+
"""Compute a 0.0–1.0 score: 0.70 * F1 + 0.30 * severity_accuracy."""
|
| 42 |
+
if not ground_truth:
|
| 43 |
+
return 1.0 if not flagged else 0.0
|
| 44 |
+
|
| 45 |
+
tp = 0
|
| 46 |
+
fp = 0
|
| 47 |
+
matched_gt_indices: Set[int] = set()
|
| 48 |
+
severity_scores: List[float] = []
|
| 49 |
+
|
| 50 |
+
for flag in flagged:
|
| 51 |
+
matched = False
|
| 52 |
+
for i, gt in enumerate(ground_truth):
|
| 53 |
+
if i in matched_gt_indices:
|
| 54 |
+
continue
|
| 55 |
+
if match_issue(flag, gt, line_tolerance):
|
| 56 |
+
tp += 1
|
| 57 |
+
matched_gt_indices.add(i)
|
| 58 |
+
matched = True
|
| 59 |
+
flag_rank = _SEV_RANK.get(flag.severity, 1)
|
| 60 |
+
gt_rank = _SEV_RANK.get(gt.severity, 1)
|
| 61 |
+
distance = abs(flag_rank - gt_rank)
|
| 62 |
+
severity_scores.append(max(0.0, 1.0 - distance * 0.34))
|
| 63 |
+
break
|
| 64 |
+
if not matched:
|
| 65 |
+
fp += 1
|
| 66 |
+
|
| 67 |
+
fn = len(ground_truth) - len(matched_gt_indices)
|
| 68 |
+
|
| 69 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 70 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 71 |
+
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
|
| 72 |
+
|
| 73 |
+
if severity_scores:
|
| 74 |
+
severity_accuracy = sum(severity_scores) / len(ground_truth)
|
| 75 |
+
else:
|
| 76 |
+
severity_accuracy = 0.0
|
| 77 |
+
|
| 78 |
+
final = 0.70 * f1 + 0.30 * severity_accuracy
|
| 79 |
+
return round(min(1.0, max(0.0, final)), 4)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def compute_live_score(flagged: List[Issue], ground_truth: List[Issue]) -> float:
|
| 83 |
+
"""F1-only score for per-step feedback (no severity bonus)."""
|
| 84 |
+
if not ground_truth:
|
| 85 |
+
return 1.0 if not flagged else 0.0
|
| 86 |
+
|
| 87 |
+
tp = 0
|
| 88 |
+
fp = 0
|
| 89 |
+
matched: Set[int] = set()
|
| 90 |
+
|
| 91 |
+
for flag in flagged:
|
| 92 |
+
hit = False
|
| 93 |
+
for i, gt in enumerate(ground_truth):
|
| 94 |
+
if i not in matched and match_issue(flag, gt):
|
| 95 |
+
tp += 1
|
| 96 |
+
matched.add(i)
|
| 97 |
+
hit = True
|
| 98 |
+
break
|
| 99 |
+
if not hit:
|
| 100 |
+
fp += 1
|
| 101 |
+
|
| 102 |
+
fn = len(ground_truth) - len(matched)
|
| 103 |
+
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
|
| 104 |
+
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
|
| 105 |
+
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
|
| 106 |
+
return round(f1, 4)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
_PATTERNS = [
|
| 110 |
+
(r"range\(len\(\w+\)\s*\+\s*1\)", None, "bug", "high",
|
| 111 |
+
"Off-by-one error: range(len(x) + 1) iterates one past the end"),
|
| 112 |
+
(r"left,\s*right\s*=\s*0,\s*len\(", None, "bug", "medium",
|
| 113 |
+
"Binary search upper bound should be len(arr) - 1"),
|
| 114 |
+
(r"counts\[word\]\s*=\s*0\b", None, "bug", "low",
|
| 115 |
+
"Counter initialized to 0 instead of 1"),
|
| 116 |
+
|
| 117 |
+
(r'SECRET_KEY\s*=\s*["\']', None, "security", "high",
|
| 118 |
+
"Hardcoded SECRET_KEY in source code"),
|
| 119 |
+
(r'PASSWORD\s*=\s*["\']', None, "security", "high",
|
| 120 |
+
"Hardcoded password in source code"),
|
| 121 |
+
(r"f['\"].*SELECT.*\{", None, "security", "critical",
|
| 122 |
+
"SQL injection via f-string query construction"),
|
| 123 |
+
(r"f['\"].*DELETE.*\{", None, "security", "critical",
|
| 124 |
+
"SQL injection via f-string DELETE query"),
|
| 125 |
+
(r"render_template_string\(f['\"]", None, "security", "high",
|
| 126 |
+
"XSS: unsanitized user input in render_template_string"),
|
| 127 |
+
(r"shell\s*=\s*True", None, "security", "critical",
|
| 128 |
+
"Command injection risk: shell=True with user input"),
|
| 129 |
+
(r"hashlib\.md5\(", None, "security", "medium",
|
| 130 |
+
"MD5 is cryptographically broken, use SHA-256 or HMAC-SHA256"),
|
| 131 |
+
(r"expected\s*==\s*\w+_hash", None, "security", "medium",
|
| 132 |
+
"Timing attack: use hmac.compare_digest() for constant-time comparison"),
|
| 133 |
+
(r"password\s*=\s*models\.CharField", None, "security", "critical",
|
| 134 |
+
"Plaintext password storage in database"),
|
| 135 |
+
(r"os\.path\.join\(['\"]\/", None, "security", "high",
|
| 136 |
+
"Path traversal: os.path.join with absolute prefix doesn't prevent traversal"),
|
| 137 |
+
|
| 138 |
+
(r"\.objects\.get\(id=item\.", None, "performance", "high",
|
| 139 |
+
"N+1 query: database lookup inside a loop"),
|
| 140 |
+
|
| 141 |
+
(r"FloatField\(\)", None, "bug", "medium",
|
| 142 |
+
"FloatField for monetary values causes precision errors, use DecimalField"),
|
| 143 |
+
(r"BinaryField\(\)", None, "security", "high",
|
| 144 |
+
"BinaryField with pickled data is a deserialization vulnerability"),
|
| 145 |
+
]
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def run_keyword_baseline(task: dict) -> List[Issue]:
|
| 149 |
+
findings: List[Issue] = []
|
| 150 |
+
seen_lines: set = set()
|
| 151 |
+
|
| 152 |
+
for filename, code in task.get("code_files", {}).items():
|
| 153 |
+
lines = code.splitlines()
|
| 154 |
+
for line_idx, line in enumerate(lines, start=1):
|
| 155 |
+
for pattern, fname_hint, itype, severity, desc in _PATTERNS:
|
| 156 |
+
# Optional filename filter
|
| 157 |
+
if fname_hint and fname_hint not in filename:
|
| 158 |
+
continue
|
| 159 |
+
if re.search(pattern, line):
|
| 160 |
+
key = (filename, line_idx)
|
| 161 |
+
if key not in seen_lines:
|
| 162 |
+
seen_lines.add(key)
|
| 163 |
+
findings.append(Issue(
|
| 164 |
+
line_number=line_idx,
|
| 165 |
+
filename=filename,
|
| 166 |
+
issue_type=itype,
|
| 167 |
+
severity=severity,
|
| 168 |
+
description=desc,
|
| 169 |
+
))
|
| 170 |
+
return findings
|
tasks/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tasks.data import ALL_TASKS, get_task, TASK_IDS
|
| 2 |
+
|
| 3 |
+
__all__ = ["ALL_TASKS", "get_task", "TASK_IDS"]
|
tasks/data.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Task definitions for the Code Review Environment.
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
from typing import List, Dict, Any
|
| 6 |
+
|
| 7 |
+
def _issue(line: int, filename: str, itype: str, severity: str, desc: str, fix: str = "") -> dict:
|
| 8 |
+
return {
|
| 9 |
+
"line_number": line,
|
| 10 |
+
"filename": filename,
|
| 11 |
+
"issue_type": itype,
|
| 12 |
+
"severity": severity,
|
| 13 |
+
"description": desc,
|
| 14 |
+
"fix_suggestion": fix,
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
_UTILS_CODE = """\
|
| 19 |
+
def calculate_average(numbers):
|
| 20 |
+
\"\"\"Calculate the average of a list of numbers.\"\"\"
|
| 21 |
+
if not numbers:
|
| 22 |
+
return 0
|
| 23 |
+
total = 0
|
| 24 |
+
for i in range(len(numbers) + 1):
|
| 25 |
+
total += numbers[i]
|
| 26 |
+
return total / len(numbers)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def binary_search(arr, target):
|
| 30 |
+
\"\"\"Search for target in sorted array. Returns index or -1.\"\"\"
|
| 31 |
+
left, right = 0, len(arr)
|
| 32 |
+
while left <= right:
|
| 33 |
+
mid = (left + right) // 2
|
| 34 |
+
if arr[mid] == target:
|
| 35 |
+
return mid
|
| 36 |
+
elif arr[mid] < target:
|
| 37 |
+
left = mid + 1
|
| 38 |
+
else:
|
| 39 |
+
right = mid - 1
|
| 40 |
+
return -1
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def count_words(text):
|
| 44 |
+
\"\"\"Count word frequency in a text string.\"\"\"
|
| 45 |
+
words = text.lower().split()
|
| 46 |
+
counts = {}
|
| 47 |
+
for word in words:
|
| 48 |
+
if word in counts:
|
| 49 |
+
counts[word] += 1
|
| 50 |
+
else:
|
| 51 |
+
counts[word] = 0
|
| 52 |
+
return counts
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def reverse_string(s):
|
| 56 |
+
\"\"\"Return the reversed version of a string (no bug here).\"\"\"
|
| 57 |
+
return s[::-1]
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
TASK_BUG_DETECTION: Dict[str, Any] = {
|
| 61 |
+
"task_id": "bug-detection",
|
| 62 |
+
"difficulty": "easy",
|
| 63 |
+
"description": (
|
| 64 |
+
"Review this Python utility module for logical bugs and errors.\n"
|
| 65 |
+
"The code contains several functions with subtle bugs that would cause\n"
|
| 66 |
+
"incorrect results or crashes. Identify all issues with exact line numbers,\n"
|
| 67 |
+
"issue type, severity, and a clear description of the problem.\n\n"
|
| 68 |
+
"File to review: utils.py"
|
| 69 |
+
),
|
| 70 |
+
"language": "python",
|
| 71 |
+
"code_files": {
|
| 72 |
+
"utils.py": _UTILS_CODE,
|
| 73 |
+
},
|
| 74 |
+
"ground_truth_issues": [
|
| 75 |
+
_issue(
|
| 76 |
+
6, "utils.py", "bug", "high",
|
| 77 |
+
"Off-by-one error: range(len(numbers) + 1) iterates one past the end, "
|
| 78 |
+
"causing IndexError on the last iteration.",
|
| 79 |
+
"Change to: range(len(numbers))"
|
| 80 |
+
),
|
| 81 |
+
_issue(
|
| 82 |
+
13, "utils.py", "bug", "medium",
|
| 83 |
+
"Binary search upper bound is wrong: right = len(arr) causes IndexError "
|
| 84 |
+
"when accessing arr[mid] on a full array.",
|
| 85 |
+
"Change to: right = len(arr) - 1"
|
| 86 |
+
),
|
| 87 |
+
_issue(
|
| 88 |
+
33, "utils.py", "bug", "low",
|
| 89 |
+
"Word count initializes new entries to 0 instead of 1, so every word's "
|
| 90 |
+
"count is underreported by 1.",
|
| 91 |
+
"Change to: counts[word] = 1"
|
| 92 |
+
),
|
| 93 |
+
],
|
| 94 |
+
"max_steps": 15,
|
| 95 |
+
"hints": [
|
| 96 |
+
"Look carefully at loop boundary conditions — are they off by one?",
|
| 97 |
+
"The binary_search function has an issue with its initial right bound.",
|
| 98 |
+
"Check how new keys are initialized in the word count dictionary.",
|
| 99 |
+
],
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
_APP_CODE = """\
|
| 104 |
+
import sqlite3
|
| 105 |
+
import os
|
| 106 |
+
import subprocess
|
| 107 |
+
from flask import Flask, request, render_template_string
|
| 108 |
+
|
| 109 |
+
app = Flask(__name__)
|
| 110 |
+
|
| 111 |
+
SECRET_KEY = "hardcoded_secret_key_123"
|
| 112 |
+
DB_PASSWORD = "admin123"
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def get_db():
|
| 116 |
+
return sqlite3.connect('users.db')
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@app.route('/user/<username>')
|
| 120 |
+
def get_user(username):
|
| 121 |
+
db = get_db()
|
| 122 |
+
query = f"SELECT * FROM users WHERE username = '{username}'"
|
| 123 |
+
result = db.execute(query).fetchone()
|
| 124 |
+
return str(result)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@app.route('/search')
|
| 128 |
+
def search():
|
| 129 |
+
term = request.args.get('term', '')
|
| 130 |
+
template = f"<h1>Results for: {term}</h1>"
|
| 131 |
+
return render_template_string(template)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@app.route('/file')
|
| 135 |
+
def read_file():
|
| 136 |
+
filename = request.args.get('name', '')
|
| 137 |
+
filepath = os.path.join('/data', filename)
|
| 138 |
+
with open(filepath, 'r') as f:
|
| 139 |
+
return f.read()
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@app.route('/admin/delete', methods=['POST'])
|
| 143 |
+
def admin_delete():
|
| 144 |
+
user_id = request.form.get('user_id')
|
| 145 |
+
db = get_db()
|
| 146 |
+
db.execute(f"DELETE FROM users WHERE id = {user_id}")
|
| 147 |
+
db.commit()
|
| 148 |
+
return "Deleted"
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@app.route('/ping')
|
| 152 |
+
def ping():
|
| 153 |
+
host = request.args.get('host', '')
|
| 154 |
+
result = subprocess.run(f"ping -c 1 {host}", shell=True, capture_output=True)
|
| 155 |
+
return result.stdout.decode()
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
TASK_SECURITY_AUDIT: Dict[str, Any] = {
|
| 159 |
+
"task_id": "security-audit",
|
| 160 |
+
"difficulty": "medium",
|
| 161 |
+
"description": (
|
| 162 |
+
"Perform a security audit on this Flask web application.\n"
|
| 163 |
+
"The code contains multiple OWASP Top-10 security vulnerabilities.\n"
|
| 164 |
+
"Identify all security issues with their exact line numbers, severity ratings,\n"
|
| 165 |
+
"and recommended fixes. Consider: injection attacks, broken authentication,\n"
|
| 166 |
+
"sensitive data exposure, and improper input handling.\n\n"
|
| 167 |
+
"File to review: app.py"
|
| 168 |
+
),
|
| 169 |
+
"language": "python",
|
| 170 |
+
"code_files": {
|
| 171 |
+
"app.py": _APP_CODE,
|
| 172 |
+
},
|
| 173 |
+
"ground_truth_issues": [
|
| 174 |
+
_issue(
|
| 175 |
+
8, "app.py", "security", "high",
|
| 176 |
+
"Hardcoded SECRET_KEY in source code. Anyone with repo access can forge sessions.",
|
| 177 |
+
"Use: SECRET_KEY = os.environ.get('SECRET_KEY') and set it as an env var."
|
| 178 |
+
),
|
| 179 |
+
_issue(
|
| 180 |
+
9, "app.py", "security", "high",
|
| 181 |
+
"Hardcoded database password in source code. Credentials should never be in code.",
|
| 182 |
+
"Use: DB_PASSWORD = os.environ.get('DB_PASSWORD')"
|
| 183 |
+
),
|
| 184 |
+
_issue(
|
| 185 |
+
19, "app.py", "security", "critical",
|
| 186 |
+
"SQL injection: username is interpolated directly into the query string. "
|
| 187 |
+
"An attacker can supply username = \\' OR 1=1 -- to dump the database.",
|
| 188 |
+
"Use parameterized queries: db.execute('SELECT * FROM users WHERE username = ?', (username,))"
|
| 189 |
+
),
|
| 190 |
+
_issue(
|
| 191 |
+
27, "app.py", "security", "high",
|
| 192 |
+
"Cross-site scripting (XSS): user-supplied 'term' is rendered directly in an "
|
| 193 |
+
"HTML template via render_template_string without escaping.",
|
| 194 |
+
"Use flask.escape(term) or Markup.escape(term) before interpolating into HTML."
|
| 195 |
+
),
|
| 196 |
+
_issue(
|
| 197 |
+
34, "app.py", "security", "high",
|
| 198 |
+
"Path traversal: os.path.join('/data', filename) does not prevent filenames "
|
| 199 |
+
"like '../etc/passwd' from escaping the /data directory.",
|
| 200 |
+
"Use: filename = os.path.basename(filename) and validate against an allowlist."
|
| 201 |
+
),
|
| 202 |
+
_issue(
|
| 203 |
+
40, "app.py", "security", "critical",
|
| 204 |
+
"Missing authentication: the /admin/delete endpoint has no access control. "
|
| 205 |
+
"Any unauthenticated user can delete records.",
|
| 206 |
+
"Add @login_required decorator and check that request.user.is_admin is True."
|
| 207 |
+
),
|
| 208 |
+
_issue(
|
| 209 |
+
51, "app.py", "security", "critical",
|
| 210 |
+
"Command injection: user-supplied 'host' is interpolated into a shell command "
|
| 211 |
+
"with shell=True. Attacker can supply 'x; rm -rf /' to execute arbitrary commands.",
|
| 212 |
+
"Use: subprocess.run(['ping', '-c', '1', host], shell=False) after validating host."
|
| 213 |
+
),
|
| 214 |
+
],
|
| 215 |
+
"max_steps": 20,
|
| 216 |
+
"hints": [
|
| 217 |
+
"Look for hardcoded credentials and secrets at the top of the file.",
|
| 218 |
+
"Check every place user input (request.args, request.form) touches a database query, "
|
| 219 |
+
"template, file path, or shell command.",
|
| 220 |
+
"The admin endpoint is missing an authorization check.",
|
| 221 |
+
],
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
_VIEWS_CODE = """\
|
| 226 |
+
import threading
|
| 227 |
+
from django.db import transaction
|
| 228 |
+
from django.contrib.auth.decorators import login_required
|
| 229 |
+
from django.http import JsonResponse
|
| 230 |
+
from .models import Order, Product, Cart
|
| 231 |
+
import hashlib
|
| 232 |
+
|
| 233 |
+
_lock = threading.Lock()
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@login_required
|
| 237 |
+
def place_order(request):
|
| 238 |
+
user = request.user
|
| 239 |
+
cart_items = Cart.objects.filter(user=user)
|
| 240 |
+
|
| 241 |
+
if not cart_items.exists():
|
| 242 |
+
return JsonResponse({'error': 'Cart is empty'}, status=400)
|
| 243 |
+
|
| 244 |
+
total = 0
|
| 245 |
+
for item in cart_items:
|
| 246 |
+
product = Product.objects.get(id=item.product_id)
|
| 247 |
+
total += product.price * item.quantity
|
| 248 |
+
|
| 249 |
+
for item in cart_items:
|
| 250 |
+
product = Product.objects.get(id=item.product_id)
|
| 251 |
+
if product.stock < item.quantity:
|
| 252 |
+
return JsonResponse({'error': f'Insufficient stock for {product.name}'}, status=400)
|
| 253 |
+
|
| 254 |
+
order = Order.objects.create(
|
| 255 |
+
user=user,
|
| 256 |
+
total=total,
|
| 257 |
+
status='pending'
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
for item in cart_items:
|
| 261 |
+
product = Product.objects.get(id=item.product_id)
|
| 262 |
+
product.stock -= item.quantity
|
| 263 |
+
product.save()
|
| 264 |
+
|
| 265 |
+
cart_items.delete()
|
| 266 |
+
return JsonResponse({'order_id': order.id, 'total': float(total)})
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
@login_required
|
| 270 |
+
def get_order_history(request):
|
| 271 |
+
page = int(request.GET.get('page', 1))
|
| 272 |
+
per_page = int(request.GET.get('per_page', 10))
|
| 273 |
+
|
| 274 |
+
orders = Order.objects.filter(user=request.user)[
|
| 275 |
+
(page - 1) * per_page: page * per_page
|
| 276 |
+
]
|
| 277 |
+
|
| 278 |
+
result = []
|
| 279 |
+
for order in orders:
|
| 280 |
+
result.append({
|
| 281 |
+
'id': order.id,
|
| 282 |
+
'total': order.total,
|
| 283 |
+
'status': order.status,
|
| 284 |
+
})
|
| 285 |
+
|
| 286 |
+
return JsonResponse({'orders': result})
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def verify_payment(order_id, payment_hash):
|
| 290 |
+
order = Order.objects.get(id=order_id)
|
| 291 |
+
expected = hashlib.md5(f"{order_id}{order.total}".encode()).hexdigest()
|
| 292 |
+
return expected == payment_hash
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
_MODELS_CODE = """\
|
| 296 |
+
from django.db import models
|
| 297 |
+
import pickle
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class User(models.Model):
|
| 301 |
+
username = models.CharField(max_length=150)
|
| 302 |
+
email = models.CharField(max_length=255)
|
| 303 |
+
password = models.CharField(max_length=255)
|
| 304 |
+
|
| 305 |
+
class Meta:
|
| 306 |
+
db_table = 'users'
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class Product(models.Model):
|
| 310 |
+
name = models.CharField(max_length=255)
|
| 311 |
+
price = models.FloatField()
|
| 312 |
+
stock = models.IntegerField(default=0)
|
| 313 |
+
metadata = models.BinaryField()
|
| 314 |
+
|
| 315 |
+
class Meta:
|
| 316 |
+
db_table = 'products'
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class Order(models.Model):
|
| 320 |
+
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
| 321 |
+
total = models.FloatField()
|
| 322 |
+
status = models.CharField(max_length=50)
|
| 323 |
+
created_at = models.DateTimeField(auto_now_add=True)
|
| 324 |
+
|
| 325 |
+
class Meta:
|
| 326 |
+
db_table = 'orders'
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class Cart(models.Model):
|
| 330 |
+
user = models.ForeignKey(User, on_delete=models.CASCADE)
|
| 331 |
+
product_id = models.IntegerField()
|
| 332 |
+
quantity = models.IntegerField()
|
| 333 |
+
|
| 334 |
+
class Meta:
|
| 335 |
+
db_table = 'cart'
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
TASK_COMPREHENSIVE: Dict[str, Any] = {
|
| 339 |
+
"task_id": "comprehensive-review",
|
| 340 |
+
"difficulty": "hard",
|
| 341 |
+
"description": (
|
| 342 |
+
"Perform a comprehensive code review of this Django e-commerce API.\n"
|
| 343 |
+
"The code spans two files and contains bugs, security vulnerabilities,\n"
|
| 344 |
+
"performance issues, and data modeling problems.\n"
|
| 345 |
+
"Find ALL issues across BOTH files. This is a hard task — look carefully\n"
|
| 346 |
+
"for subtle architectural problems, not just surface-level issues.\n\n"
|
| 347 |
+
"Files to review: views.py, models.py"
|
| 348 |
+
),
|
| 349 |
+
"language": "python",
|
| 350 |
+
"code_files": {
|
| 351 |
+
"views.py": _VIEWS_CODE,
|
| 352 |
+
"models.py": _MODELS_CODE,
|
| 353 |
+
},
|
| 354 |
+
"ground_truth_issues": [
|
| 355 |
+
_issue(
|
| 356 |
+
21, "views.py", "performance", "high",
|
| 357 |
+
"N+1 query: Product.objects.get() is called inside a loop, issuing one SQL "
|
| 358 |
+
"query per cart item. With 100 items this means 100 DB roundtrips.",
|
| 359 |
+
"Use: Product.objects.filter(id__in=[i.product_id for i in cart_items]) "
|
| 360 |
+
"then build a dict for O(1) lookup."
|
| 361 |
+
),
|
| 362 |
+
_issue(
|
| 363 |
+
26, "views.py", "bug", "critical",
|
| 364 |
+
"Race condition: the stock check and stock decrement are not atomic. "
|
| 365 |
+
"Two concurrent requests can both pass the check and oversell the product.",
|
| 366 |
+
"Wrap in transaction.atomic() and use Product.objects.select_for_update() "
|
| 367 |
+
"to lock rows during the check."
|
| 368 |
+
),
|
| 369 |
+
_issue(
|
| 370 |
+
29, "views.py", "bug", "high",
|
| 371 |
+
"Order is created outside a database transaction. If stock decrement fails "
|
| 372 |
+
"after the order is created, the database is left in an inconsistent state.",
|
| 373 |
+
"Wrap the entire order creation flow in: with transaction.atomic():"
|
| 374 |
+
),
|
| 375 |
+
_issue(
|
| 376 |
+
47, "views.py", "security", "medium",
|
| 377 |
+
"No maximum cap on per_page: an attacker can request per_page=1000000 "
|
| 378 |
+
"to dump the entire orders table in one request, causing DoS or data leak.",
|
| 379 |
+
"Add: per_page = min(int(request.GET.get('per_page', 10)), 100)"
|
| 380 |
+
),
|
| 381 |
+
_issue(
|
| 382 |
+
66, "views.py", "security", "medium",
|
| 383 |
+
"MD5 is a cryptographically broken hash function and should not be used "
|
| 384 |
+
"for payment verification. Collisions can be manufactured.",
|
| 385 |
+
"Use HMAC-SHA256: hmac.new(SECRET.encode(), payload.encode(), hashlib.sha256).hexdigest()"
|
| 386 |
+
),
|
| 387 |
+
_issue(
|
| 388 |
+
67, "views.py", "security", "medium",
|
| 389 |
+
"Timing attack: string comparison with == leaks timing information that "
|
| 390 |
+
"allows an attacker to forge valid hashes byte-by-byte.",
|
| 391 |
+
"Use: hmac.compare_digest(expected, payment_hash) for constant-time comparison."
|
| 392 |
+
),
|
| 393 |
+
_issue(
|
| 394 |
+
8, "models.py", "security", "critical",
|
| 395 |
+
"Plaintext password storage: passwords are stored as raw strings in the "
|
| 396 |
+
"database. Any DB breach immediately exposes all user passwords.",
|
| 397 |
+
"Use Django's built-in: from django.contrib.auth.hashers import make_password, check_password"
|
| 398 |
+
),
|
| 399 |
+
_issue(
|
| 400 |
+
16, "models.py", "bug", "medium",
|
| 401 |
+
"FloatField for monetary values causes floating-point precision errors "
|
| 402 |
+
"(e.g., 0.1 + 0.2 != 0.3). This will produce wrong totals over time.",
|
| 403 |
+
"Use: DecimalField(max_digits=10, decimal_places=2) for all monetary fields."
|
| 404 |
+
),
|
| 405 |
+
_issue(
|
| 406 |
+
18, "models.py", "security", "high",
|
| 407 |
+
"BinaryField storing pickled data is dangerous: pickle.loads() on untrusted "
|
| 408 |
+
"data can execute arbitrary code. Anyone who can write to this field can RCE.",
|
| 409 |
+
"Use: JSONField() instead. If binary storage is required, validate/sign the data."
|
| 410 |
+
),
|
| 411 |
+
],
|
| 412 |
+
"max_steps": 30,
|
| 413 |
+
"hints": [
|
| 414 |
+
"Look for database queries inside for loops — this is a classic N+1 problem.",
|
| 415 |
+
"Check whether stock checks and order creation happen inside a database transaction.",
|
| 416 |
+
"Look at models.py: how are passwords and monetary values stored?",
|
| 417 |
+
],
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
ALL_TASKS: Dict[str, Dict[str, Any]] = {
|
| 422 |
+
TASK_BUG_DETECTION["task_id"]: TASK_BUG_DETECTION,
|
| 423 |
+
TASK_SECURITY_AUDIT["task_id"]: TASK_SECURITY_AUDIT,
|
| 424 |
+
TASK_COMPREHENSIVE["task_id"]: TASK_COMPREHENSIVE,
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
TASK_IDS: List[str] = list(ALL_TASKS.keys())
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def get_task(task_id: str) -> Dict[str, Any]:
|
| 431 |
+
"""Return task definition by ID, raising KeyError if not found."""
|
| 432 |
+
if task_id not in ALL_TASKS:
|
| 433 |
+
raise KeyError(f"Unknown task_id '{task_id}'. Valid: {TASK_IDS}")
|
| 434 |
+
return ALL_TASKS[task_id]
|
tests/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# tests package
|
tests/test_environment.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for CodeReviewEnvironment.
|
| 3 |
+
|
| 4 |
+
Run with: pytest tests/ -v
|
| 5 |
+
Or: python -m pytest tests/ -v
|
| 6 |
+
"""
|
| 7 |
+
import sys
|
| 8 |
+
import os
|
| 9 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 10 |
+
|
| 11 |
+
import pytest
|
| 12 |
+
from models import ReviewAction, ReviewObservation, ReviewState
|
| 13 |
+
from server.environment import CodeReviewEnvironment
|
| 14 |
+
from tasks.data import ALL_TASKS, TASK_IDS
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
# Fixtures
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
|
| 21 |
+
@pytest.fixture
|
| 22 |
+
def env():
|
| 23 |
+
return CodeReviewEnvironment()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@pytest.fixture
|
| 27 |
+
def env_bug(env):
|
| 28 |
+
env.reset(task_id="bug-detection")
|
| 29 |
+
return env
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@pytest.fixture
|
| 33 |
+
def env_sec(env):
|
| 34 |
+
env.reset(task_id="security-audit")
|
| 35 |
+
return env
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@pytest.fixture
|
| 39 |
+
def env_hard(env):
|
| 40 |
+
env.reset(task_id="comprehensive-review")
|
| 41 |
+
return env
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
# reset() tests
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
class TestReset:
|
| 49 |
+
def test_reset_returns_observation(self, env):
|
| 50 |
+
obs = env.reset()
|
| 51 |
+
assert isinstance(obs, ReviewObservation)
|
| 52 |
+
|
| 53 |
+
def test_reset_done_is_false(self, env):
|
| 54 |
+
obs = env.reset()
|
| 55 |
+
assert obs.done is False
|
| 56 |
+
|
| 57 |
+
def test_reset_reward_is_none(self, env):
|
| 58 |
+
obs = env.reset()
|
| 59 |
+
assert obs.reward is None
|
| 60 |
+
|
| 61 |
+
def test_reset_has_code_files(self, env):
|
| 62 |
+
obs = env.reset()
|
| 63 |
+
assert isinstance(obs.code_files, dict)
|
| 64 |
+
assert len(obs.code_files) > 0
|
| 65 |
+
|
| 66 |
+
def test_reset_step_count_zero(self, env):
|
| 67 |
+
obs = env.reset()
|
| 68 |
+
assert obs.step_count == 0
|
| 69 |
+
|
| 70 |
+
def test_reset_no_flagged_issues(self, env):
|
| 71 |
+
obs = env.reset()
|
| 72 |
+
assert obs.flagged_issues == []
|
| 73 |
+
|
| 74 |
+
def test_reset_specific_task(self, env):
|
| 75 |
+
for task_id in TASK_IDS:
|
| 76 |
+
obs = env.reset(task_id=task_id)
|
| 77 |
+
assert obs.task_id == task_id
|
| 78 |
+
|
| 79 |
+
def test_reset_bug_detection(self, env):
|
| 80 |
+
obs = env.reset(task_id="bug-detection")
|
| 81 |
+
assert "utils.py" in obs.code_files
|
| 82 |
+
|
| 83 |
+
def test_reset_security_audit(self, env):
|
| 84 |
+
obs = env.reset(task_id="security-audit")
|
| 85 |
+
assert "app.py" in obs.code_files
|
| 86 |
+
|
| 87 |
+
def test_reset_comprehensive(self, env):
|
| 88 |
+
obs = env.reset(task_id="comprehensive-review")
|
| 89 |
+
assert "views.py" in obs.code_files
|
| 90 |
+
assert "models.py" in obs.code_files
|
| 91 |
+
|
| 92 |
+
def test_reset_with_seed_is_reproducible(self, env):
|
| 93 |
+
obs1 = env.reset(seed=42)
|
| 94 |
+
task1 = obs1.task_id
|
| 95 |
+
obs2 = env.reset(seed=42)
|
| 96 |
+
task2 = obs2.task_id
|
| 97 |
+
assert task1 == task2
|
| 98 |
+
|
| 99 |
+
def test_reset_clears_previous_state(self, env):
|
| 100 |
+
env.reset(task_id="bug-detection")
|
| 101 |
+
env.step(ReviewAction(
|
| 102 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 103 |
+
issue_type="bug", severity="high", description="test"
|
| 104 |
+
))
|
| 105 |
+
obs = env.reset(task_id="bug-detection")
|
| 106 |
+
assert obs.flagged_issues == []
|
| 107 |
+
assert obs.step_count == 0
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ---------------------------------------------------------------------------
|
| 111 |
+
# step() — flag_issue tests
|
| 112 |
+
# ---------------------------------------------------------------------------
|
| 113 |
+
|
| 114 |
+
class TestFlagIssue:
|
| 115 |
+
def test_flag_increments_step_count(self, env_bug):
|
| 116 |
+
obs = env_bug.step(ReviewAction(
|
| 117 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 118 |
+
issue_type="bug", severity="high", description="test"
|
| 119 |
+
))
|
| 120 |
+
assert obs.step_count == 1
|
| 121 |
+
|
| 122 |
+
def test_flag_adds_to_flagged_issues(self, env_bug):
|
| 123 |
+
obs = env_bug.step(ReviewAction(
|
| 124 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 125 |
+
issue_type="bug", severity="high", description="test"
|
| 126 |
+
))
|
| 127 |
+
assert len(obs.flagged_issues) == 1
|
| 128 |
+
|
| 129 |
+
def test_flag_true_positive_gives_positive_reward(self, env_bug):
|
| 130 |
+
obs = env_bug.step(ReviewAction(
|
| 131 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 132 |
+
issue_type="bug", severity="high", description="off-by-one"
|
| 133 |
+
))
|
| 134 |
+
assert obs.reward is not None and obs.reward > 0
|
| 135 |
+
|
| 136 |
+
def test_flag_false_positive_gives_negative_reward(self, env_bug):
|
| 137 |
+
obs = env_bug.step(ReviewAction(
|
| 138 |
+
action_type="flag_issue", line_number=100, filename="utils.py",
|
| 139 |
+
issue_type="bug", severity="low", description="nonexistent issue"
|
| 140 |
+
))
|
| 141 |
+
assert obs.reward is not None and obs.reward < 0
|
| 142 |
+
|
| 143 |
+
def test_flag_missing_line_number_gives_penalty(self, env_bug):
|
| 144 |
+
obs = env_bug.step(ReviewAction(
|
| 145 |
+
action_type="flag_issue", filename="utils.py",
|
| 146 |
+
issue_type="bug", severity="high", description="test"
|
| 147 |
+
))
|
| 148 |
+
assert obs.reward is not None and obs.reward <= 0
|
| 149 |
+
|
| 150 |
+
def test_flag_duplicate_line_no_change(self, env_bug):
|
| 151 |
+
env_bug.step(ReviewAction(
|
| 152 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 153 |
+
issue_type="bug", severity="high", description="test"
|
| 154 |
+
))
|
| 155 |
+
obs = env_bug.step(ReviewAction(
|
| 156 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 157 |
+
issue_type="bug", severity="high", description="same line again"
|
| 158 |
+
))
|
| 159 |
+
assert len(obs.flagged_issues) == 1 # not doubled
|
| 160 |
+
|
| 161 |
+
def test_flag_multiple_issues(self, env_bug):
|
| 162 |
+
for line in [6, 13, 33]:
|
| 163 |
+
env_bug.step(ReviewAction(
|
| 164 |
+
action_type="flag_issue", line_number=line, filename="utils.py",
|
| 165 |
+
issue_type="bug", severity="medium", description=f"bug at {line}"
|
| 166 |
+
))
|
| 167 |
+
obs = env_bug.state
|
| 168 |
+
assert len(obs.flagged_issues) == 3
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# ---------------------------------------------------------------------------
|
| 172 |
+
# step() — clear_flag tests
|
| 173 |
+
# ---------------------------------------------------------------------------
|
| 174 |
+
|
| 175 |
+
class TestClearFlag:
|
| 176 |
+
def test_clear_removes_flag(self, env_bug):
|
| 177 |
+
env_bug.step(ReviewAction(
|
| 178 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 179 |
+
issue_type="bug", severity="high", description="test"
|
| 180 |
+
))
|
| 181 |
+
obs = env_bug.step(ReviewAction(
|
| 182 |
+
action_type="clear_flag", line_number=6, filename="utils.py",
|
| 183 |
+
description=""
|
| 184 |
+
))
|
| 185 |
+
assert len(obs.flagged_issues) == 0
|
| 186 |
+
|
| 187 |
+
def test_clear_nonexistent_flag_no_reward(self, env_bug):
|
| 188 |
+
obs = env_bug.step(ReviewAction(
|
| 189 |
+
action_type="clear_flag", line_number=999, filename="utils.py",
|
| 190 |
+
description=""
|
| 191 |
+
))
|
| 192 |
+
assert obs.reward == 0.0
|
| 193 |
+
|
| 194 |
+
def test_clear_false_positive_gives_positive_reward(self, env_bug):
|
| 195 |
+
# First flag a FP
|
| 196 |
+
env_bug.step(ReviewAction(
|
| 197 |
+
action_type="flag_issue", line_number=100, filename="utils.py",
|
| 198 |
+
issue_type="bug", severity="low", description="wrong"
|
| 199 |
+
))
|
| 200 |
+
obs = env_bug.step(ReviewAction(
|
| 201 |
+
action_type="clear_flag", line_number=100, filename="utils.py",
|
| 202 |
+
description=""
|
| 203 |
+
))
|
| 204 |
+
assert obs.reward is not None and obs.reward > 0
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
# ---------------------------------------------------------------------------
|
| 208 |
+
# step() — request_hint tests
|
| 209 |
+
# ---------------------------------------------------------------------------
|
| 210 |
+
|
| 211 |
+
class TestRequestHint:
|
| 212 |
+
def test_hint_gives_small_negative_reward(self, env_bug):
|
| 213 |
+
obs = env_bug.step(ReviewAction(action_type="request_hint"))
|
| 214 |
+
assert obs.reward is not None and obs.reward < 0
|
| 215 |
+
|
| 216 |
+
def test_hint_decrements_hints_remaining(self, env_bug):
|
| 217 |
+
before = env_bug.state.step_count # proxy check
|
| 218 |
+
obs1 = env_bug.step(ReviewAction(action_type="request_hint"))
|
| 219 |
+
obs2 = env_bug.step(ReviewAction(action_type="request_hint"))
|
| 220 |
+
assert obs2.hints_remaining < obs1.hints_remaining
|
| 221 |
+
|
| 222 |
+
def test_hint_content_in_feedback(self, env_bug):
|
| 223 |
+
obs = env_bug.step(ReviewAction(action_type="request_hint"))
|
| 224 |
+
assert "hint" in obs.feedback.lower() or "loop" in obs.feedback.lower()
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# ---------------------------------------------------------------------------
|
| 228 |
+
# step() — submit_review tests
|
| 229 |
+
# ---------------------------------------------------------------------------
|
| 230 |
+
|
| 231 |
+
class TestSubmitReview:
|
| 232 |
+
def test_submit_ends_episode(self, env_bug):
|
| 233 |
+
obs = env_bug.step(ReviewAction(action_type="submit_review"))
|
| 234 |
+
assert obs.done is True
|
| 235 |
+
|
| 236 |
+
def test_submit_reward_is_float_in_range(self, env_bug):
|
| 237 |
+
obs = env_bug.step(ReviewAction(action_type="submit_review"))
|
| 238 |
+
assert obs.reward is not None
|
| 239 |
+
assert 0.0 <= obs.reward <= 1.0
|
| 240 |
+
|
| 241 |
+
def test_submit_all_bugs_gives_high_score(self, env_bug):
|
| 242 |
+
# Flag all 3 correct bugs
|
| 243 |
+
for line, sev in [(6, "high"), (13, "medium"), (33, "low")]:
|
| 244 |
+
env_bug.step(ReviewAction(
|
| 245 |
+
action_type="flag_issue", line_number=line, filename="utils.py",
|
| 246 |
+
issue_type="bug", severity=sev, description=f"bug at line {line}"
|
| 247 |
+
))
|
| 248 |
+
obs = env_bug.step(ReviewAction(action_type="submit_review"))
|
| 249 |
+
assert obs.reward is not None and obs.reward >= 0.7
|
| 250 |
+
|
| 251 |
+
def test_submit_no_flags_gives_zero(self, env_bug):
|
| 252 |
+
obs = env_bug.step(ReviewAction(action_type="submit_review"))
|
| 253 |
+
assert obs.reward == 0.0
|
| 254 |
+
|
| 255 |
+
def test_submit_after_done_is_noop(self, env_bug):
|
| 256 |
+
env_bug.step(ReviewAction(action_type="submit_review"))
|
| 257 |
+
obs2 = env_bug.step(ReviewAction(action_type="submit_review"))
|
| 258 |
+
assert obs2.done is True # still done
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# ---------------------------------------------------------------------------
|
| 262 |
+
# state property tests
|
| 263 |
+
# ---------------------------------------------------------------------------
|
| 264 |
+
|
| 265 |
+
class TestState:
|
| 266 |
+
def test_state_returns_review_state(self, env):
|
| 267 |
+
env.reset(task_id="bug-detection")
|
| 268 |
+
st = env.state
|
| 269 |
+
assert isinstance(st, ReviewState)
|
| 270 |
+
|
| 271 |
+
def test_state_has_episode_id(self, env):
|
| 272 |
+
env.reset(task_id="bug-detection")
|
| 273 |
+
assert env.state.episode_id is not None
|
| 274 |
+
|
| 275 |
+
def test_state_tracks_step_count(self, env_bug):
|
| 276 |
+
env_bug.step(ReviewAction(action_type="request_hint"))
|
| 277 |
+
assert env_bug.state.step_count == 1
|
| 278 |
+
|
| 279 |
+
def test_state_tracks_flagged_issues(self, env_bug):
|
| 280 |
+
env_bug.step(ReviewAction(
|
| 281 |
+
action_type="flag_issue", line_number=6, filename="utils.py",
|
| 282 |
+
issue_type="bug", severity="high", description="test"
|
| 283 |
+
))
|
| 284 |
+
assert len(env_bug.state.flagged_issues) == 1
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# ---------------------------------------------------------------------------
|
| 288 |
+
# Unknown action type
|
| 289 |
+
# ---------------------------------------------------------------------------
|
| 290 |
+
|
| 291 |
+
class TestUnknownAction:
|
| 292 |
+
def test_unknown_action_type_no_crash(self, env_bug):
|
| 293 |
+
obs = env_bug.step(ReviewAction(action_type="invalid_action"))
|
| 294 |
+
assert obs is not None
|
| 295 |
+
assert obs.done is False or obs.done is True
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
# ---------------------------------------------------------------------------
|
| 299 |
+
# Max steps auto-end
|
| 300 |
+
# ---------------------------------------------------------------------------
|
| 301 |
+
|
| 302 |
+
class TestMaxSteps:
|
| 303 |
+
def test_episode_auto_ends_at_max_steps(self):
|
| 304 |
+
"""Verify episode ends when step budget is exhausted."""
|
| 305 |
+
env = CodeReviewEnvironment()
|
| 306 |
+
obs = env.reset(task_id="bug-detection")
|
| 307 |
+
max_steps = obs.max_steps
|
| 308 |
+
|
| 309 |
+
for _ in range(max_steps):
|
| 310 |
+
obs = env.step(ReviewAction(action_type="request_hint"))
|
| 311 |
+
if obs.done:
|
| 312 |
+
break
|
| 313 |
+
|
| 314 |
+
assert obs.done is True
|
tests/test_graders.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for the grading logic.
|
| 3 |
+
"""
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
from models import Issue
|
| 10 |
+
from server.graders import grade_episode, match_issue, run_keyword_baseline
|
| 11 |
+
from tasks.data import ALL_TASKS, TASK_IDS
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _issue(line, filename, itype="bug", severity="medium", desc=""):
|
| 15 |
+
return Issue(line_number=line, filename=filename, issue_type=itype,
|
| 16 |
+
severity=severity, description=desc)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
# match_issue()
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
class TestMatchIssue:
|
| 24 |
+
def test_exact_match(self):
|
| 25 |
+
f = _issue(6, "utils.py", "bug", "high")
|
| 26 |
+
gt = _issue(6, "utils.py", "bug", "high")
|
| 27 |
+
assert match_issue(f, gt) is True
|
| 28 |
+
|
| 29 |
+
def test_line_within_tolerance(self):
|
| 30 |
+
f = _issue(7, "utils.py", "bug", "high")
|
| 31 |
+
gt = _issue(6, "utils.py", "bug", "high")
|
| 32 |
+
assert match_issue(f, gt, line_tolerance=2) is True
|
| 33 |
+
|
| 34 |
+
def test_line_outside_tolerance(self):
|
| 35 |
+
f = _issue(10, "utils.py", "bug", "high")
|
| 36 |
+
gt = _issue(6, "utils.py", "bug", "high")
|
| 37 |
+
assert match_issue(f, gt, line_tolerance=2) is False
|
| 38 |
+
|
| 39 |
+
def test_wrong_filename(self):
|
| 40 |
+
f = _issue(6, "other.py", "bug", "high")
|
| 41 |
+
gt = _issue(6, "utils.py", "bug", "high")
|
| 42 |
+
assert match_issue(f, gt) is False
|
| 43 |
+
|
| 44 |
+
def test_bug_logic_interchangeable(self):
|
| 45 |
+
f = _issue(6, "utils.py", "logic", "high")
|
| 46 |
+
gt = _issue(6, "utils.py", "bug", "high")
|
| 47 |
+
assert match_issue(f, gt) is True
|
| 48 |
+
|
| 49 |
+
def test_logic_bug_interchangeable(self):
|
| 50 |
+
f = _issue(6, "utils.py", "bug", "high")
|
| 51 |
+
gt = _issue(6, "utils.py", "logic", "high")
|
| 52 |
+
assert match_issue(f, gt) is True
|
| 53 |
+
|
| 54 |
+
def test_wrong_type_no_match(self):
|
| 55 |
+
f = _issue(6, "utils.py", "performance", "high")
|
| 56 |
+
gt = _issue(6, "utils.py", "bug", "high")
|
| 57 |
+
assert match_issue(f, gt) is False
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
# grade_episode()
|
| 62 |
+
# ---------------------------------------------------------------------------
|
| 63 |
+
|
| 64 |
+
class TestGradeEpisode:
|
| 65 |
+
def test_empty_both_is_perfect(self):
|
| 66 |
+
assert grade_episode([], []) == 1.0
|
| 67 |
+
|
| 68 |
+
def test_empty_flagged_is_zero(self):
|
| 69 |
+
gt = [_issue(6, "utils.py")]
|
| 70 |
+
assert grade_episode([], gt) == 0.0
|
| 71 |
+
|
| 72 |
+
def test_false_positives_only_is_zero(self):
|
| 73 |
+
flagged = [_issue(100, "utils.py"), _issue(200, "utils.py")]
|
| 74 |
+
gt = [_issue(6, "utils.py")]
|
| 75 |
+
score = grade_episode(flagged, gt)
|
| 76 |
+
assert score == 0.0
|
| 77 |
+
|
| 78 |
+
def test_perfect_match_is_near_one(self):
|
| 79 |
+
gt = [
|
| 80 |
+
_issue(6, "utils.py", "bug", "high"),
|
| 81 |
+
_issue(13, "utils.py", "bug", "medium"),
|
| 82 |
+
]
|
| 83 |
+
score = grade_episode(gt, gt)
|
| 84 |
+
assert score >= 0.9
|
| 85 |
+
|
| 86 |
+
def test_partial_match(self):
|
| 87 |
+
gt = [
|
| 88 |
+
_issue(6, "utils.py", "bug", "high"),
|
| 89 |
+
_issue(13, "utils.py", "bug", "medium"),
|
| 90 |
+
_issue(33, "utils.py", "bug", "low"),
|
| 91 |
+
]
|
| 92 |
+
flagged = [_issue(6, "utils.py", "bug", "high")] # only 1 of 3
|
| 93 |
+
score = grade_episode(flagged, gt)
|
| 94 |
+
# recall = 1/3, precision = 1/1, F1 = 0.5
|
| 95 |
+
assert 0.3 < score < 0.6
|
| 96 |
+
|
| 97 |
+
def test_false_positives_lower_score(self):
|
| 98 |
+
gt = [_issue(6, "utils.py", "bug", "high")]
|
| 99 |
+
perfect = [_issue(6, "utils.py", "bug", "high")]
|
| 100 |
+
with_fp = [_issue(6, "utils.py", "bug", "high"), _issue(100, "utils.py")]
|
| 101 |
+
assert grade_episode(perfect, gt) > grade_episode(with_fp, gt)
|
| 102 |
+
|
| 103 |
+
def test_severity_mismatch_lowers_score(self):
|
| 104 |
+
gt = [_issue(6, "utils.py", "bug", "critical")]
|
| 105 |
+
exact = [_issue(6, "utils.py", "bug", "critical")]
|
| 106 |
+
wrong_sev = [_issue(6, "utils.py", "bug", "low")]
|
| 107 |
+
assert grade_episode(exact, gt) > grade_episode(wrong_sev, gt)
|
| 108 |
+
|
| 109 |
+
def test_score_is_always_in_0_1(self):
|
| 110 |
+
import random
|
| 111 |
+
random.seed(0)
|
| 112 |
+
gt = [_issue(i * 10, "f.py") for i in range(5)]
|
| 113 |
+
for _ in range(20):
|
| 114 |
+
n = random.randint(0, 10)
|
| 115 |
+
flagged = [_issue(random.randint(1, 100), "f.py") for _ in range(n)]
|
| 116 |
+
score = grade_episode(flagged, gt)
|
| 117 |
+
assert 0.0 <= score <= 1.0, f"Score {score} out of range"
|
| 118 |
+
|
| 119 |
+
def test_multifile_match(self):
|
| 120 |
+
gt = [
|
| 121 |
+
_issue(21, "views.py", "performance", "high"),
|
| 122 |
+
_issue(8, "models.py", "security", "critical"),
|
| 123 |
+
]
|
| 124 |
+
flagged = [
|
| 125 |
+
_issue(21, "views.py", "performance", "high"),
|
| 126 |
+
_issue(8, "models.py", "security", "critical"),
|
| 127 |
+
]
|
| 128 |
+
score = grade_episode(flagged, gt)
|
| 129 |
+
assert score >= 0.85
|
| 130 |
+
|
| 131 |
+
def test_multifile_wrong_file_no_match(self):
|
| 132 |
+
gt = [_issue(21, "views.py", "performance", "high")]
|
| 133 |
+
flagged = [_issue(21, "models.py", "performance", "high")] # wrong file
|
| 134 |
+
assert grade_episode(flagged, gt) == 0.0
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# ---------------------------------------------------------------------------
|
| 138 |
+
# run_keyword_baseline()
|
| 139 |
+
# ---------------------------------------------------------------------------
|
| 140 |
+
|
| 141 |
+
class TestKeywordBaseline:
|
| 142 |
+
def test_baseline_returns_list(self):
|
| 143 |
+
from tasks.data import TASK_BUG_DETECTION
|
| 144 |
+
findings = run_keyword_baseline(TASK_BUG_DETECTION)
|
| 145 |
+
assert isinstance(findings, list)
|
| 146 |
+
|
| 147 |
+
def test_baseline_issues_have_correct_types(self):
|
| 148 |
+
from tasks.data import TASK_BUG_DETECTION
|
| 149 |
+
findings = run_keyword_baseline(TASK_BUG_DETECTION)
|
| 150 |
+
for f in findings:
|
| 151 |
+
assert isinstance(f, Issue)
|
| 152 |
+
assert f.issue_type in ("bug", "security", "performance", "logic")
|
| 153 |
+
assert f.severity in ("low", "medium", "high", "critical")
|
| 154 |
+
|
| 155 |
+
def test_baseline_finds_some_security_issues(self):
|
| 156 |
+
from tasks.data import TASK_SECURITY_AUDIT
|
| 157 |
+
findings = run_keyword_baseline(TASK_SECURITY_AUDIT)
|
| 158 |
+
security_finds = [f for f in findings if f.issue_type == "security"]
|
| 159 |
+
assert len(security_finds) >= 2
|
| 160 |
+
|
| 161 |
+
def test_baseline_score_in_range(self):
|
| 162 |
+
for task_id in TASK_IDS:
|
| 163 |
+
task = ALL_TASKS[task_id]
|
| 164 |
+
findings = run_keyword_baseline(task)
|
| 165 |
+
gt = [Issue.from_dict(i) for i in task["ground_truth_issues"]]
|
| 166 |
+
score = grade_episode(findings, gt)
|
| 167 |
+
assert 0.0 <= score <= 1.0, f"Task {task_id}: score={score} out of range"
|
| 168 |
+
|
| 169 |
+
def test_baseline_score_is_nonzero(self):
|
| 170 |
+
"""Heuristic should find at least something in most tasks."""
|
| 171 |
+
for task_id in TASK_IDS:
|
| 172 |
+
task = ALL_TASKS[task_id]
|
| 173 |
+
findings = run_keyword_baseline(task)
|
| 174 |
+
gt = [Issue.from_dict(i) for i in task["ground_truth_issues"]]
|
| 175 |
+
score = grade_episode(findings, gt)
|
| 176 |
+
# Not every task may have regex hits, but security-audit should
|
| 177 |
+
if task_id == "security-audit":
|
| 178 |
+
assert score > 0.0, f"Heuristic found nothing in {task_id}"
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# ---------------------------------------------------------------------------
|
| 182 |
+
# Ground truth sanity checks
|
| 183 |
+
# ---------------------------------------------------------------------------
|
| 184 |
+
|
| 185 |
+
class TestGroundTruth:
|
| 186 |
+
def test_all_tasks_have_3_plus_issues(self):
|
| 187 |
+
for task_id, task in ALL_TASKS.items():
|
| 188 |
+
assert len(task["ground_truth_issues"]) >= 3, (
|
| 189 |
+
f"Task {task_id} has fewer than 3 issues"
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
def test_all_tasks_have_valid_difficulties(self):
|
| 193 |
+
difficulties = {t["difficulty"] for t in ALL_TASKS.values()}
|
| 194 |
+
assert "easy" in difficulties
|
| 195 |
+
assert "medium" in difficulties
|
| 196 |
+
assert "hard" in difficulties
|
| 197 |
+
|
| 198 |
+
def test_all_issues_have_required_fields(self):
|
| 199 |
+
for task_id, task in ALL_TASKS.items():
|
| 200 |
+
for i, issue in enumerate(task["ground_truth_issues"]):
|
| 201 |
+
assert "line_number" in issue, f"{task_id}[{i}] missing line_number"
|
| 202 |
+
assert "filename" in issue, f"{task_id}[{i}] missing filename"
|
| 203 |
+
assert "issue_type" in issue, f"{task_id}[{i}] missing issue_type"
|
| 204 |
+
assert "severity" in issue, f"{task_id}[{i}] missing severity"
|
| 205 |
+
|
| 206 |
+
def test_bug_detection_issues_in_utils_py(self):
|
| 207 |
+
task = ALL_TASKS["bug-detection"]
|
| 208 |
+
for issue in task["ground_truth_issues"]:
|
| 209 |
+
assert issue["filename"] == "utils.py"
|
| 210 |
+
|
| 211 |
+
def test_comprehensive_has_multifile_issues(self):
|
| 212 |
+
task = ALL_TASKS["comprehensive-review"]
|
| 213 |
+
files = {i["filename"] for i in task["ground_truth_issues"]}
|
| 214 |
+
assert "views.py" in files
|
| 215 |
+
assert "models.py" in files
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|