Spaces:
Runtime error
Runtime error
| """ | |
| Path consistency tracking with Dirichlet prior. | |
| Non-zero from Episode 1, avoids cold-start problem from v3. | |
| """ | |
| from __future__ import annotations | |
| import numpy as np | |
| from collections import defaultdict | |
| class PathConsistencyTracker: | |
| """ | |
| Tracks how consistently the policy routes the same task type. | |
| Uses a Dirichlet prior (alpha=1.0) so the bonus is non-zero from episode 1. | |
| """ | |
| DIRICHLET_ALPHA = 1.0 | |
| def __init__(self, specialist_ids: list[str]): | |
| self.specialist_ids = specialist_ids | |
| self._task_path_counts: dict[str, dict[str, int]] = defaultdict( | |
| lambda: defaultdict(int) | |
| ) | |
| def record_path(self, task_class: str, delegation_path: list) -> None: | |
| """Record the delegation path used for a task class.""" | |
| path_key = self._path_to_key(delegation_path) | |
| self._task_path_counts[task_class][path_key] += 1 | |
| def consistency_score( | |
| self, delegation_path: list, task_class: str | |
| ) -> float: | |
| """ | |
| Score how consistent this path is with previous paths for this task class. | |
| Returns 0.0–1.0. Non-zero from episode 1 due to Dirichlet prior. | |
| """ | |
| path_key = self._path_to_key(delegation_path) | |
| counts = self._task_path_counts.get(task_class, {}) | |
| # Add Dirichlet prior counts | |
| all_paths = set(counts.keys()) | {path_key} | |
| pseudo_counts = {p: counts.get(p, 0) + self.DIRICHLET_ALPHA for p in all_paths} | |
| total = sum(pseudo_counts.values()) | |
| return float(pseudo_counts[path_key] / total) | |
| def _path_to_key(self, delegation_path: list) -> str: | |
| """Convert a delegation path to a hashable string key.""" | |
| if not delegation_path: | |
| return "empty" | |
| parts = [] | |
| for edge in delegation_path: | |
| if hasattr(edge, "callee_id"): | |
| parts.append(edge.callee_id) | |
| elif isinstance(edge, dict): | |
| parts.append(edge.get("callee_id", "?")) | |
| return "->".join(parts) | |