Spaces:
Paused
Paused
| """ | |
| Hierarchical Factored Action Space. | |
| 4 heads decoded sequentially at each step: | |
| Head 1: Meta-action — what high-level thing to do? | |
| Head 2: Specialist selection — which specialist(s) to call? | |
| Head 3: Delegation mode — how to call them? | |
| Head 4: Mode parameters — how many rounds, threshold, etc.? | |
| Design: Sequential decomposition keeps each head's distribution | |
| tractable for PPO. The policy sees a flattened joint action, but | |
| training uses the factored structure. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from enum import IntEnum | |
| from typing import Optional | |
| import numpy as np | |
| class MetaAction(IntEnum): | |
| """Top-level orchestrator decisions.""" | |
| CALL_SPECIALIST = 0 # Call one or more specialists | |
| STOP = 1 # Stop delegation, synthesize output | |
| CALL_MEDIATOR = 2 # Call conflict mediator | |
| CLARIFY_TASK = 3 # Request task clarification (if ambiguous) | |
| DELEGATE_SUBTASK = 4 # Delegate a sub-problem (2nd level) | |
| RETRY_FAILED = 5 # Retry a failed specialist with fallback | |
| PARALLEL_SPAWN = 6 # Spawn parallel specialists | |
| SPAWN_SPECIALIST = 7 # Policy requests a new specialist be created | |
| class DelegationMode(IntEnum): | |
| """How to execute the selected specialists.""" | |
| SEQUENTIAL = 0 # A → B → C (each sees previous output) | |
| PARALLEL = 1 # A, B, C all run simultaneously | |
| FAN_OUT_REDUCE = 2 # A, B, C run → mediator reduces output | |
| ITERATIVE = 3 # Run specialist, check output, loop until threshold | |
| CONDITIONAL = 4 # Run A; if condition met, run B, else C | |
| PRIORITY_QUEUE = 5 # Run in priority order, stop when threshold met | |
| BROADCAST = 6 # Send to all specialists, take first to complete | |
| class FactoredAction: | |
| """ | |
| The complete action decoded from all 4 heads. | |
| This is what gets passed to the environment's step() function. | |
| """ | |
| meta_action: MetaAction | |
| specialist_ids: list[str] # Which specialists to call | |
| delegation_mode: DelegationMode # How to call them | |
| mode_params: dict # Mode-specific parameters | |
| raw_action: Optional[np.ndarray] = None # Raw policy output (for logging) | |
| def is_terminal(self) -> bool: | |
| """Returns True if this action ends the episode.""" | |
| return self.meta_action == MetaAction.STOP | |
| def to_log_dict(self) -> dict: | |
| return { | |
| "meta_action": self.meta_action.name, | |
| "specialists": self.specialist_ids, | |
| "mode": self.delegation_mode.name, | |
| "params": self.mode_params, | |
| } | |
| class ActionDecoder: | |
| """ | |
| Decodes a flat action vector from the policy into a FactoredAction. | |
| Action vector layout: | |
| [0] : meta_action index (int, 0–6) | |
| [1 : 1+max_specialists] : specialist selection (multi-hot float) | |
| [1+max_specialists] : delegation_mode index (int, 0–6) | |
| [2+max_specialists : *] : mode_params (continuous, 4 floats) | |
| Total action dim = 1 + max_specialists + 1 + 4 = max_specialists + 6 | |
| """ | |
| NUM_META_ACTIONS = len(MetaAction) | |
| NUM_DELEGATION_MODES = len(DelegationMode) | |
| NUM_MODE_PARAMS = 4 | |
| def __init__(self, specialist_ids: list[str], max_specialists: int = 8): | |
| self.specialist_ids = specialist_ids | |
| self.max_specialists = min(len(specialist_ids), max_specialists) | |
| self.action_dim = self.max_specialists + 6 | |
| def decode( | |
| self, | |
| action_vector: np.ndarray, | |
| valid_specialist_mask: Optional[np.ndarray] = None, | |
| ) -> FactoredAction: | |
| """ | |
| Decode a flat action vector into a FactoredAction. | |
| Args: | |
| action_vector: Flat numpy array from the policy | |
| valid_specialist_mask: Binary mask, 1 = valid, 0 = masked out | |
| (enforces DAG constraints) | |
| """ | |
| action_vector = np.asarray(action_vector, dtype=np.float32) | |
| # Head 1: Meta-action | |
| meta_idx = int(np.clip(round(action_vector[0]), 0, self.NUM_META_ACTIONS - 1)) | |
| meta_action = MetaAction(meta_idx) | |
| # Head 2: Specialist selection (multi-hot) | |
| spec_logits = action_vector[1: 1 + self.max_specialists] | |
| if valid_specialist_mask is not None: | |
| spec_logits = spec_logits * valid_specialist_mask[:self.max_specialists] | |
| selected_indices = np.where(spec_logits > 0.0)[0] | |
| if len(selected_indices) == 0 and meta_action == MetaAction.CALL_SPECIALIST: | |
| # Fallback: select the highest-scoring specialist | |
| selected_indices = [int(np.argmax(spec_logits))] | |
| selected_ids = [ | |
| self.specialist_ids[i] | |
| for i in selected_indices | |
| if i < len(self.specialist_ids) | |
| ] | |
| # Head 3: Delegation mode | |
| mode_idx = int(np.clip( | |
| round(action_vector[1 + self.max_specialists]), | |
| 0, self.NUM_DELEGATION_MODES - 1 | |
| )) | |
| delegation_mode = DelegationMode(mode_idx) | |
| # Head 4: Mode parameters | |
| param_start = 2 + self.max_specialists | |
| raw_params = action_vector[param_start: param_start + self.NUM_MODE_PARAMS] | |
| mode_params = self._decode_mode_params(delegation_mode, raw_params) | |
| return FactoredAction( | |
| meta_action=meta_action, | |
| specialist_ids=selected_ids, | |
| delegation_mode=delegation_mode, | |
| mode_params=mode_params, | |
| raw_action=action_vector, | |
| ) | |
| def _decode_mode_params( | |
| self, mode: DelegationMode, raw_params: np.ndarray | |
| ) -> dict: | |
| """Decode mode-specific parameters from the raw continuous params.""" | |
| p = np.clip(raw_params, 0.0, 1.0) | |
| if mode == DelegationMode.ITERATIVE: | |
| return { | |
| "max_rounds": int(1 + round(p[0] * 4)), # 1–5 rounds | |
| "quality_threshold": float(0.5 + p[1] * 0.5), # 0.5–1.0 | |
| } | |
| elif mode == DelegationMode.PRIORITY_QUEUE: | |
| return { | |
| "stop_threshold": float(0.6 + p[0] * 0.4), # 0.6–1.0 | |
| } | |
| elif mode == DelegationMode.CONDITIONAL: | |
| return { | |
| "condition_threshold": float(0.4 + p[0] * 0.6), # 0.4–1.0 | |
| } | |
| else: | |
| return {"parallel_budget_ms": int(2000 + p[0] * 6000)} | |
| def get_action_dim(self) -> int: | |
| return self.action_dim | |
| def build_specialist_mask( | |
| self, valid_specialist_ids: list[str] | |
| ) -> np.ndarray: | |
| """Build a binary mask for valid specialist selections.""" | |
| mask = np.zeros(self.max_specialists, dtype=np.float32) | |
| valid_set = set(valid_specialist_ids) | |
| for i, sid in enumerate(self.specialist_ids[: self.max_specialists]): | |
| if sid in valid_set: | |
| mask[i] = 1.0 | |
| return mask | |