""" SNR-Guided Multi-Task Attention Network (SNR-MTAN) ================================================== A multi-task learning architecture for simultaneous 16-byte AES key recovery from side-channel power traces. Combines three novel elements: 1. **Task-Specific Soft Attention (MTAN):** At each convolutional block, each of the 16 byte-tasks has a learned attention mask that focuses on its byte-specific temporal region. Adapted from Liu et al. (CVPR 2019). 2. **SNR-Guided Weight Initialization:** Task loss weights are initialized inversely proportional to each byte's Signal-to-Noise Ratio, giving harder bytes (low SNR) higher initial weight. 3. **GradNorm Dynamic Balancing:** During training, task weights are dynamically updated to equalize gradient norms across tasks, preventing easy bytes from dominating the shared representation. Based on Chen et al. (ICML 2018). Architecture: Input(32272, 1) → 5 shared Conv blocks [Conv1D→BatchNorm→ReLU→AvgPool] At each block: 16 attention masks [Conv1D(1×1)→ReLU→Conv1D(1×1)→Sigmoid] → element-wise multiply with shared features → 16 heads: GlobalAvgPool→FC(4096)→ReLU→FC(4096)→ReLU→FC(256)→Softmax Note on architecture: The global input window (32,272 samples) is ~46x larger than the per-byte windows (700 samples) used in CNNbest. After 5 pooling layers, the spatial dimension is 1,008 vs. 21 in CNNbest. To keep the model tractable, each task head uses GlobalAveragePooling1D before the FC layers, reducing the representation to the channel dimension (512). This is consistent with MTAN (Liu et al., 2019) which uses global pooling in the task-specific decoders. References: [1] Liu, S., Johns, E., Davison, A. J.: End-to-end multi-task learning with attention. CVPR 2019. [2] Chen, Z., Badrinarayanan, V., Lee, C. Y., Rabinovich, A.: GradNorm: Gradient normalization for adaptive loss balancing in deep multitask networks. ICML 2018. """ from typing import Any, Dict, List, Optional import numpy as np import tensorflow as tf from .base import BaseModel from ..constants import ( BYTE_PEAK_SNR, CNN_DEFAULTS, GLOBAL_WINDOW_SIZE, NUM_CLASSES, ) NUM_TASKS = 16 class SoftAttentionBlock(tf.keras.layers.Layer): """ Task-specific soft-attention module (MTAN-style). Learns a spatial attention mask via two 1x1 convolutions with a bottleneck, producing a sigmoid gate that is element-wise multiplied with the shared feature map. Args: filters: Number of channels in the input feature map. task_id: Integer identifier for this task (used in layer naming). block_id: Integer identifier for the conv block (used in naming). """ def __init__(self, filters: int, task_id: int, block_id: int, **kwargs): super().__init__(name=f"attn_b{block_id}_t{task_id}", **kwargs) bottleneck = max(filters // 4, 16) self.conv_down = tf.keras.layers.Conv1D( bottleneck, kernel_size=1, padding="same", name=f"attn_down_b{block_id}_t{task_id}", ) self.relu = tf.keras.layers.ReLU(name=f"attn_relu_b{block_id}_t{task_id}") self.conv_up = tf.keras.layers.Conv1D( filters, kernel_size=1, padding="same", name=f"attn_up_b{block_id}_t{task_id}", ) self.sigmoid = tf.keras.layers.Activation( "sigmoid", name=f"attn_sig_b{block_id}_t{task_id}" ) def call(self, shared_features: tf.Tensor) -> tf.Tensor: """ Compute the attention mask and apply it to the shared features. Args: shared_features: Tensor of shape (batch, time, filters). Returns: Attended features of the same shape. """ mask = self.conv_down(shared_features) mask = self.relu(mask) mask = self.conv_up(mask) mask = self.sigmoid(mask) return shared_features * mask class GradNormCallback(tf.keras.callbacks.Callback): """ GradNorm callback for dynamic task weight balancing during training. Adjusts per-task loss weights each epoch so that gradient norms across tasks are balanced, preventing easy tasks from dominating the shared representation. Args: task_weights: tf.Variable of shape (16,) holding current loss weights. alpha: GradNorm asymmetry parameter (higher = more aggressive balancing). initial_losses: Optional initial loss values for computing training rates. """ def __init__( self, task_weights: tf.Variable, alpha: float = 1.5, initial_losses: Optional[np.ndarray] = None, ): super().__init__() self.task_weights = task_weights self.alpha = alpha self.initial_losses = initial_losses def on_epoch_end(self, epoch: int, logs: Optional[dict] = None): """Update task weights based on relative training rates.""" if logs is None or epoch < 1: return # Collect per-task losses from logs current_losses = [] for i in range(NUM_TASKS): loss_key = f"byte_{i}_loss" if loss_key in logs: current_losses.append(logs[loss_key]) else: return # Not all task losses available yet current_losses = np.array(current_losses, dtype=np.float32) # Initialize baseline losses on first valid epoch if self.initial_losses is None: self.initial_losses = current_losses.copy() return # Compute inverse training rates: r_i = L_i(t) / L_i(0) eps = 1e-8 inv_rates = current_losses / (self.initial_losses + eps) mean_inv_rate = np.mean(inv_rates) # Relative inverse training rates rel_inv_rates = inv_rates / (mean_inv_rate + eps) # Target gradient norm ratios: r_i^alpha targets = rel_inv_rates ** self.alpha # Update weights: w_i ← w_i * target_i, then renormalize new_weights = self.task_weights.numpy() * targets new_weights = new_weights / (np.mean(new_weights) + eps) new_weights = new_weights * (NUM_TASKS / np.sum(new_weights)) self.task_weights.assign(new_weights.astype(np.float32)) class SNRMTAN(BaseModel): """ SNR-Guided Multi-Task Attention Network for 16-byte AES key recovery. This model processes a single global trace window covering all 16 byte POI regions and simultaneously predicts the S-Box output for each byte. Task-specific attention masks allow each byte's head to focus on its relevant temporal region within the shared feature representation. Args: input_length: Number of time samples in the global trace window. num_classes: Number of output classes per task (256 for AES S-Box). conv_filters: List of filter counts for each shared conv block. kernel_size: Kernel size for shared conv layers. pool_size: Pool size for average pooling. fc_units: Number of units in each task-specific FC layer. num_fc_layers: Number of FC layers per task head. use_attention: Whether to use MTAN soft-attention modules. use_gradnorm: Whether to enable GradNorm dynamic weight balancing. snr_init: Whether to initialize task weights from SNR values. gradnorm_alpha: GradNorm asymmetry parameter. """ def __init__( self, input_length: int = GLOBAL_WINDOW_SIZE, num_classes: int = NUM_CLASSES, conv_filters: Optional[List[int]] = None, kernel_size: int = CNN_DEFAULTS["kernel_size"], pool_size: int = CNN_DEFAULTS["pool_size"], fc_units: int = CNN_DEFAULTS["fc_units"], num_fc_layers: int = CNN_DEFAULTS["num_fc_layers"], use_attention: bool = True, use_gradnorm: bool = True, snr_init: bool = True, gradnorm_alpha: float = 1.5, ) -> None: super().__init__(input_shape=(input_length, 1), num_classes=num_classes) self.input_length = input_length self.conv_filters = conv_filters or list(CNN_DEFAULTS["conv_filters"]) self.kernel_size = kernel_size self.pool_size = pool_size self.fc_units = fc_units self.num_fc_layers = num_fc_layers self.use_attention = use_attention self.use_gradnorm = use_gradnorm self.snr_init = snr_init self.gradnorm_alpha = gradnorm_alpha # Task weights (initialized later in build or compile) self._task_weights: Optional[tf.Variable] = None self._gradnorm_callback: Optional[GradNormCallback] = None def _init_task_weights(self) -> tf.Variable: """ Initialize per-task loss weights. If snr_init is True, weights are set inversely proportional to each byte's peak SNR value, normalized so they sum to NUM_TASKS. Otherwise, uniform weights (all 1.0) are used. """ if self.snr_init: snr_values = np.array( [BYTE_PEAK_SNR[i] for i in range(NUM_TASKS)], dtype=np.float32 ) inv_snr = 1.0 / (snr_values + 1e-8) weights = inv_snr * (NUM_TASKS / np.sum(inv_snr)) else: weights = np.ones(NUM_TASKS, dtype=np.float32) return tf.Variable(weights, trainable=False, name="task_weights") def build(self) -> tf.keras.Model: """ Construct the SNR-MTAN Keras model. The architecture uses GlobalAveragePooling1D before the FC layers in each task head to reduce the spatial dimension to the channel dimension. This keeps the model tractable for the large global input window (32,272 samples) while preserving the attention mechanism's ability to focus on byte-specific temporal regions. Returns: A Keras Model with one input and 16 outputs (one per byte). """ inputs = tf.keras.Input( shape=self.input_shape, name="trace_input" ) # ── Shared Backbone ────────────────────────────────────────────── shared = inputs block_outputs = [] for block_idx, filters in enumerate(self.conv_filters): shared = tf.keras.layers.Conv1D( filters=filters, kernel_size=self.kernel_size, padding="same", kernel_initializer="glorot_uniform", name=f"shared_conv_{block_idx}", )(shared) shared = tf.keras.layers.BatchNormalization( name=f"shared_bn_{block_idx}" )(shared) shared = tf.keras.layers.ReLU( name=f"shared_relu_{block_idx}" )(shared) shared = tf.keras.layers.AveragePooling1D( pool_size=self.pool_size, name=f"shared_pool_{block_idx}", )(shared) block_outputs.append(shared) # ── Task-Specific Branches ─────────────────────────────────────── task_outputs = [] for task_id in range(NUM_TASKS): if self.use_attention: # Apply attention at the final conv block only # (applying at all 5 blocks would create 80 attention modules # which is expensive; the final block captures the most # abstract features and is where byte-specific focus matters most) attn = SoftAttentionBlock( filters=self.conv_filters[-1], task_id=task_id, block_id=len(self.conv_filters) - 1, )(block_outputs[-1]) task_repr = attn else: # No attention: all tasks share the same final representation task_repr = block_outputs[-1] # GlobalAveragePooling1D reduces (batch, time, channels) -> (batch, channels) # This is critical for the large input: without it, flatten would produce # 1008*512 = 516K features per head, making FC layers infeasible. x = tf.keras.layers.GlobalAveragePooling1D( name=f"byte_{task_id}_gap" )(task_repr) # Task-specific classification head for fc_idx in range(self.num_fc_layers): x = tf.keras.layers.Dense( self.fc_units, activation="relu", kernel_initializer="glorot_uniform", name=f"byte_{task_id}_fc_{fc_idx}", )(x) output = tf.keras.layers.Dense( self.num_classes, activation="softmax", name=f"byte_{task_id}", )(x) task_outputs.append(output) self._model = tf.keras.Model( inputs=inputs, outputs=task_outputs, name="SNR_MTAN", ) return self._model def compile(self, learning_rate: float = 1e-5) -> tf.keras.Model: """ Build and compile the model with per-task weighted losses. Overrides the base class to set up multi-task loss weighting and optionally the GradNorm callback. Args: learning_rate: Learning rate for the RMSprop optimizer. Returns: The compiled Keras model. """ if self._model is None: self._model = self.build() # Initialize task weights self._task_weights = self._init_task_weights() # Create per-task loss dict and loss weights dict loss_dict = {} loss_weights_dict = {} for i in range(NUM_TASKS): output_name = f"byte_{i}" loss_dict[output_name] = "categorical_crossentropy" loss_weights_dict[output_name] = float(self._task_weights.numpy()[i]) self._model.compile( optimizer=tf.keras.optimizers.RMSprop(learning_rate=learning_rate), loss=loss_dict, loss_weights=loss_weights_dict, metrics={f"byte_{i}": ["accuracy"] for i in range(NUM_TASKS)}, ) # Set up GradNorm callback if requested if self.use_gradnorm: self._gradnorm_callback = GradNormCallback( task_weights=self._task_weights, alpha=self.gradnorm_alpha, ) return self._model @property def task_weights(self) -> Optional[tf.Variable]: """Access the current task loss weights.""" return self._task_weights @property def gradnorm_callback(self) -> Optional[GradNormCallback]: """Access the GradNorm callback (None if GradNorm is disabled).""" return self._gradnorm_callback def get_config(self) -> Dict[str, Any]: """Return architecture hyperparameters for logging.""" config = { "model_type": "mtan", "architecture": "SNR-MTAN", "input_length": self.input_length, "num_classes": self.num_classes, "num_tasks": NUM_TASKS, "conv_filters": self.conv_filters, "kernel_size": self.kernel_size, "pool_size": self.pool_size, "fc_units": self.fc_units, "num_fc_layers": self.num_fc_layers, "use_attention": self.use_attention, "use_gradnorm": self.use_gradnorm, "snr_init": self.snr_init, "gradnorm_alpha": self.gradnorm_alpha, } if self._model is not None: config["total_params"] = self.model.count_params() if self._task_weights is not None: config["initial_task_weights"] = self._task_weights.numpy().tolist() return config