shreyas-joshi's picture
feat: Enhance module resolution and grading configuration
bc02c39
from __future__ import annotations
from dataclasses import dataclass
from graph.graph_manager import GraphManager
from tasks.easy_task import TASK_CONFIG as EASY_TASK
from tasks.hard_task import TASK_CONFIG as HARD_TASK
from tasks.medium_task import TASK_CONFIG as MEDIUM_TASK
@dataclass(frozen=True)
class TaskSpec:
task_id: str
task_level: str
description: str
default_modules: list[str]
grader: str
max_steps: int
allow_module_override: bool
expand_to_dependencies: bool
def _to_spec(config: dict[str, object]) -> TaskSpec:
return TaskSpec(
task_id=str(config["task_id"]),
task_level=str(config["task_level"]),
description=str(config["description"]),
default_modules=[str(item) for item in list(config["default_modules"])],
grader=str(config["grader"]),
max_steps=int(config["max_steps"]),
allow_module_override=bool(config["allow_module_override"]),
expand_to_dependencies=bool(config["expand_to_dependencies"]),
)
TASK_REGISTRY: dict[str, TaskSpec] = {
"style_review": _to_spec(EASY_TASK),
"logic_review": _to_spec(MEDIUM_TASK),
"cascade_review": _to_spec(HARD_TASK),
}
def list_tasks() -> list[TaskSpec]:
return [TASK_REGISTRY[key] for key in sorted(TASK_REGISTRY)]
def get_task(task_id: str) -> TaskSpec:
if task_id not in TASK_REGISTRY:
raise ValueError(f"Unknown task_id: {task_id}")
return TASK_REGISTRY[task_id]
def resolve_task_modules(
task: TaskSpec,
graph_manager: GraphManager,
module_override: list[str] | None = None,
) -> list[str]:
graph = graph_manager.load_graph()
available = set(graph.nodes())
base_modules = task.default_modules
if module_override:
if not task.allow_module_override:
raise ValueError(f"Task {task.task_id} does not allow module_override")
requested: list[str] = []
for module in module_override:
try:
requested.append(graph_manager.resolve_module_id(module))
except ValueError:
continue
if not requested:
raise ValueError("module_override does not include any known module ids")
base_modules = sorted(set(requested))
if not task.expand_to_dependencies:
return [module for module in base_modules if module in available]
expanded: set[str] = set()
for module_id in base_modules:
if module_id not in available:
continue
expanded.add(module_id)
expanded.update(graph.successors(module_id))
expanded.update(graph.predecessors(module_id))
# Dependency-aware policy: prefer leaves first and keep deterministic ordering.
traversal_rank = {node: idx for idx, node in enumerate(graph_manager.traversal_order())}
return sorted(expanded, key=lambda node: (int(traversal_rank.get(node, 10_000)), node))