ascad-training-pipeline / src /spectral_decoupling.py
lemousehunter
v3: Add DTP + Spectral Decoupling, fix GradNorm OOM, fix _fail_job cancel
283a882
"""
Spectral Decoupling for Multi-Task Learning
=============================================
Implements L2 logit regularization to prevent gradient starvation in
multi-task learning with shared representations.
The core idea: add a penalty term lambda * ||logits_i||^2 to each task's
loss. This prevents the network from becoming overly confident on any
subset of tasks, keeping the gradient signal alive for all tasks.
Why this addresses our failure mode:
In our MTAN-Lite model, bytes 0 and 1 develop stronger feature
responses early in training. Without regularization, the softmax
outputs for these bytes become highly peaked (confident), which
causes their gradients to dominate the shared backbone updates.
The other 14 bytes receive vanishing gradients and never learn.
Spectral Decoupling prevents this by penalizing large logit
magnitudes. This keeps the softmax outputs "soft" (less confident),
which maintains gradient flow to all tasks and prevents any single
task from capturing the shared representation.
Implementation:
We modify the model architecture to output pre-softmax logits,
apply a custom loss that combines cross-entropy with L2 logit
regularization, and use a separate softmax layer for inference.
Alternatively (simpler): we add activity_regularizer=L2(lambda)
to the final Dense layer of each byte head. This is equivalent
to penalizing the logits and is natively supported by Keras.
Reference:
Pezeshki, M., Kaba, S.-O., Bengio, Y., Courville, A., Precup, D.,
& Lajoie, G. (2021). Gradient Starvation: A Learning Priors Problem.
NeurIPS 2021.
"""
import logging
from typing import Optional
import tensorflow as tf
logger = logging.getLogger(__name__)
def get_spectral_decoupling_regularizer(
lambda_sd: float = 0.01,
) -> Optional[tf.keras.regularizers.Regularizer]:
"""
Return a Keras L2 activity regularizer for spectral decoupling.
When applied to the final Dense(256) layer of each byte head,
this penalizes large pre-softmax logit values, preventing the
network from becoming overly confident on any single task.
The regularization loss is: lambda_sd * sum(logits^2)
Args:
lambda_sd: Regularization strength. Higher values = stronger
decoupling. Recommended range: 0.001 to 0.1.
- 0.001: Very mild (may not prevent starvation)
- 0.01: Moderate (recommended starting point)
- 0.1: Strong (may slow convergence)
Returns:
A Keras L2 regularizer instance, or None if lambda_sd <= 0.
"""
if lambda_sd <= 0:
logger.info("Spectral Decoupling DISABLED (lambda_sd=%.4f)", lambda_sd)
return None
regularizer = tf.keras.regularizers.L2(lambda_sd)
logger.info(
"Spectral Decoupling regularizer created: lambda=%.4f "
"(L2 penalty on pre-softmax logits to prevent gradient starvation)",
lambda_sd,
)
return regularizer