File size: 9,732 Bytes
518db7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional


class MetaController(nn.Module):
    """
    Actor-Critic RL Meta-Controller for hypergraph optimization.

    Reinforcement learning system that dynamically optimizes hypergraph architecture
    during runtime using policy gradient methods. The meta-controller uses an actor-critic
    architecture to learn optimal architectural modifications (pruning/creating hyperedges)
    that improve overall system performance.

    Architecture:
    - Input: 16384-dim SDR from Module A (SimHash Encoder)
    - Shared network: 2 layers (state_dim → 256 → 256) with GELU activation
    - Actor head: Linear(256 → 10) outputs policy logits π(a|s) over 10 meta-actions
    - Critic head: Linear(256 → 1) outputs state value estimate V(s)

    Meta-Actions (0-indexed):
    0. INCREASE_SPARSITY_THRESHOLD - Increase sparsity threshold for hyperedge activation
    1. DECREASE_SPARSITY_THRESHOLD - Decrease sparsity threshold for hyperedge activation
    2. PRUNE_WEAKEST_EDGE - Remove hyperedge with lowest weight
    3. CREATE_RANDOM_EDGE - Add new hyperedge connecting random nodes
    4. MERGE_SIMILAR_EDGES - Combine hyperedges with overlapping node sets
    5. SPLIT_DENSE_EDGE - Divide hyperedge with many nodes into two smaller edges
    6. BOOST_ACH - Boost acetylcholine (attention neuromodulator)
    7. BOOST_NE - Boost norepinephrine (arousal neuromodulator)
    8. TRIGGER_SLEEP - Trigger sleep consolidation mechanism
    9. NO_OP - No operation, continue with current configuration

    Reference:
    - Sutton & Barto (2018) - Reinforcement Learning: An Introduction
    - Chapter 13.5: Actor-Critic Methods
    """

    # Meta-Actions: Dictionary mapping action indices to (name, description) tuples
    META_ACTIONS = {
        0: (
            "INCREASE_SPARSITY_THRESHOLD",
            "Increase sparsity threshold for hyperedge activation",
        ),
        1: (
            "DECREASE_SPARSITY_THRESHOLD",
            "Decrease sparsity threshold for hyperedge activation",
        ),
        2: ("PRUNE_WEAKEST_EDGE", "Remove hyperedge with lowest weight"),
        3: ("CREATE_RANDOM_EDGE", "Add new hyperedge connecting random nodes"),
        4: ("MERGE_SIMILAR_EDGES", "Combine hyperedges with overlapping node sets"),
        5: (
            "SPLIT_DENSE_EDGE",
            "Divide hyperedge with many nodes into two smaller edges",
        ),
        6: ("BOOST_ACH", "Boost acetylcholine (attention neuromodulator)"),
        7: ("BOOST_NE", "Boost norepinephrine (arousal neuromodulator)"),
        8: ("TRIGGER_SLEEP", "Trigger sleep consolidation mechanism"),
        9: ("NO_OP", "No operation, continue with current configuration"),
    }

    def __init__(
        self,
        state_dim: int = 16384,
        hidden_dim: int = 256,
        num_actions: int = 10,
        device: Optional[str] = None,
    ) -> None:
        """
        Initialize MetaController with actor-critic architecture.

        Args:
            state_dim: Dimensionality of input state (SDR from Module A)
            hidden_dim: Size of hidden layers in shared network
            num_actions: Number of meta-actions (fixed at 10)
            device: Device to run on ('cuda' or 'cpu'). If None, auto-detects GPU.
        """
        super().__init__()
        assert num_actions == 10, f"num_actions must be 10, got {num_actions}"

        self.state_dim = state_dim
        self.hidden_dim = hidden_dim
        self.num_actions = num_actions

        # Auto-detect device if not specified
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(device)

        # Shared feature extraction (2 layers with GELU)
        self.shared = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
        )

        # Actor head: policy π(a|s) - outputs logits over actions
        self.actor = nn.Linear(hidden_dim, num_actions)

        # Critic head: value V(s) - outputs scalar state value
        self.critic = nn.Linear(hidden_dim, 1)

        # Move to device
        self.to(self.device)

    def forward(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass through actor-critic network.

        Args:
            state: Input state tensor of shape (state_dim,) or (batch, state_dim)

        Returns:
            logits: Action logits of shape (num_actions,) or (batch, num_actions)
            value: State value of shape () or (batch,) - note: scalar/squeezed output

        Example:
            >>> controller = MetaController()
            >>> state = torch.randn(16384)
            >>> logits, value = controller(state)
            >>> logits.shape
            torch.Size([10])
            >>> value.shape
            torch.Size([])
        """
        # Pass through shared network
        features = self.shared(state)

        # Actor: policy logits
        logits = self.actor(features)

        # Critic: value estimate (squeeze to remove last dimension)
        value = self.critic(features).squeeze(-1)

        return logits, value

    def select_action(
        self, state: torch.Tensor, training: bool = True
    ) -> tuple[int, float]:
        """
        Sample action from policy during training, greedy during evaluation.

        Args:
            state: torch.Tensor [state_dim] - Single state (not batched)
            training: bool - If True, sample from policy; if False, use argmax

        Returns:
            action: int - Selected action index (0-9)
            value: float - Predicted value of the state

        Example:
            >>> controller = MetaController()
            >>> state = torch.randn(16384)
            >>> action, value = controller.select_action(state, training=True)
            >>> assert 0 <= action < 10
            >>> assert isinstance(value, float)
        """
        logits, value = self.forward(state)
        safe_logits = torch.clamp(logits, -50.0, 50.0)
        probs = F.softmax(safe_logits, dim=-1)
        probs = torch.where(torch.isfinite(probs), probs, torch.zeros_like(probs))
        probs = probs + 1e-8
        total = probs.sum()
        if not torch.isfinite(total) or total <= 0:
            probs = torch.ones_like(probs) / probs.numel()
        else:
            probs = probs / total

        if training:
            # Sample from policy (exploration)
            action = int(torch.multinomial(probs, 1).item())
        else:
            # Greedy (exploitation)
            action = int(probs.argmax().item())

        return action, float(value.item())

    def compute_loss(
        self,
        state: torch.Tensor,
        action: int,
        reward: float,
        old_value: float,
        gamma: float = 0.99,
    ) -> torch.Tensor:
        """
        Compute actor-critic loss with one-step TD error.

        Loss = actor_loss + 0.5 * critic_loss

        Args:
            state: Current state (state_dim,) - SDR from Module A
            action: Action taken (int, 0-9) - Selected meta-action index
            reward: Reward received (float) - Performance improvement signal
            old_value: Value estimated before taking action (float) - Baseline for advantage
            gamma: Discount factor (not used in one-step version)

        Returns:
            loss: Combined actor + critic loss (scalar tensor)

        Example:
            >>> controller = MetaController()
            >>> state = torch.randn(16384)
            >>> action = 2
            >>> reward = 1.0
            >>> old_value = 0.5
            >>> loss = controller.compute_loss(state, action, reward, old_value)
            >>> assert loss.ndim == 0  # Scalar loss
            >>> assert not torch.isnan(loss)  # No NaN
        """
        logits, value = self.forward(state)

        # TD error (advantage) - one-step temporal difference
        td_error = reward - old_value

        # Critic loss: minimize squared TD error
        critic_loss = td_error**2

        # Actor loss: policy gradient with advantage
        log_probs = F.log_softmax(logits, dim=-1)
        actor_loss = -log_probs[action] * td_error

        # Combined loss (weight critic by 0.5, standard practice)
        loss = actor_loss + 0.5 * critic_loss

        return loss

    def execute_action(self, action: int, hypergraph: Optional[object]) -> None:
        """
        Execute a meta-action on the hypergraph (Module B).

        Stub implementation - will be completed in later subtasks.
        Applies structural modifications to the hypergraph based on the selected
        meta-action (e.g., pruning edges, adjusting sparsity, boosting neuromodulators).

        Args:
            action: Action index (0-9) - Must be valid meta-action from META_ACTIONS
            hypergraph: HypergraphManifold instance from Module B (None allowed for testing)

        Returns:
            None

        Raises:
            ValueError: If action index is out of valid range [0-9]

        Example:
            >>> controller = MetaController()
            >>> result = controller.execute_action(2, None)  # PRUNE_WEAKEST_EDGE
            >>> assert result is None  # Stub returns None
        """
        if not 0 <= action < self.num_actions:
            raise ValueError(
                f"Invalid action {action}. Must be in range [0, {self.num_actions - 1}]"
            )

        # Stub: actual implementation will modify hypergraph based on action
        # Will be implemented in subtask-3-3 through subtask-3-12
        return None