| """ |
| SNR-Guided Multi-Task Attention Network (SNR-MTAN) |
| ================================================== |
| A multi-task learning architecture for simultaneous 16-byte AES key |
| recovery from side-channel power traces. Combines three novel elements: |
| |
| 1. **Task-Specific Soft Attention (MTAN):** At each convolutional block, |
| each of the 16 byte-tasks has a learned attention mask that focuses |
| on its byte-specific temporal region. Adapted from Liu et al. (CVPR 2019). |
| |
| 2. **SNR-Guided Weight Initialization:** Task loss weights are initialized |
| inversely proportional to each byte's Signal-to-Noise Ratio, giving |
| harder bytes (low SNR) higher initial weight. |
| |
| 3. **GradNorm Dynamic Balancing:** During training, task weights are |
| dynamically updated to equalize gradient norms across tasks, |
| preventing easy bytes from dominating the shared representation. |
| Based on Chen et al. (ICML 2018). |
| |
| Architecture: |
| Input(32272, 1) β 5 shared Conv blocks [Conv1DβBatchNormβReLUβAvgPool] |
| At each block: 16 attention masks [Conv1D(1Γ1)βReLUβConv1D(1Γ1)βSigmoid] |
| β element-wise multiply with shared features |
| β 16 heads: GlobalAvgPoolβFC(4096)βReLUβFC(4096)βReLUβFC(256)βSoftmax |
| |
| Note on architecture: |
| The global input window (32,272 samples) is ~46x larger than the |
| per-byte windows (700 samples) used in CNNbest. After 5 pooling layers, |
| the spatial dimension is 1,008 vs. 21 in CNNbest. To keep the model |
| tractable, each task head uses GlobalAveragePooling1D before the FC |
| layers, reducing the representation to the channel dimension (512). |
| This is consistent with MTAN (Liu et al., 2019) which uses global |
| pooling in the task-specific decoders. |
| |
| References: |
| [1] Liu, S., Johns, E., Davison, A. J.: End-to-end multi-task learning |
| with attention. CVPR 2019. |
| [2] Chen, Z., Badrinarayanan, V., Lee, C. Y., Rabinovich, A.: GradNorm: |
| Gradient normalization for adaptive loss balancing in deep multitask |
| networks. ICML 2018. |
| """ |
|
|
| from typing import Any, Dict, List, Optional |
|
|
| import numpy as np |
| import tensorflow as tf |
|
|
| from .base import BaseModel |
| from ..constants import ( |
| BYTE_PEAK_SNR, |
| CNN_DEFAULTS, |
| GLOBAL_WINDOW_SIZE, |
| NUM_CLASSES, |
| ) |
|
|
| NUM_TASKS = 16 |
|
|
|
|
| class SoftAttentionBlock(tf.keras.layers.Layer): |
| """ |
| Task-specific soft-attention module (MTAN-style). |
| |
| Learns a spatial attention mask via two 1x1 convolutions with a |
| bottleneck, producing a sigmoid gate that is element-wise multiplied |
| with the shared feature map. |
| |
| Args: |
| filters: Number of channels in the input feature map. |
| task_id: Integer identifier for this task (used in layer naming). |
| block_id: Integer identifier for the conv block (used in naming). |
| """ |
|
|
| def __init__(self, filters: int, task_id: int, block_id: int, **kwargs): |
| super().__init__(name=f"attn_b{block_id}_t{task_id}", **kwargs) |
| bottleneck = max(filters // 4, 16) |
| self.conv_down = tf.keras.layers.Conv1D( |
| bottleneck, kernel_size=1, padding="same", |
| name=f"attn_down_b{block_id}_t{task_id}", |
| ) |
| self.relu = tf.keras.layers.ReLU(name=f"attn_relu_b{block_id}_t{task_id}") |
| self.conv_up = tf.keras.layers.Conv1D( |
| filters, kernel_size=1, padding="same", |
| name=f"attn_up_b{block_id}_t{task_id}", |
| ) |
| self.sigmoid = tf.keras.layers.Activation( |
| "sigmoid", name=f"attn_sig_b{block_id}_t{task_id}" |
| ) |
|
|
| def call(self, shared_features: tf.Tensor) -> tf.Tensor: |
| """ |
| Compute the attention mask and apply it to the shared features. |
| |
| Args: |
| shared_features: Tensor of shape (batch, time, filters). |
| |
| Returns: |
| Attended features of the same shape. |
| """ |
| mask = self.conv_down(shared_features) |
| mask = self.relu(mask) |
| mask = self.conv_up(mask) |
| mask = self.sigmoid(mask) |
| return shared_features * mask |
|
|
|
|
| class GradNormCallback(tf.keras.callbacks.Callback): |
| """ |
| GradNorm callback for dynamic task weight balancing during training. |
| |
| Adjusts per-task loss weights each epoch so that gradient norms across |
| tasks are balanced, preventing easy tasks from dominating the shared |
| representation. |
| |
| Args: |
| task_weights: tf.Variable of shape (16,) holding current loss weights. |
| alpha: GradNorm asymmetry parameter (higher = more aggressive balancing). |
| initial_losses: Optional initial loss values for computing training rates. |
| """ |
|
|
| def __init__( |
| self, |
| task_weights: tf.Variable, |
| alpha: float = 1.5, |
| initial_losses: Optional[np.ndarray] = None, |
| ): |
| super().__init__() |
| self.task_weights = task_weights |
| self.alpha = alpha |
| self.initial_losses = initial_losses |
|
|
| def on_epoch_end(self, epoch: int, logs: Optional[dict] = None): |
| """Update task weights based on relative training rates.""" |
| if logs is None or epoch < 1: |
| return |
|
|
| |
| current_losses = [] |
| for i in range(NUM_TASKS): |
| loss_key = f"byte_{i}_loss" |
| if loss_key in logs: |
| current_losses.append(logs[loss_key]) |
| else: |
| return |
|
|
| current_losses = np.array(current_losses, dtype=np.float32) |
|
|
| |
| if self.initial_losses is None: |
| self.initial_losses = current_losses.copy() |
| return |
|
|
| |
| eps = 1e-8 |
| inv_rates = current_losses / (self.initial_losses + eps) |
| mean_inv_rate = np.mean(inv_rates) |
|
|
| |
| rel_inv_rates = inv_rates / (mean_inv_rate + eps) |
|
|
| |
| targets = rel_inv_rates ** self.alpha |
|
|
| |
| new_weights = self.task_weights.numpy() * targets |
| new_weights = new_weights / (np.mean(new_weights) + eps) |
| new_weights = new_weights * (NUM_TASKS / np.sum(new_weights)) |
|
|
| self.task_weights.assign(new_weights.astype(np.float32)) |
|
|
|
|
| class SNRMTAN(BaseModel): |
| """ |
| SNR-Guided Multi-Task Attention Network for 16-byte AES key recovery. |
| |
| This model processes a single global trace window covering all 16 byte |
| POI regions and simultaneously predicts the S-Box output for each byte. |
| Task-specific attention masks allow each byte's head to focus on its |
| relevant temporal region within the shared feature representation. |
| |
| Args: |
| input_length: Number of time samples in the global trace window. |
| num_classes: Number of output classes per task (256 for AES S-Box). |
| conv_filters: List of filter counts for each shared conv block. |
| kernel_size: Kernel size for shared conv layers. |
| pool_size: Pool size for average pooling. |
| fc_units: Number of units in each task-specific FC layer. |
| num_fc_layers: Number of FC layers per task head. |
| use_attention: Whether to use MTAN soft-attention modules. |
| use_gradnorm: Whether to enable GradNorm dynamic weight balancing. |
| snr_init: Whether to initialize task weights from SNR values. |
| gradnorm_alpha: GradNorm asymmetry parameter. |
| """ |
|
|
| def __init__( |
| self, |
| input_length: int = GLOBAL_WINDOW_SIZE, |
| num_classes: int = NUM_CLASSES, |
| conv_filters: Optional[List[int]] = None, |
| kernel_size: int = CNN_DEFAULTS["kernel_size"], |
| pool_size: int = CNN_DEFAULTS["pool_size"], |
| fc_units: int = CNN_DEFAULTS["fc_units"], |
| num_fc_layers: int = CNN_DEFAULTS["num_fc_layers"], |
| use_attention: bool = True, |
| use_gradnorm: bool = True, |
| snr_init: bool = True, |
| gradnorm_alpha: float = 1.5, |
| ) -> None: |
| super().__init__(input_shape=(input_length, 1), num_classes=num_classes) |
| self.input_length = input_length |
| self.conv_filters = conv_filters or list(CNN_DEFAULTS["conv_filters"]) |
| self.kernel_size = kernel_size |
| self.pool_size = pool_size |
| self.fc_units = fc_units |
| self.num_fc_layers = num_fc_layers |
| self.use_attention = use_attention |
| self.use_gradnorm = use_gradnorm |
| self.snr_init = snr_init |
| self.gradnorm_alpha = gradnorm_alpha |
|
|
| |
| self._task_weights: Optional[tf.Variable] = None |
| self._gradnorm_callback: Optional[GradNormCallback] = None |
|
|
| def _init_task_weights(self) -> tf.Variable: |
| """ |
| Initialize per-task loss weights. |
| |
| If snr_init is True, weights are set inversely proportional to each |
| byte's peak SNR value, normalized so they sum to NUM_TASKS. |
| Otherwise, uniform weights (all 1.0) are used. |
| """ |
| if self.snr_init: |
| snr_values = np.array( |
| [BYTE_PEAK_SNR[i] for i in range(NUM_TASKS)], dtype=np.float32 |
| ) |
| inv_snr = 1.0 / (snr_values + 1e-8) |
| weights = inv_snr * (NUM_TASKS / np.sum(inv_snr)) |
| else: |
| weights = np.ones(NUM_TASKS, dtype=np.float32) |
|
|
| return tf.Variable(weights, trainable=False, name="task_weights") |
|
|
| def build(self) -> tf.keras.Model: |
| """ |
| Construct the SNR-MTAN Keras model. |
| |
| The architecture uses GlobalAveragePooling1D before the FC layers |
| in each task head to reduce the spatial dimension to the channel |
| dimension. This keeps the model tractable for the large global |
| input window (32,272 samples) while preserving the attention |
| mechanism's ability to focus on byte-specific temporal regions. |
| |
| Returns: |
| A Keras Model with one input and 16 outputs (one per byte). |
| """ |
| inputs = tf.keras.Input( |
| shape=self.input_shape, name="trace_input" |
| ) |
|
|
| |
| shared = inputs |
| block_outputs = [] |
|
|
| for block_idx, filters in enumerate(self.conv_filters): |
| shared = tf.keras.layers.Conv1D( |
| filters=filters, |
| kernel_size=self.kernel_size, |
| padding="same", |
| kernel_initializer="glorot_uniform", |
| name=f"shared_conv_{block_idx}", |
| )(shared) |
| shared = tf.keras.layers.BatchNormalization( |
| name=f"shared_bn_{block_idx}" |
| )(shared) |
| shared = tf.keras.layers.ReLU( |
| name=f"shared_relu_{block_idx}" |
| )(shared) |
| shared = tf.keras.layers.AveragePooling1D( |
| pool_size=self.pool_size, |
| name=f"shared_pool_{block_idx}", |
| )(shared) |
| block_outputs.append(shared) |
|
|
| |
| task_outputs = [] |
|
|
| for task_id in range(NUM_TASKS): |
| if self.use_attention: |
| |
| |
| |
| |
| attn = SoftAttentionBlock( |
| filters=self.conv_filters[-1], |
| task_id=task_id, |
| block_id=len(self.conv_filters) - 1, |
| )(block_outputs[-1]) |
| task_repr = attn |
| else: |
| |
| task_repr = block_outputs[-1] |
|
|
| |
| |
| |
| x = tf.keras.layers.GlobalAveragePooling1D( |
| name=f"byte_{task_id}_gap" |
| )(task_repr) |
|
|
| |
| for fc_idx in range(self.num_fc_layers): |
| x = tf.keras.layers.Dense( |
| self.fc_units, |
| activation="relu", |
| kernel_initializer="glorot_uniform", |
| name=f"byte_{task_id}_fc_{fc_idx}", |
| )(x) |
|
|
| output = tf.keras.layers.Dense( |
| self.num_classes, |
| activation="softmax", |
| name=f"byte_{task_id}", |
| )(x) |
| task_outputs.append(output) |
|
|
| self._model = tf.keras.Model( |
| inputs=inputs, |
| outputs=task_outputs, |
| name="SNR_MTAN", |
| ) |
| return self._model |
|
|
| def compile(self, learning_rate: float = 1e-5) -> tf.keras.Model: |
| """ |
| Build and compile the model with per-task weighted losses. |
| |
| Overrides the base class to set up multi-task loss weighting |
| and optionally the GradNorm callback. |
| |
| Args: |
| learning_rate: Learning rate for the RMSprop optimizer. |
| |
| Returns: |
| The compiled Keras model. |
| """ |
| if self._model is None: |
| self._model = self.build() |
|
|
| |
| self._task_weights = self._init_task_weights() |
|
|
| |
| loss_dict = {} |
| loss_weights_dict = {} |
| for i in range(NUM_TASKS): |
| output_name = f"byte_{i}" |
| loss_dict[output_name] = "categorical_crossentropy" |
| loss_weights_dict[output_name] = float(self._task_weights.numpy()[i]) |
|
|
| self._model.compile( |
| optimizer=tf.keras.optimizers.RMSprop(learning_rate=learning_rate), |
| loss=loss_dict, |
| loss_weights=loss_weights_dict, |
| metrics={f"byte_{i}": ["accuracy"] for i in range(NUM_TASKS)}, |
| ) |
|
|
| |
| if self.use_gradnorm: |
| self._gradnorm_callback = GradNormCallback( |
| task_weights=self._task_weights, |
| alpha=self.gradnorm_alpha, |
| ) |
|
|
| return self._model |
|
|
| @property |
| def task_weights(self) -> Optional[tf.Variable]: |
| """Access the current task loss weights.""" |
| return self._task_weights |
|
|
| @property |
| def gradnorm_callback(self) -> Optional[GradNormCallback]: |
| """Access the GradNorm callback (None if GradNorm is disabled).""" |
| return self._gradnorm_callback |
|
|
| def get_config(self) -> Dict[str, Any]: |
| """Return architecture hyperparameters for logging.""" |
| config = { |
| "model_type": "mtan", |
| "architecture": "SNR-MTAN", |
| "input_length": self.input_length, |
| "num_classes": self.num_classes, |
| "num_tasks": NUM_TASKS, |
| "conv_filters": self.conv_filters, |
| "kernel_size": self.kernel_size, |
| "pool_size": self.pool_size, |
| "fc_units": self.fc_units, |
| "num_fc_layers": self.num_fc_layers, |
| "use_attention": self.use_attention, |
| "use_gradnorm": self.use_gradnorm, |
| "snr_init": self.snr_init, |
| "gradnorm_alpha": self.gradnorm_alpha, |
| } |
| if self._model is not None: |
| config["total_params"] = self.model.count_params() |
| if self._task_weights is not None: |
| config["initial_task_weights"] = self._task_weights.numpy().tolist() |
| return config |
|
|