| """ |
| GLADIUS v2.0 β Cognition Auxiliary Loss |
| |
| The Cognition module was dormant across ALL experiments (0.0000% weight change) |
| because its outputs are never consumed by a loss function. Gradient descent |
| cannot activate what it cannot reach. |
| |
| This module provides auxiliary loss terms that give the Cognition module |
| gradient signal by: |
| 1. Mode prediction loss β predict the modality/task of the current batch |
| 2. Attention gate loss β encourage the attention filter to actually filter |
| 3. Consistency loss β cognitive state should be stable within a batch |
| |
| Usage: |
| from kernel.cognition_loss import CognitionAuxLoss |
| |
| aux_loss = CognitionAuxLoss(config, num_tasks=5) |
| |
| # During training: |
| result = model(input_ids, timestamp=t) |
| lm_loss = cross_entropy(result['logits'], targets) |
| |
| cog_loss = aux_loss( |
| hidden=result['_hidden'], # Need to expose this from kernel |
| mode_probs=result['mode_probs'], |
| task_label=task_id # 0=text, 1=math, 2=vision, 3=bytes, 4=video |
| ) |
| |
| total_loss = lm_loss + 0.1 * cog_loss |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .config import KernelConfig |
|
|
|
|
| class CognitionAuxLoss(nn.Module): |
| """ |
| Auxiliary loss that gives the Cognition module gradient signal. |
| |
| Three components: |
| 1. Task classification: predict which modality is being processed |
| 2. Gate utilization: encourage non-trivial attention gating |
| 3. State consistency: cognitive state should be smooth within a batch |
| """ |
| |
| def __init__(self, config: KernelConfig, num_tasks: int = 5): |
| super().__init__() |
| |
| |
| self.task_head = nn.Sequential( |
| nn.Linear(config.cognition_state_dim, config.cognition_state_dim), |
| nn.SiLU(), |
| nn.Linear(config.cognition_state_dim, num_tasks), |
| ) |
| |
| |
| self.register_buffer('target_gate_entropy', |
| torch.tensor(0.5)) |
| |
| |
| self.task_weight = 1.0 |
| self.gate_weight = 0.3 |
| self.consistency_weight = 0.1 |
| |
| def forward( |
| self, |
| cognitive_state: torch.Tensor, |
| mode_probs: torch.Tensor, |
| gate_values: torch.Tensor | None = None, |
| task_label: int | torch.Tensor | None = None, |
| ) -> dict: |
| """ |
| Compute cognition auxiliary losses. |
| |
| Args: |
| cognitive_state: (batch, cognition_state_dim) β from StateMonitor |
| mode_probs: (batch, num_modes) β from HeartbeatScheduler |
| gate_values: (batch, seq_len, 1) β from AttentionFilter (optional) |
| task_label: scalar or (batch,) β task ID for classification |
| |
| Returns: |
| dict with: |
| loss: total auxiliary loss (scalar) |
| task_loss: task classification loss |
| gate_loss: gate utilization loss |
| consistency_loss: state consistency loss |
| """ |
| losses = {} |
| total = torch.tensor(0.0, device=cognitive_state.device) |
| |
| |
| if task_label is not None: |
| task_logits = self.task_head(cognitive_state) |
| |
| if isinstance(task_label, int): |
| task_label = torch.tensor([task_label] * cognitive_state.size(0), |
| device=cognitive_state.device) |
| |
| task_loss = F.cross_entropy(task_logits, task_label) |
| losses['task_loss'] = task_loss |
| total = total + self.task_weight * task_loss |
| |
| |
| if gate_values is not None: |
| |
| mean_gate = gate_values.mean() |
| |
| gate_loss = (mean_gate - 0.5).pow(2) + (1.0 - gate_values.var()).clamp(min=0) |
| losses['gate_loss'] = gate_loss |
| total = total + self.gate_weight * gate_loss |
| |
| |
| |
| if mode_probs.size(0) > 1: |
| |
| mean_probs = mode_probs.mean(dim=0, keepdim=True) |
| consistency_loss = F.kl_div( |
| mode_probs.log().clamp(min=-10), |
| mean_probs.expand_as(mode_probs), |
| reduction='batchmean', |
| log_target=False, |
| ) |
| losses['consistency_loss'] = consistency_loss |
| total = total + self.consistency_weight * consistency_loss |
| |
| losses['loss'] = total |
| return losses |
|
|