ModularBrainAgent / CLEANED_modular_brain_agent (1).py
Almusawee's picture
Upload CLEANED_modular_brain_agent (1).py
55bc86e verified
raw
history blame
304 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import random
from collections import deque
import os # For creating log directory
# --- 0. Global Setup and Vocabulary ---
# Simulated language data
sample_sentences = [
"I love AI",
"Deep learning is fun",
"Spiking neurons are cool",
"Brain inspired models rock",
"Replay buffer helps learning"
]
def simple_tokenize(text):
"""Basic tokenizer for sample sentences."""
return text.lower().split()[:10]
# Build vocab manually
vocab = {"<unk>": 0}
word_counter = 1
for sentence in sample_sentences:
for word in simple_tokenize(sentence):
if word not in vocab:
vocab[word] = word_counter
word_counter += 1
# Task IDs (using an Enum-like class for clarity)
class Task:
REGRESSION = 3
BINARY = 4
VISION = 5
# --- 1. Base Classes and Spiking Neuron Implementations ---
class Module(nn.Module):
"""Base class for brain-inspired modules."""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class SurrogateLIF(torch.autograd.Function):
"""
Surrogate gradient function for LIF neurons.
Allows backpropagation through the non-differentiable spiking
function.
"""
@staticmethod
def forward(
ctx, input: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(input)
return (input > 0).float()
@staticmethod
def backward(
ctx, grad_output: torch.Tensor) -> torch.Tensor:
input, = ctx.saved_tensors
grad_input = grad_output.clone()
# Approximate derivative: a
# constant value where
# input is
close to 0
grad_input[input.abs(
) < 1] = 1.0
return grad_input
class SpikingNeuron(
Module):
"""
Leaky Integrate-and-Fire (LIF) neuron model.
Resets membrane potential to zero after spiking.
"""
def __init__(
self, threshold: float = 1.0, decay: float = 0.95):
super().__init__()
self.threshold = threshold
self.decay = decay
# Register buffer
# for membrane
# potential,
# persists across
forward calls
# type:
# torch.Tensor
self.register_buffer(
'mem', None)
def forward(
self, x: torch.Tensor) -> torch.Tensor:
# Initialize or
# re-initialize
# membrane
# potential if
# shape
changes
if self.mem is None or self.mem.shape != x.shape:
self.mem = torch.zeros_like(
x)
# Update
# membrane
# potential
self.mem = self.decay * self.mem + x
# Generate
# spike
spike = (
self.mem >= self.threshold).float()
# Reset
# membrane
# potential
# for
# spiking
# neurons
self.mem = torch.where(spike.bool(),
torch.zeros_like(self.mem), self.mem)
# Apply
# surrogate
# gradient
# for
# backpropagation
return SurrogateLIF.apply(
spike)
class AdaptiveLIFNeuron(
SpikingNeuron):
"""
LIF neuron with an adaptive threshold based on recent firing rate.
"""
def __init__(
self, threshold: float = 1.0, decay: float = 0.95):
super().__init__(threshold=threshold, decay=decay)
# Register
# buffer
# for
# adaptive
# threshold
# state
self.register_buffer('threshold_state',
torch.tensor(threshold))
def forward(
self, x: torch.Tensor) -> torch.Tensor:
if self.mem is None or self.mem.shape != x.shape:
self.mem = torch.zeros_like(
x)
self.mem = self.decay * self.mem + x
spike = (
self.mem >= self.threshold_state).float()
self.mem = torch.where(spike.bool(),
torch.zeros_like(self.mem), self.mem)
# Adapt
# threshold
# based
# on
# scalar
# mean
# spike
# rate
# (no_grad
# for
threshold update)
with torch.no_grad():
spike_rate = spike.mean().item() # Scalar rate
# Simple
# moving
# average
# for
# threshold
# adaptation
self.threshold_state = self.threshold_state * 0.99 +
spike_rate * 0.01
return SurrogateLIF.apply(
spike)
# --- 2. Encoder Modules (Task-Specific Input Processing) ---
class SharedEncoder(
nn.Module):
"""Generic encoder for numerical data."""
def __init__(
self, input_size: int, hidden_size: int = 4):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(
input_size, hidden_size),
nn.LayerNorm(
hidden_size),
nn.ReLU(),
nn.Dropout(
0.3)
)
def forward(
self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 2, f"Expected 2D input for SharedEncoder,
got {x.shape}"
return self.encoder(x.float())
class CNNVision(nn.Module):
"""CNN encoder for image data (e.g., MNIST)."""
def __init__(self, output_features: int = 4):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 4, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(4, 8, kernel_size=5),
nn.ReLU(),
nn.AdaptiveAvgPool2d((1, 1)), # Outputs [batch, 8, 1, 1]
nn.Flatten(), # Outputs [batch, 8]
nn.Linear(8, output_features) # Maps to desired output
size
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class GRULanguage(nn.Module):
"""GRU encoder for sequential language data."""
def __init__(self, vocab_size: int, embedding_dim: int = 4,
hidden_dim: int = 4):
super().__init__()
self.embed = nn.Embedding(vocab_size, embedding_dim)
self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dim() == 1:
x = x.unsqueeze(0) # Add batch dimension if missing
emb = self.embed(x.long()) # Ensure input is long type for
embedding
out, _ = self.gru(emb)
return out[:, -1, :] # Use last hidden state as sequence
representation
# --- 3. Brain-Inspired Modules with Specific Neuron Counts ---
class SensoryProcessor(Module):
"""Processes initial sensory input, often with low-level
features."""
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
self.norm = nn.LayerNorm(output_dim)
self.dropout = nn.Dropout(0.3)
self.neuron = SpikingNeuron()
def forward(self, x: torch.Tensor) -> torch.Tensor:
z = self.dropout(self.norm(self.linear(x)))
return self.neuron(z)
class RelayLayer(Module):
"""
Routes information, potentially performing attention-like
operations.
Simulates a small sequence for MultiheadAttention.
"""
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
self.norm = nn.LayerNorm(output_dim)
# MultiheadAttention requires embed_dim to be divisible by
num_heads
self.attn = nn.MultiheadAttention(embed_dim=output_dim,
num_heads=2, batch_first=True)
self.dropout = nn.Dropout(0.3)
self.neuron = SpikingNeuron()
def forward(self, x: torch.Tensor) -> torch.Tensor:
z = self.dropout(self.norm(self.linear(x)))
# Create fake time dimension by repeating the input for
attention
seq_len = 4
z_seq = z.unsqueeze(1).repeat(1, seq_len, 1)
# Apply attention across the simulated time steps
z_out, _ = self.attn(z_seq, z_seq, z_seq)
# Pool over the time dimension to get a single representation
routed = z_out.mean(dim=1)
return self.neuron(routed)
class InterneuronLogic(Module):
"""Core processing unit, potentially for decision making or
high-level abstraction."""
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
self.norm = nn.LayerNorm(output_dim)
self.dropout = nn.Dropout(0.3)
self.neuron = AdaptiveLIFNeuron() # Using adaptive neuron here
def forward(self, x: torch.Tensor) -> torch.Tensor:
z = self.dropout(self.norm(self.linear(x)))
return self.neuron(z)
class NeuroendocrineModulator(Module):
"""
Modulates signals, potentially for gain control or emotional
states.
Applies a sigmoid-based gain to its input.
"""
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
self.norm = nn.LayerNorm(output_dim)
self.dropout = nn.Dropout(0.3)
self.gain_control = nn.Linear(output_dim, output_dim) # Maps
output to gain factor
def forward(self, x: torch.Tensor) -> torch.Tensor:
z = self.dropout(self.norm(self.linear(x)))
gain = torch.sigmoid(self.gain_control(z)) # Gain factor
between 0 and 1
return z * gain
class AutonomicProcessor(Module):
"""
Manages internal states, often involves recurrent processing.
Uses a GRU for sequential processing.
"""
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
self.norm = nn.LayerNorm(output_dim)
self.recurrent = nn.GRU(output_dim, output_dim,
batch_first=True)
self.feedback_gain = 0.9
def forward(self, x: torch.Tensor) -> torch.Tensor:
z = self.norm(self.linear(x))
# GRU expects (batch, seq_len, features) -> seq_len=1 for
single step
h, _ = self.recurrent(z.unsqueeze(1))
# Squeeze the sequence dimension back out
return self.feedback_gain * h.squeeze(1)
class MirrorComparator(Module):
"""
Compares current state to a reference, potentially for self-other
distinction or goal comparison.
Outputs spikes and an optional similarity score.
"""
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
self.norm = nn.LayerNorm(output_dim)
self.comparison_layer = nn.CosineSimilarity(dim=1) # Cosine
similarity for comparison
self.neuron = SpikingNeuron()
def compare(self, a: torch.Tensor, b: torch.Tensor) ->
torch.Tensor:
"""Compare pre-spike representations for similarity."""
a_proj = self.linear(a)
b_proj = self.linear(b)
return self.comparison_layer(a_proj, b_proj)
def forward(self, x: torch.Tensor, reference: torch.Tensor = None)
-> (torch.Tensor, torch.Tensor):
z = self.norm(self.linear(x))
spike = self.neuron(z)
similarity = None
if reference is not None:
# Note: The 'reference' input would need to be processed
similarly before comparison
# For simplicity, assuming 'reference' is already in the
correct feature space for comparison_layer
similarity = self.compare(z, reference) # Pass processed
'z' for comparison
return spike, similarity
class PlaceGridMemory(Module):
"""
Spatial memory system using population codes and LSTM for
sequential memory.
"""
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
self.norm = nn.LayerNorm(output_dim)
self.positional_encoder = self.population_code # Reference to
internal method
self.memory_cell = nn.LSTM(output_dim, output_dim,
batch_first=True)
def population_code(self, x: torch.Tensor, pop_size: int) ->
torch.Tensor:
"""Expands a scalar value into a population-coded (one-hot
like) vector."""
# Calculate a scalar representation (e.g., mean)
x_scalar = x.mean(dim=1)
x_normalized = torch.sigmoid(x_scalar) # Normalize to [0, 1]
# Map normalized value to an index within the population size
idx = (x_normalized * (pop_size - 1)).long().clamp(0, pop_size
- 1)
encoded = torch.zeros(x.size(0), pop_size, device=x.device)
# Set the corresponding index to 1.0
encoded.scatter_(1, idx.unsqueeze(1), 1.0)
return encoded
def forward(self, x: torch.Tensor) -> torch.Tensor:
encoded = self.norm(self.linear(x))
# Apply population coding based on configured output_dim
pos_encoded = self.positional_encoder(encoded,
pop_size=self.linear.out_features)
# LSTM expects (batch, seq_len, features) -> seq_len=1 for
single step
memory_out, _ = self.memory_cell(pos_encoded.unsqueeze(1))
# Squeeze the sequence dimension back out
return memory_out.squeeze(1)
# --- 4. Task Heads (Task-Specific Output Layers) ---
class TaskHead(nn.Module):
"""Generic task head for different output types."""
def __init__(self, input_dim: int, output_dim: int, task_type:
str):
super().__init__()
self.task_type = task_type
self.head = nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.Dropout(0.3)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Regression and Binary will be squeezed to [batch]
# Vision will remain [batch, num_classes]
return self.head(x).squeeze(-1) if self.task_type in
['binary', 'regression'] else self.head(x)
# --- 5. Replay Buffer for Continual Learning ---
class TaskReplayBuffer:
"""
A replay buffer that stores experiences tagged with their task ID.
Samples are drawn proportionally or randomly across tasks.
"""
def __init__(self, buffer_size: int = 1000, device: str = "cpu"):
self.buffer_size = buffer_size
self.task_buffers = {task_id: deque(maxlen=buffer_size) for
task_id in [Task.REGRESSION, Task.BINARY, Task.VISION]}
self.device = device
def add(self, task_id: int, states: torch.Tensor, labels:
torch.Tensor):
"""Adds experiences (state-label pairs) to the buffer for a
specific task."""
if task_id not in self.task_buffers:
# Should not happen if self.task_buffers is
pre-initialized with all Task IDs
self.task_buffers[task_id] =
deque(maxlen=self.buffer_size)
for state, label in zip(states, labels):
# Detach from graph and move to CPU to prevent memory
leaks
self.task_buffers[task_id].append((state.detach().cpu(),
label.detach().cpu()))
def sample(self, batch_size: int = 32) -> (torch.Tensor,
torch.Tensor, list):
"""
Samples a batch of experiences from the buffer, ensuring mixed
tasks.
Returns: (states, labels, list of corresponding task_ids)
"""
active_tasks = [tid for tid, buffer in
self.task_buffers.items() if len(buffer) > 0]
if not active_tasks:
return None, None, None
# Sample at least two distinct tasks if possible
num_tasks_to_sample = min(2, len(active_tasks))
sampled_tasks_ids = random.sample(active_tasks,
num_tasks_to_sample)
batch_states, batch_labels, batch_task_ids = [], [], []
# Distribute batch_size across sampled tasks
per_task_samples_base = batch_size // num_tasks_to_sample
remainder = batch_size % num_tasks_to_sample
for i, task_id in enumerate(sampled_tasks_ids):
task_list = list(self.task_buffers[task_id]) # Convert
deque to list for random.sample/choices
if not task_list:
continue
# Add remainder to first few tasks
k = per_task_samples_base + (1 if i < remainder else 0)
k = min(k, len(task_list)) # Ensure we don't sample more
than available
if k == 0: continue
try:
samples = random.sample(task_list, k)
except ValueError:
# If k is larger than task_list length (should be
caught by min(k, len(task_list))
# but good fallback for robustness)
samples = random.choices(task_list, k=k)
for state, label in samples:
batch_states.append(state)
batch_labels.append(label)
batch_task_ids.append(task_id)
if not batch_states:
return None, None, None
# Stack tensors and return task IDs for individual processing
return (
torch.stack(batch_states),
torch.stack(batch_labels),
batch_task_ids # Return as list to preserve individual
task identities
)
# --- 6. Full Agent Model ---
class ModularBrainAgent(nn.Module):
"""
A comprehensive modular neural agent inspired by brain
architecture,
integrating spiking neurons, attention, and recurrent memory for
multi-task learning.
"""
def __init__(self, neuron_counts: dict = None):
super().__init__()
# --- Parameterization of Neuron Counts ---
if neuron_counts is None:
# Default neuron counts for each brain region
neuron_counts = {
'sensory': 4,
'relay': 12,
'interneurons': 2,
'neuroendocrine': 8,
'autonomic': 10,
'mirror': 14,
'place_grid': 16
}
self.neuron_counts = neuron_counts
# --- Encoders (Input Modalities) ---
self.encoders = nn.ModuleDict({
'regression': SharedEncoder(2,
self.neuron_counts['sensory']),
'language': GRULanguage(len(vocab),
embedding_dim=self.neuron_counts['sensory'],
hidden_dim=self.neuron_counts['sensory']),
'vision':
CNNVision(output_features=self.neuron_counts['sensory'])
})
# --- Brain Modules with Realistic Neuron Counts ---
self.sensory = SensoryProcessor(
input_dim=self.neuron_counts['sensory'],
output_dim=self.neuron_counts['sensory']
)
self.relay = RelayLayer(
input_dim=self.neuron_counts['sensory'], # Input from
SensoryProcessor
output_dim=self.neuron_counts['relay']
)
self.interneurons = InterneuronLogic(
input_dim=self.neuron_counts['relay'], # Input from
RelayLayer
output_dim=self.neuron_counts['interneurons']
)
self.neuroendocrine = NeuroendocrineModulator(
input_dim=self.neuron_counts['interneurons'],
output_dim=self.neuron_counts['neuroendocrine']
)
self.autonomic = AutonomicProcessor(
input_dim=self.neuron_counts['neuroendocrine'],
output_dim=self.neuron_counts['autonomic']
)
self.mirror = MirrorComparator(
input_dim=self.neuron_counts['autonomic'],
output_dim=self.neuron_counts['mirror']
)
self.place_grid = PlaceGridMemory(
input_dim=self.neuron_counts['mirror'],
output_dim=self.neuron_counts['place_grid']
)
# --- Interconnection Weights (between brain modules) ---
# Ensure input/output dimensions match the neuron_counts
self.connect_sensory_to_relay =
nn.Linear(self.neuron_counts['sensory'], self.neuron_counts['relay'])
self.connect_relay_to_inter =
nn.Linear(self.neuron_counts['relay'],
self.neuron_counts['interneurons'])
self.connect_inter_to_modulators =
nn.Linear(self.neuron_counts['interneurons'],
self.neuron_counts['neuroendocrine'])
self.connect_modulators_to_auto =
nn.Linear(self.neuron_counts['neuroendocrine'],
self.neuron_counts['autonomic'])
self.connect_auto_to_mirror =
nn.Linear(self.neuron_counts['autonomic'],
self.neuron_counts['mirror'])
self.connect_mirror_to_place =
nn.Linear(self.neuron_counts['mirror'],
self.neuron_counts['place_grid'])
# Recurrent loop from place_grid back to relay
self.connect_place_to_relay =
nn.Linear(self.neuron_counts['place_grid'],
self.neuron_counts['relay'])
# --- Optional Feedback Connections ---
self.feedback_relay_to_sensory =
nn.Linear(self.neuron_counts['relay'], self.neuron_counts['sensory'])
self.feedback_inter_to_relay =
nn.Linear(self.neuron_counts['interneurons'],
self.neuron_counts['relay'])
# --- Task Heads (Outputs for specific tasks) ---
self.task_heads = nn.ModuleDict({
'binary': TaskHead(self.neuron_counts['place_grid'], 1,
'binary'),
'vision': TaskHead(self.neuron_counts['place_grid'], 10,
'vision'), # MNIST has 10 classes
'regression': TaskHead(self.neuron_counts['place_grid'],
1, 'regression')
})
def route_modules(self, x: torch.Tensor) -> tuple[torch.Tensor,
...]:
"""
Defines the forward pass through the interconnected brain
modules.
Returns intermediate activations for potential monitoring or
later use.
"""
# Forward pass through brain circuits
h_sensory = self.sensory(x)
h_relay = self.connect_sensory_to_relay(h_sensory)
h_relay = self.relay(h_relay)
h_inter = self.connect_relay_to_inter(h_relay)
h_inter = self.interneurons(h_inter)
h_modulate = self.connect_inter_to_modulators(h_inter)
h_modulate = self.neuroendocrine(h_modulate)
h_auto = self.connect_modulators_to_auto(h_modulate)
h_auto = self.autonomic(h_auto)
h_mirror_result, _ =
self.mirror(self.connect_auto_to_mirror(h_auto)) # Mirror returns
tuple
# if isinstance(mirror_result, tuple): # No longer needed as
forward always returns tuple
# h_mirror, _ = mirror_result
# else:
# h_mirror = mirror_result
h_mirror = h_mirror_result # Renamed for clarity after tuple
unpack
h_place = self.connect_mirror_to_place(h_mirror)
h_place = self.place_grid(h_place)
# Recurrent loop from hippocampus (place_grid) back to
thalamus (relay)
# Note: Added ReLU here to prevent negative values from
feeding back into spiking neurons potentially
h_relay = h_relay +
torch.relu(self.connect_place_to_relay(h_place))
# Optional feedback connections
h_sensory = h_sensory +
torch.relu(self.feedback_relay_to_sensory(h_relay))
h_relay = h_relay +
torch.relu(self.feedback_inter_to_relay(h_inter))
return h_relay, h_place, h_mirror, h_auto, h_modulate
def encode(self, x: torch.Tensor, task_id: int) -> torch.Tensor:
"""Selects and applies the appropriate encoder based on task
ID and input shape."""
# Check task ID and input shape for specific encoders
if task_id == Task.REGRESSION and x.dim() == 2 and x.size(1)
== 2:
return self.encoders['regression'](x.float())
elif task_id == Task.BINARY and x.dim() == 2 and x.size(1) ==
10:
return self.encoders['language'](x.long())
elif task_id == Task.VISION:
if x.dim() == 4: # Already 4D (batch, 1, H, W)
return self.encoders['vision'](x.float())
elif x.dim() == 2 and x.size(1) == 784: # Flattened 2D
(batch, 784)
return self.encoders['vision'](x.view(x.size(0), 1,
28, 28))
else:
raise ValueError(f"Unexpected vision input shape:
{x.shape}")
else:
# Fallback for unexpected task_id/shape combination, or if
x is already "encoded"
# In a real system, you might want a specific error or
default encoder here.
# For now, assumes x is already in the 'sensory' input
dimension if it falls through.
if x.size(-1) != self.neuron_counts['sensory']:
raise ValueError(f"Input {x.shape} does not match
sensory input dim {self.neuron_counts['sensory']} and no specific
encoder found for task {task_id}.")
return x.float()
def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor:
"""
Main forward pass of the Modular Brain Agent.
Encodes input, routes through brain modules, and selects task
head.
"""
if x.dim() == 1:
x = x.unsqueeze(0) # Ensure batch dimension
encoded = self.encode(x, task_id)
# Route through brain modules, focusing on the outputs used
for task heads
_, h_place, _, _, _ = self.route_modules(encoded) # Only
h_place is used for task heads
# Map task IDs to the appropriate task head name
head_name = {
Task.REGRESSION: 'regression',
Task.BINARY: 'binary',
Task.VISION: 'vision'
}.get(task_id, 'regression') # Default to regression head if
task_id is unexpected
out = self.task_heads[head_name](h_place) # All tasks use the
PlaceGridMemory output
return out
# --- 7. Data Loaders (Synthetic and Real) ---
def get_mnist_loader() -> (DataLoader, int):
"""Returns MNIST DataLoader and its Task ID."""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root='./data', train=True, download=True,
transform=transform)
return DataLoader(dataset, batch_size=4, shuffle=True),
Task.VISION
def get_imdb_loader() -> (DataLoader, int):
"""Returns a synthetic IMDB-like DataLoader and its Task ID."""
texts = []
labels = []
for i in range(1000): # Create 1000 synthetic samples
sentence = random.choice(sample_sentences)
tokens = [vocab.get(word, vocab["<unk>"]) for word in
simple_tokenize(sentence)]
# Pad or truncate tokens to a fixed length
while len(tokens) < 10:
tokens.append(vocab["<unk>"])
if len(tokens) > 10:
tokens = tokens[:10]
texts.append(tokens)
labels.append(i % 2) # Binary sentiment: 0 or 1
X = torch.tensor(texts).long()
Y = torch.tensor(labels).float()
dataset = TensorDataset(X, Y)
return DataLoader(dataset, batch_size=4, shuffle=True),
Task.BINARY
def get_regression_loader() -> (DataLoader, int):
"""Returns a synthetic regression DataLoader and its Task ID."""
X = torch.randn(1000, 2) # 1000 samples, 2 features
Y = ((X[:, 0] + X[:, 1]) / 2).float() # Simple linear relationship
dataset = TensorDataset(X, Y)
return DataLoader(dataset, batch_size=4, shuffle=True),
Task.REGRESSION
# --- 8. Helper Function: Shape Inspector ---
def inspect_shapes(agent: ModularBrainAgent):
"""
Prints the shapes of tensors at various points in the agent's
forward pass
to help verify architectural correctness.
"""
print("\n=== SHAPE INSPECTOR ===")
# Use CPU for inspection to avoid CUDA memory issues during
debugging
agent.cpu()
agent.eval() # Set to eval mode for consistent behavior (e.g.,
dropout off)
try:
# Test regression path
sample_input_reg = torch.randn(4, 2)
encoded_reg = agent.encode(sample_input_reg, Task.REGRESSION)
print(f"Encoded (Regression): {encoded_reg.shape}")
h_relay_reg, h_place_reg, h_mirror_reg, h_auto_reg,
h_modulate_reg = agent.route_modules(encoded_reg)
out_reg = agent.task_heads['regression'](h_place_reg)
print(f"Regression Path:
Sensory({agent.sensory(encoded_reg).shape}) ->
Relay({h_relay_reg.shape}) -> Place({h_place_reg.shape}) ->
Output({out_reg.shape})")
# Test language path
# Using fixed input_dim for language to map to sensory input
dim
sample_input_lang = torch.randint(0, len(vocab), (4, 10))
encoded_lang = agent.encode(sample_input_lang, Task.BINARY) #
Use BINARY task for language
print(f"Encoded (Language): {encoded_lang.shape}")
h_relay_lang, h_place_lang, h_mirror_lang, h_auto_lang,
h_modulate_lang = agent.route_modules(encoded_lang)
out_lang = agent.task_heads['binary'](h_place_lang)
print(f"Language Path:
Sensory({agent.sensory(encoded_lang).shape}) ->
Relay({h_relay_lang.shape}) -> Place({h_place_lang.shape}) ->
Output({out_lang.shape})")
# Test vision path
sample_input_vis = torch.randn(4, 1, 28, 28)
encoded_vis = agent.encode(sample_input_vis, Task.VISION)
print(f"Encoded (Vision): {encoded_vis.shape}")
h_relay_vis, h_place_vis, h_mirror_vis, h_auto_vis,
h_modulate_vis = agent.route_modules(encoded_vis)
out_vis = agent.task_heads['vision'](h_place_vis)
print(f"Vision Path:
Sensory({agent.sensory(encoded_vis).shape}) ->
Relay({h_relay_vis.shape}) -> Place({h_place_vis.shape}) ->
Output({out_vis.shape})")
except Exception as e:
print(f"Shape inspection failed: {e}")
# Re-raise to ensure the error is not ignored in main
execution
raise e
finally:
agent.train() # Set back to train mode
print("=========================\n")
# --- 9. Training Function ---
def train(agent: ModularBrainAgent, episodes: int = 14400,
buffer_size: int = 1000, replay_freq: int = 5):
"""
Trains the ModularBrainAgent using a curriculum learning strategy
and experience replay.
"""
device = torch.device("cuda" if torch.cuda.is_available() else
"cpu")
agent.to(device)
print(f"Using device: {device}")
print("🧠 Running shape inspector...")
inspect_shapes(agent)
replay_buffer = TaskReplayBuffer(buffer_size=buffer_size,
device=device)
optimizer = torch.optim.AdamW(agent.parameters(), lr=0.001,
weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
patience=100, factor=0.5, verbose=True)
# Initialize weights
def init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_uniform_(m.weight,
nonlinearity='relu')
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
agent.apply(init_weights)
# Setup TensorBoard writer
log_dir = "runs/modular_brain_agent"
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
print(f"TensorBoard logs are being saved to: {log_dir}")
best_loss = float('inf')
no_improvement = 0
save_path = "best_model.pth"
task_history = {i: [] for i in [Task.REGRESSION, Task.BINARY,
Task.VISION]}
curriculum_stages = {
0: [Task.REGRESSION],
300: [Task.REGRESSION, Task.BINARY], # Start binary after 300
episodes
600: [Task.REGRESSION, Task.BINARY, Task.VISION] # Start
vision after 600 episodes
}
raw_loaders = {
Task.REGRESSION: get_regression_loader(),
Task.BINARY: get_imdb_loader(),
Task.VISION: get_mnist_loader()
}
loaders = {k: v[0] for k, v in raw_loaders.items()}
# task_ids are fixed as keys for loaders in this setup
# Initialize persistent iterators for each task's DataLoader
loaders_iter = {
Task.REGRESSION: iter(loaders[Task.REGRESSION]),
Task.BINARY: iter(loaders[Task.BINARY]),
Task.VISION: iter(loaders[Task.VISION])
}
# Helper for loss computation
def compute_loss(out: torch.Tensor, Y: torch.Tensor, task_id: int)
-> torch.Tensor:
"""Computes loss based on task type, handling shape
alignments."""
if task_id == Task.VISION:
Y = Y.long() # Target labels for CrossEntropy must be long
if out.dim() != 2 or Y.dim() != 1:
raise ValueError(f"Vision task expects out [batch,
num_classes], Y [batch]. Got {out.shape} vs {Y.shape}")
return F.cross_entropy(out, Y)
elif task_id == Task.BINARY:
# Ensure output and target have matching shapes for
BCEWithLogitsLoss
if out.shape != Y.shape:
if out.numel() == Y.numel():
out = out.view_as(Y) # Reshape output to match
target if num elements are same
else:
raise ValueError(f"Binary task shape mismatch: out
{out.shape} vs Y {Y.shape}")
return F.binary_cross_entropy_with_logits(out, Y.float())
else: # REGRESSION
if out.shape != Y.shape:
if out.numel() == Y.numel():
out = out.view_as(Y) # Reshape output to match
target if num elements are same
else:
raise ValueError(f"Regression task shape mismatch:
out {out.shape} vs Y {Y.shape}")
return F.smooth_l1_loss(out, Y.float())
loss_weights = {Task.REGRESSION: 1.5, Task.BINARY: 1.2,
Task.VISION: 1.2}
global_step = 0 # For TensorBoard logging
for ep in range(episodes):
agent.train() # Ensure model is in training mode
# Determine current curriculum stage
current_stage_episodes = 0
for stage_start_ep, tasks in
sorted(curriculum_stages.items()):
if ep >= stage_start_ep:
current_stage_episodes = stage_start_ep
current_tasks = curriculum_stages[current_stage_episodes]
# Randomly select a task from the active curriculum stage
task_id = np.random.choice(current_tasks)
# Fetch data using persistent iterator
try:
X, Y = next(loaders_iter[task_id])
except StopIteration:
# Reset iterator when exhausted for that task
loaders_iter[task_id] = iter(loaders[task_id])
X, Y = next(loaders_iter[task_id])
X, Y = X.to(device), Y.to(device)
# --- Primary Training Step ---
out = agent(X, task_id)
loss = compute_loss(out, Y, task_id) *
loss_weights.get(task_id, 1.0)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(agent.parameters(),
max_norm=1.0) # Clip gradients
optimizer.step()
# Store current experience in replay buffer
replay_buffer.add(task_id, X, Y)
# --- Accuracy Metric Calculation ---
acc = 0.0
with torch.no_grad(): # Don't track gradients for accuracy
calculation
if task_id == Task.VISION:
pred = out.argmax(dim=1)
acc = (pred == Y).float().mean().item()
elif task_id == Task.BINARY:
pred = (torch.sigmoid(out) > 0.5).float()
acc = (pred == Y).float().mean().item()
else: # REGRESSION
acc = ((out - Y).abs() < 0.2).float().mean().item() #
within 0.2 error margin
task_history[task_id].append((loss.item(), acc))
# --- Replay Phase (Fixed Logic) ---
if ep % replay_freq == 0:
replay_X_mixed, replay_Y_mixed, replay_task_ids_list =
replay_buffer.sample(batch_size=4)
if replay_X_mixed is not None:
replay_X_mixed = replay_X_mixed.to(device)
replay_Y_mixed = replay_Y_mixed.to(device)
unique_replay_tasks = list(set(replay_task_ids_list))
total_replay_loss = 0.0
num_replay_samples = 0
# Iterate through each unique task in the sampled
batch
for current_replay_task_id in unique_replay_tasks:
# Filter samples belonging to this specific task
indices_for_this_task = [i for i, tid in
enumerate(replay_task_ids_list) if tid == current_replay_task_id]
if not indices_for_this_task:
continue
task_replay_X =
replay_X_mixed[indices_for_this_task]
task_replay_Y =
replay_Y_mixed[indices_for_this_task]
# Process these samples using the correct encoder
and task head
replay_out_for_task = agent(task_replay_X,
current_replay_task_id)
# Compute loss for this task's samples
replay_loss_for_task =
compute_loss(replay_out_for_task, task_replay_Y,
current_replay_task_id)
total_replay_loss += replay_loss_for_task *
len(indices_for_this_task) # Weighted sum
num_replay_samples += len(indices_for_this_task)
if num_replay_samples > 0:
total_replay_loss /= num_replay_samples # Average
over all replayed samples
optimizer.zero_grad()
total_replay_loss.backward()
torch.nn.utils.clip_grad_norm_(agent.parameters(),
max_norm=1.0)
optimizer.step()
# Log replay loss
writer.add_scalar('Loss/Replay_Loss',
total_replay_loss.item(), global_step)
# --- Logging and Early Stopping ---
scheduler.step(loss.item()) # Step scheduler based on current
task loss
# TensorBoard Logging
writer.add_scalar(f'Loss/Current_Task_{task_id}', loss.item(),
global_step)
writer.add_scalar(f'Accuracy/Current_Task_{task_id}', acc,
global_step)
writer.add_scalar('LearningRate',
optimizer.param_groups[0]['lr'], global_step)
global_step += 1
# Check for best model and early stopping
if loss.item() < best_loss:
best_loss = loss.item()
no_improvement = 0
torch.save(agent.state_dict(), save_path)
# print(f"New best model saved at episode {ep} with loss:
{best_loss:.4f}")
else:
no_improvement += 1
# Print summary periodically
if ep % 200 == 0:
print(f"\n--- Episode {ep} (Curriculum Stage:
{current_stage_episodes}) ---")
for t_id in sorted(task_history.keys()):
if task_history[t_id]:
# Get recent history, default to empty if not
enough entries
recent_losses = [x[0] for x in
task_history[t_id][-200:]]
recent_accs = [x[1] for x in
task_history[t_id][-200:]]
avg_loss = np.mean(recent_losses) if recent_losses
else 0.0
avg_acc = np.mean(recent_accs) if recent_accs else
0.0
# Map task ID to a readable name
task_name = {Task.REGRESSION: "Regression",
Task.BINARY: "Binary", Task.VISION: "Vision"}.get(t_id,
f"Unknown_{t_id}")
print(f"Task {task_name} | Avg Loss:
{avg_loss:.3f} | Avg Acc: {avg_acc:.2f} ({len(recent_losses)}
samples)")
print(f"Overall Current Loss: {loss.item():.4f} | Best
Loss: {best_loss:.4f} | No Improvement: {no_improvement}")
print("--------------------\n")
if no_improvement >= 1000: # Early stopping threshold
print(f"Early stopping at episode {ep} due to no
improvement for {no_improvement} steps.")
break
writer.close()
print("✅ Training finished.")
return best_loss
# --- 10. Main Execution Block ---
if __name__ == "__main__":
print("🚀 Initializing Modular Brain Agent...")
# Optional: Define custom neuron counts here
# custom_neuron_config = {
# 'sensory': 8,
# 'relay': 24,
# 'interneurons': 4,
# 'neuroendocrine': 16,
# 'autonomic': 20,
# 'mirror': 28,
# 'place_grid': 32
# }
# agent = ModularBrainAgent(neuron_counts=custom_neuron_config)
agent = ModularBrainAgent() # Using default neuron counts for now
# Print model summary
total_params = sum(p.numel() for p in agent.parameters())
trainable_params = sum(p.numel() for p in agent.parameters() if
p.requires_grad)
print(f"📊 Model Summary:")
print(f" Total parameters: {total_params:,}")
print(f" Trainable parameters: {trainable_params:,}")
print(f" Model size: ~{total_params * 4 / (1024**2):.2f} MB
(approx. float32)")
# Test forward pass on each task to verify shapes
print("\n🔍 Testing forward passes...")
try:
# Test regression task
test_reg = torch.randn(2, 2)
out_reg = agent(test_reg, Task.REGRESSION)
print(f"✅ Regression test passed: {test_reg.shape} ->
{out_reg.shape}")
# Test binary classification task (language task)
test_bin = torch.randint(0, len(vocab), (2, 10))
out_bin = agent(test_bin, Task.BINARY)
print(f"✅ Binary classification test passed: {test_bin.shape}
-> {out_bin.shape}")
# Test vision task
test_vis = torch.randn(2, 1, 28, 28)
out_vis = agent(test_vis, Task.VISION)
print(f"✅ Vision test passed: {test_vis.shape} ->
{out_vis.shape}")
except Exception as e:
print(f"❌ Forward pass test failed: {e}")
# If forward pass fails, there's a fundamental issue, so exit
import sys
sys.exit(1)
print("\n🎯 Starting training...")
# Train the agent
try:
final_best_loss = train(
agent=agent,
episodes=14400, # Total training episodes
buffer_size=1000, # Replay buffer size
replay_freq=5 # Replay every N episodes
)
print(f"\n🎉 Training completed successfully!")
print(f"📈 Best loss achieved during training:
{final_best_loss:.4f}")
print(f"💾 Best model saved to: best_model.pth")
except KeyboardInterrupt:
print("\n⏹ Training interrupted by user.")
except Exception as e:
print(f"\n❌ Training failed with error: {e}")
raise e
print("\n🧠 Modular Brain Agent execution completed!")