lemousehunter
v3: Add DTP + Spectral Decoupling, fix GradNorm OOM, fix _fail_job cancel
283a882
"""
SNR-Guided Multi-Task Attention Network (SNR-MTAN)
==================================================
A multi-task learning architecture for simultaneous 16-byte AES key
recovery from side-channel power traces. Combines three novel elements:
1. **Task-Specific Soft Attention (MTAN):** At each convolutional block,
each of the 16 byte-tasks has a learned attention mask that focuses
on its byte-specific temporal region. Adapted from Liu et al. (CVPR 2019).
2. **SNR-Guided Weight Initialization:** Task loss weights are initialized
inversely proportional to each byte's Signal-to-Noise Ratio, giving
harder bytes (low SNR) higher initial weight.
3. **GradNorm Dynamic Balancing:** During training, task weights are
dynamically updated to equalize gradient norms across tasks,
preventing easy bytes from dominating the shared representation.
Based on Chen et al. (ICML 2018).
Architecture:
Input(32272, 1) → 5 shared Conv blocks [Conv1D→BatchNorm→ReLU→AvgPool]
At each block: 16 attention masks [Conv1D(1×1)→ReLU→Conv1D(1×1)→Sigmoid]
β†’ element-wise multiply with shared features
→ 16 heads: GlobalAvgPool→FC(4096)→ReLU→FC(4096)→ReLU→FC(256)→Softmax
Note on architecture:
The global input window (32,272 samples) is ~46x larger than the
per-byte windows (700 samples) used in CNNbest. After 5 pooling layers,
the spatial dimension is 1,008 vs. 21 in CNNbest. To keep the model
tractable, each task head uses GlobalAveragePooling1D before the FC
layers, reducing the representation to the channel dimension (512).
This is consistent with MTAN (Liu et al., 2019) which uses global
pooling in the task-specific decoders.
References:
[1] Liu, S., Johns, E., Davison, A. J.: End-to-end multi-task learning
with attention. CVPR 2019.
[2] Chen, Z., Badrinarayanan, V., Lee, C. Y., Rabinovich, A.: GradNorm:
Gradient normalization for adaptive loss balancing in deep multitask
networks. ICML 2018.
"""
from typing import Any, Dict, List, Optional
import numpy as np
import tensorflow as tf
from .base import BaseModel
from ..constants import (
BYTE_PEAK_SNR,
CNN_DEFAULTS,
GLOBAL_WINDOW_SIZE,
NUM_CLASSES,
)
NUM_TASKS = 16
class SoftAttentionBlock(tf.keras.layers.Layer):
"""
Task-specific soft-attention module (MTAN-style).
Learns a spatial attention mask via two 1x1 convolutions with a
bottleneck, producing a sigmoid gate that is element-wise multiplied
with the shared feature map.
Args:
filters: Number of channels in the input feature map.
task_id: Integer identifier for this task (used in layer naming).
block_id: Integer identifier for the conv block (used in naming).
"""
def __init__(self, filters: int, task_id: int, block_id: int, **kwargs):
super().__init__(name=f"attn_b{block_id}_t{task_id}", **kwargs)
bottleneck = max(filters // 4, 16)
self.conv_down = tf.keras.layers.Conv1D(
bottleneck, kernel_size=1, padding="same",
name=f"attn_down_b{block_id}_t{task_id}",
)
self.relu = tf.keras.layers.ReLU(name=f"attn_relu_b{block_id}_t{task_id}")
self.conv_up = tf.keras.layers.Conv1D(
filters, kernel_size=1, padding="same",
name=f"attn_up_b{block_id}_t{task_id}",
)
self.sigmoid = tf.keras.layers.Activation(
"sigmoid", name=f"attn_sig_b{block_id}_t{task_id}"
)
def call(self, shared_features: tf.Tensor) -> tf.Tensor:
"""
Compute the attention mask and apply it to the shared features.
Args:
shared_features: Tensor of shape (batch, time, filters).
Returns:
Attended features of the same shape.
"""
mask = self.conv_down(shared_features)
mask = self.relu(mask)
mask = self.conv_up(mask)
mask = self.sigmoid(mask)
return shared_features * mask
class GradNormCallback(tf.keras.callbacks.Callback):
"""
GradNorm callback for dynamic task weight balancing during training.
Adjusts per-task loss weights each epoch so that gradient norms across
tasks are balanced, preventing easy tasks from dominating the shared
representation.
Args:
task_weights: tf.Variable of shape (16,) holding current loss weights.
alpha: GradNorm asymmetry parameter (higher = more aggressive balancing).
initial_losses: Optional initial loss values for computing training rates.
"""
def __init__(
self,
task_weights: tf.Variable,
alpha: float = 1.5,
initial_losses: Optional[np.ndarray] = None,
):
super().__init__()
self.task_weights = task_weights
self.alpha = alpha
self.initial_losses = initial_losses
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
"""Update task weights based on relative training rates."""
if logs is None or epoch < 1:
return
# Collect per-task losses from logs
current_losses = []
for i in range(NUM_TASKS):
loss_key = f"byte_{i}_loss"
if loss_key in logs:
current_losses.append(logs[loss_key])
else:
return # Not all task losses available yet
current_losses = np.array(current_losses, dtype=np.float32)
# Initialize baseline losses on first valid epoch
if self.initial_losses is None:
self.initial_losses = current_losses.copy()
return
# Compute inverse training rates: r_i = L_i(t) / L_i(0)
eps = 1e-8
inv_rates = current_losses / (self.initial_losses + eps)
mean_inv_rate = np.mean(inv_rates)
# Relative inverse training rates
rel_inv_rates = inv_rates / (mean_inv_rate + eps)
# Target gradient norm ratios: r_i^alpha
targets = rel_inv_rates ** self.alpha
# Update weights: w_i ← w_i * target_i, then renormalize
new_weights = self.task_weights.numpy() * targets
new_weights = new_weights / (np.mean(new_weights) + eps)
new_weights = new_weights * (NUM_TASKS / np.sum(new_weights))
self.task_weights.assign(new_weights.astype(np.float32))
class SNRMTAN(BaseModel):
"""
SNR-Guided Multi-Task Attention Network for 16-byte AES key recovery.
This model processes a single global trace window covering all 16 byte
POI regions and simultaneously predicts the S-Box output for each byte.
Task-specific attention masks allow each byte's head to focus on its
relevant temporal region within the shared feature representation.
Args:
input_length: Number of time samples in the global trace window.
num_classes: Number of output classes per task (256 for AES S-Box).
conv_filters: List of filter counts for each shared conv block.
kernel_size: Kernel size for shared conv layers.
pool_size: Pool size for average pooling.
fc_units: Number of units in each task-specific FC layer.
num_fc_layers: Number of FC layers per task head.
use_attention: Whether to use MTAN soft-attention modules.
use_gradnorm: Whether to enable GradNorm dynamic weight balancing.
snr_init: Whether to initialize task weights from SNR values.
gradnorm_alpha: GradNorm asymmetry parameter.
"""
def __init__(
self,
input_length: int = GLOBAL_WINDOW_SIZE,
num_classes: int = NUM_CLASSES,
conv_filters: Optional[List[int]] = None,
kernel_size: int = CNN_DEFAULTS["kernel_size"],
pool_size: int = CNN_DEFAULTS["pool_size"],
fc_units: int = CNN_DEFAULTS["fc_units"],
num_fc_layers: int = CNN_DEFAULTS["num_fc_layers"],
use_attention: bool = True,
use_gradnorm: bool = True,
snr_init: bool = True,
gradnorm_alpha: float = 1.5,
) -> None:
super().__init__(input_shape=(input_length, 1), num_classes=num_classes)
self.input_length = input_length
self.conv_filters = conv_filters or list(CNN_DEFAULTS["conv_filters"])
self.kernel_size = kernel_size
self.pool_size = pool_size
self.fc_units = fc_units
self.num_fc_layers = num_fc_layers
self.use_attention = use_attention
self.use_gradnorm = use_gradnorm
self.snr_init = snr_init
self.gradnorm_alpha = gradnorm_alpha
# Task weights (initialized later in build or compile)
self._task_weights: Optional[tf.Variable] = None
self._gradnorm_callback: Optional[GradNormCallback] = None
def _init_task_weights(self) -> tf.Variable:
"""
Initialize per-task loss weights.
If snr_init is True, weights are set inversely proportional to each
byte's peak SNR value, normalized so they sum to NUM_TASKS.
Otherwise, uniform weights (all 1.0) are used.
"""
if self.snr_init:
snr_values = np.array(
[BYTE_PEAK_SNR[i] for i in range(NUM_TASKS)], dtype=np.float32
)
inv_snr = 1.0 / (snr_values + 1e-8)
weights = inv_snr * (NUM_TASKS / np.sum(inv_snr))
else:
weights = np.ones(NUM_TASKS, dtype=np.float32)
return tf.Variable(weights, trainable=False, name="task_weights")
def build(self) -> tf.keras.Model:
"""
Construct the SNR-MTAN Keras model.
The architecture uses GlobalAveragePooling1D before the FC layers
in each task head to reduce the spatial dimension to the channel
dimension. This keeps the model tractable for the large global
input window (32,272 samples) while preserving the attention
mechanism's ability to focus on byte-specific temporal regions.
Returns:
A Keras Model with one input and 16 outputs (one per byte).
"""
inputs = tf.keras.Input(
shape=self.input_shape, name="trace_input"
)
# ── Shared Backbone ──────────────────────────────────────────────
shared = inputs
block_outputs = []
for block_idx, filters in enumerate(self.conv_filters):
shared = tf.keras.layers.Conv1D(
filters=filters,
kernel_size=self.kernel_size,
padding="same",
kernel_initializer="glorot_uniform",
name=f"shared_conv_{block_idx}",
)(shared)
shared = tf.keras.layers.BatchNormalization(
name=f"shared_bn_{block_idx}"
)(shared)
shared = tf.keras.layers.ReLU(
name=f"shared_relu_{block_idx}"
)(shared)
shared = tf.keras.layers.AveragePooling1D(
pool_size=self.pool_size,
name=f"shared_pool_{block_idx}",
)(shared)
block_outputs.append(shared)
# ── Task-Specific Branches ───────────────────────────────────────
task_outputs = []
for task_id in range(NUM_TASKS):
if self.use_attention:
# Apply attention at the final conv block only
# (applying at all 5 blocks would create 80 attention modules
# which is expensive; the final block captures the most
# abstract features and is where byte-specific focus matters most)
attn = SoftAttentionBlock(
filters=self.conv_filters[-1],
task_id=task_id,
block_id=len(self.conv_filters) - 1,
)(block_outputs[-1])
task_repr = attn
else:
# No attention: all tasks share the same final representation
task_repr = block_outputs[-1]
# GlobalAveragePooling1D reduces (batch, time, channels) -> (batch, channels)
# This is critical for the large input: without it, flatten would produce
# 1008*512 = 516K features per head, making FC layers infeasible.
x = tf.keras.layers.GlobalAveragePooling1D(
name=f"byte_{task_id}_gap"
)(task_repr)
# Task-specific classification head
for fc_idx in range(self.num_fc_layers):
x = tf.keras.layers.Dense(
self.fc_units,
activation="relu",
kernel_initializer="glorot_uniform",
name=f"byte_{task_id}_fc_{fc_idx}",
)(x)
output = tf.keras.layers.Dense(
self.num_classes,
activation="softmax",
name=f"byte_{task_id}",
)(x)
task_outputs.append(output)
self._model = tf.keras.Model(
inputs=inputs,
outputs=task_outputs,
name="SNR_MTAN",
)
return self._model
def compile(self, learning_rate: float = 1e-5) -> tf.keras.Model:
"""
Build and compile the model with per-task weighted losses.
Overrides the base class to set up multi-task loss weighting
and optionally the GradNorm callback.
Args:
learning_rate: Learning rate for the RMSprop optimizer.
Returns:
The compiled Keras model.
"""
if self._model is None:
self._model = self.build()
# Initialize task weights
self._task_weights = self._init_task_weights()
# Create per-task loss dict and loss weights dict
loss_dict = {}
loss_weights_dict = {}
for i in range(NUM_TASKS):
output_name = f"byte_{i}"
loss_dict[output_name] = "categorical_crossentropy"
loss_weights_dict[output_name] = float(self._task_weights.numpy()[i])
self._model.compile(
optimizer=tf.keras.optimizers.RMSprop(learning_rate=learning_rate),
loss=loss_dict,
loss_weights=loss_weights_dict,
metrics={f"byte_{i}": ["accuracy"] for i in range(NUM_TASKS)},
)
# Set up GradNorm callback if requested
if self.use_gradnorm:
self._gradnorm_callback = GradNormCallback(
task_weights=self._task_weights,
alpha=self.gradnorm_alpha,
)
return self._model
@property
def task_weights(self) -> Optional[tf.Variable]:
"""Access the current task loss weights."""
return self._task_weights
@property
def gradnorm_callback(self) -> Optional[GradNormCallback]:
"""Access the GradNorm callback (None if GradNorm is disabled)."""
return self._gradnorm_callback
def get_config(self) -> Dict[str, Any]:
"""Return architecture hyperparameters for logging."""
config = {
"model_type": "mtan",
"architecture": "SNR-MTAN",
"input_length": self.input_length,
"num_classes": self.num_classes,
"num_tasks": NUM_TASKS,
"conv_filters": self.conv_filters,
"kernel_size": self.kernel_size,
"pool_size": self.pool_size,
"fc_units": self.fc_units,
"num_fc_layers": self.num_fc_layers,
"use_attention": self.use_attention,
"use_gradnorm": self.use_gradnorm,
"snr_init": self.snr_init,
"gradnorm_alpha": self.gradnorm_alpha,
}
if self._model is not None:
config["total_params"] = self.model.count_params()
if self._task_weights is not None:
config["initial_task_weights"] = self._task_weights.numpy().tolist()
return config