| """ |
| Localized Multi-Input CNN (LMIC) for ASCAD |
| ========================================== |
| A multi-task architecture that uses 16 separate per-byte POI inputs |
| instead of a single global window. Each byte receives only its 700-sample |
| POI window, eliminating the gradient starvation and task dominance problems |
| documented in HISTORY_LEDGER Sections 1-8. |
| |
| Architecture (original LMIC): |
| - 16 independent Input layers, each (700, 1) |
| - A shared convolutional feature extractor (weight-shared across all 16) |
| - 16 independent classification heads (Dense → Softmax) |
| |
| Architecture (LMIC-TSBN — Task-Specific Batch Normalization): |
| - 16 independent Input layers, each (700, 1) |
| - Shared Conv1D weights applied to all 16 inputs |
| - 16 TASK-SPECIFIC BatchNorm layers per conv layer (one per byte) |
| - Optional sigmoid gates per task per filter (TSσBN soft capacity allocation) |
| - 16 independent classification heads (Dense → Softmax) |
| |
| The TSBN variant fixes the val_loss volatility observed in LMIC v5 |
| (HISTORY_LEDGER Section 10) by preventing shared BN running statistics |
| from being corrupted by gradient interference between bytes. |
| |
| Design rationale (from literature): |
| - Marquet & Oswald (COSADE 2024): Low-level parameter sharing achieves |
| 0% failure rate and 100% key recovery in 25-45 epochs. |
| - Zaid et al. (TCHES 2020): Compact CNNs with few filters and small |
| kernels are optimal for ASCAD. |
| - Perin et al. (TCHES 2022): Per-byte POI windows of ~700 samples |
| capture all leakage information. |
| - Suteu & Serban (ICLR 2025): Task-Specific Sigmoid Batch Normalization |
| (TSσBN) eliminates gradient interference in MTL with <0.5% parameter |
| overhead by replacing shared BN with per-task BN + sigmoid gates. |
| |
| Mixed Precision Compatibility: |
| When ``mixed_float16`` global policy is active, all intermediate layers |
| compute in float16 for speed, but the final softmax output layers are |
| explicitly set to ``dtype='float32'`` to avoid numerical instability. |
| |
| References: |
| - Marquet & Oswald, "Exploring Multi-Task Learning in the Context of |
| Masked AES Implementations", COSADE 2024. |
| - Zaid et al., "Methodology for Efficient CNN Architectures in Profiling |
| Attacks", TCHES 2020. |
| - Perin et al., "Learning When to Stop: A Mutual Information Approach |
| to Fight Overfitting in Profiled Side-Channel Analysis", TCHES 2022. |
| - Micikevicius et al., "Mixed Precision Training", ICLR 2018. |
| - Suteu & Serban, "Simplifying Multi-Task Architectures Through |
| Task-Specific Normalization", ICLR 2025. |
| """ |
|
|
| import logging |
| from typing import Any, Dict, List, Optional |
|
|
| import tensorflow as tf |
| from tensorflow import keras |
| from tensorflow.keras import layers |
|
|
| from ..constants import ( |
| BYTE_POI_WINDOWS, |
| NUM_CLASSES, |
| WINDOW_SIZE, |
| ) |
| from .base import BaseModel |
| from ..spectral_decoupling import get_spectral_decoupling_regularizer |
|
|
| logger = logging.getLogger(__name__) |
|
|
| NUM_TASKS = 16 |
|
|
|
|
| |
| |
| |
|
|
| class SigmoidGate(layers.Layer): |
| """ |
| Learnable per-filter sigmoid gate for soft capacity allocation. |
| |
| Each filter gets a scalar gate parameter initialized to 0 (sigmoid(0)=0.5), |
| so all filters start equally active. During training, the gate learns to |
| suppress irrelevant filters for this task, enabling soft specialization. |
| |
| This implements the σ-gate from Suteu & Serban (ICLR 2025): |
| output = σ(gate) * input |
| |
| Args: |
| num_filters: Number of filters (channels) to gate. |
| init_value: Initial gate value (0.0 → sigmoid=0.5, all filters active). |
| """ |
|
|
| def __init__(self, num_filters: int, init_value: float = 0.0, **kwargs): |
| super().__init__(**kwargs) |
| self.num_filters = num_filters |
| self.init_value = init_value |
|
|
| def build(self, input_shape): |
| self.gate = self.add_weight( |
| name="gate", |
| shape=(self.num_filters,), |
| initializer=tf.keras.initializers.Constant(self.init_value), |
| trainable=True, |
| ) |
| super().build(input_shape) |
|
|
| def call(self, inputs): |
| |
| return inputs * tf.sigmoid(self.gate) |
|
|
| def get_config(self): |
| config = super().get_config() |
| config.update({ |
| "num_filters": self.num_filters, |
| "init_value": self.init_value, |
| }) |
| return config |
|
|
|
|
| |
| |
| |
|
|
| class LMICModel(BaseModel): |
| """ |
| Localized Multi-Input CNN (LMIC) for simultaneous 16-byte key recovery. |
| |
| Instead of processing a single 32,272-sample global window, this model |
| receives 16 separate 700-sample inputs (one per byte's POI window). |
| A shared convolutional feature extractor processes each input identically, |
| then 16 independent Dense+Softmax heads produce per-byte predictions. |
| |
| This design eliminates: |
| - Gradient starvation: each byte has its own gradient path |
| - Task dominance: no shared FC bottleneck |
| - GAP information loss: small inputs can be flattened directly |
| - Domain shift memorization: each input is only 700 samples |
| |
| Args: |
| input_length: Length of each per-byte POI window (default: 700). |
| num_classes: Number of output classes (256 for AES S-Box). |
| conv_filters: List of filter counts for shared conv blocks. |
| kernel_size: Kernel size for Conv1D layers. |
| pool_size: Pool size for AveragePooling1D layers. |
| head_units: Number of units in each per-byte head's Dense layer. |
| dropout_rate: Dropout rate for regularization. |
| label_smoothing: Label smoothing factor for cross-entropy loss. |
| spectral_decoupling_lambda: L2 logit regularization strength. |
| """ |
|
|
| def __init__( |
| self, |
| input_length: int = WINDOW_SIZE, |
| num_classes: int = NUM_CLASSES, |
| conv_filters: Optional[List[int]] = None, |
| kernel_size: int = 11, |
| pool_size: int = 2, |
| head_units: int = 128, |
| dropout_rate: float = 0.2, |
| label_smoothing: float = 0.0, |
| spectral_decoupling_lambda: float = 0.0, |
| focal_loss: bool = False, |
| focal_gamma: float = 2.0, |
| clipnorm: float = 0.0, |
| jit_compile: bool = False, |
| multi_bit: bool = False, |
| ) -> None: |
| |
| |
| super().__init__( |
| input_shape=(input_length, 1), |
| num_classes=num_classes, |
| ) |
| self.input_length = input_length |
| self.conv_filters = conv_filters or [64, 128, 256] |
| self.kernel_size = kernel_size |
| self.pool_size = pool_size |
| self.head_units = head_units |
| self.dropout_rate = dropout_rate |
| self.label_smoothing = label_smoothing |
| self.spectral_decoupling_lambda = spectral_decoupling_lambda |
| self.focal_loss = focal_loss |
| self.focal_gamma = focal_gamma |
| self.clipnorm = clipnorm |
| self.jit_compile = jit_compile |
| self.multi_bit = multi_bit |
|
|
| def _build_shared_conv_block(self, name_prefix: str = "shared") -> keras.Model: |
| """ |
| Build the shared convolutional feature extractor. |
| |
| This is a compact CNN following Zaid et al. (TCHES 2020): |
| Conv1D → BN → ReLU → AvgPool (repeated per filter count) |
| → Flatten |
| |
| The block is instantiated once and called on each of the 16 inputs, |
| so all bytes share the same convolutional weights. |
| |
| Returns: |
| A Keras Model that maps (batch, input_length, 1) → (batch, F). |
| """ |
| inp = layers.Input(shape=(self.input_length, 1), name=f"{name_prefix}_input") |
| x = inp |
|
|
| for i, filters in enumerate(self.conv_filters): |
| x = layers.Conv1D( |
| filters=filters, |
| kernel_size=self.kernel_size, |
| padding="same", |
| name=f"{name_prefix}_conv{i}", |
| )(x) |
| x = layers.BatchNormalization(name=f"{name_prefix}_bn{i}")(x) |
| x = layers.Activation("relu", name=f"{name_prefix}_relu{i}")(x) |
| x = layers.AveragePooling1D( |
| pool_size=self.pool_size, |
| name=f"{name_prefix}_pool{i}", |
| )(x) |
|
|
| |
| |
| |
| |
| |
| |
| x = layers.GlobalAveragePooling1D(name=f"{name_prefix}_gap")(x) |
|
|
| return keras.Model(inputs=inp, outputs=x, name=f"{name_prefix}_conv_block") |
|
|
| def build(self) -> keras.Model: |
| """ |
| Construct the full LMIC model with 16 inputs and 16 outputs. |
| |
| Architecture: |
| 16 × Input(700, 1) → shared_conv_block → per-byte head → softmax |
| |
| Returns: |
| A compiled-ready Keras Model with named inputs and outputs. |
| """ |
| |
| shared_conv = self._build_shared_conv_block() |
|
|
| |
| feature_dim = shared_conv.output_shape[-1] |
| logger.info( |
| "LMIC shared conv block: %d params, output_dim=%d", |
| shared_conv.count_params(), |
| feature_dim, |
| ) |
|
|
| |
| sd_reg = None |
| if self.spectral_decoupling_lambda > 0: |
| sd_reg = get_spectral_decoupling_regularizer( |
| self.spectral_decoupling_lambda |
| ) |
| logger.info( |
| "Spectral Decoupling ENABLED: lambda=%.4f", |
| self.spectral_decoupling_lambda, |
| ) |
|
|
| |
| all_inputs = {} |
| all_outputs = {} |
|
|
| for byte_idx in range(NUM_TASKS): |
| |
| byte_input = layers.Input( |
| shape=(self.input_length, 1), |
| name=f"byte_{byte_idx}_input", |
| ) |
| all_inputs[f"byte_{byte_idx}_input"] = byte_input |
|
|
| |
| features = shared_conv(byte_input) |
|
|
| |
| x = layers.Dropout( |
| self.dropout_rate, |
| name=f"byte_{byte_idx}_dropout", |
| )(features) |
| x = layers.Dense( |
| self.head_units, |
| activation="relu", |
| name=f"byte_{byte_idx}_dense", |
| )(x) |
|
|
| |
| if self.multi_bit: |
| |
| all_outputs[f"byte_{byte_idx}"] = layers.Dense( |
| 8, |
| activation="sigmoid", |
| dtype="float32", |
| kernel_regularizer=sd_reg, |
| name=f"byte_{byte_idx}", |
| )(x) |
| else: |
| |
| all_outputs[f"byte_{byte_idx}"] = layers.Dense( |
| self.num_classes, |
| activation="softmax", |
| dtype="float32", |
| kernel_regularizer=sd_reg, |
| name=f"byte_{byte_idx}", |
| )(x) |
|
|
| |
| |
| model = keras.Model( |
| inputs=all_inputs, |
| outputs=all_outputs, |
| name="LMIC" + ("_MB" if self.multi_bit else ""), |
| ) |
|
|
| self._model = model |
| return model |
|
|
| def compile(self, learning_rate: float = 5e-4) -> keras.Model: |
| """ |
| Build and compile the LMIC model with per-byte losses. |
| |
| Uses Adam optimizer (standard for SCA) with categorical cross-entropy |
| loss for each of the 16 output heads. |
| |
| Args: |
| learning_rate: Learning rate for Adam optimizer. |
| |
| Returns: |
| The compiled Keras model. |
| """ |
| if self._model is None: |
| self._model = self.build() |
|
|
| |
| if self.multi_bit: |
| |
| loss_fn = keras.losses.BinaryCrossentropy( |
| label_smoothing=self.label_smoothing, |
| ) |
| logger.info( |
| "LMIC Multi-Bit mode: BinaryCrossentropy, label_smoothing=%.2f", |
| self.label_smoothing, |
| ) |
| metrics_dict = { |
| f"byte_{i}": ["binary_accuracy"] for i in range(NUM_TASKS) |
| } |
| elif self.focal_loss: |
| from ..focal_loss import FocalCategoricalCrossentropy |
| loss_fn = FocalCategoricalCrossentropy( |
| gamma=self.focal_gamma, |
| label_smoothing=self.label_smoothing, |
| ) |
| logger.info( |
| "Focal Loss ENABLED: gamma=%.1f, label_smoothing=%.2f", |
| self.focal_gamma, self.label_smoothing, |
| ) |
| metrics_dict = { |
| f"byte_{i}": ["accuracy"] for i in range(NUM_TASKS) |
| } |
| else: |
| loss_fn = keras.losses.CategoricalCrossentropy( |
| label_smoothing=self.label_smoothing, |
| ) |
| metrics_dict = { |
| f"byte_{i}": ["accuracy"] for i in range(NUM_TASKS) |
| } |
| losses = {f"byte_{i}": loss_fn for i in range(NUM_TASKS)} |
|
|
| |
| opt_kwargs = {"learning_rate": learning_rate} |
| if self.clipnorm > 0: |
| opt_kwargs["clipnorm"] = self.clipnorm |
| logger.info("Gradient clipping ENABLED: clipnorm=%.2f", self.clipnorm) |
|
|
| self._model.compile( |
| optimizer=keras.optimizers.Adam(**opt_kwargs), |
| loss=losses, |
| metrics=metrics_dict, |
| jit_compile=self.jit_compile, |
| ) |
| if self.jit_compile: |
| logger.info("XLA jit_compile ENABLED") |
| return self._model |
|
|
| def get_config(self) -> Dict[str, Any]: |
| """Return architecture hyperparameters for logging.""" |
| return { |
| "architecture": "LMIC", |
| "input_length": self.input_length, |
| "num_classes": self.num_classes, |
| "conv_filters": self.conv_filters, |
| "kernel_size": self.kernel_size, |
| "pool_size": self.pool_size, |
| "head_units": self.head_units, |
| "dropout_rate": self.dropout_rate, |
| "label_smoothing": self.label_smoothing, |
| "spectral_decoupling_lambda": self.spectral_decoupling_lambda, |
| "focal_loss": self.focal_loss, |
| "focal_gamma": self.focal_gamma, |
| "clipnorm": self.clipnorm, |
| "jit_compile": self.jit_compile, |
| "multi_bit": self.multi_bit, |
| "num_tasks": NUM_TASKS, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class LMICTSBNModel(BaseModel): |
| """ |
| LMIC with Task-Specific Batch Normalization (LMIC-TSBN). |
| |
| This variant replaces the shared BatchNorm layers in the LMIC conv block |
| with 16 task-specific BN layers (one per byte). The Conv1D weights remain |
| fully shared across all 16 tasks — only the BN parameters (γ, β, running |
| mean, running variance) become per-task. |
| |
| This eliminates the gradient interference that caused val_loss volatility |
| in the original LMIC (HISTORY_LEDGER Section 10), where shared BN running |
| statistics were corrupted by conflicting gradient signals from different |
| bytes. |
| |
| The optional sigmoid gates (TSσBN) add a learnable per-filter gate per |
| task, enabling soft capacity allocation: each byte can learn to suppress |
| filters it doesn't need, allowing the shared conv weights to develop |
| both shared and specialized features. |
| |
| Parameter overhead: |
| - Per-task BN: 16 × 2 × sum(conv_filters) extra trainable params |
| (γ, β per filter per task) + 16 × 2 × sum(conv_filters) non-trainable |
| (running mean/var per task). For [64, 128, 256]: ~28K extra params. |
| - Sigmoid gates: 16 × sum(conv_filters) extra trainable params. |
| For [64, 128, 256]: ~7K extra params. |
| - Total overhead: <2% of model size. |
| |
| Architecture: |
| 16 × Input(700, 1) |
| → For each byte i: |
| shared_conv0(input_i) → task_bn0_i → [σ_gate0_i] → relu → avgpool |
| → shared_conv1(...) → task_bn1_i → [σ_gate1_i] → relu → avgpool |
| → shared_conv2(...) → task_bn2_i → [σ_gate2_i] → relu → avgpool |
| → GAP → dropout → dense_i → softmax_i |
| |
| References: |
| - Suteu & Serban, "Simplifying Multi-Task Architectures Through |
| Task-Specific Normalization", ICLR 2025. |
| - Zaid et al., "Methodology for Efficient CNN Architectures in |
| Profiling Attacks", TCHES 2020. |
| |
| Args: |
| input_length: Length of each per-byte POI window (default: 700). |
| num_classes: Number of output classes (256 for AES S-Box). |
| conv_filters: List of filter counts for shared conv blocks. |
| kernel_size: Kernel size for Conv1D layers. |
| pool_size: Pool size for AveragePooling1D layers. |
| head_units: Number of units in each per-byte head's Dense layer. |
| dropout_rate: Dropout rate for regularization. |
| label_smoothing: Label smoothing factor for cross-entropy loss. |
| spectral_decoupling_lambda: L2 logit regularization strength. |
| use_sigmoid_gates: Whether to add σ-gates for soft capacity allocation. |
| """ |
|
|
| def __init__( |
| self, |
| input_length: int = WINDOW_SIZE, |
| num_classes: int = NUM_CLASSES, |
| conv_filters: Optional[List[int]] = None, |
| kernel_size: int = 11, |
| pool_size: int = 2, |
| head_units: int = 128, |
| dropout_rate: float = 0.2, |
| label_smoothing: float = 0.0, |
| spectral_decoupling_lambda: float = 0.0, |
| use_sigmoid_gates: bool = False, |
| focal_loss: bool = False, |
| focal_gamma: float = 2.0, |
| clipnorm: float = 0.0, |
| jit_compile: bool = False, |
| multi_bit: bool = False, |
| ) -> None: |
| super().__init__( |
| input_shape=(input_length, 1), |
| num_classes=num_classes, |
| ) |
| self.input_length = input_length |
| self.conv_filters = conv_filters or [64, 128, 256] |
| self.kernel_size = kernel_size |
| self.pool_size = pool_size |
| self.head_units = head_units |
| self.dropout_rate = dropout_rate |
| self.label_smoothing = label_smoothing |
| self.spectral_decoupling_lambda = spectral_decoupling_lambda |
| self.use_sigmoid_gates = use_sigmoid_gates |
| self.focal_loss = focal_loss |
| self.focal_gamma = focal_gamma |
| self.clipnorm = clipnorm |
| self.jit_compile = jit_compile |
| self.multi_bit = multi_bit |
|
|
| def build(self) -> keras.Model: |
| """ |
| Construct the LMIC-TSBN model with shared conv + per-task BN. |
| |
| Instead of wrapping the conv block as a sub-Model (which would force |
| shared BN), we build the computation graph explicitly: shared Conv1D |
| layers are instantiated once and called on each byte's input, while |
| BatchNorm (and optional sigmoid gates) are instantiated per-task. |
| |
| Returns: |
| A compiled-ready Keras Model with named inputs and outputs. |
| """ |
| |
| shared_convs = [] |
| for i, filters in enumerate(self.conv_filters): |
| conv = layers.Conv1D( |
| filters=filters, |
| kernel_size=self.kernel_size, |
| padding="same", |
| name=f"shared_conv{i}", |
| ) |
| shared_convs.append(conv) |
|
|
| shared_pools = [] |
| for i in range(len(self.conv_filters)): |
| pool = layers.AveragePooling1D( |
| pool_size=self.pool_size, |
| name=f"shared_pool{i}", |
| ) |
| shared_pools.append(pool) |
|
|
| shared_gap = layers.GlobalAveragePooling1D(name="shared_gap") |
|
|
| |
| |
| |
| task_bns = [] |
| task_gates = [] |
| for i, filters in enumerate(self.conv_filters): |
| bn_list = [] |
| gate_list = [] |
| for byte_idx in range(NUM_TASKS): |
| bn = layers.BatchNormalization( |
| name=f"tsbn_L{i}_byte{byte_idx}", |
| ) |
| bn_list.append(bn) |
|
|
| if self.use_sigmoid_gates: |
| gate = SigmoidGate( |
| num_filters=filters, |
| name=f"gate_L{i}_byte{byte_idx}", |
| ) |
| gate_list.append(gate) |
| task_bns.append(bn_list) |
| task_gates.append(gate_list) |
|
|
| |
| sd_reg = None |
| if self.spectral_decoupling_lambda > 0: |
| sd_reg = get_spectral_decoupling_regularizer( |
| self.spectral_decoupling_lambda |
| ) |
| logger.info( |
| "Spectral Decoupling ENABLED: lambda=%.4f", |
| self.spectral_decoupling_lambda, |
| ) |
|
|
| |
| all_inputs = {} |
| all_outputs = {} |
|
|
| for byte_idx in range(NUM_TASKS): |
| byte_input = layers.Input( |
| shape=(self.input_length, 1), |
| name=f"byte_{byte_idx}_input", |
| ) |
| all_inputs[f"byte_{byte_idx}_input"] = byte_input |
|
|
| x = byte_input |
|
|
| |
| for layer_idx in range(len(self.conv_filters)): |
| |
| x = shared_convs[layer_idx](x) |
| |
| x = task_bns[layer_idx][byte_idx](x) |
| |
| if self.use_sigmoid_gates and task_gates[layer_idx]: |
| x = task_gates[layer_idx][byte_idx](x) |
| |
| x = layers.Activation( |
| "relu", |
| name=f"relu_L{layer_idx}_byte{byte_idx}", |
| )(x) |
| x = shared_pools[layer_idx](x) |
|
|
| |
| features = shared_gap(x) |
|
|
| |
| x = layers.Dropout( |
| self.dropout_rate, |
| name=f"byte_{byte_idx}_dropout", |
| )(features) |
| x = layers.Dense( |
| self.head_units, |
| activation="relu", |
| name=f"byte_{byte_idx}_dense", |
| )(x) |
|
|
| |
| if self.multi_bit: |
| |
| |
| all_outputs[f"byte_{byte_idx}"] = layers.Dense( |
| 8, |
| activation="sigmoid", |
| dtype="float32", |
| kernel_regularizer=sd_reg, |
| name=f"byte_{byte_idx}", |
| )(x) |
| else: |
| |
| all_outputs[f"byte_{byte_idx}"] = layers.Dense( |
| self.num_classes, |
| activation="softmax", |
| dtype="float32", |
| kernel_regularizer=sd_reg, |
| name=f"byte_{byte_idx}", |
| )(x) |
|
|
| model = keras.Model( |
| inputs=all_inputs, |
| outputs=all_outputs, |
| name="LMIC_TSBN" + ("_MB" if self.multi_bit else ""), |
| ) |
|
|
| |
| total_params = model.count_params() |
| num_shared_conv_params = sum( |
| sum(w.numpy().size for w in conv.weights) |
| for conv in shared_convs |
| ) |
| num_tsbn_params = sum( |
| sum(w.numpy().size for w in bn.weights) |
| for bn_list in task_bns |
| for bn in bn_list |
| ) |
| num_gate_params = 0 |
| if self.use_sigmoid_gates: |
| num_gate_params = sum( |
| sum(w.numpy().size for w in gate.weights) |
| for gate_list in task_gates |
| for gate in gate_list |
| ) |
|
|
| logger.info( |
| "LMIC-TSBN built: total=%d params, shared_conv=%d, " |
| "task_specific_BN=%d, sigmoid_gates=%d, heads+other=%d", |
| total_params, |
| num_shared_conv_params, |
| num_tsbn_params, |
| num_gate_params, |
| total_params - num_shared_conv_params - num_tsbn_params - num_gate_params, |
| ) |
|
|
| self._model = model |
| return model |
|
|
| def compile(self, learning_rate: float = 5e-4) -> keras.Model: |
| """ |
| Build and compile the LMIC-TSBN model with per-byte losses. |
| |
| Uses Adam optimizer with categorical cross-entropy loss for each |
| of the 16 output heads. |
| |
| Args: |
| learning_rate: Learning rate for Adam optimizer. |
| |
| Returns: |
| The compiled Keras model. |
| """ |
| if self._model is None: |
| self._model = self.build() |
|
|
| |
| if self.multi_bit: |
| |
| loss_fn = keras.losses.BinaryCrossentropy( |
| label_smoothing=self.label_smoothing, |
| ) |
| logger.info( |
| "TSBN Multi-Bit mode: BinaryCrossentropy, label_smoothing=%.2f", |
| self.label_smoothing, |
| ) |
| |
| metrics_dict = { |
| f"byte_{i}": ["binary_accuracy"] for i in range(NUM_TASKS) |
| } |
| elif self.focal_loss: |
| from ..focal_loss import FocalCategoricalCrossentropy |
| loss_fn = FocalCategoricalCrossentropy( |
| gamma=self.focal_gamma, |
| label_smoothing=self.label_smoothing, |
| ) |
| logger.info( |
| "TSBN Focal Loss ENABLED: gamma=%.1f, label_smoothing=%.2f", |
| self.focal_gamma, self.label_smoothing, |
| ) |
| metrics_dict = { |
| f"byte_{i}": ["accuracy"] for i in range(NUM_TASKS) |
| } |
| else: |
| loss_fn = keras.losses.CategoricalCrossentropy( |
| label_smoothing=self.label_smoothing, |
| ) |
| metrics_dict = { |
| f"byte_{i}": ["accuracy"] for i in range(NUM_TASKS) |
| } |
| losses = {f"byte_{i}": loss_fn for i in range(NUM_TASKS)} |
|
|
| |
| opt_kwargs = {"learning_rate": learning_rate} |
| if self.clipnorm > 0: |
| opt_kwargs["clipnorm"] = self.clipnorm |
| logger.info("TSBN Gradient clipping ENABLED: clipnorm=%.2f", self.clipnorm) |
|
|
| self._model.compile( |
| optimizer=keras.optimizers.Adam(**opt_kwargs), |
| loss=losses, |
| metrics=metrics_dict, |
| jit_compile=self.jit_compile, |
| ) |
| if self.jit_compile: |
| logger.info("TSBN XLA jit_compile ENABLED") |
| return self._model |
|
|
| def get_config(self) -> Dict[str, Any]: |
| """Return architecture hyperparameters for logging.""" |
| return { |
| "architecture": "LMIC_TSBN", |
| "input_length": self.input_length, |
| "num_classes": self.num_classes, |
| "conv_filters": self.conv_filters, |
| "kernel_size": self.kernel_size, |
| "pool_size": self.pool_size, |
| "head_units": self.head_units, |
| "dropout_rate": self.dropout_rate, |
| "label_smoothing": self.label_smoothing, |
| "spectral_decoupling_lambda": self.spectral_decoupling_lambda, |
| "use_sigmoid_gates": self.use_sigmoid_gates, |
| "focal_loss": self.focal_loss, |
| "focal_gamma": self.focal_gamma, |
| "clipnorm": self.clipnorm, |
| "jit_compile": self.jit_compile, |
| "multi_bit": self.multi_bit, |
| "num_tasks": NUM_TASKS, |
| } |
|
|