| """ |
| Gradient Norm Logger for W&B |
| ============================ |
| A Keras callback that logs per-layer gradient norms to Weights & Biases |
| at the end of each epoch. This provides direct visibility into gradient |
| flow across shared convolutions, per-task BatchNorm layers, per-task |
| classification heads, and sigmoid gates. |
| |
| Logged metrics (all under the "gradients/" prefix in W&B): |
| - gradients/total_norm : Global L2 norm of all gradients |
| - gradients/shared_conv_norm : L2 norm of shared conv layer gradients |
| - gradients/layer/{name}_norm : Per-layer L2 gradient norm |
| - gradients/task/{byte_i}_norm : Per-task aggregate gradient norm |
| - gradients/max_layer_norm : Maximum gradient norm across all layers |
| - gradients/min_layer_norm : Minimum gradient norm across all layers |
| - gradients/norm_ratio : max/min ratio (gradient imbalance indicator) |
| |
| Usage: |
| The callback is automatically added when ``wandb_project`` is set in |
| the MTLTrainer. It uses a single batch from the validation set to |
| compute gradients via ``tf.GradientTape``, keeping the overhead minimal. |
| |
| Note on Keras 3 / TF 2.16+: |
| In Keras 3, ``variable.name`` returns only the weight name (e.g. |
| ``kernel``, ``gamma``) without a layer prefix. To obtain the full |
| layer-qualified name we iterate over ``model.layers`` and build an |
| ``id(variable) → layer.name`` mapping at the start of training. |
| |
| References: |
| - Tang et al., "Improving Training Stability for Multitask Ranking |
| Models", KDD 2023 — gradient norm monitoring for instability detection. |
| - Chen et al., "GradNorm: Gradient Normalization for Adaptive Loss |
| Balancing in Multi-Task Networks", ICML 2018. |
| """ |
|
|
| import logging |
| import re |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import tensorflow as tf |
| from tensorflow import keras |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class GradientNormLogger(keras.callbacks.Callback): |
| """ |
| Logs per-layer and per-task gradient norms to W&B at each epoch end. |
| |
| The callback computes gradients using a single validation batch to |
| minimize training overhead (~2-5% additional time per epoch). |
| |
| Args: |
| val_data: Validation data as (x, y) tuple. For LMIC models, x is |
| a dict of per-byte inputs; for HPS/MTAN-Lite, x is a single |
| array. |
| log_every_n_epochs: How often to compute and log gradients. |
| Default is 1 (every epoch). Set higher to reduce overhead. |
| batch_size: Number of samples to use for gradient computation. |
| Default is 128. Smaller values reduce memory; larger values |
| give more stable gradient estimates. |
| """ |
|
|
| def __init__( |
| self, |
| val_data: Tuple, |
| log_every_n_epochs: int = 1, |
| batch_size: int = 128, |
| ) -> None: |
| super().__init__() |
| self.val_data = val_data |
| self.log_every_n_epochs = log_every_n_epochs |
| self.batch_size = batch_size |
| self._wandb = None |
| |
| self._var_id_to_qualified_name: Dict[int, str] = {} |
|
|
| def on_train_begin(self, logs: Optional[Dict] = None) -> None: |
| """Import wandb and build the variable-to-layer-name mapping.""" |
| try: |
| import wandb |
| self._wandb = wandb |
| except ImportError: |
| logger.warning( |
| "wandb not installed; GradientNormLogger will be disabled." |
| ) |
|
|
| |
| |
| |
| self._var_id_to_qualified_name = {} |
| if self.model is not None: |
| for layer in self.model.layers: |
| for var in layer.trainable_variables: |
| qualified = f"{layer.name}/{var.name}" |
| self._var_id_to_qualified_name[id(var)] = qualified |
|
|
| logger.info( |
| "GradientNormLogger: mapped %d trainable variables across %d layers", |
| len(self._var_id_to_qualified_name), |
| len([l for l in self.model.layers if l.trainable_variables]), |
| ) |
|
|
| def _get_val_batch(self) -> Tuple: |
| """Extract a single batch from validation data for gradient computation.""" |
| x_val, y_val = self.val_data |
| n = self.batch_size |
|
|
| if isinstance(x_val, dict): |
| |
| x_batch = {k: v[:n] for k, v in x_val.items()} |
| else: |
| |
| x_batch = x_val[:n] |
|
|
| if isinstance(y_val, dict): |
| y_batch = {k: v[:n] for k, v in y_val.items()} |
| else: |
| y_batch = y_val[:n] |
|
|
| return x_batch, y_batch |
|
|
| def _get_qualified_name(self, var: tf.Variable) -> str: |
| """ |
| Get the fully-qualified name for a variable. |
| |
| Tries the pre-built ``id(var) → name`` mapping first (Keras 3 |
| compatible). Falls back to ``var.name`` for older Keras versions |
| where variable names already include the layer prefix. |
| """ |
| return self._var_id_to_qualified_name.get(id(var), var.name) |
|
|
| @staticmethod |
| def _classify_layer(name: str) -> Tuple[str, Optional[int]]: |
| """ |
| Classify a layer by its role and associated task (byte index). |
| |
| Supports both HPS/MTAN-Lite naming (``conv1d``, ``byte_0_bn``, |
| ``byte_0_gate``) and LMIC-TSBN naming (``shared_conv0``, |
| ``tsbn_L0_byte0``, ``gate_L0_byte0``). |
| |
| Args: |
| name: Fully-qualified variable name, e.g. |
| ``shared_conv0/kernel`` or ``tsbn_L1_byte5/gamma``. |
| |
| Returns: |
| (category, byte_index) where category is one of: |
| "shared_conv", "task_bn", "task_gate", "task_head", "other" |
| and byte_index is None for shared layers. |
| """ |
| |
| |
| |
| if ("conv1d" in name or "shared_conv" in name) and "byte" not in name: |
| return "shared_conv", None |
|
|
| |
| |
| |
| bn_match = re.search(r"byte_(\d+)_bn|tsbn_L\d+_byte(\d+)", name) |
| if bn_match: |
| byte_idx = bn_match.group(1) or bn_match.group(2) |
| return "task_bn", int(byte_idx) |
|
|
| |
| |
| |
| gate_match = re.search(r"byte_(\d+)_gate|gate_L\d+_byte(\d+)", name) |
| if gate_match: |
| byte_idx = gate_match.group(1) or gate_match.group(2) |
| return "task_gate", int(gate_match.group(1) or gate_match.group(2)) |
|
|
| |
| head_match = re.search(r"byte_(\d+)", name) |
| if head_match: |
| return "task_head", int(head_match.group(1)) |
|
|
| |
| if "batch_norm" in name or "bn" in name: |
| return "shared_bn", None |
|
|
| return "other", None |
|
|
| def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: |
| """Compute and log gradient norms at the end of each epoch.""" |
| if self._wandb is None: |
| return |
|
|
| if (epoch + 1) % self.log_every_n_epochs != 0: |
| return |
|
|
| try: |
| x_batch, y_batch = self._get_val_batch() |
|
|
| |
| with tf.GradientTape() as tape: |
| predictions = self.model(x_batch, training=True) |
| loss = self.model.compiled_loss(y_batch, predictions) |
|
|
| trainable_vars = self.model.trainable_variables |
| gradients = tape.gradient(loss, trainable_vars) |
|
|
| |
| metrics: Dict[str, float] = {} |
| per_layer_norms: Dict[str, float] = {} |
| per_task_norms: Dict[int, List[float]] = {i: [] for i in range(16)} |
| category_norms: Dict[str, List[float]] = { |
| "shared_conv": [], |
| "shared_bn": [], |
| "task_bn": [], |
| "task_gate": [], |
| "task_head": [], |
| "other": [], |
| } |
|
|
| all_grad_norms: List[float] = [] |
|
|
| for var, grad in zip(trainable_vars, gradients): |
| if grad is None: |
| continue |
|
|
| grad_norm = float(tf.norm(grad).numpy()) |
| |
| qualified_name = self._get_qualified_name(var) |
|
|
| |
| |
| clean_name = qualified_name.replace(":", "_").replace("/", "_") |
| per_layer_norms[clean_name] = grad_norm |
| all_grad_norms.append(grad_norm) |
|
|
| |
| category, byte_idx = self._classify_layer(qualified_name) |
| category_norms[category].append(grad_norm) |
|
|
| if byte_idx is not None: |
| per_task_norms[byte_idx].append(grad_norm) |
|
|
| if not all_grad_norms: |
| return |
|
|
| |
| for name, norm in per_layer_norms.items(): |
| metrics[f"gradients/layer/{name}"] = norm |
|
|
| |
| for byte_idx, norms in per_task_norms.items(): |
| if norms: |
| task_total = sum(n ** 2 for n in norms) ** 0.5 |
| metrics[f"gradients/task/byte_{byte_idx}_norm"] = task_total |
|
|
| |
| for category, norms in category_norms.items(): |
| if norms: |
| cat_total = sum(n ** 2 for n in norms) ** 0.5 |
| metrics[f"gradients/category/{category}_norm"] = cat_total |
|
|
| |
| total_norm = sum(n ** 2 for n in all_grad_norms) ** 0.5 |
| max_norm = max(all_grad_norms) |
| min_norm = min(all_grad_norms) if min(all_grad_norms) > 0 else 1e-10 |
| norm_ratio = max_norm / min_norm |
|
|
| metrics["gradients/total_norm"] = total_norm |
| metrics["gradients/max_layer_norm"] = max_norm |
| metrics["gradients/min_layer_norm"] = min(all_grad_norms) |
| metrics["gradients/norm_ratio"] = norm_ratio |
|
|
| |
| self._wandb.log(metrics, commit=False) |
|
|
| |
| if (epoch + 1) % 10 == 0: |
| shared_conv_norm = ( |
| sum(n ** 2 for n in category_norms["shared_conv"]) ** 0.5 |
| if category_norms["shared_conv"] |
| else 0.0 |
| ) |
| task_bn_norm = ( |
| sum(n ** 2 for n in category_norms["task_bn"]) ** 0.5 |
| if category_norms["task_bn"] |
| else 0.0 |
| ) |
| task_gate_norm = ( |
| sum(n ** 2 for n in category_norms["task_gate"]) ** 0.5 |
| if category_norms["task_gate"] |
| else 0.0 |
| ) |
| logger.info( |
| "Gradient norms at epoch %d: total=%.4f, max=%.4f, " |
| "min=%.6f, ratio=%.1f | shared_conv=%.4f, " |
| "task_bn=%.4f, task_gate=%.4f", |
| epoch + 1, |
| total_norm, |
| max_norm, |
| min(all_grad_norms), |
| norm_ratio, |
| shared_conv_norm, |
| task_bn_norm, |
| task_gate_norm, |
| ) |
|
|
| except Exception as e: |
| |
| logger.warning("GradientNormLogger error at epoch %d: %s", epoch, e) |
|
|