File size: 5,468 Bytes
02ff91f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
State Representation β€” Fully observable episode state for the RL policy.

State components:
  1. Task embedding (384-dim) β€” what needs to be done
  2. Roster embedding matrix (N Γ— 384) β€” available specialists
  3. Called specialist embeddings (K Γ— 384) β€” who has been called
  4. Delegation graph adjacency vector (100-dim) β€” call structure
  5. Scratchpad summary embedding (384-dim) β€” context so far
  6. Scalar features (8-dim) β€” step count, depth, costs, etc.
  7. Called specialist mask (N-dim) β€” binary, who's been called

Flattened total: ~1376 + N*384 dims (variable; padded to max_specialists)
"""

from __future__ import annotations
import numpy as np
from dataclasses import dataclass, field
from typing import Optional


@dataclass
class EpisodeState:
    """
    Complete state for one timestep in an episode.
    Built by the SpindleFlowEnv at each step.
    """
    # Core semantic representations
    task_embedding: np.ndarray          # (384,)
    roster_embeddings: np.ndarray       # (max_specialists, 384)
    called_embeddings: np.ndarray       # (max_specialists, 384) β€” 0s for uncalled
    scratchpad_embedding: np.ndarray    # (384,)

    # Structural signals
    delegation_graph_adj: np.ndarray    # (100,) flat adjacency
    called_mask: np.ndarray             # (max_specialists,) binary

    # Scalar features
    step_count: int
    delegation_depth: int
    num_specialists_called: int
    max_specialists: int
    max_depth: int
    elapsed_ms: float
    sla_budget_ms: float
    phase: int  # 1, 2, or 3 (curriculum phase)

    def to_flat_vector(self) -> np.ndarray:
        """
        Flatten the full state to a 1D numpy array for the policy.
        This is the observation that the LSTM policy receives.
        """
        scalar_features = np.array([
            self.step_count / 10.0,
            self.delegation_depth / self.max_depth,
            self.num_specialists_called / self.max_specialists,
            self.elapsed_ms / max(self.sla_budget_ms, 1.0),
            float(self.phase) / 3.0,
            float(self.num_specialists_called > 0),
            float(self.delegation_depth == self.max_depth),
            float(self.elapsed_ms > self.sla_budget_ms * 0.8),
        ], dtype=np.float32)

        parts = [
            self.task_embedding.flatten(),
            self.roster_embeddings.flatten(),
            self.called_embeddings.flatten(),
            self.scratchpad_embedding.flatten(),
            self.delegation_graph_adj.flatten(),
            self.called_mask.flatten(),
            scalar_features,
        ]
        return np.concatenate(parts).astype(np.float32)

    @staticmethod
    def observation_dim(max_specialists: int = 8) -> int:
        """Compute the flat observation dimension given max_specialists."""
        task       = 384
        roster     = max_specialists * 384
        called     = max_specialists * 384
        scratchpad = 384
        graph      = 100   # 10Γ—10 adjacency
        mask       = max_specialists
        scalars    = 8
        return task + roster + called + scratchpad + graph + mask + scalars


def build_state(
    task_embedding: np.ndarray,
    registry,           # SpecialistRegistry
    called_ids: list[str],
    delegation_graph,   # DelegationGraph
    scratchpad,         # SharedScratchpad
    step_count: int,
    elapsed_ms: float,
    sla_budget_ms: float,
    max_specialists: int = 8,
    max_depth: int = 2,
    phase: int = 1,
    active_ids: list[str] | None = None,
) -> EpisodeState:
    """
    Factory function to build EpisodeState from all environment components.
    Called at each step by SpindleFlowEnv.

    active_ids: explicit per-episode roster (top-K by task similarity + any spawned
                specialists). When provided, replaces the default insertion-order slice.
    """
    all_ids = (list(active_ids) if active_ids is not None
               else registry.list_ids())[:max_specialists]

    # Roster embeddings matrix
    roster_matrix = np.zeros((max_specialists, 384), dtype=np.float32)
    for i, sid in enumerate(all_ids):
        if i >= max_specialists:
            break
        roster_matrix[i] = registry.get(sid).to_state_vector()

    # Called specialist embeddings
    called_matrix = np.zeros((max_specialists, 384), dtype=np.float32)
    called_mask = np.zeros(max_specialists, dtype=np.float32)
    for i, sid in enumerate(all_ids):
        if sid in called_ids and i < max_specialists:
            called_matrix[i] = registry.get(sid).to_state_vector()
            called_mask[i] = 1.0

    # Delegation graph adjacency vector
    adj_vector = np.array(
        delegation_graph.to_adjacency_vector(all_ids, max_size=10),
        dtype=np.float32,
    )

    # Scratchpad summary embedding
    scratchpad_emb = np.array(
        scratchpad.to_summary_vector(registry.embed_query),
        dtype=np.float32,
    )

    return EpisodeState(
        task_embedding=task_embedding,
        roster_embeddings=roster_matrix,
        called_embeddings=called_matrix,
        scratchpad_embedding=scratchpad_emb,
        delegation_graph_adj=adj_vector,
        called_mask=called_mask,
        step_count=step_count,
        delegation_depth=delegation_graph.depth,
        num_specialists_called=len(called_ids),
        max_specialists=max_specialists,
        max_depth=max_depth,
        elapsed_ms=elapsed_ms,
        sla_budget_ms=sla_budget_ms,
        phase=phase,
    )