File size: 2,930 Bytes
54708e8
03b0173
54708e8
03b0173
54708e8
 
 
 
03b0173
54708e8
 
 
03b0173
 
 
 
54708e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03b0173
 
 
 
 
54708e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03b0173
54708e8
 
03b0173
54708e8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Scene-level task plugin interface (v0.4).

Each task lives at src/tasks/<name>/__init__.py and exports `TASK = TaskPlugin(...)`.

The atomic evaluation unit is a *scene*:
  validate_scene(scene_sandbox_dir) -> None
  evaluate_scene(scene_sandbox_dir, gt_scene) -> Dict[str, float]
  load_gt_scene(scene_id, gt_dir) -> Any

After all scenes are settled, `aggregate(per_scene_metrics) -> Dict[str, float]`
produces the final leaderboard metrics (default = arithmetic mean per key, missing
scenes are skipped — the task can override).
"""
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional


def _default_aggregate(per_scene: Dict[str, Dict[str, float]]) -> Dict[str, float]:
    if not per_scene:
        return {}
    sums: Dict[str, float] = {}
    counts: Dict[str, int] = {}
    for metrics in per_scene.values():
        for k, v in metrics.items():
            try:
                fv = float(v)
            except (TypeError, ValueError):
                continue
            sums[k] = sums.get(k, 0.0) + fv
            counts[k] = counts.get(k, 0) + 1
    return {k: round(sums[k] / counts[k], 6) for k in sums}


@dataclass
class TaskPlugin:
    # ---- identity ----
    name: str
    display_name: str
    description: str

    # ---- per-scene contract ----
    expected_scene_layout: str
    validate_scene_fn: Callable[[str], None]
    evaluate_scene_fn: Callable[[str, Any], Dict[str, float]]
    load_gt_scene_fn: Callable[[str, str], Any]

    # ---- leaderboard ----
    primary_metric: str
    higher_is_better: bool = True
    leaderboard_columns: List[str] = field(default_factory=list)

    # ---- aggregate (optional override) ----
    aggregate_fn: Optional[Callable[[Dict[str, Dict[str, float]]], Dict[str, float]]] = None

    # ---- convenience wrappers ----
    def validate_scene(self, scene_dir: str) -> None:
        self.validate_scene_fn(scene_dir)

    def evaluate_scene(self, scene_dir: str, gt: Any) -> Dict[str, float]:
        metrics = self.evaluate_scene_fn(scene_dir, gt)
        if not isinstance(metrics, dict):
            raise RuntimeError(f"task {self.name}: evaluate_scene must return dict, got {type(metrics)}")
        return metrics

    def load_gt_scene(self, scene_id: str, gt_dir: str) -> Any:
        return self.load_gt_scene_fn(scene_id, gt_dir)

    def aggregate(self, per_scene: Dict[str, Dict[str, float]]) -> Dict[str, float]:
        agg = (self.aggregate_fn or _default_aggregate)(per_scene)
        if not isinstance(agg, dict):
            raise RuntimeError(f"task {self.name}: aggregate must return dict")
        if per_scene and self.primary_metric not in agg:
            raise RuntimeError(
                f"task {self.name}: aggregated metrics missing primary_metric "
                f"'{self.primary_metric}'. Keys: {list(agg.keys())}"
            )
        return agg