FBAGSTM's picture
STM32 AI Experimentation Hub
747451d
# /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2025 STMicroelectronics.
# * All rights reserved.
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
from common.registries.model_registry import MODEL_WRAPPER_REGISTRY
from image_classification.tf.src.models import *
from image_classification.tf.src.models import prepare_kwargs_for_model
NUM_IMAGENET_CLASSES = 1000
TF_CUSTOM_MODEL_FNS = {
# Key: model name, Value: (wrapper function, dict of static kwargs)
'custom_model': (get_custom_model, {}),
'st_efficientnetlcv1': (get_st_efficientnetlcv1, {}),
'st_fdmobilenetv1': (get_st_fdmobilenetv1, {}),
'st_mnistv1': (get_st_mnistv1, {}),
}
def _register_tf_model_wrapper(model_fn, model_name, **model_init_kwargs):
"""
Register a TensorFlow model wrapper in the global registry.
Args:
model_fn (callable): The model-building function.
model_name (str): Name to register the model under.
**model_init_kwargs: Static keyword arguments for the model_fn.
Returns:
function: The registered build_model_fn function.
"""
def build_model_fn(cfg):
"""
Build and return the model instance for the registry.
Args:
cfg (dict): top level config
Returns:
keras.Model: The constructed model.
"""
model_kwargs = prepare_kwargs_for_model(cfg)
num_classes = getattr(cfg.dataset, 'num_classes', NUM_IMAGENET_CLASSES) if cfg.dataset else NUM_IMAGENET_CLASSES
pretrained = getattr(cfg.model, 'pretrained', False)
merged_kwargs = {**model_kwargs, **model_init_kwargs} # model_init_kwargs kept in case of conflicts
return model_fn(num_classes=num_classes, pretrained=pretrained, **merged_kwargs)
# Register the model in the global registry
build_model_fn = MODEL_WRAPPER_REGISTRY.register(
framework='tf',
model_name=model_name,
use_case='image_classification',
)(build_model_fn)
# Set a unique function name for clarity
build_model_fn.__name__ = f'{model_name}_tf'
return build_model_fn
# Register all custom models defined in TF_CUSTOM_MODEL_FNS
for _model_name, (_model_fn, _model_kwargs) in TF_CUSTOM_MODEL_FNS.items():
wrapper_fn = _register_tf_model_wrapper(_model_fn, _model_name, **_model_kwargs)
# Expose the wrapper in the module's global namespace
globals()[wrapper_fn.__name__] = wrapper_fn