Spaces:
Sleeping
Sleeping
File size: 15,761 Bytes
559db20 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 | #!/usr/bin/env python3
"""
generate_sft_data.py
====================
Build SFT warmup data for metacognitive format training.
Generates demonstration trajectories that teach the model:
1. The tool-call format (<budget_prediction>, <think>, <tool_call>)
2. Short predictions + brief reasoning on safe files
3. Long predictions + deep reasoning on buggy files
4. Proper flag/skip decisions
These are used for 2-3 epochs of SFT BEFORE GRPO, so the model
arrives at RL already knowing the output format. This eliminates
the ~30% zero-reward rate caused by malformed completions.
Usage:
python scripts/generate_sft_data.py
# Output: data/sft_demonstrations.json
"""
from __future__ import annotations
import json
import os
import random
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
CVE_DATA = ROOT / "data" / "cve_training_data.json"
OUTPUT = ROOT / "data" / "sft_demonstrations.json"
# Import the system prompt from train_grpo.py
SYSTEM_PROMPT_BASE = """You are an expert security code investigator specializing in CVE vulnerability analysis. You are given a CVE vulnerability description and a list of files from a code patch.
Your mission:
1. Use read_file to examine source code of the most suspicious files first
2. Use search_code to find vulnerability-related patterns (e.g., unsafe functions, missing checks)
3. Use get_function_list to understand file structure and complexity
4. Use flag_vulnerable to mark files containing or related to the vulnerability β provide DETAILED reasoning
5. Use skip_file to mark files that are safe β briefly explain why
6. Use submit_report to provide a detailed triage analysis
CRITICAL β Thinking Strategy:
You have a THINKING BUDGET. Use it wisely, like a real security engineer:
- When flagging a file as VULNERABLE: Think deeply. Explain the exact code pattern that matches
the CVE. Mention function names, unsafe operations, missing checks. Your reasoning should be
100+ characters proving you understand the vulnerability.
- When skipping a SAFE file: Be brief. "Header file with only declarations" or "Test file,
no production logic" is sufficient. Don't waste thinking effort on obviously safe files.
- The environment TRACKS your thinking allocation and rewards you for thinking deeply on the
RIGHT files (actual bugs) and briefly on safe files.
Investigation strategy:
- Prioritize files with high complexity and recent changes β bugs hide in complex code
- Match the CVE description to code patterns (e.g., "buffer overflow" β look for unchecked memcpy/strcpy)
- Header files (.h) are usually declarations, not vulnerability sources
- Test files are rarely the vulnerability source
- Write a thorough report: mention the CVE ID, affected files, vulnerability type, and root cause
You have limited investigation points and flags. Be strategic β read the most suspicious files first.
CRITICAL β Metacognitive Format (REQUIRED):
Before EVERY <think> block, you MUST emit a budget prediction first:
<budget_prediction>short|medium|long</budget_prediction>
<think>
...your reasoning here...
</think>
<tool_call>{"name": "...", "arguments": {...}}</tool_call>
Budget bands:
- short : 0β80 characters of reasoning. Use for obviously safe files
(test files, headers with no logic, boilerplate).
- medium : 80β250 characters. Use when you need to verify but don't
see strong red flags.
- long : 250+ characters. Use when you suspect the file is vulnerable
and need to lay out the bug pattern (function name, unsafe
operation, missing check, exploit path).
You will be SCORED on:
1. Calibration β does the actual length of your <think> match the band
you predicted?
2. Difficulty awareness β do you predict 'long' on actually-vulnerable
files and 'short' on safe ones?
3. Coupling β every prediction must be followed by a real tool call
against a file (no orphan predictions).
The optimal policy predicts BEFORE thinking, thinks the predicted amount,
and predicts longer for bugs. Be honest about the difficulty.
"""
# ββ Vulnerability-specific reasoning templates ββββββββββββββββββββββββ
# These are realistic security analysis patterns for different CVE types
BUG_REASONING_TEMPLATES = [
# Buffer overflow / integer overflow
(
"Looking at {file} in the context of {cve_id}. The CVE describes {vuln_type}. "
"This file is in {component} and has complexity score {complexity}. "
"The function handling user input does not validate the size parameter "
"before passing it to the allocation routine. This matches the CVE pattern: "
"attacker-controlled length flows into a memory operation without bounds checking. "
"The missing sanitization on the arithmetic creates an integer overflow primitive "
"that leads to heap corruption. Flagging as vulnerable."
),
# Use-after-free
(
"Examining {file} for {cve_id}. The CVE is about {vuln_type}. "
"This file contains the object lifecycle management code in {component}. "
"The reference counting pattern here has a window where the object can be "
"freed while a callback still holds a dangling pointer. The race condition "
"between the release path and the async handler is the root cause. "
"No locking protects the critical section. This is the vulnerable file."
),
# Injection / command injection
(
"Analyzing {file} which is part of {component}. {cve_id} describes {vuln_type}. "
"The input parsing function constructs a command string using string concatenation "
"with user-supplied data. No escaping or parameterized query is used. "
"The attacker can inject arbitrary commands through the unsanitized parameter. "
"This is a textbook injection vulnerability matching the CVE description exactly."
),
# Auth bypass / privilege escalation
(
"Reviewing {file} for {cve_id} ({vuln_type}). This file implements the "
"authorization check in {component}. The conditional logic has a short-circuit "
"evaluation that skips the permission verification when a specific flag is set. "
"An attacker can set this flag through the public API, bypassing the intended "
"access control. The fix would require validating permissions unconditionally."
),
# Generic deep analysis
(
"Deep analysis of {file} for {cve_id}. CVE type: {vuln_type}. "
"Component: {component}, complexity: {complexity}/100, churn: {churn}. "
"The vulnerable code path starts at the entry point and flows through "
"the handler without proper validation. The specific issue is that "
"user-controlled data reaches a sensitive operation (memory write, "
"system call, or privilege check) without adequate sanitization. "
"This matches the vulnerability pattern described in the advisory."
),
]
SAFE_REASONING_TEMPLATES = [
"Header file with type declarations only. No executable logic.",
"Test file. No production code paths.",
"Factory/initialization boilerplate. No user input handling.",
"Configuration constants. No control flow.",
"Utility math functions. No external input.",
"Interface definitions. Pure declarations.",
"Build system file. Not executable code.",
"Documentation or metadata. No logic.",
"Logging helpers. No security-sensitive operations.",
"Static data definitions. No dynamic behavior.",
"Wrapper module. Delegates to other files.",
"Type aliases and enums. Declarative only.",
]
MEDIUM_REASONING_TEMPLATES = [
(
"Checking {file} in {component}. Has moderate complexity ({complexity}) "
"but the functions here handle internal data only, not user input. "
"No obvious match to the {vuln_type} pattern from {cve_id}. Skipping."
),
(
"Reviewing {file}. Complexity is {complexity} with {churn} recent changes. "
"The code processes data but uses safe library functions throughout. "
"No unchecked operations that match the CVE description. Safe to skip."
),
]
def extract_vuln_type(description: str) -> str:
"""Extract a short vulnerability type from CVE description."""
desc_lower = description.lower()
if "buffer overflow" in desc_lower or "heap" in desc_lower:
return "buffer overflow"
if "integer overflow" in desc_lower:
return "integer overflow leading to memory corruption"
if "use-after-free" in desc_lower or "use after free" in desc_lower:
return "use-after-free"
if "injection" in desc_lower or "sql" in desc_lower:
return "injection vulnerability"
if "privilege" in desc_lower or "escalat" in desc_lower:
return "privilege escalation"
if "bypass" in desc_lower or "auth" in desc_lower:
return "authentication bypass"
if "denial" in desc_lower or "dos" in desc_lower:
return "denial of service"
if "remote code" in desc_lower or "arbitrary code" in desc_lower:
return "remote code execution"
if "traversal" in desc_lower or "path" in desc_lower:
return "path traversal"
if "xss" in desc_lower or "cross-site" in desc_lower:
return "cross-site scripting"
return "security vulnerability"
def generate_completion(files: list, cve_id: str, cve_desc: str, rng: random.Random) -> str:
"""Generate a single demonstration completion for an episode."""
vuln_type = extract_vuln_type(cve_desc)
parts = []
for f in files:
fpath = f["file"]
label = f["label"]
features = f.get("features", [0, 0, 0, 0])
churn, complexity, todos, recency = features
component = f.get("file_component", "unknown")
is_test = f.get("is_test_file", False)
lang = f.get("file_language", "")
if label == 1:
# Buggy file: long prediction + deep reasoning + flag
template = rng.choice(BUG_REASONING_TEMPLATES)
reasoning = template.format(
file=fpath, cve_id=cve_id, vuln_type=vuln_type,
component=component, complexity=complexity, churn=churn,
)
parts.append(
f"<budget_prediction>long</budget_prediction>\n"
f"<think>\n{reasoning}\n</think>\n"
f'<tool_call>{{"name": "flag_vulnerable", "arguments": '
f'{{"file_path": "{fpath}", "reasoning": "{vuln_type} in {component}"}}}}</tool_call>'
)
elif is_test or "test" in fpath.lower() or "Header" in lang or fpath.endswith(".h"):
# Obviously safe: short prediction + brief reasoning + skip
reasoning = rng.choice(SAFE_REASONING_TEMPLATES)
parts.append(
f"<budget_prediction>short</budget_prediction>\n"
f"<think>\n{reasoning}\n</think>\n"
f'<tool_call>{{"name": "skip_file", "arguments": '
f'{{"file_path": "{fpath}", "reasoning": "safe - no vulnerability pattern"}}}}</tool_call>'
)
elif complexity > 30 or churn > 20:
# Medium complexity safe file: medium prediction
template = rng.choice(MEDIUM_REASONING_TEMPLATES)
reasoning = template.format(
file=fpath, component=component, complexity=complexity,
churn=churn, vuln_type=vuln_type, cve_id=cve_id,
)
parts.append(
f"<budget_prediction>medium</budget_prediction>\n"
f"<think>\n{reasoning}\n</think>\n"
f'<tool_call>{{"name": "skip_file", "arguments": '
f'{{"file_path": "{fpath}", "reasoning": "reviewed, no match to CVE pattern"}}}}</tool_call>'
)
else:
# Simple safe file: short prediction
reasoning = rng.choice(SAFE_REASONING_TEMPLATES)
parts.append(
f"<budget_prediction>short</budget_prediction>\n"
f"<think>\n{reasoning}\n</think>\n"
f'<tool_call>{{"name": "skip_file", "arguments": '
f'{{"file_path": "{fpath}", "reasoning": "safe - {reasoning[:40]}"}}}}</tool_call>'
)
return "\n\n".join(parts)
def build_user_prompt(cve_id: str, cve_desc: str, files: list) -> str:
"""Build the user message (same format as train_grpo.py)."""
file_list = "\n".join(
f" β’ {f['file']} [{f.get('file_language', 'unknown')}] "
f"complexity={f.get('features', [0,0,0,0])[1]} "
f"churn={f.get('features', [0,0,0,0])[0]}"
for f in files
)
return (
f"CVE: {cve_id}\n"
f"Description: {cve_desc}\n\n"
f"Files to investigate:\n{file_list}\n\n"
f"Begin your security investigation. Use the available tools to analyze the files.\n"
f"Remember: Think DEEPLY when flagging suspicious files (explain the vulnerability pattern).\n"
f"Be BRIEF when skipping safe files. Submit a thorough triage report when done."
)
def main():
rng = random.Random(42)
with open(CVE_DATA) as f:
all_files = json.load(f)
# Group by CVE
cve_groups = {}
for entry in all_files:
cve_id = entry["cveId"]
if cve_id not in cve_groups:
cve_groups[cve_id] = {
"cve_id": cve_id,
"cve_description": entry["cve_description"],
"files": [],
}
cve_groups[cve_id]["files"].append(entry)
# Filter to episodes with at least 1 bug and at most 15 files
# (matches the "easy" difficulty used in training)
episodes = []
for cve_id, group in cve_groups.items():
files = group["files"]
n_bugs = sum(1 for f in files if f["label"] == 1)
if n_bugs >= 1 and len(files) <= 15:
episodes.append(group)
rng.shuffle(episodes)
episodes = episodes[:50] # 50 demonstrations
print(f"Building SFT data from {len(episodes)} episodes...")
sft_data = []
for ep in episodes:
# Subsample files if too many (keep all bugs + random safe)
files = ep["files"]
bugs = [f for f in files if f["label"] == 1]
safe = [f for f in files if f["label"] == 0]
if len(safe) > 6:
safe = rng.sample(safe, 6)
selected = bugs + safe
rng.shuffle(selected)
user_prompt = build_user_prompt(ep["cve_id"], ep["cve_description"], selected)
completion = generate_completion(selected, ep["cve_id"], ep["cve_description"], rng)
sft_data.append({
"messages": [
{"role": "system", "content": SYSTEM_PROMPT_BASE},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": completion},
],
"cve_id": ep["cve_id"],
"n_files": len(selected),
"n_bugs": len(bugs),
})
os.makedirs(OUTPUT.parent, exist_ok=True)
with open(OUTPUT, "w") as f:
json.dump(sft_data, f, indent=2)
# Stats
total_files = sum(d["n_files"] for d in sft_data)
total_bugs = sum(d["n_bugs"] for d in sft_data)
avg_completion_len = sum(len(d["messages"][2]["content"]) for d in sft_data) / len(sft_data)
print(f"β
Wrote {len(sft_data)} SFT demonstrations to {OUTPUT}")
print(f" Total files: {total_files} ({total_bugs} bugs, {total_files - total_bugs} safe)")
print(f" Avg completion length: {avg_completion_len:.0f} chars")
if __name__ == "__main__":
main()
|