diff --git "a/CLEANED_modular_brain_agent.py" "b/CLEANED_modular_brain_agent.py" new file mode 100644--- /dev/null +++ "b/CLEANED_modular_brain_agent.py" @@ -0,0 +1,1274 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, TensorDataset +from torchvision import datasets, transforms +from torch.utils.tensorboard import SummaryWriter +import numpy as np +import random +from collections import deque +import os # For creating log directory + +# --- 0. Global Setup and Vocabulary --- + +# Simulated language data +sample_sentences = [ +"I love AI", +"Deep learning is fun", +"Spiking neurons are cool", +"Brain inspired models rock", +"Replay buffer helps learning" +] + + +def simple_tokenize(text): + """Basic tokenizer for sample sentences.""" + return text.lower().split()[:10] + + # Build vocab manually + vocab = {"": 0} + word_counter = 1 + for sentence in sample_sentences: + for word in simple_tokenize(sentence): + if word not in vocab: + vocab[word] = word_counter + word_counter += 1 + + # Task IDs (using an Enum-like class for clarity) + class Task: + REGRESSION = 3 + BINARY = 4 + VISION = 5 + + # --- 1. Base Classes and Spiking Neuron Implementations --- + + class Module(nn.Module): + """Base class for brain-inspired modules.""" + + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + class SurrogateLIF(torch.autograd.Function): + """ + Surrogate gradient function for LIF neurons. + Allows backpropagation through the non-differentiable spiking + function. + """ + @staticmethod + def forward( + ctx, input: torch.Tensor) -> torch.Tensor: + ctx.save_for_backward(input) + return (input > 0).float() + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor) -> torch.Tensor: + input, = ctx.saved_tensors + grad_input = grad_output.clone() + # Approximate derivative: a + # constant value where + # input is + close to 0 + grad_input[input.abs( + ) < 1] = 1.0 + return grad_input + + class SpikingNeuron( + Module): + """ + Leaky Integrate-and-Fire (LIF) neuron model. + Resets membrane potential to zero after spiking. + """ + + def __init__( + self, threshold: float = 1.0, decay: float = 0.95): + super().__init__() + self.threshold = threshold + self.decay = decay + # Register buffer + # for membrane + # potential, + # persists across + forward calls + # type: + # torch.Tensor + self.register_buffer( + 'mem', None) + + def forward( + self, x: torch.Tensor) -> torch.Tensor: + # Initialize or + # re-initialize + # membrane + # potential if + # shape + changes + if self.mem is None or self.mem.shape != x.shape: + self.mem = torch.zeros_like( + x) + + # Update + # membrane + # potential + self.mem = self.decay * self.mem + x + # Generate + # spike + spike = ( + self.mem >= self.threshold).float() + + # Reset + # membrane + # potential + # for + # spiking + # neurons + self.mem = torch.where(spike.bool(), + torch.zeros_like(self.mem), self.mem) + + # Apply + # surrogate + # gradient + # for + # backpropagation + return SurrogateLIF.apply( + spike) + + class AdaptiveLIFNeuron( + SpikingNeuron): + """ + LIF neuron with an adaptive threshold based on recent firing rate. + """ + + def __init__( + self, threshold: float = 1.0, decay: float = 0.95): + super().__init__(threshold=threshold, decay=decay) + # Register + # buffer + # for + # adaptive + # threshold + # state + self.register_buffer('threshold_state', + torch.tensor(threshold)) + + def forward( + self, x: torch.Tensor) -> torch.Tensor: + if self.mem is None or self.mem.shape != x.shape: + self.mem = torch.zeros_like( + x) + + self.mem = self.decay * self.mem + x + spike = ( + self.mem >= self.threshold_state).float() + self.mem = torch.where(spike.bool(), + torch.zeros_like(self.mem), self.mem) + + # Adapt + # threshold + # based + # on + # scalar + # mean + # spike + # rate + # (no_grad + # for + threshold update) + with torch.no_grad(): + spike_rate = spike.mean().item() # Scalar rate + # Simple + # moving + # average + # for + # threshold + # adaptation + self.threshold_state = self.threshold_state * 0.99 + + spike_rate * 0.01 + + return SurrogateLIF.apply( + spike) + + # --- 2. Encoder Modules (Task-Specific Input Processing) --- + + class SharedEncoder( + nn.Module): + """Generic encoder for numerical data.""" + def __init__( + self, input_size: int, hidden_size: int = 4): + super().__init__() + self.encoder = nn.Sequential( + nn.Linear( + input_size, hidden_size), + nn.LayerNorm( + hidden_size), + nn.ReLU(), + nn.Dropout( + 0.3) + + ) + + def forward( + self, x: torch.Tensor) -> torch.Tensor: + assert x.dim() == 2, f"Expected 2D input for SharedEncoder, + got {x.shape}" + return self.encoder(x.float()) + + + class CNNVision(nn.Module): + """CNN encoder for image data (e.g., MNIST).""" + def __init__(self, output_features: int = 4): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(1, 4, kernel_size=5), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Conv2d(4, 8, kernel_size=5), + nn.ReLU(), + nn.AdaptiveAvgPool2d((1, 1)), # Outputs [batch, 8, 1, 1] + nn.Flatten(), # Outputs [batch, 8] + nn.Linear(8, output_features) # Maps to desired output + size + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + + class GRULanguage(nn.Module): + """GRU encoder for sequential language data.""" + def __init__(self, vocab_size: int, embedding_dim: int = 4, + hidden_dim: int = 4): + super().__init__() + self.embed = nn.Embedding(vocab_size, embedding_dim) + self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.dim() == 1: + x = x.unsqueeze(0) # Add batch dimension if missing + emb = self.embed(x.long()) # Ensure input is long type for + embedding + out, _ = self.gru(emb) + return out[:, -1, :] # Use last hidden state as sequence + representation + + # --- 3. Brain-Inspired Modules with Specific Neuron Counts --- + + class SensoryProcessor(Module): + + """Processes initial sensory input, often with low-level + features.""" + def __init__(self, input_dim: int, output_dim: int): + super().__init__() + self.linear = nn.Linear(input_dim, output_dim) + self.norm = nn.LayerNorm(output_dim) + self.dropout = nn.Dropout(0.3) + self.neuron = SpikingNeuron() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + z = self.dropout(self.norm(self.linear(x))) + return self.neuron(z) + + + class RelayLayer(Module): + """ + Routes information, potentially performing attention-like + operations. + Simulates a small sequence for MultiheadAttention. + """ + def __init__(self, input_dim: int, output_dim: int): + super().__init__() + self.linear = nn.Linear(input_dim, output_dim) + self.norm = nn.LayerNorm(output_dim) + # MultiheadAttention requires embed_dim to be divisible by + num_heads + self.attn = nn.MultiheadAttention(embed_dim=output_dim, + num_heads=2, batch_first=True) + self.dropout = nn.Dropout(0.3) + self.neuron = SpikingNeuron() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + z = self.dropout(self.norm(self.linear(x))) + + # Create fake time dimension by repeating the input for + attention + seq_len = 4 + z_seq = z.unsqueeze(1).repeat(1, seq_len, 1) + + # Apply attention across the simulated time steps + z_out, _ = self.attn(z_seq, z_seq, z_seq) + + # Pool over the time dimension to get a single representation + routed = z_out.mean(dim=1) + + return self.neuron(routed) + + + + class InterneuronLogic(Module): + """Core processing unit, potentially for decision making or + high-level abstraction.""" + def __init__(self, input_dim: int, output_dim: int): + super().__init__() + self.linear = nn.Linear(input_dim, output_dim) + self.norm = nn.LayerNorm(output_dim) + self.dropout = nn.Dropout(0.3) + self.neuron = AdaptiveLIFNeuron() # Using adaptive neuron here + + def forward(self, x: torch.Tensor) -> torch.Tensor: + z = self.dropout(self.norm(self.linear(x))) + return self.neuron(z) + + + class NeuroendocrineModulator(Module): + """ + Modulates signals, potentially for gain control or emotional + states. + Applies a sigmoid-based gain to its input. + """ + def __init__(self, input_dim: int, output_dim: int): + super().__init__() + self.linear = nn.Linear(input_dim, output_dim) + self.norm = nn.LayerNorm(output_dim) + self.dropout = nn.Dropout(0.3) + self.gain_control = nn.Linear(output_dim, output_dim) # Maps + output to gain factor + + def forward(self, x: torch.Tensor) -> torch.Tensor: + z = self.dropout(self.norm(self.linear(x))) + gain = torch.sigmoid(self.gain_control(z)) # Gain factor + between 0 and 1 + return z * gain + + + class AutonomicProcessor(Module): + """ + Manages internal states, often involves recurrent processing. + Uses a GRU for sequential processing. + """ + def __init__(self, input_dim: int, output_dim: int): + super().__init__() + self.linear = nn.Linear(input_dim, output_dim) + self.norm = nn.LayerNorm(output_dim) + self.recurrent = nn.GRU(output_dim, output_dim, + batch_first=True) + self.feedback_gain = 0.9 + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + z = self.norm(self.linear(x)) + # GRU expects (batch, seq_len, features) -> seq_len=1 for + single step + h, _ = self.recurrent(z.unsqueeze(1)) + # Squeeze the sequence dimension back out + return self.feedback_gain * h.squeeze(1) + + + class MirrorComparator(Module): + """ + Compares current state to a reference, potentially for self-other + distinction or goal comparison. + Outputs spikes and an optional similarity score. + """ + def __init__(self, input_dim: int, output_dim: int): + super().__init__() + self.linear = nn.Linear(input_dim, output_dim) + self.norm = nn.LayerNorm(output_dim) + self.comparison_layer = nn.CosineSimilarity(dim=1) # Cosine + similarity for comparison + self.neuron = SpikingNeuron() + + def compare(self, a: torch.Tensor, b: torch.Tensor) -> + torch.Tensor: + """Compare pre-spike representations for similarity.""" + a_proj = self.linear(a) + b_proj = self.linear(b) + return self.comparison_layer(a_proj, b_proj) + + def forward(self, x: torch.Tensor, reference: torch.Tensor = None) + -> (torch.Tensor, torch.Tensor): + z = self.norm(self.linear(x)) + spike = self.neuron(z) + similarity = None + if reference is not None: + # Note: The 'reference' input would need to be processed + similarly before comparison + # For simplicity, assuming 'reference' is already in the + correct feature space for comparison_layer + similarity = self.compare(z, reference) # Pass processed + 'z' for comparison + return spike, similarity + + + class PlaceGridMemory(Module): + """ + + Spatial memory system using population codes and LSTM for + sequential memory. + """ + def __init__(self, input_dim: int, output_dim: int): + super().__init__() + self.linear = nn.Linear(input_dim, output_dim) + self.norm = nn.LayerNorm(output_dim) + self.positional_encoder = self.population_code # Reference to + internal method + self.memory_cell = nn.LSTM(output_dim, output_dim, + batch_first=True) + + def population_code(self, x: torch.Tensor, pop_size: int) -> + torch.Tensor: + """Expands a scalar value into a population-coded (one-hot + like) vector.""" + # Calculate a scalar representation (e.g., mean) + x_scalar = x.mean(dim=1) + x_normalized = torch.sigmoid(x_scalar) # Normalize to [0, 1] + # Map normalized value to an index within the population size + idx = (x_normalized * (pop_size - 1)).long().clamp(0, pop_size + - 1) + + encoded = torch.zeros(x.size(0), pop_size, device=x.device) + # Set the corresponding index to 1.0 + encoded.scatter_(1, idx.unsqueeze(1), 1.0) + return encoded + + def forward(self, x: torch.Tensor) -> torch.Tensor: + encoded = self.norm(self.linear(x)) + # Apply population coding based on configured output_dim + pos_encoded = self.positional_encoder(encoded, + pop_size=self.linear.out_features) + # LSTM expects (batch, seq_len, features) -> seq_len=1 for + single step + memory_out, _ = self.memory_cell(pos_encoded.unsqueeze(1)) + # Squeeze the sequence dimension back out + return memory_out.squeeze(1) + + + # --- 4. Task Heads (Task-Specific Output Layers) --- + + class TaskHead(nn.Module): + """Generic task head for different output types.""" + def __init__(self, input_dim: int, output_dim: int, task_type: + str): + super().__init__() + self.task_type = task_type + + self.head = nn.Sequential( + nn.Linear(input_dim, output_dim), + nn.Dropout(0.3) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Regression and Binary will be squeezed to [batch] + # Vision will remain [batch, num_classes] + return self.head(x).squeeze(-1) if self.task_type in + ['binary', 'regression'] else self.head(x) + + + # --- 5. Replay Buffer for Continual Learning --- + + class TaskReplayBuffer: + """ + A replay buffer that stores experiences tagged with their task ID. + Samples are drawn proportionally or randomly across tasks. + """ + def __init__(self, buffer_size: int = 1000, device: str = "cpu"): + self.buffer_size = buffer_size + self.task_buffers = {task_id: deque(maxlen=buffer_size) for + task_id in [Task.REGRESSION, Task.BINARY, Task.VISION]} + self.device = device + + def add(self, task_id: int, states: torch.Tensor, labels: + torch.Tensor): + """Adds experiences (state-label pairs) to the buffer for a + specific task.""" + if task_id not in self.task_buffers: + # Should not happen if self.task_buffers is + pre-initialized with all Task IDs + self.task_buffers[task_id] = + deque(maxlen=self.buffer_size) + + for state, label in zip(states, labels): + # Detach from graph and move to CPU to prevent memory + leaks + self.task_buffers[task_id].append((state.detach().cpu(), + label.detach().cpu())) + + def sample(self, batch_size: int = 32) -> (torch.Tensor, + torch.Tensor, list): + """ + Samples a batch of experiences from the buffer, ensuring mixed + tasks. + Returns: (states, labels, list of corresponding task_ids) + """ + + active_tasks = [tid for tid, buffer in + self.task_buffers.items() if len(buffer) > 0] + if not active_tasks: + return None, None, None + + # Sample at least two distinct tasks if possible + num_tasks_to_sample = min(2, len(active_tasks)) + sampled_tasks_ids = random.sample(active_tasks, + num_tasks_to_sample) + + batch_states, batch_labels, batch_task_ids = [], [], [] + + # Distribute batch_size across sampled tasks + per_task_samples_base = batch_size // num_tasks_to_sample + remainder = batch_size % num_tasks_to_sample + + for i, task_id in enumerate(sampled_tasks_ids): + task_list = list(self.task_buffers[task_id]) # Convert + deque to list for random.sample/choices + if not task_list: + continue + + # Add remainder to first few tasks + k = per_task_samples_base + (1 if i < remainder else 0) + k = min(k, len(task_list)) # Ensure we don't sample more + than available + + if k == 0: continue + + try: + samples = random.sample(task_list, k) + except ValueError: + # If k is larger than task_list length (should be + caught by min(k, len(task_list)) + # but good fallback for robustness) + samples = random.choices(task_list, k=k) + + for state, label in samples: + batch_states.append(state) + batch_labels.append(label) + batch_task_ids.append(task_id) + + if not batch_states: + return None, None, None + + # Stack tensors and return task IDs for individual processing + return ( + torch.stack(batch_states), + + torch.stack(batch_labels), + batch_task_ids # Return as list to preserve individual + task identities + ) + + + # --- 6. Full Agent Model --- + + class ModularBrainAgent(nn.Module): + """ + A comprehensive modular neural agent inspired by brain + architecture, + integrating spiking neurons, attention, and recurrent memory for + multi-task learning. + """ + def __init__(self, neuron_counts: dict = None): + super().__init__() + + # --- Parameterization of Neuron Counts --- + if neuron_counts is None: + # Default neuron counts for each brain region + neuron_counts = { + 'sensory': 4, + 'relay': 12, + 'interneurons': 2, + 'neuroendocrine': 8, + 'autonomic': 10, + 'mirror': 14, + 'place_grid': 16 + } + self.neuron_counts = neuron_counts + + # --- Encoders (Input Modalities) --- + self.encoders = nn.ModuleDict({ + 'regression': SharedEncoder(2, + self.neuron_counts['sensory']), + 'language': GRULanguage(len(vocab), + embedding_dim=self.neuron_counts['sensory'], + hidden_dim=self.neuron_counts['sensory']), + 'vision': + CNNVision(output_features=self.neuron_counts['sensory']) + }) + + # --- Brain Modules with Realistic Neuron Counts --- + self.sensory = SensoryProcessor( + input_dim=self.neuron_counts['sensory'], + output_dim=self.neuron_counts['sensory'] + ) + + self.relay = RelayLayer( + input_dim=self.neuron_counts['sensory'], # Input from + SensoryProcessor + output_dim=self.neuron_counts['relay'] + ) + self.interneurons = InterneuronLogic( + input_dim=self.neuron_counts['relay'], # Input from + RelayLayer + output_dim=self.neuron_counts['interneurons'] + ) + self.neuroendocrine = NeuroendocrineModulator( + input_dim=self.neuron_counts['interneurons'], + output_dim=self.neuron_counts['neuroendocrine'] + ) + self.autonomic = AutonomicProcessor( + input_dim=self.neuron_counts['neuroendocrine'], + output_dim=self.neuron_counts['autonomic'] + ) + self.mirror = MirrorComparator( + input_dim=self.neuron_counts['autonomic'], + output_dim=self.neuron_counts['mirror'] + ) + self.place_grid = PlaceGridMemory( + input_dim=self.neuron_counts['mirror'], + output_dim=self.neuron_counts['place_grid'] + ) + + # --- Interconnection Weights (between brain modules) --- + # Ensure input/output dimensions match the neuron_counts + self.connect_sensory_to_relay = + nn.Linear(self.neuron_counts['sensory'], self.neuron_counts['relay']) + self.connect_relay_to_inter = + nn.Linear(self.neuron_counts['relay'], + self.neuron_counts['interneurons']) + self.connect_inter_to_modulators = + nn.Linear(self.neuron_counts['interneurons'], + self.neuron_counts['neuroendocrine']) + self.connect_modulators_to_auto = + nn.Linear(self.neuron_counts['neuroendocrine'], + self.neuron_counts['autonomic']) + self.connect_auto_to_mirror = + nn.Linear(self.neuron_counts['autonomic'], + self.neuron_counts['mirror']) + self.connect_mirror_to_place = + nn.Linear(self.neuron_counts['mirror'], + self.neuron_counts['place_grid']) + # Recurrent loop from place_grid back to relay + self.connect_place_to_relay = + + nn.Linear(self.neuron_counts['place_grid'], + self.neuron_counts['relay']) + + # --- Optional Feedback Connections --- + self.feedback_relay_to_sensory = + nn.Linear(self.neuron_counts['relay'], self.neuron_counts['sensory']) + self.feedback_inter_to_relay = + nn.Linear(self.neuron_counts['interneurons'], + self.neuron_counts['relay']) + + # --- Task Heads (Outputs for specific tasks) --- + self.task_heads = nn.ModuleDict({ + 'binary': TaskHead(self.neuron_counts['place_grid'], 1, + 'binary'), + 'vision': TaskHead(self.neuron_counts['place_grid'], 10, + 'vision'), # MNIST has 10 classes + 'regression': TaskHead(self.neuron_counts['place_grid'], + 1, 'regression') + }) + + def route_modules(self, x: torch.Tensor) -> tuple[torch.Tensor, + ...]: + """ + Defines the forward pass through the interconnected brain + modules. + Returns intermediate activations for potential monitoring or + later use. + """ + # Forward pass through brain circuits + h_sensory = self.sensory(x) + h_relay = self.connect_sensory_to_relay(h_sensory) + h_relay = self.relay(h_relay) + + h_inter = self.connect_relay_to_inter(h_relay) + h_inter = self.interneurons(h_inter) + + h_modulate = self.connect_inter_to_modulators(h_inter) + h_modulate = self.neuroendocrine(h_modulate) + + h_auto = self.connect_modulators_to_auto(h_modulate) + h_auto = self.autonomic(h_auto) + + h_mirror_result, _ = + self.mirror(self.connect_auto_to_mirror(h_auto)) # Mirror returns + tuple + # if isinstance(mirror_result, tuple): # No longer needed as + forward always returns tuple + # h_mirror, _ = mirror_result + + # else: + # h_mirror = mirror_result + h_mirror = h_mirror_result # Renamed for clarity after tuple + unpack + + h_place = self.connect_mirror_to_place(h_mirror) + h_place = self.place_grid(h_place) + + # Recurrent loop from hippocampus (place_grid) back to + thalamus (relay) + # Note: Added ReLU here to prevent negative values from + feeding back into spiking neurons potentially + h_relay = h_relay + + torch.relu(self.connect_place_to_relay(h_place)) + + # Optional feedback connections + h_sensory = h_sensory + + torch.relu(self.feedback_relay_to_sensory(h_relay)) + h_relay = h_relay + + torch.relu(self.feedback_inter_to_relay(h_inter)) + + return h_relay, h_place, h_mirror, h_auto, h_modulate + + + def encode(self, x: torch.Tensor, task_id: int) -> torch.Tensor: + """Selects and applies the appropriate encoder based on task + ID and input shape.""" + # Check task ID and input shape for specific encoders + if task_id == Task.REGRESSION and x.dim() == 2 and x.size(1) + == 2: + return self.encoders['regression'](x.float()) + elif task_id == Task.BINARY and x.dim() == 2 and x.size(1) == + 10: + return self.encoders['language'](x.long()) + elif task_id == Task.VISION: + if x.dim() == 4: # Already 4D (batch, 1, H, W) + return self.encoders['vision'](x.float()) + elif x.dim() == 2 and x.size(1) == 784: # Flattened 2D + (batch, 784) + return self.encoders['vision'](x.view(x.size(0), 1, + 28, 28)) + else: + raise ValueError(f"Unexpected vision input shape: + {x.shape}") + else: + # Fallback for unexpected task_id/shape combination, or if + x is already "encoded" + # In a real system, you might want a specific error or + + default encoder here. + # For now, assumes x is already in the 'sensory' input + dimension if it falls through. + if x.size(-1) != self.neuron_counts['sensory']: + raise ValueError(f"Input {x.shape} does not match + sensory input dim {self.neuron_counts['sensory']} and no specific + encoder found for task {task_id}.") + return x.float() + + + def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor: + """ + Main forward pass of the Modular Brain Agent. + Encodes input, routes through brain modules, and selects task + head. + """ + if x.dim() == 1: + x = x.unsqueeze(0) # Ensure batch dimension + + encoded = self.encode(x, task_id) + # Route through brain modules, focusing on the outputs used + for task heads + _, h_place, _, _, _ = self.route_modules(encoded) # Only + h_place is used for task heads + + # Map task IDs to the appropriate task head name + head_name = { + Task.REGRESSION: 'regression', + Task.BINARY: 'binary', + Task.VISION: 'vision' + }.get(task_id, 'regression') # Default to regression head if + task_id is unexpected + + out = self.task_heads[head_name](h_place) # All tasks use the + PlaceGridMemory output + return out + + + # --- 7. Data Loaders (Synthetic and Real) --- + + def get_mnist_loader() -> (DataLoader, int): + """Returns MNIST DataLoader and its Task ID.""" + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + dataset = datasets.MNIST(root='./data', train=True, download=True, + transform=transform) + + return DataLoader(dataset, batch_size=4, shuffle=True), + Task.VISION + + + def get_imdb_loader() -> (DataLoader, int): + """Returns a synthetic IMDB-like DataLoader and its Task ID.""" + texts = [] + labels = [] + + for i in range(1000): # Create 1000 synthetic samples + sentence = random.choice(sample_sentences) + tokens = [vocab.get(word, vocab[""]) for word in + simple_tokenize(sentence)] + + # Pad or truncate tokens to a fixed length + while len(tokens) < 10: + tokens.append(vocab[""]) + if len(tokens) > 10: + tokens = tokens[:10] + + texts.append(tokens) + labels.append(i % 2) # Binary sentiment: 0 or 1 + + X = torch.tensor(texts).long() + Y = torch.tensor(labels).float() + dataset = TensorDataset(X, Y) + return DataLoader(dataset, batch_size=4, shuffle=True), + Task.BINARY + + + def get_regression_loader() -> (DataLoader, int): + """Returns a synthetic regression DataLoader and its Task ID.""" + X = torch.randn(1000, 2) # 1000 samples, 2 features + Y = ((X[:, 0] + X[:, 1]) / 2).float() # Simple linear relationship + dataset = TensorDataset(X, Y) + return DataLoader(dataset, batch_size=4, shuffle=True), + Task.REGRESSION + + + # --- 8. Helper Function: Shape Inspector --- + + def inspect_shapes(agent: ModularBrainAgent): + """ + Prints the shapes of tensors at various points in the agent's + forward pass + to help verify architectural correctness. + """ + print("\n=== SHAPE INSPECTOR ===") + + # Use CPU for inspection to avoid CUDA memory issues during + debugging + agent.cpu() + agent.eval() # Set to eval mode for consistent behavior (e.g., + dropout off) + + try: + # Test regression path + sample_input_reg = torch.randn(4, 2) + encoded_reg = agent.encode(sample_input_reg, Task.REGRESSION) + print(f"Encoded (Regression): {encoded_reg.shape}") + h_relay_reg, h_place_reg, h_mirror_reg, h_auto_reg, + h_modulate_reg = agent.route_modules(encoded_reg) + out_reg = agent.task_heads['regression'](h_place_reg) + print(f"Regression Path: + Sensory({agent.sensory(encoded_reg).shape}) -> + Relay({h_relay_reg.shape}) -> Place({h_place_reg.shape}) -> + Output({out_reg.shape})") + + # Test language path + # Using fixed input_dim for language to map to sensory input + dim + sample_input_lang = torch.randint(0, len(vocab), (4, 10)) + encoded_lang = agent.encode(sample_input_lang, Task.BINARY) # + Use BINARY task for language + print(f"Encoded (Language): {encoded_lang.shape}") + h_relay_lang, h_place_lang, h_mirror_lang, h_auto_lang, + h_modulate_lang = agent.route_modules(encoded_lang) + out_lang = agent.task_heads['binary'](h_place_lang) + print(f"Language Path: + Sensory({agent.sensory(encoded_lang).shape}) -> + Relay({h_relay_lang.shape}) -> Place({h_place_lang.shape}) -> + Output({out_lang.shape})") + + # Test vision path + sample_input_vis = torch.randn(4, 1, 28, 28) + encoded_vis = agent.encode(sample_input_vis, Task.VISION) + print(f"Encoded (Vision): {encoded_vis.shape}") + h_relay_vis, h_place_vis, h_mirror_vis, h_auto_vis, + h_modulate_vis = agent.route_modules(encoded_vis) + out_vis = agent.task_heads['vision'](h_place_vis) + print(f"Vision Path: + Sensory({agent.sensory(encoded_vis).shape}) -> + Relay({h_relay_vis.shape}) -> Place({h_place_vis.shape}) -> + Output({out_vis.shape})") + + except Exception as e: + print(f"Shape inspection failed: {e}") + + # Re-raise to ensure the error is not ignored in main + execution + raise e + finally: + agent.train() # Set back to train mode + print("=========================\n") + + + # --- 9. Training Function --- + + def train(agent: ModularBrainAgent, episodes: int = 14400, + buffer_size: int = 1000, replay_freq: int = 5): + """ + Trains the ModularBrainAgent using a curriculum learning strategy + and experience replay. + """ + device = torch.device("cuda" if torch.cuda.is_available() else + "cpu") + agent.to(device) + print(f"Using device: {device}") + + print("🧠 Running shape inspector...") + inspect_shapes(agent) + + replay_buffer = TaskReplayBuffer(buffer_size=buffer_size, + device=device) + + optimizer = torch.optim.AdamW(agent.parameters(), lr=0.001, + weight_decay=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, + patience=100, factor=0.5, verbose=True) + + # Initialize weights + def init_weights(m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + torch.nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + torch.nn.init.kaiming_uniform_(m.weight, + nonlinearity='relu') + if m.bias is not None: + torch.nn.init.zeros_(m.bias) + agent.apply(init_weights) + + # Setup TensorBoard writer + log_dir = "runs/modular_brain_agent" + os.makedirs(log_dir, exist_ok=True) + + writer = SummaryWriter(log_dir=log_dir) + print(f"TensorBoard logs are being saved to: {log_dir}") + + + best_loss = float('inf') + no_improvement = 0 + save_path = "best_model.pth" + task_history = {i: [] for i in [Task.REGRESSION, Task.BINARY, + Task.VISION]} + curriculum_stages = { + 0: [Task.REGRESSION], + 300: [Task.REGRESSION, Task.BINARY], # Start binary after 300 + episodes + 600: [Task.REGRESSION, Task.BINARY, Task.VISION] # Start + vision after 600 episodes + } + + raw_loaders = { + Task.REGRESSION: get_regression_loader(), + Task.BINARY: get_imdb_loader(), + Task.VISION: get_mnist_loader() + } + + loaders = {k: v[0] for k, v in raw_loaders.items()} + # task_ids are fixed as keys for loaders in this setup + + # Initialize persistent iterators for each task's DataLoader + loaders_iter = { + Task.REGRESSION: iter(loaders[Task.REGRESSION]), + Task.BINARY: iter(loaders[Task.BINARY]), + Task.VISION: iter(loaders[Task.VISION]) + } + + # Helper for loss computation + def compute_loss(out: torch.Tensor, Y: torch.Tensor, task_id: int) + -> torch.Tensor: + """Computes loss based on task type, handling shape + alignments.""" + if task_id == Task.VISION: + Y = Y.long() # Target labels for CrossEntropy must be long + if out.dim() != 2 or Y.dim() != 1: + raise ValueError(f"Vision task expects out [batch, + num_classes], Y [batch]. Got {out.shape} vs {Y.shape}") + return F.cross_entropy(out, Y) + + elif task_id == Task.BINARY: + # Ensure output and target have matching shapes for + BCEWithLogitsLoss + + if out.shape != Y.shape: + if out.numel() == Y.numel(): + out = out.view_as(Y) # Reshape output to match + target if num elements are same + else: + raise ValueError(f"Binary task shape mismatch: out + {out.shape} vs Y {Y.shape}") + return F.binary_cross_entropy_with_logits(out, Y.float()) + + else: # REGRESSION + if out.shape != Y.shape: + if out.numel() == Y.numel(): + out = out.view_as(Y) # Reshape output to match + target if num elements are same + else: + raise ValueError(f"Regression task shape mismatch: + out {out.shape} vs Y {Y.shape}") + return F.smooth_l1_loss(out, Y.float()) + + loss_weights = {Task.REGRESSION: 1.5, Task.BINARY: 1.2, + Task.VISION: 1.2} + global_step = 0 # For TensorBoard logging + + for ep in range(episodes): + agent.train() # Ensure model is in training mode + + # Determine current curriculum stage + current_stage_episodes = 0 + for stage_start_ep, tasks in + sorted(curriculum_stages.items()): + if ep >= stage_start_ep: + current_stage_episodes = stage_start_ep + current_tasks = curriculum_stages[current_stage_episodes] + + # Randomly select a task from the active curriculum stage + task_id = np.random.choice(current_tasks) + + # Fetch data using persistent iterator + try: + X, Y = next(loaders_iter[task_id]) + except StopIteration: + # Reset iterator when exhausted for that task + loaders_iter[task_id] = iter(loaders[task_id]) + X, Y = next(loaders_iter[task_id]) + + X, Y = X.to(device), Y.to(device) + + # --- Primary Training Step --- + + out = agent(X, task_id) + loss = compute_loss(out, Y, task_id) * + loss_weights.get(task_id, 1.0) + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(agent.parameters(), + max_norm=1.0) # Clip gradients + optimizer.step() + + # Store current experience in replay buffer + replay_buffer.add(task_id, X, Y) + + # --- Accuracy Metric Calculation --- + acc = 0.0 + with torch.no_grad(): # Don't track gradients for accuracy + calculation + if task_id == Task.VISION: + pred = out.argmax(dim=1) + acc = (pred == Y).float().mean().item() + elif task_id == Task.BINARY: + pred = (torch.sigmoid(out) > 0.5).float() + acc = (pred == Y).float().mean().item() + else: # REGRESSION + acc = ((out - Y).abs() < 0.2).float().mean().item() # + within 0.2 error margin + + task_history[task_id].append((loss.item(), acc)) + + # --- Replay Phase (Fixed Logic) --- + if ep % replay_freq == 0: + replay_X_mixed, replay_Y_mixed, replay_task_ids_list = + replay_buffer.sample(batch_size=4) + + if replay_X_mixed is not None: + replay_X_mixed = replay_X_mixed.to(device) + replay_Y_mixed = replay_Y_mixed.to(device) + + unique_replay_tasks = list(set(replay_task_ids_list)) + total_replay_loss = 0.0 + num_replay_samples = 0 + + # Iterate through each unique task in the sampled + batch + for current_replay_task_id in unique_replay_tasks: + # Filter samples belonging to this specific task + indices_for_this_task = [i for i, tid in + enumerate(replay_task_ids_list) if tid == current_replay_task_id] + + + if not indices_for_this_task: + continue + + task_replay_X = + replay_X_mixed[indices_for_this_task] + task_replay_Y = + replay_Y_mixed[indices_for_this_task] + + # Process these samples using the correct encoder + and task head + replay_out_for_task = agent(task_replay_X, + current_replay_task_id) + + # Compute loss for this task's samples + replay_loss_for_task = + compute_loss(replay_out_for_task, task_replay_Y, + current_replay_task_id) + total_replay_loss += replay_loss_for_task * + len(indices_for_this_task) # Weighted sum + num_replay_samples += len(indices_for_this_task) + + if num_replay_samples > 0: + total_replay_loss /= num_replay_samples # Average + over all replayed samples + optimizer.zero_grad() + total_replay_loss.backward() + torch.nn.utils.clip_grad_norm_(agent.parameters(), + max_norm=1.0) + optimizer.step() + # Log replay loss + writer.add_scalar('Loss/Replay_Loss', + total_replay_loss.item(), global_step) + + + # --- Logging and Early Stopping --- + scheduler.step(loss.item()) # Step scheduler based on current + task loss + + # TensorBoard Logging + writer.add_scalar(f'Loss/Current_Task_{task_id}', loss.item(), + global_step) + writer.add_scalar(f'Accuracy/Current_Task_{task_id}', acc, + global_step) + writer.add_scalar('LearningRate', + optimizer.param_groups[0]['lr'], global_step) + + global_step += 1 + + + # Check for best model and early stopping + if loss.item() < best_loss: + best_loss = loss.item() + no_improvement = 0 + torch.save(agent.state_dict(), save_path) + # print(f"New best model saved at episode {ep} with loss: + {best_loss:.4f}") + else: + no_improvement += 1 + + # Print summary periodically + if ep % 200 == 0: + print(f"\n--- Episode {ep} (Curriculum Stage: + {current_stage_episodes}) ---") + for t_id in sorted(task_history.keys()): + if task_history[t_id]: + # Get recent history, default to empty if not + enough entries + recent_losses = [x[0] for x in + task_history[t_id][-200:]] + recent_accs = [x[1] for x in + task_history[t_id][-200:]] + + avg_loss = np.mean(recent_losses) if recent_losses + else 0.0 + avg_acc = np.mean(recent_accs) if recent_accs else + 0.0 + + # Map task ID to a readable name + task_name = {Task.REGRESSION: "Regression", + Task.BINARY: "Binary", Task.VISION: "Vision"}.get(t_id, + f"Unknown_{t_id}") + print(f"Task {task_name} | Avg Loss: + {avg_loss:.3f} | Avg Acc: {avg_acc:.2f} ({len(recent_losses)} + samples)") + print(f"Overall Current Loss: {loss.item():.4f} | Best + Loss: {best_loss:.4f} | No Improvement: {no_improvement}") + print("--------------------\n") + + if no_improvement >= 1000: # Early stopping threshold + print(f"Early stopping at episode {ep} due to no + improvement for {no_improvement} steps.") + break + + writer.close() + print("āœ… Training finished.") + return best_loss + + + + # --- 10. Main Execution Block --- + + if __name__ == "__main__": + print("šŸš€ Initializing Modular Brain Agent...") + + # Optional: Define custom neuron counts here + # custom_neuron_config = { + # 'sensory': 8, + # 'relay': 24, + # 'interneurons': 4, + # 'neuroendocrine': 16, + # 'autonomic': 20, + # 'mirror': 28, + # 'place_grid': 32 + # } + # agent = ModularBrainAgent(neuron_counts=custom_neuron_config) + + agent = ModularBrainAgent() # Using default neuron counts for now + + # Print model summary + total_params = sum(p.numel() for p in agent.parameters()) + trainable_params = sum(p.numel() for p in agent.parameters() if + p.requires_grad) + + print(f"šŸ“Š Model Summary:") + print(f" Total parameters: {total_params:,}") + print(f" Trainable parameters: {trainable_params:,}") + print(f" Model size: ~{total_params * 4 / (1024**2):.2f} MB + (approx. float32)") + + # Test forward pass on each task to verify shapes + print("\nšŸ” Testing forward passes...") + try: + # Test regression task + test_reg = torch.randn(2, 2) + out_reg = agent(test_reg, Task.REGRESSION) + print(f"āœ… Regression test passed: {test_reg.shape} -> + {out_reg.shape}") + + # Test binary classification task (language task) + test_bin = torch.randint(0, len(vocab), (2, 10)) + out_bin = agent(test_bin, Task.BINARY) + print(f"āœ… Binary classification test passed: {test_bin.shape} + -> {out_bin.shape}") + + # Test vision task + + test_vis = torch.randn(2, 1, 28, 28) + out_vis = agent(test_vis, Task.VISION) + print(f"āœ… Vision test passed: {test_vis.shape} -> + {out_vis.shape}") + + except Exception as e: + print(f"āŒ Forward pass test failed: {e}") + # If forward pass fails, there's a fundamental issue, so exit + import sys + sys.exit(1) + + print("\nšŸŽÆ Starting training...") + + # Train the agent + try: + final_best_loss = train( + agent=agent, + episodes=14400, # Total training episodes + buffer_size=1000, # Replay buffer size + replay_freq=5 # Replay every N episodes + ) + + print(f"\nšŸŽ‰ Training completed successfully!") + print(f"šŸ“ˆ Best loss achieved during training: + {final_best_loss:.4f}") + print(f"šŸ’¾ Best model saved to: best_model.pth") + + except KeyboardInterrupt: + print("\nā¹ Training interrupted by user.") + except Exception as e: + print(f"\nāŒ Training failed with error: {e}") + raise e + + print("\n🧠 Modular Brain Agent execution completed!") \ No newline at end of file