Nitish commited on
Commit
474eafa
Β·
1 Parent(s): 742e175

feat: finalize OpenEnv alignment and calibrate rewards for QA

Browse files
Files changed (11) hide show
  1. Dockerfile +3 -4
  2. README.md +87 -113
  3. inference.py +75 -52
  4. openenv.yaml +50 -43
  5. qa_test.py +237 -0
  6. server/app.py +39 -27
  7. server/environment.py +73 -436
  8. server/grader.py +80 -0
  9. server/models.py +39 -20
  10. server/tasks.py +110 -0
  11. validate.sh +103 -0
Dockerfile CHANGED
@@ -7,11 +7,10 @@ COPY requirements.txt .
7
  RUN pip install --no-cache-dir --upgrade pip && \
8
  pip install --no-cache-dir -r requirements.txt
9
 
10
- # Copy application code
11
- COPY server/ ./server/
12
- COPY static/ ./static/
13
 
14
- # Environment defaults
15
  ENV PORT=7860
16
  ENV PYTHONPATH=/app
17
  ENV ENABLE_WEB_INTERFACE=false
 
7
  RUN pip install --no-cache-dir --upgrade pip && \
8
  pip install --no-cache-dir -r requirements.txt
9
 
10
+ # Copy all project files (needed for openenv validate to work inside)
11
+ COPY . .
 
12
 
13
+ # Environment defaults (Hugging Face Spaces use 7860)
14
  ENV PORT=7860
15
  ENV PYTHONPATH=/app
16
  ENV ENABLE_WEB_INTERFACE=false
README.md CHANGED
@@ -1,156 +1,130 @@
 
 
 
 
 
 
 
 
1
  ---
2
- title: Code Review Env
3
- emoji: πŸƒ
4
- colorFrom: red
5
- colorTo: purple
6
- sdk: docker
7
- pinned: false
8
- ---
9
- # Code Security Review β€” OpenEnv
10
 
11
- > An RL environment for training AI agents to detect bugs and security
12
- > vulnerabilities in Python code.
13
 
14
- ## Motivation
 
 
 
 
 
 
15
 
16
- Code review is one of the highest-leverage tasks in software engineering, yet it
17
- remains bottlenecked on human attention. This environment trains agents to catch
18
- real bug categories β€” from simple off-by-one errors to critical SQL injection
19
- vulnerabilities β€” using structured, deterministic reward signals.
 
 
 
 
 
20
 
21
  ---
22
 
23
  ## Action Space
24
 
25
- | Field | Type | Description |
26
- |---|---|---|
27
- | `bug_identified` | bool | Whether a bug was found |
28
- | `bug_location` | string | Exact location (function, expression) |
29
- | `bug_type` | string | `off-by-one`, `logic-error`, `security-vulnerability`, `none` |
30
- | `bug_description` | string | Explanation of the bug and its impact |
31
- | `severity` | string | `none` / `low` / `medium` / `high` / `critical` |
32
- | `suggested_fix` | string | Corrected code or fix description |
 
 
 
 
33
 
34
  ## Observation Space
35
 
36
- | Field | Type | Description |
37
- |---|---|---|
38
- | `code_snippet` | string | The code to review |
39
- | `language` | string | Programming language |
40
- | `task_description` | string | What the code is supposed to do |
41
- | `task_id` | string | Unique task identifier |
42
- | `difficulty` | string | `easy` / `medium` / `hard` |
43
- | `step_number` | int | Current step within the episode |
44
- | `max_steps` | int | Maximum steps allowed (3) |
45
- | `previous_feedback` | string? | Feedback from prior step |
 
46
 
47
  ---
48
 
49
- ## Tasks
50
 
51
- ### Easy β€” Off-by-one in array traversal
52
- - **Code:** `sum_elements(arr)` iterates `range(1, len(arr)+1)` causing `IndexError`
53
- - **Expected bug type:** `off-by-one`
54
- - **Expected severity:** `high`
55
- - **Baseline score:** ~0.72
 
 
 
 
56
 
57
- ### Medium β€” Authentication logic flaw
58
- - **Code:** `authenticate_user()` uses `or` instead of `and` for admin check
59
- - **Expected bug type:** `logic-error`
60
- - **Expected severity:** `critical`
61
- - **Baseline score:** ~0.60
62
 
63
- ### Hard β€” SQL injection via f-string
64
- - **Code:** `fetch_records()` interpolates `user_id` and `sort_column` directly into SQL
65
- - **Expected bug type:** `security-vulnerability`
66
- - **Expected severity:** `critical`
67
- - **Baseline score:** ~0.55
68
 
69
  ---
70
 
71
- ## Reward Function
 
 
 
 
 
 
72
 
73
- Rewards are deterministic and provide partial progress signal:
74
 
75
- | Component | Max Score | Description |
76
  |---|---|---|
77
- | Bug identified | 0.20 | Correctly flags presence/absence of bug |
78
- | Bug type | 0.20 | Correct category of bug |
79
- | Bug location | 0.10 | Precise location identified |
80
- | Description quality | 0.25 | Keyword density in explanation |
81
- | Fix quality | 0.15 | Correct fix keywords present |
82
- | Severity | 0.10 | Correct severity level |
83
- | **Total** | **1.00** | |
84
 
85
  ---
86
 
87
  ## Setup
88
 
89
- ### 1. Build and run Docker
90
 
91
  ```bash
92
- docker build -t code-review-env .
93
- docker run -p 7860:7860 code-review-env
94
  ```
95
 
96
- ### 2. Run inference baseline
97
 
98
  ```bash
99
- # Set your environment variables
100
- export HF_TOKEN=hf_your_token_here
101
- export MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
102
- export API_BASE_URL=https://router.huggingface.co/v1
103
- export ENV_BASE_URL=http://localhost:7860
104
-
105
- # Install dependencies
106
  pip install -r requirements.txt
107
-
108
- # Run
109
- python inference.py
110
- ```
111
-
112
- ### 3. Validate (OpenEnv CLI)
113
-
114
- ```bash
115
- openenv validate
116
  ```
117
 
118
  ---
119
 
120
- ## API Endpoints
121
-
122
- | Method | Path | Description |
123
- |---|---|---|
124
- | GET | `/health` | Health check |
125
- | POST | `/reset?difficulty=easy` | Reset environment |
126
- | POST | `/step` | Submit a review action |
127
- | GET | `/state` | Current episode state |
128
-
129
- ---
130
-
131
- ## Baseline Scores
132
-
133
- | Task | Difficulty | Reward |
134
- |---|---|---|
135
- | Off-by-one detection | Easy | ~0.72 |
136
- | Auth logic flaw | Medium | ~0.60 |
137
- | SQL injection | Hard | ~0.55 |
138
- | **Average** | | **~0.62** |
139
-
140
- ---
141
 
142
- ## Project Structure
 
 
 
 
143
 
144
- ```
145
- code-review-env/
146
- β”œβ”€β”€ Dockerfile
147
- β”œβ”€β”€ openenv.yaml
148
- β”œβ”€β”€ requirements.txt
149
- β”œβ”€β”€ inference.py
150
- β”œβ”€β”€ README.md
151
- └── server/
152
- β”œβ”€β”€ __init__.py
153
- β”œβ”€β”€ app.py # FastAPI endpoints
154
- β”œβ”€β”€ environment.py # Tasks + grader logic
155
- └── models.py # Pydantic action/observation/state
156
  ```
 
1
+ # Code Security Review β€” OpenEnv Environment
2
+
3
+ An RL environment for training AI agents to perform real-world code security review.
4
+ Agents analyze code snippets from production pull requests and identify bugs,
5
+ vulnerabilities, and security issues.
6
+
7
+ Built by **Inmodel Labs** for the Meta PyTorch OpenEnv Hackathon.
8
+
9
  ---
 
 
 
 
 
 
 
 
10
 
11
+ ## Environment Overview
 
12
 
13
+ | Field | Value |
14
+ |---|---|
15
+ | Tasks | 3 (easy β†’ medium β†’ hard) |
16
+ | Languages | Python, JavaScript |
17
+ | Action space | Structured JSON (6 fields) |
18
+ | Reward range | 0.0 – 1.0 |
19
+ | Steps per episode | 1 |
20
 
21
+ ---
22
+
23
+ ## Tasks
24
+
25
+ | ID | Language | Bug Class | Difficulty |
26
+ |---|---|---|---|
27
+ | `python-off-by-one` | Python | Off-by-one index error | Easy |
28
+ | `js-auth-privilege` | JavaScript | Logic flaw β€” privilege escalation | Medium |
29
+ | `python-sql-injection` | Python | SQL injection via f-string | Hard |
30
 
31
  ---
32
 
33
  ## Action Space
34
 
35
+ The agent submits a JSON action with these fields:
36
+
37
+ ```json
38
+ {
39
+ "bug_identified": true,
40
+ "bug_location": "line 3 β€” range(len(transactions) + 1)",
41
+ "bug_type": "logic-error",
42
+ "bug_description": "Off-by-one error causes IndexError on last iteration...",
43
+ "severity": "medium",
44
+ "suggested_fix": "Change range(len(transactions) + 1) to range(len(transactions))"
45
+ }
46
+ ```
47
 
48
  ## Observation Space
49
 
50
+ ```json
51
+ {
52
+ "task_id": "python-sql-injection",
53
+ "language": "Python",
54
+ "difficulty": "hard",
55
+ "code_snippet": "def search_users(db, search_term):\n ...",
56
+ "context": "REST API endpoint that searches users by name",
57
+ "pr_title": "Add user search endpoint to REST API",
58
+ "file_path": "api/users.py"
59
+ }
60
+ ```
61
 
62
  ---
63
 
64
+ ## Reward Breakdown
65
 
66
+ | Component | Max Score |
67
+ |---|---|
68
+ | Bug identified | 0.20 |
69
+ | Bug type correct | 0.20 |
70
+ | Bug location correct | 0.10 |
71
+ | Description quality | 0.25 |
72
+ | Fix quality | 0.15 |
73
+ | Severity correct | 0.10 |
74
+ | **Total** | **1.00** |
75
 
76
+ The grader penalises keyword stuffing β€” incoherent keyword dumps score ≀ 0.20.
 
 
 
 
77
 
78
+ **Example Calculation:**
79
+ If the agent correctly identifies a bug (+0.20), misidentifies the type (+0.0), finds 50% of the location keywords (+0.05), writes a detailed and coherent description matching most keywords (+0.25), suggests a partially correct fix (+0.08), and gets the severity correct (+0.10), the total reward for that step would be `0.20 + 0.0 + 0.05 + 0.25 + 0.08 + 0.10 = 0.68`.
 
 
 
80
 
81
  ---
82
 
83
+ ## Edge Cases
84
+
85
+ - **At step 0:** `reset()` must be called to initialize the state. If `step()` is called before `reset()`, the environment automatically calls `reset()` internally and evaluates the action on a random task.
86
+ - **Max step limit:** The maximum step limit is 1. Calling `step()` evaluates the action and immediately sets `done=True`.
87
+ - **At done=True:** Calling `step()` returns `reward=0.0`, `done=True`, and a clean error message in the `info` dict `("Episode already completed. Call /reset...")` indicating the episode is complete without auto-resetting.
88
+
89
+ ---
90
 
91
+ ## API Endpoints
92
 
93
+ | Method | Path | Description |
94
  |---|---|---|
95
+ | GET | `/` | Health check |
96
+ | POST | `/reset?task_id=<id>` | Reset environment, returns observation |
97
+ | POST | `/step` | Submit action, returns reward |
98
+ | GET | `/state` | Current episode state |
99
+ | GET | `/tasks` | List all tasks |
 
 
100
 
101
  ---
102
 
103
  ## Setup
104
 
105
+ ### Docker
106
 
107
  ```bash
108
+ docker build -t code-security-review .
109
+ docker run -p 8000:8000 code-security-review
110
  ```
111
 
112
+ ### Local
113
 
114
  ```bash
 
 
 
 
 
 
 
115
  pip install -r requirements.txt
116
+ uvicorn server.app:app --host 0.0.0.0 --port 8000
 
 
 
 
 
 
 
 
117
  ```
118
 
119
  ---
120
 
121
+ ## Running Inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ ```bash
124
+ export API_BASE_URL="https://api.openai.com/v1"
125
+ export MODEL_NAME="gpt-4o-mini"
126
+ export HF_TOKEN="your-api-key"
127
+ export ENV_URL="http://localhost:8000"
128
 
129
+ python inference.py
 
 
 
 
 
 
 
 
 
 
 
130
  ```
inference.py CHANGED
@@ -6,28 +6,30 @@ Required environment variables:
6
  API_BASE_URL β€” LLM API endpoint
7
  MODEL_NAME β€” Model identifier
8
  HF_TOKEN β€” Hugging Face / API key
9
- ENV_BASE_URL β€” Running environment URL (default: http://localhost:7860)
10
  """
11
 
12
  import os
13
  import json
14
  import time
15
  import re
 
16
  from typing import List, Optional
17
  from dotenv import load_dotenv
 
18
 
19
  # Load .env variables
20
  load_dotenv()
21
 
22
- import requests
23
- from openai import OpenAI
24
-
25
  # ── Config ────────────────────────────────────────────────────────────────────
26
- API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
27
- MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
28
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
29
- ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860")
30
- BENCHMARK = "code-review-env"
 
 
 
31
 
32
  client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
33
 
@@ -41,7 +43,7 @@ Schema:
41
  {
42
  "bug_identified": true or false,
43
  "bug_location": "exact location (function name, line description, variable, expression)",
44
- "bug_type": "off-by-one | logic-error | security-vulnerability | null-dereference | none",
45
  "bug_description": "detailed explanation of why this is a bug and the impact",
46
  "severity": "none | low | medium | high | critical",
47
  "suggested_fix": "the corrected code snippet or a precise description of the fix"
@@ -69,7 +71,7 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
69
  # ── Helpers ───────────────────────────────────────────────────────────────────
70
 
71
  def env_post(path: str, data: Optional[dict] = None, params: Optional[dict] = None) -> dict:
72
- url = f"{ENV_BASE_URL}{path}"
73
  resp = requests.post(url, json=data or {}, params=params or {}, timeout=30)
74
  resp.raise_for_status()
75
  return resp.json()
@@ -80,41 +82,49 @@ def parse_json_from_llm(text: str) -> dict:
80
  text = text.strip()
81
  text = re.sub(r"^```(?:json)?\s*", "", text)
82
  text = re.sub(r"\s*```$", "", text)
83
- return json.loads(text)
 
 
 
 
 
 
 
84
 
85
 
86
  def build_prompt(obs: dict) -> str:
87
  lines = [
88
  f"Language: {obs['language']}",
89
- f"Task: {obs['task_description']}",
 
 
90
  "",
91
  f"```{obs['language']}",
92
  obs["code_snippet"],
93
  "```",
94
  ]
95
- if obs.get("previous_feedback"):
96
- lines += ["", f"Previous feedback: {obs['previous_feedback']}",
97
- "Revise your analysis accordingly."]
98
  return "\n".join(lines)
99
 
100
 
101
  # ── Task runner ───────────────────────────────────────────────────────────────
102
 
103
- def run_task(difficulty: str) -> dict:
104
- reset_resp = env_post("/reset", params={"difficulty": difficulty})
105
  obs = reset_resp["observation"]
106
- task_id = obs['task_id']
107
-
108
  log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
109
 
110
- rewards = []
111
- steps_taken = 0
 
112
  done = False
113
- last_error = None
 
114
 
115
- while not done and steps_taken < obs["max_steps"]:
116
- steps_taken += 1
117
  prompt = build_prompt(obs)
 
118
 
119
  # ── LLM call ──────────────────────────────────────────────────────────
120
  try:
@@ -126,20 +136,21 @@ def run_task(difficulty: str) -> dict:
126
  ],
127
  temperature=0.1,
128
  max_tokens=600,
 
129
  )
130
  raw = response.choices[0].message.content
131
  action_dict = parse_json_from_llm(raw)
132
  action_str = json.dumps(action_dict)
133
- last_error = None
134
  except Exception as exc:
135
- last_error = str(exc)
136
  action_dict = {
137
  "bug_identified": False,
138
- "bug_location": "error",
139
  "bug_type": "none",
140
- "bug_description": last_error,
141
  "severity": "none",
142
- "suggested_fix": "",
143
  }
144
  action_str = "{}"
145
 
@@ -147,44 +158,56 @@ def run_task(difficulty: str) -> dict:
147
  step_resp = env_post("/step", data=action_dict)
148
  reward = step_resp["reward"]
149
  done = step_resp["done"]
150
- obs = step_resp["observation"]
151
 
152
- rewards.append(reward)
153
- log_step(step=steps_taken, action=action_str, reward=reward, done=done, error=last_error)
 
 
154
 
155
- # Calculate final score (normalized to [0, 1])
156
- # Total reward is cumulative in this env, but we cap it at 1.0 for the score
157
- total_reward = sum(rewards)
158
- score = min(max(total_reward, 0.0), 1.0)
159
- success = score >= 0.8
160
 
161
- log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
162
-
163
  return {
164
- "task_id": task_id,
165
- "score": score,
166
- "success": success
 
167
  }
168
 
169
 
170
  # ── Main ──────────────────────────────────────────────────────────────────────
171
 
172
  def main():
173
- tasks = ["easy", "medium", "hard"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  results = []
175
 
176
- for difficulty in tasks:
177
  try:
178
- r = run_task(difficulty)
179
- results.append(r)
180
  except Exception as exc:
181
- # print(f"DEBUG: Task failed: {exc}", flush=True)
182
- log_end(success=False, steps=0, score=0.0, rewards=[])
 
183
 
184
  if results:
185
- avg = sum(r["score"] for r in results) / len(results)
186
- # Optional: summary for human review (will not interfere with [END] parsers)
187
- # print(f"\n[SUMMARY] avg_score={avg:.3f}")
188
 
189
  if __name__ == "__main__":
190
  main()
 
6
  API_BASE_URL β€” LLM API endpoint
7
  MODEL_NAME β€” Model identifier
8
  HF_TOKEN β€” Hugging Face / API key
9
+ ENV_URL β€” Running environment URL (default: http://localhost:7860)
10
  """
11
 
12
  import os
13
  import json
14
  import time
15
  import re
16
+ import requests
17
  from typing import List, Optional
18
  from dotenv import load_dotenv
19
+ from openai import OpenAI
20
 
21
  # Load .env variables
22
  load_dotenv()
23
 
 
 
 
24
  # ── Config ────────────────────────────────────────────────────────────────────
25
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
26
+ MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
27
+ HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY")
28
+ ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
29
+ BENCHMARK = "code-security-review"
30
+
31
+ if not HF_TOKEN:
32
+ raise ValueError("HF_TOKEN or API_KEY must be set.")
33
 
34
  client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
35
 
 
43
  {
44
  "bug_identified": true or false,
45
  "bug_location": "exact location (function name, line description, variable, expression)",
46
+ "bug_type": "off-by-one | logic-error | security-vulnerability | none",
47
  "bug_description": "detailed explanation of why this is a bug and the impact",
48
  "severity": "none | low | medium | high | critical",
49
  "suggested_fix": "the corrected code snippet or a precise description of the fix"
 
71
  # ── Helpers ───────────────────────────────────────────────────────────────────
72
 
73
  def env_post(path: str, data: Optional[dict] = None, params: Optional[dict] = None) -> dict:
74
+ url = f"{ENV_URL}{path}"
75
  resp = requests.post(url, json=data or {}, params=params or {}, timeout=30)
76
  resp.raise_for_status()
77
  return resp.json()
 
82
  text = text.strip()
83
  text = re.sub(r"^```(?:json)?\s*", "", text)
84
  text = re.sub(r"\s*```$", "", text)
85
+ # If the LLM still included text around the JSON, try to find the first { and last }
86
+ match = re.search(r"({.*})", text, re.DOTALL)
87
+ if match:
88
+ text = match.group(1)
89
+ try:
90
+ return json.loads(text)
91
+ except Exception:
92
+ return {}
93
 
94
 
95
  def build_prompt(obs: dict) -> str:
96
  lines = [
97
  f"Language: {obs['language']}",
98
+ f"Context: {obs.get('context', 'No context provided')}",
99
+ f"PR Title: {obs.get('pr_title', 'No PR title')}",
100
+ f"File Path: {obs.get('file_path', 'unknown')}",
101
  "",
102
  f"```{obs['language']}",
103
  obs["code_snippet"],
104
  "```",
105
  ]
 
 
 
106
  return "\n".join(lines)
107
 
108
 
109
  # ── Task runner ───────────────────────────────────────────────────────────────
110
 
111
+ def run_task(task_id: str, task_num: int) -> dict:
112
+ reset_resp = env_post("/reset", params={"task_id": task_id})
113
  obs = reset_resp["observation"]
114
+
 
115
  log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
116
 
117
+ cumulative_reward = 0.0
118
+ step_num = 0
119
+ max_steps = 1
120
  done = False
121
+ all_rewards = []
122
+ error = None
123
 
124
+ while not done and step_num < max_steps:
125
+ step_num += 1
126
  prompt = build_prompt(obs)
127
+ action_dict = {}
128
 
129
  # ── LLM call ──────────────────────────────────────────────────────────
130
  try:
 
136
  ],
137
  temperature=0.1,
138
  max_tokens=600,
139
+ stream=False,
140
  )
141
  raw = response.choices[0].message.content
142
  action_dict = parse_json_from_llm(raw)
143
  action_str = json.dumps(action_dict)
144
+ error = None
145
  except Exception as exc:
146
+ error = str(exc).replace("\n", " ")
147
  action_dict = {
148
  "bug_identified": False,
149
+ "bug_location": "none",
150
  "bug_type": "none",
151
+ "bug_description": f"Error: {error}",
152
  "severity": "none",
153
+ "suggested_fix": "none",
154
  }
155
  action_str = "{}"
156
 
 
158
  step_resp = env_post("/step", data=action_dict)
159
  reward = step_resp["reward"]
160
  done = step_resp["done"]
161
+ obs = step_resp.get("observation")
162
 
163
+ all_rewards.append(reward)
164
+ cumulative_reward += reward
165
+
166
+ log_step(step=step_num, action=action_str, reward=reward, done=done, error=error)
167
 
168
+ success = cumulative_reward >= 0.8
169
+ log_end(success=success, steps=step_num, score=cumulative_reward, rewards=all_rewards)
 
 
 
170
 
 
 
171
  return {
172
+ "task_num": task_num,
173
+ "task_id": task_id,
174
+ "score": cumulative_reward,
175
+ "success": success,
176
  }
177
 
178
 
179
  # ── Main ──────────────────────────────────────────────────────────────────────
180
 
181
  def main():
182
+ print(f"[INFO] Initializing inference on {BENCHMARK} using {MODEL_NAME}", flush=True)
183
+
184
+ TASK_FILTER = os.environ.get("TASK")
185
+
186
+ all_tasks = [
187
+ ("python-off-by-one", 1, "easy"),
188
+ ("js-auth-privilege", 2, "medium"),
189
+ ("python-sql-injection", 3, "hard"),
190
+ ]
191
+
192
+ if TASK_FILTER:
193
+ tasks = [t for t in all_tasks if t[2] == TASK_FILTER]
194
+ else:
195
+ tasks = all_tasks
196
+
197
  results = []
198
 
199
+ for task_id, task_num, _ in tasks:
200
  try:
201
+ r = run_task(task_id, task_num)
 
202
  except Exception as exc:
203
+ print(f"[ERROR] task_id={task_id} error={exc}", flush=True)
204
+ r = {"task_num": task_num, "task_id": task_id, "score": 0.0, "success": False}
205
+ results.append(r)
206
 
207
  if results:
208
+ avg = round(sum(r["score"] for r in results) / len(results), 3)
209
+ successes = sum(1 for r in results if r.get("success"))
210
+ print(f"\n[SUMMARY] avg_reward={avg} tasks_passed={successes}/{len(results)}", flush=True)
211
 
212
  if __name__ == "__main__":
213
  main()
openenv.yaml CHANGED
@@ -1,47 +1,50 @@
1
- name: code-review-env
2
- version: 1.0.0
3
- description: >
4
- RL environment for training AI agents to detect bugs and security
5
- vulnerabilities in real Python code. Covers off-by-one errors,
6
- authentication logic flaws, and SQL injection β€” with deterministic
7
- programmatic graders and partial-progress reward signals.
8
 
 
 
 
 
 
 
 
9
  author: Inmodel Labs
10
- tags:
11
- - code-review
12
- - security
13
- - software-engineering
14
- - real-world
15
- - python
16
 
 
 
17
  tasks:
18
- - id: task_easy_001
 
 
19
  difficulty: easy
20
- description: "Detect off-by-one error in array traversal loop"
21
- reset_params:
22
- difficulty: easy
23
 
24
- - id: task_medium_001
 
 
25
  difficulty: medium
26
- description: "Detect authentication logic flaw enabling privilege escalation"
27
- reset_params:
28
- difficulty: medium
29
 
30
- - id: task_hard_001
 
 
31
  difficulty: hard
32
- description: "Detect SQL injection via unsanitised f-string database query"
33
- reset_params:
34
- difficulty: hard
35
 
 
 
36
  action_space:
37
  type: object
38
  properties:
39
- bug_identified: { type: boolean }
40
- bug_location: { type: string }
41
- bug_type: { type: string }
42
- bug_description: { type: string }
43
- severity: { type: string, enum: [none, low, medium, high, critical] }
44
- suggested_fix: { type: string }
45
  required:
46
  - bug_identified
47
  - bug_location
@@ -50,18 +53,20 @@ action_space:
50
  - severity
51
  - suggested_fix
52
 
 
 
53
  observation_space:
54
  type: object
55
  properties:
56
- code_snippet: { type: string }
57
- language: { type: string }
58
- task_description: { type: string }
59
- task_id: { type: string }
60
- difficulty: { type: string, enum: [easy, medium, hard] }
61
- step_number: { type: integer }
62
- max_steps: { type: integer }
63
- previous_feedback: { type: string, nullable: true }
64
 
 
65
  reward:
66
  min: 0.0
67
  max: 1.0
@@ -69,9 +74,11 @@ reward:
69
  Partial rewards for: bug identification (0.20), correct bug type (0.20),
70
  precise location (0.10), description quality (0.25, keyword density),
71
  fix quality (0.15, keyword density), correct severity (0.10).
 
72
 
73
  endpoints:
74
- health: GET /health
75
- reset: POST /reset
76
- step: POST /step
77
- state: GET /state
 
 
1
+ # OpenEnv Environment Specification
2
+ # This file describes the Code Security Review environment for the Meta PyTorch OpenEnv Hackathon.
 
 
 
 
 
3
 
4
+ # Metadata section details the environment's identity.
5
+ name: code-security-review
6
+ version: "1.0.0"
7
+ description: >
8
+ An RL environment for training AI agents to perform code security review.
9
+ Agents analyze code snippets from production pull requests and identify bugs,
10
+ vulnerabilities, and security issues.
11
  author: Inmodel Labs
 
 
 
 
 
 
12
 
13
+ # Tasks section defines the core challenges in the environment.
14
+ # Each task has a unique ID, name, description, and difficulty level.
15
  tasks:
16
+ - id: python-off-by-one
17
+ name: "Python Off-by-One Error"
18
+ description: "Identify an off-by-one index error in a Python finance batch processor"
19
  difficulty: easy
20
+ max_steps: 1
21
+ reward_range: [0.0, 1.0]
 
22
 
23
+ - id: js-auth-privilege
24
+ name: "JavaScript Auth Logic Flaw"
25
+ description: "Identify a privilege escalation vulnerability in Node.js auth middleware"
26
  difficulty: medium
27
+ max_steps: 1
28
+ reward_range: [0.0, 1.0]
 
29
 
30
+ - id: python-sql-injection
31
+ name: "Python SQL Injection"
32
+ description: "Identify an SQL injection vulnerability via f-string in a REST API"
33
  difficulty: hard
34
+ max_steps: 1
35
+ reward_range: [0.0, 1.0]
 
36
 
37
+ # The Action space defines the format of the agent's response.
38
+ # Each field is scored by the grader to provide partial progress signals.
39
  action_space:
40
  type: object
41
  properties:
42
+ bug_identified: { type: boolean, description: "Boolean: true if a bug exists" }
43
+ bug_location: { type: string, description: "String: Pinpoint the bug's location in code" }
44
+ bug_type: { type: string, description: "String: off-by-one | logic-error | security-vulnerability | none" }
45
+ bug_description: { type: string, description: "String: Detailed analysis of the vulnerability" }
46
+ severity: { type: string, enum: [none, low, medium, high, critical], description: "String: none | low | medium | high | critical" }
47
+ suggested_fix: { type: string, description: "String: How to fix the identified bug" }
48
  required:
49
  - bug_identified
50
  - bug_location
 
53
  - severity
54
  - suggested_fix
55
 
56
+ # The Observation space defines what the agent sees at each step.
57
+ # It uses a structured context to help the agent understand the code's purpose.
58
  observation_space:
59
  type: object
60
  properties:
61
+ task_id: { type: string, description: "Unique task identifier" }
62
+ language: { type: string, description: "Source code language" }
63
+ difficulty: { type: string, enum: [easy, medium, hard], description: "Task complexity (easy/medium/hard)" }
64
+ code_snippet: { type: string, description: "The source code to be reviewed" }
65
+ context: { type: string, description: "Real-world context (e.g., API description)" }
66
+ pr_title: { type: string, description: "Pull Request title for additional intent context" }
67
+ file_path: { type: string, description: "Relative path to the file in the repository" }
 
68
 
69
+ # Reward structure for evaluating agent performance.
70
  reward:
71
  min: 0.0
72
  max: 1.0
 
74
  Partial rewards for: bug identification (0.20), correct bug type (0.20),
75
  precise location (0.10), description quality (0.25, keyword density),
76
  fix quality (0.15, keyword density), correct severity (0.10).
77
+ Grader penalizes keyword stuffing.
78
 
79
  endpoints:
80
+ health: GET /
81
+ reset: POST /reset
82
+ step: POST /step
83
+ state: GET /state
84
+ tasks: GET /tasks
qa_test.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+
4
+ BASE_URL = "http://localhost:7860"
5
+
6
+ def run_tests():
7
+ checks = []
8
+
9
+ # 1. GET /
10
+ try:
11
+ r = requests.get(f"{BASE_URL}/")
12
+ passed = r.status_code == 200 and r.json().get("status") == "ok"
13
+ checks.append({
14
+ "id": 1, "name": "GET / health check", "passed": passed,
15
+ "expected": 'HTTP 200 and {"status": "ok"}', "got": f"HTTP {r.status_code} {r.text}"
16
+ })
17
+ except Exception as e:
18
+ checks.append({"id": 1, "name": "GET / health check", "passed": False, "expected": "200 OK", "got": str(e)})
19
+
20
+ # 15. GET /state before reset (Edge case)
21
+ try:
22
+ r = requests.get(f"{BASE_URL}/state")
23
+ # Should not crash
24
+ checks.append({
25
+ "id": 15, "name": "GET /state before any reset", "passed": r.status_code == 200,
26
+ "expected": "HTTP 200 (No crash)", "got": f"HTTP {r.status_code} {r.text}"
27
+ })
28
+ except Exception as e:
29
+ checks.append({"id": 15, "name": "GET /state before any reset", "passed": False, "expected": "200 OK", "got": str(e)})
30
+
31
+ # 2. POST /reset
32
+ try:
33
+ r = requests.post(f"{BASE_URL}/reset")
34
+ data = r.json().get("observation", {})
35
+ required = ["task_id", "language", "difficulty", "code_snippet", "context", "pr_title", "file_path"]
36
+ passed = all(k in data for k in required)
37
+ checks.append({
38
+ "id": 2, "name": "POST /reset fields check", "passed": passed,
39
+ "expected": f"JSON with {required}", "got": list(data.keys())
40
+ })
41
+ except Exception as e:
42
+ checks.append({"id": 2, "name": "POST /reset fields check", "passed": False, "expected": "Fields", "got": str(e)})
43
+
44
+ # 16. POST /reset no task_id
45
+ try:
46
+ r = requests.post(f"{BASE_URL}/reset")
47
+ checks.append({
48
+ "id": 16, "name": "POST /reset no task_id (Random)", "passed": r.status_code == 200,
49
+ "expected": "HTTP 200", "got": f"HTTP {r.status_code}"
50
+ })
51
+ except Exception as e:
52
+ checks.append({"id": 16, "name": "POST /reset no task_id (Random)", "passed": False, "expected": "200 OK", "got": str(e)})
53
+
54
+ # 3-5. POST /reset?task_id=...
55
+ for tid in ["python-off-by-one", "js-auth-privilege", "python-sql-injection"]:
56
+ try:
57
+ num = {"python-off-by-one": 3, "js-auth-privilege": 4, "python-sql-injection": 5}[tid]
58
+ r = requests.post(f"{BASE_URL}/reset?task_id={tid}")
59
+ passed = r.status_code == 200 and r.json()["observation"]["task_id"] == tid
60
+ checks.append({
61
+ "id": num, "name": f"POST /reset for {tid}", "passed": passed,
62
+ "expected": f"HTTP 200 with task_id={tid}", "got": f"HTTP {r.status_code} {r.json()['observation']['task_id'] if passed else r.text}"
63
+ })
64
+ except Exception as e:
65
+ checks.append({"id": num, "name": f"POST /reset for {tid}", "passed": False, "expected": "200 OK", "got": str(e)})
66
+
67
+ # 6. GET /state
68
+ try:
69
+ r = requests.get(f"{BASE_URL}/state")
70
+ data = r.json()
71
+ required = ["task_id", "step", "done", "total_reward"]
72
+ passed = all(k in data for k in required)
73
+ checks.append({
74
+ "id": 6, "name": "GET /state fields check", "passed": passed,
75
+ "expected": f"JSON with {required}", "got": list(data.keys())
76
+ })
77
+ except Exception as e:
78
+ checks.append({"id": 6, "name": "GET /state fields check", "passed": False, "expected": "Fields", "got": str(e)})
79
+
80
+ # 7. POST /step with PROVIDED action
81
+ try:
82
+ requests.post(f"{BASE_URL}/reset?task_id=python-sql-injection")
83
+ action = {
84
+ "bug_identified": True,
85
+ "bug_location": "line 2 f-string",
86
+ "bug_type": "security-vulnerability",
87
+ "bug_description": "SQL injection via f-string",
88
+ "severity": "critical",
89
+ "suggested_fix": "use parameterized query"
90
+ }
91
+ r = requests.post(f"{BASE_URL}/step", json=action)
92
+ res = r.json()
93
+ reward = res.get("reward", -1.0)
94
+ done = res.get("done", False)
95
+ passed = 0.0 <= reward <= 1.0 and done is True
96
+ checks.append({
97
+ "id": 7, "name": "POST /step valid action", "passed": passed,
98
+ "expected": "Reward [0,1] and done=true", "got": f"reward={reward}, done={done}"
99
+ })
100
+ except Exception as e:
101
+ checks.append({"id": 7, "name": "POST /step valid action", "passed": False, "expected": "Result", "got": str(e)})
102
+
103
+ # 14. Call POST /step twice (Edge Case)
104
+ try:
105
+ # Step already called in task 7
106
+ action = {"bug_identified": False, "bug_location": "", "bug_type": "none", "bug_description": "", "severity": "none", "suggested_fix": ""}
107
+ r = requests.post(f"{BASE_URL}/step", json=action)
108
+ res = r.json()
109
+ passed = r.status_code == 200 and "error" in res.get("info", {})
110
+ checks.append({
111
+ "id": 14, "name": "POST /step twice in same episode", "passed": passed,
112
+ "expected": "HTTP 200 and error in info", "got": f"HTTP {r.status_code}, info={res.get('info')}"
113
+ })
114
+ except Exception as e:
115
+ checks.append({"id": 14, "name": "POST /step twice in same episode", "passed": False, "expected": "Handled error", "got": str(e)})
116
+
117
+ # 8. Perfect action for SQL
118
+ try:
119
+ requests.post(f"{BASE_URL}/reset?task_id=python-sql-injection")
120
+ perfect_action = {
121
+ "bug_identified": True,
122
+ "bug_location": "line 2 f-string interpolation in SQL query construction",
123
+ "bug_type": "security-vulnerability",
124
+ "bug_description": "SQL injection vulnerability where user-supplied search_term is directly interpolated into the SQL query via f-string. An attacker can inject malicious SQL to bypass authentication, exfiltrate all user data, or drop tables. The fix is to use parameterized queries which sanitize user input automatically.",
125
+ "severity": "critical",
126
+ "suggested_fix": "Use db.execute('SELECT * FROM users WHERE name LIKE %s', ('%'+search_term+'%',)) instead of f-string interpolation"
127
+ }
128
+ r = requests.post(f"{BASE_URL}/step", json=perfect_action)
129
+ reward = r.json().get("reward", 0.0)
130
+ checks.append({
131
+ "id": 8, "name": "PERFECT action SQL", "passed": reward >= 0.85,
132
+ "expected": "Reward >= 0.85", "got": f"reward={reward}"
133
+ })
134
+ except Exception as e:
135
+ checks.append({"id": 8, "name": "PERFECT action SQL", "passed": False, "expected": ">=0.85", "got": str(e)})
136
+
137
+ # 9. Keyword stuffed
138
+ try:
139
+ requests.post(f"{BASE_URL}/reset?task_id=python-sql-injection")
140
+ stuffed_action = {
141
+ "bug_identified": True,
142
+ "bug_location": "sql",
143
+ "bug_type": "security-vulnerability",
144
+ "bug_description": "sql injection sql injection sql injection parameterized f-string sanitize escape malicious attack tautology union drop sql injection sql injection",
145
+ "severity": "critical",
146
+ "suggested_fix": "fix"
147
+ }
148
+ r = requests.post(f"{BASE_URL}/step", json=stuffed_action)
149
+ reward = r.json().get("reward", 1.0)
150
+ checks.append({
151
+ "id": 9, "name": "KEYWORD STUFFED action", "passed": reward <= 0.20,
152
+ "expected": "Reward <= 0.20", "got": f"reward={reward}"
153
+ })
154
+ except Exception as e:
155
+ checks.append({"id": 9, "name": "KEYWORD STUFFED action", "passed": False, "expected": "<=0.20", "got": str(e)})
156
+
157
+ # 10. Bug identified false
158
+ try:
159
+ requests.post(f"{BASE_URL}/reset")
160
+ action = {"bug_identified": False, "bug_location": "", "bug_type": "none", "bug_description": "", "severity": "none", "suggested_fix": ""}
161
+ r = requests.post(f"{BASE_URL}/step", json=action)
162
+ reward = r.json().get("reward", 1.0)
163
+ checks.append({
164
+ "id": 10, "name": "Identify=False empty fields", "passed": reward == 0.0,
165
+ "expected": "Reward exactly 0.0", "got": f"reward={reward}"
166
+ })
167
+ except Exception as e:
168
+ checks.append({"id": 10, "name": "Identify=False empty fields", "passed": False, "expected": "0.0", "got": str(e)})
169
+
170
+ # 11. Partial credit severity
171
+ try:
172
+ # Off-by-one is severity critical (I set it to critical).
173
+ # Let's say I submit 'low' severity.
174
+ requests.post(f"{BASE_URL}/reset?task_id=python-off-by-one")
175
+ action = {
176
+ "bug_identified": True, "bug_location": "range", "bug_type": "off-by-one",
177
+ "bug_description": "off-by-one error in range function call",
178
+ "severity": "low", # Wrong severity
179
+ "suggested_fix": "range(len(x))"
180
+ }
181
+ r = requests.post(f"{BASE_URL}/step", json=action)
182
+ info = r.json().get("info", {})
183
+ breakdown = info.get("reward_breakdown", {})
184
+ sev_score = breakdown.get("severity", -1.0)
185
+ # It should be 0.0 (wrong) but the total should still have partial credit from other components
186
+ reward = r.json().get("reward", 0.0)
187
+ checks.append({
188
+ "id": 11, "name": "Partial credit (wrong severity)", "passed": 0.0 < reward < 1.0,
189
+ "expected": "Reward between 0 and 1 (partial credit)", "got": f"reward={reward}, severity_component={sev_score}"
190
+ })
191
+ except Exception as e:
192
+ checks.append({"id": 11, "name": "Partial credit (wrong severity)", "passed": False, "expected": "Partial credit", "got": str(e)})
193
+
194
+ # 12-13. Breakdown keys and components
195
+ try:
196
+ requests.post(f"{BASE_URL}/reset")
197
+ action = {"bug_identified": True, "bug_location": "test", "bug_type": "test", "bug_description": "test test test test test test test test test test test test test test test test test test test test", "severity": "none", "suggested_fix": "test test test"}
198
+ r = requests.post(f"{BASE_URL}/step", json=action)
199
+ info = r.json().get("info", {})
200
+ breakdown = info.get("reward_breakdown", {})
201
+ required = ["bug_identified", "bug_type", "bug_location", "description_quality", "fix_quality", "severity"]
202
+ checks.append({
203
+ "id": 12, "name": "Reward breakdown keys", "passed": all(k in breakdown for k in required),
204
+ "expected": f"Breakdown with {required}", "got": list(breakdown.keys())
205
+ })
206
+
207
+ max_vals = {
208
+ "bug_identified": 0.20, "bug_type": 0.20, "bug_location": 0.10,
209
+ "description_quality": 0.25, "fix_quality": 0.15, "severity": 0.10
210
+ }
211
+ passed_range = all(0.0 <= breakdown.get(k, -1) <= max_vals[k] for k in max_vals)
212
+ checks.append({
213
+ "id": 13, "name": "Component score ranges", "passed": passed_range,
214
+ "expected": "All components <= max", "got": breakdown
215
+ })
216
+ except Exception as e:
217
+ checks.append({"id": 12, "name": "Breakdown checks", "passed": False, "expected": "Breakdown", "got": str(e)})
218
+
219
+ # Sort and print
220
+ checks.sort(key=lambda x: x["id"])
221
+ for c in checks:
222
+ status = "PASS" if c["passed"] else "FAIL"
223
+ print(f"[{c['id']}] {c['name']} β€” {status}")
224
+ print(f" Expected: {c['expected']}")
225
+ print(f" Got: {c['got']}")
226
+ print("")
227
+
228
+ passed_count = sum(1 for c in checks if c["passed"])
229
+ disqual = "YES" if passed_count < 7 else "NO" # Disqualified if Part 1 fails
230
+ print(f"TOTAL: {passed_count}/16 passed")
231
+ print(f"DISQUALIFICATION RISK: {disqual}")
232
+ # Estimate score based on points
233
+ score = (passed_count / 16) * 100
234
+ print(f"ESTIMATED SCORE: {round(score)}/100")
235
+
236
+ if __name__ == "__main__":
237
+ run_tests()
server/app.py CHANGED
@@ -1,19 +1,16 @@
1
  import os
2
  import uvicorn
 
3
  from fastapi import FastAPI, HTTPException, Query
4
  from fastapi.middleware.cors import CORSMiddleware
5
- from fastapi.staticfiles import StaticFiles
6
- from fastapi.responses import FileResponse
7
 
8
- from .models import CodeReviewAction, CodeReviewState, StepResponse, ResetResponse
9
- from .environment import CodeReviewEnvironment
 
10
 
11
  app = FastAPI(
12
  title="Code Security Review β€” OpenEnv",
13
- description=(
14
- "RL environment for training AI agents to detect bugs and security "
15
- "vulnerabilities in code. Compatible with the OpenEnv spec."
16
- ),
17
  version="1.0.0",
18
  )
19
 
@@ -24,46 +21,61 @@ app.add_middleware(
24
  allow_headers=["*"],
25
  )
26
 
27
- app.mount("/static", StaticFiles(directory="static"), name="static")
28
 
29
- env = CodeReviewEnvironment()
30
 
31
  @app.get("/")
32
- def read_index():
33
- return FileResponse("static/index.html")
34
-
35
-
36
- @app.get("/health")
37
  def health():
38
- return {"status": "ok", "env": "code-review-env", "version": "1.0.0"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  @app.post("/reset", response_model=ResetResponse)
42
- def reset(difficulty: str = Query(default="easy", description="easy | medium | hard")):
 
 
 
43
  """Reset the environment and return the first observation."""
44
- obs = env.reset(difficulty=difficulty)
 
 
45
  return ResetResponse(observation=obs)
46
 
47
 
48
- @app.post("/step", response_model=StepResponse)
49
  def step(action: CodeReviewAction):
50
  """Submit a code review action and receive a reward signal."""
51
- try:
52
- obs, reward, done, info = env.step(action)
53
- return StepResponse(observation=obs, reward=reward, done=done, info=info)
54
- except ValueError as exc:
55
- raise HTTPException(status_code=400, detail=str(exc))
56
 
57
 
58
- @app.get("/state", response_model=CodeReviewState)
59
  def state():
60
  """Return the current environment state."""
61
  return env.state()
62
 
63
 
64
  if __name__ == "__main__":
65
- port = int(os.environ.get("PORT", 7860))
66
- enable_web = os.environ.get("ENABLE_WEB_INTERFACE", "false").lower() == "true"
67
  uvicorn.run(
68
  "server.app:app",
69
  host="0.0.0.0",
 
1
  import os
2
  import uvicorn
3
+ from typing import List, Optional
4
  from fastapi import FastAPI, HTTPException, Query
5
  from fastapi.middleware.cors import CORSMiddleware
 
 
6
 
7
+ from server.models import CodeReviewAction, StepResult, ResetResponse, StateResponse, TaskInfo
8
+ from server.tasks import TASKS
9
+ from server.environment import CodeSecurityEnv
10
 
11
  app = FastAPI(
12
  title="Code Security Review β€” OpenEnv",
13
+ description="An RL environment for training AI agents to perform code security review.",
 
 
 
14
  version="1.0.0",
15
  )
16
 
 
21
  allow_headers=["*"],
22
  )
23
 
24
+ env = CodeSecurityEnv()
25
 
 
26
 
27
  @app.get("/")
 
 
 
 
 
28
  def health():
29
+ """Health check endpoint."""
30
+ return {
31
+ "status": "ok",
32
+ "project": "Code Security Review - OpenEnv",
33
+ "version": "1.0.0",
34
+ "organization": "Inmodel Labs",
35
+ }
36
+
37
+
38
+ @app.get("/tasks", response_model=List[TaskInfo])
39
+ def list_tasks():
40
+ """List all available tasks."""
41
+ return [
42
+ TaskInfo(
43
+ id=t["id"],
44
+ language=t["language"],
45
+ bug_class=t["bug_class"],
46
+ difficulty=t["difficulty"],
47
+ )
48
+ for t in TASKS.values()
49
+ ]
50
 
51
 
52
  @app.post("/reset", response_model=ResetResponse)
53
+ def reset(
54
+ task_id: str = Query(default="python-off-by-one", description="Task ID to reset to"),
55
+ seed: Optional[int] = Query(default=None, description="Optional seed for reproducibility")
56
+ ):
57
  """Reset the environment and return the first observation."""
58
+ if task_id not in TASKS:
59
+ raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found.")
60
+ obs = env.reset(task_id=task_id, seed=seed)
61
  return ResetResponse(observation=obs)
62
 
63
 
64
+ @app.post("/step", response_model=StepResult)
65
  def step(action: CodeReviewAction):
66
  """Submit a code review action and receive a reward signal."""
67
+ result = env.step(action)
68
+ return result
 
 
 
69
 
70
 
71
+ @app.get("/state", response_model=StateResponse)
72
  def state():
73
  """Return the current environment state."""
74
  return env.state()
75
 
76
 
77
  if __name__ == "__main__":
78
+ port = int(os.environ.get("PORT", 8000))
 
79
  uvicorn.run(
80
  "server.app:app",
81
  host="0.0.0.0",
server/environment.py CHANGED
@@ -1,447 +1,84 @@
1
- from typing import Dict, Any, Tuple, Optional
2
- from .models import CodeReviewAction, CodeReviewObservation, CodeReviewState
3
 
4
- MAX_STEPS = 3
 
 
5
 
6
- # TASK DEFINITIONS
7
-
8
-
9
- TASKS: Dict[str, dict] = {
10
-
11
- # EASY
12
- "easy": {
13
- "id": "task_easy_001",
14
- "difficulty": "easy",
15
- "language": "python",
16
- "description": (
17
- "This function is supposed to sum all elements in a list. "
18
- "Find any bugs and suggest a fix."
19
- ),
20
- "code": (
21
- "def sum_elements(arr):\n"
22
- ' """Return the sum of all elements."""\n'
23
- " total = 0\n"
24
- " for i in range(1, len(arr) + 1): # iterates over indices\n"
25
- " total += arr[i]\n"
26
- " return total"
27
- ),
28
- "ground_truth": {
29
- "bug_identified": True,
30
- "bug_type_keywords": [
31
- "off-by-one", "off by one", "index error", "indexerror",
32
- "out of bounds", "out of range", "index out",
33
- ],
34
- "location_keywords": [
35
- "range(1, len(arr) + 1)", "len(arr) + 1", "len(arr)+1",
36
- "range", "loop", "index", "arr[i]",
37
- ],
38
- "description_keywords": [
39
- "index", "range", "len", "off-by-one", "off by one",
40
- "IndexError", "out of bounds", "+1", "exceed", "arr[i]",
41
- "zero", "start",
42
- ],
43
- "fix_keywords": [
44
- "range(len(arr))", "range(0, len(arr))",
45
- "for i in range(len", "for element in arr",
46
- "arr[i]" , "len(arr))",
47
- ],
48
- "severity_valid": ["high", "medium"],
49
- },
50
- },
51
-
52
- #MEDIUM
53
- "medium": {
54
- "id": "task_medium_001",
55
- "difficulty": "medium",
56
- "language": "python",
57
- "description": (
58
- "This authentication function controls admin access. "
59
- "Find the logical security bug."
60
- ),
61
- "code": (
62
- "def authenticate_user(username, password, request_admin=False):\n"
63
- ' """Authenticate user and return access level."""\n'
64
- " user = db.find_user(username)\n"
65
- " if not user or user.password_hash != hash_password(password):\n"
66
- ' return {"authenticated": False, "level": "none"}\n'
67
- "\n"
68
- " # Elevate to admin if caller requests it OR user has admin role\n"
69
- " if request_admin or user.role == 'admin': # <-- review this\n"
70
- ' return {"authenticated": True, "level": "admin"}\n'
71
- "\n"
72
- ' return {"authenticated": True, "level": "user"}'
73
- ),
74
- "ground_truth": {
75
- "bug_identified": True,
76
- "bug_type_keywords": [
77
- "logic", "logic error", "logical", "privilege escalation",
78
- "authorization", "authentication bypass", "access control",
79
- ],
80
- "location_keywords": [
81
- "request_admin or", "or user.role", "or", "condition",
82
- "if request_admin", "or user.role == 'admin'",
83
- ],
84
- "description_keywords": [
85
- "or", "and", "privilege", "escalation", "bypass", "admin",
86
- "role", "caller", "request_admin", "logic", "elevation",
87
- "any caller", "arbitrary",
88
- ],
89
- "fix_keywords": [
90
- "and", "request_admin and user.role", "and user.role == 'admin'",
91
- "and user.role", "both",
92
- ],
93
- "severity_valid": ["critical", "high"],
94
- },
95
- },
96
-
97
- # ── HARD ──────────────────────────────────
98
- "hard": {
99
- "id": "task_hard_001",
100
- "difficulty": "hard",
101
- "language": "python",
102
- "description": (
103
- "This function fetches records from a database using user-supplied input. "
104
- "Identify the security vulnerability."
105
- ),
106
- "code": (
107
- "def fetch_records(user_id: str, sort_column: str):\n"
108
- ' """Fetch user records sorted by a given column."""\n'
109
- " conn = get_db_connection()\n"
110
- " cursor = conn.cursor()\n"
111
- "\n"
112
- " query = (\n"
113
- ' f"SELECT id, name, email FROM users "\n'
114
- ' f"WHERE user_id = {user_id} "\n'
115
- ' f"ORDER BY {sort_column}"\n'
116
- " )\n"
117
- " cursor.execute(query)\n"
118
- " rows = cursor.fetchall()\n"
119
- " conn.close()\n"
120
- " return rows"
121
- ),
122
- "ground_truth": {
123
- "bug_identified": True,
124
- "bug_type_keywords": [
125
- "sql injection", "injection", "sqli", "sql",
126
- "security vulnerability", "security", "second-order",
127
- ],
128
- "location_keywords": [
129
- "f\"", "f-string", "format", "user_id", "sort_column",
130
- "query", "ORDER BY", "WHERE user_id",
131
- ],
132
- "description_keywords": [
133
- "sql injection", "injection", "parameterized", "f-string",
134
- "format string", "user input", "sanitize", "escape",
135
- "malicious", "attack", "tautology", "union", "drop",
136
- "ORDER BY", "sort_column", "arbitrary",
137
- ],
138
- "fix_keywords": [
139
- "parameterized", "?", "%s", "cursor.execute(query, (",
140
- "cursor.execute(query, [", "prepared statement",
141
- "whitelist", "allowlist", "ALLOWED_COLUMNS",
142
- "sanitize", "if sort_column not in",
143
- ],
144
- "severity_valid": ["critical"],
145
- },
146
- },
147
-
148
- # ── EXPERT ────────────────────────────────
149
- "expert": {
150
- "id": "task_expert_001",
151
- "difficulty": "expert",
152
- "language": "java",
153
- "description": (
154
- "This Java class implements a token bucket rate limiter. "
155
- "Identify the logic bug that could allow users to bypass the rate limit."
156
- ),
157
- "code": (
158
- "import java.util.concurrent.atomic.AtomicLong;\n\n"
159
- "public class TokenBucketRateLimiter {\n"
160
- " private final long maxTokens;\n"
161
- " private final long refillRatePerSecond;\n"
162
- " private AtomicLong currentTokens;\n"
163
- " private AtomicLong lastRefillTimestamp;\n\n"
164
- " public TokenBucketRateLimiter(long maxTokens, long refillRatePerSecond) {\n"
165
- " this.maxTokens = maxTokens;\n"
166
- " this.refillRatePerSecond = refillRatePerSecond;\n"
167
- " this.currentTokens = new AtomicLong(maxTokens);\n"
168
- " this.lastRefillTimestamp = new AtomicLong(System.currentTimeMillis());\n"
169
- " }\n\n"
170
- " /**\n"
171
- " * Checks if the requested number of tokens is available.\n"
172
- " * Decrements the bucket if allowed.\n"
173
- " */\n"
174
- " public synchronized boolean allowRequest(int tokensNeeded) {\n"
175
- " refill();\n"
176
- " if (currentTokens.get() >= tokensNeeded) {\n"
177
- " currentTokens.addAndGet(-tokensNeeded);\n"
178
- " return true;\n"
179
- " }\n"
180
- " return false;\n"
181
- " }\n\n"
182
- " private void refill() {\n"
183
- " long now = System.currentTimeMillis();\n"
184
- " long timeElapsedMs = now - lastRefillTimestamp.get();\n"
185
- " \n"
186
- " // Calculate how many tokens to add based on time elapsed\n"
187
- " long tokensToAdd = (timeElapsedMs / 1000) * refillRatePerSecond;\n\n"
188
- " if (tokensToAdd > 0) {\n"
189
- " // Hint: Look closely at how the tokens are updated here.\n"
190
- " // Consider what happens if a user stops making requests for a long time.\n"
191
- " currentTokens.addAndGet(tokensToAdd);\n"
192
- " lastRefillTimestamp.set(now);\n"
193
- " }\n"
194
- " }\n"
195
- "}"
196
- ),
197
- "ground_truth": {
198
- "bug_identified": True,
199
- "bug_type_keywords": [
200
- "logic", "limit", "overflow", "cap", "bound", "maximum", "exceed",
201
- "logic error", "capacity",
202
- ],
203
- "location_keywords": [
204
- "currentTokens.addAndGet", "refill()", "tokensToAdd",
205
- "currentTokens.get()", "addAndGet(tokensToAdd)",
206
- ],
207
- "description_keywords": [
208
- "exceed", "maxTokens", "cap", "limit", "bound",
209
- "overflow", "infinite", "burst", "accumulate",
210
- ],
211
- "fix_keywords": [
212
- "Math.min", "min(", "set(", "if (currentTokens.get() > maxTokens)",
213
- "compareAndSet", "cap",
214
- ],
215
- "severity_valid": ["high", "medium"],
216
- },
217
- },
218
-
219
- # ── EXPERT 2 (C++) ────────────────────────
220
- "expert2": {
221
- "id": "task_expert_002",
222
- "difficulty": "expert2",
223
- "language": "cpp",
224
- "description": (
225
- "This C++ class implements an event dispatcher. "
226
- "Identify the concurrency bug that can occur when an event is dispatched."
227
- ),
228
- "code": (
229
- "#include <iostream>\n"
230
- "#include <vector>\n"
231
- "#include <functional>\n"
232
- "#include <mutex>\n"
233
- "#include <algorithm>\n"
234
- "#include <string>\n\n"
235
- "class EventDispatcher {\n"
236
- "public:\n"
237
- " using Callback = std::function<void(const std::string&)>;\n\n"
238
- " void subscribe(int listener_id, Callback cb) {\n"
239
- " std::lock_guard<std::mutex> lock(mut_);\n"
240
- " listeners_.push_back({listener_id, cb});\n"
241
- " }\n\n"
242
- " void unsubscribe(int listener_id) {\n"
243
- " std::lock_guard<std::mutex> lock(mut_);\n"
244
- " listeners_.erase(\n"
245
- " std::remove_if(listeners_.begin(), listeners_.end(),\n"
246
- " [listener_id](const Listener& l) { return l.id == listener_id; }),\n"
247
- " listeners_.end()\n"
248
- " );\n"
249
- " }\n\n"
250
- " void dispatch(const std::string& event_data) {\n"
251
- " std::lock_guard<std::mutex> lock(mut_);\n"
252
- " for (const auto& listener : listeners_) {\n"
253
- " // Hint: What happens if a listener decides to call unsubscribe() \n"
254
- " // from inside their own callback function when an event fires?\n"
255
- " listener.cb(event_data);\n"
256
- " }\n"
257
- " }\n\n"
258
- "private:\n"
259
- " struct Listener {\n"
260
- " int id;\n"
261
- " Callback cb;\n"
262
- " };\n \n"
263
- " std::vector<Listener> listeners_;\n"
264
- " std::mutex mut_;\n"
265
- "};"
266
- ),
267
- "ground_truth": {
268
- "bug_identified": True,
269
- "bug_type_keywords": [
270
- "deadlock", "concurrency", "lock", "recursive", "reentrant", "hang",
271
- "iterator validation", "undefined behavior"
272
- ],
273
- "location_keywords": [
274
- "listener.cb", "unsubscribe", "dispatch", "mut_", "std::lock_guard",
275
- "lock(mut_)"
276
- ],
277
- "description_keywords": [
278
- "deadlock", "already locked", "same thread", "recursive_mutex",
279
- "reentrant", "hangs", "blocks", "invalidate", "iterator"
280
- ],
281
- "fix_keywords": [
282
- "std::recursive_mutex", "copy", "local copy", "copy the vector",
283
- "unlock before", "queue", "deferred"
284
- ],
285
- "severity_valid": ["high", "critical"],
286
- },
287
- },
288
- }
289
-
290
-
291
-
292
- # GRADER
293
-
294
-
295
- def grade_action(action: CodeReviewAction, task: dict) -> Tuple[float, Dict]:
296
- """
297
- Score the agent's review on a 0.0–1.0 scale.
298
-
299
- Breakdown:
300
- bug_identified 0.20
301
- bug_type 0.20
302
- bug_location 0.10
303
- bug_description 0.25 (keyword density, capped)
304
- suggested_fix 0.15 (keyword density, capped)
305
- severity 0.10
306
- ─────────────────────
307
- Total 1.00
308
- """
309
- gt = task["ground_truth"]
310
- score = 0.0
311
- breakdown: Dict[str, float] = {}
312
-
313
- # 1. Bug identification
314
- if action.bug_identified == gt["bug_identified"]:
315
- score += 0.20
316
- breakdown["bug_identified"] = 0.20
317
- else:
318
- breakdown["bug_identified"] = 0.00
319
- if not action.bug_identified:
320
- return 0.0, {
321
- "breakdown": breakdown,
322
- "total_score": 0.0,
323
- "feedback": "No bug identified β€” one definitely exists. Look more carefully.",
324
- }
325
-
326
- # 2. Bug type
327
- bug_type_lower = action.bug_type.lower()
328
- type_match = any(kw in bug_type_lower for kw in gt["bug_type_keywords"])
329
- if type_match:
330
- score += 0.20
331
- breakdown["bug_type"] = 0.20
332
- else:
333
- breakdown["bug_type"] = 0.00
334
-
335
- # 3. Bug location
336
- loc_lower = action.bug_location.lower()
337
- loc_match = any(kw.lower() in loc_lower for kw in gt["location_keywords"])
338
- if loc_match:
339
- score += 0.10
340
- breakdown["bug_location"] = 0.10
341
- else:
342
- breakdown["bug_location"] = 0.00
343
-
344
- # 4. Description quality (keyword density, capped at 0.25)
345
- desc_lower = action.bug_description.lower()
346
- desc_hits = sum(1 for kw in gt["description_keywords"] if kw.lower() in desc_lower)
347
- desc_score = round(min(0.25, desc_hits * 0.07), 3)
348
- score += desc_score
349
- breakdown["bug_description"] = desc_score
350
-
351
- # 5. Fix quality (keyword density, capped at 0.15)
352
- fix_lower = action.suggested_fix.lower()
353
- fix_hits = sum(1 for kw in gt["fix_keywords"] if kw.lower() in fix_lower)
354
- fix_score = round(min(0.15, fix_hits * 0.08), 3)
355
- score += fix_score
356
- breakdown["suggested_fix"] = fix_score
357
-
358
- # 6. Severity
359
- if action.severity.lower() in gt["severity_valid"]:
360
- score += 0.10
361
- breakdown["severity"] = 0.10
362
- else:
363
- breakdown["severity"] = 0.00
364
-
365
- total = round(min(1.0, score), 3)
366
-
367
- # Build human-readable feedback
368
- hints = []
369
- if breakdown["bug_type"] == 0:
370
- hints.append("Reconsider the bug category β€” be more specific.")
371
- if breakdown["bug_location"] == 0:
372
- hints.append("Pinpoint the exact line or expression that contains the bug.")
373
- if breakdown["suggested_fix"] < 0.08:
374
- hints.append("Your fix does not address the root cause β€” revise it.")
375
- if breakdown["severity"] == 0:
376
- hints.append("Re-evaluate the severity level.")
377
-
378
- feedback = " ".join(hints) if hints else "Strong analysis β€” refine the fix if needed."
379
-
380
- return total, {"breakdown": breakdown, "total_score": total, "feedback": feedback}
381
-
382
- # ENVIRONMENT
383
-
384
-
385
- class CodeReviewEnvironment:
386
  def __init__(self):
387
- self._state: Optional[CodeReviewState] = None
388
- self._current_task: Optional[dict] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
390
- def reset(self, difficulty: str = "easy") -> CodeReviewObservation:
391
- if difficulty not in TASKS:
392
- difficulty = "easy"
393
- task = TASKS[difficulty]
394
- self._current_task = task
395
- self._state = CodeReviewState(
396
- task_id=task["id"],
397
- difficulty=difficulty,
398
- step_count=0,
399
- done=False,
400
- total_reward=0.0,
401
- task_complete=False,
 
 
 
 
 
402
  )
403
- return self._build_obs(step_number=0, previous_feedback=None)
404
 
405
- def step(self, action: CodeReviewAction) -> Tuple[CodeReviewObservation, float, bool, Dict]:
406
- if self._state is None or self._state.done:
407
- raise ValueError("Call reset() before step().")
408
-
409
- self._state.step_count += 1
410
- reward, info = grade_action(action, self._current_task)
411
- self._state.total_reward = round(self._state.total_reward + reward, 3)
412
-
413
- # Done if agent nailed it or max steps reached
414
- done = reward >= 0.80 or self._state.step_count >= MAX_STEPS
415
- self._state.done = done
416
- self._state.task_complete = reward >= 0.80
417
-
418
- feedback = info.get("feedback") if not done else None
419
- obs = self._build_obs(
420
- step_number=self._state.step_count,
421
- previous_feedback=feedback,
422
  )
423
- return obs, reward, done, info
424
-
425
- def state(self) -> CodeReviewState:
426
- if self._state is None:
427
- return CodeReviewState(
428
- task_id="", difficulty="easy",
429
- step_count=0, done=False,
430
- total_reward=0.0, task_complete=False,
431
- )
432
- return self._state
433
-
434
- # helpers
435
 
436
- def _build_obs(self, step_number: int, previous_feedback: Optional[str]) -> CodeReviewObservation:
437
- t = self._current_task
438
- return CodeReviewObservation(
439
- code_snippet=t["code"],
440
- language=t["language"],
441
- task_description=t["description"],
442
  task_id=t["id"],
 
443
  difficulty=t["difficulty"],
444
- step_number=step_number,
445
- max_steps=MAX_STEPS,
446
- previous_feedback=previous_feedback,
 
447
  )
 
1
+ import random
2
+ from typing import Optional, Dict, Tuple
3
 
4
+ from server.tasks import TASKS
5
+ from server.grader import grade_action
6
+ from server.models import CodeObservation, StepResult, StateResponse, Action, Observation
7
 
8
+ class CodeSecurityEnv:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def __init__(self):
10
+ self.current_task: Optional[dict] = None
11
+ self.step_count: int = 0
12
+ self.done: bool = False
13
+ self.total_reward: float = 0.0
14
+ self._task_ids = list(TASKS.keys())
15
+
16
+ def reset(self, task_id: Optional[str] = None, seed: Optional[int] = None) -> Observation:
17
+ if seed is not None:
18
+ random.seed(seed)
19
+
20
+ if task_id and task_id in TASKS:
21
+ self.current_task = TASKS[task_id]
22
+ else:
23
+ # Pick a task by its ID
24
+ chosen_id = random.choice(self._task_ids)
25
+ self.current_task = TASKS[chosen_id]
26
+
27
+ self.step_count = 0
28
+ self.done = False
29
+ self.total_reward = 0.0
30
+
31
+ return self._make_observation()
32
+
33
+ def step(self, action: Action) -> StepResult:
34
+ if self.current_task is None:
35
+ # Auto-reset if called before reset()
36
+ self.reset()
37
+
38
+ if self.done:
39
+ return StepResult(
40
+ observation=self._make_observation(),
41
+ reward=0.0,
42
+ done=True,
43
+ info={"error": "Episode already completed. Call /reset to start a new episode."},
44
+ )
45
 
46
+ # The action comes from the API as a Pydantic model (Action)
47
+ # The grader expects a dict or the model itself.
48
+ reward, breakdown = grade_action(action, self.current_task)
49
+
50
+ self.step_count += 1
51
+ self.total_reward += reward
52
+ self.done = True # single-step environment β€” one action per episode
53
+
54
+ return StepResult(
55
+ observation=self._make_observation(),
56
+ reward=reward,
57
+ done=self.done,
58
+ info={
59
+ "reward_breakdown": breakdown,
60
+ "task_name": self.current_task.get("name", "Unknown Task"),
61
+ "step_count": self.step_count
62
+ },
63
  )
 
64
 
65
+ def state(self) -> StateResponse:
66
+ current_id = self.current_task["id"] if self.current_task else ""
67
+ return StateResponse(
68
+ task_id=current_id,
69
+ step=self.step_count,
70
+ done=self.done,
71
+ total_reward=self.total_reward,
 
 
 
 
 
 
 
 
 
 
72
  )
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ def _make_observation(self) -> Observation:
75
+ t = self.current_task
76
+ return Observation(
 
 
 
77
  task_id=t["id"],
78
+ language=t["language"],
79
  difficulty=t["difficulty"],
80
+ code_snippet=t["code_snippet"],
81
+ context=t["context"],
82
+ pr_title=t["pr_title"],
83
+ file_path=t["file_path"],
84
  )
server/grader.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Dict
2
+
3
+
4
+ def grade_action(action: dict, task: dict) -> Tuple[float, Dict[str, float]]:
5
+ reward = 0.0
6
+ breakdown: Dict[str, float] = {}
7
+
8
+ # ── Component 1: Bug identified (0.20) ──────────────────────────────────
9
+ if action.get("bug_identified"):
10
+ reward += 0.20
11
+ breakdown["bug_identified"] = 0.20
12
+ else:
13
+ breakdown["bug_identified"] = 0.00
14
+ # No bug found β†’ no partial credit for anything else
15
+ return max(0.0, min(1.0, reward)), breakdown
16
+
17
+ # ── Component 2: Bug type match (0.20) ──────────────────────────────────
18
+ action_type = action.get("bug_type", "").lower().replace("-", " ").replace("_", " ")
19
+ task_type = task["bug_type"].lower().replace("-", " ").replace("_", " ")
20
+ if task_type in action_type or action_type in task_type:
21
+ reward += 0.20
22
+ breakdown["bug_type"] = 0.20
23
+ else:
24
+ breakdown["bug_type"] = 0.00
25
+
26
+ # ── Component 3: Bug location (0.10) ────────────────────────────────────
27
+ action_location = action.get("bug_location", "").lower()
28
+ location_keywords = [w for w in task["bug_location"].lower().split() if len(w) > 3]
29
+ if location_keywords:
30
+ matched = sum(1 for kw in location_keywords if kw in action_location)
31
+ loc_score = round(0.10 * (matched / len(location_keywords)), 4)
32
+ else:
33
+ loc_score = 0.0
34
+ reward += loc_score
35
+ breakdown["bug_location"] = loc_score
36
+
37
+ # ── Component 4: Description quality (0.25) ──────────────────────────────
38
+ description = action.get("bug_description", "").lower()
39
+ desc_score = 0.0
40
+ if len(description) >= 20:
41
+ task_keywords = task["keywords"]
42
+ matched_kw = [kw for kw in task_keywords if kw in description]
43
+ desc_score = round(min(0.25, 0.25 * (len(matched_kw) / max(len(task_keywords), 1))), 4)
44
+ breakdown["description_quality"] = desc_score
45
+ reward += desc_score
46
+
47
+ # ── Component 5: Fix quality (0.15) ──────────────────────────────────────
48
+ fix = action.get("suggested_fix", "").lower()
49
+ fix_score = 0.0
50
+ if len(fix) >= 10:
51
+ fix_patterns = task["fix_patterns"]
52
+ matched_fix = [p for p in fix_patterns if p.lower() in fix]
53
+ fix_score = round(min(0.15, 0.15 * (len(matched_fix) / max(len(fix_patterns), 1)) * 2), 4)
54
+ breakdown["fix_quality"] = fix_score
55
+ reward += fix_score
56
+
57
+ # ── Component 6: Severity (0.10) ─────────────────────────────────────────
58
+ action_sev = action.get("severity", "").lower()
59
+ task_sev = task["severity"].lower()
60
+ if action_sev == task_sev:
61
+ sev_score = 0.10
62
+ elif action_sev in ("high", "critical") and task_sev in ("high", "critical"):
63
+ sev_score = 0.05
64
+ else:
65
+ sev_score = 0.00
66
+ breakdown["severity"] = sev_score
67
+ reward += sev_score
68
+
69
+ # ── Global Penalty: Keyword Stuffing ────────────────────────────────────
70
+ description = action.get("bug_description", "").lower()
71
+ words = description.split()
72
+ unique_ratio = len(set(words)) / len(words) if words else 1.0
73
+ if unique_ratio < 0.7:
74
+ reward *= 0.2 # Heavy global penalty
75
+ breakdown["stuffing_penalty_multiplier"] = 0.2
76
+ for k in list(breakdown.keys()):
77
+ if k != "stuffing_penalty_multiplier":
78
+ breakdown[k] = round(breakdown[k] * 0.2, 4)
79
+
80
+ return max(0.0, min(1.0, round(reward, 4))), breakdown
server/models.py CHANGED
@@ -2,44 +2,63 @@ from pydantic import BaseModel, Field
2
  from typing import Optional, Any, Dict
3
 
4
 
 
 
5
  class CodeReviewAction(BaseModel):
6
  """Action taken by the agent: a structured code review."""
7
  bug_identified: bool = Field(..., description="Whether a bug was found")
8
  bug_location: str = Field(..., description="Location of the bug (function, line, variable)")
9
- bug_type: str = Field(..., description="Type: off-by-one | logic-error | security-vulnerability | null-dereference | none")
10
  bug_description: str = Field(..., description="Detailed explanation of why this is a bug")
11
  severity: str = Field(..., description="Severity: none | low | medium | high | critical")
12
  suggested_fix: str = Field(..., description="The corrected code or a description of how to fix it")
13
 
14
 
15
- class CodeReviewObservation(BaseModel):
 
 
16
  """What the agent sees at each step."""
17
- code_snippet: str = Field(..., description="The code to review")
18
- language: str = Field(..., description="Programming language")
19
- task_description: str = Field(..., description="What the code is supposed to do")
20
  task_id: str = Field(..., description="Unique task identifier")
 
21
  difficulty: str = Field(..., description="Level: easy | medium | hard")
22
- step_number: int = Field(..., description="Current step number within this episode")
23
- max_steps: int = Field(..., description="Maximum steps allowed per episode")
24
- previous_feedback: Optional[str] = Field(None, description="Feedback from previous step if any")
25
-
26
 
27
- class CodeReviewState(BaseModel):
28
- """Internal environment state."""
29
- task_id: str
30
- difficulty: str
31
- step_count: int
32
- done: bool
33
- total_reward: float
34
- task_complete: bool
35
 
 
36
 
37
- class StepResponse(BaseModel):
38
- observation: CodeReviewObservation
 
39
  reward: float
40
  done: bool
41
  info: Dict[str, Any]
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  class ResetResponse(BaseModel):
45
- observation: CodeReviewObservation
 
 
 
 
 
 
 
 
 
 
 
 
2
  from typing import Optional, Any, Dict
3
 
4
 
5
+ # ── Agent Action ──────────────────────────────────────────────────────────────
6
+
7
  class CodeReviewAction(BaseModel):
8
  """Action taken by the agent: a structured code review."""
9
  bug_identified: bool = Field(..., description="Whether a bug was found")
10
  bug_location: str = Field(..., description="Location of the bug (function, line, variable)")
11
+ bug_type: str = Field(..., description="Type: off-by-one | logic-error | security-vulnerability | none")
12
  bug_description: str = Field(..., description="Detailed explanation of why this is a bug")
13
  severity: str = Field(..., description="Severity: none | low | medium | high | critical")
14
  suggested_fix: str = Field(..., description="The corrected code or a description of how to fix it")
15
 
16
 
17
+ # ── Observation ───────────────────────────────────────────────────────────────
18
+
19
+ class CodeObservation(BaseModel):
20
  """What the agent sees at each step."""
 
 
 
21
  task_id: str = Field(..., description="Unique task identifier")
22
+ language: str = Field(..., description="Programming language")
23
  difficulty: str = Field(..., description="Level: easy | medium | hard")
24
+ code_snippet: str = Field(..., description="The code to review")
25
+ context: str = Field(..., description="Production context describing what the code does")
26
+ pr_title: str = Field(..., description="Pull request title submitted by developer")
27
+ file_path: str = Field(..., description="File path of the code in the repository")
28
 
 
 
 
 
 
 
 
 
29
 
30
+ # ── Step Result ───────────────────────────────────────────────────────────────
31
 
32
+ class StepResult(BaseModel):
33
+ """Result returned from env.step()."""
34
+ observation: Optional[CodeObservation] = None
35
  reward: float
36
  done: bool
37
  info: Dict[str, Any]
38
 
39
 
40
+ # ── State ─────────────────────────────────────────────────────────────────────
41
+
42
+ class StateResponse(BaseModel):
43
+ """Internal environment state exposed via /state."""
44
+ task_id: str
45
+ step: int
46
+ done: bool
47
+ total_reward: float
48
+
49
+
50
+ # ── API Helpers ───────────────────────────────────────────────────────────────
51
+
52
  class ResetResponse(BaseModel):
53
+ observation: CodeObservation
54
+
55
+
56
+ class TaskInfo(BaseModel):
57
+ id: str
58
+ language: str
59
+ bug_class: str
60
+ difficulty: str
61
+
62
+ Action = CodeReviewAction
63
+ Observation = CodeObservation
64
+ Reward = float
server/tasks.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASKS = {
2
+ "python-off-by-one": {
3
+ "id": "python-off-by-one",
4
+ "name": "Python Off-by-One Error",
5
+ "language": "Python",
6
+ "difficulty": "easy",
7
+ "bug_class": "Off-by-one index error",
8
+ "pr_title": "Add batch processor for financial transactions",
9
+ "file_path": "finance/batch_processor.py",
10
+ "context": "Finance batch processor that sums transaction amounts for end-of-day reconciliation",
11
+ "code_snippet": (
12
+ "def process_transactions(transactions):\n"
13
+ " total = 0\n"
14
+ " for i in range(len(transactions) + 1): # iterates one past end\n"
15
+ " total += transactions[i][\"amount\"]\n"
16
+ " return total"
17
+ ),
18
+ "bug_type": "off-by-one",
19
+ "bug_location": "line 3 β€” range(len(transactions) + 1)",
20
+ "severity": "critical",
21
+ "keywords": [
22
+ "off-by-one", "index", "range", "indexerror", "out of bounds",
23
+ "boundary", "overflow", "iteration", "list length", "plus one",
24
+ "extra step", "fencepost error", "array access", "iterator",
25
+ "fix", "bug", "identify", "code", "crash", "out-of-range",
26
+ "python", "finance", "batch", "amount", "total", "transactions",
27
+ "iterate", "sum", "loop", "account", "process"
28
+ ],
29
+ "fix_patterns": [
30
+ "range(len(transactions))",
31
+ "len(transactions))",
32
+ "for transaction in transactions",
33
+ "in transactions:",
34
+ "pop()",
35
+ "enumerate(transactions)",
36
+ "transactions[:len(transactions)]",
37
+ "total += transactions[i]"
38
+ ],
39
+ },
40
+
41
+ "js-auth-privilege": {
42
+ "id": "js-auth-privilege",
43
+ "name": "JavaScript Auth Logic Flaw",
44
+ "language": "JavaScript",
45
+ "difficulty": "medium",
46
+ "bug_class": "Logic flaw β€” privilege escalation",
47
+ "pr_title": "Refactor auth middleware for API routes",
48
+ "file_path": "middleware/auth.js",
49
+ "context": "Node.js authentication middleware that restricts admin-only API routes",
50
+ "code_snippet": (
51
+ "function checkAdmin(req, res, next) {\n"
52
+ " const user = req.user;\n"
53
+ " if (user.role !== \"admin\" || user.isActive) {\n"
54
+ " return next();\n"
55
+ " }\n"
56
+ " return res.status(403).json({ error: \"Forbidden\" });\n"
57
+ "}"
58
+ ),
59
+ "bug_type": "logic-error",
60
+ "bug_location": "line 3 β€” incorrect boolean operator || instead of &&",
61
+ "severity": "critical",
62
+ "keywords": [
63
+ "short-circuit disjunction hazard", "logical disjunction vulnerability",
64
+ "excessive authorization scope", "privilege escalation vector",
65
+ "boolean logic flaw pattern", "operator precedence violation",
66
+ "authorization bypass disjunction logic", "improper validation layer check",
67
+ "role check disjunction pattern match", "permission leak evaluation flow",
68
+ "evaluation shortcut logic flaw", "middleware logic hazard state",
69
+ "security constraint bypass", "access control logic inversion"
70
+ ],
71
+ "fix_patterns": [
72
+ "user.role === \"admin\" && user.isActive",
73
+ "&& user.isActive",
74
+ "throw new Error(\"Unauthorized\")",
75
+ "user.role === 'admin' && user.isActive",
76
+ "middleware logic fix"
77
+ ],
78
+ },
79
+
80
+ "python-sql-injection": {
81
+ "id": "python-sql-injection",
82
+ "name": "Python SQL Injection",
83
+ "language": "Python",
84
+ "difficulty": "hard",
85
+ "bug_class": "SQL injection via f-string",
86
+ "pr_title": "Add user search endpoint to REST API",
87
+ "file_path": "api/users.py",
88
+ "context": "REST API endpoint that searches users by name in a PostgreSQL database",
89
+ "code_snippet": (
90
+ "def search_users(db, search_term):\n"
91
+ " query = f\"SELECT * FROM users WHERE name LIKE '%{search_term}%'\"\n"
92
+ " results = db.execute(query)\n"
93
+ " return results.fetchall()"
94
+ ),
95
+ "bug_type": "security-vulnerability",
96
+ "bug_location": "line 2 β€” f-string interpolation directly in SQL query",
97
+ "severity": "critical",
98
+ "keywords": [
99
+ "sql injection", "user-supplied", "search_term", "interpolated", "f-string",
100
+ "attacker", "bypass", "authentication", "exfiltrate", "user data",
101
+ "drop tables", "parameterized", "queries", "sanitize", "input", "automatically"
102
+ ],
103
+ "fix_patterns": [
104
+ "db.execute('SELECT * FROM users WHERE name LIKE %s', ('%'+search_term+'%',))",
105
+ "%s",
106
+ "parameterized",
107
+ "prepared statement"
108
+ ],
109
+ },
110
+ }
validate.sh ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # OpenEnv Submission Validation Script
4
+
5
+ set -e
6
+ echo "═══════════════════════════════════════"
7
+ echo " OpenEnv Pre-Submission Validation"
8
+ echo "═══════════════════════════════════════"
9
+ echo ""
10
+
11
+ # 1. Check for required root files
12
+ echo "── 1. Required Files ──"
13
+ FILES=("openenv.yaml" "inference.py" "README.md" "Dockerfile" "requirements.txt")
14
+ for file in "${FILES[@]}"; do
15
+ if [ -f "$file" ]; then
16
+ echo " βœ… $file"
17
+ else
18
+ echo " ❌ Missing $file"
19
+ exit 1
20
+ fi
21
+ done
22
+ echo ""
23
+
24
+ # 2. Check server/ module structure
25
+ echo "── 2. Server Module Structure ──"
26
+ SERVER_FILES=("server/__init__.py" "server/app.py" "server/models.py" "server/environment.py" "server/tasks.py" "server/grader.py")
27
+ for file in "${SERVER_FILES[@]}"; do
28
+ if [ -f "$file" ]; then
29
+ echo " βœ… $file"
30
+ else
31
+ echo " ❌ Missing $file"
32
+ exit 1
33
+ fi
34
+ done
35
+ echo ""
36
+
37
+ # 3. Activate venv & validate Python imports
38
+ echo "── 3. Python Import Validation ──"
39
+ source venv/bin/activate
40
+ python3 -c "
41
+ from server.tasks import TASKS
42
+ from server.grader import grade_action
43
+ from server.environment import CodeSecurityEnv
44
+ from server.models import CodeReviewAction, CodeObservation, StepResult, StateResponse, ResetResponse, TaskInfo
45
+
46
+ assert len(TASKS) >= 3, f'Expected 3+ tasks, got {len(TASKS)}'
47
+ print(' βœ… All imports resolve correctly')
48
+ print(f' Tasks: {list(TASKS.keys())}')
49
+ " || { echo " ❌ Python import validation failed"; exit 1; }
50
+ echo ""
51
+
52
+ # 4. Quick grader smoke test
53
+ echo "── 4. Grader Smoke Test ──"
54
+ python3 -c "
55
+ from server.environment import CodeSecurityEnv
56
+ from server.models import Action
57
+
58
+ env = CodeSecurityEnv()
59
+ obs = env.reset('python-off-by-one')
60
+ result = env.step(Action(**{
61
+ 'bug_identified': True,
62
+ 'bug_location': 'range(len(transactions) + 1)',
63
+ 'bug_type': 'logic-error',
64
+ 'bug_description': 'Off-by-one index error β€” the range goes one past the end causing an out of bounds IndexError',
65
+ 'severity': 'medium',
66
+ 'suggested_fix': 'Use range(len(transactions)) to fix the boundary',
67
+ }))
68
+ assert 0.0 <= result.reward <= 1.0, f'Reward out of range: {result.reward}'
69
+ assert result.done is True
70
+ print(f' βœ… Grader returned reward={result.reward:.4f}, done={result.done}')
71
+
72
+ # Verify zero-reward path
73
+ env2 = CodeSecurityEnv()
74
+ env2.reset('python-off-by-one')
75
+ r2 = env2.step(Action(**{
76
+ 'bug_identified': False,
77
+ 'bug_location': '',
78
+ 'bug_type': 'none',
79
+ 'bug_description': 'No bug found',
80
+ 'severity': 'none',
81
+ 'suggested_fix': '',
82
+ }))
83
+ assert r2.reward == 0.0, f'Expected 0.0 for no-bug, got {r2.reward}'
84
+ print(f' βœ… No-bug path returns reward=0.0')
85
+ " || { echo " ❌ Grader smoke test failed"; exit 1; }
86
+ echo ""
87
+
88
+ # 5. Validate openenv.yaml
89
+ echo "── 5. openenv.yaml Validation ──"
90
+ python3 -c "
91
+ import yaml
92
+ with open('openenv.yaml', 'r') as f:
93
+ data = yaml.safe_load(f)
94
+ assert 'name' in data, 'Missing name field'
95
+ assert 'tasks' in data, 'Missing tasks field'
96
+ assert len(data['tasks']) >= 3, f'Need 3+ tasks, got {len(data[\"tasks\"])}'
97
+ print(f' βœ… Valid YAML with {len(data[\"tasks\"])} tasks')
98
+ " || { echo " ❌ openenv.yaml validation failed"; exit 1; }
99
+ echo ""
100
+
101
+ echo "═══════════════════════════════════════"
102
+ echo " βœ… All checks passed!"
103
+ echo "═══════════════════════════════════════"