""" Spectral Decoupling for Multi-Task Learning ============================================= Implements L2 logit regularization to prevent gradient starvation in multi-task learning with shared representations. The core idea: add a penalty term lambda * ||logits_i||^2 to each task's loss. This prevents the network from becoming overly confident on any subset of tasks, keeping the gradient signal alive for all tasks. Why this addresses our failure mode: In our MTAN-Lite model, bytes 0 and 1 develop stronger feature responses early in training. Without regularization, the softmax outputs for these bytes become highly peaked (confident), which causes their gradients to dominate the shared backbone updates. The other 14 bytes receive vanishing gradients and never learn. Spectral Decoupling prevents this by penalizing large logit magnitudes. This keeps the softmax outputs "soft" (less confident), which maintains gradient flow to all tasks and prevents any single task from capturing the shared representation. Implementation: We modify the model architecture to output pre-softmax logits, apply a custom loss that combines cross-entropy with L2 logit regularization, and use a separate softmax layer for inference. Alternatively (simpler): we add activity_regularizer=L2(lambda) to the final Dense layer of each byte head. This is equivalent to penalizing the logits and is natively supported by Keras. Reference: Pezeshki, M., Kaba, S.-O., Bengio, Y., Courville, A., Precup, D., & Lajoie, G. (2021). Gradient Starvation: A Learning Priors Problem. NeurIPS 2021. """ import logging from typing import Optional import tensorflow as tf logger = logging.getLogger(__name__) def get_spectral_decoupling_regularizer( lambda_sd: float = 0.01, ) -> Optional[tf.keras.regularizers.Regularizer]: """ Return a Keras L2 activity regularizer for spectral decoupling. When applied to the final Dense(256) layer of each byte head, this penalizes large pre-softmax logit values, preventing the network from becoming overly confident on any single task. The regularization loss is: lambda_sd * sum(logits^2) Args: lambda_sd: Regularization strength. Higher values = stronger decoupling. Recommended range: 0.001 to 0.1. - 0.001: Very mild (may not prevent starvation) - 0.01: Moderate (recommended starting point) - 0.1: Strong (may slow convergence) Returns: A Keras L2 regularizer instance, or None if lambda_sd <= 0. """ if lambda_sd <= 0: logger.info("Spectral Decoupling DISABLED (lambda_sd=%.4f)", lambda_sd) return None regularizer = tf.keras.regularizers.L2(lambda_sd) logger.info( "Spectral Decoupling regularizer created: lambda=%.4f " "(L2 penalty on pre-softmax logits to prevent gradient starvation)", lambda_sd, ) return regularizer