""" Multi-Task Learning Models for ASCAD ===================================== Implements two architectures for simultaneous 16-byte key recovery: 1. **HPS (Hard Parameter Sharing)**: Shared CNN backbone + shared FC layers + 16 independent softmax heads. Based on Marquet & Oswald (COSADE 2024). 2. **MTAN-Lite (Simplified Multi-Task Attention Network)**: Shared CNN backbone with a single soft-attention module at the final conv block + per-task heads. Novel contribution: SNR-guided attention initialization. Mixed Precision Compatibility: When ``mixed_float16`` global policy is active, all intermediate layers compute in float16 for speed, but the final softmax output layers are explicitly set to ``dtype='float32'`` to avoid numerical instability in cross-entropy loss computation. This follows the recommended practice from Micikevicius et al. (ICLR 2018) and the TensorFlow mixed precision guide. References: - Marquet & Oswald, "Exploring Multi-Task Learning in the Context of Masked AES Implementations", COSADE 2024. - Liu et al., "End-to-End Multi-Task Learning with Attention", CVPR 2019. - Prouff et al., "Study of Deep Learning Techniques for Side-Channel Analysis and Introduction to ASCAD Database", 2019. - Micikevicius et al., "Mixed Precision Training", ICLR 2018. """ import logging from typing import Dict, List, Optional import numpy as np import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers from ..constants import ( BYTE_PEAK_SNR, GLOBAL_WINDOW_SIZE, NUM_CLASSES, ) from .base import BaseModel from ..gradnorm import build_per_output_loss_metrics from ..spectral_decoupling import get_spectral_decoupling_regularizer logger = logging.getLogger(__name__) NUM_TASKS = 16 # ============================================================================ # Shared CNN Backbone (used by both HPS and MTAN-Lite) # ============================================================================ def _build_shared_backbone( input_shape: tuple, conv_filters: List[int], kernel_size: int = 11, pool_size: int = 2, dropout_rate: float = 0.0, name: str = "shared_backbone", ) -> tuple: """ Build a shared 1D CNN backbone. Architecture follows the ASCAD CNNbest pattern (Prouff et al., 2019) but with configurable filter counts for multi-task use. Args: input_shape: Shape of input traces (T, 1). conv_filters: Number of filters per conv block. kernel_size: Convolution kernel size. pool_size: Average pooling size. dropout_rate: Dropout rate after each block (0 = no dropout). name: Name prefix for layers. Returns: Tuple of (input_tensor, backbone_output_tensor). """ inp = layers.Input(shape=input_shape, name=f"{name}_input") x = inp for i, filters in enumerate(conv_filters): x = layers.Conv1D( filters, kernel_size, padding="same", activation="relu", kernel_initializer="he_uniform", name=f"{name}_conv{i}", )(x) x = layers.BatchNormalization(name=f"{name}_bn{i}")(x) x = layers.AveragePooling1D(pool_size, name=f"{name}_pool{i}")(x) if dropout_rate > 0: x = layers.Dropout(dropout_rate, name=f"{name}_drop{i}")(x) return inp, x # ============================================================================ # Option A: Hard Parameter Sharing (HPS) # ============================================================================ class HPSModel(BaseModel): """ Hard Parameter Sharing multi-task model. Architecture (following Marquet & Oswald, COSADE 2024): - Shared CNN backbone (5 conv blocks) - Global Average Pooling - Shared FC layers (all tasks share the same FC weights) - 16 independent softmax heads (Dense(256) per byte) This is the "md" (maximal dependency) configuration from Marquet & Oswald, which they found to be the most effective for masked AES implementations. """ def __init__( self, conv_filters: Optional[List[int]] = None, kernel_size: int = 11, pool_size: int = 2, fc_units: int = 512, num_fc_layers: int = 2, dropout_rate: float = 0.2, label_smoothing: float = 0.0, spectral_decoupling_lambda: float = 0.0, ) -> None: """ Args: conv_filters: Filters per conv block. Default: [64, 128, 256, 256, 512]. kernel_size: Conv kernel size. pool_size: Pooling size. fc_units: Units in shared FC layers. num_fc_layers: Number of shared FC layers. dropout_rate: Dropout rate. spectral_decoupling_lambda: L2 logit regularization strength (0.0 = disabled). Penalizes large pre-softmax logits to prevent gradient starvation (Pezeshki et al., NeurIPS 2021). """ if conv_filters is None: conv_filters = [64, 128, 256, 256, 512] self.conv_filters = conv_filters self.kernel_size = kernel_size self.pool_size = pool_size self.fc_units = fc_units self.num_fc_layers = num_fc_layers self.dropout_rate = dropout_rate self.label_smoothing = label_smoothing self.spectral_decoupling_lambda = spectral_decoupling_lambda def build(self) -> keras.Model: """Build the HPS multi-task Keras model.""" input_shape = (GLOBAL_WINDOW_SIZE, 1) # Shared backbone inp, backbone_out = _build_shared_backbone( input_shape=input_shape, conv_filters=self.conv_filters, kernel_size=self.kernel_size, pool_size=self.pool_size, dropout_rate=self.dropout_rate, name="hps", ) # Global Average Pooling (eliminates large flatten dimensions) x = layers.GlobalAveragePooling1D(name="hps_gap")(backbone_out) # Shared FC layers (all 16 tasks share these weights) for i in range(self.num_fc_layers): x = layers.Dense( self.fc_units, activation="relu", kernel_initializer="he_uniform", name=f"hps_shared_fc{i}", )(x) x = layers.BatchNormalization(name=f"hps_shared_fc_bn{i}")(x) x = layers.Dropout(self.dropout_rate, name=f"hps_shared_fc_drop{i}")(x) # 16 independent softmax heads # Spectral Decoupling: L2 activity regularizer on pre-softmax logits sd_reg = get_spectral_decoupling_regularizer( self.spectral_decoupling_lambda ) outputs = {} for byte_idx in range(NUM_TASKS): # dtype='float32' ensures softmax outputs remain in FP32 even # when mixed precision (mixed_float16) is active. This prevents # numerical instability in cross-entropy loss computation. outputs[f"byte_{byte_idx}"] = layers.Dense( NUM_CLASSES, activation="softmax", activity_regularizer=sd_reg, dtype="float32", name=f"byte_{byte_idx}", )(x) model = keras.Model(inputs=inp, outputs=outputs, name="hps_mtl") logger.info( "Built HPS model: %s params, conv=%s, fc=%d×%d", f"{model.count_params():,}", self.conv_filters, self.num_fc_layers, self.fc_units, ) return model def compile(self, learning_rate: float = 5e-4) -> keras.Model: """Compile the model with Adam optimizer and equal task weights.""" model = self.build() if self.label_smoothing > 0: losses = { f"byte_{i}": tf.keras.losses.CategoricalCrossentropy( label_smoothing=self.label_smoothing ) for i in range(NUM_TASKS) } else: losses = {f"byte_{i}": "categorical_crossentropy" for i in range(NUM_TASKS)} # Equal weights for all tasks (Marquet & Oswald finding) loss_weights = {f"byte_{i}": 1.0 for i in range(NUM_TASKS)} # Per-output loss metrics for GradNorm (zero extra memory) metrics = build_per_output_loss_metrics( label_smoothing=self.label_smoothing ) model.compile( optimizer=keras.optimizers.Adam(learning_rate=learning_rate), loss=losses, loss_weights=loss_weights, metrics=metrics, ) return model def get_config(self) -> Dict: return { "model_type": "hps", "conv_filters": self.conv_filters, "kernel_size": self.kernel_size, "pool_size": self.pool_size, "fc_units": self.fc_units, "num_fc_layers": self.num_fc_layers, "dropout_rate": self.dropout_rate, "label_smoothing": self.label_smoothing, "spectral_decoupling_lambda": self.spectral_decoupling_lambda, } # ============================================================================ # Option B: Simplified MTAN (MTAN-Lite) # ============================================================================ class SoftAttentionBlock(layers.Layer): """ Soft attention module applied to a single feature map. Implements the attention mechanism from Liu et al. (CVPR 2019), simplified for 1D signals: attention_mask = sigmoid(Conv1D_1x1(Conv1D_1x1(features))) attended = features * attention_mask Each task gets its own attention block to learn task-specific feature selection from the shared representation. """ def __init__(self, channels: int, bottleneck_ratio: int = 4, **kwargs): super().__init__(**kwargs) self.channels = channels self.bottleneck = max(channels // bottleneck_ratio, 16) def build(self, input_shape): self.conv_down = layers.Conv1D( self.bottleneck, 1, padding="same", activation="relu", kernel_initializer="he_uniform", ) self.conv_up = layers.Conv1D( self.channels, 1, padding="same", activation="sigmoid", kernel_initializer="glorot_uniform", ) super().build(input_shape) def call(self, inputs, training=None): att = self.conv_down(inputs) att = self.conv_up(att) return inputs * att class MTANLiteModel(BaseModel): """ Simplified Multi-Task Attention Network (MTAN-Lite). Novel contribution: Applies soft attention at the FINAL conv block only (not all blocks as in the original MTAN), with optional SNR-guided initialization of task weights. Architecture: - Shared CNN backbone (5 conv blocks, no attention) - 16 task-specific soft attention modules on the final feature map - Per-task: GlobalAveragePooling → FC(256) → Softmax(256) This is a middle ground between the full MTAN (too heavy) and simple HPS (no task-specific feature selection). The hypothesis is that task-specific attention at the final layer allows each byte to focus on its relevant POI region within the global window. """ def __init__( self, conv_filters: Optional[List[int]] = None, kernel_size: int = 11, pool_size: int = 2, head_fc_units: int = 256, dropout_rate: float = 0.2, snr_init: bool = False, bottleneck_ratio: int = 4, label_smoothing: float = 0.0, spectral_decoupling_lambda: float = 0.0, ) -> None: """ Args: conv_filters: Filters per conv block. Default: [64, 128, 256, 256, 512]. kernel_size: Conv kernel size. pool_size: Pooling size. head_fc_units: FC units in each task head. dropout_rate: Dropout rate. snr_init: Whether to use SNR-guided loss weight initialization. bottleneck_ratio: Attention bottleneck compression ratio. label_smoothing: Label smoothing factor (0.0 = no smoothing). """ if conv_filters is None: conv_filters = [64, 128, 256, 256, 512] self.conv_filters = conv_filters self.kernel_size = kernel_size self.pool_size = pool_size self.head_fc_units = head_fc_units self.dropout_rate = dropout_rate self.snr_init = snr_init self.bottleneck_ratio = bottleneck_ratio self.label_smoothing = label_smoothing self.spectral_decoupling_lambda = spectral_decoupling_lambda self._task_weights = None def _compute_snr_weights(self) -> Dict[str, float]: """ Compute task weights inversely proportional to SNR. Bytes with lower SNR (harder to attack) get higher weight, encouraging the model to allocate more capacity to difficult bytes. Returns: Dictionary mapping "byte_i" to weight value. """ snr_values = np.array([BYTE_PEAK_SNR[i] for i in range(NUM_TASKS)]) # Inverse SNR: lower SNR → higher weight inv_snr = 1.0 / (snr_values + 1e-6) # Normalize so mean weight = 1.0 weights = inv_snr / inv_snr.mean() self._task_weights = weights return {f"byte_{i}": float(weights[i]) for i in range(NUM_TASKS)} def build(self) -> keras.Model: """Build the MTAN-Lite multi-task Keras model.""" input_shape = (GLOBAL_WINDOW_SIZE, 1) final_channels = self.conv_filters[-1] # Shared backbone (no attention in backbone) inp, backbone_out = _build_shared_backbone( input_shape=input_shape, conv_filters=self.conv_filters, kernel_size=self.kernel_size, pool_size=self.pool_size, dropout_rate=self.dropout_rate, name="mtan_lite", ) # Per-task: attention → GAP → FC → softmax # Spectral Decoupling: L2 activity regularizer on pre-softmax logits sd_reg = get_spectral_decoupling_regularizer( self.spectral_decoupling_lambda ) outputs = {} for byte_idx in range(NUM_TASKS): # Task-specific soft attention on the final feature map att = SoftAttentionBlock( channels=final_channels, bottleneck_ratio=self.bottleneck_ratio, name=f"att_byte_{byte_idx}", )(backbone_out) # Global Average Pooling x = layers.GlobalAveragePooling1D( name=f"gap_byte_{byte_idx}" )(att) # Small FC head x = layers.Dense( self.head_fc_units, activation="relu", kernel_initializer="he_uniform", name=f"fc_byte_{byte_idx}", )(x) x = layers.BatchNormalization(name=f"bn_byte_{byte_idx}")(x) x = layers.Dropout(self.dropout_rate, name=f"drop_byte_{byte_idx}")(x) # Softmax output (with optional spectral decoupling) # dtype='float32' ensures softmax outputs remain in FP32 even # when mixed precision (mixed_float16) is active. This prevents # numerical instability in cross-entropy loss computation. outputs[f"byte_{byte_idx}"] = layers.Dense( NUM_CLASSES, activation="softmax", activity_regularizer=sd_reg, dtype="float32", name=f"byte_{byte_idx}", )(x) model = keras.Model(inputs=inp, outputs=outputs, name="mtan_lite") logger.info( "Built MTAN-Lite model: %s params, conv=%s, head_fc=%d, " "snr_init=%s", f"{model.count_params():,}", self.conv_filters, self.head_fc_units, self.snr_init, ) return model def compile(self, learning_rate: float = 5e-4) -> keras.Model: """Compile the model with Adam optimizer.""" model = self.build() if self.label_smoothing > 0: losses = { f"byte_{i}": tf.keras.losses.CategoricalCrossentropy( label_smoothing=self.label_smoothing ) for i in range(NUM_TASKS) } else: losses = {f"byte_{i}": "categorical_crossentropy" for i in range(NUM_TASKS)} if self.snr_init: loss_weights = self._compute_snr_weights() logger.info( "SNR-guided weights: min=%.2f (byte %d), max=%.2f (byte %d)", min(loss_weights.values()), min(loss_weights, key=loss_weights.get), max(loss_weights.values()), max(loss_weights, key=loss_weights.get), ) else: loss_weights = {f"byte_{i}": 1.0 for i in range(NUM_TASKS)} # Per-output loss metrics for GradNorm (zero extra memory) metrics = build_per_output_loss_metrics( label_smoothing=self.label_smoothing ) model.compile( optimizer=keras.optimizers.Adam(learning_rate=learning_rate), loss=losses, loss_weights=loss_weights, metrics=metrics, ) return model def get_config(self) -> Dict: return { "model_type": "mtan_lite", "conv_filters": self.conv_filters, "kernel_size": self.kernel_size, "pool_size": self.pool_size, "head_fc_units": self.head_fc_units, "dropout_rate": self.dropout_rate, "snr_init": self.snr_init, "bottleneck_ratio": self.bottleneck_ratio, "label_smoothing": self.label_smoothing, "spectral_decoupling_lambda": self.spectral_decoupling_lambda, "task_weights": ( self._task_weights.tolist() if self._task_weights is not None else None ), }