Spaces:
Paused
Paused
File size: 5,468 Bytes
02ff91f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | """
State Representation β Fully observable episode state for the RL policy.
State components:
1. Task embedding (384-dim) β what needs to be done
2. Roster embedding matrix (N Γ 384) β available specialists
3. Called specialist embeddings (K Γ 384) β who has been called
4. Delegation graph adjacency vector (100-dim) β call structure
5. Scratchpad summary embedding (384-dim) β context so far
6. Scalar features (8-dim) β step count, depth, costs, etc.
7. Called specialist mask (N-dim) β binary, who's been called
Flattened total: ~1376 + N*384 dims (variable; padded to max_specialists)
"""
from __future__ import annotations
import numpy as np
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class EpisodeState:
"""
Complete state for one timestep in an episode.
Built by the SpindleFlowEnv at each step.
"""
# Core semantic representations
task_embedding: np.ndarray # (384,)
roster_embeddings: np.ndarray # (max_specialists, 384)
called_embeddings: np.ndarray # (max_specialists, 384) β 0s for uncalled
scratchpad_embedding: np.ndarray # (384,)
# Structural signals
delegation_graph_adj: np.ndarray # (100,) flat adjacency
called_mask: np.ndarray # (max_specialists,) binary
# Scalar features
step_count: int
delegation_depth: int
num_specialists_called: int
max_specialists: int
max_depth: int
elapsed_ms: float
sla_budget_ms: float
phase: int # 1, 2, or 3 (curriculum phase)
def to_flat_vector(self) -> np.ndarray:
"""
Flatten the full state to a 1D numpy array for the policy.
This is the observation that the LSTM policy receives.
"""
scalar_features = np.array([
self.step_count / 10.0,
self.delegation_depth / self.max_depth,
self.num_specialists_called / self.max_specialists,
self.elapsed_ms / max(self.sla_budget_ms, 1.0),
float(self.phase) / 3.0,
float(self.num_specialists_called > 0),
float(self.delegation_depth == self.max_depth),
float(self.elapsed_ms > self.sla_budget_ms * 0.8),
], dtype=np.float32)
parts = [
self.task_embedding.flatten(),
self.roster_embeddings.flatten(),
self.called_embeddings.flatten(),
self.scratchpad_embedding.flatten(),
self.delegation_graph_adj.flatten(),
self.called_mask.flatten(),
scalar_features,
]
return np.concatenate(parts).astype(np.float32)
@staticmethod
def observation_dim(max_specialists: int = 8) -> int:
"""Compute the flat observation dimension given max_specialists."""
task = 384
roster = max_specialists * 384
called = max_specialists * 384
scratchpad = 384
graph = 100 # 10Γ10 adjacency
mask = max_specialists
scalars = 8
return task + roster + called + scratchpad + graph + mask + scalars
def build_state(
task_embedding: np.ndarray,
registry, # SpecialistRegistry
called_ids: list[str],
delegation_graph, # DelegationGraph
scratchpad, # SharedScratchpad
step_count: int,
elapsed_ms: float,
sla_budget_ms: float,
max_specialists: int = 8,
max_depth: int = 2,
phase: int = 1,
active_ids: list[str] | None = None,
) -> EpisodeState:
"""
Factory function to build EpisodeState from all environment components.
Called at each step by SpindleFlowEnv.
active_ids: explicit per-episode roster (top-K by task similarity + any spawned
specialists). When provided, replaces the default insertion-order slice.
"""
all_ids = (list(active_ids) if active_ids is not None
else registry.list_ids())[:max_specialists]
# Roster embeddings matrix
roster_matrix = np.zeros((max_specialists, 384), dtype=np.float32)
for i, sid in enumerate(all_ids):
if i >= max_specialists:
break
roster_matrix[i] = registry.get(sid).to_state_vector()
# Called specialist embeddings
called_matrix = np.zeros((max_specialists, 384), dtype=np.float32)
called_mask = np.zeros(max_specialists, dtype=np.float32)
for i, sid in enumerate(all_ids):
if sid in called_ids and i < max_specialists:
called_matrix[i] = registry.get(sid).to_state_vector()
called_mask[i] = 1.0
# Delegation graph adjacency vector
adj_vector = np.array(
delegation_graph.to_adjacency_vector(all_ids, max_size=10),
dtype=np.float32,
)
# Scratchpad summary embedding
scratchpad_emb = np.array(
scratchpad.to_summary_vector(registry.embed_query),
dtype=np.float32,
)
return EpisodeState(
task_embedding=task_embedding,
roster_embeddings=roster_matrix,
called_embeddings=called_matrix,
scratchpad_embedding=scratchpad_emb,
delegation_graph_adj=adj_vector,
called_mask=called_mask,
step_count=step_count,
delegation_depth=delegation_graph.depth,
num_specialists_called=len(called_ids),
max_specialists=max_specialists,
max_depth=max_depth,
elapsed_ms=elapsed_ms,
sla_budget_ms=sla_budget_ms,
phase=phase,
)
|