File size: 6,960 Bytes
02ff91f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
Hierarchical Factored Action Space.

4 heads decoded sequentially at each step:
  Head 1: Meta-action β€” what high-level thing to do?
  Head 2: Specialist selection β€” which specialist(s) to call?
  Head 3: Delegation mode β€” how to call them?
  Head 4: Mode parameters β€” how many rounds, threshold, etc.?

Design: Sequential decomposition keeps each head's distribution
tractable for PPO. The policy sees a flattened joint action, but
training uses the factored structure.
"""

from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum
from typing import Optional
import numpy as np


class MetaAction(IntEnum):
    """Top-level orchestrator decisions."""
    CALL_SPECIALIST  = 0    # Call one or more specialists
    STOP             = 1    # Stop delegation, synthesize output
    CALL_MEDIATOR    = 2    # Call conflict mediator
    CLARIFY_TASK     = 3    # Request task clarification (if ambiguous)
    DELEGATE_SUBTASK = 4    # Delegate a sub-problem (2nd level)
    RETRY_FAILED     = 5    # Retry a failed specialist with fallback
    PARALLEL_SPAWN   = 6    # Spawn parallel specialists
    SPAWN_SPECIALIST = 7    # Policy requests a new specialist be created


class DelegationMode(IntEnum):
    """How to execute the selected specialists."""
    SEQUENTIAL     = 0      # A β†’ B β†’ C (each sees previous output)
    PARALLEL       = 1      # A, B, C all run simultaneously
    FAN_OUT_REDUCE = 2      # A, B, C run β†’ mediator reduces output
    ITERATIVE      = 3      # Run specialist, check output, loop until threshold
    CONDITIONAL    = 4      # Run A; if condition met, run B, else C
    PRIORITY_QUEUE = 5      # Run in priority order, stop when threshold met
    BROADCAST      = 6      # Send to all specialists, take first to complete


@dataclass
class FactoredAction:
    """
    The complete action decoded from all 4 heads.
    This is what gets passed to the environment's step() function.
    """
    meta_action: MetaAction
    specialist_ids: list[str]               # Which specialists to call
    delegation_mode: DelegationMode         # How to call them
    mode_params: dict                       # Mode-specific parameters
    raw_action: Optional[np.ndarray] = None # Raw policy output (for logging)

    def is_terminal(self) -> bool:
        """Returns True if this action ends the episode."""
        return self.meta_action == MetaAction.STOP

    def to_log_dict(self) -> dict:
        return {
            "meta_action": self.meta_action.name,
            "specialists": self.specialist_ids,
            "mode": self.delegation_mode.name,
            "params": self.mode_params,
        }


class ActionDecoder:
    """
    Decodes a flat action vector from the policy into a FactoredAction.

    Action vector layout:
      [0]                     : meta_action index (int, 0–6)
      [1 : 1+max_specialists] : specialist selection (multi-hot float)
      [1+max_specialists]     : delegation_mode index (int, 0–6)
      [2+max_specialists : *] : mode_params (continuous, 4 floats)

    Total action dim = 1 + max_specialists + 1 + 4 = max_specialists + 6
    """

    NUM_META_ACTIONS    = len(MetaAction)
    NUM_DELEGATION_MODES = len(DelegationMode)
    NUM_MODE_PARAMS     = 4

    def __init__(self, specialist_ids: list[str], max_specialists: int = 8):
        self.specialist_ids = specialist_ids
        self.max_specialists = min(len(specialist_ids), max_specialists)
        self.action_dim = self.max_specialists + 6

    def decode(
        self,
        action_vector: np.ndarray,
        valid_specialist_mask: Optional[np.ndarray] = None,
    ) -> FactoredAction:
        """
        Decode a flat action vector into a FactoredAction.

        Args:
            action_vector: Flat numpy array from the policy
            valid_specialist_mask: Binary mask, 1 = valid, 0 = masked out
                                   (enforces DAG constraints)
        """
        action_vector = np.asarray(action_vector, dtype=np.float32)

        # Head 1: Meta-action
        meta_idx = int(np.clip(round(action_vector[0]), 0, self.NUM_META_ACTIONS - 1))
        meta_action = MetaAction(meta_idx)

        # Head 2: Specialist selection (multi-hot)
        spec_logits = action_vector[1: 1 + self.max_specialists]
        if valid_specialist_mask is not None:
            spec_logits = spec_logits * valid_specialist_mask[:self.max_specialists]

        selected_indices = np.where(spec_logits > 0.0)[0]
        if len(selected_indices) == 0 and meta_action == MetaAction.CALL_SPECIALIST:
            # Fallback: select the highest-scoring specialist
            selected_indices = [int(np.argmax(spec_logits))]

        selected_ids = [
            self.specialist_ids[i]
            for i in selected_indices
            if i < len(self.specialist_ids)
        ]

        # Head 3: Delegation mode
        mode_idx = int(np.clip(
            round(action_vector[1 + self.max_specialists]),
            0, self.NUM_DELEGATION_MODES - 1
        ))
        delegation_mode = DelegationMode(mode_idx)

        # Head 4: Mode parameters
        param_start = 2 + self.max_specialists
        raw_params = action_vector[param_start: param_start + self.NUM_MODE_PARAMS]
        mode_params = self._decode_mode_params(delegation_mode, raw_params)

        return FactoredAction(
            meta_action=meta_action,
            specialist_ids=selected_ids,
            delegation_mode=delegation_mode,
            mode_params=mode_params,
            raw_action=action_vector,
        )

    def _decode_mode_params(
        self, mode: DelegationMode, raw_params: np.ndarray
    ) -> dict:
        """Decode mode-specific parameters from the raw continuous params."""
        p = np.clip(raw_params, 0.0, 1.0)
        if mode == DelegationMode.ITERATIVE:
            return {
                "max_rounds": int(1 + round(p[0] * 4)),          # 1–5 rounds
                "quality_threshold": float(0.5 + p[1] * 0.5),    # 0.5–1.0
            }
        elif mode == DelegationMode.PRIORITY_QUEUE:
            return {
                "stop_threshold": float(0.6 + p[0] * 0.4),       # 0.6–1.0
            }
        elif mode == DelegationMode.CONDITIONAL:
            return {
                "condition_threshold": float(0.4 + p[0] * 0.6),  # 0.4–1.0
            }
        else:
            return {"parallel_budget_ms": int(2000 + p[0] * 6000)}

    def get_action_dim(self) -> int:
        return self.action_dim

    def build_specialist_mask(
        self, valid_specialist_ids: list[str]
    ) -> np.ndarray:
        """Build a binary mask for valid specialist selections."""
        mask = np.zeros(self.max_specialists, dtype=np.float32)
        valid_set = set(valid_specialist_ids)
        for i, sid in enumerate(self.specialist_ids[: self.max_specialists]):
            if sid in valid_set:
                mask[i] = 1.0
        return mask