File size: 2,025 Bytes
02ff91f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
Path consistency tracking with Dirichlet prior.
Non-zero from Episode 1, avoids cold-start problem from v3.
"""

from __future__ import annotations
import numpy as np
from collections import defaultdict


class PathConsistencyTracker:
    """
    Tracks how consistently the policy routes the same task type.
    Uses a Dirichlet prior (alpha=1.0) so the bonus is non-zero from episode 1.
    """

    DIRICHLET_ALPHA = 1.0

    def __init__(self, specialist_ids: list[str]):
        self.specialist_ids = specialist_ids
        self._task_path_counts: dict[str, dict[str, int]] = defaultdict(
            lambda: defaultdict(int)
        )

    def record_path(self, task_class: str, delegation_path: list) -> None:
        """Record the delegation path used for a task class."""
        path_key = self._path_to_key(delegation_path)
        self._task_path_counts[task_class][path_key] += 1

    def consistency_score(
        self, delegation_path: list, task_class: str
    ) -> float:
        """
        Score how consistent this path is with previous paths for this task class.
        Returns 0.0–1.0. Non-zero from episode 1 due to Dirichlet prior.
        """
        path_key = self._path_to_key(delegation_path)
        counts = self._task_path_counts.get(task_class, {})

        # Add Dirichlet prior counts
        all_paths = set(counts.keys()) | {path_key}
        pseudo_counts = {p: counts.get(p, 0) + self.DIRICHLET_ALPHA for p in all_paths}
        total = sum(pseudo_counts.values())

        return float(pseudo_counts[path_key] / total)

    def _path_to_key(self, delegation_path: list) -> str:
        """Convert a delegation path to a hashable string key."""
        if not delegation_path:
            return "empty"
        parts = []
        for edge in delegation_path:
            if hasattr(edge, "callee_id"):
                parts.append(edge.callee_id)
            elif isinstance(edge, dict):
                parts.append(edge.get("callee_id", "?"))
        return "->".join(parts)