lemousehunter
feat: training speed optimizations — mixed precision, vectorized augmentation, cached eval predictions
1fe1a19 | """ | |
| Multi-Task Learning Models for ASCAD | |
| ===================================== | |
| Implements two architectures for simultaneous 16-byte key recovery: | |
| 1. **HPS (Hard Parameter Sharing)**: Shared CNN backbone + shared FC layers + | |
| 16 independent softmax heads. Based on Marquet & Oswald (COSADE 2024). | |
| 2. **MTAN-Lite (Simplified Multi-Task Attention Network)**: Shared CNN backbone | |
| with a single soft-attention module at the final conv block + per-task heads. | |
| Novel contribution: SNR-guided attention initialization. | |
| Mixed Precision Compatibility: | |
| When ``mixed_float16`` global policy is active, all intermediate layers | |
| compute in float16 for speed, but the final softmax output layers are | |
| explicitly set to ``dtype='float32'`` to avoid numerical instability in | |
| cross-entropy loss computation. This follows the recommended practice | |
| from Micikevicius et al. (ICLR 2018) and the TensorFlow mixed precision | |
| guide. | |
| References: | |
| - Marquet & Oswald, "Exploring Multi-Task Learning in the Context of | |
| Masked AES Implementations", COSADE 2024. | |
| - Liu et al., "End-to-End Multi-Task Learning with Attention", CVPR 2019. | |
| - Prouff et al., "Study of Deep Learning Techniques for Side-Channel | |
| Analysis and Introduction to ASCAD Database", 2019. | |
| - Micikevicius et al., "Mixed Precision Training", ICLR 2018. | |
| """ | |
| import logging | |
| from typing import Dict, List, Optional | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from tensorflow.keras import layers | |
| from ..constants import ( | |
| BYTE_PEAK_SNR, | |
| GLOBAL_WINDOW_SIZE, | |
| NUM_CLASSES, | |
| ) | |
| from .base import BaseModel | |
| from ..gradnorm import build_per_output_loss_metrics | |
| from ..spectral_decoupling import get_spectral_decoupling_regularizer | |
| logger = logging.getLogger(__name__) | |
| NUM_TASKS = 16 | |
| # ============================================================================ | |
| # Shared CNN Backbone (used by both HPS and MTAN-Lite) | |
| # ============================================================================ | |
| def _build_shared_backbone( | |
| input_shape: tuple, | |
| conv_filters: List[int], | |
| kernel_size: int = 11, | |
| pool_size: int = 2, | |
| dropout_rate: float = 0.0, | |
| name: str = "shared_backbone", | |
| ) -> tuple: | |
| """ | |
| Build a shared 1D CNN backbone. | |
| Architecture follows the ASCAD CNNbest pattern (Prouff et al., 2019) | |
| but with configurable filter counts for multi-task use. | |
| Args: | |
| input_shape: Shape of input traces (T, 1). | |
| conv_filters: Number of filters per conv block. | |
| kernel_size: Convolution kernel size. | |
| pool_size: Average pooling size. | |
| dropout_rate: Dropout rate after each block (0 = no dropout). | |
| name: Name prefix for layers. | |
| Returns: | |
| Tuple of (input_tensor, backbone_output_tensor). | |
| """ | |
| inp = layers.Input(shape=input_shape, name=f"{name}_input") | |
| x = inp | |
| for i, filters in enumerate(conv_filters): | |
| x = layers.Conv1D( | |
| filters, kernel_size, | |
| padding="same", | |
| activation="relu", | |
| kernel_initializer="he_uniform", | |
| name=f"{name}_conv{i}", | |
| )(x) | |
| x = layers.BatchNormalization(name=f"{name}_bn{i}")(x) | |
| x = layers.AveragePooling1D(pool_size, name=f"{name}_pool{i}")(x) | |
| if dropout_rate > 0: | |
| x = layers.Dropout(dropout_rate, name=f"{name}_drop{i}")(x) | |
| return inp, x | |
| # ============================================================================ | |
| # Option A: Hard Parameter Sharing (HPS) | |
| # ============================================================================ | |
| class HPSModel(BaseModel): | |
| """ | |
| Hard Parameter Sharing multi-task model. | |
| Architecture (following Marquet & Oswald, COSADE 2024): | |
| - Shared CNN backbone (5 conv blocks) | |
| - Global Average Pooling | |
| - Shared FC layers (all tasks share the same FC weights) | |
| - 16 independent softmax heads (Dense(256) per byte) | |
| This is the "md" (maximal dependency) configuration from Marquet & Oswald, | |
| which they found to be the most effective for masked AES implementations. | |
| """ | |
| def __init__( | |
| self, | |
| conv_filters: Optional[List[int]] = None, | |
| kernel_size: int = 11, | |
| pool_size: int = 2, | |
| fc_units: int = 512, | |
| num_fc_layers: int = 2, | |
| dropout_rate: float = 0.2, | |
| label_smoothing: float = 0.0, | |
| spectral_decoupling_lambda: float = 0.0, | |
| ) -> None: | |
| """ | |
| Args: | |
| conv_filters: Filters per conv block. Default: [64, 128, 256, 256, 512]. | |
| kernel_size: Conv kernel size. | |
| pool_size: Pooling size. | |
| fc_units: Units in shared FC layers. | |
| num_fc_layers: Number of shared FC layers. | |
| dropout_rate: Dropout rate. | |
| spectral_decoupling_lambda: L2 logit regularization strength | |
| (0.0 = disabled). Penalizes large pre-softmax logits to | |
| prevent gradient starvation (Pezeshki et al., NeurIPS 2021). | |
| """ | |
| if conv_filters is None: | |
| conv_filters = [64, 128, 256, 256, 512] | |
| self.conv_filters = conv_filters | |
| self.kernel_size = kernel_size | |
| self.pool_size = pool_size | |
| self.fc_units = fc_units | |
| self.num_fc_layers = num_fc_layers | |
| self.dropout_rate = dropout_rate | |
| self.label_smoothing = label_smoothing | |
| self.spectral_decoupling_lambda = spectral_decoupling_lambda | |
| def build(self) -> keras.Model: | |
| """Build the HPS multi-task Keras model.""" | |
| input_shape = (GLOBAL_WINDOW_SIZE, 1) | |
| # Shared backbone | |
| inp, backbone_out = _build_shared_backbone( | |
| input_shape=input_shape, | |
| conv_filters=self.conv_filters, | |
| kernel_size=self.kernel_size, | |
| pool_size=self.pool_size, | |
| dropout_rate=self.dropout_rate, | |
| name="hps", | |
| ) | |
| # Global Average Pooling (eliminates large flatten dimensions) | |
| x = layers.GlobalAveragePooling1D(name="hps_gap")(backbone_out) | |
| # Shared FC layers (all 16 tasks share these weights) | |
| for i in range(self.num_fc_layers): | |
| x = layers.Dense( | |
| self.fc_units, | |
| activation="relu", | |
| kernel_initializer="he_uniform", | |
| name=f"hps_shared_fc{i}", | |
| )(x) | |
| x = layers.BatchNormalization(name=f"hps_shared_fc_bn{i}")(x) | |
| x = layers.Dropout(self.dropout_rate, name=f"hps_shared_fc_drop{i}")(x) | |
| # 16 independent softmax heads | |
| # Spectral Decoupling: L2 activity regularizer on pre-softmax logits | |
| sd_reg = get_spectral_decoupling_regularizer( | |
| self.spectral_decoupling_lambda | |
| ) | |
| outputs = {} | |
| for byte_idx in range(NUM_TASKS): | |
| # dtype='float32' ensures softmax outputs remain in FP32 even | |
| # when mixed precision (mixed_float16) is active. This prevents | |
| # numerical instability in cross-entropy loss computation. | |
| outputs[f"byte_{byte_idx}"] = layers.Dense( | |
| NUM_CLASSES, | |
| activation="softmax", | |
| activity_regularizer=sd_reg, | |
| dtype="float32", | |
| name=f"byte_{byte_idx}", | |
| )(x) | |
| model = keras.Model(inputs=inp, outputs=outputs, name="hps_mtl") | |
| logger.info( | |
| "Built HPS model: %s params, conv=%s, fc=%d×%d", | |
| f"{model.count_params():,}", | |
| self.conv_filters, | |
| self.num_fc_layers, | |
| self.fc_units, | |
| ) | |
| return model | |
| def compile(self, learning_rate: float = 5e-4) -> keras.Model: | |
| """Compile the model with Adam optimizer and equal task weights.""" | |
| model = self.build() | |
| if self.label_smoothing > 0: | |
| losses = { | |
| f"byte_{i}": tf.keras.losses.CategoricalCrossentropy( | |
| label_smoothing=self.label_smoothing | |
| ) | |
| for i in range(NUM_TASKS) | |
| } | |
| else: | |
| losses = {f"byte_{i}": "categorical_crossentropy" for i in range(NUM_TASKS)} | |
| # Equal weights for all tasks (Marquet & Oswald finding) | |
| loss_weights = {f"byte_{i}": 1.0 for i in range(NUM_TASKS)} | |
| # Per-output loss metrics for GradNorm (zero extra memory) | |
| metrics = build_per_output_loss_metrics( | |
| label_smoothing=self.label_smoothing | |
| ) | |
| model.compile( | |
| optimizer=keras.optimizers.Adam(learning_rate=learning_rate), | |
| loss=losses, | |
| loss_weights=loss_weights, | |
| metrics=metrics, | |
| ) | |
| return model | |
| def get_config(self) -> Dict: | |
| return { | |
| "model_type": "hps", | |
| "conv_filters": self.conv_filters, | |
| "kernel_size": self.kernel_size, | |
| "pool_size": self.pool_size, | |
| "fc_units": self.fc_units, | |
| "num_fc_layers": self.num_fc_layers, | |
| "dropout_rate": self.dropout_rate, | |
| "label_smoothing": self.label_smoothing, | |
| "spectral_decoupling_lambda": self.spectral_decoupling_lambda, | |
| } | |
| # ============================================================================ | |
| # Option B: Simplified MTAN (MTAN-Lite) | |
| # ============================================================================ | |
| class SoftAttentionBlock(layers.Layer): | |
| """ | |
| Soft attention module applied to a single feature map. | |
| Implements the attention mechanism from Liu et al. (CVPR 2019), | |
| simplified for 1D signals: | |
| attention_mask = sigmoid(Conv1D_1x1(Conv1D_1x1(features))) | |
| attended = features * attention_mask | |
| Each task gets its own attention block to learn task-specific feature | |
| selection from the shared representation. | |
| """ | |
| def __init__(self, channels: int, bottleneck_ratio: int = 4, **kwargs): | |
| super().__init__(**kwargs) | |
| self.channels = channels | |
| self.bottleneck = max(channels // bottleneck_ratio, 16) | |
| def build(self, input_shape): | |
| self.conv_down = layers.Conv1D( | |
| self.bottleneck, 1, padding="same", activation="relu", | |
| kernel_initializer="he_uniform", | |
| ) | |
| self.conv_up = layers.Conv1D( | |
| self.channels, 1, padding="same", activation="sigmoid", | |
| kernel_initializer="glorot_uniform", | |
| ) | |
| super().build(input_shape) | |
| def call(self, inputs, training=None): | |
| att = self.conv_down(inputs) | |
| att = self.conv_up(att) | |
| return inputs * att | |
| class MTANLiteModel(BaseModel): | |
| """ | |
| Simplified Multi-Task Attention Network (MTAN-Lite). | |
| Novel contribution: Applies soft attention at the FINAL conv block only | |
| (not all blocks as in the original MTAN), with optional SNR-guided | |
| initialization of task weights. | |
| Architecture: | |
| - Shared CNN backbone (5 conv blocks, no attention) | |
| - 16 task-specific soft attention modules on the final feature map | |
| - Per-task: GlobalAveragePooling → FC(256) → Softmax(256) | |
| This is a middle ground between the full MTAN (too heavy) and simple HPS | |
| (no task-specific feature selection). The hypothesis is that task-specific | |
| attention at the final layer allows each byte to focus on its relevant | |
| POI region within the global window. | |
| """ | |
| def __init__( | |
| self, | |
| conv_filters: Optional[List[int]] = None, | |
| kernel_size: int = 11, | |
| pool_size: int = 2, | |
| head_fc_units: int = 256, | |
| dropout_rate: float = 0.2, | |
| snr_init: bool = False, | |
| bottleneck_ratio: int = 4, | |
| label_smoothing: float = 0.0, | |
| spectral_decoupling_lambda: float = 0.0, | |
| ) -> None: | |
| """ | |
| Args: | |
| conv_filters: Filters per conv block. Default: [64, 128, 256, 256, 512]. | |
| kernel_size: Conv kernel size. | |
| pool_size: Pooling size. | |
| head_fc_units: FC units in each task head. | |
| dropout_rate: Dropout rate. | |
| snr_init: Whether to use SNR-guided loss weight initialization. | |
| bottleneck_ratio: Attention bottleneck compression ratio. | |
| label_smoothing: Label smoothing factor (0.0 = no smoothing). | |
| """ | |
| if conv_filters is None: | |
| conv_filters = [64, 128, 256, 256, 512] | |
| self.conv_filters = conv_filters | |
| self.kernel_size = kernel_size | |
| self.pool_size = pool_size | |
| self.head_fc_units = head_fc_units | |
| self.dropout_rate = dropout_rate | |
| self.snr_init = snr_init | |
| self.bottleneck_ratio = bottleneck_ratio | |
| self.label_smoothing = label_smoothing | |
| self.spectral_decoupling_lambda = spectral_decoupling_lambda | |
| self._task_weights = None | |
| def _compute_snr_weights(self) -> Dict[str, float]: | |
| """ | |
| Compute task weights inversely proportional to SNR. | |
| Bytes with lower SNR (harder to attack) get higher weight, | |
| encouraging the model to allocate more capacity to difficult bytes. | |
| Returns: | |
| Dictionary mapping "byte_i" to weight value. | |
| """ | |
| snr_values = np.array([BYTE_PEAK_SNR[i] for i in range(NUM_TASKS)]) | |
| # Inverse SNR: lower SNR → higher weight | |
| inv_snr = 1.0 / (snr_values + 1e-6) | |
| # Normalize so mean weight = 1.0 | |
| weights = inv_snr / inv_snr.mean() | |
| self._task_weights = weights | |
| return {f"byte_{i}": float(weights[i]) for i in range(NUM_TASKS)} | |
| def build(self) -> keras.Model: | |
| """Build the MTAN-Lite multi-task Keras model.""" | |
| input_shape = (GLOBAL_WINDOW_SIZE, 1) | |
| final_channels = self.conv_filters[-1] | |
| # Shared backbone (no attention in backbone) | |
| inp, backbone_out = _build_shared_backbone( | |
| input_shape=input_shape, | |
| conv_filters=self.conv_filters, | |
| kernel_size=self.kernel_size, | |
| pool_size=self.pool_size, | |
| dropout_rate=self.dropout_rate, | |
| name="mtan_lite", | |
| ) | |
| # Per-task: attention → GAP → FC → softmax | |
| # Spectral Decoupling: L2 activity regularizer on pre-softmax logits | |
| sd_reg = get_spectral_decoupling_regularizer( | |
| self.spectral_decoupling_lambda | |
| ) | |
| outputs = {} | |
| for byte_idx in range(NUM_TASKS): | |
| # Task-specific soft attention on the final feature map | |
| att = SoftAttentionBlock( | |
| channels=final_channels, | |
| bottleneck_ratio=self.bottleneck_ratio, | |
| name=f"att_byte_{byte_idx}", | |
| )(backbone_out) | |
| # Global Average Pooling | |
| x = layers.GlobalAveragePooling1D( | |
| name=f"gap_byte_{byte_idx}" | |
| )(att) | |
| # Small FC head | |
| x = layers.Dense( | |
| self.head_fc_units, | |
| activation="relu", | |
| kernel_initializer="he_uniform", | |
| name=f"fc_byte_{byte_idx}", | |
| )(x) | |
| x = layers.BatchNormalization(name=f"bn_byte_{byte_idx}")(x) | |
| x = layers.Dropout(self.dropout_rate, name=f"drop_byte_{byte_idx}")(x) | |
| # Softmax output (with optional spectral decoupling) | |
| # dtype='float32' ensures softmax outputs remain in FP32 even | |
| # when mixed precision (mixed_float16) is active. This prevents | |
| # numerical instability in cross-entropy loss computation. | |
| outputs[f"byte_{byte_idx}"] = layers.Dense( | |
| NUM_CLASSES, | |
| activation="softmax", | |
| activity_regularizer=sd_reg, | |
| dtype="float32", | |
| name=f"byte_{byte_idx}", | |
| )(x) | |
| model = keras.Model(inputs=inp, outputs=outputs, name="mtan_lite") | |
| logger.info( | |
| "Built MTAN-Lite model: %s params, conv=%s, head_fc=%d, " | |
| "snr_init=%s", | |
| f"{model.count_params():,}", | |
| self.conv_filters, | |
| self.head_fc_units, | |
| self.snr_init, | |
| ) | |
| return model | |
| def compile(self, learning_rate: float = 5e-4) -> keras.Model: | |
| """Compile the model with Adam optimizer.""" | |
| model = self.build() | |
| if self.label_smoothing > 0: | |
| losses = { | |
| f"byte_{i}": tf.keras.losses.CategoricalCrossentropy( | |
| label_smoothing=self.label_smoothing | |
| ) | |
| for i in range(NUM_TASKS) | |
| } | |
| else: | |
| losses = {f"byte_{i}": "categorical_crossentropy" for i in range(NUM_TASKS)} | |
| if self.snr_init: | |
| loss_weights = self._compute_snr_weights() | |
| logger.info( | |
| "SNR-guided weights: min=%.2f (byte %d), max=%.2f (byte %d)", | |
| min(loss_weights.values()), | |
| min(loss_weights, key=loss_weights.get), | |
| max(loss_weights.values()), | |
| max(loss_weights, key=loss_weights.get), | |
| ) | |
| else: | |
| loss_weights = {f"byte_{i}": 1.0 for i in range(NUM_TASKS)} | |
| # Per-output loss metrics for GradNorm (zero extra memory) | |
| metrics = build_per_output_loss_metrics( | |
| label_smoothing=self.label_smoothing | |
| ) | |
| model.compile( | |
| optimizer=keras.optimizers.Adam(learning_rate=learning_rate), | |
| loss=losses, | |
| loss_weights=loss_weights, | |
| metrics=metrics, | |
| ) | |
| return model | |
| def get_config(self) -> Dict: | |
| return { | |
| "model_type": "mtan_lite", | |
| "conv_filters": self.conv_filters, | |
| "kernel_size": self.kernel_size, | |
| "pool_size": self.pool_size, | |
| "head_fc_units": self.head_fc_units, | |
| "dropout_rate": self.dropout_rate, | |
| "snr_init": self.snr_init, | |
| "bottleneck_ratio": self.bottleneck_ratio, | |
| "label_smoothing": self.label_smoothing, | |
| "spectral_decoupling_lambda": self.spectral_decoupling_lambda, | |
| "task_weights": ( | |
| self._task_weights.tolist() if self._task_weights is not None | |
| else None | |
| ), | |
| } | |