File size: 3,797 Bytes
747451d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#  /*---------------------------------------------------------------------------------------------
#  * Copyright (c) 2022-2023 STMicroelectronics.
#  * All rights reserved.
#  *
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/

import os
import tensorflow as tf
from omegaconf import DictConfig
from tensorflow.keras import layers

def prepare_kwargs_for_model(cfg: DictConfig):

    dropout = cfg.training.dropout if cfg.training and 'dropout' in cfg.training else None
    model_kwargs = {
        'pretrained': getattr(cfg.model, 'pretrained', True),
        'num_classes': getattr(cfg.dataset, 'num_classes', 80),
        'model_type': getattr(cfg.model, 'model_type', None),
        'width_mul': getattr(cfg.model, 'width_mul', None),
        'depth_mul': getattr(cfg.model, 'depth_mul', None),
        'input_shape': getattr(cfg.model, 'input_shape', None),
        'num_anchors': getattr(cfg.postprocessing, 'num_anchors', None),
        'dropout': dropout,
    }
    return model_kwargs


def model_family(model_type: str) -> str:
    if model_type in ("ssd_mobilenet_v2_fpnlite"):
        return "ssd_mobilenet_v2_fpnlite"
    elif model_type in ("yolov2t", "st_yololcv1"):
        return "yolo"
    elif model_type in ("yolov8n", "yolov11n", "yolov5u"):
        return "yolov8n"
    elif model_type in ("st_yoloxn"):
        return "st_yoloxn"
    elif model_type in ("yolov4t", "yolov4"):
        return "yolov4"
    elif model_type in ("face_detect_front"):
        return "face_detect_front"
    elif model_type in ("st_yolod"):
        return "st_yolod"
    elif model_type in ("ssd"):
        return "ssd"
    else:
        raise ValueError(f"Internal error: unknown model type {model_type}")


def load_model_for_training(cfg: DictConfig) -> tuple:
    """"
    Loads a model for training.
    
    The model to train can be:
    - a model from the Model Zoo
    - a user model (BYOM)
    - a model previously trained during a training that was interrupted.
    
    When a training is run, the following files are saved in the saved_models
    directory:
        base_model.h5(.keras):
            Model saved before the training started. Weights are random.
        best_weights.h5(.keras):
            Best weights obtained since the beginning of the training.
        last_weights.h5(.keras):
            Weights saved at the end of the last epoch.
    
    To resume a training, the last weights are loaded into the base model.
    """
     
    model = None
                        
    # Resume a previously interrupted training
    if cfg.model.resume_training_from:
        resume_dir = os.path.join(cfg.model.resume_training_from, cfg.general.saved_models_dir)
        print(f"[INFO] : Resuming training from directory {resume_dir}\n")
        
        message = "\nUnable to resume training."
        if not os.path.isdir(resume_dir):
            raise FileNotFoundError(f"\nCould not find resume directory {resume_dir}{message}")
        model_path = os.path.join(resume_dir, "base_model.keras")
        if not os.path.isfile(model_path):
            raise FileNotFoundError(f"\nCould not find model file {model_path}{message}\n")
        last_weights_path = os.path.join(resume_dir, "last_weights.weights.h5")
        if not os.path.isfile(last_weights_path):
            raise FileNotFoundError(f"\nCould not find model weights file {last_weights_path}{message}\n")
        
        model = tf.keras.models.load_model(model_path, compile=False)
        model.load_weights(last_weights_path)

    return model