""" Gradient Norm Logger for W&B ============================ A Keras callback that logs per-layer gradient norms to Weights & Biases at the end of each epoch. This provides direct visibility into gradient flow across shared convolutions, per-task BatchNorm layers, per-task classification heads, and sigmoid gates. Logged metrics (all under the "gradients/" prefix in W&B): - gradients/total_norm : Global L2 norm of all gradients - gradients/shared_conv_norm : L2 norm of shared conv layer gradients - gradients/layer/{name}_norm : Per-layer L2 gradient norm - gradients/task/{byte_i}_norm : Per-task aggregate gradient norm - gradients/max_layer_norm : Maximum gradient norm across all layers - gradients/min_layer_norm : Minimum gradient norm across all layers - gradients/norm_ratio : max/min ratio (gradient imbalance indicator) Usage: The callback is automatically added when ``wandb_project`` is set in the MTLTrainer. It uses a single batch from the validation set to compute gradients via ``tf.GradientTape``, keeping the overhead minimal. Note on Keras 3 / TF 2.16+: In Keras 3, ``variable.name`` returns only the weight name (e.g. ``kernel``, ``gamma``) without a layer prefix. To obtain the full layer-qualified name we iterate over ``model.layers`` and build an ``id(variable) → layer.name`` mapping at the start of training. References: - Tang et al., "Improving Training Stability for Multitask Ranking Models", KDD 2023 — gradient norm monitoring for instability detection. - Chen et al., "GradNorm: Gradient Normalization for Adaptive Loss Balancing in Multi-Task Networks", ICML 2018. """ import logging import re from typing import Any, Dict, List, Optional, Tuple import tensorflow as tf from tensorflow import keras logger = logging.getLogger(__name__) class GradientNormLogger(keras.callbacks.Callback): """ Logs per-layer and per-task gradient norms to W&B at each epoch end. The callback computes gradients using a single validation batch to minimize training overhead (~2-5% additional time per epoch). Args: val_data: Validation data as (x, y) tuple. For LMIC models, x is a dict of per-byte inputs; for HPS/MTAN-Lite, x is a single array. log_every_n_epochs: How often to compute and log gradients. Default is 1 (every epoch). Set higher to reduce overhead. batch_size: Number of samples to use for gradient computation. Default is 128. Smaller values reduce memory; larger values give more stable gradient estimates. """ def __init__( self, val_data: Tuple, log_every_n_epochs: int = 1, batch_size: int = 128, ) -> None: super().__init__() self.val_data = val_data self.log_every_n_epochs = log_every_n_epochs self.batch_size = batch_size self._wandb = None # Built lazily in on_train_begin: maps id(variable) → qualified name self._var_id_to_qualified_name: Dict[int, str] = {} def on_train_begin(self, logs: Optional[Dict] = None) -> None: """Import wandb and build the variable-to-layer-name mapping.""" try: import wandb self._wandb = wandb except ImportError: logger.warning( "wandb not installed; GradientNormLogger will be disabled." ) # Build id(var) → "layer_name/var_name" mapping from model layers. # This is necessary because Keras 3 strips layer prefixes from # variable.name, returning only "kernel", "gamma", etc. self._var_id_to_qualified_name = {} if self.model is not None: for layer in self.model.layers: for var in layer.trainable_variables: qualified = f"{layer.name}/{var.name}" self._var_id_to_qualified_name[id(var)] = qualified logger.info( "GradientNormLogger: mapped %d trainable variables across %d layers", len(self._var_id_to_qualified_name), len([l for l in self.model.layers if l.trainable_variables]), ) def _get_val_batch(self) -> Tuple: """Extract a single batch from validation data for gradient computation.""" x_val, y_val = self.val_data n = self.batch_size if isinstance(x_val, dict): # LMIC multi-input: dict of per-byte arrays x_batch = {k: v[:n] for k, v in x_val.items()} else: # HPS/MTAN-Lite: single array x_batch = x_val[:n] if isinstance(y_val, dict): y_batch = {k: v[:n] for k, v in y_val.items()} else: y_batch = y_val[:n] return x_batch, y_batch def _get_qualified_name(self, var: tf.Variable) -> str: """ Get the fully-qualified name for a variable. Tries the pre-built ``id(var) → name`` mapping first (Keras 3 compatible). Falls back to ``var.name`` for older Keras versions where variable names already include the layer prefix. """ return self._var_id_to_qualified_name.get(id(var), var.name) @staticmethod def _classify_layer(name: str) -> Tuple[str, Optional[int]]: """ Classify a layer by its role and associated task (byte index). Supports both HPS/MTAN-Lite naming (``conv1d``, ``byte_0_bn``, ``byte_0_gate``) and LMIC-TSBN naming (``shared_conv0``, ``tsbn_L0_byte0``, ``gate_L0_byte0``). Args: name: Fully-qualified variable name, e.g. ``shared_conv0/kernel`` or ``tsbn_L1_byte5/gamma``. Returns: (category, byte_index) where category is one of: "shared_conv", "task_bn", "task_gate", "task_head", "other" and byte_index is None for shared layers. """ # Shared conv layers: # HPS/MTAN-Lite: conv1d, conv1d_1, conv1d_2 # LMIC-TSBN: shared_conv0, shared_conv1, shared_conv2 if ("conv1d" in name or "shared_conv" in name) and "byte" not in name: return "shared_conv", None # Task-specific BN: # HPS/MTAN-Lite: byte_0_bn_0, byte_3_bn_2 # LMIC-TSBN: tsbn_L0_byte0, tsbn_L2_byte15 bn_match = re.search(r"byte_(\d+)_bn|tsbn_L\d+_byte(\d+)", name) if bn_match: byte_idx = bn_match.group(1) or bn_match.group(2) return "task_bn", int(byte_idx) # Sigmoid gates: # HPS/MTAN-Lite: byte_0_gate_0, byte_5_gate_1 # LMIC-TSBN: gate_L0_byte0, gate_L1_byte5 gate_match = re.search(r"byte_(\d+)_gate|gate_L\d+_byte(\d+)", name) if gate_match: byte_idx = gate_match.group(1) or gate_match.group(2) return "task_gate", int(gate_match.group(1) or gate_match.group(2)) # Task heads: byte_0_dense, byte_0_output, byte_0/kernel, etc. head_match = re.search(r"byte_(\d+)", name) if head_match: return "task_head", int(head_match.group(1)) # Shared BN (non-task-specific): batch_normalization, etc. if "batch_norm" in name or "bn" in name: return "shared_bn", None return "other", None def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Compute and log gradient norms at the end of each epoch.""" if self._wandb is None: return if (epoch + 1) % self.log_every_n_epochs != 0: return try: x_batch, y_batch = self._get_val_batch() # Compute gradients via GradientTape with tf.GradientTape() as tape: predictions = self.model(x_batch, training=True) loss = self.model.compiled_loss(y_batch, predictions) trainable_vars = self.model.trainable_variables gradients = tape.gradient(loss, trainable_vars) # Compute per-layer norms and classify metrics: Dict[str, float] = {} per_layer_norms: Dict[str, float] = {} per_task_norms: Dict[int, List[float]] = {i: [] for i in range(16)} category_norms: Dict[str, List[float]] = { "shared_conv": [], "shared_bn": [], "task_bn": [], "task_gate": [], "task_head": [], "other": [], } all_grad_norms: List[float] = [] for var, grad in zip(trainable_vars, gradients): if grad is None: continue grad_norm = float(tf.norm(grad).numpy()) # Use the qualified name (layer_name/var_name) qualified_name = self._get_qualified_name(var) # Store per-layer norm # Clean the name for W&B (replace :, / with _) clean_name = qualified_name.replace(":", "_").replace("/", "_") per_layer_norms[clean_name] = grad_norm all_grad_norms.append(grad_norm) # Classify using the qualified name (which includes layer name) category, byte_idx = self._classify_layer(qualified_name) category_norms[category].append(grad_norm) if byte_idx is not None: per_task_norms[byte_idx].append(grad_norm) if not all_grad_norms: return # --- Log per-layer norms --- for name, norm in per_layer_norms.items(): metrics[f"gradients/layer/{name}"] = norm # --- Log per-task aggregate norms --- for byte_idx, norms in per_task_norms.items(): if norms: task_total = sum(n ** 2 for n in norms) ** 0.5 metrics[f"gradients/task/byte_{byte_idx}_norm"] = task_total # --- Log per-category aggregate norms --- for category, norms in category_norms.items(): if norms: cat_total = sum(n ** 2 for n in norms) ** 0.5 metrics[f"gradients/category/{category}_norm"] = cat_total # --- Log summary statistics --- total_norm = sum(n ** 2 for n in all_grad_norms) ** 0.5 max_norm = max(all_grad_norms) min_norm = min(all_grad_norms) if min(all_grad_norms) > 0 else 1e-10 norm_ratio = max_norm / min_norm metrics["gradients/total_norm"] = total_norm metrics["gradients/max_layer_norm"] = max_norm metrics["gradients/min_layer_norm"] = min(all_grad_norms) metrics["gradients/norm_ratio"] = norm_ratio # Log all metrics to W&B in a single call self._wandb.log(metrics, commit=False) # Log summary to console every 10 epochs if (epoch + 1) % 10 == 0: shared_conv_norm = ( sum(n ** 2 for n in category_norms["shared_conv"]) ** 0.5 if category_norms["shared_conv"] else 0.0 ) task_bn_norm = ( sum(n ** 2 for n in category_norms["task_bn"]) ** 0.5 if category_norms["task_bn"] else 0.0 ) task_gate_norm = ( sum(n ** 2 for n in category_norms["task_gate"]) ** 0.5 if category_norms["task_gate"] else 0.0 ) logger.info( "Gradient norms at epoch %d: total=%.4f, max=%.4f, " "min=%.6f, ratio=%.1f | shared_conv=%.4f, " "task_bn=%.4f, task_gate=%.4f", epoch + 1, total_norm, max_norm, min(all_grad_norms), norm_ratio, shared_conv_norm, task_bn_norm, task_gate_norm, ) except Exception as e: # Never crash training due to logging failure logger.warning("GradientNormLogger error at epoch %d: %s", epoch, e)