vLAR's picture
updaet
54708e8
"""Scene-level task plugin interface (v0.4).
Each task lives at src/tasks/<name>/__init__.py and exports `TASK = TaskPlugin(...)`.
The atomic evaluation unit is a *scene*:
validate_scene(scene_sandbox_dir) -> None
evaluate_scene(scene_sandbox_dir, gt_scene) -> Dict[str, float]
load_gt_scene(scene_id, gt_dir) -> Any
After all scenes are settled, `aggregate(per_scene_metrics) -> Dict[str, float]`
produces the final leaderboard metrics (default = arithmetic mean per key, missing
scenes are skipped — the task can override).
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
def _default_aggregate(per_scene: Dict[str, Dict[str, float]]) -> Dict[str, float]:
if not per_scene:
return {}
sums: Dict[str, float] = {}
counts: Dict[str, int] = {}
for metrics in per_scene.values():
for k, v in metrics.items():
try:
fv = float(v)
except (TypeError, ValueError):
continue
sums[k] = sums.get(k, 0.0) + fv
counts[k] = counts.get(k, 0) + 1
return {k: round(sums[k] / counts[k], 6) for k in sums}
@dataclass
class TaskPlugin:
# ---- identity ----
name: str
display_name: str
description: str
# ---- per-scene contract ----
expected_scene_layout: str
validate_scene_fn: Callable[[str], None]
evaluate_scene_fn: Callable[[str, Any], Dict[str, float]]
load_gt_scene_fn: Callable[[str, str], Any]
# ---- leaderboard ----
primary_metric: str
higher_is_better: bool = True
leaderboard_columns: List[str] = field(default_factory=list)
# ---- aggregate (optional override) ----
aggregate_fn: Optional[Callable[[Dict[str, Dict[str, float]]], Dict[str, float]]] = None
# ---- convenience wrappers ----
def validate_scene(self, scene_dir: str) -> None:
self.validate_scene_fn(scene_dir)
def evaluate_scene(self, scene_dir: str, gt: Any) -> Dict[str, float]:
metrics = self.evaluate_scene_fn(scene_dir, gt)
if not isinstance(metrics, dict):
raise RuntimeError(f"task {self.name}: evaluate_scene must return dict, got {type(metrics)}")
return metrics
def load_gt_scene(self, scene_id: str, gt_dir: str) -> Any:
return self.load_gt_scene_fn(scene_id, gt_dir)
def aggregate(self, per_scene: Dict[str, Dict[str, float]]) -> Dict[str, float]:
agg = (self.aggregate_fn or _default_aggregate)(per_scene)
if not isinstance(agg, dict):
raise RuntimeError(f"task {self.name}: aggregate must return dict")
if per_scene and self.primary_metric not in agg:
raise RuntimeError(
f"task {self.name}: aggregated metrics missing primary_metric "
f"'{self.primary_metric}'. Keys: {list(agg.keys())}"
)
return agg