tracefix_rl / server /graders.py
databoysu
active graders
7266968
"""Task graders for TraceFix-RL.
The online validator expects importable grader callables for each task entry.
These graders execute the real task tests against the final code state so the
judge can verify actual solution quality instead of a canned lookup.
"""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from core.sandbox import run_code_with_tests
from tasks.tasks import ALL_TASKS
MIN_SCORE = 0.01
MAX_SCORE = 0.98
def _clamp(score: float) -> float:
return round(min(max(score, MIN_SCORE), MAX_SCORE), 4)
def _as_mapping(value: Any) -> Optional[Mapping[str, Any]]:
if isinstance(value, Mapping):
return value
if hasattr(value, "model_dump"):
try:
dumped = value.model_dump()
except Exception:
return None
if isinstance(dumped, Mapping):
return dumped
if hasattr(value, "dict"):
try:
dumped = value.dict()
except Exception:
return None
if isinstance(dumped, Mapping):
return dumped
return None
def _find_score_value(payload: Any) -> Optional[float]:
mapping = _as_mapping(payload)
if mapping is not None:
for key in ("final_score", "grader_score", "score", "reward", "total_reward"):
value = mapping.get(key)
if isinstance(value, (int, float)):
return float(value)
for nested_key in ("metadata", "info", "observation", "state"):
nested_value = mapping.get(nested_key)
nested_score = _find_score_value(nested_value)
if nested_score is not None:
return nested_score
return None
for attr in ("final_score", "grader_score", "score", "reward", "total_reward"):
if hasattr(payload, attr):
value = getattr(payload, attr)
if isinstance(value, (int, float)):
return float(value)
for attr in ("metadata", "info", "observation", "state"):
if hasattr(payload, attr):
nested_score = _find_score_value(getattr(payload, attr))
if nested_score is not None:
return nested_score
return None
def _find_task(task_name: str) -> Optional[dict[str, Any]]:
for task in ALL_TASKS:
if task.get("name") == task_name:
return task
return None
def _extract_final_observation(payload: Any) -> Any:
if payload is None:
return None
mapping = _as_mapping(payload)
if mapping is not None:
for key in ("final_observation", "observation", "state", "last_observation"):
if key in mapping:
candidate = mapping.get(key)
if candidate is not None:
nested = _extract_final_observation(candidate)
if nested is not None:
return nested
if "trajectory" in mapping:
return _extract_final_observation(mapping.get("trajectory"))
return payload
if isinstance(payload, Sequence) and not isinstance(payload, (str, bytes, bytearray)):
if not payload:
return None
last_item = payload[-1]
if isinstance(last_item, Sequence) and not isinstance(last_item, (str, bytes, bytearray)) and len(last_item) >= 2:
return _extract_final_observation(last_item[1])
if isinstance(last_item, Mapping) or hasattr(last_item, "model_dump") or hasattr(last_item, "dict"):
return _extract_final_observation(last_item)
return last_item
return payload
def _observation_to_source(observation: Any) -> Optional[str]:
if observation is None:
return None
mapping = _as_mapping(observation)
if mapping is not None:
source = mapping.get("source")
if isinstance(source, str) and source.strip():
return source
code_lines = mapping.get("code_lines") or mapping.get("code")
if isinstance(code_lines, Sequence) and not isinstance(code_lines, (str, bytes, bytearray)):
lines = [str(line) for line in code_lines]
return "\n".join(lines)
code_dict = mapping.get("code_dict")
if isinstance(code_dict, Mapping) and code_dict:
ordered_lines: list[tuple[int, str]] = []
for key, value in code_dict.items():
try:
line_no = int(key)
except Exception:
continue
ordered_lines.append((line_no, str(value)))
if ordered_lines:
ordered_lines.sort(key=lambda item: item[0])
return "\n".join(line for _, line in ordered_lines)
for attr in ("source", "code", "code_lines", "code_dict"):
if hasattr(observation, attr):
value = getattr(observation, attr)
if isinstance(value, str) and value.strip():
return value
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
return "\n".join(str(line) for line in value)
if isinstance(value, Mapping) and value:
ordered_lines = []
for key, line in value.items():
try:
ordered_lines.append((int(key), str(line)))
except Exception:
continue
if ordered_lines:
ordered_lines.sort(key=lambda item: item[0])
return "\n".join(line for _, line in ordered_lines)
return None
def _evaluate_task(task_name: str, payload: Any) -> float:
task = _find_task(task_name)
if task is None:
return MIN_SCORE
final_observation = _extract_final_observation(payload)
source = _observation_to_source(final_observation)
if not source or not source.strip():
return MIN_SCORE
try:
_, results, syntax_err = run_code_with_tests(
source=source,
test_callables=task["tests"],
)
except Exception:
return MIN_SCORE
if syntax_err:
return MIN_SCORE
if results and all(test_result.passed for test_result in results):
return MAX_SCORE
return MIN_SCORE
def grade(payload: Any = None, *args: Any, task_name: str = "", **kwargs: Any) -> float:
"""Execute the task's real tests against the final code state."""
if payload is None and args:
payload = args[0]
if not task_name:
task_name = str(kwargs.get("task_id") or kwargs.get("name") or "")
if task_name:
active_payload = payload if payload is not None else kwargs
return _evaluate_task(task_name, active_payload)
return MIN_SCORE
def grade_valid_parentheses_wrong_mapping(*args: Any, **kwargs: Any) -> float:
task_kwargs = dict(kwargs)
task_kwargs["task_name"] = "valid_parentheses_wrong_mapping"
return grade(*args, **task_kwargs)
def grade_binary_search_off_by_one(*args: Any, **kwargs: Any) -> float:
task_kwargs = dict(kwargs)
task_kwargs["task_name"] = "binary_search_off_by_one"
return grade(*args, **task_kwargs)
def grade_reverse_string_returns_original(*args: Any, **kwargs: Any) -> float:
task_kwargs = dict(kwargs)
task_kwargs["task_name"] = "reverse_string_returns_original"
return grade(*args, **task_kwargs)
__all__ = [
"grade",
"grade_valid_parentheses_wrong_mapping",
"grade_binary_search_off_by_one",
"grade_reverse_string_returns_original",
]