lemousehunter
feat: training speed optimizations — mixed precision, vectorized augmentation, cached eval predictions
1fe1a19
"""
Multi-Task Learning Models for ASCAD
=====================================
Implements two architectures for simultaneous 16-byte key recovery:
1. **HPS (Hard Parameter Sharing)**: Shared CNN backbone + shared FC layers +
16 independent softmax heads. Based on Marquet & Oswald (COSADE 2024).
2. **MTAN-Lite (Simplified Multi-Task Attention Network)**: Shared CNN backbone
with a single soft-attention module at the final conv block + per-task heads.
Novel contribution: SNR-guided attention initialization.
Mixed Precision Compatibility:
When ``mixed_float16`` global policy is active, all intermediate layers
compute in float16 for speed, but the final softmax output layers are
explicitly set to ``dtype='float32'`` to avoid numerical instability in
cross-entropy loss computation. This follows the recommended practice
from Micikevicius et al. (ICLR 2018) and the TensorFlow mixed precision
guide.
References:
- Marquet & Oswald, "Exploring Multi-Task Learning in the Context of
Masked AES Implementations", COSADE 2024.
- Liu et al., "End-to-End Multi-Task Learning with Attention", CVPR 2019.
- Prouff et al., "Study of Deep Learning Techniques for Side-Channel
Analysis and Introduction to ASCAD Database", 2019.
- Micikevicius et al., "Mixed Precision Training", ICLR 2018.
"""
import logging
from typing import Dict, List, Optional
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from ..constants import (
BYTE_PEAK_SNR,
GLOBAL_WINDOW_SIZE,
NUM_CLASSES,
)
from .base import BaseModel
from ..gradnorm import build_per_output_loss_metrics
from ..spectral_decoupling import get_spectral_decoupling_regularizer
logger = logging.getLogger(__name__)
NUM_TASKS = 16
# ============================================================================
# Shared CNN Backbone (used by both HPS and MTAN-Lite)
# ============================================================================
def _build_shared_backbone(
input_shape: tuple,
conv_filters: List[int],
kernel_size: int = 11,
pool_size: int = 2,
dropout_rate: float = 0.0,
name: str = "shared_backbone",
) -> tuple:
"""
Build a shared 1D CNN backbone.
Architecture follows the ASCAD CNNbest pattern (Prouff et al., 2019)
but with configurable filter counts for multi-task use.
Args:
input_shape: Shape of input traces (T, 1).
conv_filters: Number of filters per conv block.
kernel_size: Convolution kernel size.
pool_size: Average pooling size.
dropout_rate: Dropout rate after each block (0 = no dropout).
name: Name prefix for layers.
Returns:
Tuple of (input_tensor, backbone_output_tensor).
"""
inp = layers.Input(shape=input_shape, name=f"{name}_input")
x = inp
for i, filters in enumerate(conv_filters):
x = layers.Conv1D(
filters, kernel_size,
padding="same",
activation="relu",
kernel_initializer="he_uniform",
name=f"{name}_conv{i}",
)(x)
x = layers.BatchNormalization(name=f"{name}_bn{i}")(x)
x = layers.AveragePooling1D(pool_size, name=f"{name}_pool{i}")(x)
if dropout_rate > 0:
x = layers.Dropout(dropout_rate, name=f"{name}_drop{i}")(x)
return inp, x
# ============================================================================
# Option A: Hard Parameter Sharing (HPS)
# ============================================================================
class HPSModel(BaseModel):
"""
Hard Parameter Sharing multi-task model.
Architecture (following Marquet & Oswald, COSADE 2024):
- Shared CNN backbone (5 conv blocks)
- Global Average Pooling
- Shared FC layers (all tasks share the same FC weights)
- 16 independent softmax heads (Dense(256) per byte)
This is the "md" (maximal dependency) configuration from Marquet & Oswald,
which they found to be the most effective for masked AES implementations.
"""
def __init__(
self,
conv_filters: Optional[List[int]] = None,
kernel_size: int = 11,
pool_size: int = 2,
fc_units: int = 512,
num_fc_layers: int = 2,
dropout_rate: float = 0.2,
label_smoothing: float = 0.0,
spectral_decoupling_lambda: float = 0.0,
) -> None:
"""
Args:
conv_filters: Filters per conv block. Default: [64, 128, 256, 256, 512].
kernel_size: Conv kernel size.
pool_size: Pooling size.
fc_units: Units in shared FC layers.
num_fc_layers: Number of shared FC layers.
dropout_rate: Dropout rate.
spectral_decoupling_lambda: L2 logit regularization strength
(0.0 = disabled). Penalizes large pre-softmax logits to
prevent gradient starvation (Pezeshki et al., NeurIPS 2021).
"""
if conv_filters is None:
conv_filters = [64, 128, 256, 256, 512]
self.conv_filters = conv_filters
self.kernel_size = kernel_size
self.pool_size = pool_size
self.fc_units = fc_units
self.num_fc_layers = num_fc_layers
self.dropout_rate = dropout_rate
self.label_smoothing = label_smoothing
self.spectral_decoupling_lambda = spectral_decoupling_lambda
def build(self) -> keras.Model:
"""Build the HPS multi-task Keras model."""
input_shape = (GLOBAL_WINDOW_SIZE, 1)
# Shared backbone
inp, backbone_out = _build_shared_backbone(
input_shape=input_shape,
conv_filters=self.conv_filters,
kernel_size=self.kernel_size,
pool_size=self.pool_size,
dropout_rate=self.dropout_rate,
name="hps",
)
# Global Average Pooling (eliminates large flatten dimensions)
x = layers.GlobalAveragePooling1D(name="hps_gap")(backbone_out)
# Shared FC layers (all 16 tasks share these weights)
for i in range(self.num_fc_layers):
x = layers.Dense(
self.fc_units,
activation="relu",
kernel_initializer="he_uniform",
name=f"hps_shared_fc{i}",
)(x)
x = layers.BatchNormalization(name=f"hps_shared_fc_bn{i}")(x)
x = layers.Dropout(self.dropout_rate, name=f"hps_shared_fc_drop{i}")(x)
# 16 independent softmax heads
# Spectral Decoupling: L2 activity regularizer on pre-softmax logits
sd_reg = get_spectral_decoupling_regularizer(
self.spectral_decoupling_lambda
)
outputs = {}
for byte_idx in range(NUM_TASKS):
# dtype='float32' ensures softmax outputs remain in FP32 even
# when mixed precision (mixed_float16) is active. This prevents
# numerical instability in cross-entropy loss computation.
outputs[f"byte_{byte_idx}"] = layers.Dense(
NUM_CLASSES,
activation="softmax",
activity_regularizer=sd_reg,
dtype="float32",
name=f"byte_{byte_idx}",
)(x)
model = keras.Model(inputs=inp, outputs=outputs, name="hps_mtl")
logger.info(
"Built HPS model: %s params, conv=%s, fc=%d×%d",
f"{model.count_params():,}",
self.conv_filters,
self.num_fc_layers,
self.fc_units,
)
return model
def compile(self, learning_rate: float = 5e-4) -> keras.Model:
"""Compile the model with Adam optimizer and equal task weights."""
model = self.build()
if self.label_smoothing > 0:
losses = {
f"byte_{i}": tf.keras.losses.CategoricalCrossentropy(
label_smoothing=self.label_smoothing
)
for i in range(NUM_TASKS)
}
else:
losses = {f"byte_{i}": "categorical_crossentropy" for i in range(NUM_TASKS)}
# Equal weights for all tasks (Marquet & Oswald finding)
loss_weights = {f"byte_{i}": 1.0 for i in range(NUM_TASKS)}
# Per-output loss metrics for GradNorm (zero extra memory)
metrics = build_per_output_loss_metrics(
label_smoothing=self.label_smoothing
)
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
loss=losses,
loss_weights=loss_weights,
metrics=metrics,
)
return model
def get_config(self) -> Dict:
return {
"model_type": "hps",
"conv_filters": self.conv_filters,
"kernel_size": self.kernel_size,
"pool_size": self.pool_size,
"fc_units": self.fc_units,
"num_fc_layers": self.num_fc_layers,
"dropout_rate": self.dropout_rate,
"label_smoothing": self.label_smoothing,
"spectral_decoupling_lambda": self.spectral_decoupling_lambda,
}
# ============================================================================
# Option B: Simplified MTAN (MTAN-Lite)
# ============================================================================
class SoftAttentionBlock(layers.Layer):
"""
Soft attention module applied to a single feature map.
Implements the attention mechanism from Liu et al. (CVPR 2019),
simplified for 1D signals:
attention_mask = sigmoid(Conv1D_1x1(Conv1D_1x1(features)))
attended = features * attention_mask
Each task gets its own attention block to learn task-specific feature
selection from the shared representation.
"""
def __init__(self, channels: int, bottleneck_ratio: int = 4, **kwargs):
super().__init__(**kwargs)
self.channels = channels
self.bottleneck = max(channels // bottleneck_ratio, 16)
def build(self, input_shape):
self.conv_down = layers.Conv1D(
self.bottleneck, 1, padding="same", activation="relu",
kernel_initializer="he_uniform",
)
self.conv_up = layers.Conv1D(
self.channels, 1, padding="same", activation="sigmoid",
kernel_initializer="glorot_uniform",
)
super().build(input_shape)
def call(self, inputs, training=None):
att = self.conv_down(inputs)
att = self.conv_up(att)
return inputs * att
class MTANLiteModel(BaseModel):
"""
Simplified Multi-Task Attention Network (MTAN-Lite).
Novel contribution: Applies soft attention at the FINAL conv block only
(not all blocks as in the original MTAN), with optional SNR-guided
initialization of task weights.
Architecture:
- Shared CNN backbone (5 conv blocks, no attention)
- 16 task-specific soft attention modules on the final feature map
- Per-task: GlobalAveragePooling → FC(256) → Softmax(256)
This is a middle ground between the full MTAN (too heavy) and simple HPS
(no task-specific feature selection). The hypothesis is that task-specific
attention at the final layer allows each byte to focus on its relevant
POI region within the global window.
"""
def __init__(
self,
conv_filters: Optional[List[int]] = None,
kernel_size: int = 11,
pool_size: int = 2,
head_fc_units: int = 256,
dropout_rate: float = 0.2,
snr_init: bool = False,
bottleneck_ratio: int = 4,
label_smoothing: float = 0.0,
spectral_decoupling_lambda: float = 0.0,
) -> None:
"""
Args:
conv_filters: Filters per conv block. Default: [64, 128, 256, 256, 512].
kernel_size: Conv kernel size.
pool_size: Pooling size.
head_fc_units: FC units in each task head.
dropout_rate: Dropout rate.
snr_init: Whether to use SNR-guided loss weight initialization.
bottleneck_ratio: Attention bottleneck compression ratio.
label_smoothing: Label smoothing factor (0.0 = no smoothing).
"""
if conv_filters is None:
conv_filters = [64, 128, 256, 256, 512]
self.conv_filters = conv_filters
self.kernel_size = kernel_size
self.pool_size = pool_size
self.head_fc_units = head_fc_units
self.dropout_rate = dropout_rate
self.snr_init = snr_init
self.bottleneck_ratio = bottleneck_ratio
self.label_smoothing = label_smoothing
self.spectral_decoupling_lambda = spectral_decoupling_lambda
self._task_weights = None
def _compute_snr_weights(self) -> Dict[str, float]:
"""
Compute task weights inversely proportional to SNR.
Bytes with lower SNR (harder to attack) get higher weight,
encouraging the model to allocate more capacity to difficult bytes.
Returns:
Dictionary mapping "byte_i" to weight value.
"""
snr_values = np.array([BYTE_PEAK_SNR[i] for i in range(NUM_TASKS)])
# Inverse SNR: lower SNR → higher weight
inv_snr = 1.0 / (snr_values + 1e-6)
# Normalize so mean weight = 1.0
weights = inv_snr / inv_snr.mean()
self._task_weights = weights
return {f"byte_{i}": float(weights[i]) for i in range(NUM_TASKS)}
def build(self) -> keras.Model:
"""Build the MTAN-Lite multi-task Keras model."""
input_shape = (GLOBAL_WINDOW_SIZE, 1)
final_channels = self.conv_filters[-1]
# Shared backbone (no attention in backbone)
inp, backbone_out = _build_shared_backbone(
input_shape=input_shape,
conv_filters=self.conv_filters,
kernel_size=self.kernel_size,
pool_size=self.pool_size,
dropout_rate=self.dropout_rate,
name="mtan_lite",
)
# Per-task: attention → GAP → FC → softmax
# Spectral Decoupling: L2 activity regularizer on pre-softmax logits
sd_reg = get_spectral_decoupling_regularizer(
self.spectral_decoupling_lambda
)
outputs = {}
for byte_idx in range(NUM_TASKS):
# Task-specific soft attention on the final feature map
att = SoftAttentionBlock(
channels=final_channels,
bottleneck_ratio=self.bottleneck_ratio,
name=f"att_byte_{byte_idx}",
)(backbone_out)
# Global Average Pooling
x = layers.GlobalAveragePooling1D(
name=f"gap_byte_{byte_idx}"
)(att)
# Small FC head
x = layers.Dense(
self.head_fc_units,
activation="relu",
kernel_initializer="he_uniform",
name=f"fc_byte_{byte_idx}",
)(x)
x = layers.BatchNormalization(name=f"bn_byte_{byte_idx}")(x)
x = layers.Dropout(self.dropout_rate, name=f"drop_byte_{byte_idx}")(x)
# Softmax output (with optional spectral decoupling)
# dtype='float32' ensures softmax outputs remain in FP32 even
# when mixed precision (mixed_float16) is active. This prevents
# numerical instability in cross-entropy loss computation.
outputs[f"byte_{byte_idx}"] = layers.Dense(
NUM_CLASSES,
activation="softmax",
activity_regularizer=sd_reg,
dtype="float32",
name=f"byte_{byte_idx}",
)(x)
model = keras.Model(inputs=inp, outputs=outputs, name="mtan_lite")
logger.info(
"Built MTAN-Lite model: %s params, conv=%s, head_fc=%d, "
"snr_init=%s",
f"{model.count_params():,}",
self.conv_filters,
self.head_fc_units,
self.snr_init,
)
return model
def compile(self, learning_rate: float = 5e-4) -> keras.Model:
"""Compile the model with Adam optimizer."""
model = self.build()
if self.label_smoothing > 0:
losses = {
f"byte_{i}": tf.keras.losses.CategoricalCrossentropy(
label_smoothing=self.label_smoothing
)
for i in range(NUM_TASKS)
}
else:
losses = {f"byte_{i}": "categorical_crossentropy" for i in range(NUM_TASKS)}
if self.snr_init:
loss_weights = self._compute_snr_weights()
logger.info(
"SNR-guided weights: min=%.2f (byte %d), max=%.2f (byte %d)",
min(loss_weights.values()),
min(loss_weights, key=loss_weights.get),
max(loss_weights.values()),
max(loss_weights, key=loss_weights.get),
)
else:
loss_weights = {f"byte_{i}": 1.0 for i in range(NUM_TASKS)}
# Per-output loss metrics for GradNorm (zero extra memory)
metrics = build_per_output_loss_metrics(
label_smoothing=self.label_smoothing
)
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
loss=losses,
loss_weights=loss_weights,
metrics=metrics,
)
return model
def get_config(self) -> Dict:
return {
"model_type": "mtan_lite",
"conv_filters": self.conv_filters,
"kernel_size": self.kernel_size,
"pool_size": self.pool_size,
"head_fc_units": self.head_fc_units,
"dropout_rate": self.dropout_rate,
"snr_init": self.snr_init,
"bottleneck_ratio": self.bottleneck_ratio,
"label_smoothing": self.label_smoothing,
"spectral_decoupling_lambda": self.spectral_decoupling_lambda,
"task_weights": (
self._task_weights.tolist() if self._task_weights is not None
else None
),
}