gladius-v2-kernel / kernel /cognition_loss.py
amuzetnoM's picture
WYRM kernel source (v27 FINAL)
9463e5c verified
"""
GLADIUS v2.0 β€” Cognition Auxiliary Loss
The Cognition module was dormant across ALL experiments (0.0000% weight change)
because its outputs are never consumed by a loss function. Gradient descent
cannot activate what it cannot reach.
This module provides auxiliary loss terms that give the Cognition module
gradient signal by:
1. Mode prediction loss β€” predict the modality/task of the current batch
2. Attention gate loss β€” encourage the attention filter to actually filter
3. Consistency loss β€” cognitive state should be stable within a batch
Usage:
from kernel.cognition_loss import CognitionAuxLoss
aux_loss = CognitionAuxLoss(config, num_tasks=5)
# During training:
result = model(input_ids, timestamp=t)
lm_loss = cross_entropy(result['logits'], targets)
cog_loss = aux_loss(
hidden=result['_hidden'], # Need to expose this from kernel
mode_probs=result['mode_probs'],
task_label=task_id # 0=text, 1=math, 2=vision, 3=bytes, 4=video
)
total_loss = lm_loss + 0.1 * cog_loss
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .config import KernelConfig
class CognitionAuxLoss(nn.Module):
"""
Auxiliary loss that gives the Cognition module gradient signal.
Three components:
1. Task classification: predict which modality is being processed
2. Gate utilization: encourage non-trivial attention gating
3. State consistency: cognitive state should be smooth within a batch
"""
def __init__(self, config: KernelConfig, num_tasks: int = 5):
super().__init__()
# Task classifier head β€” maps cognitive state to task prediction
self.task_head = nn.Sequential(
nn.Linear(config.cognition_state_dim, config.cognition_state_dim),
nn.SiLU(),
nn.Linear(config.cognition_state_dim, num_tasks),
)
# Gate diversity target
self.register_buffer('target_gate_entropy',
torch.tensor(0.5)) # Don't want uniform, don't want collapsed
# Loss weights
self.task_weight = 1.0
self.gate_weight = 0.3
self.consistency_weight = 0.1
def forward(
self,
cognitive_state: torch.Tensor,
mode_probs: torch.Tensor,
gate_values: torch.Tensor | None = None,
task_label: int | torch.Tensor | None = None,
) -> dict:
"""
Compute cognition auxiliary losses.
Args:
cognitive_state: (batch, cognition_state_dim) β€” from StateMonitor
mode_probs: (batch, num_modes) β€” from HeartbeatScheduler
gate_values: (batch, seq_len, 1) β€” from AttentionFilter (optional)
task_label: scalar or (batch,) β€” task ID for classification
Returns:
dict with:
loss: total auxiliary loss (scalar)
task_loss: task classification loss
gate_loss: gate utilization loss
consistency_loss: state consistency loss
"""
losses = {}
total = torch.tensor(0.0, device=cognitive_state.device)
# 1. Task classification loss
if task_label is not None:
task_logits = self.task_head(cognitive_state) # (B, num_tasks)
if isinstance(task_label, int):
task_label = torch.tensor([task_label] * cognitive_state.size(0),
device=cognitive_state.device)
task_loss = F.cross_entropy(task_logits, task_label)
losses['task_loss'] = task_loss
total = total + self.task_weight * task_loss
# 2. Gate utilization loss (encourage non-trivial gating)
if gate_values is not None:
# Mean gate activation β€” should not be all-pass (1.0) or all-block (0.0)
mean_gate = gate_values.mean()
# Penalize deviation from target entropy
gate_loss = (mean_gate - 0.5).pow(2) + (1.0 - gate_values.var()).clamp(min=0)
losses['gate_loss'] = gate_loss
total = total + self.gate_weight * gate_loss
# 3. Mode consistency loss
# Mode predictions within a batch should agree (same task = same mode)
if mode_probs.size(0) > 1:
# KL divergence between each sample's mode distribution and batch mean
mean_probs = mode_probs.mean(dim=0, keepdim=True)
consistency_loss = F.kl_div(
mode_probs.log().clamp(min=-10),
mean_probs.expand_as(mode_probs),
reduction='batchmean',
log_target=False,
)
losses['consistency_loss'] = consistency_loss
total = total + self.consistency_weight * consistency_loss
losses['loss'] = total
return losses