| """ |
| Model registry for ASCAD attack architectures. |
| |
| Usage: |
| from src.models import create_model |
| model_wrapper = create_model("mlp") |
| keras_model = model_wrapper.compile(learning_rate=1e-5) |
| |
| Supported architectures: |
| - ``mlp``: Multi-Layer Perceptron (Prouff et al., 2019) |
| - ``cnn``: CNNbest 1-D convolutional network (Prouff et al., 2019) |
| - ``mtan``: Full SNR-MTAN with per-block attention (deprecated) |
| - ``hps``: Hard Parameter Sharing multi-task (Marquet & Oswald, 2024) |
| - ``mtan_lite``: Simplified MTAN with final-block attention (novel) |
| - ``lmic``: Localized Multi-Input CNN (16 per-byte POI inputs) |
| - ``lmic_tsbn``: LMIC with Task-Specific Batch Normalization (Suteu & Serban, 2025) |
| """ |
|
|
| from typing import Dict, Type |
|
|
| from .base import BaseModel |
| from .mlp import MLPBest |
| from .cnn import CNNBest |
| from .mtan import SNRMTAN |
| from .mtl import HPSModel, MTANLiteModel |
| from .lmic import LMICModel, LMICTSBNModel |
|
|
| |
| MODEL_REGISTRY: Dict[str, Type[BaseModel]] = { |
| "mlp": MLPBest, |
| "cnn": CNNBest, |
| "mtan": SNRMTAN, |
| "hps": HPSModel, |
| "mtan_lite": MTANLiteModel, |
| "lmic": LMICModel, |
| "lmic_tsbn": LMICTSBNModel, |
| } |
|
|
|
|
| def create_model(model_type: str, **kwargs) -> BaseModel: |
| """ |
| Factory function to create a model by type name. |
| |
| Args: |
| model_type: One of the registered model types. |
| **kwargs: Additional keyword arguments passed to the model constructor. |
| |
| Returns: |
| An instance of the requested model (not yet compiled). |
| |
| Raises: |
| ValueError: If model_type is not registered. |
| """ |
| if model_type not in MODEL_REGISTRY: |
| raise ValueError( |
| f"Unknown model type '{model_type}'. " |
| f"Available: {list(MODEL_REGISTRY.keys())}" |
| ) |
| return MODEL_REGISTRY[model_type](**kwargs) |
|
|
|
|
| def register_model(name: str, model_class: Type[BaseModel]) -> None: |
| """ |
| Register a new model class in the registry. |
| |
| Args: |
| name: String identifier for the model type. |
| model_class: The model class (must subclass BaseModel). |
| """ |
| if not issubclass(model_class, BaseModel): |
| raise TypeError(f"{model_class} must be a subclass of BaseModel") |
| MODEL_REGISTRY[name] = model_class |
|
|
|
|
| __all__ = [ |
| "BaseModel", |
| "MLPBest", |
| "CNNBest", |
| "SNRMTAN", |
| "HPSModel", |
| "MTANLiteModel", |
| "LMICModel", |
| "LMICTSBNModel", |
| "create_model", |
| "register_model", |
| "MODEL_REGISTRY", |
| ] |
|
|