adversarial-sast / engine.py
Ferr0's picture
repo-audit: detection via HF Inference (Qwen3-Coder-480B), 7B keeps refutation
3255f2b verified
Raw
History Blame Contribute Delete
12.8 kB
"""Shared SAST engine — stage 1 (detect) and stage 2 (adversarial refute).
Extracted verbatim from the snippet path so both the single-snippet audit and the
whole-repo audit reuse the *same* prompts, schemas and refutation logic. Pure functions:
they take an Outlines `model` + tokenizer, no GPU decorator (callers wrap as needed).
"""
import json
import re
from schemas import Report, Verdict
DETECT_SYS = (
"You are a security code reviewer. List candidate SECURITY vulnerabilities with a real "
"attack surface (command/SQL injection, path traversal, XSS, SSRF, deserialization, auth "
"bypass...). Be inclusive about security issues — an adversarial pass verifies each — but "
"ignore pure style, quality or error-handling nits. Use 1-based line numbers."
)
REFUTE_SYS = (
"You are an exploit analyst verifying ONE claimed vulnerability in the given code. Decide if it "
"is GENUINELY exploitable. Judge ONLY what the code shows.\n"
"- EXPLOITABLE (exploitable=true): attacker-controlled input reaches the dangerous sink with NO "
"adequate sanitization in between; give a concrete proof-of-concept input.\n"
"- FALSE POSITIVE (exploitable=false): adequate protection on the value before the sink — a "
"prepared/parameterized query, an allow-list of permitted values, a cast to a safe type, proper "
"escaping (escapeshellarg, htmlspecialchars), or strict value validation (is_numeric on the whole "
"value). Also FP if the sink is unreachable/dead code, or the input is not attacker-controlled. "
"If a proper sanitization like that is visible, it IS a false positive — say so.\n"
"Two traps, do NOT fall for them:\n"
"1. An EXISTENCE/presence check is NOT sanitization. If attacker input is used to build a path "
"passed to include/require/file_get_contents/fopen/readfile — EVEN with a fixed directory prefix, "
"a forced extension (.php), or a file_exists()/is_file() guard — it IS EXPLOITABLE (path traversal "
"/ LFI): the attacker still controls which file loads, and such prefix/suffix/existence constraints "
"are routinely bypassed (../, encoded traversal, existing sensitive files).\n"
"2. isset/empty/strlen do not stop injection.\n"
"If exploitable, set exploitable=true and give the PoC input."
)
SEV = {"critical": "🔴", "high": "🟠", "medium": "🟡", "low": "⚪"}
_SEV_SET = {"low", "medium", "high", "critical"}
# Flat JSON schema mirroring schemas.Report (Candidate fields) — used as response_format for the
# HF Inference detector, which does NOT do Outlines-style constrained decoding (unlike ZeroGPU).
REPORT_JSON_SCHEMA = {
"type": "object",
"properties": {
"findings": {
"type": "array",
"items": {
"type": "object",
"properties": {
"vuln_type": {"type": "string"},
"line": {"type": "integer"},
"severity": {"type": "string", "enum": ["low", "medium", "high", "critical"]},
"rationale": {"type": "string"},
},
"required": ["vuln_type", "line", "severity", "rationale"],
},
}
},
"required": ["findings"],
}
# Flat JSON schema mirroring schemas.Verdict — response_format for the HF Inference refuter.
VERDICT_JSON_SCHEMA = {
"type": "object",
"properties": {
"exploitable": {"type": "boolean"},
"reasoning": {"type": "string"},
"poc": {"type": "string"},
},
"required": ["exploitable", "reasoning"],
}
def number_lines(code):
"""Prefix each line with its 1-based number (`N| `) so the detector cites the right line."""
return "\n".join(f"{i}| {ln}" for i, ln in enumerate(code.splitlines(), 1))
def _usage_tokens(resp):
u = getattr(resp, "usage", None)
return (getattr(u, "total_tokens", 0) or 0) if u else 0
def _parse_json(txt):
"""Robust parse: strip ``` fences, else extract the outermost {...}. Returns dict or None."""
if not txt:
return None
txt = txt.strip()
if txt.startswith("```"):
txt = re.sub(r"^```[a-zA-Z]*\n?", "", txt)
txt = re.sub(r"\n?```$", "", txt).strip()
try:
return json.loads(txt)
except json.JSONDecodeError:
m = re.search(r"\{.*\}", txt, re.S)
if not m:
return None
try:
return json.loads(m.group(0))
except json.JSONDecodeError:
return None
def _clean(findings, max_findings=None):
"""Normalize HF-Inference findings to the Candidate shape (vuln_type/line/severity/rationale)."""
out = []
for f in (findings or []):
if not isinstance(f, dict):
continue
try:
vt = str(f["vuln_type"]).strip()
ln = int(f["line"])
except (KeyError, TypeError, ValueError):
continue
if not vt:
continue
sev = str(f.get("severity", "medium")).lower()
if sev not in _SEV_SET:
sev = "medium"
out.append({"vuln_type": vt, "line": ln, "severity": sev,
"rationale": str(f.get("rationale", "")).strip()})
return out[:max_findings] if max_findings else out
def _clean_verdict(data):
return {"exploitable": bool((data or {}).get("exploitable")),
"reasoning": str((data or {}).get("reasoning", "")).strip(),
"poc": str((data or {}).get("poc") or "").strip()}
def chat(tok, system, user):
return tok.apply_chat_template(
[{"role": "system", "content": system}, {"role": "user", "content": user}],
tokenize=False, add_generation_prompt=True)
def detect(model, tok, code, lang, max_findings=5):
"""Stage 1 — list candidate vulnerabilities for one piece of code."""
prompt = chat(tok, DETECT_SYS, f"Language: {lang}\nCode:\n```\n{code}\n```\nList up to "
f"{max_findings} candidate vulnerabilities.")
try:
rep = json.loads(model(prompt, output_type=Report, max_new_tokens=512))
except (json.JSONDecodeError, ValueError):
# output truncated mid-JSON (large/complex file) → nothing reliable to report
return []
return (rep.get("findings") or [])[:max_findings]
def detect_hf(client, model_id, code, lang, max_findings=None):
"""Stage 1 via HF Inference (a large model, e.g. Qwen3-Coder-480B) — an off-GPU network call.
Sends line-numbered code and asks for ALL candidates (no cap by default: detector recall is
what matters — the source→sink distance is what defeats the small model; stage-2 refutation
kills false positives). Prefers a strict json_schema response_format; if the provider rejects
it, retries once with a strict-JSON prompt + robust parse. Returns (findings, total_tokens),
tokens being for cost logging by the caller.
"""
numbered = number_lines(code)
user = (f"Language: {lang}\nCode (each line is prefixed with `N| ` where N is its 1-based line "
f"number):\n```\n{numbered}\n```\nList ALL candidate security vulnerabilities you find. "
f"Cite the 1-based line number from the prefix. Do not limit the number of findings.")
msgs = [{"role": "system", "content": DETECT_SYS}, {"role": "user", "content": user}]
tokens = 0
# 1) preferred: constrained to the schema (when the provider supports json_schema)
try:
r = client.chat.completions.create(
model=model_id, messages=msgs, temperature=0, max_tokens=2048,
response_format={"type": "json_schema",
"json_schema": {"name": "Report", "schema": REPORT_JSON_SCHEMA, "strict": True}},
)
tokens += _usage_tokens(r)
data = _parse_json(r.choices[0].message.content)
if data is not None:
return _clean(data.get("findings"), max_findings), tokens
except Exception:
pass # provider may not support response_format → fall through to the strict-JSON retry
# 2) fallback: strict-JSON instruction, no response_format, one retry
msgs2 = msgs + [{"role": "system", "content":
'Return ONLY a JSON object {"findings":[{"vuln_type":string,"line":int,'
'"severity":"low|medium|high|critical","rationale":string}]}. No prose, no markdown.'}]
try:
r = client.chat.completions.create(model=model_id, messages=msgs2, temperature=0, max_tokens=2048)
tokens += _usage_tokens(r)
data = _parse_json(r.choices[0].message.content) or {}
return _clean(data.get("findings"), max_findings), tokens
except Exception:
return [], tokens
def refute(model, tok, code, candidate):
"""Stage 2 — adversarially verify ONE candidate (the calibrated, reused step)."""
prompt = chat(tok, REFUTE_SYS,
f"Code:\n```\n{code}\n```\nClaimed vulnerability: {candidate['vuln_type']} "
f"at line {candidate['line']}{candidate['rationale']}\nIs it really exploitable?")
try:
return json.loads(model(prompt, output_type=Verdict, max_new_tokens=300))
except (json.JSONDecodeError, ValueError):
# couldn't parse a verdict → conservatively treat as not-confirmed (no false alarm)
return {"exploitable": False, "reasoning": "verification inconclusive (parse error)", "poc": ""}
def refute_hf(client, model_id, code, candidate):
"""Stage 2 via HF Inference (small model) — an off-GPU network call, so a whole-repo scan needs
no ZeroGPU reservation (no proxy-token expiry on long scans). Same REFUTE_SYS + targeted context
as the ZeroGPU path. Returns (verdict, total_tokens). `code` is the caller's tight window."""
user = (f"Code:\n```\n{code}\n```\nClaimed vulnerability: {candidate['vuln_type']} "
f"at line {candidate['line']}{candidate['rationale']}\nIs it really exploitable?")
msgs = [{"role": "system", "content": REFUTE_SYS}, {"role": "user", "content": user}]
tokens = 0
try:
r = client.chat.completions.create(
model=model_id, messages=msgs, temperature=0, max_tokens=512,
response_format={"type": "json_schema",
"json_schema": {"name": "Verdict", "schema": VERDICT_JSON_SCHEMA, "strict": True}},
)
tokens += _usage_tokens(r)
data = _parse_json(r.choices[0].message.content)
if data is not None:
return _clean_verdict(data), tokens
except Exception:
pass
msgs2 = msgs + [{"role": "system", "content":
'Return ONLY {"exploitable":bool,"reasoning":string,"poc":string}. No prose, no markdown.'}]
try:
r = client.chat.completions.create(model=model_id, messages=msgs2, temperature=0, max_tokens=512)
tokens += _usage_tokens(r)
return _clean_verdict(_parse_json(r.choices[0].message.content)), tokens
except Exception:
return {"exploitable": False, "reasoning": "verification inconclusive (inference error)", "poc": ""}, tokens
def render_snippet(out, dt):
"""Markdown report for the single-snippet path (unchanged behaviour)."""
if not out.get("verified"):
cs = out.get("candidates", [])
if not cs:
return f"No candidate vulnerabilities found · {dt:.1f}s"
lines = [f"### ⚠️ {len(cs)} candidates — **unverified** (raw detector) · {dt:.1f}s",
"_Raw guesses — some are false positives. Flip **Verify ON** to refute them._\n"]
for c in cs:
lines.append(f"- **{c['vuln_type']}** · line {c['line']} · "
f"{SEV.get(c['severity'], '')} {c['severity']}{c['rationale']}")
return "\n".join(lines)
res = out.get("results", [])
real = [r for r in res if r["verdict"].get("exploitable")]
fp = [r for r in res if not r["verdict"].get("exploitable")]
lines = [f"### ✅ Verified · **{len(real)} confirmed**, {len(fp)} refuted · {dt:.1f}s\n"]
if real:
lines.append(f"#### ✗ Confirmed ({len(real)})")
for r in real:
c, v = r["candidate"], r["verdict"]
lines.append(f"- **{c['vuln_type']}** · line {c['line']} · {SEV.get(c['severity'], '')} {c['severity']}")
if v.get("poc"):
lines.append(f" - PoC: `{v['poc']}`")
lines.append(f" - {v.get('reasoning', '')}")
if fp:
lines.append(f"\n#### ✓ Refuted as false positive ({len(fp)})")
for r in fp:
c, v = r["candidate"], r["verdict"]
lines.append(f"- ~~{c['vuln_type']} · line {c['line']}~~ — {v.get('reasoning', '')}")
if not res:
lines.append("_Nothing flagged._")
return "\n".join(lines)