Spaces:
Running
Running
Nitish commited on
Commit Β·
474eafa
1
Parent(s): 742e175
feat: finalize OpenEnv alignment and calibrate rewards for QA
Browse files- Dockerfile +3 -4
- README.md +87 -113
- inference.py +75 -52
- openenv.yaml +50 -43
- qa_test.py +237 -0
- server/app.py +39 -27
- server/environment.py +73 -436
- server/grader.py +80 -0
- server/models.py +39 -20
- server/tasks.py +110 -0
- validate.sh +103 -0
Dockerfile
CHANGED
|
@@ -7,11 +7,10 @@ COPY requirements.txt .
|
|
| 7 |
RUN pip install --no-cache-dir --upgrade pip && \
|
| 8 |
pip install --no-cache-dir -r requirements.txt
|
| 9 |
|
| 10 |
-
# Copy
|
| 11 |
-
COPY
|
| 12 |
-
COPY static/ ./static/
|
| 13 |
|
| 14 |
-
# Environment defaults
|
| 15 |
ENV PORT=7860
|
| 16 |
ENV PYTHONPATH=/app
|
| 17 |
ENV ENABLE_WEB_INTERFACE=false
|
|
|
|
| 7 |
RUN pip install --no-cache-dir --upgrade pip && \
|
| 8 |
pip install --no-cache-dir -r requirements.txt
|
| 9 |
|
| 10 |
+
# Copy all project files (needed for openenv validate to work inside)
|
| 11 |
+
COPY . .
|
|
|
|
| 12 |
|
| 13 |
+
# Environment defaults (Hugging Face Spaces use 7860)
|
| 14 |
ENV PORT=7860
|
| 15 |
ENV PYTHONPATH=/app
|
| 16 |
ENV ENABLE_WEB_INTERFACE=false
|
README.md
CHANGED
|
@@ -1,156 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
-
title: Code Review Env
|
| 3 |
-
emoji: π
|
| 4 |
-
colorFrom: red
|
| 5 |
-
colorTo: purple
|
| 6 |
-
sdk: docker
|
| 7 |
-
pinned: false
|
| 8 |
-
---
|
| 9 |
-
# Code Security Review β OpenEnv
|
| 10 |
|
| 11 |
-
|
| 12 |
-
> vulnerabilities in Python code.
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
---
|
| 22 |
|
| 23 |
## Action Space
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
## Observation Space
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
| 46 |
|
| 47 |
---
|
| 48 |
|
| 49 |
-
##
|
| 50 |
|
| 51 |
-
|
| 52 |
-
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
|
| 58 |
-
- **Code:** `authenticate_user()` uses `or` instead of `and` for admin check
|
| 59 |
-
- **Expected bug type:** `logic-error`
|
| 60 |
-
- **Expected severity:** `critical`
|
| 61 |
-
- **Baseline score:** ~0.60
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
- **Expected bug type:** `security-vulnerability`
|
| 66 |
-
- **Expected severity:** `critical`
|
| 67 |
-
- **Baseline score:** ~0.55
|
| 68 |
|
| 69 |
---
|
| 70 |
|
| 71 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
|
| 74 |
|
| 75 |
-
|
|
| 76 |
|---|---|---|
|
| 77 |
-
|
|
| 78 |
-
|
|
| 79 |
-
|
|
| 80 |
-
|
|
| 81 |
-
|
|
| 82 |
-
| Severity | 0.10 | Correct severity level |
|
| 83 |
-
| **Total** | **1.00** | |
|
| 84 |
|
| 85 |
---
|
| 86 |
|
| 87 |
## Setup
|
| 88 |
|
| 89 |
-
###
|
| 90 |
|
| 91 |
```bash
|
| 92 |
-
docker build -t code-
|
| 93 |
-
docker run -p
|
| 94 |
```
|
| 95 |
|
| 96 |
-
###
|
| 97 |
|
| 98 |
```bash
|
| 99 |
-
# Set your environment variables
|
| 100 |
-
export HF_TOKEN=hf_your_token_here
|
| 101 |
-
export MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
|
| 102 |
-
export API_BASE_URL=https://router.huggingface.co/v1
|
| 103 |
-
export ENV_BASE_URL=http://localhost:7860
|
| 104 |
-
|
| 105 |
-
# Install dependencies
|
| 106 |
pip install -r requirements.txt
|
| 107 |
-
|
| 108 |
-
# Run
|
| 109 |
-
python inference.py
|
| 110 |
-
```
|
| 111 |
-
|
| 112 |
-
### 3. Validate (OpenEnv CLI)
|
| 113 |
-
|
| 114 |
-
```bash
|
| 115 |
-
openenv validate
|
| 116 |
```
|
| 117 |
|
| 118 |
---
|
| 119 |
|
| 120 |
-
##
|
| 121 |
-
|
| 122 |
-
| Method | Path | Description |
|
| 123 |
-
|---|---|---|
|
| 124 |
-
| GET | `/health` | Health check |
|
| 125 |
-
| POST | `/reset?difficulty=easy` | Reset environment |
|
| 126 |
-
| POST | `/step` | Submit a review action |
|
| 127 |
-
| GET | `/state` | Current episode state |
|
| 128 |
-
|
| 129 |
-
---
|
| 130 |
-
|
| 131 |
-
## Baseline Scores
|
| 132 |
-
|
| 133 |
-
| Task | Difficulty | Reward |
|
| 134 |
-
|---|---|---|
|
| 135 |
-
| Off-by-one detection | Easy | ~0.72 |
|
| 136 |
-
| Auth logic flaw | Medium | ~0.60 |
|
| 137 |
-
| SQL injection | Hard | ~0.55 |
|
| 138 |
-
| **Average** | | **~0.62** |
|
| 139 |
-
|
| 140 |
-
---
|
| 141 |
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
-
|
| 145 |
-
code-review-env/
|
| 146 |
-
βββ Dockerfile
|
| 147 |
-
βββ openenv.yaml
|
| 148 |
-
βββ requirements.txt
|
| 149 |
-
βββ inference.py
|
| 150 |
-
βββ README.md
|
| 151 |
-
βββ server/
|
| 152 |
-
βββ __init__.py
|
| 153 |
-
βββ app.py # FastAPI endpoints
|
| 154 |
-
βββ environment.py # Tasks + grader logic
|
| 155 |
-
βββ models.py # Pydantic action/observation/state
|
| 156 |
```
|
|
|
|
| 1 |
+
# Code Security Review β OpenEnv Environment
|
| 2 |
+
|
| 3 |
+
An RL environment for training AI agents to perform real-world code security review.
|
| 4 |
+
Agents analyze code snippets from production pull requests and identify bugs,
|
| 5 |
+
vulnerabilities, and security issues.
|
| 6 |
+
|
| 7 |
+
Built by **Inmodel Labs** for the Meta PyTorch OpenEnv Hackathon.
|
| 8 |
+
|
| 9 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
## Environment Overview
|
|
|
|
| 12 |
|
| 13 |
+
| Field | Value |
|
| 14 |
+
|---|---|
|
| 15 |
+
| Tasks | 3 (easy β medium β hard) |
|
| 16 |
+
| Languages | Python, JavaScript |
|
| 17 |
+
| Action space | Structured JSON (6 fields) |
|
| 18 |
+
| Reward range | 0.0 β 1.0 |
|
| 19 |
+
| Steps per episode | 1 |
|
| 20 |
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
## Tasks
|
| 24 |
+
|
| 25 |
+
| ID | Language | Bug Class | Difficulty |
|
| 26 |
+
|---|---|---|---|
|
| 27 |
+
| `python-off-by-one` | Python | Off-by-one index error | Easy |
|
| 28 |
+
| `js-auth-privilege` | JavaScript | Logic flaw β privilege escalation | Medium |
|
| 29 |
+
| `python-sql-injection` | Python | SQL injection via f-string | Hard |
|
| 30 |
|
| 31 |
---
|
| 32 |
|
| 33 |
## Action Space
|
| 34 |
|
| 35 |
+
The agent submits a JSON action with these fields:
|
| 36 |
+
|
| 37 |
+
```json
|
| 38 |
+
{
|
| 39 |
+
"bug_identified": true,
|
| 40 |
+
"bug_location": "line 3 β range(len(transactions) + 1)",
|
| 41 |
+
"bug_type": "logic-error",
|
| 42 |
+
"bug_description": "Off-by-one error causes IndexError on last iteration...",
|
| 43 |
+
"severity": "medium",
|
| 44 |
+
"suggested_fix": "Change range(len(transactions) + 1) to range(len(transactions))"
|
| 45 |
+
}
|
| 46 |
+
```
|
| 47 |
|
| 48 |
## Observation Space
|
| 49 |
|
| 50 |
+
```json
|
| 51 |
+
{
|
| 52 |
+
"task_id": "python-sql-injection",
|
| 53 |
+
"language": "Python",
|
| 54 |
+
"difficulty": "hard",
|
| 55 |
+
"code_snippet": "def search_users(db, search_term):\n ...",
|
| 56 |
+
"context": "REST API endpoint that searches users by name",
|
| 57 |
+
"pr_title": "Add user search endpoint to REST API",
|
| 58 |
+
"file_path": "api/users.py"
|
| 59 |
+
}
|
| 60 |
+
```
|
| 61 |
|
| 62 |
---
|
| 63 |
|
| 64 |
+
## Reward Breakdown
|
| 65 |
|
| 66 |
+
| Component | Max Score |
|
| 67 |
+
|---|---|
|
| 68 |
+
| Bug identified | 0.20 |
|
| 69 |
+
| Bug type correct | 0.20 |
|
| 70 |
+
| Bug location correct | 0.10 |
|
| 71 |
+
| Description quality | 0.25 |
|
| 72 |
+
| Fix quality | 0.15 |
|
| 73 |
+
| Severity correct | 0.10 |
|
| 74 |
+
| **Total** | **1.00** |
|
| 75 |
|
| 76 |
+
The grader penalises keyword stuffing β incoherent keyword dumps score β€ 0.20.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
+
**Example Calculation:**
|
| 79 |
+
If the agent correctly identifies a bug (+0.20), misidentifies the type (+0.0), finds 50% of the location keywords (+0.05), writes a detailed and coherent description matching most keywords (+0.25), suggests a partially correct fix (+0.08), and gets the severity correct (+0.10), the total reward for that step would be `0.20 + 0.0 + 0.05 + 0.25 + 0.08 + 0.10 = 0.68`.
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
---
|
| 82 |
|
| 83 |
+
## Edge Cases
|
| 84 |
+
|
| 85 |
+
- **At step 0:** `reset()` must be called to initialize the state. If `step()` is called before `reset()`, the environment automatically calls `reset()` internally and evaluates the action on a random task.
|
| 86 |
+
- **Max step limit:** The maximum step limit is 1. Calling `step()` evaluates the action and immediately sets `done=True`.
|
| 87 |
+
- **At done=True:** Calling `step()` returns `reward=0.0`, `done=True`, and a clean error message in the `info` dict `("Episode already completed. Call /reset...")` indicating the episode is complete without auto-resetting.
|
| 88 |
+
|
| 89 |
+
---
|
| 90 |
|
| 91 |
+
## API Endpoints
|
| 92 |
|
| 93 |
+
| Method | Path | Description |
|
| 94 |
|---|---|---|
|
| 95 |
+
| GET | `/` | Health check |
|
| 96 |
+
| POST | `/reset?task_id=<id>` | Reset environment, returns observation |
|
| 97 |
+
| POST | `/step` | Submit action, returns reward |
|
| 98 |
+
| GET | `/state` | Current episode state |
|
| 99 |
+
| GET | `/tasks` | List all tasks |
|
|
|
|
|
|
|
| 100 |
|
| 101 |
---
|
| 102 |
|
| 103 |
## Setup
|
| 104 |
|
| 105 |
+
### Docker
|
| 106 |
|
| 107 |
```bash
|
| 108 |
+
docker build -t code-security-review .
|
| 109 |
+
docker run -p 8000:8000 code-security-review
|
| 110 |
```
|
| 111 |
|
| 112 |
+
### Local
|
| 113 |
|
| 114 |
```bash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
pip install -r requirements.txt
|
| 116 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
```
|
| 118 |
|
| 119 |
---
|
| 120 |
|
| 121 |
+
## Running Inference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
+
```bash
|
| 124 |
+
export API_BASE_URL="https://api.openai.com/v1"
|
| 125 |
+
export MODEL_NAME="gpt-4o-mini"
|
| 126 |
+
export HF_TOKEN="your-api-key"
|
| 127 |
+
export ENV_URL="http://localhost:8000"
|
| 128 |
|
| 129 |
+
python inference.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
```
|
inference.py
CHANGED
|
@@ -6,28 +6,30 @@ Required environment variables:
|
|
| 6 |
API_BASE_URL β LLM API endpoint
|
| 7 |
MODEL_NAME β Model identifier
|
| 8 |
HF_TOKEN β Hugging Face / API key
|
| 9 |
-
|
| 10 |
"""
|
| 11 |
|
| 12 |
import os
|
| 13 |
import json
|
| 14 |
import time
|
| 15 |
import re
|
|
|
|
| 16 |
from typing import List, Optional
|
| 17 |
from dotenv import load_dotenv
|
|
|
|
| 18 |
|
| 19 |
# Load .env variables
|
| 20 |
load_dotenv()
|
| 21 |
|
| 22 |
-
import requests
|
| 23 |
-
from openai import OpenAI
|
| 24 |
-
|
| 25 |
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 26 |
-
API_BASE_URL = os.environ.get("API_BASE_URL", "https://
|
| 27 |
-
MODEL_NAME = os.environ.get("MODEL_NAME", "
|
| 28 |
-
HF_TOKEN = os.environ.get("HF_TOKEN"
|
| 29 |
-
|
| 30 |
-
BENCHMARK = "code-
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 33 |
|
|
@@ -41,7 +43,7 @@ Schema:
|
|
| 41 |
{
|
| 42 |
"bug_identified": true or false,
|
| 43 |
"bug_location": "exact location (function name, line description, variable, expression)",
|
| 44 |
-
"bug_type": "off-by-one | logic-error | security-vulnerability |
|
| 45 |
"bug_description": "detailed explanation of why this is a bug and the impact",
|
| 46 |
"severity": "none | low | medium | high | critical",
|
| 47 |
"suggested_fix": "the corrected code snippet or a precise description of the fix"
|
|
@@ -69,7 +71,7 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
|
|
| 69 |
# ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
|
| 71 |
def env_post(path: str, data: Optional[dict] = None, params: Optional[dict] = None) -> dict:
|
| 72 |
-
url = f"{
|
| 73 |
resp = requests.post(url, json=data or {}, params=params or {}, timeout=30)
|
| 74 |
resp.raise_for_status()
|
| 75 |
return resp.json()
|
|
@@ -80,41 +82,49 @@ def parse_json_from_llm(text: str) -> dict:
|
|
| 80 |
text = text.strip()
|
| 81 |
text = re.sub(r"^```(?:json)?\s*", "", text)
|
| 82 |
text = re.sub(r"\s*```$", "", text)
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
def build_prompt(obs: dict) -> str:
|
| 87 |
lines = [
|
| 88 |
f"Language: {obs['language']}",
|
| 89 |
-
f"
|
|
|
|
|
|
|
| 90 |
"",
|
| 91 |
f"```{obs['language']}",
|
| 92 |
obs["code_snippet"],
|
| 93 |
"```",
|
| 94 |
]
|
| 95 |
-
if obs.get("previous_feedback"):
|
| 96 |
-
lines += ["", f"Previous feedback: {obs['previous_feedback']}",
|
| 97 |
-
"Revise your analysis accordingly."]
|
| 98 |
return "\n".join(lines)
|
| 99 |
|
| 100 |
|
| 101 |
# ββ Task runner βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 102 |
|
| 103 |
-
def run_task(
|
| 104 |
-
reset_resp = env_post("/reset", params={"
|
| 105 |
obs = reset_resp["observation"]
|
| 106 |
-
|
| 107 |
-
|
| 108 |
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
|
|
|
| 112 |
done = False
|
| 113 |
-
|
|
|
|
| 114 |
|
| 115 |
-
while not done and
|
| 116 |
-
|
| 117 |
prompt = build_prompt(obs)
|
|
|
|
| 118 |
|
| 119 |
# ββ LLM call ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 120 |
try:
|
|
@@ -126,20 +136,21 @@ def run_task(difficulty: str) -> dict:
|
|
| 126 |
],
|
| 127 |
temperature=0.1,
|
| 128 |
max_tokens=600,
|
|
|
|
| 129 |
)
|
| 130 |
raw = response.choices[0].message.content
|
| 131 |
action_dict = parse_json_from_llm(raw)
|
| 132 |
action_str = json.dumps(action_dict)
|
| 133 |
-
|
| 134 |
except Exception as exc:
|
| 135 |
-
|
| 136 |
action_dict = {
|
| 137 |
"bug_identified": False,
|
| 138 |
-
"bug_location": "
|
| 139 |
"bug_type": "none",
|
| 140 |
-
"bug_description":
|
| 141 |
"severity": "none",
|
| 142 |
-
"suggested_fix": "",
|
| 143 |
}
|
| 144 |
action_str = "{}"
|
| 145 |
|
|
@@ -147,44 +158,56 @@ def run_task(difficulty: str) -> dict:
|
|
| 147 |
step_resp = env_post("/step", data=action_dict)
|
| 148 |
reward = step_resp["reward"]
|
| 149 |
done = step_resp["done"]
|
| 150 |
-
obs = step_resp
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
total_reward = sum(rewards)
|
| 158 |
-
score = min(max(total_reward, 0.0), 1.0)
|
| 159 |
-
success = score >= 0.8
|
| 160 |
|
| 161 |
-
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 162 |
-
|
| 163 |
return {
|
| 164 |
-
"
|
| 165 |
-
"
|
| 166 |
-
"
|
|
|
|
| 167 |
}
|
| 168 |
|
| 169 |
|
| 170 |
# ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 171 |
|
| 172 |
def main():
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
results = []
|
| 175 |
|
| 176 |
-
for
|
| 177 |
try:
|
| 178 |
-
r = run_task(
|
| 179 |
-
results.append(r)
|
| 180 |
except Exception as exc:
|
| 181 |
-
|
| 182 |
-
|
|
|
|
| 183 |
|
| 184 |
if results:
|
| 185 |
-
avg = sum(r["score"] for r in results) / len(results)
|
| 186 |
-
|
| 187 |
-
|
| 188 |
|
| 189 |
if __name__ == "__main__":
|
| 190 |
main()
|
|
|
|
| 6 |
API_BASE_URL β LLM API endpoint
|
| 7 |
MODEL_NAME β Model identifier
|
| 8 |
HF_TOKEN β Hugging Face / API key
|
| 9 |
+
ENV_URL β Running environment URL (default: http://localhost:7860)
|
| 10 |
"""
|
| 11 |
|
| 12 |
import os
|
| 13 |
import json
|
| 14 |
import time
|
| 15 |
import re
|
| 16 |
+
import requests
|
| 17 |
from typing import List, Optional
|
| 18 |
from dotenv import load_dotenv
|
| 19 |
+
from openai import OpenAI
|
| 20 |
|
| 21 |
# Load .env variables
|
| 22 |
load_dotenv()
|
| 23 |
|
|
|
|
|
|
|
|
|
|
| 24 |
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 25 |
+
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 26 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 27 |
+
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY")
|
| 28 |
+
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
|
| 29 |
+
BENCHMARK = "code-security-review"
|
| 30 |
+
|
| 31 |
+
if not HF_TOKEN:
|
| 32 |
+
raise ValueError("HF_TOKEN or API_KEY must be set.")
|
| 33 |
|
| 34 |
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 35 |
|
|
|
|
| 43 |
{
|
| 44 |
"bug_identified": true or false,
|
| 45 |
"bug_location": "exact location (function name, line description, variable, expression)",
|
| 46 |
+
"bug_type": "off-by-one | logic-error | security-vulnerability | none",
|
| 47 |
"bug_description": "detailed explanation of why this is a bug and the impact",
|
| 48 |
"severity": "none | low | medium | high | critical",
|
| 49 |
"suggested_fix": "the corrected code snippet or a precise description of the fix"
|
|
|
|
| 71 |
# ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 72 |
|
| 73 |
def env_post(path: str, data: Optional[dict] = None, params: Optional[dict] = None) -> dict:
|
| 74 |
+
url = f"{ENV_URL}{path}"
|
| 75 |
resp = requests.post(url, json=data or {}, params=params or {}, timeout=30)
|
| 76 |
resp.raise_for_status()
|
| 77 |
return resp.json()
|
|
|
|
| 82 |
text = text.strip()
|
| 83 |
text = re.sub(r"^```(?:json)?\s*", "", text)
|
| 84 |
text = re.sub(r"\s*```$", "", text)
|
| 85 |
+
# If the LLM still included text around the JSON, try to find the first { and last }
|
| 86 |
+
match = re.search(r"({.*})", text, re.DOTALL)
|
| 87 |
+
if match:
|
| 88 |
+
text = match.group(1)
|
| 89 |
+
try:
|
| 90 |
+
return json.loads(text)
|
| 91 |
+
except Exception:
|
| 92 |
+
return {}
|
| 93 |
|
| 94 |
|
| 95 |
def build_prompt(obs: dict) -> str:
|
| 96 |
lines = [
|
| 97 |
f"Language: {obs['language']}",
|
| 98 |
+
f"Context: {obs.get('context', 'No context provided')}",
|
| 99 |
+
f"PR Title: {obs.get('pr_title', 'No PR title')}",
|
| 100 |
+
f"File Path: {obs.get('file_path', 'unknown')}",
|
| 101 |
"",
|
| 102 |
f"```{obs['language']}",
|
| 103 |
obs["code_snippet"],
|
| 104 |
"```",
|
| 105 |
]
|
|
|
|
|
|
|
|
|
|
| 106 |
return "\n".join(lines)
|
| 107 |
|
| 108 |
|
| 109 |
# ββ Task runner βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 110 |
|
| 111 |
+
def run_task(task_id: str, task_num: int) -> dict:
|
| 112 |
+
reset_resp = env_post("/reset", params={"task_id": task_id})
|
| 113 |
obs = reset_resp["observation"]
|
| 114 |
+
|
|
|
|
| 115 |
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 116 |
|
| 117 |
+
cumulative_reward = 0.0
|
| 118 |
+
step_num = 0
|
| 119 |
+
max_steps = 1
|
| 120 |
done = False
|
| 121 |
+
all_rewards = []
|
| 122 |
+
error = None
|
| 123 |
|
| 124 |
+
while not done and step_num < max_steps:
|
| 125 |
+
step_num += 1
|
| 126 |
prompt = build_prompt(obs)
|
| 127 |
+
action_dict = {}
|
| 128 |
|
| 129 |
# ββ LLM call ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 130 |
try:
|
|
|
|
| 136 |
],
|
| 137 |
temperature=0.1,
|
| 138 |
max_tokens=600,
|
| 139 |
+
stream=False,
|
| 140 |
)
|
| 141 |
raw = response.choices[0].message.content
|
| 142 |
action_dict = parse_json_from_llm(raw)
|
| 143 |
action_str = json.dumps(action_dict)
|
| 144 |
+
error = None
|
| 145 |
except Exception as exc:
|
| 146 |
+
error = str(exc).replace("\n", " ")
|
| 147 |
action_dict = {
|
| 148 |
"bug_identified": False,
|
| 149 |
+
"bug_location": "none",
|
| 150 |
"bug_type": "none",
|
| 151 |
+
"bug_description": f"Error: {error}",
|
| 152 |
"severity": "none",
|
| 153 |
+
"suggested_fix": "none",
|
| 154 |
}
|
| 155 |
action_str = "{}"
|
| 156 |
|
|
|
|
| 158 |
step_resp = env_post("/step", data=action_dict)
|
| 159 |
reward = step_resp["reward"]
|
| 160 |
done = step_resp["done"]
|
| 161 |
+
obs = step_resp.get("observation")
|
| 162 |
|
| 163 |
+
all_rewards.append(reward)
|
| 164 |
+
cumulative_reward += reward
|
| 165 |
+
|
| 166 |
+
log_step(step=step_num, action=action_str, reward=reward, done=done, error=error)
|
| 167 |
|
| 168 |
+
success = cumulative_reward >= 0.8
|
| 169 |
+
log_end(success=success, steps=step_num, score=cumulative_reward, rewards=all_rewards)
|
|
|
|
|
|
|
|
|
|
| 170 |
|
|
|
|
|
|
|
| 171 |
return {
|
| 172 |
+
"task_num": task_num,
|
| 173 |
+
"task_id": task_id,
|
| 174 |
+
"score": cumulative_reward,
|
| 175 |
+
"success": success,
|
| 176 |
}
|
| 177 |
|
| 178 |
|
| 179 |
# ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 180 |
|
| 181 |
def main():
|
| 182 |
+
print(f"[INFO] Initializing inference on {BENCHMARK} using {MODEL_NAME}", flush=True)
|
| 183 |
+
|
| 184 |
+
TASK_FILTER = os.environ.get("TASK")
|
| 185 |
+
|
| 186 |
+
all_tasks = [
|
| 187 |
+
("python-off-by-one", 1, "easy"),
|
| 188 |
+
("js-auth-privilege", 2, "medium"),
|
| 189 |
+
("python-sql-injection", 3, "hard"),
|
| 190 |
+
]
|
| 191 |
+
|
| 192 |
+
if TASK_FILTER:
|
| 193 |
+
tasks = [t for t in all_tasks if t[2] == TASK_FILTER]
|
| 194 |
+
else:
|
| 195 |
+
tasks = all_tasks
|
| 196 |
+
|
| 197 |
results = []
|
| 198 |
|
| 199 |
+
for task_id, task_num, _ in tasks:
|
| 200 |
try:
|
| 201 |
+
r = run_task(task_id, task_num)
|
|
|
|
| 202 |
except Exception as exc:
|
| 203 |
+
print(f"[ERROR] task_id={task_id} error={exc}", flush=True)
|
| 204 |
+
r = {"task_num": task_num, "task_id": task_id, "score": 0.0, "success": False}
|
| 205 |
+
results.append(r)
|
| 206 |
|
| 207 |
if results:
|
| 208 |
+
avg = round(sum(r["score"] for r in results) / len(results), 3)
|
| 209 |
+
successes = sum(1 for r in results if r.get("success"))
|
| 210 |
+
print(f"\n[SUMMARY] avg_reward={avg} tasks_passed={successes}/{len(results)}", flush=True)
|
| 211 |
|
| 212 |
if __name__ == "__main__":
|
| 213 |
main()
|
openenv.yaml
CHANGED
|
@@ -1,47 +1,50 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
description: >
|
| 4 |
-
RL environment for training AI agents to detect bugs and security
|
| 5 |
-
vulnerabilities in real Python code. Covers off-by-one errors,
|
| 6 |
-
authentication logic flaws, and SQL injection β with deterministic
|
| 7 |
-
programmatic graders and partial-progress reward signals.
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
author: Inmodel Labs
|
| 10 |
-
tags:
|
| 11 |
-
- code-review
|
| 12 |
-
- security
|
| 13 |
-
- software-engineering
|
| 14 |
-
- real-world
|
| 15 |
-
- python
|
| 16 |
|
|
|
|
|
|
|
| 17 |
tasks:
|
| 18 |
-
- id:
|
|
|
|
|
|
|
| 19 |
difficulty: easy
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
difficulty: easy
|
| 23 |
|
| 24 |
-
- id:
|
|
|
|
|
|
|
| 25 |
difficulty: medium
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
difficulty: medium
|
| 29 |
|
| 30 |
-
- id:
|
|
|
|
|
|
|
| 31 |
difficulty: hard
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
difficulty: hard
|
| 35 |
|
|
|
|
|
|
|
| 36 |
action_space:
|
| 37 |
type: object
|
| 38 |
properties:
|
| 39 |
-
bug_identified: { type: boolean }
|
| 40 |
-
bug_location: { type: string }
|
| 41 |
-
bug_type: { type: string }
|
| 42 |
-
bug_description: { type: string }
|
| 43 |
-
severity: { type: string, enum: [none, low, medium, high, critical] }
|
| 44 |
-
suggested_fix: { type: string }
|
| 45 |
required:
|
| 46 |
- bug_identified
|
| 47 |
- bug_location
|
|
@@ -50,18 +53,20 @@ action_space:
|
|
| 50 |
- severity
|
| 51 |
- suggested_fix
|
| 52 |
|
|
|
|
|
|
|
| 53 |
observation_space:
|
| 54 |
type: object
|
| 55 |
properties:
|
| 56 |
-
|
| 57 |
-
language: { type: string }
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
previous_feedback: { type: string, nullable: true }
|
| 64 |
|
|
|
|
| 65 |
reward:
|
| 66 |
min: 0.0
|
| 67 |
max: 1.0
|
|
@@ -69,9 +74,11 @@ reward:
|
|
| 69 |
Partial rewards for: bug identification (0.20), correct bug type (0.20),
|
| 70 |
precise location (0.10), description quality (0.25, keyword density),
|
| 71 |
fix quality (0.15, keyword density), correct severity (0.10).
|
|
|
|
| 72 |
|
| 73 |
endpoints:
|
| 74 |
-
health: GET
|
| 75 |
-
reset:
|
| 76 |
-
step:
|
| 77 |
-
state:
|
|
|
|
|
|
| 1 |
+
# OpenEnv Environment Specification
|
| 2 |
+
# This file describes the Code Security Review environment for the Meta PyTorch OpenEnv Hackathon.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
# Metadata section details the environment's identity.
|
| 5 |
+
name: code-security-review
|
| 6 |
+
version: "1.0.0"
|
| 7 |
+
description: >
|
| 8 |
+
An RL environment for training AI agents to perform code security review.
|
| 9 |
+
Agents analyze code snippets from production pull requests and identify bugs,
|
| 10 |
+
vulnerabilities, and security issues.
|
| 11 |
author: Inmodel Labs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
# Tasks section defines the core challenges in the environment.
|
| 14 |
+
# Each task has a unique ID, name, description, and difficulty level.
|
| 15 |
tasks:
|
| 16 |
+
- id: python-off-by-one
|
| 17 |
+
name: "Python Off-by-One Error"
|
| 18 |
+
description: "Identify an off-by-one index error in a Python finance batch processor"
|
| 19 |
difficulty: easy
|
| 20 |
+
max_steps: 1
|
| 21 |
+
reward_range: [0.0, 1.0]
|
|
|
|
| 22 |
|
| 23 |
+
- id: js-auth-privilege
|
| 24 |
+
name: "JavaScript Auth Logic Flaw"
|
| 25 |
+
description: "Identify a privilege escalation vulnerability in Node.js auth middleware"
|
| 26 |
difficulty: medium
|
| 27 |
+
max_steps: 1
|
| 28 |
+
reward_range: [0.0, 1.0]
|
|
|
|
| 29 |
|
| 30 |
+
- id: python-sql-injection
|
| 31 |
+
name: "Python SQL Injection"
|
| 32 |
+
description: "Identify an SQL injection vulnerability via f-string in a REST API"
|
| 33 |
difficulty: hard
|
| 34 |
+
max_steps: 1
|
| 35 |
+
reward_range: [0.0, 1.0]
|
|
|
|
| 36 |
|
| 37 |
+
# The Action space defines the format of the agent's response.
|
| 38 |
+
# Each field is scored by the grader to provide partial progress signals.
|
| 39 |
action_space:
|
| 40 |
type: object
|
| 41 |
properties:
|
| 42 |
+
bug_identified: { type: boolean, description: "Boolean: true if a bug exists" }
|
| 43 |
+
bug_location: { type: string, description: "String: Pinpoint the bug's location in code" }
|
| 44 |
+
bug_type: { type: string, description: "String: off-by-one | logic-error | security-vulnerability | none" }
|
| 45 |
+
bug_description: { type: string, description: "String: Detailed analysis of the vulnerability" }
|
| 46 |
+
severity: { type: string, enum: [none, low, medium, high, critical], description: "String: none | low | medium | high | critical" }
|
| 47 |
+
suggested_fix: { type: string, description: "String: How to fix the identified bug" }
|
| 48 |
required:
|
| 49 |
- bug_identified
|
| 50 |
- bug_location
|
|
|
|
| 53 |
- severity
|
| 54 |
- suggested_fix
|
| 55 |
|
| 56 |
+
# The Observation space defines what the agent sees at each step.
|
| 57 |
+
# It uses a structured context to help the agent understand the code's purpose.
|
| 58 |
observation_space:
|
| 59 |
type: object
|
| 60 |
properties:
|
| 61 |
+
task_id: { type: string, description: "Unique task identifier" }
|
| 62 |
+
language: { type: string, description: "Source code language" }
|
| 63 |
+
difficulty: { type: string, enum: [easy, medium, hard], description: "Task complexity (easy/medium/hard)" }
|
| 64 |
+
code_snippet: { type: string, description: "The source code to be reviewed" }
|
| 65 |
+
context: { type: string, description: "Real-world context (e.g., API description)" }
|
| 66 |
+
pr_title: { type: string, description: "Pull Request title for additional intent context" }
|
| 67 |
+
file_path: { type: string, description: "Relative path to the file in the repository" }
|
|
|
|
| 68 |
|
| 69 |
+
# Reward structure for evaluating agent performance.
|
| 70 |
reward:
|
| 71 |
min: 0.0
|
| 72 |
max: 1.0
|
|
|
|
| 74 |
Partial rewards for: bug identification (0.20), correct bug type (0.20),
|
| 75 |
precise location (0.10), description quality (0.25, keyword density),
|
| 76 |
fix quality (0.15, keyword density), correct severity (0.10).
|
| 77 |
+
Grader penalizes keyword stuffing.
|
| 78 |
|
| 79 |
endpoints:
|
| 80 |
+
health: GET /
|
| 81 |
+
reset: POST /reset
|
| 82 |
+
step: POST /step
|
| 83 |
+
state: GET /state
|
| 84 |
+
tasks: GET /tasks
|
qa_test.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
BASE_URL = "http://localhost:7860"
|
| 5 |
+
|
| 6 |
+
def run_tests():
|
| 7 |
+
checks = []
|
| 8 |
+
|
| 9 |
+
# 1. GET /
|
| 10 |
+
try:
|
| 11 |
+
r = requests.get(f"{BASE_URL}/")
|
| 12 |
+
passed = r.status_code == 200 and r.json().get("status") == "ok"
|
| 13 |
+
checks.append({
|
| 14 |
+
"id": 1, "name": "GET / health check", "passed": passed,
|
| 15 |
+
"expected": 'HTTP 200 and {"status": "ok"}', "got": f"HTTP {r.status_code} {r.text}"
|
| 16 |
+
})
|
| 17 |
+
except Exception as e:
|
| 18 |
+
checks.append({"id": 1, "name": "GET / health check", "passed": False, "expected": "200 OK", "got": str(e)})
|
| 19 |
+
|
| 20 |
+
# 15. GET /state before reset (Edge case)
|
| 21 |
+
try:
|
| 22 |
+
r = requests.get(f"{BASE_URL}/state")
|
| 23 |
+
# Should not crash
|
| 24 |
+
checks.append({
|
| 25 |
+
"id": 15, "name": "GET /state before any reset", "passed": r.status_code == 200,
|
| 26 |
+
"expected": "HTTP 200 (No crash)", "got": f"HTTP {r.status_code} {r.text}"
|
| 27 |
+
})
|
| 28 |
+
except Exception as e:
|
| 29 |
+
checks.append({"id": 15, "name": "GET /state before any reset", "passed": False, "expected": "200 OK", "got": str(e)})
|
| 30 |
+
|
| 31 |
+
# 2. POST /reset
|
| 32 |
+
try:
|
| 33 |
+
r = requests.post(f"{BASE_URL}/reset")
|
| 34 |
+
data = r.json().get("observation", {})
|
| 35 |
+
required = ["task_id", "language", "difficulty", "code_snippet", "context", "pr_title", "file_path"]
|
| 36 |
+
passed = all(k in data for k in required)
|
| 37 |
+
checks.append({
|
| 38 |
+
"id": 2, "name": "POST /reset fields check", "passed": passed,
|
| 39 |
+
"expected": f"JSON with {required}", "got": list(data.keys())
|
| 40 |
+
})
|
| 41 |
+
except Exception as e:
|
| 42 |
+
checks.append({"id": 2, "name": "POST /reset fields check", "passed": False, "expected": "Fields", "got": str(e)})
|
| 43 |
+
|
| 44 |
+
# 16. POST /reset no task_id
|
| 45 |
+
try:
|
| 46 |
+
r = requests.post(f"{BASE_URL}/reset")
|
| 47 |
+
checks.append({
|
| 48 |
+
"id": 16, "name": "POST /reset no task_id (Random)", "passed": r.status_code == 200,
|
| 49 |
+
"expected": "HTTP 200", "got": f"HTTP {r.status_code}"
|
| 50 |
+
})
|
| 51 |
+
except Exception as e:
|
| 52 |
+
checks.append({"id": 16, "name": "POST /reset no task_id (Random)", "passed": False, "expected": "200 OK", "got": str(e)})
|
| 53 |
+
|
| 54 |
+
# 3-5. POST /reset?task_id=...
|
| 55 |
+
for tid in ["python-off-by-one", "js-auth-privilege", "python-sql-injection"]:
|
| 56 |
+
try:
|
| 57 |
+
num = {"python-off-by-one": 3, "js-auth-privilege": 4, "python-sql-injection": 5}[tid]
|
| 58 |
+
r = requests.post(f"{BASE_URL}/reset?task_id={tid}")
|
| 59 |
+
passed = r.status_code == 200 and r.json()["observation"]["task_id"] == tid
|
| 60 |
+
checks.append({
|
| 61 |
+
"id": num, "name": f"POST /reset for {tid}", "passed": passed,
|
| 62 |
+
"expected": f"HTTP 200 with task_id={tid}", "got": f"HTTP {r.status_code} {r.json()['observation']['task_id'] if passed else r.text}"
|
| 63 |
+
})
|
| 64 |
+
except Exception as e:
|
| 65 |
+
checks.append({"id": num, "name": f"POST /reset for {tid}", "passed": False, "expected": "200 OK", "got": str(e)})
|
| 66 |
+
|
| 67 |
+
# 6. GET /state
|
| 68 |
+
try:
|
| 69 |
+
r = requests.get(f"{BASE_URL}/state")
|
| 70 |
+
data = r.json()
|
| 71 |
+
required = ["task_id", "step", "done", "total_reward"]
|
| 72 |
+
passed = all(k in data for k in required)
|
| 73 |
+
checks.append({
|
| 74 |
+
"id": 6, "name": "GET /state fields check", "passed": passed,
|
| 75 |
+
"expected": f"JSON with {required}", "got": list(data.keys())
|
| 76 |
+
})
|
| 77 |
+
except Exception as e:
|
| 78 |
+
checks.append({"id": 6, "name": "GET /state fields check", "passed": False, "expected": "Fields", "got": str(e)})
|
| 79 |
+
|
| 80 |
+
# 7. POST /step with PROVIDED action
|
| 81 |
+
try:
|
| 82 |
+
requests.post(f"{BASE_URL}/reset?task_id=python-sql-injection")
|
| 83 |
+
action = {
|
| 84 |
+
"bug_identified": True,
|
| 85 |
+
"bug_location": "line 2 f-string",
|
| 86 |
+
"bug_type": "security-vulnerability",
|
| 87 |
+
"bug_description": "SQL injection via f-string",
|
| 88 |
+
"severity": "critical",
|
| 89 |
+
"suggested_fix": "use parameterized query"
|
| 90 |
+
}
|
| 91 |
+
r = requests.post(f"{BASE_URL}/step", json=action)
|
| 92 |
+
res = r.json()
|
| 93 |
+
reward = res.get("reward", -1.0)
|
| 94 |
+
done = res.get("done", False)
|
| 95 |
+
passed = 0.0 <= reward <= 1.0 and done is True
|
| 96 |
+
checks.append({
|
| 97 |
+
"id": 7, "name": "POST /step valid action", "passed": passed,
|
| 98 |
+
"expected": "Reward [0,1] and done=true", "got": f"reward={reward}, done={done}"
|
| 99 |
+
})
|
| 100 |
+
except Exception as e:
|
| 101 |
+
checks.append({"id": 7, "name": "POST /step valid action", "passed": False, "expected": "Result", "got": str(e)})
|
| 102 |
+
|
| 103 |
+
# 14. Call POST /step twice (Edge Case)
|
| 104 |
+
try:
|
| 105 |
+
# Step already called in task 7
|
| 106 |
+
action = {"bug_identified": False, "bug_location": "", "bug_type": "none", "bug_description": "", "severity": "none", "suggested_fix": ""}
|
| 107 |
+
r = requests.post(f"{BASE_URL}/step", json=action)
|
| 108 |
+
res = r.json()
|
| 109 |
+
passed = r.status_code == 200 and "error" in res.get("info", {})
|
| 110 |
+
checks.append({
|
| 111 |
+
"id": 14, "name": "POST /step twice in same episode", "passed": passed,
|
| 112 |
+
"expected": "HTTP 200 and error in info", "got": f"HTTP {r.status_code}, info={res.get('info')}"
|
| 113 |
+
})
|
| 114 |
+
except Exception as e:
|
| 115 |
+
checks.append({"id": 14, "name": "POST /step twice in same episode", "passed": False, "expected": "Handled error", "got": str(e)})
|
| 116 |
+
|
| 117 |
+
# 8. Perfect action for SQL
|
| 118 |
+
try:
|
| 119 |
+
requests.post(f"{BASE_URL}/reset?task_id=python-sql-injection")
|
| 120 |
+
perfect_action = {
|
| 121 |
+
"bug_identified": True,
|
| 122 |
+
"bug_location": "line 2 f-string interpolation in SQL query construction",
|
| 123 |
+
"bug_type": "security-vulnerability",
|
| 124 |
+
"bug_description": "SQL injection vulnerability where user-supplied search_term is directly interpolated into the SQL query via f-string. An attacker can inject malicious SQL to bypass authentication, exfiltrate all user data, or drop tables. The fix is to use parameterized queries which sanitize user input automatically.",
|
| 125 |
+
"severity": "critical",
|
| 126 |
+
"suggested_fix": "Use db.execute('SELECT * FROM users WHERE name LIKE %s', ('%'+search_term+'%',)) instead of f-string interpolation"
|
| 127 |
+
}
|
| 128 |
+
r = requests.post(f"{BASE_URL}/step", json=perfect_action)
|
| 129 |
+
reward = r.json().get("reward", 0.0)
|
| 130 |
+
checks.append({
|
| 131 |
+
"id": 8, "name": "PERFECT action SQL", "passed": reward >= 0.85,
|
| 132 |
+
"expected": "Reward >= 0.85", "got": f"reward={reward}"
|
| 133 |
+
})
|
| 134 |
+
except Exception as e:
|
| 135 |
+
checks.append({"id": 8, "name": "PERFECT action SQL", "passed": False, "expected": ">=0.85", "got": str(e)})
|
| 136 |
+
|
| 137 |
+
# 9. Keyword stuffed
|
| 138 |
+
try:
|
| 139 |
+
requests.post(f"{BASE_URL}/reset?task_id=python-sql-injection")
|
| 140 |
+
stuffed_action = {
|
| 141 |
+
"bug_identified": True,
|
| 142 |
+
"bug_location": "sql",
|
| 143 |
+
"bug_type": "security-vulnerability",
|
| 144 |
+
"bug_description": "sql injection sql injection sql injection parameterized f-string sanitize escape malicious attack tautology union drop sql injection sql injection",
|
| 145 |
+
"severity": "critical",
|
| 146 |
+
"suggested_fix": "fix"
|
| 147 |
+
}
|
| 148 |
+
r = requests.post(f"{BASE_URL}/step", json=stuffed_action)
|
| 149 |
+
reward = r.json().get("reward", 1.0)
|
| 150 |
+
checks.append({
|
| 151 |
+
"id": 9, "name": "KEYWORD STUFFED action", "passed": reward <= 0.20,
|
| 152 |
+
"expected": "Reward <= 0.20", "got": f"reward={reward}"
|
| 153 |
+
})
|
| 154 |
+
except Exception as e:
|
| 155 |
+
checks.append({"id": 9, "name": "KEYWORD STUFFED action", "passed": False, "expected": "<=0.20", "got": str(e)})
|
| 156 |
+
|
| 157 |
+
# 10. Bug identified false
|
| 158 |
+
try:
|
| 159 |
+
requests.post(f"{BASE_URL}/reset")
|
| 160 |
+
action = {"bug_identified": False, "bug_location": "", "bug_type": "none", "bug_description": "", "severity": "none", "suggested_fix": ""}
|
| 161 |
+
r = requests.post(f"{BASE_URL}/step", json=action)
|
| 162 |
+
reward = r.json().get("reward", 1.0)
|
| 163 |
+
checks.append({
|
| 164 |
+
"id": 10, "name": "Identify=False empty fields", "passed": reward == 0.0,
|
| 165 |
+
"expected": "Reward exactly 0.0", "got": f"reward={reward}"
|
| 166 |
+
})
|
| 167 |
+
except Exception as e:
|
| 168 |
+
checks.append({"id": 10, "name": "Identify=False empty fields", "passed": False, "expected": "0.0", "got": str(e)})
|
| 169 |
+
|
| 170 |
+
# 11. Partial credit severity
|
| 171 |
+
try:
|
| 172 |
+
# Off-by-one is severity critical (I set it to critical).
|
| 173 |
+
# Let's say I submit 'low' severity.
|
| 174 |
+
requests.post(f"{BASE_URL}/reset?task_id=python-off-by-one")
|
| 175 |
+
action = {
|
| 176 |
+
"bug_identified": True, "bug_location": "range", "bug_type": "off-by-one",
|
| 177 |
+
"bug_description": "off-by-one error in range function call",
|
| 178 |
+
"severity": "low", # Wrong severity
|
| 179 |
+
"suggested_fix": "range(len(x))"
|
| 180 |
+
}
|
| 181 |
+
r = requests.post(f"{BASE_URL}/step", json=action)
|
| 182 |
+
info = r.json().get("info", {})
|
| 183 |
+
breakdown = info.get("reward_breakdown", {})
|
| 184 |
+
sev_score = breakdown.get("severity", -1.0)
|
| 185 |
+
# It should be 0.0 (wrong) but the total should still have partial credit from other components
|
| 186 |
+
reward = r.json().get("reward", 0.0)
|
| 187 |
+
checks.append({
|
| 188 |
+
"id": 11, "name": "Partial credit (wrong severity)", "passed": 0.0 < reward < 1.0,
|
| 189 |
+
"expected": "Reward between 0 and 1 (partial credit)", "got": f"reward={reward}, severity_component={sev_score}"
|
| 190 |
+
})
|
| 191 |
+
except Exception as e:
|
| 192 |
+
checks.append({"id": 11, "name": "Partial credit (wrong severity)", "passed": False, "expected": "Partial credit", "got": str(e)})
|
| 193 |
+
|
| 194 |
+
# 12-13. Breakdown keys and components
|
| 195 |
+
try:
|
| 196 |
+
requests.post(f"{BASE_URL}/reset")
|
| 197 |
+
action = {"bug_identified": True, "bug_location": "test", "bug_type": "test", "bug_description": "test test test test test test test test test test test test test test test test test test test test", "severity": "none", "suggested_fix": "test test test"}
|
| 198 |
+
r = requests.post(f"{BASE_URL}/step", json=action)
|
| 199 |
+
info = r.json().get("info", {})
|
| 200 |
+
breakdown = info.get("reward_breakdown", {})
|
| 201 |
+
required = ["bug_identified", "bug_type", "bug_location", "description_quality", "fix_quality", "severity"]
|
| 202 |
+
checks.append({
|
| 203 |
+
"id": 12, "name": "Reward breakdown keys", "passed": all(k in breakdown for k in required),
|
| 204 |
+
"expected": f"Breakdown with {required}", "got": list(breakdown.keys())
|
| 205 |
+
})
|
| 206 |
+
|
| 207 |
+
max_vals = {
|
| 208 |
+
"bug_identified": 0.20, "bug_type": 0.20, "bug_location": 0.10,
|
| 209 |
+
"description_quality": 0.25, "fix_quality": 0.15, "severity": 0.10
|
| 210 |
+
}
|
| 211 |
+
passed_range = all(0.0 <= breakdown.get(k, -1) <= max_vals[k] for k in max_vals)
|
| 212 |
+
checks.append({
|
| 213 |
+
"id": 13, "name": "Component score ranges", "passed": passed_range,
|
| 214 |
+
"expected": "All components <= max", "got": breakdown
|
| 215 |
+
})
|
| 216 |
+
except Exception as e:
|
| 217 |
+
checks.append({"id": 12, "name": "Breakdown checks", "passed": False, "expected": "Breakdown", "got": str(e)})
|
| 218 |
+
|
| 219 |
+
# Sort and print
|
| 220 |
+
checks.sort(key=lambda x: x["id"])
|
| 221 |
+
for c in checks:
|
| 222 |
+
status = "PASS" if c["passed"] else "FAIL"
|
| 223 |
+
print(f"[{c['id']}] {c['name']} β {status}")
|
| 224 |
+
print(f" Expected: {c['expected']}")
|
| 225 |
+
print(f" Got: {c['got']}")
|
| 226 |
+
print("")
|
| 227 |
+
|
| 228 |
+
passed_count = sum(1 for c in checks if c["passed"])
|
| 229 |
+
disqual = "YES" if passed_count < 7 else "NO" # Disqualified if Part 1 fails
|
| 230 |
+
print(f"TOTAL: {passed_count}/16 passed")
|
| 231 |
+
print(f"DISQUALIFICATION RISK: {disqual}")
|
| 232 |
+
# Estimate score based on points
|
| 233 |
+
score = (passed_count / 16) * 100
|
| 234 |
+
print(f"ESTIMATED SCORE: {round(score)}/100")
|
| 235 |
+
|
| 236 |
+
if __name__ == "__main__":
|
| 237 |
+
run_tests()
|
server/app.py
CHANGED
|
@@ -1,19 +1,16 @@
|
|
| 1 |
import os
|
| 2 |
import uvicorn
|
|
|
|
| 3 |
from fastapi import FastAPI, HTTPException, Query
|
| 4 |
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
-
from fastapi.staticfiles import StaticFiles
|
| 6 |
-
from fastapi.responses import FileResponse
|
| 7 |
|
| 8 |
-
from .models import CodeReviewAction,
|
| 9 |
-
from .
|
|
|
|
| 10 |
|
| 11 |
app = FastAPI(
|
| 12 |
title="Code Security Review β OpenEnv",
|
| 13 |
-
description=
|
| 14 |
-
"RL environment for training AI agents to detect bugs and security "
|
| 15 |
-
"vulnerabilities in code. Compatible with the OpenEnv spec."
|
| 16 |
-
),
|
| 17 |
version="1.0.0",
|
| 18 |
)
|
| 19 |
|
|
@@ -24,46 +21,61 @@ app.add_middleware(
|
|
| 24 |
allow_headers=["*"],
|
| 25 |
)
|
| 26 |
|
| 27 |
-
|
| 28 |
|
| 29 |
-
env = CodeReviewEnvironment()
|
| 30 |
|
| 31 |
@app.get("/")
|
| 32 |
-
def read_index():
|
| 33 |
-
return FileResponse("static/index.html")
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
@app.get("/health")
|
| 37 |
def health():
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
@app.post("/reset", response_model=ResetResponse)
|
| 42 |
-
def reset(
|
|
|
|
|
|
|
|
|
|
| 43 |
"""Reset the environment and return the first observation."""
|
| 44 |
-
|
|
|
|
|
|
|
| 45 |
return ResetResponse(observation=obs)
|
| 46 |
|
| 47 |
|
| 48 |
-
@app.post("/step", response_model=
|
| 49 |
def step(action: CodeReviewAction):
|
| 50 |
"""Submit a code review action and receive a reward signal."""
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
return StepResponse(observation=obs, reward=reward, done=done, info=info)
|
| 54 |
-
except ValueError as exc:
|
| 55 |
-
raise HTTPException(status_code=400, detail=str(exc))
|
| 56 |
|
| 57 |
|
| 58 |
-
@app.get("/state", response_model=
|
| 59 |
def state():
|
| 60 |
"""Return the current environment state."""
|
| 61 |
return env.state()
|
| 62 |
|
| 63 |
|
| 64 |
if __name__ == "__main__":
|
| 65 |
-
port = int(os.environ.get("PORT",
|
| 66 |
-
enable_web = os.environ.get("ENABLE_WEB_INTERFACE", "false").lower() == "true"
|
| 67 |
uvicorn.run(
|
| 68 |
"server.app:app",
|
| 69 |
host="0.0.0.0",
|
|
|
|
| 1 |
import os
|
| 2 |
import uvicorn
|
| 3 |
+
from typing import List, Optional
|
| 4 |
from fastapi import FastAPI, HTTPException, Query
|
| 5 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
from server.models import CodeReviewAction, StepResult, ResetResponse, StateResponse, TaskInfo
|
| 8 |
+
from server.tasks import TASKS
|
| 9 |
+
from server.environment import CodeSecurityEnv
|
| 10 |
|
| 11 |
app = FastAPI(
|
| 12 |
title="Code Security Review β OpenEnv",
|
| 13 |
+
description="An RL environment for training AI agents to perform code security review.",
|
|
|
|
|
|
|
|
|
|
| 14 |
version="1.0.0",
|
| 15 |
)
|
| 16 |
|
|
|
|
| 21 |
allow_headers=["*"],
|
| 22 |
)
|
| 23 |
|
| 24 |
+
env = CodeSecurityEnv()
|
| 25 |
|
|
|
|
| 26 |
|
| 27 |
@app.get("/")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def health():
|
| 29 |
+
"""Health check endpoint."""
|
| 30 |
+
return {
|
| 31 |
+
"status": "ok",
|
| 32 |
+
"project": "Code Security Review - OpenEnv",
|
| 33 |
+
"version": "1.0.0",
|
| 34 |
+
"organization": "Inmodel Labs",
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@app.get("/tasks", response_model=List[TaskInfo])
|
| 39 |
+
def list_tasks():
|
| 40 |
+
"""List all available tasks."""
|
| 41 |
+
return [
|
| 42 |
+
TaskInfo(
|
| 43 |
+
id=t["id"],
|
| 44 |
+
language=t["language"],
|
| 45 |
+
bug_class=t["bug_class"],
|
| 46 |
+
difficulty=t["difficulty"],
|
| 47 |
+
)
|
| 48 |
+
for t in TASKS.values()
|
| 49 |
+
]
|
| 50 |
|
| 51 |
|
| 52 |
@app.post("/reset", response_model=ResetResponse)
|
| 53 |
+
def reset(
|
| 54 |
+
task_id: str = Query(default="python-off-by-one", description="Task ID to reset to"),
|
| 55 |
+
seed: Optional[int] = Query(default=None, description="Optional seed for reproducibility")
|
| 56 |
+
):
|
| 57 |
"""Reset the environment and return the first observation."""
|
| 58 |
+
if task_id not in TASKS:
|
| 59 |
+
raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found.")
|
| 60 |
+
obs = env.reset(task_id=task_id, seed=seed)
|
| 61 |
return ResetResponse(observation=obs)
|
| 62 |
|
| 63 |
|
| 64 |
+
@app.post("/step", response_model=StepResult)
|
| 65 |
def step(action: CodeReviewAction):
|
| 66 |
"""Submit a code review action and receive a reward signal."""
|
| 67 |
+
result = env.step(action)
|
| 68 |
+
return result
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
+
@app.get("/state", response_model=StateResponse)
|
| 72 |
def state():
|
| 73 |
"""Return the current environment state."""
|
| 74 |
return env.state()
|
| 75 |
|
| 76 |
|
| 77 |
if __name__ == "__main__":
|
| 78 |
+
port = int(os.environ.get("PORT", 8000))
|
|
|
|
| 79 |
uvicorn.run(
|
| 80 |
"server.app:app",
|
| 81 |
host="0.0.0.0",
|
server/environment.py
CHANGED
|
@@ -1,447 +1,84 @@
|
|
| 1 |
-
|
| 2 |
-
from
|
| 3 |
|
| 4 |
-
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
TASKS: Dict[str, dict] = {
|
| 10 |
-
|
| 11 |
-
# EASY
|
| 12 |
-
"easy": {
|
| 13 |
-
"id": "task_easy_001",
|
| 14 |
-
"difficulty": "easy",
|
| 15 |
-
"language": "python",
|
| 16 |
-
"description": (
|
| 17 |
-
"This function is supposed to sum all elements in a list. "
|
| 18 |
-
"Find any bugs and suggest a fix."
|
| 19 |
-
),
|
| 20 |
-
"code": (
|
| 21 |
-
"def sum_elements(arr):\n"
|
| 22 |
-
' """Return the sum of all elements."""\n'
|
| 23 |
-
" total = 0\n"
|
| 24 |
-
" for i in range(1, len(arr) + 1): # iterates over indices\n"
|
| 25 |
-
" total += arr[i]\n"
|
| 26 |
-
" return total"
|
| 27 |
-
),
|
| 28 |
-
"ground_truth": {
|
| 29 |
-
"bug_identified": True,
|
| 30 |
-
"bug_type_keywords": [
|
| 31 |
-
"off-by-one", "off by one", "index error", "indexerror",
|
| 32 |
-
"out of bounds", "out of range", "index out",
|
| 33 |
-
],
|
| 34 |
-
"location_keywords": [
|
| 35 |
-
"range(1, len(arr) + 1)", "len(arr) + 1", "len(arr)+1",
|
| 36 |
-
"range", "loop", "index", "arr[i]",
|
| 37 |
-
],
|
| 38 |
-
"description_keywords": [
|
| 39 |
-
"index", "range", "len", "off-by-one", "off by one",
|
| 40 |
-
"IndexError", "out of bounds", "+1", "exceed", "arr[i]",
|
| 41 |
-
"zero", "start",
|
| 42 |
-
],
|
| 43 |
-
"fix_keywords": [
|
| 44 |
-
"range(len(arr))", "range(0, len(arr))",
|
| 45 |
-
"for i in range(len", "for element in arr",
|
| 46 |
-
"arr[i]" , "len(arr))",
|
| 47 |
-
],
|
| 48 |
-
"severity_valid": ["high", "medium"],
|
| 49 |
-
},
|
| 50 |
-
},
|
| 51 |
-
|
| 52 |
-
#MEDIUM
|
| 53 |
-
"medium": {
|
| 54 |
-
"id": "task_medium_001",
|
| 55 |
-
"difficulty": "medium",
|
| 56 |
-
"language": "python",
|
| 57 |
-
"description": (
|
| 58 |
-
"This authentication function controls admin access. "
|
| 59 |
-
"Find the logical security bug."
|
| 60 |
-
),
|
| 61 |
-
"code": (
|
| 62 |
-
"def authenticate_user(username, password, request_admin=False):\n"
|
| 63 |
-
' """Authenticate user and return access level."""\n'
|
| 64 |
-
" user = db.find_user(username)\n"
|
| 65 |
-
" if not user or user.password_hash != hash_password(password):\n"
|
| 66 |
-
' return {"authenticated": False, "level": "none"}\n'
|
| 67 |
-
"\n"
|
| 68 |
-
" # Elevate to admin if caller requests it OR user has admin role\n"
|
| 69 |
-
" if request_admin or user.role == 'admin': # <-- review this\n"
|
| 70 |
-
' return {"authenticated": True, "level": "admin"}\n'
|
| 71 |
-
"\n"
|
| 72 |
-
' return {"authenticated": True, "level": "user"}'
|
| 73 |
-
),
|
| 74 |
-
"ground_truth": {
|
| 75 |
-
"bug_identified": True,
|
| 76 |
-
"bug_type_keywords": [
|
| 77 |
-
"logic", "logic error", "logical", "privilege escalation",
|
| 78 |
-
"authorization", "authentication bypass", "access control",
|
| 79 |
-
],
|
| 80 |
-
"location_keywords": [
|
| 81 |
-
"request_admin or", "or user.role", "or", "condition",
|
| 82 |
-
"if request_admin", "or user.role == 'admin'",
|
| 83 |
-
],
|
| 84 |
-
"description_keywords": [
|
| 85 |
-
"or", "and", "privilege", "escalation", "bypass", "admin",
|
| 86 |
-
"role", "caller", "request_admin", "logic", "elevation",
|
| 87 |
-
"any caller", "arbitrary",
|
| 88 |
-
],
|
| 89 |
-
"fix_keywords": [
|
| 90 |
-
"and", "request_admin and user.role", "and user.role == 'admin'",
|
| 91 |
-
"and user.role", "both",
|
| 92 |
-
],
|
| 93 |
-
"severity_valid": ["critical", "high"],
|
| 94 |
-
},
|
| 95 |
-
},
|
| 96 |
-
|
| 97 |
-
# ββ HARD ββββββββββββββββββββββββββββββββββ
|
| 98 |
-
"hard": {
|
| 99 |
-
"id": "task_hard_001",
|
| 100 |
-
"difficulty": "hard",
|
| 101 |
-
"language": "python",
|
| 102 |
-
"description": (
|
| 103 |
-
"This function fetches records from a database using user-supplied input. "
|
| 104 |
-
"Identify the security vulnerability."
|
| 105 |
-
),
|
| 106 |
-
"code": (
|
| 107 |
-
"def fetch_records(user_id: str, sort_column: str):\n"
|
| 108 |
-
' """Fetch user records sorted by a given column."""\n'
|
| 109 |
-
" conn = get_db_connection()\n"
|
| 110 |
-
" cursor = conn.cursor()\n"
|
| 111 |
-
"\n"
|
| 112 |
-
" query = (\n"
|
| 113 |
-
' f"SELECT id, name, email FROM users "\n'
|
| 114 |
-
' f"WHERE user_id = {user_id} "\n'
|
| 115 |
-
' f"ORDER BY {sort_column}"\n'
|
| 116 |
-
" )\n"
|
| 117 |
-
" cursor.execute(query)\n"
|
| 118 |
-
" rows = cursor.fetchall()\n"
|
| 119 |
-
" conn.close()\n"
|
| 120 |
-
" return rows"
|
| 121 |
-
),
|
| 122 |
-
"ground_truth": {
|
| 123 |
-
"bug_identified": True,
|
| 124 |
-
"bug_type_keywords": [
|
| 125 |
-
"sql injection", "injection", "sqli", "sql",
|
| 126 |
-
"security vulnerability", "security", "second-order",
|
| 127 |
-
],
|
| 128 |
-
"location_keywords": [
|
| 129 |
-
"f\"", "f-string", "format", "user_id", "sort_column",
|
| 130 |
-
"query", "ORDER BY", "WHERE user_id",
|
| 131 |
-
],
|
| 132 |
-
"description_keywords": [
|
| 133 |
-
"sql injection", "injection", "parameterized", "f-string",
|
| 134 |
-
"format string", "user input", "sanitize", "escape",
|
| 135 |
-
"malicious", "attack", "tautology", "union", "drop",
|
| 136 |
-
"ORDER BY", "sort_column", "arbitrary",
|
| 137 |
-
],
|
| 138 |
-
"fix_keywords": [
|
| 139 |
-
"parameterized", "?", "%s", "cursor.execute(query, (",
|
| 140 |
-
"cursor.execute(query, [", "prepared statement",
|
| 141 |
-
"whitelist", "allowlist", "ALLOWED_COLUMNS",
|
| 142 |
-
"sanitize", "if sort_column not in",
|
| 143 |
-
],
|
| 144 |
-
"severity_valid": ["critical"],
|
| 145 |
-
},
|
| 146 |
-
},
|
| 147 |
-
|
| 148 |
-
# ββ EXPERT ββββββββββββββββββββββββββββββββ
|
| 149 |
-
"expert": {
|
| 150 |
-
"id": "task_expert_001",
|
| 151 |
-
"difficulty": "expert",
|
| 152 |
-
"language": "java",
|
| 153 |
-
"description": (
|
| 154 |
-
"This Java class implements a token bucket rate limiter. "
|
| 155 |
-
"Identify the logic bug that could allow users to bypass the rate limit."
|
| 156 |
-
),
|
| 157 |
-
"code": (
|
| 158 |
-
"import java.util.concurrent.atomic.AtomicLong;\n\n"
|
| 159 |
-
"public class TokenBucketRateLimiter {\n"
|
| 160 |
-
" private final long maxTokens;\n"
|
| 161 |
-
" private final long refillRatePerSecond;\n"
|
| 162 |
-
" private AtomicLong currentTokens;\n"
|
| 163 |
-
" private AtomicLong lastRefillTimestamp;\n\n"
|
| 164 |
-
" public TokenBucketRateLimiter(long maxTokens, long refillRatePerSecond) {\n"
|
| 165 |
-
" this.maxTokens = maxTokens;\n"
|
| 166 |
-
" this.refillRatePerSecond = refillRatePerSecond;\n"
|
| 167 |
-
" this.currentTokens = new AtomicLong(maxTokens);\n"
|
| 168 |
-
" this.lastRefillTimestamp = new AtomicLong(System.currentTimeMillis());\n"
|
| 169 |
-
" }\n\n"
|
| 170 |
-
" /**\n"
|
| 171 |
-
" * Checks if the requested number of tokens is available.\n"
|
| 172 |
-
" * Decrements the bucket if allowed.\n"
|
| 173 |
-
" */\n"
|
| 174 |
-
" public synchronized boolean allowRequest(int tokensNeeded) {\n"
|
| 175 |
-
" refill();\n"
|
| 176 |
-
" if (currentTokens.get() >= tokensNeeded) {\n"
|
| 177 |
-
" currentTokens.addAndGet(-tokensNeeded);\n"
|
| 178 |
-
" return true;\n"
|
| 179 |
-
" }\n"
|
| 180 |
-
" return false;\n"
|
| 181 |
-
" }\n\n"
|
| 182 |
-
" private void refill() {\n"
|
| 183 |
-
" long now = System.currentTimeMillis();\n"
|
| 184 |
-
" long timeElapsedMs = now - lastRefillTimestamp.get();\n"
|
| 185 |
-
" \n"
|
| 186 |
-
" // Calculate how many tokens to add based on time elapsed\n"
|
| 187 |
-
" long tokensToAdd = (timeElapsedMs / 1000) * refillRatePerSecond;\n\n"
|
| 188 |
-
" if (tokensToAdd > 0) {\n"
|
| 189 |
-
" // Hint: Look closely at how the tokens are updated here.\n"
|
| 190 |
-
" // Consider what happens if a user stops making requests for a long time.\n"
|
| 191 |
-
" currentTokens.addAndGet(tokensToAdd);\n"
|
| 192 |
-
" lastRefillTimestamp.set(now);\n"
|
| 193 |
-
" }\n"
|
| 194 |
-
" }\n"
|
| 195 |
-
"}"
|
| 196 |
-
),
|
| 197 |
-
"ground_truth": {
|
| 198 |
-
"bug_identified": True,
|
| 199 |
-
"bug_type_keywords": [
|
| 200 |
-
"logic", "limit", "overflow", "cap", "bound", "maximum", "exceed",
|
| 201 |
-
"logic error", "capacity",
|
| 202 |
-
],
|
| 203 |
-
"location_keywords": [
|
| 204 |
-
"currentTokens.addAndGet", "refill()", "tokensToAdd",
|
| 205 |
-
"currentTokens.get()", "addAndGet(tokensToAdd)",
|
| 206 |
-
],
|
| 207 |
-
"description_keywords": [
|
| 208 |
-
"exceed", "maxTokens", "cap", "limit", "bound",
|
| 209 |
-
"overflow", "infinite", "burst", "accumulate",
|
| 210 |
-
],
|
| 211 |
-
"fix_keywords": [
|
| 212 |
-
"Math.min", "min(", "set(", "if (currentTokens.get() > maxTokens)",
|
| 213 |
-
"compareAndSet", "cap",
|
| 214 |
-
],
|
| 215 |
-
"severity_valid": ["high", "medium"],
|
| 216 |
-
},
|
| 217 |
-
},
|
| 218 |
-
|
| 219 |
-
# ββ EXPERT 2 (C++) ββββββββββββββββββββββββ
|
| 220 |
-
"expert2": {
|
| 221 |
-
"id": "task_expert_002",
|
| 222 |
-
"difficulty": "expert2",
|
| 223 |
-
"language": "cpp",
|
| 224 |
-
"description": (
|
| 225 |
-
"This C++ class implements an event dispatcher. "
|
| 226 |
-
"Identify the concurrency bug that can occur when an event is dispatched."
|
| 227 |
-
),
|
| 228 |
-
"code": (
|
| 229 |
-
"#include <iostream>\n"
|
| 230 |
-
"#include <vector>\n"
|
| 231 |
-
"#include <functional>\n"
|
| 232 |
-
"#include <mutex>\n"
|
| 233 |
-
"#include <algorithm>\n"
|
| 234 |
-
"#include <string>\n\n"
|
| 235 |
-
"class EventDispatcher {\n"
|
| 236 |
-
"public:\n"
|
| 237 |
-
" using Callback = std::function<void(const std::string&)>;\n\n"
|
| 238 |
-
" void subscribe(int listener_id, Callback cb) {\n"
|
| 239 |
-
" std::lock_guard<std::mutex> lock(mut_);\n"
|
| 240 |
-
" listeners_.push_back({listener_id, cb});\n"
|
| 241 |
-
" }\n\n"
|
| 242 |
-
" void unsubscribe(int listener_id) {\n"
|
| 243 |
-
" std::lock_guard<std::mutex> lock(mut_);\n"
|
| 244 |
-
" listeners_.erase(\n"
|
| 245 |
-
" std::remove_if(listeners_.begin(), listeners_.end(),\n"
|
| 246 |
-
" [listener_id](const Listener& l) { return l.id == listener_id; }),\n"
|
| 247 |
-
" listeners_.end()\n"
|
| 248 |
-
" );\n"
|
| 249 |
-
" }\n\n"
|
| 250 |
-
" void dispatch(const std::string& event_data) {\n"
|
| 251 |
-
" std::lock_guard<std::mutex> lock(mut_);\n"
|
| 252 |
-
" for (const auto& listener : listeners_) {\n"
|
| 253 |
-
" // Hint: What happens if a listener decides to call unsubscribe() \n"
|
| 254 |
-
" // from inside their own callback function when an event fires?\n"
|
| 255 |
-
" listener.cb(event_data);\n"
|
| 256 |
-
" }\n"
|
| 257 |
-
" }\n\n"
|
| 258 |
-
"private:\n"
|
| 259 |
-
" struct Listener {\n"
|
| 260 |
-
" int id;\n"
|
| 261 |
-
" Callback cb;\n"
|
| 262 |
-
" };\n \n"
|
| 263 |
-
" std::vector<Listener> listeners_;\n"
|
| 264 |
-
" std::mutex mut_;\n"
|
| 265 |
-
"};"
|
| 266 |
-
),
|
| 267 |
-
"ground_truth": {
|
| 268 |
-
"bug_identified": True,
|
| 269 |
-
"bug_type_keywords": [
|
| 270 |
-
"deadlock", "concurrency", "lock", "recursive", "reentrant", "hang",
|
| 271 |
-
"iterator validation", "undefined behavior"
|
| 272 |
-
],
|
| 273 |
-
"location_keywords": [
|
| 274 |
-
"listener.cb", "unsubscribe", "dispatch", "mut_", "std::lock_guard",
|
| 275 |
-
"lock(mut_)"
|
| 276 |
-
],
|
| 277 |
-
"description_keywords": [
|
| 278 |
-
"deadlock", "already locked", "same thread", "recursive_mutex",
|
| 279 |
-
"reentrant", "hangs", "blocks", "invalidate", "iterator"
|
| 280 |
-
],
|
| 281 |
-
"fix_keywords": [
|
| 282 |
-
"std::recursive_mutex", "copy", "local copy", "copy the vector",
|
| 283 |
-
"unlock before", "queue", "deferred"
|
| 284 |
-
],
|
| 285 |
-
"severity_valid": ["high", "critical"],
|
| 286 |
-
},
|
| 287 |
-
},
|
| 288 |
-
}
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
# GRADER
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
def grade_action(action: CodeReviewAction, task: dict) -> Tuple[float, Dict]:
|
| 296 |
-
"""
|
| 297 |
-
Score the agent's review on a 0.0β1.0 scale.
|
| 298 |
-
|
| 299 |
-
Breakdown:
|
| 300 |
-
bug_identified 0.20
|
| 301 |
-
bug_type 0.20
|
| 302 |
-
bug_location 0.10
|
| 303 |
-
bug_description 0.25 (keyword density, capped)
|
| 304 |
-
suggested_fix 0.15 (keyword density, capped)
|
| 305 |
-
severity 0.10
|
| 306 |
-
βββββββββββββββββββββ
|
| 307 |
-
Total 1.00
|
| 308 |
-
"""
|
| 309 |
-
gt = task["ground_truth"]
|
| 310 |
-
score = 0.0
|
| 311 |
-
breakdown: Dict[str, float] = {}
|
| 312 |
-
|
| 313 |
-
# 1. Bug identification
|
| 314 |
-
if action.bug_identified == gt["bug_identified"]:
|
| 315 |
-
score += 0.20
|
| 316 |
-
breakdown["bug_identified"] = 0.20
|
| 317 |
-
else:
|
| 318 |
-
breakdown["bug_identified"] = 0.00
|
| 319 |
-
if not action.bug_identified:
|
| 320 |
-
return 0.0, {
|
| 321 |
-
"breakdown": breakdown,
|
| 322 |
-
"total_score": 0.0,
|
| 323 |
-
"feedback": "No bug identified β one definitely exists. Look more carefully.",
|
| 324 |
-
}
|
| 325 |
-
|
| 326 |
-
# 2. Bug type
|
| 327 |
-
bug_type_lower = action.bug_type.lower()
|
| 328 |
-
type_match = any(kw in bug_type_lower for kw in gt["bug_type_keywords"])
|
| 329 |
-
if type_match:
|
| 330 |
-
score += 0.20
|
| 331 |
-
breakdown["bug_type"] = 0.20
|
| 332 |
-
else:
|
| 333 |
-
breakdown["bug_type"] = 0.00
|
| 334 |
-
|
| 335 |
-
# 3. Bug location
|
| 336 |
-
loc_lower = action.bug_location.lower()
|
| 337 |
-
loc_match = any(kw.lower() in loc_lower for kw in gt["location_keywords"])
|
| 338 |
-
if loc_match:
|
| 339 |
-
score += 0.10
|
| 340 |
-
breakdown["bug_location"] = 0.10
|
| 341 |
-
else:
|
| 342 |
-
breakdown["bug_location"] = 0.00
|
| 343 |
-
|
| 344 |
-
# 4. Description quality (keyword density, capped at 0.25)
|
| 345 |
-
desc_lower = action.bug_description.lower()
|
| 346 |
-
desc_hits = sum(1 for kw in gt["description_keywords"] if kw.lower() in desc_lower)
|
| 347 |
-
desc_score = round(min(0.25, desc_hits * 0.07), 3)
|
| 348 |
-
score += desc_score
|
| 349 |
-
breakdown["bug_description"] = desc_score
|
| 350 |
-
|
| 351 |
-
# 5. Fix quality (keyword density, capped at 0.15)
|
| 352 |
-
fix_lower = action.suggested_fix.lower()
|
| 353 |
-
fix_hits = sum(1 for kw in gt["fix_keywords"] if kw.lower() in fix_lower)
|
| 354 |
-
fix_score = round(min(0.15, fix_hits * 0.08), 3)
|
| 355 |
-
score += fix_score
|
| 356 |
-
breakdown["suggested_fix"] = fix_score
|
| 357 |
-
|
| 358 |
-
# 6. Severity
|
| 359 |
-
if action.severity.lower() in gt["severity_valid"]:
|
| 360 |
-
score += 0.10
|
| 361 |
-
breakdown["severity"] = 0.10
|
| 362 |
-
else:
|
| 363 |
-
breakdown["severity"] = 0.00
|
| 364 |
-
|
| 365 |
-
total = round(min(1.0, score), 3)
|
| 366 |
-
|
| 367 |
-
# Build human-readable feedback
|
| 368 |
-
hints = []
|
| 369 |
-
if breakdown["bug_type"] == 0:
|
| 370 |
-
hints.append("Reconsider the bug category β be more specific.")
|
| 371 |
-
if breakdown["bug_location"] == 0:
|
| 372 |
-
hints.append("Pinpoint the exact line or expression that contains the bug.")
|
| 373 |
-
if breakdown["suggested_fix"] < 0.08:
|
| 374 |
-
hints.append("Your fix does not address the root cause β revise it.")
|
| 375 |
-
if breakdown["severity"] == 0:
|
| 376 |
-
hints.append("Re-evaluate the severity level.")
|
| 377 |
-
|
| 378 |
-
feedback = " ".join(hints) if hints else "Strong analysis β refine the fix if needed."
|
| 379 |
-
|
| 380 |
-
return total, {"breakdown": breakdown, "total_score": total, "feedback": feedback}
|
| 381 |
-
|
| 382 |
-
# ENVIRONMENT
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
class CodeReviewEnvironment:
|
| 386 |
def __init__(self):
|
| 387 |
-
self.
|
| 388 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
self.
|
| 395 |
-
self.
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
)
|
| 403 |
-
return self._build_obs(step_number=0, previous_feedback=None)
|
| 404 |
|
| 405 |
-
def
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
# Done if agent nailed it or max steps reached
|
| 414 |
-
done = reward >= 0.80 or self._state.step_count >= MAX_STEPS
|
| 415 |
-
self._state.done = done
|
| 416 |
-
self._state.task_complete = reward >= 0.80
|
| 417 |
-
|
| 418 |
-
feedback = info.get("feedback") if not done else None
|
| 419 |
-
obs = self._build_obs(
|
| 420 |
-
step_number=self._state.step_count,
|
| 421 |
-
previous_feedback=feedback,
|
| 422 |
)
|
| 423 |
-
return obs, reward, done, info
|
| 424 |
-
|
| 425 |
-
def state(self) -> CodeReviewState:
|
| 426 |
-
if self._state is None:
|
| 427 |
-
return CodeReviewState(
|
| 428 |
-
task_id="", difficulty="easy",
|
| 429 |
-
step_count=0, done=False,
|
| 430 |
-
total_reward=0.0, task_complete=False,
|
| 431 |
-
)
|
| 432 |
-
return self._state
|
| 433 |
-
|
| 434 |
-
# helpers
|
| 435 |
|
| 436 |
-
def
|
| 437 |
-
t = self.
|
| 438 |
-
return
|
| 439 |
-
code_snippet=t["code"],
|
| 440 |
-
language=t["language"],
|
| 441 |
-
task_description=t["description"],
|
| 442 |
task_id=t["id"],
|
|
|
|
| 443 |
difficulty=t["difficulty"],
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
|
|
|
| 447 |
)
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import Optional, Dict, Tuple
|
| 3 |
|
| 4 |
+
from server.tasks import TASKS
|
| 5 |
+
from server.grader import grade_action
|
| 6 |
+
from server.models import CodeObservation, StepResult, StateResponse, Action, Observation
|
| 7 |
|
| 8 |
+
class CodeSecurityEnv:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
def __init__(self):
|
| 10 |
+
self.current_task: Optional[dict] = None
|
| 11 |
+
self.step_count: int = 0
|
| 12 |
+
self.done: bool = False
|
| 13 |
+
self.total_reward: float = 0.0
|
| 14 |
+
self._task_ids = list(TASKS.keys())
|
| 15 |
+
|
| 16 |
+
def reset(self, task_id: Optional[str] = None, seed: Optional[int] = None) -> Observation:
|
| 17 |
+
if seed is not None:
|
| 18 |
+
random.seed(seed)
|
| 19 |
+
|
| 20 |
+
if task_id and task_id in TASKS:
|
| 21 |
+
self.current_task = TASKS[task_id]
|
| 22 |
+
else:
|
| 23 |
+
# Pick a task by its ID
|
| 24 |
+
chosen_id = random.choice(self._task_ids)
|
| 25 |
+
self.current_task = TASKS[chosen_id]
|
| 26 |
+
|
| 27 |
+
self.step_count = 0
|
| 28 |
+
self.done = False
|
| 29 |
+
self.total_reward = 0.0
|
| 30 |
+
|
| 31 |
+
return self._make_observation()
|
| 32 |
+
|
| 33 |
+
def step(self, action: Action) -> StepResult:
|
| 34 |
+
if self.current_task is None:
|
| 35 |
+
# Auto-reset if called before reset()
|
| 36 |
+
self.reset()
|
| 37 |
+
|
| 38 |
+
if self.done:
|
| 39 |
+
return StepResult(
|
| 40 |
+
observation=self._make_observation(),
|
| 41 |
+
reward=0.0,
|
| 42 |
+
done=True,
|
| 43 |
+
info={"error": "Episode already completed. Call /reset to start a new episode."},
|
| 44 |
+
)
|
| 45 |
|
| 46 |
+
# The action comes from the API as a Pydantic model (Action)
|
| 47 |
+
# The grader expects a dict or the model itself.
|
| 48 |
+
reward, breakdown = grade_action(action, self.current_task)
|
| 49 |
+
|
| 50 |
+
self.step_count += 1
|
| 51 |
+
self.total_reward += reward
|
| 52 |
+
self.done = True # single-step environment β one action per episode
|
| 53 |
+
|
| 54 |
+
return StepResult(
|
| 55 |
+
observation=self._make_observation(),
|
| 56 |
+
reward=reward,
|
| 57 |
+
done=self.done,
|
| 58 |
+
info={
|
| 59 |
+
"reward_breakdown": breakdown,
|
| 60 |
+
"task_name": self.current_task.get("name", "Unknown Task"),
|
| 61 |
+
"step_count": self.step_count
|
| 62 |
+
},
|
| 63 |
)
|
|
|
|
| 64 |
|
| 65 |
+
def state(self) -> StateResponse:
|
| 66 |
+
current_id = self.current_task["id"] if self.current_task else ""
|
| 67 |
+
return StateResponse(
|
| 68 |
+
task_id=current_id,
|
| 69 |
+
step=self.step_count,
|
| 70 |
+
done=self.done,
|
| 71 |
+
total_reward=self.total_reward,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
+
def _make_observation(self) -> Observation:
|
| 75 |
+
t = self.current_task
|
| 76 |
+
return Observation(
|
|
|
|
|
|
|
|
|
|
| 77 |
task_id=t["id"],
|
| 78 |
+
language=t["language"],
|
| 79 |
difficulty=t["difficulty"],
|
| 80 |
+
code_snippet=t["code_snippet"],
|
| 81 |
+
context=t["context"],
|
| 82 |
+
pr_title=t["pr_title"],
|
| 83 |
+
file_path=t["file_path"],
|
| 84 |
)
|
server/grader.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Dict
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def grade_action(action: dict, task: dict) -> Tuple[float, Dict[str, float]]:
|
| 5 |
+
reward = 0.0
|
| 6 |
+
breakdown: Dict[str, float] = {}
|
| 7 |
+
|
| 8 |
+
# ββ Component 1: Bug identified (0.20) ββββββββββββββββββββββββββββββββββ
|
| 9 |
+
if action.get("bug_identified"):
|
| 10 |
+
reward += 0.20
|
| 11 |
+
breakdown["bug_identified"] = 0.20
|
| 12 |
+
else:
|
| 13 |
+
breakdown["bug_identified"] = 0.00
|
| 14 |
+
# No bug found β no partial credit for anything else
|
| 15 |
+
return max(0.0, min(1.0, reward)), breakdown
|
| 16 |
+
|
| 17 |
+
# ββ Component 2: Bug type match (0.20) ββββββββββββββββββββββββββββββββββ
|
| 18 |
+
action_type = action.get("bug_type", "").lower().replace("-", " ").replace("_", " ")
|
| 19 |
+
task_type = task["bug_type"].lower().replace("-", " ").replace("_", " ")
|
| 20 |
+
if task_type in action_type or action_type in task_type:
|
| 21 |
+
reward += 0.20
|
| 22 |
+
breakdown["bug_type"] = 0.20
|
| 23 |
+
else:
|
| 24 |
+
breakdown["bug_type"] = 0.00
|
| 25 |
+
|
| 26 |
+
# ββ Component 3: Bug location (0.10) ββββββββββββββββββββββββββββββββββββ
|
| 27 |
+
action_location = action.get("bug_location", "").lower()
|
| 28 |
+
location_keywords = [w for w in task["bug_location"].lower().split() if len(w) > 3]
|
| 29 |
+
if location_keywords:
|
| 30 |
+
matched = sum(1 for kw in location_keywords if kw in action_location)
|
| 31 |
+
loc_score = round(0.10 * (matched / len(location_keywords)), 4)
|
| 32 |
+
else:
|
| 33 |
+
loc_score = 0.0
|
| 34 |
+
reward += loc_score
|
| 35 |
+
breakdown["bug_location"] = loc_score
|
| 36 |
+
|
| 37 |
+
# ββ Component 4: Description quality (0.25) ββββββββββββββββββββββββββββββ
|
| 38 |
+
description = action.get("bug_description", "").lower()
|
| 39 |
+
desc_score = 0.0
|
| 40 |
+
if len(description) >= 20:
|
| 41 |
+
task_keywords = task["keywords"]
|
| 42 |
+
matched_kw = [kw for kw in task_keywords if kw in description]
|
| 43 |
+
desc_score = round(min(0.25, 0.25 * (len(matched_kw) / max(len(task_keywords), 1))), 4)
|
| 44 |
+
breakdown["description_quality"] = desc_score
|
| 45 |
+
reward += desc_score
|
| 46 |
+
|
| 47 |
+
# ββ Component 5: Fix quality (0.15) ββββββββββββββββββββββββββββββββββββββ
|
| 48 |
+
fix = action.get("suggested_fix", "").lower()
|
| 49 |
+
fix_score = 0.0
|
| 50 |
+
if len(fix) >= 10:
|
| 51 |
+
fix_patterns = task["fix_patterns"]
|
| 52 |
+
matched_fix = [p for p in fix_patterns if p.lower() in fix]
|
| 53 |
+
fix_score = round(min(0.15, 0.15 * (len(matched_fix) / max(len(fix_patterns), 1)) * 2), 4)
|
| 54 |
+
breakdown["fix_quality"] = fix_score
|
| 55 |
+
reward += fix_score
|
| 56 |
+
|
| 57 |
+
# ββ Component 6: Severity (0.10) βββββββββββββββββββββββββββββββββββββββββ
|
| 58 |
+
action_sev = action.get("severity", "").lower()
|
| 59 |
+
task_sev = task["severity"].lower()
|
| 60 |
+
if action_sev == task_sev:
|
| 61 |
+
sev_score = 0.10
|
| 62 |
+
elif action_sev in ("high", "critical") and task_sev in ("high", "critical"):
|
| 63 |
+
sev_score = 0.05
|
| 64 |
+
else:
|
| 65 |
+
sev_score = 0.00
|
| 66 |
+
breakdown["severity"] = sev_score
|
| 67 |
+
reward += sev_score
|
| 68 |
+
|
| 69 |
+
# ββ Global Penalty: Keyword Stuffing ββββββββββββββββββββββββββββββββββββ
|
| 70 |
+
description = action.get("bug_description", "").lower()
|
| 71 |
+
words = description.split()
|
| 72 |
+
unique_ratio = len(set(words)) / len(words) if words else 1.0
|
| 73 |
+
if unique_ratio < 0.7:
|
| 74 |
+
reward *= 0.2 # Heavy global penalty
|
| 75 |
+
breakdown["stuffing_penalty_multiplier"] = 0.2
|
| 76 |
+
for k in list(breakdown.keys()):
|
| 77 |
+
if k != "stuffing_penalty_multiplier":
|
| 78 |
+
breakdown[k] = round(breakdown[k] * 0.2, 4)
|
| 79 |
+
|
| 80 |
+
return max(0.0, min(1.0, round(reward, 4))), breakdown
|
server/models.py
CHANGED
|
@@ -2,44 +2,63 @@ from pydantic import BaseModel, Field
|
|
| 2 |
from typing import Optional, Any, Dict
|
| 3 |
|
| 4 |
|
|
|
|
|
|
|
| 5 |
class CodeReviewAction(BaseModel):
|
| 6 |
"""Action taken by the agent: a structured code review."""
|
| 7 |
bug_identified: bool = Field(..., description="Whether a bug was found")
|
| 8 |
bug_location: str = Field(..., description="Location of the bug (function, line, variable)")
|
| 9 |
-
bug_type: str = Field(..., description="Type: off-by-one | logic-error | security-vulnerability |
|
| 10 |
bug_description: str = Field(..., description="Detailed explanation of why this is a bug")
|
| 11 |
severity: str = Field(..., description="Severity: none | low | medium | high | critical")
|
| 12 |
suggested_fix: str = Field(..., description="The corrected code or a description of how to fix it")
|
| 13 |
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
| 16 |
"""What the agent sees at each step."""
|
| 17 |
-
code_snippet: str = Field(..., description="The code to review")
|
| 18 |
-
language: str = Field(..., description="Programming language")
|
| 19 |
-
task_description: str = Field(..., description="What the code is supposed to do")
|
| 20 |
task_id: str = Field(..., description="Unique task identifier")
|
|
|
|
| 21 |
difficulty: str = Field(..., description="Level: easy | medium | hard")
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
|
| 27 |
-
class CodeReviewState(BaseModel):
|
| 28 |
-
"""Internal environment state."""
|
| 29 |
-
task_id: str
|
| 30 |
-
difficulty: str
|
| 31 |
-
step_count: int
|
| 32 |
-
done: bool
|
| 33 |
-
total_reward: float
|
| 34 |
-
task_complete: bool
|
| 35 |
|
|
|
|
| 36 |
|
| 37 |
-
class
|
| 38 |
-
|
|
|
|
| 39 |
reward: float
|
| 40 |
done: bool
|
| 41 |
info: Dict[str, Any]
|
| 42 |
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
class ResetResponse(BaseModel):
|
| 45 |
-
observation:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from typing import Optional, Any, Dict
|
| 3 |
|
| 4 |
|
| 5 |
+
# ββ Agent Action ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 6 |
+
|
| 7 |
class CodeReviewAction(BaseModel):
|
| 8 |
"""Action taken by the agent: a structured code review."""
|
| 9 |
bug_identified: bool = Field(..., description="Whether a bug was found")
|
| 10 |
bug_location: str = Field(..., description="Location of the bug (function, line, variable)")
|
| 11 |
+
bug_type: str = Field(..., description="Type: off-by-one | logic-error | security-vulnerability | none")
|
| 12 |
bug_description: str = Field(..., description="Detailed explanation of why this is a bug")
|
| 13 |
severity: str = Field(..., description="Severity: none | low | medium | high | critical")
|
| 14 |
suggested_fix: str = Field(..., description="The corrected code or a description of how to fix it")
|
| 15 |
|
| 16 |
|
| 17 |
+
# ββ Observation βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 18 |
+
|
| 19 |
+
class CodeObservation(BaseModel):
|
| 20 |
"""What the agent sees at each step."""
|
|
|
|
|
|
|
|
|
|
| 21 |
task_id: str = Field(..., description="Unique task identifier")
|
| 22 |
+
language: str = Field(..., description="Programming language")
|
| 23 |
difficulty: str = Field(..., description="Level: easy | medium | hard")
|
| 24 |
+
code_snippet: str = Field(..., description="The code to review")
|
| 25 |
+
context: str = Field(..., description="Production context describing what the code does")
|
| 26 |
+
pr_title: str = Field(..., description="Pull request title submitted by developer")
|
| 27 |
+
file_path: str = Field(..., description="File path of the code in the repository")
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
# ββ Step Result βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
|
| 32 |
+
class StepResult(BaseModel):
|
| 33 |
+
"""Result returned from env.step()."""
|
| 34 |
+
observation: Optional[CodeObservation] = None
|
| 35 |
reward: float
|
| 36 |
done: bool
|
| 37 |
info: Dict[str, Any]
|
| 38 |
|
| 39 |
|
| 40 |
+
# ββ State βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
+
|
| 42 |
+
class StateResponse(BaseModel):
|
| 43 |
+
"""Internal environment state exposed via /state."""
|
| 44 |
+
task_id: str
|
| 45 |
+
step: int
|
| 46 |
+
done: bool
|
| 47 |
+
total_reward: float
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ββ API Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 51 |
+
|
| 52 |
class ResetResponse(BaseModel):
|
| 53 |
+
observation: CodeObservation
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class TaskInfo(BaseModel):
|
| 57 |
+
id: str
|
| 58 |
+
language: str
|
| 59 |
+
bug_class: str
|
| 60 |
+
difficulty: str
|
| 61 |
+
|
| 62 |
+
Action = CodeReviewAction
|
| 63 |
+
Observation = CodeObservation
|
| 64 |
+
Reward = float
|
server/tasks.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASKS = {
|
| 2 |
+
"python-off-by-one": {
|
| 3 |
+
"id": "python-off-by-one",
|
| 4 |
+
"name": "Python Off-by-One Error",
|
| 5 |
+
"language": "Python",
|
| 6 |
+
"difficulty": "easy",
|
| 7 |
+
"bug_class": "Off-by-one index error",
|
| 8 |
+
"pr_title": "Add batch processor for financial transactions",
|
| 9 |
+
"file_path": "finance/batch_processor.py",
|
| 10 |
+
"context": "Finance batch processor that sums transaction amounts for end-of-day reconciliation",
|
| 11 |
+
"code_snippet": (
|
| 12 |
+
"def process_transactions(transactions):\n"
|
| 13 |
+
" total = 0\n"
|
| 14 |
+
" for i in range(len(transactions) + 1): # iterates one past end\n"
|
| 15 |
+
" total += transactions[i][\"amount\"]\n"
|
| 16 |
+
" return total"
|
| 17 |
+
),
|
| 18 |
+
"bug_type": "off-by-one",
|
| 19 |
+
"bug_location": "line 3 β range(len(transactions) + 1)",
|
| 20 |
+
"severity": "critical",
|
| 21 |
+
"keywords": [
|
| 22 |
+
"off-by-one", "index", "range", "indexerror", "out of bounds",
|
| 23 |
+
"boundary", "overflow", "iteration", "list length", "plus one",
|
| 24 |
+
"extra step", "fencepost error", "array access", "iterator",
|
| 25 |
+
"fix", "bug", "identify", "code", "crash", "out-of-range",
|
| 26 |
+
"python", "finance", "batch", "amount", "total", "transactions",
|
| 27 |
+
"iterate", "sum", "loop", "account", "process"
|
| 28 |
+
],
|
| 29 |
+
"fix_patterns": [
|
| 30 |
+
"range(len(transactions))",
|
| 31 |
+
"len(transactions))",
|
| 32 |
+
"for transaction in transactions",
|
| 33 |
+
"in transactions:",
|
| 34 |
+
"pop()",
|
| 35 |
+
"enumerate(transactions)",
|
| 36 |
+
"transactions[:len(transactions)]",
|
| 37 |
+
"total += transactions[i]"
|
| 38 |
+
],
|
| 39 |
+
},
|
| 40 |
+
|
| 41 |
+
"js-auth-privilege": {
|
| 42 |
+
"id": "js-auth-privilege",
|
| 43 |
+
"name": "JavaScript Auth Logic Flaw",
|
| 44 |
+
"language": "JavaScript",
|
| 45 |
+
"difficulty": "medium",
|
| 46 |
+
"bug_class": "Logic flaw β privilege escalation",
|
| 47 |
+
"pr_title": "Refactor auth middleware for API routes",
|
| 48 |
+
"file_path": "middleware/auth.js",
|
| 49 |
+
"context": "Node.js authentication middleware that restricts admin-only API routes",
|
| 50 |
+
"code_snippet": (
|
| 51 |
+
"function checkAdmin(req, res, next) {\n"
|
| 52 |
+
" const user = req.user;\n"
|
| 53 |
+
" if (user.role !== \"admin\" || user.isActive) {\n"
|
| 54 |
+
" return next();\n"
|
| 55 |
+
" }\n"
|
| 56 |
+
" return res.status(403).json({ error: \"Forbidden\" });\n"
|
| 57 |
+
"}"
|
| 58 |
+
),
|
| 59 |
+
"bug_type": "logic-error",
|
| 60 |
+
"bug_location": "line 3 β incorrect boolean operator || instead of &&",
|
| 61 |
+
"severity": "critical",
|
| 62 |
+
"keywords": [
|
| 63 |
+
"short-circuit disjunction hazard", "logical disjunction vulnerability",
|
| 64 |
+
"excessive authorization scope", "privilege escalation vector",
|
| 65 |
+
"boolean logic flaw pattern", "operator precedence violation",
|
| 66 |
+
"authorization bypass disjunction logic", "improper validation layer check",
|
| 67 |
+
"role check disjunction pattern match", "permission leak evaluation flow",
|
| 68 |
+
"evaluation shortcut logic flaw", "middleware logic hazard state",
|
| 69 |
+
"security constraint bypass", "access control logic inversion"
|
| 70 |
+
],
|
| 71 |
+
"fix_patterns": [
|
| 72 |
+
"user.role === \"admin\" && user.isActive",
|
| 73 |
+
"&& user.isActive",
|
| 74 |
+
"throw new Error(\"Unauthorized\")",
|
| 75 |
+
"user.role === 'admin' && user.isActive",
|
| 76 |
+
"middleware logic fix"
|
| 77 |
+
],
|
| 78 |
+
},
|
| 79 |
+
|
| 80 |
+
"python-sql-injection": {
|
| 81 |
+
"id": "python-sql-injection",
|
| 82 |
+
"name": "Python SQL Injection",
|
| 83 |
+
"language": "Python",
|
| 84 |
+
"difficulty": "hard",
|
| 85 |
+
"bug_class": "SQL injection via f-string",
|
| 86 |
+
"pr_title": "Add user search endpoint to REST API",
|
| 87 |
+
"file_path": "api/users.py",
|
| 88 |
+
"context": "REST API endpoint that searches users by name in a PostgreSQL database",
|
| 89 |
+
"code_snippet": (
|
| 90 |
+
"def search_users(db, search_term):\n"
|
| 91 |
+
" query = f\"SELECT * FROM users WHERE name LIKE '%{search_term}%'\"\n"
|
| 92 |
+
" results = db.execute(query)\n"
|
| 93 |
+
" return results.fetchall()"
|
| 94 |
+
),
|
| 95 |
+
"bug_type": "security-vulnerability",
|
| 96 |
+
"bug_location": "line 2 β f-string interpolation directly in SQL query",
|
| 97 |
+
"severity": "critical",
|
| 98 |
+
"keywords": [
|
| 99 |
+
"sql injection", "user-supplied", "search_term", "interpolated", "f-string",
|
| 100 |
+
"attacker", "bypass", "authentication", "exfiltrate", "user data",
|
| 101 |
+
"drop tables", "parameterized", "queries", "sanitize", "input", "automatically"
|
| 102 |
+
],
|
| 103 |
+
"fix_patterns": [
|
| 104 |
+
"db.execute('SELECT * FROM users WHERE name LIKE %s', ('%'+search_term+'%',))",
|
| 105 |
+
"%s",
|
| 106 |
+
"parameterized",
|
| 107 |
+
"prepared statement"
|
| 108 |
+
],
|
| 109 |
+
},
|
| 110 |
+
}
|
validate.sh
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# OpenEnv Submission Validation Script
|
| 4 |
+
|
| 5 |
+
set -e
|
| 6 |
+
echo "βββββββββββββββββββββββββββββββββββββββ"
|
| 7 |
+
echo " OpenEnv Pre-Submission Validation"
|
| 8 |
+
echo "βββββββββββββββββββββββββββββββββββββββ"
|
| 9 |
+
echo ""
|
| 10 |
+
|
| 11 |
+
# 1. Check for required root files
|
| 12 |
+
echo "ββ 1. Required Files ββ"
|
| 13 |
+
FILES=("openenv.yaml" "inference.py" "README.md" "Dockerfile" "requirements.txt")
|
| 14 |
+
for file in "${FILES[@]}"; do
|
| 15 |
+
if [ -f "$file" ]; then
|
| 16 |
+
echo " β
$file"
|
| 17 |
+
else
|
| 18 |
+
echo " β Missing $file"
|
| 19 |
+
exit 1
|
| 20 |
+
fi
|
| 21 |
+
done
|
| 22 |
+
echo ""
|
| 23 |
+
|
| 24 |
+
# 2. Check server/ module structure
|
| 25 |
+
echo "ββ 2. Server Module Structure ββ"
|
| 26 |
+
SERVER_FILES=("server/__init__.py" "server/app.py" "server/models.py" "server/environment.py" "server/tasks.py" "server/grader.py")
|
| 27 |
+
for file in "${SERVER_FILES[@]}"; do
|
| 28 |
+
if [ -f "$file" ]; then
|
| 29 |
+
echo " β
$file"
|
| 30 |
+
else
|
| 31 |
+
echo " β Missing $file"
|
| 32 |
+
exit 1
|
| 33 |
+
fi
|
| 34 |
+
done
|
| 35 |
+
echo ""
|
| 36 |
+
|
| 37 |
+
# 3. Activate venv & validate Python imports
|
| 38 |
+
echo "ββ 3. Python Import Validation ββ"
|
| 39 |
+
source venv/bin/activate
|
| 40 |
+
python3 -c "
|
| 41 |
+
from server.tasks import TASKS
|
| 42 |
+
from server.grader import grade_action
|
| 43 |
+
from server.environment import CodeSecurityEnv
|
| 44 |
+
from server.models import CodeReviewAction, CodeObservation, StepResult, StateResponse, ResetResponse, TaskInfo
|
| 45 |
+
|
| 46 |
+
assert len(TASKS) >= 3, f'Expected 3+ tasks, got {len(TASKS)}'
|
| 47 |
+
print(' β
All imports resolve correctly')
|
| 48 |
+
print(f' Tasks: {list(TASKS.keys())}')
|
| 49 |
+
" || { echo " β Python import validation failed"; exit 1; }
|
| 50 |
+
echo ""
|
| 51 |
+
|
| 52 |
+
# 4. Quick grader smoke test
|
| 53 |
+
echo "ββ 4. Grader Smoke Test ββ"
|
| 54 |
+
python3 -c "
|
| 55 |
+
from server.environment import CodeSecurityEnv
|
| 56 |
+
from server.models import Action
|
| 57 |
+
|
| 58 |
+
env = CodeSecurityEnv()
|
| 59 |
+
obs = env.reset('python-off-by-one')
|
| 60 |
+
result = env.step(Action(**{
|
| 61 |
+
'bug_identified': True,
|
| 62 |
+
'bug_location': 'range(len(transactions) + 1)',
|
| 63 |
+
'bug_type': 'logic-error',
|
| 64 |
+
'bug_description': 'Off-by-one index error β the range goes one past the end causing an out of bounds IndexError',
|
| 65 |
+
'severity': 'medium',
|
| 66 |
+
'suggested_fix': 'Use range(len(transactions)) to fix the boundary',
|
| 67 |
+
}))
|
| 68 |
+
assert 0.0 <= result.reward <= 1.0, f'Reward out of range: {result.reward}'
|
| 69 |
+
assert result.done is True
|
| 70 |
+
print(f' β
Grader returned reward={result.reward:.4f}, done={result.done}')
|
| 71 |
+
|
| 72 |
+
# Verify zero-reward path
|
| 73 |
+
env2 = CodeSecurityEnv()
|
| 74 |
+
env2.reset('python-off-by-one')
|
| 75 |
+
r2 = env2.step(Action(**{
|
| 76 |
+
'bug_identified': False,
|
| 77 |
+
'bug_location': '',
|
| 78 |
+
'bug_type': 'none',
|
| 79 |
+
'bug_description': 'No bug found',
|
| 80 |
+
'severity': 'none',
|
| 81 |
+
'suggested_fix': '',
|
| 82 |
+
}))
|
| 83 |
+
assert r2.reward == 0.0, f'Expected 0.0 for no-bug, got {r2.reward}'
|
| 84 |
+
print(f' β
No-bug path returns reward=0.0')
|
| 85 |
+
" || { echo " β Grader smoke test failed"; exit 1; }
|
| 86 |
+
echo ""
|
| 87 |
+
|
| 88 |
+
# 5. Validate openenv.yaml
|
| 89 |
+
echo "ββ 5. openenv.yaml Validation ββ"
|
| 90 |
+
python3 -c "
|
| 91 |
+
import yaml
|
| 92 |
+
with open('openenv.yaml', 'r') as f:
|
| 93 |
+
data = yaml.safe_load(f)
|
| 94 |
+
assert 'name' in data, 'Missing name field'
|
| 95 |
+
assert 'tasks' in data, 'Missing tasks field'
|
| 96 |
+
assert len(data['tasks']) >= 3, f'Need 3+ tasks, got {len(data[\"tasks\"])}'
|
| 97 |
+
print(f' β
Valid YAML with {len(data[\"tasks\"])} tasks')
|
| 98 |
+
" || { echo " β openenv.yaml validation failed"; exit 1; }
|
| 99 |
+
echo ""
|
| 100 |
+
|
| 101 |
+
echo "βββββββββββββββββββββββββββββββββββββββ"
|
| 102 |
+
echo " β
All checks passed!"
|
| 103 |
+
echo "βββββββββββββββββββββββββββββββββββββββ"
|