| """ |
| Spectral Decoupling for Multi-Task Learning |
| ============================================= |
| Implements L2 logit regularization to prevent gradient starvation in |
| multi-task learning with shared representations. |
| |
| The core idea: add a penalty term lambda * ||logits_i||^2 to each task's |
| loss. This prevents the network from becoming overly confident on any |
| subset of tasks, keeping the gradient signal alive for all tasks. |
| |
| Why this addresses our failure mode: |
| In our MTAN-Lite model, bytes 0 and 1 develop stronger feature |
| responses early in training. Without regularization, the softmax |
| outputs for these bytes become highly peaked (confident), which |
| causes their gradients to dominate the shared backbone updates. |
| The other 14 bytes receive vanishing gradients and never learn. |
| |
| Spectral Decoupling prevents this by penalizing large logit |
| magnitudes. This keeps the softmax outputs "soft" (less confident), |
| which maintains gradient flow to all tasks and prevents any single |
| task from capturing the shared representation. |
| |
| Implementation: |
| We modify the model architecture to output pre-softmax logits, |
| apply a custom loss that combines cross-entropy with L2 logit |
| regularization, and use a separate softmax layer for inference. |
| |
| Alternatively (simpler): we add activity_regularizer=L2(lambda) |
| to the final Dense layer of each byte head. This is equivalent |
| to penalizing the logits and is natively supported by Keras. |
| |
| Reference: |
| Pezeshki, M., Kaba, S.-O., Bengio, Y., Courville, A., Precup, D., |
| & Lajoie, G. (2021). Gradient Starvation: A Learning Priors Problem. |
| NeurIPS 2021. |
| """ |
|
|
| import logging |
| from typing import Optional |
|
|
| import tensorflow as tf |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def get_spectral_decoupling_regularizer( |
| lambda_sd: float = 0.01, |
| ) -> Optional[tf.keras.regularizers.Regularizer]: |
| """ |
| Return a Keras L2 activity regularizer for spectral decoupling. |
| |
| When applied to the final Dense(256) layer of each byte head, |
| this penalizes large pre-softmax logit values, preventing the |
| network from becoming overly confident on any single task. |
| |
| The regularization loss is: lambda_sd * sum(logits^2) |
| |
| Args: |
| lambda_sd: Regularization strength. Higher values = stronger |
| decoupling. Recommended range: 0.001 to 0.1. |
| - 0.001: Very mild (may not prevent starvation) |
| - 0.01: Moderate (recommended starting point) |
| - 0.1: Strong (may slow convergence) |
| |
| Returns: |
| A Keras L2 regularizer instance, or None if lambda_sd <= 0. |
| """ |
| if lambda_sd <= 0: |
| logger.info("Spectral Decoupling DISABLED (lambda_sd=%.4f)", lambda_sd) |
| return None |
|
|
| regularizer = tf.keras.regularizers.L2(lambda_sd) |
| logger.info( |
| "Spectral Decoupling regularizer created: lambda=%.4f " |
| "(L2 penalty on pre-softmax logits to prevent gradient starvation)", |
| lambda_sd, |
| ) |
| return regularizer |
|
|