lemousehunter's picture
fix: LMIC base class multi_bit support in build() and compile()
1056030
"""
Localized Multi-Input CNN (LMIC) for ASCAD
==========================================
A multi-task architecture that uses 16 separate per-byte POI inputs
instead of a single global window. Each byte receives only its 700-sample
POI window, eliminating the gradient starvation and task dominance problems
documented in HISTORY_LEDGER Sections 1-8.
Architecture (original LMIC):
- 16 independent Input layers, each (700, 1)
- A shared convolutional feature extractor (weight-shared across all 16)
- 16 independent classification heads (Dense → Softmax)
Architecture (LMIC-TSBN — Task-Specific Batch Normalization):
- 16 independent Input layers, each (700, 1)
- Shared Conv1D weights applied to all 16 inputs
- 16 TASK-SPECIFIC BatchNorm layers per conv layer (one per byte)
- Optional sigmoid gates per task per filter (TSσBN soft capacity allocation)
- 16 independent classification heads (Dense → Softmax)
The TSBN variant fixes the val_loss volatility observed in LMIC v5
(HISTORY_LEDGER Section 10) by preventing shared BN running statistics
from being corrupted by gradient interference between bytes.
Design rationale (from literature):
- Marquet & Oswald (COSADE 2024): Low-level parameter sharing achieves
0% failure rate and 100% key recovery in 25-45 epochs.
- Zaid et al. (TCHES 2020): Compact CNNs with few filters and small
kernels are optimal for ASCAD.
- Perin et al. (TCHES 2022): Per-byte POI windows of ~700 samples
capture all leakage information.
- Suteu & Serban (ICLR 2025): Task-Specific Sigmoid Batch Normalization
(TSσBN) eliminates gradient interference in MTL with <0.5% parameter
overhead by replacing shared BN with per-task BN + sigmoid gates.
Mixed Precision Compatibility:
When ``mixed_float16`` global policy is active, all intermediate layers
compute in float16 for speed, but the final softmax output layers are
explicitly set to ``dtype='float32'`` to avoid numerical instability.
References:
- Marquet & Oswald, "Exploring Multi-Task Learning in the Context of
Masked AES Implementations", COSADE 2024.
- Zaid et al., "Methodology for Efficient CNN Architectures in Profiling
Attacks", TCHES 2020.
- Perin et al., "Learning When to Stop: A Mutual Information Approach
to Fight Overfitting in Profiled Side-Channel Analysis", TCHES 2022.
- Micikevicius et al., "Mixed Precision Training", ICLR 2018.
- Suteu & Serban, "Simplifying Multi-Task Architectures Through
Task-Specific Normalization", ICLR 2025.
"""
import logging
from typing import Any, Dict, List, Optional
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from ..constants import (
BYTE_POI_WINDOWS,
NUM_CLASSES,
WINDOW_SIZE,
)
from .base import BaseModel
from ..spectral_decoupling import get_spectral_decoupling_regularizer
logger = logging.getLogger(__name__)
NUM_TASKS = 16
# ---------------------------------------------------------------------------
# Sigmoid Gate Layer (for TSσBN)
# ---------------------------------------------------------------------------
class SigmoidGate(layers.Layer):
"""
Learnable per-filter sigmoid gate for soft capacity allocation.
Each filter gets a scalar gate parameter initialized to 0 (sigmoid(0)=0.5),
so all filters start equally active. During training, the gate learns to
suppress irrelevant filters for this task, enabling soft specialization.
This implements the σ-gate from Suteu & Serban (ICLR 2025):
output = σ(gate) * input
Args:
num_filters: Number of filters (channels) to gate.
init_value: Initial gate value (0.0 → sigmoid=0.5, all filters active).
"""
def __init__(self, num_filters: int, init_value: float = 0.0, **kwargs):
super().__init__(**kwargs)
self.num_filters = num_filters
self.init_value = init_value
def build(self, input_shape):
self.gate = self.add_weight(
name="gate",
shape=(self.num_filters,),
initializer=tf.keras.initializers.Constant(self.init_value),
trainable=True,
)
super().build(input_shape)
def call(self, inputs):
# gate shape: (num_filters,) → broadcast over (batch, time, filters)
return inputs * tf.sigmoid(self.gate)
def get_config(self):
config = super().get_config()
config.update({
"num_filters": self.num_filters,
"init_value": self.init_value,
})
return config
# ---------------------------------------------------------------------------
# Original LMIC Model (shared BN — kept for backward compatibility)
# ---------------------------------------------------------------------------
class LMICModel(BaseModel):
"""
Localized Multi-Input CNN (LMIC) for simultaneous 16-byte key recovery.
Instead of processing a single 32,272-sample global window, this model
receives 16 separate 700-sample inputs (one per byte's POI window).
A shared convolutional feature extractor processes each input identically,
then 16 independent Dense+Softmax heads produce per-byte predictions.
This design eliminates:
- Gradient starvation: each byte has its own gradient path
- Task dominance: no shared FC bottleneck
- GAP information loss: small inputs can be flattened directly
- Domain shift memorization: each input is only 700 samples
Args:
input_length: Length of each per-byte POI window (default: 700).
num_classes: Number of output classes (256 for AES S-Box).
conv_filters: List of filter counts for shared conv blocks.
kernel_size: Kernel size for Conv1D layers.
pool_size: Pool size for AveragePooling1D layers.
head_units: Number of units in each per-byte head's Dense layer.
dropout_rate: Dropout rate for regularization.
label_smoothing: Label smoothing factor for cross-entropy loss.
spectral_decoupling_lambda: L2 logit regularization strength.
"""
def __init__(
self,
input_length: int = WINDOW_SIZE,
num_classes: int = NUM_CLASSES,
conv_filters: Optional[List[int]] = None,
kernel_size: int = 11,
pool_size: int = 2,
head_units: int = 128,
dropout_rate: float = 0.2,
label_smoothing: float = 0.0,
spectral_decoupling_lambda: float = 0.0,
focal_loss: bool = False,
focal_gamma: float = 2.0,
clipnorm: float = 0.0,
jit_compile: bool = False,
multi_bit: bool = False,
) -> None:
# BaseModel expects input_shape, but LMIC has multiple inputs.
# We store the per-byte input shape and override build().
super().__init__(
input_shape=(input_length, 1),
num_classes=num_classes,
)
self.input_length = input_length
self.conv_filters = conv_filters or [64, 128, 256]
self.kernel_size = kernel_size
self.pool_size = pool_size
self.head_units = head_units
self.dropout_rate = dropout_rate
self.label_smoothing = label_smoothing
self.spectral_decoupling_lambda = spectral_decoupling_lambda
self.focal_loss = focal_loss
self.focal_gamma = focal_gamma
self.clipnorm = clipnorm
self.jit_compile = jit_compile
self.multi_bit = multi_bit
def _build_shared_conv_block(self, name_prefix: str = "shared") -> keras.Model:
"""
Build the shared convolutional feature extractor.
This is a compact CNN following Zaid et al. (TCHES 2020):
Conv1D → BN → ReLU → AvgPool (repeated per filter count)
→ Flatten
The block is instantiated once and called on each of the 16 inputs,
so all bytes share the same convolutional weights.
Returns:
A Keras Model that maps (batch, input_length, 1) → (batch, F).
"""
inp = layers.Input(shape=(self.input_length, 1), name=f"{name_prefix}_input")
x = inp
for i, filters in enumerate(self.conv_filters):
x = layers.Conv1D(
filters=filters,
kernel_size=self.kernel_size,
padding="same",
name=f"{name_prefix}_conv{i}",
)(x)
x = layers.BatchNormalization(name=f"{name_prefix}_bn{i}")(x)
x = layers.Activation("relu", name=f"{name_prefix}_relu{i}")(x)
x = layers.AveragePooling1D(
pool_size=self.pool_size,
name=f"{name_prefix}_pool{i}",
)(x)
# GlobalAveragePooling reduces (batch, time, filters) → (batch, filters)
# This keeps the feature dimension equal to the last conv's filter count,
# making per-byte heads compact. Unlike the MTAN-Lite architecture where
# GAP on a 32K global window destroyed byte-specific spatial information,
# here each input is already a 700-sample byte-specific POI window, so
# GAP is safe — it only pools within a single byte's features.
x = layers.GlobalAveragePooling1D(name=f"{name_prefix}_gap")(x)
return keras.Model(inputs=inp, outputs=x, name=f"{name_prefix}_conv_block")
def build(self) -> keras.Model:
"""
Construct the full LMIC model with 16 inputs and 16 outputs.
Architecture:
16 × Input(700, 1) → shared_conv_block → per-byte head → softmax
Returns:
A compiled-ready Keras Model with named inputs and outputs.
"""
# Build the shared conv block (instantiated once, called 16 times)
shared_conv = self._build_shared_conv_block()
# Log the shared block summary
feature_dim = shared_conv.output_shape[-1]
logger.info(
"LMIC shared conv block: %d params, output_dim=%d",
shared_conv.count_params(),
feature_dim,
)
# Get spectral decoupling regularizer if enabled
sd_reg = None
if self.spectral_decoupling_lambda > 0:
sd_reg = get_spectral_decoupling_regularizer(
self.spectral_decoupling_lambda
)
logger.info(
"Spectral Decoupling ENABLED: lambda=%.4f",
self.spectral_decoupling_lambda,
)
# Build 16 input branches and heads
all_inputs = {} # dict for named inputs
all_outputs = {} # dict for named outputs (ensures predict() returns dict)
for byte_idx in range(NUM_TASKS):
# Per-byte input
byte_input = layers.Input(
shape=(self.input_length, 1),
name=f"byte_{byte_idx}_input",
)
all_inputs[f"byte_{byte_idx}_input"] = byte_input
# Pass through shared conv block
features = shared_conv(byte_input)
# Per-byte independent head
x = layers.Dropout(
self.dropout_rate,
name=f"byte_{byte_idx}_dropout",
)(features)
x = layers.Dense(
self.head_units,
activation="relu",
name=f"byte_{byte_idx}_dense",
)(x)
# Output layer — pinned to float32 for mixed precision safety
if self.multi_bit:
# Multi-bit mode: 8 independent binary outputs per byte
all_outputs[f"byte_{byte_idx}"] = layers.Dense(
8,
activation="sigmoid",
dtype="float32",
kernel_regularizer=sd_reg,
name=f"byte_{byte_idx}",
)(x)
else:
# Standard mode: 256-class softmax
all_outputs[f"byte_{byte_idx}"] = layers.Dense(
self.num_classes,
activation="softmax",
dtype="float32",
kernel_regularizer=sd_reg,
name=f"byte_{byte_idx}",
)(x)
# Build the multi-input, multi-output model
# Using dicts ensures model.predict() returns a dict with named keys.
model = keras.Model(
inputs=all_inputs,
outputs=all_outputs,
name="LMIC" + ("_MB" if self.multi_bit else ""),
)
self._model = model
return model
def compile(self, learning_rate: float = 5e-4) -> keras.Model:
"""
Build and compile the LMIC model with per-byte losses.
Uses Adam optimizer (standard for SCA) with categorical cross-entropy
loss for each of the 16 output heads.
Args:
learning_rate: Learning rate for Adam optimizer.
Returns:
The compiled Keras model.
"""
if self._model is None:
self._model = self.build()
# Build loss dict for all 16 bytes
if self.multi_bit:
# Multi-bit mode: binary crossentropy for 8 independent bits
loss_fn = keras.losses.BinaryCrossentropy(
label_smoothing=self.label_smoothing,
)
logger.info(
"LMIC Multi-Bit mode: BinaryCrossentropy, label_smoothing=%.2f",
self.label_smoothing,
)
metrics_dict = {
f"byte_{i}": ["binary_accuracy"] for i in range(NUM_TASKS)
}
elif self.focal_loss:
from ..focal_loss import FocalCategoricalCrossentropy
loss_fn = FocalCategoricalCrossentropy(
gamma=self.focal_gamma,
label_smoothing=self.label_smoothing,
)
logger.info(
"Focal Loss ENABLED: gamma=%.1f, label_smoothing=%.2f",
self.focal_gamma, self.label_smoothing,
)
metrics_dict = {
f"byte_{i}": ["accuracy"] for i in range(NUM_TASKS)
}
else:
loss_fn = keras.losses.CategoricalCrossentropy(
label_smoothing=self.label_smoothing,
)
metrics_dict = {
f"byte_{i}": ["accuracy"] for i in range(NUM_TASKS)
}
losses = {f"byte_{i}": loss_fn for i in range(NUM_TASKS)}
# Build optimizer with optional gradient clipping
opt_kwargs = {"learning_rate": learning_rate}
if self.clipnorm > 0:
opt_kwargs["clipnorm"] = self.clipnorm
logger.info("Gradient clipping ENABLED: clipnorm=%.2f", self.clipnorm)
self._model.compile(
optimizer=keras.optimizers.Adam(**opt_kwargs),
loss=losses,
metrics=metrics_dict,
jit_compile=self.jit_compile,
)
if self.jit_compile:
logger.info("XLA jit_compile ENABLED")
return self._model
def get_config(self) -> Dict[str, Any]:
"""Return architecture hyperparameters for logging."""
return {
"architecture": "LMIC",
"input_length": self.input_length,
"num_classes": self.num_classes,
"conv_filters": self.conv_filters,
"kernel_size": self.kernel_size,
"pool_size": self.pool_size,
"head_units": self.head_units,
"dropout_rate": self.dropout_rate,
"label_smoothing": self.label_smoothing,
"spectral_decoupling_lambda": self.spectral_decoupling_lambda,
"focal_loss": self.focal_loss,
"focal_gamma": self.focal_gamma,
"clipnorm": self.clipnorm,
"jit_compile": self.jit_compile,
"multi_bit": self.multi_bit,
"num_tasks": NUM_TASKS,
}
# ---------------------------------------------------------------------------
# LMIC-TSBN: Task-Specific Batch Normalization variant
# ---------------------------------------------------------------------------
class LMICTSBNModel(BaseModel):
"""
LMIC with Task-Specific Batch Normalization (LMIC-TSBN).
This variant replaces the shared BatchNorm layers in the LMIC conv block
with 16 task-specific BN layers (one per byte). The Conv1D weights remain
fully shared across all 16 tasks — only the BN parameters (γ, β, running
mean, running variance) become per-task.
This eliminates the gradient interference that caused val_loss volatility
in the original LMIC (HISTORY_LEDGER Section 10), where shared BN running
statistics were corrupted by conflicting gradient signals from different
bytes.
The optional sigmoid gates (TSσBN) add a learnable per-filter gate per
task, enabling soft capacity allocation: each byte can learn to suppress
filters it doesn't need, allowing the shared conv weights to develop
both shared and specialized features.
Parameter overhead:
- Per-task BN: 16 × 2 × sum(conv_filters) extra trainable params
(γ, β per filter per task) + 16 × 2 × sum(conv_filters) non-trainable
(running mean/var per task). For [64, 128, 256]: ~28K extra params.
- Sigmoid gates: 16 × sum(conv_filters) extra trainable params.
For [64, 128, 256]: ~7K extra params.
- Total overhead: <2% of model size.
Architecture:
16 × Input(700, 1)
→ For each byte i:
shared_conv0(input_i) → task_bn0_i → [σ_gate0_i] → relu → avgpool
→ shared_conv1(...) → task_bn1_i → [σ_gate1_i] → relu → avgpool
→ shared_conv2(...) → task_bn2_i → [σ_gate2_i] → relu → avgpool
→ GAP → dropout → dense_i → softmax_i
References:
- Suteu & Serban, "Simplifying Multi-Task Architectures Through
Task-Specific Normalization", ICLR 2025.
- Zaid et al., "Methodology for Efficient CNN Architectures in
Profiling Attacks", TCHES 2020.
Args:
input_length: Length of each per-byte POI window (default: 700).
num_classes: Number of output classes (256 for AES S-Box).
conv_filters: List of filter counts for shared conv blocks.
kernel_size: Kernel size for Conv1D layers.
pool_size: Pool size for AveragePooling1D layers.
head_units: Number of units in each per-byte head's Dense layer.
dropout_rate: Dropout rate for regularization.
label_smoothing: Label smoothing factor for cross-entropy loss.
spectral_decoupling_lambda: L2 logit regularization strength.
use_sigmoid_gates: Whether to add σ-gates for soft capacity allocation.
"""
def __init__(
self,
input_length: int = WINDOW_SIZE,
num_classes: int = NUM_CLASSES,
conv_filters: Optional[List[int]] = None,
kernel_size: int = 11,
pool_size: int = 2,
head_units: int = 128,
dropout_rate: float = 0.2,
label_smoothing: float = 0.0,
spectral_decoupling_lambda: float = 0.0,
use_sigmoid_gates: bool = False,
focal_loss: bool = False,
focal_gamma: float = 2.0,
clipnorm: float = 0.0,
jit_compile: bool = False,
multi_bit: bool = False,
) -> None:
super().__init__(
input_shape=(input_length, 1),
num_classes=num_classes,
)
self.input_length = input_length
self.conv_filters = conv_filters or [64, 128, 256]
self.kernel_size = kernel_size
self.pool_size = pool_size
self.head_units = head_units
self.dropout_rate = dropout_rate
self.label_smoothing = label_smoothing
self.spectral_decoupling_lambda = spectral_decoupling_lambda
self.use_sigmoid_gates = use_sigmoid_gates
self.focal_loss = focal_loss
self.focal_gamma = focal_gamma
self.clipnorm = clipnorm
self.jit_compile = jit_compile
self.multi_bit = multi_bit
def build(self) -> keras.Model:
"""
Construct the LMIC-TSBN model with shared conv + per-task BN.
Instead of wrapping the conv block as a sub-Model (which would force
shared BN), we build the computation graph explicitly: shared Conv1D
layers are instantiated once and called on each byte's input, while
BatchNorm (and optional sigmoid gates) are instantiated per-task.
Returns:
A compiled-ready Keras Model with named inputs and outputs.
"""
# --- Create shared Conv1D layers (instantiated once) ---
shared_convs = []
for i, filters in enumerate(self.conv_filters):
conv = layers.Conv1D(
filters=filters,
kernel_size=self.kernel_size,
padding="same",
name=f"shared_conv{i}",
)
shared_convs.append(conv)
shared_pools = []
for i in range(len(self.conv_filters)):
pool = layers.AveragePooling1D(
pool_size=self.pool_size,
name=f"shared_pool{i}",
)
shared_pools.append(pool)
shared_gap = layers.GlobalAveragePooling1D(name="shared_gap")
# --- Create per-task BN layers and optional sigmoid gates ---
# task_bns[layer_idx][byte_idx] = BatchNormalization layer
# task_gates[layer_idx][byte_idx] = SigmoidGate layer (if enabled)
task_bns = []
task_gates = []
for i, filters in enumerate(self.conv_filters):
bn_list = []
gate_list = []
for byte_idx in range(NUM_TASKS):
bn = layers.BatchNormalization(
name=f"tsbn_L{i}_byte{byte_idx}",
)
bn_list.append(bn)
if self.use_sigmoid_gates:
gate = SigmoidGate(
num_filters=filters,
name=f"gate_L{i}_byte{byte_idx}",
)
gate_list.append(gate)
task_bns.append(bn_list)
task_gates.append(gate_list)
# Get spectral decoupling regularizer if enabled
sd_reg = None
if self.spectral_decoupling_lambda > 0:
sd_reg = get_spectral_decoupling_regularizer(
self.spectral_decoupling_lambda
)
logger.info(
"Spectral Decoupling ENABLED: lambda=%.4f",
self.spectral_decoupling_lambda,
)
# --- Build 16 input branches ---
all_inputs = {}
all_outputs = {}
for byte_idx in range(NUM_TASKS):
byte_input = layers.Input(
shape=(self.input_length, 1),
name=f"byte_{byte_idx}_input",
)
all_inputs[f"byte_{byte_idx}_input"] = byte_input
x = byte_input
# Pass through shared conv layers with per-task BN
for layer_idx in range(len(self.conv_filters)):
# Shared conv weights
x = shared_convs[layer_idx](x)
# Task-specific BatchNorm
x = task_bns[layer_idx][byte_idx](x)
# Optional sigmoid gate for soft capacity allocation
if self.use_sigmoid_gates and task_gates[layer_idx]:
x = task_gates[layer_idx][byte_idx](x)
# Activation and pooling (shared, stateless)
x = layers.Activation(
"relu",
name=f"relu_L{layer_idx}_byte{byte_idx}",
)(x)
x = shared_pools[layer_idx](x)
# Global average pooling
features = shared_gap(x)
# Per-byte independent head
x = layers.Dropout(
self.dropout_rate,
name=f"byte_{byte_idx}_dropout",
)(features)
x = layers.Dense(
self.head_units,
activation="relu",
name=f"byte_{byte_idx}_dense",
)(x)
# Output layer — pinned to float32 for mixed precision safety
if self.multi_bit:
# Multi-bit mode: 8 independent binary outputs per byte
# (Wu et al., TCHES 2024)
all_outputs[f"byte_{byte_idx}"] = layers.Dense(
8,
activation="sigmoid",
dtype="float32",
kernel_regularizer=sd_reg,
name=f"byte_{byte_idx}",
)(x)
else:
# Standard mode: 256-class softmax
all_outputs[f"byte_{byte_idx}"] = layers.Dense(
self.num_classes,
activation="softmax",
dtype="float32",
kernel_regularizer=sd_reg,
name=f"byte_{byte_idx}",
)(x)
model = keras.Model(
inputs=all_inputs,
outputs=all_outputs,
name="LMIC_TSBN" + ("_MB" if self.multi_bit else ""),
)
# Log architecture summary
total_params = model.count_params()
num_shared_conv_params = sum(
sum(w.numpy().size for w in conv.weights)
for conv in shared_convs
)
num_tsbn_params = sum(
sum(w.numpy().size for w in bn.weights)
for bn_list in task_bns
for bn in bn_list
)
num_gate_params = 0
if self.use_sigmoid_gates:
num_gate_params = sum(
sum(w.numpy().size for w in gate.weights)
for gate_list in task_gates
for gate in gate_list
)
logger.info(
"LMIC-TSBN built: total=%d params, shared_conv=%d, "
"task_specific_BN=%d, sigmoid_gates=%d, heads+other=%d",
total_params,
num_shared_conv_params,
num_tsbn_params,
num_gate_params,
total_params - num_shared_conv_params - num_tsbn_params - num_gate_params,
)
self._model = model
return model
def compile(self, learning_rate: float = 5e-4) -> keras.Model:
"""
Build and compile the LMIC-TSBN model with per-byte losses.
Uses Adam optimizer with categorical cross-entropy loss for each
of the 16 output heads.
Args:
learning_rate: Learning rate for Adam optimizer.
Returns:
The compiled Keras model.
"""
if self._model is None:
self._model = self.build()
# Build loss dict for all 16 bytes
if self.multi_bit:
# Multi-bit mode: binary crossentropy for 8 independent bits
loss_fn = keras.losses.BinaryCrossentropy(
label_smoothing=self.label_smoothing,
)
logger.info(
"TSBN Multi-Bit mode: BinaryCrossentropy, label_smoothing=%.2f",
self.label_smoothing,
)
# Accuracy metric: use binary accuracy for multi-bit
metrics_dict = {
f"byte_{i}": ["binary_accuracy"] for i in range(NUM_TASKS)
}
elif self.focal_loss:
from ..focal_loss import FocalCategoricalCrossentropy
loss_fn = FocalCategoricalCrossentropy(
gamma=self.focal_gamma,
label_smoothing=self.label_smoothing,
)
logger.info(
"TSBN Focal Loss ENABLED: gamma=%.1f, label_smoothing=%.2f",
self.focal_gamma, self.label_smoothing,
)
metrics_dict = {
f"byte_{i}": ["accuracy"] for i in range(NUM_TASKS)
}
else:
loss_fn = keras.losses.CategoricalCrossentropy(
label_smoothing=self.label_smoothing,
)
metrics_dict = {
f"byte_{i}": ["accuracy"] for i in range(NUM_TASKS)
}
losses = {f"byte_{i}": loss_fn for i in range(NUM_TASKS)}
# Build optimizer with optional gradient clipping
opt_kwargs = {"learning_rate": learning_rate}
if self.clipnorm > 0:
opt_kwargs["clipnorm"] = self.clipnorm
logger.info("TSBN Gradient clipping ENABLED: clipnorm=%.2f", self.clipnorm)
self._model.compile(
optimizer=keras.optimizers.Adam(**opt_kwargs),
loss=losses,
metrics=metrics_dict,
jit_compile=self.jit_compile,
)
if self.jit_compile:
logger.info("TSBN XLA jit_compile ENABLED")
return self._model
def get_config(self) -> Dict[str, Any]:
"""Return architecture hyperparameters for logging."""
return {
"architecture": "LMIC_TSBN",
"input_length": self.input_length,
"num_classes": self.num_classes,
"conv_filters": self.conv_filters,
"kernel_size": self.kernel_size,
"pool_size": self.pool_size,
"head_units": self.head_units,
"dropout_rate": self.dropout_rate,
"label_smoothing": self.label_smoothing,
"spectral_decoupling_lambda": self.spectral_decoupling_lambda,
"use_sigmoid_gates": self.use_sigmoid_gates,
"focal_loss": self.focal_loss,
"focal_gamma": self.focal_gamma,
"clipnorm": self.clipnorm,
"jit_compile": self.jit_compile,
"multi_bit": self.multi_bit,
"num_tasks": NUM_TASKS,
}