Spaces:
Running
Running
Final check:Passed
Browse files- .gitignore +1 -0
- inference.py +148 -26
- patchhawk/agent/environment.py +10 -2
- patchhawk/agent/sandbox.py +3 -3
- patchhawk/app/dashboard.py +18 -3
- patchhawk/data/generate_scenarios.py +19 -9
- patchhawk/data/scenarios.json +0 -0
- patchhawk/env_models.py +6 -0
- patchhawk/training/train_grpo.py +22 -4
.gitignore
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
__pycache__/
|
| 3 |
*.py[codz]
|
| 4 |
*$py.class
|
|
|
|
| 5 |
|
| 6 |
# C extensions
|
| 7 |
*.so
|
|
|
|
| 2 |
__pycache__/
|
| 3 |
*.py[codz]
|
| 4 |
*$py.class
|
| 5 |
+
wandb/
|
| 6 |
|
| 7 |
# C extensions
|
| 8 |
*.so
|
inference.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
"""
|
| 3 |
PatchHawk inference script β runs the LLM agent loop against the
|
| 4 |
OpenEnv-compliant PatchHawkEnv.
|
| 5 |
-
|
| 6 |
Environment variables:
|
| 7 |
API_BASE_URL β OpenAI-compatible API endpoint (required unless DRY_RUN=1)
|
| 8 |
MODEL_NAME β Model identifier (default: meta-llama/Llama-3.2-3B-Instruct)
|
|
@@ -29,11 +29,17 @@ from patchhawk.env_models import PatchHawkAction, PatchHawkObservation, PatchHaw
|
|
| 29 |
from patchhawk import tasks as graders
|
| 30 |
|
| 31 |
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
API_BASE_URL = os.getenv(
|
| 34 |
"API_BASE_URL", "https://router.huggingface.co/hf-inference/v1"
|
| 35 |
)
|
| 36 |
-
|
|
|
|
| 37 |
HF_TOKEN = os.getenv("HF_TOKEN", "")
|
| 38 |
DRY_RUN = os.getenv("DRY_RUN", "0") == "1"
|
| 39 |
SINGLE_TASK = os.getenv("TASK", "")
|
|
@@ -59,19 +65,53 @@ TASK_DEFS = [
|
|
| 59 |
# ββ Prompt builder βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 60 |
|
| 61 |
SYSTEM_PROMPT = """\
|
| 62 |
-
You are PatchHawk, a security agent that detects supply-chain vulnerabilities
|
| 63 |
-
|
|
|
|
| 64 |
|
| 65 |
-
|
| 66 |
{
|
| 67 |
-
"
|
| 68 |
-
"
|
|
|
|
|
|
|
| 69 |
}
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
"""
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
def _build_user_prompt(obs: PatchHawkObservation, step: int) -> str:
|
| 77 |
parts = [
|
|
@@ -89,37 +129,119 @@ def _build_user_prompt(obs: PatchHawkObservation, step: int) -> str:
|
|
| 89 |
# ββ LLM caller βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 90 |
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
| 99 |
)
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
| 103 |
temperature=0.2,
|
| 104 |
-
max_tokens=512,
|
| 105 |
)
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
def _parse_action(text: str) -> PatchHawkAction:
|
| 110 |
"""Parse LLM response text into a PatchHawkAction."""
|
| 111 |
-
# Try to extract JSON from the response
|
| 112 |
text = text.strip()
|
| 113 |
-
# Handle markdown code blocks
|
| 114 |
if "```json" in text:
|
| 115 |
text = text.split("```json")[1].split("```")[0].strip()
|
| 116 |
-
elif "```" in text:
|
| 117 |
text = text.split("```")[1].split("```")[0].strip()
|
| 118 |
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
return PatchHawkAction(
|
| 121 |
-
action_type=int(data
|
| 122 |
-
patch_content=data.get("patch_content"),
|
|
|
|
|
|
|
| 123 |
)
|
| 124 |
|
| 125 |
|
|
|
|
| 2 |
"""
|
| 3 |
PatchHawk inference script β runs the LLM agent loop against the
|
| 4 |
OpenEnv-compliant PatchHawkEnv.
|
| 5 |
+
a
|
| 6 |
Environment variables:
|
| 7 |
API_BASE_URL β OpenAI-compatible API endpoint (required unless DRY_RUN=1)
|
| 8 |
MODEL_NAME β Model identifier (default: meta-llama/Llama-3.2-3B-Instruct)
|
|
|
|
| 29 |
from patchhawk import tasks as graders
|
| 30 |
|
| 31 |
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
+
try:
|
| 33 |
+
from dotenv import load_dotenv
|
| 34 |
+
load_dotenv()
|
| 35 |
+
except ImportError:
|
| 36 |
+
pass
|
| 37 |
|
| 38 |
API_BASE_URL = os.getenv(
|
| 39 |
"API_BASE_URL", "https://router.huggingface.co/hf-inference/v1"
|
| 40 |
)
|
| 41 |
+
# Prefer explicit MODEL_NAME, fallback to GRPO_POLICY_MODEL from .env, then default to 32B model.
|
| 42 |
+
MODEL_NAME = os.getenv("MODEL_NAME", os.getenv("GRPO_POLICY_MODEL", "Qwen/Qwen2.5-Coder-32B-Instruct"))
|
| 43 |
HF_TOKEN = os.getenv("HF_TOKEN", "")
|
| 44 |
DRY_RUN = os.getenv("DRY_RUN", "0") == "1"
|
| 45 |
SINGLE_TASK = os.getenv("TASK", "")
|
|
|
|
| 65 |
# ββ Prompt builder βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 66 |
|
| 67 |
SYSTEM_PROMPT = """\
|
| 68 |
+
You are PatchHawk, a security agent that detects supply-chain vulnerabilities in Python code.
|
| 69 |
+
|
| 70 |
+
Given a code snippet and static analysis flags, you must respond **EXACTLY** with a single JSON object. No extra text, no markdown formatting.
|
| 71 |
|
| 72 |
+
## Output JSON Schema
|
| 73 |
{
|
| 74 |
+
"reasoning": "<str>", // Mandatory: Explain what the vulnerability is, why it's dangerous, and your recommended fix (if any).
|
| 75 |
+
"risk_score": <float>, // 0.0 (no risk) to 1.0 (critical). Be precise to two decimals.
|
| 76 |
+
"action_type": <int>, // One of: 0=ANALYZE, 1=EXECUTE_SANDBOX, 2=BLOCK_PR, 3=SUBMIT_PATCH, 4=REQUEST_REVIEW
|
| 77 |
+
"patch_content": "<str|null>" // Full patched code if action_type=3, otherwise null. Must be valid Python.
|
| 78 |
}
|
| 79 |
|
| 80 |
+
## Action Type Guidelines
|
| 81 |
+
- **0 ANALYZE** β No immediate threat, but needs deeper review.
|
| 82 |
+
- **1 EXECUTE_SANDBOX** β Suspicious but not obviously malicious; run in isolated environment.
|
| 83 |
+
- **2 BLOCK_PR** β Severely malicious, unfixable (e.g., hidden backdoor, remote shell). Reject PR.
|
| 84 |
+
- **3 SUBMIT_PATCH** β Vulnerability can be fixed. Provide corrected code in `patch_content`.
|
| 85 |
+
- **4 REQUEST_REVIEW** β Complex or ambiguous; require human expert.
|
| 86 |
+
|
| 87 |
+
## Rules
|
| 88 |
+
- `reasoning` must be thorough: describe the flaw, its impact (CWE if known), and stepβbyβstep how to patch.
|
| 89 |
+
- Escape all double quotes inside strings with backslash (`\"`).
|
| 90 |
+
- If the code is benign, set `risk_score` β€ 0.2, `action_type` = 0, and `patch_content` = null.
|
| 91 |
+
- Never include comments or explanations outside the JSON object.
|
| 92 |
+
|
| 93 |
+
**Example valid response:**
|
| 94 |
+
{"reasoning": "Hardcoded password 'admin123' in __init__ allows credential bypass. Replace with env var.", "risk_score": 0.85, "action_type": 3, "patch_content": "import os\\nclass Malicious:\\n def __init__(self):\\n self.cache = []\\n self.password = os.getenv('DB_PASS')\\n ..."}
|
| 95 |
"""
|
| 96 |
|
| 97 |
+
# SYSTEM_PROMPT = """\
|
| 98 |
+
# You are PatchHawk, a security agent that detects supply-chain vulnerabilities
|
| 99 |
+
# in Python code. You will be given a code snippet and static analysis flags.
|
| 100 |
+
|
| 101 |
+
# Respond EXACTLY with a JSON object containing the following keys:
|
| 102 |
+
# {
|
| 103 |
+
# "reasoning": "<str>", // Step-by-step explanation of what the vulnerability is, why you are blocking/patching it, and how it can be fixed.
|
| 104 |
+
# "risk_score": <float>, // Your predicted risk score from 0.0 to 1.0 based on your analysis
|
| 105 |
+
# "action_type": <int>, // 0=ANALYZE, 1=EXECUTE_SANDBOX, 2=BLOCK_PR, 3=SUBMIT_PATCH, 4=REQUEST_REVIEW
|
| 106 |
+
# "patch_content": "<str|null>" // The full patched python code fixing the vulnerability
|
| 107 |
+
# }
|
| 108 |
+
|
| 109 |
+
# Be decisive. First, explain your findings thoroughly in the "reasoning" field.
|
| 110 |
+
# If the code is malicious but you can fix the vulnerability, use SUBMIT_PATCH (3) and provide the safe, corrected code in "patch_content".
|
| 111 |
+
# If the code is severely malicious and completely unfixable, use BLOCK_PR (2).
|
| 112 |
+
# IMPORTANT: Ensure your output is perfectly VALID JSON. Escape all double quotes inside strings properly.
|
| 113 |
+
# """
|
| 114 |
+
|
| 115 |
|
| 116 |
def _build_user_prompt(obs: PatchHawkObservation, step: int) -> str:
|
| 117 |
parts = [
|
|
|
|
| 129 |
# ββ LLM caller βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 130 |
|
| 131 |
|
| 132 |
+
_local_pipeline = None
|
| 133 |
+
|
| 134 |
+
def _call_llm_local(messages: list[dict]) -> str:
|
| 135 |
+
"""Call a local HuggingFace model using transformers pipeline if remote API fails."""
|
| 136 |
+
global _local_pipeline
|
| 137 |
+
if _local_pipeline is None:
|
| 138 |
+
import torch
|
| 139 |
+
from transformers import pipeline
|
| 140 |
+
|
| 141 |
+
# User is already using this model in .env GRPO_POLICY_MODEL
|
| 142 |
+
local_model = os.getenv("GRPO_POLICY_MODEL", "unsloth/Qwen2.5-Coder-3B-Instruct")
|
| 143 |
+
print(f"\n[Fallback] Loading local model: {local_model} into memory. This may take a moment...", flush=True)
|
| 144 |
+
|
| 145 |
+
_local_pipeline = pipeline(
|
| 146 |
+
"text-generation",
|
| 147 |
+
model=local_model,
|
| 148 |
+
model_kwargs={"torch_dtype": torch.bfloat16}, # Half-precision to save VRAM natively fit on 12GB
|
| 149 |
+
device_map="auto"
|
| 150 |
+
)
|
| 151 |
+
print("[Fallback] Local model loaded successfully.\n", flush=True)
|
| 152 |
|
| 153 |
+
# Format messages array to a standard conversational string format
|
| 154 |
+
prompt = _local_pipeline.tokenizer.apply_chat_template(
|
| 155 |
+
messages,
|
| 156 |
+
tokenize=False,
|
| 157 |
+
add_generation_prompt=True
|
| 158 |
)
|
| 159 |
+
|
| 160 |
+
# Run Generation
|
| 161 |
+
outputs = _local_pipeline(
|
| 162 |
+
prompt,
|
| 163 |
+
max_new_tokens=2048,
|
| 164 |
+
do_sample=True,
|
| 165 |
temperature=0.2,
|
|
|
|
| 166 |
)
|
| 167 |
+
|
| 168 |
+
generated = outputs[0]["generated_text"]
|
| 169 |
+
|
| 170 |
+
print(f"\ngenerated:{generated}\n")
|
| 171 |
+
# Strip prompt from returned generated output
|
| 172 |
+
if generated.startswith(prompt):
|
| 173 |
+
generated = generated[len(prompt):]
|
| 174 |
+
|
| 175 |
+
return generated.strip()
|
| 176 |
+
|
| 177 |
|
| 178 |
+
def _call_llm(messages: list[dict]) -> str:
|
| 179 |
+
"""Call the OpenAI-compatible LLM and return the text content."""
|
| 180 |
+
from openai import OpenAI
|
| 181 |
+
|
| 182 |
+
try:
|
| 183 |
+
client = OpenAI(
|
| 184 |
+
base_url=API_BASE_URL,
|
| 185 |
+
api_key=HF_TOKEN or "no-key",
|
| 186 |
+
)
|
| 187 |
+
response = client.chat.completions.create(
|
| 188 |
+
model=MODEL_NAME,
|
| 189 |
+
messages=messages,
|
| 190 |
+
temperature=0.2,
|
| 191 |
+
max_tokens=512,
|
| 192 |
+
)
|
| 193 |
+
return response.choices[0].message.content or ""
|
| 194 |
+
except Exception as e:
|
| 195 |
+
print(f"[LLM ERROR] Remote API failed: {e}. Initiating local Fallback...", flush=True)
|
| 196 |
+
return _call_llm_local(messages)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
import re
|
| 200 |
|
| 201 |
def _parse_action(text: str) -> PatchHawkAction:
|
| 202 |
"""Parse LLM response text into a PatchHawkAction."""
|
|
|
|
| 203 |
text = text.strip()
|
|
|
|
| 204 |
if "```json" in text:
|
| 205 |
text = text.split("```json")[1].split("```")[0].strip()
|
| 206 |
+
elif "```" in text and not text.startswith("{"):
|
| 207 |
text = text.split("```")[1].split("```")[0].strip()
|
| 208 |
|
| 209 |
+
def clean_patch(p: str) -> str:
|
| 210 |
+
if not p: return p
|
| 211 |
+
if "```python" in p:
|
| 212 |
+
return p.split("```python")[1].split("```")[0].strip()
|
| 213 |
+
if "```" in p:
|
| 214 |
+
return p.split("```")[1].split("```")[0].strip()
|
| 215 |
+
return p
|
| 216 |
+
|
| 217 |
+
try:
|
| 218 |
+
data = json.loads(text)
|
| 219 |
+
except json.JSONDecodeError:
|
| 220 |
+
action_match = re.search(r'"action_type"\s*:\s*(\d+)', text)
|
| 221 |
+
action_type = int(action_match.group(1)) if action_match else 2
|
| 222 |
+
|
| 223 |
+
risk_match = re.search(r'"risk_score"\s*:\s*([\d\.]+)', text)
|
| 224 |
+
risk_score = float(risk_match.group(1)) if risk_match else None
|
| 225 |
+
|
| 226 |
+
patch_match = re.search(r'"patch_content"\s*:\s*"(.*)', text, re.DOTALL)
|
| 227 |
+
patch_content = None
|
| 228 |
+
if patch_match:
|
| 229 |
+
raw_patch = patch_match.group(1).rsplit('"', 1)[0]
|
| 230 |
+
raw_patch = raw_patch.replace("\\n", "\n").replace('\\"', '"').replace("\\\\", "\\")
|
| 231 |
+
patch_content = clean_patch(raw_patch)
|
| 232 |
+
|
| 233 |
+
return PatchHawkAction(
|
| 234 |
+
action_type=action_type,
|
| 235 |
+
reasoning="JSON Error/Truncated Output. Recovered partial data.",
|
| 236 |
+
predicted_risk=risk_score,
|
| 237 |
+
patch_content=patch_content
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
return PatchHawkAction(
|
| 241 |
+
action_type=int(data.get("action_type", 2)),
|
| 242 |
+
patch_content=clean_patch(data.get("patch_content")),
|
| 243 |
+
reasoning=data.get("reasoning"),
|
| 244 |
+
predicted_risk=data.get("risk_score"),
|
| 245 |
)
|
| 246 |
|
| 247 |
|
patchhawk/agent/environment.py
CHANGED
|
@@ -308,8 +308,16 @@ class PatchHawkEnv(Environment[PatchHawkAction, PatchHawkObservation, PatchHawkS
|
|
| 308 |
if self.step_counter >= self.max_steps and not done:
|
| 309 |
done = True
|
| 310 |
if label == "malicious":
|
| 311 |
-
reward =
|
| 312 |
-
reason = "max steps reached on malicious scenario"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
self.cumulative_reward += reward
|
| 315 |
|
|
|
|
| 308 |
if self.step_counter >= self.max_steps and not done:
|
| 309 |
done = True
|
| 310 |
if label == "malicious":
|
| 311 |
+
reward -= 5.0
|
| 312 |
+
reason += " | max steps reached on malicious scenario"
|
| 313 |
+
|
| 314 |
+
# ββ Dynamic Risk Bonus βββββββββββββββββββββββββββββββββββ
|
| 315 |
+
predict_risk = getattr(action, "predicted_risk", None)
|
| 316 |
+
if predict_risk is not None:
|
| 317 |
+
actual_risk = 1.0 if label == "malicious" else 0.0
|
| 318 |
+
accuracy_bonus = (1.0 - abs(actual_risk - float(predict_risk))) * 2.0
|
| 319 |
+
reward += accuracy_bonus
|
| 320 |
+
reason += f" | AI risk accuracy bonus: +{accuracy_bonus:.2f}"
|
| 321 |
|
| 322 |
self.cumulative_reward += reward
|
| 323 |
|
patchhawk/agent/sandbox.py
CHANGED
|
@@ -31,7 +31,7 @@ def run_code(
|
|
| 31 |
temp_dir = tempfile.mkdtemp(prefix="patchhawk_sandbox_")
|
| 32 |
script_path = os.path.join(temp_dir, "script.py")
|
| 33 |
|
| 34 |
-
with open(script_path, "w") as f:
|
| 35 |
f.write(code)
|
| 36 |
|
| 37 |
result: Dict[str, Any] = {
|
|
@@ -91,7 +91,7 @@ def check_syntax(
|
|
| 91 |
temp_dir = tempfile.mkdtemp(prefix="patchhawk_syntax_")
|
| 92 |
script_path = os.path.join(temp_dir, "script.py")
|
| 93 |
|
| 94 |
-
with open(script_path, "w") as f:
|
| 95 |
f.write(code)
|
| 96 |
|
| 97 |
try:
|
|
@@ -107,7 +107,7 @@ def check_syntax(
|
|
| 107 |
"--cpus",
|
| 108 |
"0.5",
|
| 109 |
"-v",
|
| 110 |
-
f"{temp_dir}:/app:
|
| 111 |
"patchhawk-sandbox:latest",
|
| 112 |
"python",
|
| 113 |
"-m",
|
|
|
|
| 31 |
temp_dir = tempfile.mkdtemp(prefix="patchhawk_sandbox_")
|
| 32 |
script_path = os.path.join(temp_dir, "script.py")
|
| 33 |
|
| 34 |
+
with open(script_path, "w", encoding="utf-8") as f:
|
| 35 |
f.write(code)
|
| 36 |
|
| 37 |
result: Dict[str, Any] = {
|
|
|
|
| 91 |
temp_dir = tempfile.mkdtemp(prefix="patchhawk_syntax_")
|
| 92 |
script_path = os.path.join(temp_dir, "script.py")
|
| 93 |
|
| 94 |
+
with open(script_path, "w", encoding="utf-8") as f:
|
| 95 |
f.write(code)
|
| 96 |
|
| 97 |
try:
|
|
|
|
| 107 |
"--cpus",
|
| 108 |
"0.5",
|
| 109 |
"-v",
|
| 110 |
+
f"{temp_dir}:/app:rw",
|
| 111 |
"patchhawk-sandbox:latest",
|
| 112 |
"python",
|
| 113 |
"-m",
|
patchhawk/app/dashboard.py
CHANGED
|
@@ -181,7 +181,10 @@ def main():
|
|
| 181 |
final_action_type = PatchHawkEnv.ACTION_BLOCK_PR
|
| 182 |
else:
|
| 183 |
final_action_type = PatchHawkEnv.ACTION_REQUEST_REVIEW
|
| 184 |
-
action = PatchHawkAction(
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
# Visual Hacker Terminal Effect
|
| 187 |
if final_action_type == PatchHawkEnv.ACTION_SUBMIT_PATCH:
|
|
@@ -219,8 +222,13 @@ def main():
|
|
| 219 |
with st.expander("π€ Agent Thought Process (LLM Trace)"):
|
| 220 |
st.markdown(f"```json\n{llm_thought_process}\n```")
|
| 221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
m1, m2, m3 = st.columns(3)
|
| 223 |
-
m1.metric("Risk Score", f"{
|
| 224 |
m2.metric("Decision", PatchHawkEnv.ACTION_NAMES[final_action_type])
|
| 225 |
m3.metric("Reward", f"{total_reward:+.2f}")
|
| 226 |
|
|
@@ -229,6 +237,10 @@ def main():
|
|
| 229 |
)
|
| 230 |
|
| 231 |
with tab1:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
if final_action_type == PatchHawkEnv.ACTION_BLOCK_PR:
|
| 233 |
st.markdown(
|
| 234 |
"<div class='info-box status-malicious'>β BLOCKED β "
|
|
@@ -253,10 +265,13 @@ def main():
|
|
| 253 |
|
| 254 |
with tab2:
|
| 255 |
telem = obs.metadata.get("telemetry")
|
|
|
|
| 256 |
if telem:
|
| 257 |
st.json(telem)
|
|
|
|
|
|
|
| 258 |
else:
|
| 259 |
-
st.info("No sandbox
|
| 260 |
|
| 261 |
with tab3:
|
| 262 |
if final_action_type == PatchHawkEnv.ACTION_SUBMIT_PATCH and scenario.get(
|
|
|
|
| 181 |
final_action_type = PatchHawkEnv.ACTION_BLOCK_PR
|
| 182 |
else:
|
| 183 |
final_action_type = PatchHawkEnv.ACTION_REQUEST_REVIEW
|
| 184 |
+
action = PatchHawkAction(
|
| 185 |
+
action_type=final_action_type,
|
| 186 |
+
reasoning="Static rule-based fallback decision due to high risk score."
|
| 187 |
+
)
|
| 188 |
|
| 189 |
# Visual Hacker Terminal Effect
|
| 190 |
if final_action_type == PatchHawkEnv.ACTION_SUBMIT_PATCH:
|
|
|
|
| 222 |
with st.expander("π€ Agent Thought Process (LLM Trace)"):
|
| 223 |
st.markdown(f"```json\n{llm_thought_process}\n```")
|
| 224 |
|
| 225 |
+
# Opt for LLM's predicted risk score if available
|
| 226 |
+
display_risk = getattr(action, "predicted_risk", None)
|
| 227 |
+
if display_risk is None:
|
| 228 |
+
display_risk = risk
|
| 229 |
+
|
| 230 |
m1, m2, m3 = st.columns(3)
|
| 231 |
+
m1.metric("Risk Score", f"{float(display_risk):.2f}")
|
| 232 |
m2.metric("Decision", PatchHawkEnv.ACTION_NAMES[final_action_type])
|
| 233 |
m3.metric("Reward", f"{total_reward:+.2f}")
|
| 234 |
|
|
|
|
| 237 |
)
|
| 238 |
|
| 239 |
with tab1:
|
| 240 |
+
if hasattr(action, "reasoning") and action.reasoning:
|
| 241 |
+
st.markdown("### π§ Agent's Reasoning")
|
| 242 |
+
st.info(action.reasoning)
|
| 243 |
+
|
| 244 |
if final_action_type == PatchHawkEnv.ACTION_BLOCK_PR:
|
| 245 |
st.markdown(
|
| 246 |
"<div class='info-box status-malicious'>β BLOCKED β "
|
|
|
|
| 265 |
|
| 266 |
with tab2:
|
| 267 |
telem = obs.metadata.get("telemetry")
|
| 268 |
+
details = obs.metadata.get("details")
|
| 269 |
if telem:
|
| 270 |
st.json(telem)
|
| 271 |
+
elif dict(details) if details else None:
|
| 272 |
+
st.json(details)
|
| 273 |
else:
|
| 274 |
+
st.info("No sandbox telemetry generated for this action.")
|
| 275 |
|
| 276 |
with tab3:
|
| 277 |
if final_action_type == PatchHawkEnv.ACTION_SUBMIT_PATCH and scenario.get(
|
patchhawk/data/generate_scenarios.py
CHANGED
|
@@ -128,12 +128,16 @@ def auto_generate_unit_test(filename: str, code: str) -> str:
|
|
| 128 |
# ============================================================
|
| 129 |
|
| 130 |
|
| 131 |
-
def generate_track_b_scenarios(benign_files: list) -> list:
|
| 132 |
-
"""Generate
|
| 133 |
scenarios = []
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
bf = random.choice(benign_files)
|
| 138 |
attack_name, attack_data = random.choice(list(ATTACK_TEMPLATES.items()))
|
| 139 |
malicious_code = attack_data["inject"] + bf["code"]
|
|
@@ -187,7 +191,7 @@ def generate_track_b_scenarios(benign_files: list) -> list:
|
|
| 187 |
"result = subprocess.run(['echo', 'build ok'], capture_output=True)\n\n",
|
| 188 |
),
|
| 189 |
]
|
| 190 |
-
for i in range(
|
| 191 |
bf = random.choice(benign_files)
|
| 192 |
fp_name, fp_code = random.choice(fp_templates)
|
| 193 |
suspicious_code = fp_code + bf["code"]
|
|
@@ -205,8 +209,8 @@ def generate_track_b_scenarios(benign_files: list) -> list:
|
|
| 205 |
}
|
| 206 |
)
|
| 207 |
|
| 208 |
-
# ββ Functional / Clean (
|
| 209 |
-
for i in range(
|
| 210 |
bf = random.choice(benign_files)
|
| 211 |
test_code = auto_generate_unit_test(bf["filename"], bf["code"])
|
| 212 |
scenarios.append(
|
|
@@ -222,7 +226,7 @@ def generate_track_b_scenarios(benign_files: list) -> list:
|
|
| 222 |
}
|
| 223 |
)
|
| 224 |
|
| 225 |
-
return scenarios
|
| 226 |
|
| 227 |
|
| 228 |
# ============================================================
|
|
@@ -486,6 +490,12 @@ def main():
|
|
| 486 |
type=str,
|
| 487 |
default="patchhawk/data/scenarios.json",
|
| 488 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
parser.add_argument(
|
| 490 |
"--use-sdk",
|
| 491 |
action="store_true",
|
|
@@ -535,7 +545,7 @@ def main():
|
|
| 535 |
return
|
| 536 |
|
| 537 |
# Track B (always)
|
| 538 |
-
scenarios = generate_track_b_scenarios(benign_files)
|
| 539 |
|
| 540 |
# Track A (optional)
|
| 541 |
if args.use_sdk:
|
|
|
|
| 128 |
# ============================================================
|
| 129 |
|
| 130 |
|
| 131 |
+
def generate_track_b_scenarios(benign_files: list, num_samples: int = 55) -> list:
|
| 132 |
+
"""Generate proportional scenarios dynamically based on num_samples."""
|
| 133 |
scenarios = []
|
| 134 |
|
| 135 |
+
tp_count = int(num_samples * 0.45)
|
| 136 |
+
fp_count = int(num_samples * 0.27)
|
| 137 |
+
fn_count = num_samples - tp_count - fp_count
|
| 138 |
+
|
| 139 |
+
# ββ True Positives (45%) ββββββββββββββββββββββββββββββββββ
|
| 140 |
+
for i in range(tp_count):
|
| 141 |
bf = random.choice(benign_files)
|
| 142 |
attack_name, attack_data = random.choice(list(ATTACK_TEMPLATES.items()))
|
| 143 |
malicious_code = attack_data["inject"] + bf["code"]
|
|
|
|
| 191 |
"result = subprocess.run(['echo', 'build ok'], capture_output=True)\n\n",
|
| 192 |
),
|
| 193 |
]
|
| 194 |
+
for i in range(fp_count):
|
| 195 |
bf = random.choice(benign_files)
|
| 196 |
fp_name, fp_code = random.choice(fp_templates)
|
| 197 |
suspicious_code = fp_code + bf["code"]
|
|
|
|
| 209 |
}
|
| 210 |
)
|
| 211 |
|
| 212 |
+
# ββ Functional / Clean (28%) ββββββββββββββββββββββββββββββ
|
| 213 |
+
for i in range(fn_count):
|
| 214 |
bf = random.choice(benign_files)
|
| 215 |
test_code = auto_generate_unit_test(bf["filename"], bf["code"])
|
| 216 |
scenarios.append(
|
|
|
|
| 226 |
}
|
| 227 |
)
|
| 228 |
|
| 229 |
+
return scenarios
|
| 230 |
|
| 231 |
|
| 232 |
# ============================================================
|
|
|
|
| 490 |
type=str,
|
| 491 |
default="patchhawk/data/scenarios.json",
|
| 492 |
)
|
| 493 |
+
parser.add_argument(
|
| 494 |
+
"--num-samples",
|
| 495 |
+
type=int,
|
| 496 |
+
default=55,
|
| 497 |
+
help="Number of scenarios to generate with Track B (mutation engine).",
|
| 498 |
+
)
|
| 499 |
parser.add_argument(
|
| 500 |
"--use-sdk",
|
| 501 |
action="store_true",
|
|
|
|
| 545 |
return
|
| 546 |
|
| 547 |
# Track B (always)
|
| 548 |
+
scenarios = generate_track_b_scenarios(benign_files, args.num_samples)
|
| 549 |
|
| 550 |
# Track A (optional)
|
| 551 |
if args.use_sdk:
|
patchhawk/data/scenarios.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
patchhawk/env_models.py
CHANGED
|
@@ -53,6 +53,12 @@ class PatchHawkAction(Action):
|
|
| 53 |
patch_content: Optional[str] = Field(
|
| 54 |
None, description="The unified context patch if action is SUBMIT_PATCH"
|
| 55 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
# ββ State ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 53 |
patch_content: Optional[str] = Field(
|
| 54 |
None, description="The unified context patch if action is SUBMIT_PATCH"
|
| 55 |
)
|
| 56 |
+
reasoning: Optional[str] = Field(
|
| 57 |
+
None, description="Explanation of the vulnerability and chosen action"
|
| 58 |
+
)
|
| 59 |
+
predicted_risk: Optional[float] = Field(
|
| 60 |
+
None, description="LLM predicted risk score (0.0 to 1.0)"
|
| 61 |
+
)
|
| 62 |
|
| 63 |
|
| 64 |
# ββ State ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
patchhawk/training/train_grpo.py
CHANGED
|
@@ -33,6 +33,7 @@ def _build_prompt(scenario: dict) -> str:
|
|
| 33 |
f"<code_snippet>\n{scenario['code_snippet']}\n</code_snippet>\n"
|
| 34 |
"Respond in STRICT XML:\n"
|
| 35 |
"<thought>...</thought>\n"
|
|
|
|
| 36 |
"<action>0-4</action>\n"
|
| 37 |
"<patch>...</patch> (ONLY if action=3)\n"
|
| 38 |
)
|
|
@@ -90,7 +91,10 @@ def train_agent(args):
|
|
| 90 |
else:
|
| 91 |
print("No GPU found β training will be slow.")
|
| 92 |
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
# 4βbit quantisation config
|
| 96 |
bnb_config = BitsAndBytesConfig(
|
|
@@ -147,6 +151,10 @@ def train_agent(args):
|
|
| 147 |
score += 0.5
|
| 148 |
else:
|
| 149 |
score -= 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
if re.search(r"<action>[0-4]</action>", text):
|
| 151 |
score += 0.5
|
| 152 |
else:
|
|
@@ -194,12 +202,22 @@ def train_agent(args):
|
|
| 194 |
patch_match = re.search(r"<patch>(.*?)</patch>", text, re.DOTALL)
|
| 195 |
if patch_match:
|
| 196 |
patch = patch_match.group(1).strip()
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
try:
|
| 199 |
# Reset environment to the exact scenario
|
| 200 |
-
env.reset(
|
| 201 |
-
obs = env.step(PatchHawkAction(
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
except Exception as exc:
|
| 204 |
print(f"env_reward crash: {exc}")
|
| 205 |
rewards.append(-3.0)
|
|
|
|
| 33 |
f"<code_snippet>\n{scenario['code_snippet']}\n</code_snippet>\n"
|
| 34 |
"Respond in STRICT XML:\n"
|
| 35 |
"<thought>...</thought>\n"
|
| 36 |
+
"<risk_score>0.0 to 1.0</risk_score>\n"
|
| 37 |
"<action>0-4</action>\n"
|
| 38 |
"<patch>...</patch> (ONLY if action=3)\n"
|
| 39 |
)
|
|
|
|
| 91 |
else:
|
| 92 |
print("No GPU found β training will be slow.")
|
| 93 |
|
| 94 |
+
from dotenv import load_dotenv
|
| 95 |
+
load_dotenv()
|
| 96 |
+
|
| 97 |
+
MODEL_NAME = os.getenv("GRPO_POLICY_MODEL", "Qwen/Qwen2.5-Coder-3B-Instruct")
|
| 98 |
|
| 99 |
# 4βbit quantisation config
|
| 100 |
bnb_config = BitsAndBytesConfig(
|
|
|
|
| 151 |
score += 0.5
|
| 152 |
else:
|
| 153 |
score -= 1.0
|
| 154 |
+
if re.search(r"<risk_score>[\d\.]+</risk_score>", text):
|
| 155 |
+
score += 0.5
|
| 156 |
+
else:
|
| 157 |
+
score -= 1.0
|
| 158 |
if re.search(r"<action>[0-4]</action>", text):
|
| 159 |
score += 0.5
|
| 160 |
else:
|
|
|
|
| 202 |
patch_match = re.search(r"<patch>(.*?)</patch>", text, re.DOTALL)
|
| 203 |
if patch_match:
|
| 204 |
patch = patch_match.group(1).strip()
|
| 205 |
+
|
| 206 |
+
risk_match = re.search(r"<risk_score>([\d\.]+)</risk_score>", text)
|
| 207 |
+
predicted_risk = float(risk_match.group(1)) if risk_match else None
|
| 208 |
|
| 209 |
try:
|
| 210 |
# Reset environment to the exact scenario
|
| 211 |
+
env.reset(scenario=scenario)
|
| 212 |
+
obs = env.step(PatchHawkAction(
|
| 213 |
+
action_type=action_type,
|
| 214 |
+
patch_content=patch,
|
| 215 |
+
predicted_risk=predicted_risk
|
| 216 |
+
))
|
| 217 |
+
reward_val = float(obs.reward or 0.0)
|
| 218 |
+
rewards.append(reward_val)
|
| 219 |
+
val_msg = obs.metadata.get('validation') or ("Telemetry Extracted" if obs.metadata.get('telemetry') else "None")
|
| 220 |
+
print(f"[Env Reward] Action: {action_type} | Reward: {reward_val:+.2f} | Docker: {val_msg}")
|
| 221 |
except Exception as exc:
|
| 222 |
print(f"env_reward crash: {exc}")
|
| 223 |
rewards.append(-3.0)
|