|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import tensorflow as tf
|
| import onnxruntime
|
|
|
| def load_model_from_path(cfg, model_path):
|
| """
|
| Loads a model from the given file path.
|
|
|
| Supported formats:
|
| - Keras (.h5, .keras): Returns a compiled Keras model.
|
| - TFLite (.tflite): Returns a TensorFlow Lite Interpreter.
|
| - ONNX (.onnx): Returns an ONNX Runtime InferenceSession.
|
|
|
| Args:
|
| model_path (str): Path to the model file.
|
|
|
| Returns:
|
| model object: Loaded model (Keras model, TFLite Interpreter, or ONNX InferenceSession) with 'model_path' and 'input_shape' attributes.
|
|
|
| Raises:
|
| ValueError: If the file format is not supported.
|
| """
|
| file_extension = str(model_path).split('.')[-1]
|
| model = None
|
| input_shape = None
|
| if cfg.training:
|
| resume_training_from = getattr(cfg.model, 'resume_training_from', None)
|
| else:
|
| resume_training_from = None
|
| if file_extension in ['h5', 'keras']:
|
| if resume_training_from:
|
| model = tf.keras.models.load_model(model_path, compile=True)
|
| else:
|
| model = tf.keras.models.load_model(model_path, compile=False)
|
| input_shape = (tuple(model.inputs[0].shape))[1:]
|
| elif file_extension == 'tflite':
|
| num_threads = getattr(cfg.general, 'num_threads_tflite', 1)
|
| model = tf.lite.Interpreter(model_path, num_threads=num_threads)
|
| model.allocate_tensors()
|
| input_details = model.get_input_details()
|
| input_shape = tuple(input_details[0]['shape'])
|
|
|
| elif file_extension == 'onnx':
|
| model = onnxruntime.InferenceSession(model_path)
|
| input_shape = tuple(model.get_inputs()[0].shape)
|
|
|
| else:
|
| raise ValueError(f"Unsupported model file format: {file_extension}. Supported formats are: h5, keras, tflite, onnx.")
|
| setattr(model, 'model_path', model_path)
|
| cfg.model.input_shape = input_shape
|
| return model |