""" GLADIUS v2.0 — Cognition Auxiliary Loss The Cognition module was dormant across ALL experiments (0.0000% weight change) because its outputs are never consumed by a loss function. Gradient descent cannot activate what it cannot reach. This module provides auxiliary loss terms that give the Cognition module gradient signal by: 1. Mode prediction loss — predict the modality/task of the current batch 2. Attention gate loss — encourage the attention filter to actually filter 3. Consistency loss — cognitive state should be stable within a batch Usage: from kernel.cognition_loss import CognitionAuxLoss aux_loss = CognitionAuxLoss(config, num_tasks=5) # During training: result = model(input_ids, timestamp=t) lm_loss = cross_entropy(result['logits'], targets) cog_loss = aux_loss( hidden=result['_hidden'], # Need to expose this from kernel mode_probs=result['mode_probs'], task_label=task_id # 0=text, 1=math, 2=vision, 3=bytes, 4=video ) total_loss = lm_loss + 0.1 * cog_loss """ import torch import torch.nn as nn import torch.nn.functional as F from .config import KernelConfig class CognitionAuxLoss(nn.Module): """ Auxiliary loss that gives the Cognition module gradient signal. Three components: 1. Task classification: predict which modality is being processed 2. Gate utilization: encourage non-trivial attention gating 3. State consistency: cognitive state should be smooth within a batch """ def __init__(self, config: KernelConfig, num_tasks: int = 5): super().__init__() # Task classifier head — maps cognitive state to task prediction self.task_head = nn.Sequential( nn.Linear(config.cognition_state_dim, config.cognition_state_dim), nn.SiLU(), nn.Linear(config.cognition_state_dim, num_tasks), ) # Gate diversity target self.register_buffer('target_gate_entropy', torch.tensor(0.5)) # Don't want uniform, don't want collapsed # Loss weights self.task_weight = 1.0 self.gate_weight = 0.3 self.consistency_weight = 0.1 def forward( self, cognitive_state: torch.Tensor, mode_probs: torch.Tensor, gate_values: torch.Tensor | None = None, task_label: int | torch.Tensor | None = None, ) -> dict: """ Compute cognition auxiliary losses. Args: cognitive_state: (batch, cognition_state_dim) — from StateMonitor mode_probs: (batch, num_modes) — from HeartbeatScheduler gate_values: (batch, seq_len, 1) — from AttentionFilter (optional) task_label: scalar or (batch,) — task ID for classification Returns: dict with: loss: total auxiliary loss (scalar) task_loss: task classification loss gate_loss: gate utilization loss consistency_loss: state consistency loss """ losses = {} total = torch.tensor(0.0, device=cognitive_state.device) # 1. Task classification loss if task_label is not None: task_logits = self.task_head(cognitive_state) # (B, num_tasks) if isinstance(task_label, int): task_label = torch.tensor([task_label] * cognitive_state.size(0), device=cognitive_state.device) task_loss = F.cross_entropy(task_logits, task_label) losses['task_loss'] = task_loss total = total + self.task_weight * task_loss # 2. Gate utilization loss (encourage non-trivial gating) if gate_values is not None: # Mean gate activation — should not be all-pass (1.0) or all-block (0.0) mean_gate = gate_values.mean() # Penalize deviation from target entropy gate_loss = (mean_gate - 0.5).pow(2) + (1.0 - gate_values.var()).clamp(min=0) losses['gate_loss'] = gate_loss total = total + self.gate_weight * gate_loss # 3. Mode consistency loss # Mode predictions within a batch should agree (same task = same mode) if mode_probs.size(0) > 1: # KL divergence between each sample's mode distribution and batch mean mean_probs = mode_probs.mean(dim=0, keepdim=True) consistency_loss = F.kl_div( mode_probs.log().clamp(min=-10), mean_probs.expand_as(mode_probs), reduction='batchmean', log_target=False, ) losses['consistency_loss'] = consistency_loss total = total + self.consistency_weight * consistency_loss losses['loss'] = total return losses