ascad-training-pipeline / src /gradient_logger.py
lemousehunter
fix(gradient_logger): use layer.name→var mapping for Keras 3 compatibility
94e09b3
"""
Gradient Norm Logger for W&B
============================
A Keras callback that logs per-layer gradient norms to Weights & Biases
at the end of each epoch. This provides direct visibility into gradient
flow across shared convolutions, per-task BatchNorm layers, per-task
classification heads, and sigmoid gates.
Logged metrics (all under the "gradients/" prefix in W&B):
- gradients/total_norm : Global L2 norm of all gradients
- gradients/shared_conv_norm : L2 norm of shared conv layer gradients
- gradients/layer/{name}_norm : Per-layer L2 gradient norm
- gradients/task/{byte_i}_norm : Per-task aggregate gradient norm
- gradients/max_layer_norm : Maximum gradient norm across all layers
- gradients/min_layer_norm : Minimum gradient norm across all layers
- gradients/norm_ratio : max/min ratio (gradient imbalance indicator)
Usage:
The callback is automatically added when ``wandb_project`` is set in
the MTLTrainer. It uses a single batch from the validation set to
compute gradients via ``tf.GradientTape``, keeping the overhead minimal.
Note on Keras 3 / TF 2.16+:
In Keras 3, ``variable.name`` returns only the weight name (e.g.
``kernel``, ``gamma``) without a layer prefix. To obtain the full
layer-qualified name we iterate over ``model.layers`` and build an
``id(variable) → layer.name`` mapping at the start of training.
References:
- Tang et al., "Improving Training Stability for Multitask Ranking
Models", KDD 2023 — gradient norm monitoring for instability detection.
- Chen et al., "GradNorm: Gradient Normalization for Adaptive Loss
Balancing in Multi-Task Networks", ICML 2018.
"""
import logging
import re
from typing import Any, Dict, List, Optional, Tuple
import tensorflow as tf
from tensorflow import keras
logger = logging.getLogger(__name__)
class GradientNormLogger(keras.callbacks.Callback):
"""
Logs per-layer and per-task gradient norms to W&B at each epoch end.
The callback computes gradients using a single validation batch to
minimize training overhead (~2-5% additional time per epoch).
Args:
val_data: Validation data as (x, y) tuple. For LMIC models, x is
a dict of per-byte inputs; for HPS/MTAN-Lite, x is a single
array.
log_every_n_epochs: How often to compute and log gradients.
Default is 1 (every epoch). Set higher to reduce overhead.
batch_size: Number of samples to use for gradient computation.
Default is 128. Smaller values reduce memory; larger values
give more stable gradient estimates.
"""
def __init__(
self,
val_data: Tuple,
log_every_n_epochs: int = 1,
batch_size: int = 128,
) -> None:
super().__init__()
self.val_data = val_data
self.log_every_n_epochs = log_every_n_epochs
self.batch_size = batch_size
self._wandb = None
# Built lazily in on_train_begin: maps id(variable) → qualified name
self._var_id_to_qualified_name: Dict[int, str] = {}
def on_train_begin(self, logs: Optional[Dict] = None) -> None:
"""Import wandb and build the variable-to-layer-name mapping."""
try:
import wandb
self._wandb = wandb
except ImportError:
logger.warning(
"wandb not installed; GradientNormLogger will be disabled."
)
# Build id(var) → "layer_name/var_name" mapping from model layers.
# This is necessary because Keras 3 strips layer prefixes from
# variable.name, returning only "kernel", "gamma", etc.
self._var_id_to_qualified_name = {}
if self.model is not None:
for layer in self.model.layers:
for var in layer.trainable_variables:
qualified = f"{layer.name}/{var.name}"
self._var_id_to_qualified_name[id(var)] = qualified
logger.info(
"GradientNormLogger: mapped %d trainable variables across %d layers",
len(self._var_id_to_qualified_name),
len([l for l in self.model.layers if l.trainable_variables]),
)
def _get_val_batch(self) -> Tuple:
"""Extract a single batch from validation data for gradient computation."""
x_val, y_val = self.val_data
n = self.batch_size
if isinstance(x_val, dict):
# LMIC multi-input: dict of per-byte arrays
x_batch = {k: v[:n] for k, v in x_val.items()}
else:
# HPS/MTAN-Lite: single array
x_batch = x_val[:n]
if isinstance(y_val, dict):
y_batch = {k: v[:n] for k, v in y_val.items()}
else:
y_batch = y_val[:n]
return x_batch, y_batch
def _get_qualified_name(self, var: tf.Variable) -> str:
"""
Get the fully-qualified name for a variable.
Tries the pre-built ``id(var) → name`` mapping first (Keras 3
compatible). Falls back to ``var.name`` for older Keras versions
where variable names already include the layer prefix.
"""
return self._var_id_to_qualified_name.get(id(var), var.name)
@staticmethod
def _classify_layer(name: str) -> Tuple[str, Optional[int]]:
"""
Classify a layer by its role and associated task (byte index).
Supports both HPS/MTAN-Lite naming (``conv1d``, ``byte_0_bn``,
``byte_0_gate``) and LMIC-TSBN naming (``shared_conv0``,
``tsbn_L0_byte0``, ``gate_L0_byte0``).
Args:
name: Fully-qualified variable name, e.g.
``shared_conv0/kernel`` or ``tsbn_L1_byte5/gamma``.
Returns:
(category, byte_index) where category is one of:
"shared_conv", "task_bn", "task_gate", "task_head", "other"
and byte_index is None for shared layers.
"""
# Shared conv layers:
# HPS/MTAN-Lite: conv1d, conv1d_1, conv1d_2
# LMIC-TSBN: shared_conv0, shared_conv1, shared_conv2
if ("conv1d" in name or "shared_conv" in name) and "byte" not in name:
return "shared_conv", None
# Task-specific BN:
# HPS/MTAN-Lite: byte_0_bn_0, byte_3_bn_2
# LMIC-TSBN: tsbn_L0_byte0, tsbn_L2_byte15
bn_match = re.search(r"byte_(\d+)_bn|tsbn_L\d+_byte(\d+)", name)
if bn_match:
byte_idx = bn_match.group(1) or bn_match.group(2)
return "task_bn", int(byte_idx)
# Sigmoid gates:
# HPS/MTAN-Lite: byte_0_gate_0, byte_5_gate_1
# LMIC-TSBN: gate_L0_byte0, gate_L1_byte5
gate_match = re.search(r"byte_(\d+)_gate|gate_L\d+_byte(\d+)", name)
if gate_match:
byte_idx = gate_match.group(1) or gate_match.group(2)
return "task_gate", int(gate_match.group(1) or gate_match.group(2))
# Task heads: byte_0_dense, byte_0_output, byte_0/kernel, etc.
head_match = re.search(r"byte_(\d+)", name)
if head_match:
return "task_head", int(head_match.group(1))
# Shared BN (non-task-specific): batch_normalization, etc.
if "batch_norm" in name or "bn" in name:
return "shared_bn", None
return "other", None
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Compute and log gradient norms at the end of each epoch."""
if self._wandb is None:
return
if (epoch + 1) % self.log_every_n_epochs != 0:
return
try:
x_batch, y_batch = self._get_val_batch()
# Compute gradients via GradientTape
with tf.GradientTape() as tape:
predictions = self.model(x_batch, training=True)
loss = self.model.compiled_loss(y_batch, predictions)
trainable_vars = self.model.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Compute per-layer norms and classify
metrics: Dict[str, float] = {}
per_layer_norms: Dict[str, float] = {}
per_task_norms: Dict[int, List[float]] = {i: [] for i in range(16)}
category_norms: Dict[str, List[float]] = {
"shared_conv": [],
"shared_bn": [],
"task_bn": [],
"task_gate": [],
"task_head": [],
"other": [],
}
all_grad_norms: List[float] = []
for var, grad in zip(trainable_vars, gradients):
if grad is None:
continue
grad_norm = float(tf.norm(grad).numpy())
# Use the qualified name (layer_name/var_name)
qualified_name = self._get_qualified_name(var)
# Store per-layer norm
# Clean the name for W&B (replace :, / with _)
clean_name = qualified_name.replace(":", "_").replace("/", "_")
per_layer_norms[clean_name] = grad_norm
all_grad_norms.append(grad_norm)
# Classify using the qualified name (which includes layer name)
category, byte_idx = self._classify_layer(qualified_name)
category_norms[category].append(grad_norm)
if byte_idx is not None:
per_task_norms[byte_idx].append(grad_norm)
if not all_grad_norms:
return
# --- Log per-layer norms ---
for name, norm in per_layer_norms.items():
metrics[f"gradients/layer/{name}"] = norm
# --- Log per-task aggregate norms ---
for byte_idx, norms in per_task_norms.items():
if norms:
task_total = sum(n ** 2 for n in norms) ** 0.5
metrics[f"gradients/task/byte_{byte_idx}_norm"] = task_total
# --- Log per-category aggregate norms ---
for category, norms in category_norms.items():
if norms:
cat_total = sum(n ** 2 for n in norms) ** 0.5
metrics[f"gradients/category/{category}_norm"] = cat_total
# --- Log summary statistics ---
total_norm = sum(n ** 2 for n in all_grad_norms) ** 0.5
max_norm = max(all_grad_norms)
min_norm = min(all_grad_norms) if min(all_grad_norms) > 0 else 1e-10
norm_ratio = max_norm / min_norm
metrics["gradients/total_norm"] = total_norm
metrics["gradients/max_layer_norm"] = max_norm
metrics["gradients/min_layer_norm"] = min(all_grad_norms)
metrics["gradients/norm_ratio"] = norm_ratio
# Log all metrics to W&B in a single call
self._wandb.log(metrics, commit=False)
# Log summary to console every 10 epochs
if (epoch + 1) % 10 == 0:
shared_conv_norm = (
sum(n ** 2 for n in category_norms["shared_conv"]) ** 0.5
if category_norms["shared_conv"]
else 0.0
)
task_bn_norm = (
sum(n ** 2 for n in category_norms["task_bn"]) ** 0.5
if category_norms["task_bn"]
else 0.0
)
task_gate_norm = (
sum(n ** 2 for n in category_norms["task_gate"]) ** 0.5
if category_norms["task_gate"]
else 0.0
)
logger.info(
"Gradient norms at epoch %d: total=%.4f, max=%.4f, "
"min=%.6f, ratio=%.1f | shared_conv=%.4f, "
"task_bn=%.4f, task_gate=%.4f",
epoch + 1,
total_norm,
max_norm,
min(all_grad_norms),
norm_ratio,
shared_conv_norm,
task_bn_norm,
task_gate_norm,
)
except Exception as e:
# Never crash training due to logging failure
logger.warning("GradientNormLogger error at epoch %d: %s", epoch, e)