""" Localized Multi-Input CNN (LMIC) for ASCAD ========================================== A multi-task architecture that uses 16 separate per-byte POI inputs instead of a single global window. Each byte receives only its 700-sample POI window, eliminating the gradient starvation and task dominance problems documented in HISTORY_LEDGER Sections 1-8. Architecture (original LMIC): - 16 independent Input layers, each (700, 1) - A shared convolutional feature extractor (weight-shared across all 16) - 16 independent classification heads (Dense → Softmax) Architecture (LMIC-TSBN — Task-Specific Batch Normalization): - 16 independent Input layers, each (700, 1) - Shared Conv1D weights applied to all 16 inputs - 16 TASK-SPECIFIC BatchNorm layers per conv layer (one per byte) - Optional sigmoid gates per task per filter (TSσBN soft capacity allocation) - 16 independent classification heads (Dense → Softmax) The TSBN variant fixes the val_loss volatility observed in LMIC v5 (HISTORY_LEDGER Section 10) by preventing shared BN running statistics from being corrupted by gradient interference between bytes. Design rationale (from literature): - Marquet & Oswald (COSADE 2024): Low-level parameter sharing achieves 0% failure rate and 100% key recovery in 25-45 epochs. - Zaid et al. (TCHES 2020): Compact CNNs with few filters and small kernels are optimal for ASCAD. - Perin et al. (TCHES 2022): Per-byte POI windows of ~700 samples capture all leakage information. - Suteu & Serban (ICLR 2025): Task-Specific Sigmoid Batch Normalization (TSσBN) eliminates gradient interference in MTL with <0.5% parameter overhead by replacing shared BN with per-task BN + sigmoid gates. Mixed Precision Compatibility: When ``mixed_float16`` global policy is active, all intermediate layers compute in float16 for speed, but the final softmax output layers are explicitly set to ``dtype='float32'`` to avoid numerical instability. References: - Marquet & Oswald, "Exploring Multi-Task Learning in the Context of Masked AES Implementations", COSADE 2024. - Zaid et al., "Methodology for Efficient CNN Architectures in Profiling Attacks", TCHES 2020. - Perin et al., "Learning When to Stop: A Mutual Information Approach to Fight Overfitting in Profiled Side-Channel Analysis", TCHES 2022. - Micikevicius et al., "Mixed Precision Training", ICLR 2018. - Suteu & Serban, "Simplifying Multi-Task Architectures Through Task-Specific Normalization", ICLR 2025. """ import logging from typing import Any, Dict, List, Optional import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers from ..constants import ( BYTE_POI_WINDOWS, NUM_CLASSES, WINDOW_SIZE, ) from .base import BaseModel from ..spectral_decoupling import get_spectral_decoupling_regularizer logger = logging.getLogger(__name__) NUM_TASKS = 16 # --------------------------------------------------------------------------- # Sigmoid Gate Layer (for TSσBN) # --------------------------------------------------------------------------- class SigmoidGate(layers.Layer): """ Learnable per-filter sigmoid gate for soft capacity allocation. Each filter gets a scalar gate parameter initialized to 0 (sigmoid(0)=0.5), so all filters start equally active. During training, the gate learns to suppress irrelevant filters for this task, enabling soft specialization. This implements the σ-gate from Suteu & Serban (ICLR 2025): output = σ(gate) * input Args: num_filters: Number of filters (channels) to gate. init_value: Initial gate value (0.0 → sigmoid=0.5, all filters active). """ def __init__(self, num_filters: int, init_value: float = 0.0, **kwargs): super().__init__(**kwargs) self.num_filters = num_filters self.init_value = init_value def build(self, input_shape): self.gate = self.add_weight( name="gate", shape=(self.num_filters,), initializer=tf.keras.initializers.Constant(self.init_value), trainable=True, ) super().build(input_shape) def call(self, inputs): # gate shape: (num_filters,) → broadcast over (batch, time, filters) return inputs * tf.sigmoid(self.gate) def get_config(self): config = super().get_config() config.update({ "num_filters": self.num_filters, "init_value": self.init_value, }) return config # --------------------------------------------------------------------------- # Original LMIC Model (shared BN — kept for backward compatibility) # --------------------------------------------------------------------------- class LMICModel(BaseModel): """ Localized Multi-Input CNN (LMIC) for simultaneous 16-byte key recovery. Instead of processing a single 32,272-sample global window, this model receives 16 separate 700-sample inputs (one per byte's POI window). A shared convolutional feature extractor processes each input identically, then 16 independent Dense+Softmax heads produce per-byte predictions. This design eliminates: - Gradient starvation: each byte has its own gradient path - Task dominance: no shared FC bottleneck - GAP information loss: small inputs can be flattened directly - Domain shift memorization: each input is only 700 samples Args: input_length: Length of each per-byte POI window (default: 700). num_classes: Number of output classes (256 for AES S-Box). conv_filters: List of filter counts for shared conv blocks. kernel_size: Kernel size for Conv1D layers. pool_size: Pool size for AveragePooling1D layers. head_units: Number of units in each per-byte head's Dense layer. dropout_rate: Dropout rate for regularization. label_smoothing: Label smoothing factor for cross-entropy loss. spectral_decoupling_lambda: L2 logit regularization strength. """ def __init__( self, input_length: int = WINDOW_SIZE, num_classes: int = NUM_CLASSES, conv_filters: Optional[List[int]] = None, kernel_size: int = 11, pool_size: int = 2, head_units: int = 128, dropout_rate: float = 0.2, label_smoothing: float = 0.0, spectral_decoupling_lambda: float = 0.0, focal_loss: bool = False, focal_gamma: float = 2.0, clipnorm: float = 0.0, jit_compile: bool = False, multi_bit: bool = False, ) -> None: # BaseModel expects input_shape, but LMIC has multiple inputs. # We store the per-byte input shape and override build(). super().__init__( input_shape=(input_length, 1), num_classes=num_classes, ) self.input_length = input_length self.conv_filters = conv_filters or [64, 128, 256] self.kernel_size = kernel_size self.pool_size = pool_size self.head_units = head_units self.dropout_rate = dropout_rate self.label_smoothing = label_smoothing self.spectral_decoupling_lambda = spectral_decoupling_lambda self.focal_loss = focal_loss self.focal_gamma = focal_gamma self.clipnorm = clipnorm self.jit_compile = jit_compile self.multi_bit = multi_bit def _build_shared_conv_block(self, name_prefix: str = "shared") -> keras.Model: """ Build the shared convolutional feature extractor. This is a compact CNN following Zaid et al. (TCHES 2020): Conv1D → BN → ReLU → AvgPool (repeated per filter count) → Flatten The block is instantiated once and called on each of the 16 inputs, so all bytes share the same convolutional weights. Returns: A Keras Model that maps (batch, input_length, 1) → (batch, F). """ inp = layers.Input(shape=(self.input_length, 1), name=f"{name_prefix}_input") x = inp for i, filters in enumerate(self.conv_filters): x = layers.Conv1D( filters=filters, kernel_size=self.kernel_size, padding="same", name=f"{name_prefix}_conv{i}", )(x) x = layers.BatchNormalization(name=f"{name_prefix}_bn{i}")(x) x = layers.Activation("relu", name=f"{name_prefix}_relu{i}")(x) x = layers.AveragePooling1D( pool_size=self.pool_size, name=f"{name_prefix}_pool{i}", )(x) # GlobalAveragePooling reduces (batch, time, filters) → (batch, filters) # This keeps the feature dimension equal to the last conv's filter count, # making per-byte heads compact. Unlike the MTAN-Lite architecture where # GAP on a 32K global window destroyed byte-specific spatial information, # here each input is already a 700-sample byte-specific POI window, so # GAP is safe — it only pools within a single byte's features. x = layers.GlobalAveragePooling1D(name=f"{name_prefix}_gap")(x) return keras.Model(inputs=inp, outputs=x, name=f"{name_prefix}_conv_block") def build(self) -> keras.Model: """ Construct the full LMIC model with 16 inputs and 16 outputs. Architecture: 16 × Input(700, 1) → shared_conv_block → per-byte head → softmax Returns: A compiled-ready Keras Model with named inputs and outputs. """ # Build the shared conv block (instantiated once, called 16 times) shared_conv = self._build_shared_conv_block() # Log the shared block summary feature_dim = shared_conv.output_shape[-1] logger.info( "LMIC shared conv block: %d params, output_dim=%d", shared_conv.count_params(), feature_dim, ) # Get spectral decoupling regularizer if enabled sd_reg = None if self.spectral_decoupling_lambda > 0: sd_reg = get_spectral_decoupling_regularizer( self.spectral_decoupling_lambda ) logger.info( "Spectral Decoupling ENABLED: lambda=%.4f", self.spectral_decoupling_lambda, ) # Build 16 input branches and heads all_inputs = {} # dict for named inputs all_outputs = {} # dict for named outputs (ensures predict() returns dict) for byte_idx in range(NUM_TASKS): # Per-byte input byte_input = layers.Input( shape=(self.input_length, 1), name=f"byte_{byte_idx}_input", ) all_inputs[f"byte_{byte_idx}_input"] = byte_input # Pass through shared conv block features = shared_conv(byte_input) # Per-byte independent head x = layers.Dropout( self.dropout_rate, name=f"byte_{byte_idx}_dropout", )(features) x = layers.Dense( self.head_units, activation="relu", name=f"byte_{byte_idx}_dense", )(x) # Output layer — pinned to float32 for mixed precision safety if self.multi_bit: # Multi-bit mode: 8 independent binary outputs per byte all_outputs[f"byte_{byte_idx}"] = layers.Dense( 8, activation="sigmoid", dtype="float32", kernel_regularizer=sd_reg, name=f"byte_{byte_idx}", )(x) else: # Standard mode: 256-class softmax all_outputs[f"byte_{byte_idx}"] = layers.Dense( self.num_classes, activation="softmax", dtype="float32", kernel_regularizer=sd_reg, name=f"byte_{byte_idx}", )(x) # Build the multi-input, multi-output model # Using dicts ensures model.predict() returns a dict with named keys. model = keras.Model( inputs=all_inputs, outputs=all_outputs, name="LMIC" + ("_MB" if self.multi_bit else ""), ) self._model = model return model def compile(self, learning_rate: float = 5e-4) -> keras.Model: """ Build and compile the LMIC model with per-byte losses. Uses Adam optimizer (standard for SCA) with categorical cross-entropy loss for each of the 16 output heads. Args: learning_rate: Learning rate for Adam optimizer. Returns: The compiled Keras model. """ if self._model is None: self._model = self.build() # Build loss dict for all 16 bytes if self.multi_bit: # Multi-bit mode: binary crossentropy for 8 independent bits loss_fn = keras.losses.BinaryCrossentropy( label_smoothing=self.label_smoothing, ) logger.info( "LMIC Multi-Bit mode: BinaryCrossentropy, label_smoothing=%.2f", self.label_smoothing, ) metrics_dict = { f"byte_{i}": ["binary_accuracy"] for i in range(NUM_TASKS) } elif self.focal_loss: from ..focal_loss import FocalCategoricalCrossentropy loss_fn = FocalCategoricalCrossentropy( gamma=self.focal_gamma, label_smoothing=self.label_smoothing, ) logger.info( "Focal Loss ENABLED: gamma=%.1f, label_smoothing=%.2f", self.focal_gamma, self.label_smoothing, ) metrics_dict = { f"byte_{i}": ["accuracy"] for i in range(NUM_TASKS) } else: loss_fn = keras.losses.CategoricalCrossentropy( label_smoothing=self.label_smoothing, ) metrics_dict = { f"byte_{i}": ["accuracy"] for i in range(NUM_TASKS) } losses = {f"byte_{i}": loss_fn for i in range(NUM_TASKS)} # Build optimizer with optional gradient clipping opt_kwargs = {"learning_rate": learning_rate} if self.clipnorm > 0: opt_kwargs["clipnorm"] = self.clipnorm logger.info("Gradient clipping ENABLED: clipnorm=%.2f", self.clipnorm) self._model.compile( optimizer=keras.optimizers.Adam(**opt_kwargs), loss=losses, metrics=metrics_dict, jit_compile=self.jit_compile, ) if self.jit_compile: logger.info("XLA jit_compile ENABLED") return self._model def get_config(self) -> Dict[str, Any]: """Return architecture hyperparameters for logging.""" return { "architecture": "LMIC", "input_length": self.input_length, "num_classes": self.num_classes, "conv_filters": self.conv_filters, "kernel_size": self.kernel_size, "pool_size": self.pool_size, "head_units": self.head_units, "dropout_rate": self.dropout_rate, "label_smoothing": self.label_smoothing, "spectral_decoupling_lambda": self.spectral_decoupling_lambda, "focal_loss": self.focal_loss, "focal_gamma": self.focal_gamma, "clipnorm": self.clipnorm, "jit_compile": self.jit_compile, "multi_bit": self.multi_bit, "num_tasks": NUM_TASKS, } # --------------------------------------------------------------------------- # LMIC-TSBN: Task-Specific Batch Normalization variant # --------------------------------------------------------------------------- class LMICTSBNModel(BaseModel): """ LMIC with Task-Specific Batch Normalization (LMIC-TSBN). This variant replaces the shared BatchNorm layers in the LMIC conv block with 16 task-specific BN layers (one per byte). The Conv1D weights remain fully shared across all 16 tasks — only the BN parameters (γ, β, running mean, running variance) become per-task. This eliminates the gradient interference that caused val_loss volatility in the original LMIC (HISTORY_LEDGER Section 10), where shared BN running statistics were corrupted by conflicting gradient signals from different bytes. The optional sigmoid gates (TSσBN) add a learnable per-filter gate per task, enabling soft capacity allocation: each byte can learn to suppress filters it doesn't need, allowing the shared conv weights to develop both shared and specialized features. Parameter overhead: - Per-task BN: 16 × 2 × sum(conv_filters) extra trainable params (γ, β per filter per task) + 16 × 2 × sum(conv_filters) non-trainable (running mean/var per task). For [64, 128, 256]: ~28K extra params. - Sigmoid gates: 16 × sum(conv_filters) extra trainable params. For [64, 128, 256]: ~7K extra params. - Total overhead: <2% of model size. Architecture: 16 × Input(700, 1) → For each byte i: shared_conv0(input_i) → task_bn0_i → [σ_gate0_i] → relu → avgpool → shared_conv1(...) → task_bn1_i → [σ_gate1_i] → relu → avgpool → shared_conv2(...) → task_bn2_i → [σ_gate2_i] → relu → avgpool → GAP → dropout → dense_i → softmax_i References: - Suteu & Serban, "Simplifying Multi-Task Architectures Through Task-Specific Normalization", ICLR 2025. - Zaid et al., "Methodology for Efficient CNN Architectures in Profiling Attacks", TCHES 2020. Args: input_length: Length of each per-byte POI window (default: 700). num_classes: Number of output classes (256 for AES S-Box). conv_filters: List of filter counts for shared conv blocks. kernel_size: Kernel size for Conv1D layers. pool_size: Pool size for AveragePooling1D layers. head_units: Number of units in each per-byte head's Dense layer. dropout_rate: Dropout rate for regularization. label_smoothing: Label smoothing factor for cross-entropy loss. spectral_decoupling_lambda: L2 logit regularization strength. use_sigmoid_gates: Whether to add σ-gates for soft capacity allocation. """ def __init__( self, input_length: int = WINDOW_SIZE, num_classes: int = NUM_CLASSES, conv_filters: Optional[List[int]] = None, kernel_size: int = 11, pool_size: int = 2, head_units: int = 128, dropout_rate: float = 0.2, label_smoothing: float = 0.0, spectral_decoupling_lambda: float = 0.0, use_sigmoid_gates: bool = False, focal_loss: bool = False, focal_gamma: float = 2.0, clipnorm: float = 0.0, jit_compile: bool = False, multi_bit: bool = False, ) -> None: super().__init__( input_shape=(input_length, 1), num_classes=num_classes, ) self.input_length = input_length self.conv_filters = conv_filters or [64, 128, 256] self.kernel_size = kernel_size self.pool_size = pool_size self.head_units = head_units self.dropout_rate = dropout_rate self.label_smoothing = label_smoothing self.spectral_decoupling_lambda = spectral_decoupling_lambda self.use_sigmoid_gates = use_sigmoid_gates self.focal_loss = focal_loss self.focal_gamma = focal_gamma self.clipnorm = clipnorm self.jit_compile = jit_compile self.multi_bit = multi_bit def build(self) -> keras.Model: """ Construct the LMIC-TSBN model with shared conv + per-task BN. Instead of wrapping the conv block as a sub-Model (which would force shared BN), we build the computation graph explicitly: shared Conv1D layers are instantiated once and called on each byte's input, while BatchNorm (and optional sigmoid gates) are instantiated per-task. Returns: A compiled-ready Keras Model with named inputs and outputs. """ # --- Create shared Conv1D layers (instantiated once) --- shared_convs = [] for i, filters in enumerate(self.conv_filters): conv = layers.Conv1D( filters=filters, kernel_size=self.kernel_size, padding="same", name=f"shared_conv{i}", ) shared_convs.append(conv) shared_pools = [] for i in range(len(self.conv_filters)): pool = layers.AveragePooling1D( pool_size=self.pool_size, name=f"shared_pool{i}", ) shared_pools.append(pool) shared_gap = layers.GlobalAveragePooling1D(name="shared_gap") # --- Create per-task BN layers and optional sigmoid gates --- # task_bns[layer_idx][byte_idx] = BatchNormalization layer # task_gates[layer_idx][byte_idx] = SigmoidGate layer (if enabled) task_bns = [] task_gates = [] for i, filters in enumerate(self.conv_filters): bn_list = [] gate_list = [] for byte_idx in range(NUM_TASKS): bn = layers.BatchNormalization( name=f"tsbn_L{i}_byte{byte_idx}", ) bn_list.append(bn) if self.use_sigmoid_gates: gate = SigmoidGate( num_filters=filters, name=f"gate_L{i}_byte{byte_idx}", ) gate_list.append(gate) task_bns.append(bn_list) task_gates.append(gate_list) # Get spectral decoupling regularizer if enabled sd_reg = None if self.spectral_decoupling_lambda > 0: sd_reg = get_spectral_decoupling_regularizer( self.spectral_decoupling_lambda ) logger.info( "Spectral Decoupling ENABLED: lambda=%.4f", self.spectral_decoupling_lambda, ) # --- Build 16 input branches --- all_inputs = {} all_outputs = {} for byte_idx in range(NUM_TASKS): byte_input = layers.Input( shape=(self.input_length, 1), name=f"byte_{byte_idx}_input", ) all_inputs[f"byte_{byte_idx}_input"] = byte_input x = byte_input # Pass through shared conv layers with per-task BN for layer_idx in range(len(self.conv_filters)): # Shared conv weights x = shared_convs[layer_idx](x) # Task-specific BatchNorm x = task_bns[layer_idx][byte_idx](x) # Optional sigmoid gate for soft capacity allocation if self.use_sigmoid_gates and task_gates[layer_idx]: x = task_gates[layer_idx][byte_idx](x) # Activation and pooling (shared, stateless) x = layers.Activation( "relu", name=f"relu_L{layer_idx}_byte{byte_idx}", )(x) x = shared_pools[layer_idx](x) # Global average pooling features = shared_gap(x) # Per-byte independent head x = layers.Dropout( self.dropout_rate, name=f"byte_{byte_idx}_dropout", )(features) x = layers.Dense( self.head_units, activation="relu", name=f"byte_{byte_idx}_dense", )(x) # Output layer — pinned to float32 for mixed precision safety if self.multi_bit: # Multi-bit mode: 8 independent binary outputs per byte # (Wu et al., TCHES 2024) all_outputs[f"byte_{byte_idx}"] = layers.Dense( 8, activation="sigmoid", dtype="float32", kernel_regularizer=sd_reg, name=f"byte_{byte_idx}", )(x) else: # Standard mode: 256-class softmax all_outputs[f"byte_{byte_idx}"] = layers.Dense( self.num_classes, activation="softmax", dtype="float32", kernel_regularizer=sd_reg, name=f"byte_{byte_idx}", )(x) model = keras.Model( inputs=all_inputs, outputs=all_outputs, name="LMIC_TSBN" + ("_MB" if self.multi_bit else ""), ) # Log architecture summary total_params = model.count_params() num_shared_conv_params = sum( sum(w.numpy().size for w in conv.weights) for conv in shared_convs ) num_tsbn_params = sum( sum(w.numpy().size for w in bn.weights) for bn_list in task_bns for bn in bn_list ) num_gate_params = 0 if self.use_sigmoid_gates: num_gate_params = sum( sum(w.numpy().size for w in gate.weights) for gate_list in task_gates for gate in gate_list ) logger.info( "LMIC-TSBN built: total=%d params, shared_conv=%d, " "task_specific_BN=%d, sigmoid_gates=%d, heads+other=%d", total_params, num_shared_conv_params, num_tsbn_params, num_gate_params, total_params - num_shared_conv_params - num_tsbn_params - num_gate_params, ) self._model = model return model def compile(self, learning_rate: float = 5e-4) -> keras.Model: """ Build and compile the LMIC-TSBN model with per-byte losses. Uses Adam optimizer with categorical cross-entropy loss for each of the 16 output heads. Args: learning_rate: Learning rate for Adam optimizer. Returns: The compiled Keras model. """ if self._model is None: self._model = self.build() # Build loss dict for all 16 bytes if self.multi_bit: # Multi-bit mode: binary crossentropy for 8 independent bits loss_fn = keras.losses.BinaryCrossentropy( label_smoothing=self.label_smoothing, ) logger.info( "TSBN Multi-Bit mode: BinaryCrossentropy, label_smoothing=%.2f", self.label_smoothing, ) # Accuracy metric: use binary accuracy for multi-bit metrics_dict = { f"byte_{i}": ["binary_accuracy"] for i in range(NUM_TASKS) } elif self.focal_loss: from ..focal_loss import FocalCategoricalCrossentropy loss_fn = FocalCategoricalCrossentropy( gamma=self.focal_gamma, label_smoothing=self.label_smoothing, ) logger.info( "TSBN Focal Loss ENABLED: gamma=%.1f, label_smoothing=%.2f", self.focal_gamma, self.label_smoothing, ) metrics_dict = { f"byte_{i}": ["accuracy"] for i in range(NUM_TASKS) } else: loss_fn = keras.losses.CategoricalCrossentropy( label_smoothing=self.label_smoothing, ) metrics_dict = { f"byte_{i}": ["accuracy"] for i in range(NUM_TASKS) } losses = {f"byte_{i}": loss_fn for i in range(NUM_TASKS)} # Build optimizer with optional gradient clipping opt_kwargs = {"learning_rate": learning_rate} if self.clipnorm > 0: opt_kwargs["clipnorm"] = self.clipnorm logger.info("TSBN Gradient clipping ENABLED: clipnorm=%.2f", self.clipnorm) self._model.compile( optimizer=keras.optimizers.Adam(**opt_kwargs), loss=losses, metrics=metrics_dict, jit_compile=self.jit_compile, ) if self.jit_compile: logger.info("TSBN XLA jit_compile ENABLED") return self._model def get_config(self) -> Dict[str, Any]: """Return architecture hyperparameters for logging.""" return { "architecture": "LMIC_TSBN", "input_length": self.input_length, "num_classes": self.num_classes, "conv_filters": self.conv_filters, "kernel_size": self.kernel_size, "pool_size": self.pool_size, "head_units": self.head_units, "dropout_rate": self.dropout_rate, "label_smoothing": self.label_smoothing, "spectral_decoupling_lambda": self.spectral_decoupling_lambda, "use_sigmoid_gates": self.use_sigmoid_gates, "focal_loss": self.focal_loss, "focal_gamma": self.focal_gamma, "clipnorm": self.clipnorm, "jit_compile": self.jit_compile, "multi_bit": self.multi_bit, "num_tasks": NUM_TASKS, }