lemousehunter
feat: LMIC-TSBN model + persistence fixes across restarts
cbb6546
"""
Model registry for ASCAD attack architectures.
Usage:
from src.models import create_model
model_wrapper = create_model("mlp")
keras_model = model_wrapper.compile(learning_rate=1e-5)
Supported architectures:
- ``mlp``: Multi-Layer Perceptron (Prouff et al., 2019)
- ``cnn``: CNNbest 1-D convolutional network (Prouff et al., 2019)
- ``mtan``: Full SNR-MTAN with per-block attention (deprecated)
- ``hps``: Hard Parameter Sharing multi-task (Marquet & Oswald, 2024)
- ``mtan_lite``: Simplified MTAN with final-block attention (novel)
- ``lmic``: Localized Multi-Input CNN (16 per-byte POI inputs)
- ``lmic_tsbn``: LMIC with Task-Specific Batch Normalization (Suteu & Serban, 2025)
"""
from typing import Dict, Type
from .base import BaseModel
from .mlp import MLPBest
from .cnn import CNNBest
from .mtan import SNRMTAN
from .mtl import HPSModel, MTANLiteModel
from .lmic import LMICModel, LMICTSBNModel
# Registry mapping model type strings to classes
MODEL_REGISTRY: Dict[str, Type[BaseModel]] = {
"mlp": MLPBest,
"cnn": CNNBest,
"mtan": SNRMTAN,
"hps": HPSModel,
"mtan_lite": MTANLiteModel,
"lmic": LMICModel,
"lmic_tsbn": LMICTSBNModel,
}
def create_model(model_type: str, **kwargs) -> BaseModel:
"""
Factory function to create a model by type name.
Args:
model_type: One of the registered model types.
**kwargs: Additional keyword arguments passed to the model constructor.
Returns:
An instance of the requested model (not yet compiled).
Raises:
ValueError: If model_type is not registered.
"""
if model_type not in MODEL_REGISTRY:
raise ValueError(
f"Unknown model type '{model_type}'. "
f"Available: {list(MODEL_REGISTRY.keys())}"
)
return MODEL_REGISTRY[model_type](**kwargs)
def register_model(name: str, model_class: Type[BaseModel]) -> None:
"""
Register a new model class in the registry.
Args:
name: String identifier for the model type.
model_class: The model class (must subclass BaseModel).
"""
if not issubclass(model_class, BaseModel):
raise TypeError(f"{model_class} must be a subclass of BaseModel")
MODEL_REGISTRY[name] = model_class
__all__ = [
"BaseModel",
"MLPBest",
"CNNBest",
"SNRMTAN",
"HPSModel",
"MTANLiteModel",
"LMICModel",
"LMICTSBNModel",
"create_model",
"register_model",
"MODEL_REGISTRY",
]