import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional class MetaController(nn.Module): """ Actor-Critic RL Meta-Controller for hypergraph optimization. Reinforcement learning system that dynamically optimizes hypergraph architecture during runtime using policy gradient methods. The meta-controller uses an actor-critic architecture to learn optimal architectural modifications (pruning/creating hyperedges) that improve overall system performance. Architecture: - Input: 16384-dim SDR from Module A (SimHash Encoder) - Shared network: 2 layers (state_dim → 256 → 256) with GELU activation - Actor head: Linear(256 → 10) outputs policy logits π(a|s) over 10 meta-actions - Critic head: Linear(256 → 1) outputs state value estimate V(s) Meta-Actions (0-indexed): 0. INCREASE_SPARSITY_THRESHOLD - Increase sparsity threshold for hyperedge activation 1. DECREASE_SPARSITY_THRESHOLD - Decrease sparsity threshold for hyperedge activation 2. PRUNE_WEAKEST_EDGE - Remove hyperedge with lowest weight 3. CREATE_RANDOM_EDGE - Add new hyperedge connecting random nodes 4. MERGE_SIMILAR_EDGES - Combine hyperedges with overlapping node sets 5. SPLIT_DENSE_EDGE - Divide hyperedge with many nodes into two smaller edges 6. BOOST_ACH - Boost acetylcholine (attention neuromodulator) 7. BOOST_NE - Boost norepinephrine (arousal neuromodulator) 8. TRIGGER_SLEEP - Trigger sleep consolidation mechanism 9. NO_OP - No operation, continue with current configuration Reference: - Sutton & Barto (2018) - Reinforcement Learning: An Introduction - Chapter 13.5: Actor-Critic Methods """ # Meta-Actions: Dictionary mapping action indices to (name, description) tuples META_ACTIONS = { 0: ( "INCREASE_SPARSITY_THRESHOLD", "Increase sparsity threshold for hyperedge activation", ), 1: ( "DECREASE_SPARSITY_THRESHOLD", "Decrease sparsity threshold for hyperedge activation", ), 2: ("PRUNE_WEAKEST_EDGE", "Remove hyperedge with lowest weight"), 3: ("CREATE_RANDOM_EDGE", "Add new hyperedge connecting random nodes"), 4: ("MERGE_SIMILAR_EDGES", "Combine hyperedges with overlapping node sets"), 5: ( "SPLIT_DENSE_EDGE", "Divide hyperedge with many nodes into two smaller edges", ), 6: ("BOOST_ACH", "Boost acetylcholine (attention neuromodulator)"), 7: ("BOOST_NE", "Boost norepinephrine (arousal neuromodulator)"), 8: ("TRIGGER_SLEEP", "Trigger sleep consolidation mechanism"), 9: ("NO_OP", "No operation, continue with current configuration"), } def __init__( self, state_dim: int = 16384, hidden_dim: int = 256, num_actions: int = 10, device: Optional[str] = None, ) -> None: """ Initialize MetaController with actor-critic architecture. Args: state_dim: Dimensionality of input state (SDR from Module A) hidden_dim: Size of hidden layers in shared network num_actions: Number of meta-actions (fixed at 10) device: Device to run on ('cuda' or 'cpu'). If None, auto-detects GPU. """ super().__init__() assert num_actions == 10, f"num_actions must be 10, got {num_actions}" self.state_dim = state_dim self.hidden_dim = hidden_dim self.num_actions = num_actions # Auto-detect device if not specified if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.device = torch.device(device) # Shared feature extraction (2 layers with GELU) self.shared = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, hidden_dim), nn.GELU(), ) # Actor head: policy π(a|s) - outputs logits over actions self.actor = nn.Linear(hidden_dim, num_actions) # Critic head: value V(s) - outputs scalar state value self.critic = nn.Linear(hidden_dim, 1) # Move to device self.to(self.device) def forward(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Forward pass through actor-critic network. Args: state: Input state tensor of shape (state_dim,) or (batch, state_dim) Returns: logits: Action logits of shape (num_actions,) or (batch, num_actions) value: State value of shape () or (batch,) - note: scalar/squeezed output Example: >>> controller = MetaController() >>> state = torch.randn(16384) >>> logits, value = controller(state) >>> logits.shape torch.Size([10]) >>> value.shape torch.Size([]) """ # Pass through shared network features = self.shared(state) # Actor: policy logits logits = self.actor(features) # Critic: value estimate (squeeze to remove last dimension) value = self.critic(features).squeeze(-1) return logits, value def select_action( self, state: torch.Tensor, training: bool = True ) -> tuple[int, float]: """ Sample action from policy during training, greedy during evaluation. Args: state: torch.Tensor [state_dim] - Single state (not batched) training: bool - If True, sample from policy; if False, use argmax Returns: action: int - Selected action index (0-9) value: float - Predicted value of the state Example: >>> controller = MetaController() >>> state = torch.randn(16384) >>> action, value = controller.select_action(state, training=True) >>> assert 0 <= action < 10 >>> assert isinstance(value, float) """ logits, value = self.forward(state) safe_logits = torch.clamp(logits, -50.0, 50.0) probs = F.softmax(safe_logits, dim=-1) probs = torch.where(torch.isfinite(probs), probs, torch.zeros_like(probs)) probs = probs + 1e-8 total = probs.sum() if not torch.isfinite(total) or total <= 0: probs = torch.ones_like(probs) / probs.numel() else: probs = probs / total if training: # Sample from policy (exploration) action = int(torch.multinomial(probs, 1).item()) else: # Greedy (exploitation) action = int(probs.argmax().item()) return action, float(value.item()) def compute_loss( self, state: torch.Tensor, action: int, reward: float, old_value: float, gamma: float = 0.99, ) -> torch.Tensor: """ Compute actor-critic loss with one-step TD error. Loss = actor_loss + 0.5 * critic_loss Args: state: Current state (state_dim,) - SDR from Module A action: Action taken (int, 0-9) - Selected meta-action index reward: Reward received (float) - Performance improvement signal old_value: Value estimated before taking action (float) - Baseline for advantage gamma: Discount factor (not used in one-step version) Returns: loss: Combined actor + critic loss (scalar tensor) Example: >>> controller = MetaController() >>> state = torch.randn(16384) >>> action = 2 >>> reward = 1.0 >>> old_value = 0.5 >>> loss = controller.compute_loss(state, action, reward, old_value) >>> assert loss.ndim == 0 # Scalar loss >>> assert not torch.isnan(loss) # No NaN """ logits, value = self.forward(state) # TD error (advantage) - one-step temporal difference td_error = reward - old_value # Critic loss: minimize squared TD error critic_loss = td_error**2 # Actor loss: policy gradient with advantage log_probs = F.log_softmax(logits, dim=-1) actor_loss = -log_probs[action] * td_error # Combined loss (weight critic by 0.5, standard practice) loss = actor_loss + 0.5 * critic_loss return loss def execute_action(self, action: int, hypergraph: Optional[object]) -> None: """ Execute a meta-action on the hypergraph (Module B). Stub implementation - will be completed in later subtasks. Applies structural modifications to the hypergraph based on the selected meta-action (e.g., pruning edges, adjusting sparsity, boosting neuromodulators). Args: action: Action index (0-9) - Must be valid meta-action from META_ACTIONS hypergraph: HypergraphManifold instance from Module B (None allowed for testing) Returns: None Raises: ValueError: If action index is out of valid range [0-9] Example: >>> controller = MetaController() >>> result = controller.execute_action(2, None) # PRUNE_WEAKEST_EDGE >>> assert result is None # Stub returns None """ if not 0 <= action < self.num_actions: raise ValueError( f"Invalid action {action}. Must be in range [0, {self.num_actions - 1}]" ) # Stub: actual implementation will modify hypergraph based on action # Will be implemented in subtask-3-3 through subtask-3-12 return None