code-review-env / server /graders.py
codemaverick2
Code Review Environment OpenEnv hackathon submission
ff9fcbd
"""
Grading logic for the Code Review Environment.
"""
from __future__ import annotations
import re
from typing import List, Tuple, Set
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models import Issue
_SEV_RANK = {"low": 0, "medium": 1, "high": 2, "critical": 3}
_TYPE_COMPAT = {
"bug": {"bug", "logic"},
"logic": {"bug", "logic"},
"security": {"security"},
"performance": {"performance"},
}
def match_issue(flagged: Issue, gt: Issue, line_tolerance: int = 2) -> bool:
if flagged.filename != gt.filename:
return False
if abs(flagged.line_number - gt.line_number) > line_tolerance:
return False
compat = _TYPE_COMPAT.get(gt.issue_type, {gt.issue_type})
if flagged.issue_type not in compat:
return False
return True
def grade_episode(
flagged: List[Issue],
ground_truth: List[Issue],
line_tolerance: int = 2,
) -> float:
"""Compute a 0.0–1.0 score: 0.70 * F1 + 0.30 * severity_accuracy."""
if not ground_truth:
return 1.0 if not flagged else 0.0
tp = 0
fp = 0
matched_gt_indices: Set[int] = set()
severity_scores: List[float] = []
for flag in flagged:
matched = False
for i, gt in enumerate(ground_truth):
if i in matched_gt_indices:
continue
if match_issue(flag, gt, line_tolerance):
tp += 1
matched_gt_indices.add(i)
matched = True
flag_rank = _SEV_RANK.get(flag.severity, 1)
gt_rank = _SEV_RANK.get(gt.severity, 1)
distance = abs(flag_rank - gt_rank)
severity_scores.append(max(0.0, 1.0 - distance * 0.34))
break
if not matched:
fp += 1
fn = len(ground_truth) - len(matched_gt_indices)
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
if severity_scores:
severity_accuracy = sum(severity_scores) / len(ground_truth)
else:
severity_accuracy = 0.0
final = 0.70 * f1 + 0.30 * severity_accuracy
return round(min(1.0, max(0.0, final)), 4)
def compute_live_score(flagged: List[Issue], ground_truth: List[Issue]) -> float:
"""F1-only score for per-step feedback (no severity bonus)."""
if not ground_truth:
return 1.0 if not flagged else 0.0
tp = 0
fp = 0
matched: Set[int] = set()
for flag in flagged:
hit = False
for i, gt in enumerate(ground_truth):
if i not in matched and match_issue(flag, gt):
tp += 1
matched.add(i)
hit = True
break
if not hit:
fp += 1
fn = len(ground_truth) - len(matched)
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
return round(f1, 4)
_PATTERNS = [
(r"range\(len\(\w+\)\s*\+\s*1\)", None, "bug", "high",
"Off-by-one error: range(len(x) + 1) iterates one past the end"),
(r"left,\s*right\s*=\s*0,\s*len\(", None, "bug", "medium",
"Binary search upper bound should be len(arr) - 1"),
(r"counts\[word\]\s*=\s*0\b", None, "bug", "low",
"Counter initialized to 0 instead of 1"),
(r'SECRET_KEY\s*=\s*["\']', None, "security", "high",
"Hardcoded SECRET_KEY in source code"),
(r'PASSWORD\s*=\s*["\']', None, "security", "high",
"Hardcoded password in source code"),
(r"f['\"].*SELECT.*\{", None, "security", "critical",
"SQL injection via f-string query construction"),
(r"f['\"].*DELETE.*\{", None, "security", "critical",
"SQL injection via f-string DELETE query"),
(r"render_template_string\(f['\"]", None, "security", "high",
"XSS: unsanitized user input in render_template_string"),
(r"shell\s*=\s*True", None, "security", "critical",
"Command injection risk: shell=True with user input"),
(r"hashlib\.md5\(", None, "security", "medium",
"MD5 is cryptographically broken, use SHA-256 or HMAC-SHA256"),
(r"expected\s*==\s*\w+_hash", None, "security", "medium",
"Timing attack: use hmac.compare_digest() for constant-time comparison"),
(r"password\s*=\s*models\.CharField", None, "security", "critical",
"Plaintext password storage in database"),
(r"os\.path\.join\(['\"]\/", None, "security", "high",
"Path traversal: os.path.join with absolute prefix doesn't prevent traversal"),
(r"\.objects\.get\(id=item\.", None, "performance", "high",
"N+1 query: database lookup inside a loop"),
(r"FloatField\(\)", None, "bug", "medium",
"FloatField for monetary values causes precision errors, use DecimalField"),
(r"BinaryField\(\)", None, "security", "high",
"BinaryField with pickled data is a deserialization vulnerability"),
]
def run_keyword_baseline(task: dict) -> List[Issue]:
findings: List[Issue] = []
seen_lines: set = set()
for filename, code in task.get("code_files", {}).items():
lines = code.splitlines()
for line_idx, line in enumerate(lines, start=1):
for pattern, fname_hint, itype, severity, desc in _PATTERNS:
# Optional filename filter
if fname_hint and fname_hint not in filename:
continue
if re.search(pattern, line):
key = (filename, line_idx)
if key not in seen_lines:
seen_lines.add(key)
findings.append(Issue(
line_number=line_idx,
filename=filename,
issue_type=itype,
severity=severity,
description=desc,
))
return findings