SecureCodeEnv / graders /attacks.py
vishaldhakad's picture
change in reward system to strict it between the 0-1
791664b
Raw
History Blame Contribute Delete
5 kB
"""
SecureCodeEnv - Attack Resistance Grader v4
All scores clamped to (0.001, 0.999).
"""
from sandbox.executor import safe_exec
from sandbox.payload_gen import get_payloads
from graders.clamp import clamp
def _is_seq(v):
return isinstance(v, (list, tuple))
def _sql_safe(result, payload):
if not _is_seq(result) or len(result) < 2:
return False
query = str(result[0])
if not any(p in query for p in ("%s", "?", ":param", "%(", ":%")):
return False
payload_stripped = payload.strip().replace(" ", "").lower()
query_lower = query.lower().replace(" ", "")
return payload_stripped not in query_lower
def _path_safe(result, payload):
if result is None:
return True
if not isinstance(result, str):
return True
r = result.lower()
if not result.startswith("/var/uploads"):
return False
bad = ["/etc/", "/root/", "/proc/", "passwd", "shadow", "%2e", "%2f", ".."]
return not any(b in r for b in bad)
ATTACK_CONFIGS = {
"medium_sql_query_builder": {
"fn_name": "build_user_query",
"fn_args_template": ["{payload}", "admin"],
"attack_type": "injection",
"check": _sql_safe,
"exception_is_pass": False,
},
"medium_file_path_handler": {
"fn_name": "resolve_safe_path",
"fn_args_template": ["/var/uploads", "{payload}"],
"attack_type": "traversal",
"check": _path_safe,
"exception_is_pass": True,
},
"hard_file_upload_handler": {
"fn_name": "validate_upload",
"fn_args_template": ["{payload}", "image/jpeg", 1000, "/uploads"],
"attack_type": "traversal",
"check": lambda result, payload: (
isinstance(result, dict) and
(result.get("ok") is False or
not any(b in str(result.get("safe_path", "")).lower()
for b in ["..", "/etc/", "/root/", "%2e"]))
),
"exception_is_pass": True,
},
"hard_jwt_validator": {
"fn_name": "validate_token",
"fn_args_template": ["{payload}", "test_secret_key_for_grading"],
"attack_type": "auth_bypass",
"check": lambda result, payload: result is None,
"exception_is_pass": True,
},
"hard_auth_middleware": {
"fn_name": "verify_csrf_token",
"fn_args_template": ["{payload}", "legitimate_session_csrf_token"],
"attack_type": "auth_bypass",
"check": lambda result, payload: result is False,
"exception_is_pass": True,
},
"easy_input_sanitizer": {
"fn_name": "sanitize_html",
"fn_args_template": ["{payload}"],
"attack_type": "xss",
"check": lambda result, payload: (
isinstance(result, str) and
not any(b in result.lower()
for b in ["<script", "onerror", "javascript:", "<img ", "<svg", "onload"])
),
"exception_is_pass": False,
},
}
def grade_attacks(code: str, task: dict, seed: int) -> dict:
task_id = task.get("id", "")
config = ATTACK_CONFIGS.get(task_id)
if not config:
return {"score": clamp(0.5), "blocked": 0, "total": 0,
"attack_type": "none",
"feedback": "No attack grading defined β€” neutral score"}
payloads = get_payloads(config["attack_type"], seed=seed, count=8)
if not payloads:
return {"score": clamp(0.5), "blocked": 0, "total": 0,
"attack_type": config["attack_type"],
"feedback": "No payloads generated β€” neutral score"}
blocked = 0
exception_is_pass = config.get("exception_is_pass", True)
for payload in payloads:
args = [a.replace("{payload}", payload) if isinstance(a, str) else a
for a in config["fn_args_template"]]
result = safe_exec(code, args, function_name=config["fn_name"], timeout=3)
if not result["ok"]:
if exception_is_pass:
blocked += 1
else:
try:
if config["check"](result.get("output"), payload):
blocked += 1
except Exception:
pass
raw = blocked / len(payloads)
return {
"score": clamp(raw),
"blocked": blocked,
"total": len(payloads),
"attack_type": config["attack_type"],
"feedback": _feedback(raw, config["attack_type"]),
}
def _feedback(score: float, attack_type: str) -> str:
names = {"injection": "SQL injection", "traversal": "path traversal",
"auth_bypass": "auth bypass", "xss": "XSS"}
name = names.get(attack_type, attack_type)
if score >= 0.875: return f"Excellent β€” {name} attacks blocked ({score:.0%})"
elif score >= 0.625: return f"Good β€” most {name} attacks blocked ({score:.0%})"
elif score >= 0.375: return f"Partial β€” {score:.0%} of {name} attacks blocked"
else: return f"Vulnerable β€” {score:.0%} of {name} attacks blocked"