SecureCodeEnv / graders /correctness.py
vishaldhakad's picture
change in reward system to strict it between the 0-1
791664b
Raw
History Blame Contribute Delete
5.96 kB
"""
SecureCodeEnv - Correctness Grader v4
Weight: 25% of total reward.
All scores clamped to (0.001, 0.999).
"""
from sandbox.executor import safe_exec
from graders.clamp import clamp
def _is_seq(v):
return isinstance(v, (list, tuple))
def grade_correctness(code: str, task: dict) -> dict:
test_cases = task.get("test_cases", [])
if not test_cases:
return {"score": clamp(0.5), "passed": 0, "total": 0,
"details": [], "feedback": "No test cases defined"}
passed = 0
details = []
for tc in test_cases:
result = _run_test_case(code, tc)
if result["passed"]:
passed += 1
details.append(result)
raw = passed / len(test_cases)
return {
"score": clamp(raw),
"passed": passed,
"total": len(test_cases),
"details": details,
"feedback": _feedback(raw, passed, len(test_cases)),
}
def _run_test_case(code: str, tc: dict) -> dict:
fn_name = tc.get("fn", "solution")
inputs = tc.get("input", [])
desc = tc.get("description", "")
if "fn_class" in tc:
return _run_class_test(code, tc)
exec_result = safe_exec(code, inputs, function_name=fn_name, timeout=5)
if not exec_result["ok"]:
expected_exc = tc.get("expected_exception")
error_str = exec_result.get("error", "")
exc_type = exec_result.get("type", "")
if expected_exc:
if (exc_type == expected_exc or
expected_exc.lower() in error_str.lower() or
expected_exc.lower() in exc_type.lower()):
return {"passed": True, "description": desc,
"note": f"Expected {expected_exc} raised"}
return {"passed": False, "description": desc,
"error": error_str[:200]}
output = exec_result.get("output")
# SQL injection parameterization check
if tc.get("sql_injection_check"):
if not _is_seq(output) or len(output) < 2:
return {"passed": False, "description": desc, "error": "Not a 2-element sequence"}
query = str(output[0])
payload_val = inputs[0] if inputs else ""
has_ph = any(p in query for p in ("%s", "?", ":param", "%(username"))
safe = str(payload_val).strip() not in query
return {"passed": has_ph and safe, "description": desc,
"note": f"placeholder={has_ph} payload_safe={safe}"}
# Not-None
if "expected_not_none" in tc:
return {"passed": output is not None, "description": desc}
# Equality
if "expected" in tc:
return {"passed": output == tc["expected"], "description": desc,
"got": output, "expected": tc["expected"]}
# Type check (JSON converts tuple→list)
if "expected_type" in tc:
tname = tc["expected_type"]
atype = type(output).__name__
equiv = {("tuple","list"),("list","tuple")}
ok = atype == tname or (atype, tname) in equiv or (tname, atype) in equiv
if ok and "expected_len" in tc:
ok = hasattr(output, "__len__") and len(output) == tc["expected_len"]
return {"passed": ok, "description": desc, "got_type": atype}
# Contains
if "expected_contains" in tc:
return {"passed": tc["expected_contains"] in str(output), "description": desc}
# Not-contains
if "expected_not_contains" in tc:
forbidden = tc["expected_not_contains"]
if isinstance(forbidden, list):
ok = not any(f in str(output) for f in forbidden)
else:
ok = forbidden not in str(output)
return {"passed": ok, "description": desc, "got": str(output)[:100]}
# Min length
if "expected_min_len" in tc:
return {"passed": output is not None and len(str(output)) >= tc["expected_min_len"],
"description": desc}
# Max length
if "expected_max_len" in tc:
return {"passed": output is not None and len(str(output)) <= tc["expected_max_len"],
"description": desc}
# Ok-flag (dict with "ok" key)
if "expected_ok" in tc:
return {"passed": isinstance(output, dict) and output.get("ok") == tc["expected_ok"],
"description": desc}
return {"passed": True, "description": desc, "note": "No assertion"}
def _run_class_test(code: str, tc: dict) -> dict:
class_name = tc.get("fn_class", "Solution")
init_args = tc.get("init_args", [])
method = tc.get("method", "is_allowed")
inputs = tc.get("input", [])
desc = tc.get("description", "")
harness = f"""
{code}
def run_task(args):
init_args = args[0]; method = args[1]; inputs = args[2]
obj = {class_name}(*init_args)
if method == "is_allowed_multi":
result = None
for _ in range(3):
result = obj.is_allowed(inputs[0])
return result
if method == "independent_clients":
return obj.is_allowed("client_a") == obj.is_allowed("client_b") == True
return getattr(obj, method)(*inputs)
"""
result = safe_exec(harness, [[init_args, method, inputs]],
function_name="run_task", timeout=5)
if not result["ok"]:
return {"passed": False, "description": desc, "error": result.get("error","")[:200]}
output = result.get("output")
if "expected" in tc:
return {"passed": output == tc["expected"], "description": desc}
if "expected_last" in tc:
return {"passed": output == tc["expected_last"], "description": desc}
return {"passed": True, "description": desc}
def _feedback(score: float, passed: int, total: int) -> str:
if score >= 0.9: return f"Excellent — {passed}/{total} tests passed"
elif score >= 0.7: return f"Good — {passed}/{total} tests passed"
elif score >= 0.5: return f"Partial — {passed}/{total} tests passed"
else: return f"Poor — {passed}/{total} tests passed"