sem-v6-training / src /sem_v6 /modules /module_d.py
icarus112's picture
Upload folder using huggingface_hub
518db7a verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class MetaController(nn.Module):
"""
Actor-Critic RL Meta-Controller for hypergraph optimization.
Reinforcement learning system that dynamically optimizes hypergraph architecture
during runtime using policy gradient methods. The meta-controller uses an actor-critic
architecture to learn optimal architectural modifications (pruning/creating hyperedges)
that improve overall system performance.
Architecture:
- Input: 16384-dim SDR from Module A (SimHash Encoder)
- Shared network: 2 layers (state_dim → 256 → 256) with GELU activation
- Actor head: Linear(256 → 10) outputs policy logits π(a|s) over 10 meta-actions
- Critic head: Linear(256 → 1) outputs state value estimate V(s)
Meta-Actions (0-indexed):
0. INCREASE_SPARSITY_THRESHOLD - Increase sparsity threshold for hyperedge activation
1. DECREASE_SPARSITY_THRESHOLD - Decrease sparsity threshold for hyperedge activation
2. PRUNE_WEAKEST_EDGE - Remove hyperedge with lowest weight
3. CREATE_RANDOM_EDGE - Add new hyperedge connecting random nodes
4. MERGE_SIMILAR_EDGES - Combine hyperedges with overlapping node sets
5. SPLIT_DENSE_EDGE - Divide hyperedge with many nodes into two smaller edges
6. BOOST_ACH - Boost acetylcholine (attention neuromodulator)
7. BOOST_NE - Boost norepinephrine (arousal neuromodulator)
8. TRIGGER_SLEEP - Trigger sleep consolidation mechanism
9. NO_OP - No operation, continue with current configuration
Reference:
- Sutton & Barto (2018) - Reinforcement Learning: An Introduction
- Chapter 13.5: Actor-Critic Methods
"""
# Meta-Actions: Dictionary mapping action indices to (name, description) tuples
META_ACTIONS = {
0: (
"INCREASE_SPARSITY_THRESHOLD",
"Increase sparsity threshold for hyperedge activation",
),
1: (
"DECREASE_SPARSITY_THRESHOLD",
"Decrease sparsity threshold for hyperedge activation",
),
2: ("PRUNE_WEAKEST_EDGE", "Remove hyperedge with lowest weight"),
3: ("CREATE_RANDOM_EDGE", "Add new hyperedge connecting random nodes"),
4: ("MERGE_SIMILAR_EDGES", "Combine hyperedges with overlapping node sets"),
5: (
"SPLIT_DENSE_EDGE",
"Divide hyperedge with many nodes into two smaller edges",
),
6: ("BOOST_ACH", "Boost acetylcholine (attention neuromodulator)"),
7: ("BOOST_NE", "Boost norepinephrine (arousal neuromodulator)"),
8: ("TRIGGER_SLEEP", "Trigger sleep consolidation mechanism"),
9: ("NO_OP", "No operation, continue with current configuration"),
}
def __init__(
self,
state_dim: int = 16384,
hidden_dim: int = 256,
num_actions: int = 10,
device: Optional[str] = None,
) -> None:
"""
Initialize MetaController with actor-critic architecture.
Args:
state_dim: Dimensionality of input state (SDR from Module A)
hidden_dim: Size of hidden layers in shared network
num_actions: Number of meta-actions (fixed at 10)
device: Device to run on ('cuda' or 'cpu'). If None, auto-detects GPU.
"""
super().__init__()
assert num_actions == 10, f"num_actions must be 10, got {num_actions}"
self.state_dim = state_dim
self.hidden_dim = hidden_dim
self.num_actions = num_actions
# Auto-detect device if not specified
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
# Shared feature extraction (2 layers with GELU)
self.shared = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
)
# Actor head: policy π(a|s) - outputs logits over actions
self.actor = nn.Linear(hidden_dim, num_actions)
# Critic head: value V(s) - outputs scalar state value
self.critic = nn.Linear(hidden_dim, 1)
# Move to device
self.to(self.device)
def forward(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass through actor-critic network.
Args:
state: Input state tensor of shape (state_dim,) or (batch, state_dim)
Returns:
logits: Action logits of shape (num_actions,) or (batch, num_actions)
value: State value of shape () or (batch,) - note: scalar/squeezed output
Example:
>>> controller = MetaController()
>>> state = torch.randn(16384)
>>> logits, value = controller(state)
>>> logits.shape
torch.Size([10])
>>> value.shape
torch.Size([])
"""
# Pass through shared network
features = self.shared(state)
# Actor: policy logits
logits = self.actor(features)
# Critic: value estimate (squeeze to remove last dimension)
value = self.critic(features).squeeze(-1)
return logits, value
def select_action(
self, state: torch.Tensor, training: bool = True
) -> tuple[int, float]:
"""
Sample action from policy during training, greedy during evaluation.
Args:
state: torch.Tensor [state_dim] - Single state (not batched)
training: bool - If True, sample from policy; if False, use argmax
Returns:
action: int - Selected action index (0-9)
value: float - Predicted value of the state
Example:
>>> controller = MetaController()
>>> state = torch.randn(16384)
>>> action, value = controller.select_action(state, training=True)
>>> assert 0 <= action < 10
>>> assert isinstance(value, float)
"""
logits, value = self.forward(state)
safe_logits = torch.clamp(logits, -50.0, 50.0)
probs = F.softmax(safe_logits, dim=-1)
probs = torch.where(torch.isfinite(probs), probs, torch.zeros_like(probs))
probs = probs + 1e-8
total = probs.sum()
if not torch.isfinite(total) or total <= 0:
probs = torch.ones_like(probs) / probs.numel()
else:
probs = probs / total
if training:
# Sample from policy (exploration)
action = int(torch.multinomial(probs, 1).item())
else:
# Greedy (exploitation)
action = int(probs.argmax().item())
return action, float(value.item())
def compute_loss(
self,
state: torch.Tensor,
action: int,
reward: float,
old_value: float,
gamma: float = 0.99,
) -> torch.Tensor:
"""
Compute actor-critic loss with one-step TD error.
Loss = actor_loss + 0.5 * critic_loss
Args:
state: Current state (state_dim,) - SDR from Module A
action: Action taken (int, 0-9) - Selected meta-action index
reward: Reward received (float) - Performance improvement signal
old_value: Value estimated before taking action (float) - Baseline for advantage
gamma: Discount factor (not used in one-step version)
Returns:
loss: Combined actor + critic loss (scalar tensor)
Example:
>>> controller = MetaController()
>>> state = torch.randn(16384)
>>> action = 2
>>> reward = 1.0
>>> old_value = 0.5
>>> loss = controller.compute_loss(state, action, reward, old_value)
>>> assert loss.ndim == 0 # Scalar loss
>>> assert not torch.isnan(loss) # No NaN
"""
logits, value = self.forward(state)
# TD error (advantage) - one-step temporal difference
td_error = reward - old_value
# Critic loss: minimize squared TD error
critic_loss = td_error**2
# Actor loss: policy gradient with advantage
log_probs = F.log_softmax(logits, dim=-1)
actor_loss = -log_probs[action] * td_error
# Combined loss (weight critic by 0.5, standard practice)
loss = actor_loss + 0.5 * critic_loss
return loss
def execute_action(self, action: int, hypergraph: Optional[object]) -> None:
"""
Execute a meta-action on the hypergraph (Module B).
Stub implementation - will be completed in later subtasks.
Applies structural modifications to the hypergraph based on the selected
meta-action (e.g., pruning edges, adjusting sparsity, boosting neuromodulators).
Args:
action: Action index (0-9) - Must be valid meta-action from META_ACTIONS
hypergraph: HypergraphManifold instance from Module B (None allowed for testing)
Returns:
None
Raises:
ValueError: If action index is out of valid range [0-9]
Example:
>>> controller = MetaController()
>>> result = controller.execute_action(2, None) # PRUNE_WEAKEST_EDGE
>>> assert result is None # Stub returns None
"""
if not 0 <= action < self.num_actions:
raise ValueError(
f"Invalid action {action}. Must be in range [0, {self.num_actions - 1}]"
)
# Stub: actual implementation will modify hypergraph based on action
# Will be implemented in subtask-3-3 through subtask-3-12
return None