| """ |
| Abstract base class for all ASCAD attack models. |
| |
| Provides a common interface for building, compiling, and describing models |
| so that the trainer and evaluation modules can work with any architecture |
| without knowing its internals. |
| """ |
|
|
| from abc import ABC, abstractmethod |
| from typing import Dict, Any |
|
|
| import tensorflow as tf |
|
|
|
|
| class BaseModel(ABC): |
| """Abstract base class for ASCAD side-channel attack models.""" |
|
|
| def __init__(self, input_shape: tuple, num_classes: int = 256) -> None: |
| """ |
| Args: |
| input_shape: Shape of a single input sample (excluding batch dim). |
| num_classes: Number of output classes (256 for AES S-Box). |
| """ |
| self.input_shape = input_shape |
| self.num_classes = num_classes |
| self._model: tf.keras.Model = None |
|
|
| @abstractmethod |
| def build(self) -> tf.keras.Model: |
| """ |
| Construct and return the Keras model. |
| |
| Subclasses must implement this to define the architecture. |
| The returned model should NOT be compiled yet. |
| """ |
| raise NotImplementedError |
|
|
| def compile(self, learning_rate: float = 1e-5) -> tf.keras.Model: |
| """ |
| Build (if needed) and compile the model with standard SCA settings. |
| |
| Args: |
| learning_rate: Learning rate for the RMSprop optimizer. |
| |
| Returns: |
| The compiled Keras model. |
| """ |
| if self._model is None: |
| self._model = self.build() |
|
|
| self._model.compile( |
| optimizer=tf.keras.optimizers.RMSprop(learning_rate=learning_rate), |
| loss="categorical_crossentropy", |
| metrics=["accuracy"], |
| ) |
| return self._model |
|
|
| @property |
| def model(self) -> tf.keras.Model: |
| """Access the underlying Keras model (builds if needed).""" |
| if self._model is None: |
| self._model = self.build() |
| return self._model |
|
|
| @abstractmethod |
| def get_config(self) -> Dict[str, Any]: |
| """ |
| Return a dictionary of architecture hyperparameters. |
| |
| Used for logging, reproducibility, and results.json metadata. |
| """ |
| raise NotImplementedError |
|
|
| def summary(self) -> None: |
| """Print the model summary.""" |
| self.model.summary() |
|
|
| @property |
| def name(self) -> str: |
| """Human-readable model name.""" |
| return self.__class__.__name__ |
|
|