lemousehunter
v3: Add DTP + Spectral Decoupling, fix GradNorm OOM, fix _fail_job cancel
283a882
"""
Abstract base class for all ASCAD attack models.
Provides a common interface for building, compiling, and describing models
so that the trainer and evaluation modules can work with any architecture
without knowing its internals.
"""
from abc import ABC, abstractmethod
from typing import Dict, Any
import tensorflow as tf
class BaseModel(ABC):
"""Abstract base class for ASCAD side-channel attack models."""
def __init__(self, input_shape: tuple, num_classes: int = 256) -> None:
"""
Args:
input_shape: Shape of a single input sample (excluding batch dim).
num_classes: Number of output classes (256 for AES S-Box).
"""
self.input_shape = input_shape
self.num_classes = num_classes
self._model: tf.keras.Model = None
@abstractmethod
def build(self) -> tf.keras.Model:
"""
Construct and return the Keras model.
Subclasses must implement this to define the architecture.
The returned model should NOT be compiled yet.
"""
raise NotImplementedError
def compile(self, learning_rate: float = 1e-5) -> tf.keras.Model:
"""
Build (if needed) and compile the model with standard SCA settings.
Args:
learning_rate: Learning rate for the RMSprop optimizer.
Returns:
The compiled Keras model.
"""
if self._model is None:
self._model = self.build()
self._model.compile(
optimizer=tf.keras.optimizers.RMSprop(learning_rate=learning_rate),
loss="categorical_crossentropy",
metrics=["accuracy"],
)
return self._model
@property
def model(self) -> tf.keras.Model:
"""Access the underlying Keras model (builds if needed)."""
if self._model is None:
self._model = self.build()
return self._model
@abstractmethod
def get_config(self) -> Dict[str, Any]:
"""
Return a dictionary of architecture hyperparameters.
Used for logging, reproducibility, and results.json metadata.
"""
raise NotImplementedError
def summary(self) -> None:
"""Print the model summary."""
self.model.summary()
@property
def name(self) -> str:
"""Human-readable model name."""
return self.__class__.__name__