kanishcr7 commited on
Commit
d6abea2
Β·
1 Parent(s): 93d7cd0

Final check:Passed

Browse files
.gitignore CHANGED
@@ -2,6 +2,7 @@
2
  __pycache__/
3
  *.py[codz]
4
  *$py.class
 
5
 
6
  # C extensions
7
  *.so
 
2
  __pycache__/
3
  *.py[codz]
4
  *$py.class
5
+ wandb/
6
 
7
  # C extensions
8
  *.so
inference.py CHANGED
@@ -2,7 +2,7 @@
2
  """
3
  PatchHawk inference script β€” runs the LLM agent loop against the
4
  OpenEnv-compliant PatchHawkEnv.
5
-
6
  Environment variables:
7
  API_BASE_URL – OpenAI-compatible API endpoint (required unless DRY_RUN=1)
8
  MODEL_NAME – Model identifier (default: meta-llama/Llama-3.2-3B-Instruct)
@@ -29,11 +29,17 @@ from patchhawk.env_models import PatchHawkAction, PatchHawkObservation, PatchHaw
29
  from patchhawk import tasks as graders
30
 
31
  # ── Configuration ────────────────────────────────────────────────────
 
 
 
 
 
32
 
33
  API_BASE_URL = os.getenv(
34
  "API_BASE_URL", "https://router.huggingface.co/hf-inference/v1"
35
  )
36
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-Coder-32B-Instruct")
 
37
  HF_TOKEN = os.getenv("HF_TOKEN", "")
38
  DRY_RUN = os.getenv("DRY_RUN", "0") == "1"
39
  SINGLE_TASK = os.getenv("TASK", "")
@@ -59,19 +65,53 @@ TASK_DEFS = [
59
  # ── Prompt builder ───────────────────────────────────────────────────
60
 
61
  SYSTEM_PROMPT = """\
62
- You are PatchHawk, a security agent that detects supply-chain vulnerabilities
63
- in Python code. You will be given a code snippet and static analysis flags.
 
64
 
65
- Respond with a JSON object containing:
66
  {
67
- "action_type": <int>, // 0=ANALYZE, 1=EXECUTE_SANDBOX, 2=BLOCK_PR, 3=SUBMIT_PATCH, 4=REQUEST_REVIEW
68
- "patch_content": <str|null> // required if action_type == 3
 
 
69
  }
70
 
71
- Be decisive. If the code is clearly malicious, BLOCK_PR (2). If you can
72
- generate a patch that removes the vulnerability, use SUBMIT_PATCH (3).
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  """
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  def _build_user_prompt(obs: PatchHawkObservation, step: int) -> str:
77
  parts = [
@@ -89,37 +129,119 @@ def _build_user_prompt(obs: PatchHawkObservation, step: int) -> str:
89
  # ── LLM caller ───────────────────────────────────────────────────────
90
 
91
 
92
- def _call_llm(messages: list[dict]) -> str:
93
- """Call the OpenAI-compatible LLM and return the text content."""
94
- from openai import OpenAI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- client = OpenAI(
97
- base_url=API_BASE_URL,
98
- api_key=HF_TOKEN or "no-key",
 
 
99
  )
100
- response = client.chat.completions.create(
101
- model=MODEL_NAME,
102
- messages=messages,
 
 
 
103
  temperature=0.2,
104
- max_tokens=512,
105
  )
106
- return response.choices[0].message.content or ""
 
 
 
 
 
 
 
 
 
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  def _parse_action(text: str) -> PatchHawkAction:
110
  """Parse LLM response text into a PatchHawkAction."""
111
- # Try to extract JSON from the response
112
  text = text.strip()
113
- # Handle markdown code blocks
114
  if "```json" in text:
115
  text = text.split("```json")[1].split("```")[0].strip()
116
- elif "```" in text:
117
  text = text.split("```")[1].split("```")[0].strip()
118
 
119
- data = json.loads(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  return PatchHawkAction(
121
- action_type=int(data["action_type"]),
122
- patch_content=data.get("patch_content"),
 
 
123
  )
124
 
125
 
 
2
  """
3
  PatchHawk inference script β€” runs the LLM agent loop against the
4
  OpenEnv-compliant PatchHawkEnv.
5
+ a
6
  Environment variables:
7
  API_BASE_URL – OpenAI-compatible API endpoint (required unless DRY_RUN=1)
8
  MODEL_NAME – Model identifier (default: meta-llama/Llama-3.2-3B-Instruct)
 
29
  from patchhawk import tasks as graders
30
 
31
  # ── Configuration ────────────────────────────────────────────────────
32
+ try:
33
+ from dotenv import load_dotenv
34
+ load_dotenv()
35
+ except ImportError:
36
+ pass
37
 
38
  API_BASE_URL = os.getenv(
39
  "API_BASE_URL", "https://router.huggingface.co/hf-inference/v1"
40
  )
41
+ # Prefer explicit MODEL_NAME, fallback to GRPO_POLICY_MODEL from .env, then default to 32B model.
42
+ MODEL_NAME = os.getenv("MODEL_NAME", os.getenv("GRPO_POLICY_MODEL", "Qwen/Qwen2.5-Coder-32B-Instruct"))
43
  HF_TOKEN = os.getenv("HF_TOKEN", "")
44
  DRY_RUN = os.getenv("DRY_RUN", "0") == "1"
45
  SINGLE_TASK = os.getenv("TASK", "")
 
65
  # ── Prompt builder ───────────────────────────────────────────────────
66
 
67
  SYSTEM_PROMPT = """\
68
+ You are PatchHawk, a security agent that detects supply-chain vulnerabilities in Python code.
69
+
70
+ Given a code snippet and static analysis flags, you must respond **EXACTLY** with a single JSON object. No extra text, no markdown formatting.
71
 
72
+ ## Output JSON Schema
73
  {
74
+ "reasoning": "<str>", // Mandatory: Explain what the vulnerability is, why it's dangerous, and your recommended fix (if any).
75
+ "risk_score": <float>, // 0.0 (no risk) to 1.0 (critical). Be precise to two decimals.
76
+ "action_type": <int>, // One of: 0=ANALYZE, 1=EXECUTE_SANDBOX, 2=BLOCK_PR, 3=SUBMIT_PATCH, 4=REQUEST_REVIEW
77
+ "patch_content": "<str|null>" // Full patched code if action_type=3, otherwise null. Must be valid Python.
78
  }
79
 
80
+ ## Action Type Guidelines
81
+ - **0 ANALYZE** – No immediate threat, but needs deeper review.
82
+ - **1 EXECUTE_SANDBOX** – Suspicious but not obviously malicious; run in isolated environment.
83
+ - **2 BLOCK_PR** – Severely malicious, unfixable (e.g., hidden backdoor, remote shell). Reject PR.
84
+ - **3 SUBMIT_PATCH** – Vulnerability can be fixed. Provide corrected code in `patch_content`.
85
+ - **4 REQUEST_REVIEW** – Complex or ambiguous; require human expert.
86
+
87
+ ## Rules
88
+ - `reasoning` must be thorough: describe the flaw, its impact (CWE if known), and step‑by‑step how to patch.
89
+ - Escape all double quotes inside strings with backslash (`\"`).
90
+ - If the code is benign, set `risk_score` ≀ 0.2, `action_type` = 0, and `patch_content` = null.
91
+ - Never include comments or explanations outside the JSON object.
92
+
93
+ **Example valid response:**
94
+ {"reasoning": "Hardcoded password 'admin123' in __init__ allows credential bypass. Replace with env var.", "risk_score": 0.85, "action_type": 3, "patch_content": "import os\\nclass Malicious:\\n def __init__(self):\\n self.cache = []\\n self.password = os.getenv('DB_PASS')\\n ..."}
95
  """
96
 
97
+ # SYSTEM_PROMPT = """\
98
+ # You are PatchHawk, a security agent that detects supply-chain vulnerabilities
99
+ # in Python code. You will be given a code snippet and static analysis flags.
100
+
101
+ # Respond EXACTLY with a JSON object containing the following keys:
102
+ # {
103
+ # "reasoning": "<str>", // Step-by-step explanation of what the vulnerability is, why you are blocking/patching it, and how it can be fixed.
104
+ # "risk_score": <float>, // Your predicted risk score from 0.0 to 1.0 based on your analysis
105
+ # "action_type": <int>, // 0=ANALYZE, 1=EXECUTE_SANDBOX, 2=BLOCK_PR, 3=SUBMIT_PATCH, 4=REQUEST_REVIEW
106
+ # "patch_content": "<str|null>" // The full patched python code fixing the vulnerability
107
+ # }
108
+
109
+ # Be decisive. First, explain your findings thoroughly in the "reasoning" field.
110
+ # If the code is malicious but you can fix the vulnerability, use SUBMIT_PATCH (3) and provide the safe, corrected code in "patch_content".
111
+ # If the code is severely malicious and completely unfixable, use BLOCK_PR (2).
112
+ # IMPORTANT: Ensure your output is perfectly VALID JSON. Escape all double quotes inside strings properly.
113
+ # """
114
+
115
 
116
  def _build_user_prompt(obs: PatchHawkObservation, step: int) -> str:
117
  parts = [
 
129
  # ── LLM caller ───────────────────────────────────────────────────────
130
 
131
 
132
+ _local_pipeline = None
133
+
134
+ def _call_llm_local(messages: list[dict]) -> str:
135
+ """Call a local HuggingFace model using transformers pipeline if remote API fails."""
136
+ global _local_pipeline
137
+ if _local_pipeline is None:
138
+ import torch
139
+ from transformers import pipeline
140
+
141
+ # User is already using this model in .env GRPO_POLICY_MODEL
142
+ local_model = os.getenv("GRPO_POLICY_MODEL", "unsloth/Qwen2.5-Coder-3B-Instruct")
143
+ print(f"\n[Fallback] Loading local model: {local_model} into memory. This may take a moment...", flush=True)
144
+
145
+ _local_pipeline = pipeline(
146
+ "text-generation",
147
+ model=local_model,
148
+ model_kwargs={"torch_dtype": torch.bfloat16}, # Half-precision to save VRAM natively fit on 12GB
149
+ device_map="auto"
150
+ )
151
+ print("[Fallback] Local model loaded successfully.\n", flush=True)
152
 
153
+ # Format messages array to a standard conversational string format
154
+ prompt = _local_pipeline.tokenizer.apply_chat_template(
155
+ messages,
156
+ tokenize=False,
157
+ add_generation_prompt=True
158
  )
159
+
160
+ # Run Generation
161
+ outputs = _local_pipeline(
162
+ prompt,
163
+ max_new_tokens=2048,
164
+ do_sample=True,
165
  temperature=0.2,
 
166
  )
167
+
168
+ generated = outputs[0]["generated_text"]
169
+
170
+ print(f"\ngenerated:{generated}\n")
171
+ # Strip prompt from returned generated output
172
+ if generated.startswith(prompt):
173
+ generated = generated[len(prompt):]
174
+
175
+ return generated.strip()
176
+
177
 
178
+ def _call_llm(messages: list[dict]) -> str:
179
+ """Call the OpenAI-compatible LLM and return the text content."""
180
+ from openai import OpenAI
181
+
182
+ try:
183
+ client = OpenAI(
184
+ base_url=API_BASE_URL,
185
+ api_key=HF_TOKEN or "no-key",
186
+ )
187
+ response = client.chat.completions.create(
188
+ model=MODEL_NAME,
189
+ messages=messages,
190
+ temperature=0.2,
191
+ max_tokens=512,
192
+ )
193
+ return response.choices[0].message.content or ""
194
+ except Exception as e:
195
+ print(f"[LLM ERROR] Remote API failed: {e}. Initiating local Fallback...", flush=True)
196
+ return _call_llm_local(messages)
197
+
198
+
199
+ import re
200
 
201
  def _parse_action(text: str) -> PatchHawkAction:
202
  """Parse LLM response text into a PatchHawkAction."""
 
203
  text = text.strip()
 
204
  if "```json" in text:
205
  text = text.split("```json")[1].split("```")[0].strip()
206
+ elif "```" in text and not text.startswith("{"):
207
  text = text.split("```")[1].split("```")[0].strip()
208
 
209
+ def clean_patch(p: str) -> str:
210
+ if not p: return p
211
+ if "```python" in p:
212
+ return p.split("```python")[1].split("```")[0].strip()
213
+ if "```" in p:
214
+ return p.split("```")[1].split("```")[0].strip()
215
+ return p
216
+
217
+ try:
218
+ data = json.loads(text)
219
+ except json.JSONDecodeError:
220
+ action_match = re.search(r'"action_type"\s*:\s*(\d+)', text)
221
+ action_type = int(action_match.group(1)) if action_match else 2
222
+
223
+ risk_match = re.search(r'"risk_score"\s*:\s*([\d\.]+)', text)
224
+ risk_score = float(risk_match.group(1)) if risk_match else None
225
+
226
+ patch_match = re.search(r'"patch_content"\s*:\s*"(.*)', text, re.DOTALL)
227
+ patch_content = None
228
+ if patch_match:
229
+ raw_patch = patch_match.group(1).rsplit('"', 1)[0]
230
+ raw_patch = raw_patch.replace("\\n", "\n").replace('\\"', '"').replace("\\\\", "\\")
231
+ patch_content = clean_patch(raw_patch)
232
+
233
+ return PatchHawkAction(
234
+ action_type=action_type,
235
+ reasoning="JSON Error/Truncated Output. Recovered partial data.",
236
+ predicted_risk=risk_score,
237
+ patch_content=patch_content
238
+ )
239
+
240
  return PatchHawkAction(
241
+ action_type=int(data.get("action_type", 2)),
242
+ patch_content=clean_patch(data.get("patch_content")),
243
+ reasoning=data.get("reasoning"),
244
+ predicted_risk=data.get("risk_score"),
245
  )
246
 
247
 
patchhawk/agent/environment.py CHANGED
@@ -308,8 +308,16 @@ class PatchHawkEnv(Environment[PatchHawkAction, PatchHawkObservation, PatchHawkS
308
  if self.step_counter >= self.max_steps and not done:
309
  done = True
310
  if label == "malicious":
311
- reward = -5.0
312
- reason = "max steps reached on malicious scenario"
 
 
 
 
 
 
 
 
313
 
314
  self.cumulative_reward += reward
315
 
 
308
  if self.step_counter >= self.max_steps and not done:
309
  done = True
310
  if label == "malicious":
311
+ reward -= 5.0
312
+ reason += " | max steps reached on malicious scenario"
313
+
314
+ # ── Dynamic Risk Bonus ───────────────────────────────────
315
+ predict_risk = getattr(action, "predicted_risk", None)
316
+ if predict_risk is not None:
317
+ actual_risk = 1.0 if label == "malicious" else 0.0
318
+ accuracy_bonus = (1.0 - abs(actual_risk - float(predict_risk))) * 2.0
319
+ reward += accuracy_bonus
320
+ reason += f" | AI risk accuracy bonus: +{accuracy_bonus:.2f}"
321
 
322
  self.cumulative_reward += reward
323
 
patchhawk/agent/sandbox.py CHANGED
@@ -31,7 +31,7 @@ def run_code(
31
  temp_dir = tempfile.mkdtemp(prefix="patchhawk_sandbox_")
32
  script_path = os.path.join(temp_dir, "script.py")
33
 
34
- with open(script_path, "w") as f:
35
  f.write(code)
36
 
37
  result: Dict[str, Any] = {
@@ -91,7 +91,7 @@ def check_syntax(
91
  temp_dir = tempfile.mkdtemp(prefix="patchhawk_syntax_")
92
  script_path = os.path.join(temp_dir, "script.py")
93
 
94
- with open(script_path, "w") as f:
95
  f.write(code)
96
 
97
  try:
@@ -107,7 +107,7 @@ def check_syntax(
107
  "--cpus",
108
  "0.5",
109
  "-v",
110
- f"{temp_dir}:/app:ro",
111
  "patchhawk-sandbox:latest",
112
  "python",
113
  "-m",
 
31
  temp_dir = tempfile.mkdtemp(prefix="patchhawk_sandbox_")
32
  script_path = os.path.join(temp_dir, "script.py")
33
 
34
+ with open(script_path, "w", encoding="utf-8") as f:
35
  f.write(code)
36
 
37
  result: Dict[str, Any] = {
 
91
  temp_dir = tempfile.mkdtemp(prefix="patchhawk_syntax_")
92
  script_path = os.path.join(temp_dir, "script.py")
93
 
94
+ with open(script_path, "w", encoding="utf-8") as f:
95
  f.write(code)
96
 
97
  try:
 
107
  "--cpus",
108
  "0.5",
109
  "-v",
110
+ f"{temp_dir}:/app:rw",
111
  "patchhawk-sandbox:latest",
112
  "python",
113
  "-m",
patchhawk/app/dashboard.py CHANGED
@@ -181,7 +181,10 @@ def main():
181
  final_action_type = PatchHawkEnv.ACTION_BLOCK_PR
182
  else:
183
  final_action_type = PatchHawkEnv.ACTION_REQUEST_REVIEW
184
- action = PatchHawkAction(action_type=final_action_type)
 
 
 
185
 
186
  # Visual Hacker Terminal Effect
187
  if final_action_type == PatchHawkEnv.ACTION_SUBMIT_PATCH:
@@ -219,8 +222,13 @@ def main():
219
  with st.expander("πŸ€– Agent Thought Process (LLM Trace)"):
220
  st.markdown(f"```json\n{llm_thought_process}\n```")
221
 
 
 
 
 
 
222
  m1, m2, m3 = st.columns(3)
223
- m1.metric("Risk Score", f"{risk:.2f}")
224
  m2.metric("Decision", PatchHawkEnv.ACTION_NAMES[final_action_type])
225
  m3.metric("Reward", f"{total_reward:+.2f}")
226
 
@@ -229,6 +237,10 @@ def main():
229
  )
230
 
231
  with tab1:
 
 
 
 
232
  if final_action_type == PatchHawkEnv.ACTION_BLOCK_PR:
233
  st.markdown(
234
  "<div class='info-box status-malicious'>β›” BLOCKED β€” "
@@ -253,10 +265,13 @@ def main():
253
 
254
  with tab2:
255
  telem = obs.metadata.get("telemetry")
 
256
  if telem:
257
  st.json(telem)
 
 
258
  else:
259
- st.info("No sandbox execution for this path.")
260
 
261
  with tab3:
262
  if final_action_type == PatchHawkEnv.ACTION_SUBMIT_PATCH and scenario.get(
 
181
  final_action_type = PatchHawkEnv.ACTION_BLOCK_PR
182
  else:
183
  final_action_type = PatchHawkEnv.ACTION_REQUEST_REVIEW
184
+ action = PatchHawkAction(
185
+ action_type=final_action_type,
186
+ reasoning="Static rule-based fallback decision due to high risk score."
187
+ )
188
 
189
  # Visual Hacker Terminal Effect
190
  if final_action_type == PatchHawkEnv.ACTION_SUBMIT_PATCH:
 
222
  with st.expander("πŸ€– Agent Thought Process (LLM Trace)"):
223
  st.markdown(f"```json\n{llm_thought_process}\n```")
224
 
225
+ # Opt for LLM's predicted risk score if available
226
+ display_risk = getattr(action, "predicted_risk", None)
227
+ if display_risk is None:
228
+ display_risk = risk
229
+
230
  m1, m2, m3 = st.columns(3)
231
+ m1.metric("Risk Score", f"{float(display_risk):.2f}")
232
  m2.metric("Decision", PatchHawkEnv.ACTION_NAMES[final_action_type])
233
  m3.metric("Reward", f"{total_reward:+.2f}")
234
 
 
237
  )
238
 
239
  with tab1:
240
+ if hasattr(action, "reasoning") and action.reasoning:
241
+ st.markdown("### 🧠 Agent's Reasoning")
242
+ st.info(action.reasoning)
243
+
244
  if final_action_type == PatchHawkEnv.ACTION_BLOCK_PR:
245
  st.markdown(
246
  "<div class='info-box status-malicious'>β›” BLOCKED β€” "
 
265
 
266
  with tab2:
267
  telem = obs.metadata.get("telemetry")
268
+ details = obs.metadata.get("details")
269
  if telem:
270
  st.json(telem)
271
+ elif dict(details) if details else None:
272
+ st.json(details)
273
  else:
274
+ st.info("No sandbox telemetry generated for this action.")
275
 
276
  with tab3:
277
  if final_action_type == PatchHawkEnv.ACTION_SUBMIT_PATCH and scenario.get(
patchhawk/data/generate_scenarios.py CHANGED
@@ -128,12 +128,16 @@ def auto_generate_unit_test(filename: str, code: str) -> str:
128
  # ============================================================
129
 
130
 
131
- def generate_track_b_scenarios(benign_files: list) -> list:
132
- """Generate β‰₯ 50 scenarios: 25 TP, 15 FP, 15 functional."""
133
  scenarios = []
134
 
135
- # ── True Positives (25) ──────────────────────────────────
136
- for i in range(25):
 
 
 
 
137
  bf = random.choice(benign_files)
138
  attack_name, attack_data = random.choice(list(ATTACK_TEMPLATES.items()))
139
  malicious_code = attack_data["inject"] + bf["code"]
@@ -187,7 +191,7 @@ def generate_track_b_scenarios(benign_files: list) -> list:
187
  "result = subprocess.run(['echo', 'build ok'], capture_output=True)\n\n",
188
  ),
189
  ]
190
- for i in range(15):
191
  bf = random.choice(benign_files)
192
  fp_name, fp_code = random.choice(fp_templates)
193
  suspicious_code = fp_code + bf["code"]
@@ -205,8 +209,8 @@ def generate_track_b_scenarios(benign_files: list) -> list:
205
  }
206
  )
207
 
208
- # ── Functional / Clean (15) ──────────────────────────────
209
- for i in range(15):
210
  bf = random.choice(benign_files)
211
  test_code = auto_generate_unit_test(bf["filename"], bf["code"])
212
  scenarios.append(
@@ -222,7 +226,7 @@ def generate_track_b_scenarios(benign_files: list) -> list:
222
  }
223
  )
224
 
225
- return scenarios # 55 total from Track B alone
226
 
227
 
228
  # ============================================================
@@ -486,6 +490,12 @@ def main():
486
  type=str,
487
  default="patchhawk/data/scenarios.json",
488
  )
 
 
 
 
 
 
489
  parser.add_argument(
490
  "--use-sdk",
491
  action="store_true",
@@ -535,7 +545,7 @@ def main():
535
  return
536
 
537
  # Track B (always)
538
- scenarios = generate_track_b_scenarios(benign_files)
539
 
540
  # Track A (optional)
541
  if args.use_sdk:
 
128
  # ============================================================
129
 
130
 
131
+ def generate_track_b_scenarios(benign_files: list, num_samples: int = 55) -> list:
132
+ """Generate proportional scenarios dynamically based on num_samples."""
133
  scenarios = []
134
 
135
+ tp_count = int(num_samples * 0.45)
136
+ fp_count = int(num_samples * 0.27)
137
+ fn_count = num_samples - tp_count - fp_count
138
+
139
+ # ── True Positives (45%) ──────────────────────────────────
140
+ for i in range(tp_count):
141
  bf = random.choice(benign_files)
142
  attack_name, attack_data = random.choice(list(ATTACK_TEMPLATES.items()))
143
  malicious_code = attack_data["inject"] + bf["code"]
 
191
  "result = subprocess.run(['echo', 'build ok'], capture_output=True)\n\n",
192
  ),
193
  ]
194
+ for i in range(fp_count):
195
  bf = random.choice(benign_files)
196
  fp_name, fp_code = random.choice(fp_templates)
197
  suspicious_code = fp_code + bf["code"]
 
209
  }
210
  )
211
 
212
+ # ── Functional / Clean (28%) ──────────────────────────────
213
+ for i in range(fn_count):
214
  bf = random.choice(benign_files)
215
  test_code = auto_generate_unit_test(bf["filename"], bf["code"])
216
  scenarios.append(
 
226
  }
227
  )
228
 
229
+ return scenarios
230
 
231
 
232
  # ============================================================
 
490
  type=str,
491
  default="patchhawk/data/scenarios.json",
492
  )
493
+ parser.add_argument(
494
+ "--num-samples",
495
+ type=int,
496
+ default=55,
497
+ help="Number of scenarios to generate with Track B (mutation engine).",
498
+ )
499
  parser.add_argument(
500
  "--use-sdk",
501
  action="store_true",
 
545
  return
546
 
547
  # Track B (always)
548
+ scenarios = generate_track_b_scenarios(benign_files, args.num_samples)
549
 
550
  # Track A (optional)
551
  if args.use_sdk:
patchhawk/data/scenarios.json CHANGED
The diff for this file is too large to render. See raw diff
 
patchhawk/env_models.py CHANGED
@@ -53,6 +53,12 @@ class PatchHawkAction(Action):
53
  patch_content: Optional[str] = Field(
54
  None, description="The unified context patch if action is SUBMIT_PATCH"
55
  )
 
 
 
 
 
 
56
 
57
 
58
  # ── State ────────────────────────────────────────────────────────────
 
53
  patch_content: Optional[str] = Field(
54
  None, description="The unified context patch if action is SUBMIT_PATCH"
55
  )
56
+ reasoning: Optional[str] = Field(
57
+ None, description="Explanation of the vulnerability and chosen action"
58
+ )
59
+ predicted_risk: Optional[float] = Field(
60
+ None, description="LLM predicted risk score (0.0 to 1.0)"
61
+ )
62
 
63
 
64
  # ── State ────────────────────────────────────────────────────────────
patchhawk/training/train_grpo.py CHANGED
@@ -33,6 +33,7 @@ def _build_prompt(scenario: dict) -> str:
33
  f"<code_snippet>\n{scenario['code_snippet']}\n</code_snippet>\n"
34
  "Respond in STRICT XML:\n"
35
  "<thought>...</thought>\n"
 
36
  "<action>0-4</action>\n"
37
  "<patch>...</patch> (ONLY if action=3)\n"
38
  )
@@ -90,7 +91,10 @@ def train_agent(args):
90
  else:
91
  print("No GPU found β€” training will be slow.")
92
 
93
- MODEL_NAME = "Qwen/Qwen2.5-Coder-3B-Instruct"
 
 
 
94
 
95
  # 4‑bit quantisation config
96
  bnb_config = BitsAndBytesConfig(
@@ -147,6 +151,10 @@ def train_agent(args):
147
  score += 0.5
148
  else:
149
  score -= 1.0
 
 
 
 
150
  if re.search(r"<action>[0-4]</action>", text):
151
  score += 0.5
152
  else:
@@ -194,12 +202,22 @@ def train_agent(args):
194
  patch_match = re.search(r"<patch>(.*?)</patch>", text, re.DOTALL)
195
  if patch_match:
196
  patch = patch_match.group(1).strip()
 
 
 
197
 
198
  try:
199
  # Reset environment to the exact scenario
200
- env.reset(scenario_idx=env.scenarios.index(scenario))
201
- obs = env.step(PatchHawkAction(action_type=action_type, patch_content=patch))
202
- rewards.append(float(obs.reward or 0.0))
 
 
 
 
 
 
 
203
  except Exception as exc:
204
  print(f"env_reward crash: {exc}")
205
  rewards.append(-3.0)
 
33
  f"<code_snippet>\n{scenario['code_snippet']}\n</code_snippet>\n"
34
  "Respond in STRICT XML:\n"
35
  "<thought>...</thought>\n"
36
+ "<risk_score>0.0 to 1.0</risk_score>\n"
37
  "<action>0-4</action>\n"
38
  "<patch>...</patch> (ONLY if action=3)\n"
39
  )
 
91
  else:
92
  print("No GPU found β€” training will be slow.")
93
 
94
+ from dotenv import load_dotenv
95
+ load_dotenv()
96
+
97
+ MODEL_NAME = os.getenv("GRPO_POLICY_MODEL", "Qwen/Qwen2.5-Coder-3B-Instruct")
98
 
99
  # 4‑bit quantisation config
100
  bnb_config = BitsAndBytesConfig(
 
151
  score += 0.5
152
  else:
153
  score -= 1.0
154
+ if re.search(r"<risk_score>[\d\.]+</risk_score>", text):
155
+ score += 0.5
156
+ else:
157
+ score -= 1.0
158
  if re.search(r"<action>[0-4]</action>", text):
159
  score += 0.5
160
  else:
 
202
  patch_match = re.search(r"<patch>(.*?)</patch>", text, re.DOTALL)
203
  if patch_match:
204
  patch = patch_match.group(1).strip()
205
+
206
+ risk_match = re.search(r"<risk_score>([\d\.]+)</risk_score>", text)
207
+ predicted_risk = float(risk_match.group(1)) if risk_match else None
208
 
209
  try:
210
  # Reset environment to the exact scenario
211
+ env.reset(scenario=scenario)
212
+ obs = env.step(PatchHawkAction(
213
+ action_type=action_type,
214
+ patch_content=patch,
215
+ predicted_risk=predicted_risk
216
+ ))
217
+ reward_val = float(obs.reward or 0.0)
218
+ rewards.append(reward_val)
219
+ val_msg = obs.metadata.get('validation') or ("Telemetry Extracted" if obs.metadata.get('telemetry') else "None")
220
+ print(f"[Env Reward] Action: {action_type} | Reward: {reward_val:+.2f} | Docker: {val_msg}")
221
  except Exception as exc:
222
  print(f"env_reward crash: {exc}")
223
  rewards.append(-3.0)