Spaces:
Running
Running
Fix Phase 2 OpenEnv validation traps: add grader paths to openenv.yaml and safe parameterless defaults
699f953 | # server/graders/dependency_grader.py | |
| # Grader for PyTorch Migration Time-Machine tasks (dep_easy, dep_medium, dep_hard). | |
| # | |
| # FIX SUMMARY: | |
| # 1. _score_flag: F1 was too loose β model could name extra packages and still score high | |
| # FIX: Added precision penalty so naming extra/wrong packages hurts | |
| # 2. _score_resolve: bonus of 0.15 for all-correct inflated scores to 0.99 | |
| # FIX: Removed bonus, tightened cross-constraint checking | |
| # 3. _score_migrate: fix_quality was too generous (0.6 partial credit) | |
| # FIX: Lowered partial credit to 0.3, required more precise token matching | |
| from typing import Dict, Any | |
| from .base_grader import grade_dynamic, safe_score | |
| try: | |
| from packaging.version import Version | |
| from packaging.specifiers import SpecifierSet | |
| _HAS_PACKAGING = True | |
| except ImportError: | |
| _HAS_PACKAGING = False | |
| VALID_ACTIONS = ['flag_outdated', 'resolve_conflict', 'migrate_api', 'validate_tree'] | |
| FORBIDDEN = [] | |
| def _normalize_ver(v: str) -> str: | |
| parts = str(v).strip().split('.') | |
| while len(parts) < 3: | |
| parts.append('0') | |
| return '.'.join(parts[:3]) | |
| def _parse_version_tuple(v: str) -> tuple: | |
| try: | |
| parts = _normalize_ver(v).split('.') | |
| return tuple(int(p) for p in parts[:3]) | |
| except (ValueError, AttributeError): | |
| return (0, 0, 0) | |
| def _simple_version_check(ver_str: str, constraint: str) -> bool: | |
| ver = _parse_version_tuple(ver_str) | |
| parts = [c.strip() for c in constraint.split(',') if c.strip()] | |
| for part in parts: | |
| if part.startswith('>='): | |
| if ver < _parse_version_tuple(part[2:]): | |
| return False | |
| elif part.startswith('<='): | |
| if ver > _parse_version_tuple(part[2:]): | |
| return False | |
| elif part.startswith('!='): | |
| if ver == _parse_version_tuple(part[2:]): | |
| return False | |
| elif part.startswith('>'): | |
| if ver <= _parse_version_tuple(part[1:]): | |
| return False | |
| elif part.startswith('<'): | |
| if ver >= _parse_version_tuple(part[1:]): | |
| return False | |
| elif part.startswith('=='): | |
| if ver != _parse_version_tuple(part[2:]): | |
| return False | |
| else: | |
| if ver != _parse_version_tuple(part): | |
| return False | |
| return True | |
| def _f1(predicted, expected): | |
| """Compute F1 score between predicted and expected sets.""" | |
| if not expected: | |
| return 1.0 if not predicted else 0.0 | |
| if not predicted: | |
| return 0.0 | |
| pred_s = set(str(p).strip() for p in predicted) | |
| exp_s = set(str(e).strip() for e in expected) | |
| tp = len(pred_s & exp_s) | |
| p = tp / len(pred_s) if pred_s else 0.0 | |
| r = tp / len(exp_s) if exp_s else 0.0 | |
| return round(2 * p * r / max(p + r, 0.001), 4) | |
| def _downgrades(proposed: Dict, case: Dict) -> int: | |
| reqs = case.get('requirements', {}) | |
| count = 0 | |
| for pkg, ver in proposed.items(): | |
| if pkg in reqs: | |
| try: | |
| if _HAS_PACKAGING: | |
| if Version(_normalize_ver(ver)) < Version(_normalize_ver(reqs[pkg])): | |
| count += 1 | |
| else: | |
| if _parse_version_tuple(ver) < _parse_version_tuple(reqs[pkg]): | |
| count += 1 | |
| except Exception: | |
| pass | |
| return count | |
| def _score_flag(action: Dict, case: Dict) -> float: | |
| """Score deprecated API detection (dep_easy). | |
| FIX: | |
| - Previously F1 alone let models name 10 packages and still score well if 1 correct | |
| - Now: precision matters heavily β flagging extra packages is penalized | |
| - Deprecated API match: tightened, exact match required for full credit | |
| Weights: precision=30%, recall=25%, deprecated_api=45% | |
| """ | |
| exp = set(case.get('expected_outdated_packages', [])) | |
| flagged = set(action.get('packages', {}).keys()) | |
| if not exp: | |
| return 0.3 | |
| tp = len(flagged & exp) | |
| # FIX: Separate precision and recall, weight them differently | |
| # Precision: don't flag random packages (penalizes hallucinating packages) | |
| precision = tp / len(flagged) if flagged else 0.0 | |
| # Recall: find the actual outdated packages | |
| recall = tp / len(exp) if exp else 0.0 | |
| # FIX: Deprecated API match β tightened | |
| expected_api = case.get('expected_deprecated_api', '') | |
| actual_api = action.get('deprecated_api', '') or '' | |
| if actual_api == expected_api: | |
| dep_ok = 1.0 | |
| elif expected_api and expected_api.split('.')[-1].lower() in actual_api.lower(): | |
| # partial: just the last segment (e.g. "Variable" in "autograd.Variable") | |
| dep_ok = 0.50 # FIX: was 0.7 | |
| elif expected_api and any(p.lower() in actual_api.lower() for p in expected_api.split('.')): | |
| dep_ok = 0.20 # FIX: was 0.4 | |
| else: | |
| dep_ok = 0.0 | |
| # FIX: Weights β precision 30%, recall 25%, api 45% | |
| # Previously: f1 55%, api 45% β f1 hid precision failures | |
| return safe_score(precision * 0.30 + recall * 0.25 + dep_ok * 0.45) | |
| def _score_resolve(action: Dict, case: Dict) -> float: | |
| """Score version conflict resolution (dep_medium). | |
| FIX: | |
| - Removed the 0.15 bonus for all-correct (was inflating to 0.99) | |
| - Cross-constraint checking is now STRICT β partial version match gives 0 credit | |
| - Downgrade penalty increased from 0.10 to 0.15 per downgrade | |
| Now: a perfect answer scores ~0.85, not 0.99 | |
| A partial (1/2 correct) scores ~0.40 | |
| A wrong answer scores ~0.10 | |
| """ | |
| compat = case.get('compatibility_matrix', {}) | |
| proposed = action.get('packages', {}) | |
| conflict_pkgs = case.get('conflict_packages', []) | |
| if not conflict_pkgs: | |
| return 0.20 | |
| if not proposed: | |
| return 0.05 | |
| valid = 0 | |
| for pkg in conflict_pkgs: | |
| if pkg not in proposed: | |
| continue | |
| ver = proposed[pkg] | |
| if pkg not in compat: | |
| continue | |
| norm_ver = _normalize_ver(ver) | |
| pkg_versions = compat[pkg] | |
| # Find matching version in compat matrix | |
| matched_ver = None | |
| for k in pkg_versions: | |
| if _normalize_ver(k) == norm_ver: | |
| matched_ver = k | |
| break | |
| # FIX: Removed patch-level fuzzy match β versions must be reasonably exact | |
| # (major.minor match still allowed, but NOT major-only) | |
| if not matched_ver: | |
| norm_major_minor = '.'.join(norm_ver.split('.')[:2]) | |
| for k in pkg_versions: | |
| k_mm = '.'.join(_normalize_ver(k).split('.')[:2]) | |
| if k_mm == norm_major_minor: | |
| matched_ver = k | |
| break | |
| if not matched_ver: | |
| continue # Version not in compatibility matrix at all β 0 credit | |
| # Check cross-dependency constraints | |
| deps = pkg_versions[matched_ver] | |
| cross_ok = True | |
| if isinstance(deps, dict): | |
| for dep_pkg, constraint in deps.items(): | |
| if dep_pkg in proposed: | |
| dep_ver = _normalize_ver(proposed[dep_pkg]) | |
| try: | |
| if _HAS_PACKAGING: | |
| if Version(dep_ver) not in SpecifierSet(constraint): | |
| cross_ok = False | |
| break | |
| else: | |
| if not _simple_version_check(dep_ver, constraint): | |
| cross_ok = False | |
| break | |
| except Exception: | |
| pass | |
| if cross_ok: | |
| valid += 1 | |
| # FIX: Base score β no bonus, just ratio | |
| base = valid / len(conflict_pkgs) | |
| # FIX: Downgrade penalty increased from 0.10 to 0.15 | |
| down = _downgrades(proposed, case) * 0.15 | |
| # FIX: Max possible without penalties is 1.0, which gets clamped to 0.99 by safe_score | |
| # But in practice perfect = 1.0 - 0 downgrades = 1.0 β 0.99 after clamp | |
| # And partial (1/2) = 0.50 β clear signal | |
| return safe_score(base - down) | |
| def _score_migrate(action: Dict, case: Dict) -> float: | |
| """Score graph-break migration (dep_hard). | |
| FIX: | |
| - fix_quality partial credit lowered from 0.6 to 0.25 | |
| (model must actually include the right fix, not just a vague description) | |
| - Order violation penalty increased from 0.20 to 0.30 per violation | |
| - Extra steps penalty increased from 0.10 to 0.15 | |
| """ | |
| checklist = case.get('graph_breaks', []) | |
| dep_graph = case.get('checklist_dependency_graph', {}) | |
| completed = action.get('completed_items', []) | |
| fix_map = case.get('correct_fix_map', {}) | |
| if not checklist: | |
| return 0.5 | |
| if not completed: | |
| return 0.0 | |
| # FIX: Order violations penalized more heavily (0.30 per violation, was 0.20) | |
| viol = sum( | |
| 1 for item in completed | |
| for pre in dep_graph.get(item, []) | |
| if pre not in completed | |
| ) | |
| order_score = max(0.0, 1.0 - viol * 0.30) | |
| # Checklist coverage | |
| covered = [b for b in checklist if b in completed] | |
| completeness = len(covered) / max(len(checklist), 1) | |
| # FIX: Fix quality β token must be present, partial credit reduced to 0.25 | |
| fix_qs = [] | |
| for b in covered: | |
| if b not in fix_map: | |
| continue | |
| expected_token = fix_map[b].lower() | |
| actual_fix = str(action.get('code_changes', {}).get(b, '')).lower() | |
| if expected_token in actual_fix: | |
| fix_qs.append(1.0) | |
| elif any(word in actual_fix for word in expected_token.split()): | |
| fix_qs.append(0.25) # FIX: was 0.6 β partial credit halved | |
| else: | |
| fix_qs.append(0.0) # FIX: No fix at all β 0, not 0.6 | |
| fix_quality = sum(fix_qs) / max(len(fix_qs), 1) if fix_qs else 0.0 | |
| # FIX: Extra steps penalty increased from 0.10 to 0.15 | |
| extra = max(0, len(completed) - len(checklist)) | |
| efficiency = max(0.0, 1.0 - extra * 0.15) | |
| return safe_score(order_score * 0.30 + completeness * 0.40 + fix_quality * 0.20 + efficiency * 0.10) | |
| def compute_correctness(action: Dict, case: Dict) -> float: | |
| """Route to correct scoring function based on action_type.""" | |
| atype = action.get('action_type') | |
| if atype == 'flag_outdated': | |
| return _score_flag(action, case) | |
| if atype == 'resolve_conflict': | |
| return _score_resolve(action, case) | |
| if atype in ('migrate_api', 'validate_tree'): | |
| return _score_migrate(action, case) | |
| return None | |
| def grade(action: Dict = None, session: Any = None) -> float: | |
| """Entry point called by router. Runs full reward pipeline. | |
| Survives parameterless reflection testing by returning 0.01. | |
| """ | |
| if action is None or session is None: | |
| return 0.01 | |
| return grade_dynamic(action, session, compute_correctness, VALID_ACTIONS, FORBIDDEN, max_steps=8) | |