SecureCodeEnv / graders /static_analysis.py
vishaldhakad's picture
change in reward system to strict it between the 0-1
791664b
Raw
History Blame Contribute Delete
6.28 kB
"""
SecureCodeEnv - Static Analysis Grader v4
All scores clamped to (0.001, 0.999).
"""
import subprocess, json, tempfile, os, re
from graders.clamp import clamp
def grade_static_analysis(code: str, task: dict) -> dict:
bandit = _run_bandit(code)
custom = _run_custom_checks(code, task)
if custom.get("hard_fail"):
final = min(bandit["score"] * 0.4, 0.40)
else:
final = (bandit["score"] * 0.60) + (custom["score"] * 0.40)
all_issues = bandit.get("issues", []) + custom.get("issues", [])
return {
"score": clamp(final),
"bandit_score": clamp(bandit["score"]),
"ast_score": clamp(custom["score"]),
"hard_fail": custom.get("hard_fail", False),
"issues": all_issues[:10],
"feedback": _feedback(final, all_issues, custom.get("hard_fail", False)),
}
def _run_bandit(code: str) -> dict:
tmp = None
try:
with tempfile.NamedTemporaryFile(mode="w", suffix=".py",
delete=False, prefix="sce_ban_") as f:
f.write(code); tmp = f.name
res = subprocess.run(
["bandit", "-r", tmp, "-f", "json", "-q", "--exit-zero"],
capture_output=True, text=True, timeout=15)
data = json.loads(res.stdout or '{"results":[]}')
issues = data.get("results", [])
penalty = sum(
0.40 if i.get("issue_severity") == "HIGH"
else 0.20 if i.get("issue_severity") == "MEDIUM"
else 0.05
for i in issues)
return {
"score": max(0.0, 1.0 - penalty),
"issues": [{"severity": i.get("issue_severity"),
"text": i.get("issue_text", "")[:100],
"line": i.get("line_number")} for i in issues[:5]],
}
except FileNotFoundError:
return {"score": 0.75, "issues": [], "note": "bandit not installed"}
except Exception as e:
return {"score": 0.75, "issues": [], "note": str(e)[:40]}
finally:
if tmp and os.path.exists(tmp):
try: os.unlink(tmp)
except OSError: pass
HARD_REQUIREMENTS = {
"easy_password_validator": [
{"type": "forbidden_any",
"patterns": ["hashlib.md5","hashlib.sha1","hashlib.sha256","md5(","sha1(","sha256("],
"message": "Weak hash β€” must use bcrypt"},
{"type": "required_any", "patterns": ["bcrypt"],
"message": "bcrypt not imported"},
],
"easy_token_generator": [
{"type": "forbidden_any",
"patterns": ["random.random(","random.randint(","random.choice(","random.seed("],
"message": "Weak PRNG β€” must use secrets"},
{"type": "required_any", "patterns": ["secrets"],
"message": "secrets module not imported"},
],
"medium_sql_query_builder": [
{"type": "forbidden_pattern", "regex": r'f["\'].*SELECT.*{',
"message": "f-string SQL β€” injection vulnerability"},
{"type": "forbidden_pattern",
"regex": r'["\']\s*\+\s*(username|role|user_input|query)',
"message": "String concat SQL β€” injection vulnerability"},
],
"medium_file_path_handler": [
{"type": "required_any", "patterns": ["resolve()","resolve(","realpath"],
"message": "Must use Path.resolve() or realpath"},
],
"hard_jwt_validator": [
{"type": "forbidden_any",
"patterns": ["verify_signature\": False","verify_signature':False",
"verify_exp\": False","algorithms=[\"none\"","algorithms=['none'"],
"message": "JWT verification disabled"},
{"type": "required_any", "patterns": ["algorithms="],
"message": "algorithms= not specified"},
],
"hard_auth_middleware": [
{"type": "required_any", "patterns": ["hmac.compare_digest"],
"message": "hmac.compare_digest not used"},
],
"easy_input_sanitizer": [
{"type": "forbidden_any", "patterns": ["eval(","exec("],
"message": "eval/exec usage"},
],
"hard_file_upload_handler": [
{"type": "required_any", "patterns": ["uuid"],
"message": "uuid not used"},
],
}
GENERIC_CHECKS = [
{"type": "forbidden_any", "patterns": ["eval(","exec("],
"message": "eval/exec usage", "severity": "HIGH"},
{"type": "forbidden_any", "patterns": ["shell=True"],
"message": "shell=True", "severity": "HIGH"},
{"type": "forbidden_any", "patterns": ["pickle.loads","pickle.load"],
"message": "Unsafe pickle", "severity": "HIGH"},
{"type": "forbidden_any", "patterns": ["hashlib.md5","hashlib.sha1"],
"message": "Weak hash", "severity": "HIGH"},
]
def _run_custom_checks(code: str, task: dict) -> dict:
issues = []; hard_fail = False; passed = total = 0
for chk in GENERIC_CHECKS:
total += 1
if _violated(code, chk):
issues.append({"check": chk["message"], "severity": chk.get("severity","MEDIUM"),
"message": chk["message"]})
else:
passed += 1
for req in HARD_REQUIREMENTS.get(task.get("id",""), []):
total += 1
if _violated(code, req):
hard_fail = True
issues.append({"check": req["message"], "severity": "CRITICAL",
"message": req["message"]})
else:
passed += 1
return {"score": passed / max(total, 1), "issues": issues, "hard_fail": hard_fail}
def _violated(code: str, req: dict) -> bool:
t = req.get("type", "")
if t == "forbidden_any":
return any(p in code for p in req.get("patterns", []))
if t == "required_any":
return not any(p in code for p in req.get("patterns", []))
if t == "forbidden_pattern":
return bool(re.search(req.get("regex","NOMATCH"), code, re.IGNORECASE | re.DOTALL))
return False
def _feedback(score: float, issues: list, hard_fail: bool) -> str:
if hard_fail:
return f"CRITICAL: {'; '.join(i['message'] for i in issues if i.get('severity')=='CRITICAL')[:120]}"
if score >= 0.9: return "Clean β€” no significant security issues"
high = sum(1 for i in issues if i.get("severity") == "HIGH")
return f"{high} HIGH severity issue(s)" if high else f"Some issues (score: {score:.2f})"