FBAGSTM's picture
STM32 AI Experimentation Hub
747451d
# /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2022-2023 STMicroelectronics.
# * All rights reserved.
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
import os
import sys
from hydra.core.hydra_config import HydraConfig
import hydra
import warnings
from timm import utils as timm_utils
from types import SimpleNamespace
warnings.filterwarnings("ignore")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
import tensorflow as tf
import torch
from omegaconf import DictConfig
import mlflow
import argparse
import logging
from typing import Optional
from clearml import Task
from clearml.backend_config.defs import get_active_config_file
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))
from api import get_model, get_dataloaders, get_trainer, get_quantizer, get_evaluator, get_predictor
from common.utils import mlflow_ini, set_gpu_memory_limit, get_random_seed, display_figures, log_to_file
from common.benchmarking import benchmark, cloud_connect
from common.evaluation import gen_load_val
from common.prediction import gen_load_val_predict
from image_classification.tf.src.utils import get_config
from image_classification.tf.src.deployment import deploy, deploy_mpu
from common.onnx_utils.onnx_model_convertor import torch_model_export_static
def _process_mode(cfg: DictConfig = None) -> None:
"""
Process the selected mode of operation.
Args:
cfg (DictConfig): The configuration object.
Returns:
None
Raises:
ValueError: If an invalid operation_mode is selected or if required datasets are missing.
"""
# Logging the operation_mode in the output_dir/stm32ai_main.log file
mode = cfg.operation_mode
mlflow.log_param("model_path", cfg.model.model_path)
log_to_file(cfg.output_dir, f'operation_mode: {mode}')
# Connect to STM32Cube.AI Developer Cloud if needed
credentials = None
if cfg.tools and cfg.tools.stedgeai and cfg.tools.stedgeai.on_cloud:
_, _, credentials = cloud_connect(stedgeai_core_version=cfg.tools.stedgeai.version)
# Creates model
model = get_model(cfg=cfg)
saved_model_dir = os.path.join(cfg.output_dir, cfg.general.saved_models_dir)
os.makedirs(saved_model_dir, exist_ok=True)
if cfg.model.framework == 'torch' and isinstance(model, torch.nn.Module) and cfg.operation_mode not in ['training', 'chain_tb', 'chain_tqe', 'chain_tqeb', 'chain_tbqeb']:
# Export Torch models in onnx format for all services but training
# (export to onnx is also handled at the end of the trainer.train() method)
model = torch_model_export_static(cfg=cfg,
model_dir=saved_model_dir,
model=model)
# Creates dataloaders
dataloaders = get_dataloaders(cfg=cfg)
# Executes Services
if mode == 'training':
trainer = get_trainer(cfg=cfg,
model=model,
dataloaders=dataloaders)
trained_model = trainer.train()
display_figures(cfg)
evaluator = get_evaluator(cfg=cfg,
model=trained_model,
dataloaders=dataloaders)
acc = evaluator.evaluate()
elif mode == 'evaluation':
gen_load_val(cfg=cfg, model=model)
os.chdir(os.path.dirname(os.path.realpath(__file__)))
evaluator = get_evaluator(cfg=cfg,
model=model,
dataloaders=dataloaders)
acc = evaluator.evaluate()
elif mode == 'deployment':
if cfg.hardware_type == "MPU":
deploy_mpu(cfg=cfg, model_path_to_deploy=model.model_path)
else:
deploy(cfg=cfg, model_path_to_deploy=model.model_path)
elif mode == 'quantization':
quantizer = get_quantizer(cfg=cfg,
model=model,
dataloaders=dataloaders)
quantized_model = quantizer.quantize()
elif mode == 'prediction':
gen_load_val_predict(cfg=cfg, model=model)
os.chdir(os.path.dirname(os.path.realpath(__file__)))
predictor = get_predictor(cfg=cfg,
model=model,
dataloaders=dataloaders)
predictor.predict()
elif mode == 'benchmarking':
benchmark(cfg=cfg, model_path_to_benchmark=model.model_path)
elif mode == 'chain_tqe':
trainer = get_trainer(cfg=cfg,
model=model,
dataloaders=dataloaders)
trained_model = trainer.train()
evaluator = get_evaluator(cfg=cfg,
model=trained_model,
dataloaders=dataloaders)
acc = evaluator.evaluate()
quantizer = get_quantizer(cfg=cfg,
model=trained_model,
dataloaders=dataloaders)
quantized_model = quantizer.quantize()
evaluator = get_evaluator(cfg=cfg,
model=quantized_model,
dataloaders=dataloaders)
acc = evaluator.evaluate()
print('[INFO] : chain_tqe complete.')
elif mode == 'chain_tqeb':
trainer = get_trainer(cfg=cfg,
model=model,
dataloaders=dataloaders)
trained_model = trainer.train()
evaluator = get_evaluator(cfg=cfg,
model=trained_model,
dataloaders=dataloaders)
acc = evaluator.evaluate()
quantizer = get_quantizer(cfg=cfg,
model=trained_model,
dataloaders=dataloaders)
quantized_model = quantizer.quantize()
evaluator = get_evaluator(cfg=cfg,
model=quantized_model,
dataloaders=dataloaders)
acc = evaluator.evaluate()
benchmark(cfg=cfg, model_path_to_benchmark=quantized_model.model_path, credentials=credentials)
print('[INFO] : chain_tqeb complete.')
elif mode == 'chain_eqe':
evaluator = get_evaluator(cfg=cfg,
model=model,
dataloaders=dataloaders)
acc = evaluator.evaluate()
quantizer = get_quantizer(cfg=cfg,
model=model,
dataloaders=dataloaders)
quantized_model = quantizer.quantize()
evaluator = get_evaluator(cfg=cfg,
model=quantized_model,
dataloaders=dataloaders)
acc = evaluator.evaluate()
print('[INFO] : chain_eqe complete.')
elif mode == 'chain_eqeb':
evaluator = get_evaluator(cfg=cfg,
model=model,
dataloaders=dataloaders)
acc = evaluator.evaluate()
quantizer = get_quantizer(cfg=cfg,
model=model,
dataloaders=dataloaders)
quantized_model = quantizer.quantize()
evaluator = get_evaluator(cfg=cfg,
model=quantized_model,
dataloaders=dataloaders)
acc = evaluator.evaluate()
benchmark(cfg=cfg, model_path_to_benchmark=quantized_model.model_path, credentials=credentials)
print('[INFO] : chain_eqeb complete.')
elif mode == 'chain_qb':
quantizer = get_quantizer(cfg=cfg,
model=model,
dataloaders=dataloaders)
quantized_model = quantizer.quantize()
benchmark(cfg=cfg, model_path_to_benchmark=quantized_model.model_path, credentials=credentials)
print('[INFO] : chain_qb complete.')
elif mode == 'chain_qd':
quantizer = get_quantizer(cfg=cfg,
model=model,
dataloaders=dataloaders)
quantized_model = quantizer.quantize()
if cfg.hardware_type == "MCU":
deploy(cfg=cfg, model_path_to_deploy=quantized_model.model_path, credentials=credentials)
else:
deploy_mpu(cfg=cfg, model_path_to_deploy=quantized_model.model_path, credentials=credentials)
print('[INFO] : chain_qd complete.')
# Raise an error if an invalid mode is selected
else:
raise ValueError(f"Invalid mode: {mode}")
# Record the whole hydra working directory to get all info
mlflow.log_artifact(cfg.output_dir)
if mode in ['benchmarking', 'chain_qb', 'chain_eqeb', 'chain_tqeb']:
mlflow.log_param("stedgeai_core_version", cfg.tools.stedgeai.version)
mlflow.log_param("target", cfg.benchmarking.board)
# Logging the completion of the chain
log_to_file(cfg.output_dir, f'operation finished: {mode}')
# ClearML - Example how to get task's context anywhere in the file.
# Checks if there's a valid ClearML configuration file
if get_active_config_file() is not None:
print(f"[INFO] : ClearML task connection")
task = Task.current_task()
task.connect(cfg)
def _fw_agnostic_initializations(cfg: DictConfig = None) -> DictConfig:
"""
Framework-agnostic initializations.
This function performs initializations that are independent of the specific deep learning framework being used.
It includes parsing the configuration file, setting up MLFlow, and initializing ClearML if a valid configuration
file is found.
Args:
cfg (DictConfig): Configuration object.
Returns:
DictConfig: Updated configuration object with initialized settings.
"""
# Parse the configuration file and set the output directory
cfg = get_config(cfg)
cfg.output_dir = HydraConfig.get().run.dir
# Initialize MLFlow for experiment tracking
# MLFlow is used to log metrics, parameters, and artifacts during training
mlflow_ini(cfg)
# Check if there's a valid ClearML configuration file and initialize ClearML
print(f"[INFO] : ClearML config check")
if get_active_config_file() is not None:
# If a ClearML configuration file is found, initialize ClearML
print(f"[INFO] : ClearML initialization and configuration")
# Initialize ClearML's Task object with the project and task names
task = Task.init(project_name=cfg.general.project_name,
task_name='ic_modelzoo_task')
# Optionally log the configuration to ClearML
task.connect_configuration(name=cfg.operation_mode,
configuration=cfg)
# Return the updated configuration object
return cfg
def _tf_specific_initializations(cfg: DictConfig = None) -> None:
"""
TensorFlow-specific initializations.
This function performs initializations specific to TensorFlow, such as configuring GPU memory limits
and setting a random seed for reproducibility.
Args:
cfg (DictConfig): Configuration object.
"""
# Check if the 'general' section exists in the configuration
if "general" in cfg and cfg.general:
# Set an upper limit on GPU memory usage if specified in the configuration
if "gpu_memory_limit" in cfg.general and cfg.general.gpu_memory_limit:
set_gpu_memory_limit(cfg.general.gpu_memory_limit)
print(f"[INFO] : Setting upper limit of usable GPU memory to {int(cfg.general.gpu_memory_limit)}GBytes.")
else:
# Warn the user if GPU memory usage is unlimited
print("[WARNING] The usable GPU memory is unlimited.\n"
"Please consider setting the 'gpu_memory_limit' attribute "
"in the 'general' section of your configuration file.")
# Set a random seed for reproducibility
seed = get_random_seed(cfg)
print(f'[INFO] : The random seed for this simulation is {seed}')
if seed is not None:
tf.keras.utils.set_random_seed(seed)
def _torch_specific_initializations(cfg: DictConfig = None) -> None:
"""
PyTorch-specific initializations.
This function is a placeholder for PyTorch-specific initializations, such as configuring GPU memory limits
and setting a random seed for reproducibility.
Args:
cfg (DictConfig): Configuration object.
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
temp_args = SimpleNamespace(
device=device,
)
device = timm_utils.init_distributed_device(temp_args)
cfg.device = temp_args.device
cfg.world_size = temp_args.world_size
cfg.rank = temp_args.rank
cfg.local_rank = temp_args.local_rank
cfg.distributed = temp_args.distributed
@hydra.main(version_base=None, config_path="", config_name="user_config")
def main(cfg: DictConfig) -> None:
"""
Main entry point of the script.
Args:
cfg: Configuration dictionary.
Returns:
None
"""
# Framework agnostic initializations
cfg = _fw_agnostic_initializations(cfg)
# Framework specific initializations
if cfg.model.framework == "tf":
_tf_specific_initializations(cfg)
elif cfg.model.framework == "torch":
_torch_specific_initializations(cfg)
else:
raise ValueError(f"Invalid framework used: {cfg.model.framework}")
# Executes the required service
_process_mode(cfg=cfg)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--config-path', type=str, default='./', help='Path to folder containing configuration file')
parser.add_argument('--config-name', type=str, default='user_config.yaml', help='name of the configuration file')
# add arguments to the parser
parser.add_argument('params', nargs='*',
help='List of parameters to over-ride in config.yaml')
args = parser.parse_args()
# Call the main function
main()
# log the config_path and config_name parameters
mlflow.log_param('config_path', args.config_path)
mlflow.log_param('config_name', args.config_name)
mlflow.end_run()