Spaces:
Paused
Paused
| """ | |
| State Representation — Fully observable episode state for the RL policy. | |
| State components: | |
| 1. Task embedding (384-dim) — what needs to be done | |
| 2. Roster embedding matrix (N × 384) — available specialists | |
| 3. Called specialist embeddings (K × 384) — who has been called | |
| 4. Delegation graph adjacency vector (100-dim) — call structure | |
| 5. Scratchpad summary embedding (384-dim) — context so far | |
| 6. Scalar features (8-dim) — step count, depth, costs, etc. | |
| 7. Called specialist mask (N-dim) — binary, who's been called | |
| Flattened total: ~1376 + N*384 dims (variable; padded to max_specialists) | |
| """ | |
| from __future__ import annotations | |
| import numpy as np | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| class EpisodeState: | |
| """ | |
| Complete state for one timestep in an episode. | |
| Built by the SpindleFlowEnv at each step. | |
| """ | |
| # Core semantic representations | |
| task_embedding: np.ndarray # (384,) | |
| roster_embeddings: np.ndarray # (max_specialists, 384) | |
| called_embeddings: np.ndarray # (max_specialists, 384) — 0s for uncalled | |
| scratchpad_embedding: np.ndarray # (384,) | |
| # Structural signals | |
| delegation_graph_adj: np.ndarray # (100,) flat adjacency | |
| called_mask: np.ndarray # (max_specialists,) binary | |
| # Scalar features | |
| step_count: int | |
| delegation_depth: int | |
| num_specialists_called: int | |
| max_specialists: int | |
| max_depth: int | |
| elapsed_ms: float | |
| sla_budget_ms: float | |
| phase: int # 1, 2, or 3 (curriculum phase) | |
| def to_flat_vector(self) -> np.ndarray: | |
| """ | |
| Flatten the full state to a 1D numpy array for the policy. | |
| This is the observation that the LSTM policy receives. | |
| """ | |
| scalar_features = np.array([ | |
| self.step_count / 10.0, | |
| self.delegation_depth / self.max_depth, | |
| self.num_specialists_called / self.max_specialists, | |
| self.elapsed_ms / max(self.sla_budget_ms, 1.0), | |
| float(self.phase) / 3.0, | |
| float(self.num_specialists_called > 0), | |
| float(self.delegation_depth == self.max_depth), | |
| float(self.elapsed_ms > self.sla_budget_ms * 0.8), | |
| ], dtype=np.float32) | |
| parts = [ | |
| self.task_embedding.flatten(), | |
| self.roster_embeddings.flatten(), | |
| self.called_embeddings.flatten(), | |
| self.scratchpad_embedding.flatten(), | |
| self.delegation_graph_adj.flatten(), | |
| self.called_mask.flatten(), | |
| scalar_features, | |
| ] | |
| return np.concatenate(parts).astype(np.float32) | |
| def observation_dim(max_specialists: int = 8) -> int: | |
| """Compute the flat observation dimension given max_specialists.""" | |
| task = 384 | |
| roster = max_specialists * 384 | |
| called = max_specialists * 384 | |
| scratchpad = 384 | |
| graph = 100 # 10×10 adjacency | |
| mask = max_specialists | |
| scalars = 8 | |
| return task + roster + called + scratchpad + graph + mask + scalars | |
| def build_state( | |
| task_embedding: np.ndarray, | |
| registry, # SpecialistRegistry | |
| called_ids: list[str], | |
| delegation_graph, # DelegationGraph | |
| scratchpad, # SharedScratchpad | |
| step_count: int, | |
| elapsed_ms: float, | |
| sla_budget_ms: float, | |
| max_specialists: int = 8, | |
| max_depth: int = 2, | |
| phase: int = 1, | |
| active_ids: list[str] | None = None, | |
| ) -> EpisodeState: | |
| """ | |
| Factory function to build EpisodeState from all environment components. | |
| Called at each step by SpindleFlowEnv. | |
| active_ids: explicit per-episode roster (top-K by task similarity + any spawned | |
| specialists). When provided, replaces the default insertion-order slice. | |
| """ | |
| all_ids = (list(active_ids) if active_ids is not None | |
| else registry.list_ids())[:max_specialists] | |
| # Roster embeddings matrix | |
| roster_matrix = np.zeros((max_specialists, 384), dtype=np.float32) | |
| for i, sid in enumerate(all_ids): | |
| if i >= max_specialists: | |
| break | |
| roster_matrix[i] = registry.get(sid).to_state_vector() | |
| # Called specialist embeddings | |
| called_matrix = np.zeros((max_specialists, 384), dtype=np.float32) | |
| called_mask = np.zeros(max_specialists, dtype=np.float32) | |
| for i, sid in enumerate(all_ids): | |
| if sid in called_ids and i < max_specialists: | |
| called_matrix[i] = registry.get(sid).to_state_vector() | |
| called_mask[i] = 1.0 | |
| # Delegation graph adjacency vector | |
| adj_vector = np.array( | |
| delegation_graph.to_adjacency_vector(all_ids, max_size=10), | |
| dtype=np.float32, | |
| ) | |
| # Scratchpad summary embedding | |
| scratchpad_emb = np.array( | |
| scratchpad.to_summary_vector(registry.embed_query), | |
| dtype=np.float32, | |
| ) | |
| return EpisodeState( | |
| task_embedding=task_embedding, | |
| roster_embeddings=roster_matrix, | |
| called_embeddings=called_matrix, | |
| scratchpad_embedding=scratchpad_emb, | |
| delegation_graph_adj=adj_vector, | |
| called_mask=called_mask, | |
| step_count=step_count, | |
| delegation_depth=delegation_graph.depth, | |
| num_specialists_called=len(called_ids), | |
| max_specialists=max_specialists, | |
| max_depth=max_depth, | |
| elapsed_ms=elapsed_ms, | |
| sla_budget_ms=sla_budget_ms, | |
| phase=phase, | |
| ) | |