| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional |
|
|
|
|
| class MetaController(nn.Module): |
| """ |
| Actor-Critic RL Meta-Controller for hypergraph optimization. |
| |
| Reinforcement learning system that dynamically optimizes hypergraph architecture |
| during runtime using policy gradient methods. The meta-controller uses an actor-critic |
| architecture to learn optimal architectural modifications (pruning/creating hyperedges) |
| that improve overall system performance. |
| |
| Architecture: |
| - Input: 16384-dim SDR from Module A (SimHash Encoder) |
| - Shared network: 2 layers (state_dim → 256 → 256) with GELU activation |
| - Actor head: Linear(256 → 10) outputs policy logits π(a|s) over 10 meta-actions |
| - Critic head: Linear(256 → 1) outputs state value estimate V(s) |
| |
| Meta-Actions (0-indexed): |
| 0. INCREASE_SPARSITY_THRESHOLD - Increase sparsity threshold for hyperedge activation |
| 1. DECREASE_SPARSITY_THRESHOLD - Decrease sparsity threshold for hyperedge activation |
| 2. PRUNE_WEAKEST_EDGE - Remove hyperedge with lowest weight |
| 3. CREATE_RANDOM_EDGE - Add new hyperedge connecting random nodes |
| 4. MERGE_SIMILAR_EDGES - Combine hyperedges with overlapping node sets |
| 5. SPLIT_DENSE_EDGE - Divide hyperedge with many nodes into two smaller edges |
| 6. BOOST_ACH - Boost acetylcholine (attention neuromodulator) |
| 7. BOOST_NE - Boost norepinephrine (arousal neuromodulator) |
| 8. TRIGGER_SLEEP - Trigger sleep consolidation mechanism |
| 9. NO_OP - No operation, continue with current configuration |
| |
| Reference: |
| - Sutton & Barto (2018) - Reinforcement Learning: An Introduction |
| - Chapter 13.5: Actor-Critic Methods |
| """ |
|
|
| |
| META_ACTIONS = { |
| 0: ( |
| "INCREASE_SPARSITY_THRESHOLD", |
| "Increase sparsity threshold for hyperedge activation", |
| ), |
| 1: ( |
| "DECREASE_SPARSITY_THRESHOLD", |
| "Decrease sparsity threshold for hyperedge activation", |
| ), |
| 2: ("PRUNE_WEAKEST_EDGE", "Remove hyperedge with lowest weight"), |
| 3: ("CREATE_RANDOM_EDGE", "Add new hyperedge connecting random nodes"), |
| 4: ("MERGE_SIMILAR_EDGES", "Combine hyperedges with overlapping node sets"), |
| 5: ( |
| "SPLIT_DENSE_EDGE", |
| "Divide hyperedge with many nodes into two smaller edges", |
| ), |
| 6: ("BOOST_ACH", "Boost acetylcholine (attention neuromodulator)"), |
| 7: ("BOOST_NE", "Boost norepinephrine (arousal neuromodulator)"), |
| 8: ("TRIGGER_SLEEP", "Trigger sleep consolidation mechanism"), |
| 9: ("NO_OP", "No operation, continue with current configuration"), |
| } |
|
|
| def __init__( |
| self, |
| state_dim: int = 16384, |
| hidden_dim: int = 256, |
| num_actions: int = 10, |
| device: Optional[str] = None, |
| ) -> None: |
| """ |
| Initialize MetaController with actor-critic architecture. |
| |
| Args: |
| state_dim: Dimensionality of input state (SDR from Module A) |
| hidden_dim: Size of hidden layers in shared network |
| num_actions: Number of meta-actions (fixed at 10) |
| device: Device to run on ('cuda' or 'cpu'). If None, auto-detects GPU. |
| """ |
| super().__init__() |
| assert num_actions == 10, f"num_actions must be 10, got {num_actions}" |
|
|
| self.state_dim = state_dim |
| self.hidden_dim = hidden_dim |
| self.num_actions = num_actions |
|
|
| |
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.device = torch.device(device) |
|
|
| |
| self.shared = nn.Sequential( |
| nn.Linear(state_dim, hidden_dim), |
| nn.GELU(), |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.GELU(), |
| ) |
|
|
| |
| self.actor = nn.Linear(hidden_dim, num_actions) |
|
|
| |
| self.critic = nn.Linear(hidden_dim, 1) |
|
|
| |
| self.to(self.device) |
|
|
| def forward(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Forward pass through actor-critic network. |
| |
| Args: |
| state: Input state tensor of shape (state_dim,) or (batch, state_dim) |
| |
| Returns: |
| logits: Action logits of shape (num_actions,) or (batch, num_actions) |
| value: State value of shape () or (batch,) - note: scalar/squeezed output |
| |
| Example: |
| >>> controller = MetaController() |
| >>> state = torch.randn(16384) |
| >>> logits, value = controller(state) |
| >>> logits.shape |
| torch.Size([10]) |
| >>> value.shape |
| torch.Size([]) |
| """ |
| |
| features = self.shared(state) |
|
|
| |
| logits = self.actor(features) |
|
|
| |
| value = self.critic(features).squeeze(-1) |
|
|
| return logits, value |
|
|
| def select_action( |
| self, state: torch.Tensor, training: bool = True |
| ) -> tuple[int, float]: |
| """ |
| Sample action from policy during training, greedy during evaluation. |
| |
| Args: |
| state: torch.Tensor [state_dim] - Single state (not batched) |
| training: bool - If True, sample from policy; if False, use argmax |
| |
| Returns: |
| action: int - Selected action index (0-9) |
| value: float - Predicted value of the state |
| |
| Example: |
| >>> controller = MetaController() |
| >>> state = torch.randn(16384) |
| >>> action, value = controller.select_action(state, training=True) |
| >>> assert 0 <= action < 10 |
| >>> assert isinstance(value, float) |
| """ |
| logits, value = self.forward(state) |
| safe_logits = torch.clamp(logits, -50.0, 50.0) |
| probs = F.softmax(safe_logits, dim=-1) |
| probs = torch.where(torch.isfinite(probs), probs, torch.zeros_like(probs)) |
| probs = probs + 1e-8 |
| total = probs.sum() |
| if not torch.isfinite(total) or total <= 0: |
| probs = torch.ones_like(probs) / probs.numel() |
| else: |
| probs = probs / total |
|
|
| if training: |
| |
| action = int(torch.multinomial(probs, 1).item()) |
| else: |
| |
| action = int(probs.argmax().item()) |
|
|
| return action, float(value.item()) |
|
|
| def compute_loss( |
| self, |
| state: torch.Tensor, |
| action: int, |
| reward: float, |
| old_value: float, |
| gamma: float = 0.99, |
| ) -> torch.Tensor: |
| """ |
| Compute actor-critic loss with one-step TD error. |
| |
| Loss = actor_loss + 0.5 * critic_loss |
| |
| Args: |
| state: Current state (state_dim,) - SDR from Module A |
| action: Action taken (int, 0-9) - Selected meta-action index |
| reward: Reward received (float) - Performance improvement signal |
| old_value: Value estimated before taking action (float) - Baseline for advantage |
| gamma: Discount factor (not used in one-step version) |
| |
| Returns: |
| loss: Combined actor + critic loss (scalar tensor) |
| |
| Example: |
| >>> controller = MetaController() |
| >>> state = torch.randn(16384) |
| >>> action = 2 |
| >>> reward = 1.0 |
| >>> old_value = 0.5 |
| >>> loss = controller.compute_loss(state, action, reward, old_value) |
| >>> assert loss.ndim == 0 # Scalar loss |
| >>> assert not torch.isnan(loss) # No NaN |
| """ |
| logits, value = self.forward(state) |
|
|
| |
| td_error = reward - old_value |
|
|
| |
| critic_loss = td_error**2 |
|
|
| |
| log_probs = F.log_softmax(logits, dim=-1) |
| actor_loss = -log_probs[action] * td_error |
|
|
| |
| loss = actor_loss + 0.5 * critic_loss |
|
|
| return loss |
|
|
| def execute_action(self, action: int, hypergraph: Optional[object]) -> None: |
| """ |
| Execute a meta-action on the hypergraph (Module B). |
| |
| Stub implementation - will be completed in later subtasks. |
| Applies structural modifications to the hypergraph based on the selected |
| meta-action (e.g., pruning edges, adjusting sparsity, boosting neuromodulators). |
| |
| Args: |
| action: Action index (0-9) - Must be valid meta-action from META_ACTIONS |
| hypergraph: HypergraphManifold instance from Module B (None allowed for testing) |
| |
| Returns: |
| None |
| |
| Raises: |
| ValueError: If action index is out of valid range [0-9] |
| |
| Example: |
| >>> controller = MetaController() |
| >>> result = controller.execute_action(2, None) # PRUNE_WEAKEST_EDGE |
| >>> assert result is None # Stub returns None |
| """ |
| if not 0 <= action < self.num_actions: |
| raise ValueError( |
| f"Invalid action {action}. Must be in range [0, {self.num_actions - 1}]" |
| ) |
|
|
| |
| |
| return None |
|
|