Spaces:
Sleeping
Sleeping
| """ | |
| Grading logic for the Code Review Environment. | |
| """ | |
| from __future__ import annotations | |
| import re | |
| from typing import List, Tuple, Set | |
| import sys | |
| import os | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from models import Issue | |
| _SEV_RANK = {"low": 0, "medium": 1, "high": 2, "critical": 3} | |
| _TYPE_COMPAT = { | |
| "bug": {"bug", "logic"}, | |
| "logic": {"bug", "logic"}, | |
| "security": {"security"}, | |
| "performance": {"performance"}, | |
| } | |
| def match_issue(flagged: Issue, gt: Issue, line_tolerance: int = 2) -> bool: | |
| if flagged.filename != gt.filename: | |
| return False | |
| if abs(flagged.line_number - gt.line_number) > line_tolerance: | |
| return False | |
| compat = _TYPE_COMPAT.get(gt.issue_type, {gt.issue_type}) | |
| if flagged.issue_type not in compat: | |
| return False | |
| return True | |
| def grade_episode( | |
| flagged: List[Issue], | |
| ground_truth: List[Issue], | |
| line_tolerance: int = 2, | |
| ) -> float: | |
| """Compute a 0.0–1.0 score: 0.70 * F1 + 0.30 * severity_accuracy.""" | |
| if not ground_truth: | |
| return 1.0 if not flagged else 0.0 | |
| tp = 0 | |
| fp = 0 | |
| matched_gt_indices: Set[int] = set() | |
| severity_scores: List[float] = [] | |
| for flag in flagged: | |
| matched = False | |
| for i, gt in enumerate(ground_truth): | |
| if i in matched_gt_indices: | |
| continue | |
| if match_issue(flag, gt, line_tolerance): | |
| tp += 1 | |
| matched_gt_indices.add(i) | |
| matched = True | |
| flag_rank = _SEV_RANK.get(flag.severity, 1) | |
| gt_rank = _SEV_RANK.get(gt.severity, 1) | |
| distance = abs(flag_rank - gt_rank) | |
| severity_scores.append(max(0.0, 1.0 - distance * 0.34)) | |
| break | |
| if not matched: | |
| fp += 1 | |
| fn = len(ground_truth) - len(matched_gt_indices) | |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0 | |
| if severity_scores: | |
| severity_accuracy = sum(severity_scores) / len(ground_truth) | |
| else: | |
| severity_accuracy = 0.0 | |
| final = 0.70 * f1 + 0.30 * severity_accuracy | |
| return round(min(1.0, max(0.0, final)), 4) | |
| def compute_live_score(flagged: List[Issue], ground_truth: List[Issue]) -> float: | |
| """F1-only score for per-step feedback (no severity bonus).""" | |
| if not ground_truth: | |
| return 1.0 if not flagged else 0.0 | |
| tp = 0 | |
| fp = 0 | |
| matched: Set[int] = set() | |
| for flag in flagged: | |
| hit = False | |
| for i, gt in enumerate(ground_truth): | |
| if i not in matched and match_issue(flag, gt): | |
| tp += 1 | |
| matched.add(i) | |
| hit = True | |
| break | |
| if not hit: | |
| fp += 1 | |
| fn = len(ground_truth) - len(matched) | |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0 | |
| return round(f1, 4) | |
| _PATTERNS = [ | |
| (r"range\(len\(\w+\)\s*\+\s*1\)", None, "bug", "high", | |
| "Off-by-one error: range(len(x) + 1) iterates one past the end"), | |
| (r"left,\s*right\s*=\s*0,\s*len\(", None, "bug", "medium", | |
| "Binary search upper bound should be len(arr) - 1"), | |
| (r"counts\[word\]\s*=\s*0\b", None, "bug", "low", | |
| "Counter initialized to 0 instead of 1"), | |
| (r'SECRET_KEY\s*=\s*["\']', None, "security", "high", | |
| "Hardcoded SECRET_KEY in source code"), | |
| (r'PASSWORD\s*=\s*["\']', None, "security", "high", | |
| "Hardcoded password in source code"), | |
| (r"f['\"].*SELECT.*\{", None, "security", "critical", | |
| "SQL injection via f-string query construction"), | |
| (r"f['\"].*DELETE.*\{", None, "security", "critical", | |
| "SQL injection via f-string DELETE query"), | |
| (r"render_template_string\(f['\"]", None, "security", "high", | |
| "XSS: unsanitized user input in render_template_string"), | |
| (r"shell\s*=\s*True", None, "security", "critical", | |
| "Command injection risk: shell=True with user input"), | |
| (r"hashlib\.md5\(", None, "security", "medium", | |
| "MD5 is cryptographically broken, use SHA-256 or HMAC-SHA256"), | |
| (r"expected\s*==\s*\w+_hash", None, "security", "medium", | |
| "Timing attack: use hmac.compare_digest() for constant-time comparison"), | |
| (r"password\s*=\s*models\.CharField", None, "security", "critical", | |
| "Plaintext password storage in database"), | |
| (r"os\.path\.join\(['\"]\/", None, "security", "high", | |
| "Path traversal: os.path.join with absolute prefix doesn't prevent traversal"), | |
| (r"\.objects\.get\(id=item\.", None, "performance", "high", | |
| "N+1 query: database lookup inside a loop"), | |
| (r"FloatField\(\)", None, "bug", "medium", | |
| "FloatField for monetary values causes precision errors, use DecimalField"), | |
| (r"BinaryField\(\)", None, "security", "high", | |
| "BinaryField with pickled data is a deserialization vulnerability"), | |
| ] | |
| def run_keyword_baseline(task: dict) -> List[Issue]: | |
| findings: List[Issue] = [] | |
| seen_lines: set = set() | |
| for filename, code in task.get("code_files", {}).items(): | |
| lines = code.splitlines() | |
| for line_idx, line in enumerate(lines, start=1): | |
| for pattern, fname_hint, itype, severity, desc in _PATTERNS: | |
| # Optional filename filter | |
| if fname_hint and fname_hint not in filename: | |
| continue | |
| if re.search(pattern, line): | |
| key = (filename, line_idx) | |
| if key not in seen_lines: | |
| seen_lines.add(key) | |
| findings.append(Issue( | |
| line_number=line_idx, | |
| filename=filename, | |
| issue_type=itype, | |
| severity=severity, | |
| description=desc, | |
| )) | |
| return findings | |