| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, TensorDataset |
| from torchvision import datasets, transforms |
| from torch.utils.tensorboard import SummaryWriter |
| import numpy as np |
| import random |
| from collections import deque |
| import os |
|
|
| |
|
|
| |
| sample_sentences = [ |
| "I love AI", |
| "Deep learning is fun", |
| "Spiking neurons are cool", |
| "Brain inspired models rock", |
| "Replay buffer helps learning" |
| ] |
|
|
|
|
| def simple_tokenize(text): |
| """Basic tokenizer for sample sentences.""" |
| return text.lower().split()[:10] |
|
|
| |
| vocab = {"<unk>": 0} |
| word_counter = 1 |
| for sentence in sample_sentences: |
| for word in simple_tokenize(sentence): |
| if word not in vocab: |
| vocab[word] = word_counter |
| word_counter += 1 |
|
|
| |
| class Task: |
| REGRESSION = 3 |
| BINARY = 4 |
| VISION = 5 |
|
|
| |
|
|
| class Module(nn.Module): |
| """Base class for brain-inspired modules.""" |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| raise NotImplementedError |
|
|
| class SurrogateLIF(torch.autograd.Function): |
| """ |
| Surrogate gradient function for LIF neurons. |
| Allows backpropagation through the non-differentiable spiking |
| function. |
| """ |
| @staticmethod |
| def forward( |
| ctx, input: torch.Tensor) -> torch.Tensor: |
| ctx.save_for_backward(input) |
| return (input > 0).float() |
|
|
| @staticmethod |
| def backward( |
| ctx, grad_output: torch.Tensor) -> torch.Tensor: |
| input, = ctx.saved_tensors |
| grad_input = grad_output.clone() |
| |
| |
| |
| close to 0 |
| grad_input[input.abs( |
| ) < 1] = 1.0 |
| return grad_input |
|
|
| class SpikingNeuron( |
| Module): |
| """ |
| Leaky Integrate-and-Fire (LIF) neuron model. |
| Resets membrane potential to zero after spiking. |
| """ |
|
|
| def __init__( |
| self, threshold: float = 1.0, decay: float = 0.95): |
| super().__init__() |
| self.threshold = threshold |
| self.decay = decay |
| |
| |
| |
| |
| forward calls |
| |
| |
| self.register_buffer( |
| 'mem', None) |
|
|
| def forward( |
| self, x: torch.Tensor) -> torch.Tensor: |
| |
| |
| |
| |
| |
| changes |
| if self.mem is None or self.mem.shape != x.shape: |
| self.mem = torch.zeros_like( |
| x) |
|
|
| |
| |
| |
| self.mem = self.decay * self.mem + x |
| |
| |
| spike = ( |
| self.mem >= self.threshold).float() |
|
|
| |
| |
| |
| |
| |
| |
| self.mem = torch.where(spike.bool(), |
| torch.zeros_like(self.mem), self.mem) |
|
|
| |
| |
| |
| |
| |
| return SurrogateLIF.apply( |
| spike) |
|
|
| class AdaptiveLIFNeuron( |
| SpikingNeuron): |
| """ |
| LIF neuron with an adaptive threshold based on recent firing rate. |
| """ |
|
|
| def __init__( |
| self, threshold: float = 1.0, decay: float = 0.95): |
| super().__init__(threshold=threshold, decay=decay) |
| |
| |
| |
| |
| |
| |
| self.register_buffer('threshold_state', |
| torch.tensor(threshold)) |
|
|
| def forward( |
| self, x: torch.Tensor) -> torch.Tensor: |
| if self.mem is None or self.mem.shape != x.shape: |
| self.mem = torch.zeros_like( |
| x) |
|
|
| self.mem = self.decay * self.mem + x |
| spike = ( |
| self.mem >= self.threshold_state).float() |
| self.mem = torch.where(spike.bool(), |
| torch.zeros_like(self.mem), self.mem) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| threshold update) |
| with torch.no_grad(): |
| spike_rate = spike.mean().item() |
| |
| |
| |
| |
| |
| |
| self.threshold_state = self.threshold_state * 0.99 + |
| spike_rate * 0.01 |
|
|
| return SurrogateLIF.apply( |
| spike) |
|
|
| |
|
|
| class SharedEncoder( |
| nn.Module): |
| """Generic encoder for numerical data.""" |
| def __init__( |
| self, input_size: int, hidden_size: int = 4): |
| super().__init__() |
| self.encoder = nn.Sequential( |
| nn.Linear( |
| input_size, hidden_size), |
| nn.LayerNorm( |
| hidden_size), |
| nn.ReLU(), |
| nn.Dropout( |
| 0.3) |
|
|
| ) |
|
|
| def forward( |
| self, x: torch.Tensor) -> torch.Tensor: |
| assert x.dim() == 2, f"Expected 2D input for SharedEncoder, |
| got {x.shape}" |
| return self.encoder(x.float()) |
|
|
|
|
| class CNNVision(nn.Module): |
| """CNN encoder for image data (e.g., MNIST).""" |
| def __init__(self, output_features: int = 4): |
| super().__init__() |
| self.conv = nn.Sequential( |
| nn.Conv2d(1, 4, kernel_size=5), |
| nn.ReLU(), |
| nn.MaxPool2d(2), |
| nn.Conv2d(4, 8, kernel_size=5), |
| nn.ReLU(), |
| nn.AdaptiveAvgPool2d((1, 1)), |
| nn.Flatten(), |
| nn.Linear(8, output_features) |
| size |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.conv(x) |
|
|
|
|
| class GRULanguage(nn.Module): |
| """GRU encoder for sequential language data.""" |
| def __init__(self, vocab_size: int, embedding_dim: int = 4, |
| hidden_dim: int = 4): |
| super().__init__() |
| self.embed = nn.Embedding(vocab_size, embedding_dim) |
| self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if x.dim() == 1: |
| x = x.unsqueeze(0) |
| emb = self.embed(x.long()) |
| embedding |
| out, _ = self.gru(emb) |
| return out[:, -1, :] |
| representation |
|
|
| |
|
|
| class SensoryProcessor(Module): |
|
|
| """Processes initial sensory input, often with low-level |
| features.""" |
| def __init__(self, input_dim: int, output_dim: int): |
| super().__init__() |
| self.linear = nn.Linear(input_dim, output_dim) |
| self.norm = nn.LayerNorm(output_dim) |
| self.dropout = nn.Dropout(0.3) |
| self.neuron = SpikingNeuron() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| z = self.dropout(self.norm(self.linear(x))) |
| return self.neuron(z) |
|
|
|
|
| class RelayLayer(Module): |
| """ |
| Routes information, potentially performing attention-like |
| operations. |
| Simulates a small sequence for MultiheadAttention. |
| """ |
| def __init__(self, input_dim: int, output_dim: int): |
| super().__init__() |
| self.linear = nn.Linear(input_dim, output_dim) |
| self.norm = nn.LayerNorm(output_dim) |
| |
| num_heads |
| self.attn = nn.MultiheadAttention(embed_dim=output_dim, |
| num_heads=2, batch_first=True) |
| self.dropout = nn.Dropout(0.3) |
| self.neuron = SpikingNeuron() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| z = self.dropout(self.norm(self.linear(x))) |
|
|
| |
| attention |
| seq_len = 4 |
| z_seq = z.unsqueeze(1).repeat(1, seq_len, 1) |
|
|
| |
| z_out, _ = self.attn(z_seq, z_seq, z_seq) |
|
|
| |
| routed = z_out.mean(dim=1) |
|
|
| return self.neuron(routed) |
|
|
|
|
|
|
| class InterneuronLogic(Module): |
| """Core processing unit, potentially for decision making or |
| high-level abstraction.""" |
| def __init__(self, input_dim: int, output_dim: int): |
| super().__init__() |
| self.linear = nn.Linear(input_dim, output_dim) |
| self.norm = nn.LayerNorm(output_dim) |
| self.dropout = nn.Dropout(0.3) |
| self.neuron = AdaptiveLIFNeuron() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| z = self.dropout(self.norm(self.linear(x))) |
| return self.neuron(z) |
|
|
|
|
| class NeuroendocrineModulator(Module): |
| """ |
| Modulates signals, potentially for gain control or emotional |
| states. |
| Applies a sigmoid-based gain to its input. |
| """ |
| def __init__(self, input_dim: int, output_dim: int): |
| super().__init__() |
| self.linear = nn.Linear(input_dim, output_dim) |
| self.norm = nn.LayerNorm(output_dim) |
| self.dropout = nn.Dropout(0.3) |
| self.gain_control = nn.Linear(output_dim, output_dim) |
| output to gain factor |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| z = self.dropout(self.norm(self.linear(x))) |
| gain = torch.sigmoid(self.gain_control(z)) |
| between 0 and 1 |
| return z * gain |
|
|
|
|
| class AutonomicProcessor(Module): |
| """ |
| Manages internal states, often involves recurrent processing. |
| Uses a GRU for sequential processing. |
| """ |
| def __init__(self, input_dim: int, output_dim: int): |
| super().__init__() |
| self.linear = nn.Linear(input_dim, output_dim) |
| self.norm = nn.LayerNorm(output_dim) |
| self.recurrent = nn.GRU(output_dim, output_dim, |
| batch_first=True) |
| self.feedback_gain = 0.9 |
|
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| z = self.norm(self.linear(x)) |
| |
| single step |
| h, _ = self.recurrent(z.unsqueeze(1)) |
| |
| return self.feedback_gain * h.squeeze(1) |
|
|
|
|
| class MirrorComparator(Module): |
| """ |
| Compares current state to a reference, potentially for self-other |
| distinction or goal comparison. |
| Outputs spikes and an optional similarity score. |
| """ |
| def __init__(self, input_dim: int, output_dim: int): |
| super().__init__() |
| self.linear = nn.Linear(input_dim, output_dim) |
| self.norm = nn.LayerNorm(output_dim) |
| self.comparison_layer = nn.CosineSimilarity(dim=1) |
| similarity for comparison |
| self.neuron = SpikingNeuron() |
|
|
| def compare(self, a: torch.Tensor, b: torch.Tensor) -> |
| torch.Tensor: |
| """Compare pre-spike representations for similarity.""" |
| a_proj = self.linear(a) |
| b_proj = self.linear(b) |
| return self.comparison_layer(a_proj, b_proj) |
|
|
| def forward(self, x: torch.Tensor, reference: torch.Tensor = None) |
| -> (torch.Tensor, torch.Tensor): |
| z = self.norm(self.linear(x)) |
| spike = self.neuron(z) |
| similarity = None |
| if reference is not None: |
| |
| similarly before comparison |
| |
| correct feature space for comparison_layer |
| similarity = self.compare(z, reference) |
| 'z' for comparison |
| return spike, similarity |
|
|
|
|
| class PlaceGridMemory(Module): |
| """ |
| |
| Spatial memory system using population codes and LSTM for |
| sequential memory. |
| """ |
| def __init__(self, input_dim: int, output_dim: int): |
| super().__init__() |
| self.linear = nn.Linear(input_dim, output_dim) |
| self.norm = nn.LayerNorm(output_dim) |
| self.positional_encoder = self.population_code |
| internal method |
| self.memory_cell = nn.LSTM(output_dim, output_dim, |
| batch_first=True) |
|
|
| def population_code(self, x: torch.Tensor, pop_size: int) -> |
| torch.Tensor: |
| """Expands a scalar value into a population-coded (one-hot |
| like) vector.""" |
| |
| x_scalar = x.mean(dim=1) |
| x_normalized = torch.sigmoid(x_scalar) |
| |
| idx = (x_normalized * (pop_size - 1)).long().clamp(0, pop_size |
| - 1) |
|
|
| encoded = torch.zeros(x.size(0), pop_size, device=x.device) |
| |
| encoded.scatter_(1, idx.unsqueeze(1), 1.0) |
| return encoded |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| encoded = self.norm(self.linear(x)) |
| |
| pos_encoded = self.positional_encoder(encoded, |
| pop_size=self.linear.out_features) |
| |
| single step |
| memory_out, _ = self.memory_cell(pos_encoded.unsqueeze(1)) |
| |
| return memory_out.squeeze(1) |
|
|
|
|
| |
|
|
| class TaskHead(nn.Module): |
| """Generic task head for different output types.""" |
| def __init__(self, input_dim: int, output_dim: int, task_type: |
| str): |
| super().__init__() |
| self.task_type = task_type |
|
|
| self.head = nn.Sequential( |
| nn.Linear(input_dim, output_dim), |
| nn.Dropout(0.3) |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| |
| return self.head(x).squeeze(-1) if self.task_type in |
| ['binary', 'regression'] else self.head(x) |
|
|
|
|
| |
|
|
| class TaskReplayBuffer: |
| """ |
| A replay buffer that stores experiences tagged with their task ID. |
| Samples are drawn proportionally or randomly across tasks. |
| """ |
| def __init__(self, buffer_size: int = 1000, device: str = "cpu"): |
| self.buffer_size = buffer_size |
| self.task_buffers = {task_id: deque(maxlen=buffer_size) for |
| task_id in [Task.REGRESSION, Task.BINARY, Task.VISION]} |
| self.device = device |
|
|
| def add(self, task_id: int, states: torch.Tensor, labels: |
| torch.Tensor): |
| """Adds experiences (state-label pairs) to the buffer for a |
| specific task.""" |
| if task_id not in self.task_buffers: |
| |
| pre-initialized with all Task IDs |
| self.task_buffers[task_id] = |
| deque(maxlen=self.buffer_size) |
|
|
| for state, label in zip(states, labels): |
| |
| leaks |
| self.task_buffers[task_id].append((state.detach().cpu(), |
| label.detach().cpu())) |
|
|
| def sample(self, batch_size: int = 32) -> (torch.Tensor, |
| torch.Tensor, list): |
| """ |
| Samples a batch of experiences from the buffer, ensuring mixed |
| tasks. |
| Returns: (states, labels, list of corresponding task_ids) |
| """ |
|
|
| active_tasks = [tid for tid, buffer in |
| self.task_buffers.items() if len(buffer) > 0] |
| if not active_tasks: |
| return None, None, None |
|
|
| |
| num_tasks_to_sample = min(2, len(active_tasks)) |
| sampled_tasks_ids = random.sample(active_tasks, |
| num_tasks_to_sample) |
|
|
| batch_states, batch_labels, batch_task_ids = [], [], [] |
|
|
| |
| per_task_samples_base = batch_size // num_tasks_to_sample |
| remainder = batch_size % num_tasks_to_sample |
|
|
| for i, task_id in enumerate(sampled_tasks_ids): |
| task_list = list(self.task_buffers[task_id]) |
| deque to list for random.sample/choices |
| if not task_list: |
| continue |
|
|
| |
| k = per_task_samples_base + (1 if i < remainder else 0) |
| k = min(k, len(task_list)) |
| than available |
|
|
| if k == 0: continue |
|
|
| try: |
| samples = random.sample(task_list, k) |
| except ValueError: |
| |
| caught by min(k, len(task_list)) |
| |
| samples = random.choices(task_list, k=k) |
|
|
| for state, label in samples: |
| batch_states.append(state) |
| batch_labels.append(label) |
| batch_task_ids.append(task_id) |
|
|
| if not batch_states: |
| return None, None, None |
|
|
| |
| return ( |
| torch.stack(batch_states), |
|
|
| torch.stack(batch_labels), |
| batch_task_ids |
| task identities |
| ) |
|
|
|
|
| |
|
|
| class ModularBrainAgent(nn.Module): |
| """ |
| A comprehensive modular neural agent inspired by brain |
| architecture, |
| integrating spiking neurons, attention, and recurrent memory for |
| multi-task learning. |
| """ |
| def __init__(self, neuron_counts: dict = None): |
| super().__init__() |
|
|
| |
| if neuron_counts is None: |
| |
| neuron_counts = { |
| 'sensory': 4, |
| 'relay': 12, |
| 'interneurons': 2, |
| 'neuroendocrine': 8, |
| 'autonomic': 10, |
| 'mirror': 14, |
| 'place_grid': 16 |
| } |
| self.neuron_counts = neuron_counts |
|
|
| |
| self.encoders = nn.ModuleDict({ |
| 'regression': SharedEncoder(2, |
| self.neuron_counts['sensory']), |
| 'language': GRULanguage(len(vocab), |
| embedding_dim=self.neuron_counts['sensory'], |
| hidden_dim=self.neuron_counts['sensory']), |
| 'vision': |
| CNNVision(output_features=self.neuron_counts['sensory']) |
| }) |
|
|
| |
| self.sensory = SensoryProcessor( |
| input_dim=self.neuron_counts['sensory'], |
| output_dim=self.neuron_counts['sensory'] |
| ) |
|
|
| self.relay = RelayLayer( |
| input_dim=self.neuron_counts['sensory'], |
| SensoryProcessor |
| output_dim=self.neuron_counts['relay'] |
| ) |
| self.interneurons = InterneuronLogic( |
| input_dim=self.neuron_counts['relay'], |
| RelayLayer |
| output_dim=self.neuron_counts['interneurons'] |
| ) |
| self.neuroendocrine = NeuroendocrineModulator( |
| input_dim=self.neuron_counts['interneurons'], |
| output_dim=self.neuron_counts['neuroendocrine'] |
| ) |
| self.autonomic = AutonomicProcessor( |
| input_dim=self.neuron_counts['neuroendocrine'], |
| output_dim=self.neuron_counts['autonomic'] |
| ) |
| self.mirror = MirrorComparator( |
| input_dim=self.neuron_counts['autonomic'], |
| output_dim=self.neuron_counts['mirror'] |
| ) |
| self.place_grid = PlaceGridMemory( |
| input_dim=self.neuron_counts['mirror'], |
| output_dim=self.neuron_counts['place_grid'] |
| ) |
|
|
| |
| |
| self.connect_sensory_to_relay = |
| nn.Linear(self.neuron_counts['sensory'], self.neuron_counts['relay']) |
| self.connect_relay_to_inter = |
| nn.Linear(self.neuron_counts['relay'], |
| self.neuron_counts['interneurons']) |
| self.connect_inter_to_modulators = |
| nn.Linear(self.neuron_counts['interneurons'], |
| self.neuron_counts['neuroendocrine']) |
| self.connect_modulators_to_auto = |
| nn.Linear(self.neuron_counts['neuroendocrine'], |
| self.neuron_counts['autonomic']) |
| self.connect_auto_to_mirror = |
| nn.Linear(self.neuron_counts['autonomic'], |
| self.neuron_counts['mirror']) |
| self.connect_mirror_to_place = |
| nn.Linear(self.neuron_counts['mirror'], |
| self.neuron_counts['place_grid']) |
| |
| self.connect_place_to_relay = |
|
|
| nn.Linear(self.neuron_counts['place_grid'], |
| self.neuron_counts['relay']) |
|
|
| |
| self.feedback_relay_to_sensory = |
| nn.Linear(self.neuron_counts['relay'], self.neuron_counts['sensory']) |
| self.feedback_inter_to_relay = |
| nn.Linear(self.neuron_counts['interneurons'], |
| self.neuron_counts['relay']) |
|
|
| |
| self.task_heads = nn.ModuleDict({ |
| 'binary': TaskHead(self.neuron_counts['place_grid'], 1, |
| 'binary'), |
| 'vision': TaskHead(self.neuron_counts['place_grid'], 10, |
| 'vision'), |
| 'regression': TaskHead(self.neuron_counts['place_grid'], |
| 1, 'regression') |
| }) |
|
|
| def route_modules(self, x: torch.Tensor) -> tuple[torch.Tensor, |
| ...]: |
| """ |
| Defines the forward pass through the interconnected brain |
| modules. |
| Returns intermediate activations for potential monitoring or |
| later use. |
| """ |
| |
| h_sensory = self.sensory(x) |
| h_relay = self.connect_sensory_to_relay(h_sensory) |
| h_relay = self.relay(h_relay) |
|
|
| h_inter = self.connect_relay_to_inter(h_relay) |
| h_inter = self.interneurons(h_inter) |
|
|
| h_modulate = self.connect_inter_to_modulators(h_inter) |
| h_modulate = self.neuroendocrine(h_modulate) |
|
|
| h_auto = self.connect_modulators_to_auto(h_modulate) |
| h_auto = self.autonomic(h_auto) |
|
|
| h_mirror_result, _ = |
| self.mirror(self.connect_auto_to_mirror(h_auto)) |
| tuple |
| |
| forward always returns tuple |
| |
|
|
| |
| |
| h_mirror = h_mirror_result |
| unpack |
|
|
| h_place = self.connect_mirror_to_place(h_mirror) |
| h_place = self.place_grid(h_place) |
|
|
| |
| thalamus (relay) |
| |
| feeding back into spiking neurons potentially |
| h_relay = h_relay + |
| torch.relu(self.connect_place_to_relay(h_place)) |
|
|
| |
| h_sensory = h_sensory + |
| torch.relu(self.feedback_relay_to_sensory(h_relay)) |
| h_relay = h_relay + |
| torch.relu(self.feedback_inter_to_relay(h_inter)) |
|
|
| return h_relay, h_place, h_mirror, h_auto, h_modulate |
|
|
|
|
| def encode(self, x: torch.Tensor, task_id: int) -> torch.Tensor: |
| """Selects and applies the appropriate encoder based on task |
| ID and input shape.""" |
| |
| if task_id == Task.REGRESSION and x.dim() == 2 and x.size(1) |
| == 2: |
| return self.encoders['regression'](x.float()) |
| elif task_id == Task.BINARY and x.dim() == 2 and x.size(1) == |
| 10: |
| return self.encoders['language'](x.long()) |
| elif task_id == Task.VISION: |
| if x.dim() == 4: |
| return self.encoders['vision'](x.float()) |
| elif x.dim() == 2 and x.size(1) == 784: |
| (batch, 784) |
| return self.encoders['vision'](x.view(x.size(0), 1, |
| 28, 28)) |
| else: |
| raise ValueError(f"Unexpected vision input shape: |
| {x.shape}") |
| else: |
| |
| x is already "encoded" |
| |
|
|
| default encoder here. |
| |
| dimension if it falls through. |
| if x.size(-1) != self.neuron_counts['sensory']: |
| raise ValueError(f"Input {x.shape} does not match |
| sensory input dim {self.neuron_counts['sensory']} and no specific |
| encoder found for task {task_id}.") |
| return x.float() |
|
|
|
|
| def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor: |
| """ |
| Main forward pass of the Modular Brain Agent. |
| Encodes input, routes through brain modules, and selects task |
| head. |
| """ |
| if x.dim() == 1: |
| x = x.unsqueeze(0) |
|
|
| encoded = self.encode(x, task_id) |
| |
| for task heads |
| _, h_place, _, _, _ = self.route_modules(encoded) |
| h_place is used for task heads |
|
|
| |
| head_name = { |
| Task.REGRESSION: 'regression', |
| Task.BINARY: 'binary', |
| Task.VISION: 'vision' |
| }.get(task_id, 'regression') |
| task_id is unexpected |
|
|
| out = self.task_heads[head_name](h_place) |
| PlaceGridMemory output |
| return out |
|
|
|
|
| |
|
|
| def get_mnist_loader() -> (DataLoader, int): |
| """Returns MNIST DataLoader and its Task ID.""" |
| transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize((0.5,), (0.5,)) |
| ]) |
| dataset = datasets.MNIST(root='./data', train=True, download=True, |
| transform=transform) |
|
|
| return DataLoader(dataset, batch_size=4, shuffle=True), |
| Task.VISION |
|
|
|
|
| def get_imdb_loader() -> (DataLoader, int): |
| """Returns a synthetic IMDB-like DataLoader and its Task ID.""" |
| texts = [] |
| labels = [] |
|
|
| for i in range(1000): |
| sentence = random.choice(sample_sentences) |
| tokens = [vocab.get(word, vocab["<unk>"]) for word in |
| simple_tokenize(sentence)] |
|
|
| |
| while len(tokens) < 10: |
| tokens.append(vocab["<unk>"]) |
| if len(tokens) > 10: |
| tokens = tokens[:10] |
|
|
| texts.append(tokens) |
| labels.append(i % 2) |
|
|
| X = torch.tensor(texts).long() |
| Y = torch.tensor(labels).float() |
| dataset = TensorDataset(X, Y) |
| return DataLoader(dataset, batch_size=4, shuffle=True), |
| Task.BINARY |
|
|
|
|
| def get_regression_loader() -> (DataLoader, int): |
| """Returns a synthetic regression DataLoader and its Task ID.""" |
| X = torch.randn(1000, 2) |
| Y = ((X[:, 0] + X[:, 1]) / 2).float() |
| dataset = TensorDataset(X, Y) |
| return DataLoader(dataset, batch_size=4, shuffle=True), |
| Task.REGRESSION |
|
|
|
|
| |
|
|
| def inspect_shapes(agent: ModularBrainAgent): |
| """ |
| Prints the shapes of tensors at various points in the agent's |
| forward pass |
| to help verify architectural correctness. |
| """ |
| print("\n=== SHAPE INSPECTOR ===") |
|
|
| |
| debugging |
| agent.cpu() |
| agent.eval() |
| dropout off) |
|
|
| try: |
| |
| sample_input_reg = torch.randn(4, 2) |
| encoded_reg = agent.encode(sample_input_reg, Task.REGRESSION) |
| print(f"Encoded (Regression): {encoded_reg.shape}") |
| h_relay_reg, h_place_reg, h_mirror_reg, h_auto_reg, |
| h_modulate_reg = agent.route_modules(encoded_reg) |
| out_reg = agent.task_heads['regression'](h_place_reg) |
| print(f"Regression Path: |
| Sensory({agent.sensory(encoded_reg).shape}) -> |
| Relay({h_relay_reg.shape}) -> Place({h_place_reg.shape}) -> |
| Output({out_reg.shape})") |
|
|
| |
| |
| dim |
| sample_input_lang = torch.randint(0, len(vocab), (4, 10)) |
| encoded_lang = agent.encode(sample_input_lang, Task.BINARY) |
| Use BINARY task for language |
| print(f"Encoded (Language): {encoded_lang.shape}") |
| h_relay_lang, h_place_lang, h_mirror_lang, h_auto_lang, |
| h_modulate_lang = agent.route_modules(encoded_lang) |
| out_lang = agent.task_heads['binary'](h_place_lang) |
| print(f"Language Path: |
| Sensory({agent.sensory(encoded_lang).shape}) -> |
| Relay({h_relay_lang.shape}) -> Place({h_place_lang.shape}) -> |
| Output({out_lang.shape})") |
|
|
| |
| sample_input_vis = torch.randn(4, 1, 28, 28) |
| encoded_vis = agent.encode(sample_input_vis, Task.VISION) |
| print(f"Encoded (Vision): {encoded_vis.shape}") |
| h_relay_vis, h_place_vis, h_mirror_vis, h_auto_vis, |
| h_modulate_vis = agent.route_modules(encoded_vis) |
| out_vis = agent.task_heads['vision'](h_place_vis) |
| print(f"Vision Path: |
| Sensory({agent.sensory(encoded_vis).shape}) -> |
| Relay({h_relay_vis.shape}) -> Place({h_place_vis.shape}) -> |
| Output({out_vis.shape})") |
|
|
| except Exception as e: |
| print(f"Shape inspection failed: {e}") |
|
|
| |
| execution |
| raise e |
| finally: |
| agent.train() |
| print("=========================\n") |
|
|
|
|
| |
|
|
| def train(agent: ModularBrainAgent, episodes: int = 14400, |
| buffer_size: int = 1000, replay_freq: int = 5): |
| """ |
| Trains the ModularBrainAgent using a curriculum learning strategy |
| and experience replay. |
| """ |
| device = torch.device("cuda" if torch.cuda.is_available() else |
| "cpu") |
| agent.to(device) |
| print(f"Using device: {device}") |
|
|
| print("🧠 Running shape inspector...") |
| inspect_shapes(agent) |
|
|
| replay_buffer = TaskReplayBuffer(buffer_size=buffer_size, |
| device=device) |
|
|
| optimizer = torch.optim.AdamW(agent.parameters(), lr=0.001, |
| weight_decay=1e-5) |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, |
| patience=100, factor=0.5, verbose=True) |
|
|
| |
| def init_weights(m): |
| if isinstance(m, nn.Linear): |
| torch.nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| torch.nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Conv2d): |
| torch.nn.init.kaiming_uniform_(m.weight, |
| nonlinearity='relu') |
| if m.bias is not None: |
| torch.nn.init.zeros_(m.bias) |
| agent.apply(init_weights) |
|
|
| |
| log_dir = "runs/modular_brain_agent" |
| os.makedirs(log_dir, exist_ok=True) |
|
|
| writer = SummaryWriter(log_dir=log_dir) |
| print(f"TensorBoard logs are being saved to: {log_dir}") |
|
|
|
|
| best_loss = float('inf') |
| no_improvement = 0 |
| save_path = "best_model.pth" |
| task_history = {i: [] for i in [Task.REGRESSION, Task.BINARY, |
| Task.VISION]} |
| curriculum_stages = { |
| 0: [Task.REGRESSION], |
| 300: [Task.REGRESSION, Task.BINARY], |
| episodes |
| 600: [Task.REGRESSION, Task.BINARY, Task.VISION] |
| vision after 600 episodes |
| } |
|
|
| raw_loaders = { |
| Task.REGRESSION: get_regression_loader(), |
| Task.BINARY: get_imdb_loader(), |
| Task.VISION: get_mnist_loader() |
| } |
|
|
| loaders = {k: v[0] for k, v in raw_loaders.items()} |
| |
|
|
| |
| loaders_iter = { |
| Task.REGRESSION: iter(loaders[Task.REGRESSION]), |
| Task.BINARY: iter(loaders[Task.BINARY]), |
| Task.VISION: iter(loaders[Task.VISION]) |
| } |
|
|
| |
| def compute_loss(out: torch.Tensor, Y: torch.Tensor, task_id: int) |
| -> torch.Tensor: |
| """Computes loss based on task type, handling shape |
| alignments.""" |
| if task_id == Task.VISION: |
| Y = Y.long() |
| if out.dim() != 2 or Y.dim() != 1: |
| raise ValueError(f"Vision task expects out [batch, |
| num_classes], Y [batch]. Got {out.shape} vs {Y.shape}") |
| return F.cross_entropy(out, Y) |
|
|
| elif task_id == Task.BINARY: |
| |
| BCEWithLogitsLoss |
|
|
| if out.shape != Y.shape: |
| if out.numel() == Y.numel(): |
| out = out.view_as(Y) |
| target if num elements are same |
| else: |
| raise ValueError(f"Binary task shape mismatch: out |
| {out.shape} vs Y {Y.shape}") |
| return F.binary_cross_entropy_with_logits(out, Y.float()) |
|
|
| else: |
| if out.shape != Y.shape: |
| if out.numel() == Y.numel(): |
| out = out.view_as(Y) |
| target if num elements are same |
| else: |
| raise ValueError(f"Regression task shape mismatch: |
| out {out.shape} vs Y {Y.shape}") |
| return F.smooth_l1_loss(out, Y.float()) |
|
|
| loss_weights = {Task.REGRESSION: 1.5, Task.BINARY: 1.2, |
| Task.VISION: 1.2} |
| global_step = 0 |
|
|
| for ep in range(episodes): |
| agent.train() |
|
|
| |
| current_stage_episodes = 0 |
| for stage_start_ep, tasks in |
| sorted(curriculum_stages.items()): |
| if ep >= stage_start_ep: |
| current_stage_episodes = stage_start_ep |
| current_tasks = curriculum_stages[current_stage_episodes] |
|
|
| |
| task_id = np.random.choice(current_tasks) |
|
|
| |
| try: |
| X, Y = next(loaders_iter[task_id]) |
| except StopIteration: |
| |
| loaders_iter[task_id] = iter(loaders[task_id]) |
| X, Y = next(loaders_iter[task_id]) |
|
|
| X, Y = X.to(device), Y.to(device) |
|
|
| |
|
|
| out = agent(X, task_id) |
| loss = compute_loss(out, Y, task_id) * |
| loss_weights.get(task_id, 1.0) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(agent.parameters(), |
| max_norm=1.0) |
| optimizer.step() |
|
|
| |
| replay_buffer.add(task_id, X, Y) |
|
|
| |
| acc = 0.0 |
| with torch.no_grad(): |
| calculation |
| if task_id == Task.VISION: |
| pred = out.argmax(dim=1) |
| acc = (pred == Y).float().mean().item() |
| elif task_id == Task.BINARY: |
| pred = (torch.sigmoid(out) > 0.5).float() |
| acc = (pred == Y).float().mean().item() |
| else: |
| acc = ((out - Y).abs() < 0.2).float().mean().item() |
| within 0.2 error margin |
|
|
| task_history[task_id].append((loss.item(), acc)) |
|
|
| |
| if ep % replay_freq == 0: |
| replay_X_mixed, replay_Y_mixed, replay_task_ids_list = |
| replay_buffer.sample(batch_size=4) |
|
|
| if replay_X_mixed is not None: |
| replay_X_mixed = replay_X_mixed.to(device) |
| replay_Y_mixed = replay_Y_mixed.to(device) |
|
|
| unique_replay_tasks = list(set(replay_task_ids_list)) |
| total_replay_loss = 0.0 |
| num_replay_samples = 0 |
|
|
| |
| batch |
| for current_replay_task_id in unique_replay_tasks: |
| |
| indices_for_this_task = [i for i, tid in |
| enumerate(replay_task_ids_list) if tid == current_replay_task_id] |
|
|
|
|
| if not indices_for_this_task: |
| continue |
|
|
| task_replay_X = |
| replay_X_mixed[indices_for_this_task] |
| task_replay_Y = |
| replay_Y_mixed[indices_for_this_task] |
|
|
| |
| and task head |
| replay_out_for_task = agent(task_replay_X, |
| current_replay_task_id) |
|
|
| |
| replay_loss_for_task = |
| compute_loss(replay_out_for_task, task_replay_Y, |
| current_replay_task_id) |
| total_replay_loss += replay_loss_for_task * |
| len(indices_for_this_task) |
| num_replay_samples += len(indices_for_this_task) |
|
|
| if num_replay_samples > 0: |
| total_replay_loss /= num_replay_samples |
| over all replayed samples |
| optimizer.zero_grad() |
| total_replay_loss.backward() |
| torch.nn.utils.clip_grad_norm_(agent.parameters(), |
| max_norm=1.0) |
| optimizer.step() |
| |
| writer.add_scalar('Loss/Replay_Loss', |
| total_replay_loss.item(), global_step) |
|
|
|
|
| |
| scheduler.step(loss.item()) |
| task loss |
|
|
| |
| writer.add_scalar(f'Loss/Current_Task_{task_id}', loss.item(), |
| global_step) |
| writer.add_scalar(f'Accuracy/Current_Task_{task_id}', acc, |
| global_step) |
| writer.add_scalar('LearningRate', |
| optimizer.param_groups[0]['lr'], global_step) |
|
|
| global_step += 1 |
|
|
|
|
| |
| if loss.item() < best_loss: |
| best_loss = loss.item() |
| no_improvement = 0 |
| torch.save(agent.state_dict(), save_path) |
| |
| {best_loss:.4f}") |
| else: |
| no_improvement += 1 |
| |
| # Print summary periodically |
| if ep % 200 == 0: |
| print(f"\n--- Episode {ep} (Curriculum Stage: |
| {current_stage_episodes}) ---") |
| for t_id in sorted(task_history.keys()): |
| if task_history[t_id]: |
| # Get recent history, default to empty if not |
| enough entries |
| recent_losses = [x[0] for x in |
| task_history[t_id][-200:]] |
| recent_accs = [x[1] for x in |
| task_history[t_id][-200:]] |
| |
| avg_loss = np.mean(recent_losses) if recent_losses |
| else 0.0 |
| avg_acc = np.mean(recent_accs) if recent_accs else |
| 0.0 |
| |
| # Map task ID to a readable name |
| task_name = {Task.REGRESSION: "Regression", |
| Task.BINARY: "Binary", Task.VISION: "Vision"}.get(t_id, |
| f"Unknown_{t_id}") |
| print(f"Task {task_name} | Avg Loss: |
| {avg_loss:.3f} | Avg Acc: {avg_acc:.2f} ({len(recent_losses)} |
| samples)") |
| print(f"Overall Current Loss: {loss.item():.4f} | Best |
| Loss: {best_loss:.4f} | No Improvement: {no_improvement}") |
| print("--------------------\n") |
| |
| if no_improvement >= 1000: # Early stopping threshold |
| print(f"Early stopping at episode {ep} due to no |
| improvement for {no_improvement} steps.") |
| break |
| |
| writer.close() |
| print("✅ Training finished.") |
| return best_loss |
| |
| |
| |
| # --- 10. Main Execution Block --- |
| |
| if __name__ == "__main__": |
| print("🚀 Initializing Modular Brain Agent...") |
| |
| # Optional: Define custom neuron counts here |
| # custom_neuron_config = { |
| # 'sensory': 8, |
| # 'relay': 24, |
| # 'interneurons': 4, |
| # 'neuroendocrine': 16, |
| # 'autonomic': 20, |
| # 'mirror': 28, |
| # 'place_grid': 32 |
| # } |
| # agent = ModularBrainAgent(neuron_counts=custom_neuron_config) |
| |
| agent = ModularBrainAgent() # Using default neuron counts for now |
| |
| # Print model summary |
| total_params = sum(p.numel() for p in agent.parameters()) |
| trainable_params = sum(p.numel() for p in agent.parameters() if |
| p.requires_grad) |
| |
| print(f"📊 Model Summary:") |
| print(f" Total parameters: {total_params:,}") |
| print(f" Trainable parameters: {trainable_params:,}") |
| print(f" Model size: ~{total_params * 4 / (1024**2):.2f} MB |
| (approx. float32)") |
| |
| # Test forward pass on each task to verify shapes |
| print("\n🔍 Testing forward passes...") |
| try: |
| # Test regression task |
| test_reg = torch.randn(2, 2) |
| out_reg = agent(test_reg, Task.REGRESSION) |
| print(f"✅ Regression test passed: {test_reg.shape} -> |
| {out_reg.shape}") |
| |
| # Test binary classification task (language task) |
| test_bin = torch.randint(0, len(vocab), (2, 10)) |
| out_bin = agent(test_bin, Task.BINARY) |
| print(f"✅ Binary classification test passed: {test_bin.shape} |
| -> {out_bin.shape}") |
| |
| # Test vision task |
| |
| test_vis = torch.randn(2, 1, 28, 28) |
| out_vis = agent(test_vis, Task.VISION) |
| print(f"✅ Vision test passed: {test_vis.shape} -> |
| {out_vis.shape}") |
| |
| except Exception as e: |
| print(f"❌ Forward pass test failed: {e}") |
| # If forward pass fails, there's a fundamental issue, so exit |
| import sys |
| sys.exit(1) |
| |
| print("\n🎯 Starting training...") |
| |
| # Train the agent |
| try: |
| final_best_loss = train( |
| agent=agent, |
| episodes=14400, # Total training episodes |
| buffer_size=1000, # Replay buffer size |
| replay_freq=5 # Replay every N episodes |
| ) |
| |
| print(f"\n🎉 Training completed successfully!") |
| print(f"📈 Best loss achieved during training: |
| {final_best_loss:.4f}") |
| print(f"💾 Best model saved to: best_model.pth") |
| |
| except KeyboardInterrupt: |
| print("\n⏹ Training interrupted by user.") |
| except Exception as e: |
| print(f"\n❌ Training failed with error: {e}") |
| raise e |
| |
| print("\n🧠 Modular Brain Agent execution completed!") |