File size: 34,059 Bytes
747451d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
#  /*---------------------------------------------------------------------------------------------
#  * Copyright (c) 2022-2023 STMicroelectronics.
#  * All rights reserved.
#  *
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/

import os
from typing import Dict, List
from munch import DefaultMunch
from omegaconf import DictConfig
from pathlib import Path
import re
import numpy as np
import requests
from hydra.core.hydra_config import HydraConfig


aspect_ratio_dict = {"fit": "ASPECT_RATIO_FIT",
                     "crop": "ASPECT_RATIO_CROP",
                     "padding": "ASPECT_RATIO_PADDING",
                     "full_screen": "ASPECT_RATIO_FULLSCREEN"}
                     
color_mode_n6_dict = {"rgb": "COLOR_RGB",
                      "bgr": "COLOR_BGR"}
                      


def download_file(url:str, local_path:str):
    """
    Downloads a file from the given URL and saves it to the specified local path.
    args:
        url (str): URL of the file to download
        local_path(str): Local path where the file should be saved
    """
    try:
        # Send a GET request to the URL
        response = requests.get(url, stream=True, timeout=20) # timeout of 20 seconds

        # Check if the request was successful
        response.raise_for_status()

        # Open the local file in write-binary mode
        with open(local_path, 'wb') as file:
            # Iterate over the response content in chunks
            for chunk in response.iter_content(chunk_size=8192):
                # Write each chunk to the local file
                file.write(chunk)

        print(f"[INFO] : Pretrained model file downloaded successfully and saved as :\n\t{local_path}")

    except requests.exceptions.HTTPError as http_err:
        print(f"[ERROR] : HTTP error occurred :\n\t{http_err}")
    except Exception as err:
        print(f"[ERROR] : An error occurred while downloading the pretrained model \n\t: {err}")


def check_attributes(cfg: Dict, 
                     expected: List[str] = None, 
                     optional: List[str] = [], 
                     section: str = None) -> None:
    '''
    Checks that all the expected attributes are present in the configuration dictionary and that
    there are no unknown attributes. Optional attributes may also be present.
    args:
        cfg (dict): The configuration dictionary.
        expected (list): A list of expected (required) attributes.
        optional (list, optional): A list of optional attributes. Defaults to [].
        section (str, optional): The name of the config file section to check. Defaults to None.
    raises:
        ValueError: If an unknown or unsupported attribute is found or if a required attribute is missing.
    '''
    if section is not None:
        message = "\nPlease check the '{}' section of your configuration file.".format(section)
    else:
        message = "\nPlease check your configuration file."

    if cfg is not None:
        if type(cfg) != DefaultMunch:
            raise ValueError("Expecting an attribute. Received {}{}".format(cfg, message))
        # Check that each attribute name is legal.
        for attr in cfg.keys():
            if (not attr in expected) and (not attr in optional):
                raise ValueError("\nUnknown or unsupported attribute. Received `{}`{}".format(attr, message))
        # Get the list of used attributes
        used = list(cfg.keys())
    else:
        used = []

    # Check that all the mandatory attributes are present.
    for attr in expected:
        if attr not in used:
            raise ValueError("\nMissing `{}` attribute{}".format(attr, message))
        if cfg[attr] is None:
            raise ValueError("\nMissing a value for attribute `{}`{}".format(attr, message))
        

def collect_callback_args(name, 
                          args=None, 
                          message=None) -> str:
    if args:
        if type(args) != DefaultMunch:
            raise ValueError(f"\nInvalid syntax for `{name}` callback arguments{message}")
        text = "("
        for k, v in args.items():
            if type(v) == str and v[:7] != "lambda ":
                text += f'{k}=r"{v}", '
            else:
                text += f'{k}={v}, '
        text = text[:-2] + ")"
    else:
        text = "()"
    return text


def get_random_seed(cfg: DictConfig = None):
    '''
    Returns a random seed based on the configuration file.
    args:
        cfg (DictConfig): The configuration object.
    returns:
        int or None: The random seed. If no seed is set in the configuration file, returns 0.
    '''
    if "global_seed" in cfg.general:
        seed = cfg.general["global_seed"]
        if seed == "None":
            seed = None
        else:
            seed = int(seed)
    else:
        seed = 0
    return seed


def replace_none_string(dico: dict) -> dict:
    '''
    Replaces None strings in the values of a dictionary with the Python None value.
       Other values are unchanged.
    args:
        dico (dict): any dictionary.
    '''
    for k, v in dico.items():
        if v == "None":
            dico[k] = None
    return dico


def expand_env_vars(string: str) -> str:
    '''
    Expands environment variables in a string if any. The syntax for variables
    is ${variable_name}. An error is thrown if a variable is found in the string
    but is not set.
    args:
        string (str): The string to expand environment variables in.
    returns:
        string (str): The original string with expanded variables.
    '''
    for match in re.findall(r'\$\{\w+\}', string):
        var_name = match[2:-1]
        # Get the variable value, throw an error if it is not set.
        var_value = os.environ.get(var_name)
        if var_value is None:
            raise OSError("\nCould not find an environment variable named `{}`\n"
                          "Please check your configuration file.".format(var_name))
        match = "\\" + match
        string = re.sub(match, var_value, string, count=1)
    return string


def postprocess_config_dict(config: DictConfig, replace_none_string=False) -> None:
    '''
    The YAML loader outputs some attribute values as strings although they
    are different Python types. This function walks the config dictionary 
    tree and restores some of these types, including booleans, integers,
    floats and tuples.
    For example:
    - "True" is converted to boolean True
    - "1./255" is converted to a float (equal to 0.00392156)
    - "(128, 128, 3)" is converted to a tuple.
    The function also replaces environment variables that appear in strings
    with their values.
    args:
        config (DictConfig): dictionary containing the entire configuration file.
    '''
    
    for k in config.keys():
        v = config[k]
        if type(v) == dict:
            postprocess_config_dict(v, replace_none_string=replace_none_string)
        elif type(v) == str:
            if replace_none_string and v.lower() == "none":
                config[k] = None
                continue
            # Expand environment variables if any
            v_exp = expand_env_vars(v)
            if v_exp != v:
                config[k] = v_exp
                v = v_exp
            if v[:7] == "lambda ":
                # The value is a lambda function. Remove the \n characters
                # and multiple blanks that get inserted by the YAML loader
                # if the function is written on several lines.
                v = re.sub("\n", "", v)
                config[k] = re.sub(" +", " ", v)
            else:
                try:
                    v_eval = eval(v)
                except:
                    v_eval = v
                if isinstance(v_eval, (bool, int, float, tuple)):
                    config[k] = v_eval


def check_config_attributes(cfg: DictConfig, 
                            specs: Dict = None, 
                            section: str = None) -> None:
    '''
    This function checks that the attributes used in a given section
    of the configuration file comply with specified requirements.
    args:
        cfg (DictConfig): dictionary containing the configuration file section to check
        specs (Dict): dictionary specifying the requirements for attribute usage in the section
        section (str): name of the section
    '''
    specs = DefaultMunch.fromDict(specs)
    if section == "top_level":
        message = f"\nPlease check the top-level of your configuration file."
    else:
        message = f"\nPlease check the '{section}' section of your configuration file."

    if specs.legal:
        # Check that all the used attribute names are known
        for attr in cfg.keys():
            if attr not in specs.legal:
                raise ValueError(f"\nUnknown attribute `{attr}`{message}")

    if specs.all:
        # Check that all the specified attributes are present and have a value
        for attr in specs.all:
            if attr not in cfg:
                if section == "top_level":
                    raise ValueError(f"\nMissing `{attr}` section{message}")
                else:
                    raise ValueError(f"\nMissing `{attr}` attribute{message}")
            if cfg[attr] is None:
                if section == "top_level":
                    raise ValueError(f"\nMissing body of `{attr}` section{message}")
                else:
                    raise ValueError(f"\nExpecting a value for `{attr}` attribute{message}")

    if specs.one_or_more:
        # Check that at least one of the specified attributes is present and has a value
        count = 0
        for attr in specs.one_or_more:
            if attr in cfg and cfg[attr] is not None:
                count += 1
        if count == 0:
            raise ValueError(f"\nMissing one or more attributes from {specs.one_or_more}{message}")


def check_model_file_extension(ml_path, mode, mode_groups, field_name):
    """
    Validates the file extension and existence of a model file path according to the current operation mode.

    Args:
        ml_path (str): Path to the model file to check.
        mode (str): The current operation mode (e.g., 'training', 'quantization', etc.).
        mode_groups (Any): An object with attributes for each mode group, each being a list of mode names.
        field_name (str): The name of the config field being checked (for error messages).

    Raises:
        ValueError: If the file extension is not allowed for the current mode, or if the path is not provided.
        FileNotFoundError: If the file does not exist at the given path.
    """
    m1 = f"\nExpecting `{field_name}` attribute to be set to a path to "
    m2 = "\nPlease check the 'model' section of your configuration file."
    if not ml_path:
        raise ValueError(m1 + "a valid file path" + m2)
    file_extension = Path(ml_path).suffix.lower()
    if mode in mode_groups.training:
        allowed = [".h5", ".keras"]
        if file_extension not in allowed:
            raise ValueError(m1 + ", ".join(allowed) + m2)
    elif mode in mode_groups.quantization:
        allowed = [".h5", ".keras", ".onnx"]
        if file_extension not in allowed:
            raise ValueError(m1 + ", ".join(allowed) + m2)
    elif mode in ("evaluation", "prediction"):
        allowed = [".h5", ".keras", ".tflite", ".onnx"]
        if file_extension not in allowed:
            raise ValueError(m1 + ", ".join(allowed) + m2)
    elif mode in ("benchmarking", "deployment"):
        allowed = [".h5", ".keras", ".tflite", ".onnx"]
        if file_extension not in allowed:
            raise ValueError(m1 + ", ".join(allowed) + m2)
    if not os.path.isfile(ml_path):
        raise FileNotFoundError(
            f"\nUnable to find file {ml_path}\nPlease check the '{field_name}' attribute in your configuration file"
        )
        

def parse_tools_section(cfg: DictConfig, 
                        operation_mode: str,
                        hardware_type: str ="MCU") -> None:
    '''
    parses the 'tools' section of the configuration file
    args:
        cfg (DictConfig): 'tools' section of configuration file
        operation_mode (str): service or operation mode used
        hardware_type (str): type of hardware targetted
    '''
#    if cfg is not None:
    required = []
    if hardware_type == "MCU" and not operation_mode == "evaluation" and not operation_mode == "prediction":
        required += ["path_to_cubeIDE",]
    
    if cfg.stedgeai:
        legal = ["stedgeai", "path_to_cubeIDE"]
        check_config_attributes(cfg, specs={"legal": legal, "all": required}, section="tools")

        # stedgeai usage
        legal = ["optimization", "on_cloud", "path_to_stedgeai"]
        check_config_attributes(cfg.stedgeai, 
                                specs={"legal": legal, "all": []}, section="tools.stedgeai")
        if not cfg.stedgeai.on_cloud:
            if not os.path.isfile(cfg.stedgeai.path_to_stedgeai):
                print(cfg.stedgeai.path_to_stedgeai)
                raise ValueError("Path for `stedgeai.exe` does not exist.\n"
                                "Please check the cfg.tools.stedgeai section!")
        
        # Patch to support stedgeai with legacy naming stm32ai : reconstruct stm32ai dictionnary
        # from stedgeai one
        cfg["stm32ai"] = cfg.stedgeai
        cfg.stm32ai["optimization"] = cfg.stedgeai.optimization if cfg.stedgeai.optimization else "balanced"
        cfg.stm32ai["on_cloud"] = cfg.stedgeai.on_cloud # if cfg.stedgeai.on_cloud else True
        cfg.stm32ai["path_to_stm32ai"] = cfg.stedgeai.path_to_stedgeai if cfg.stedgeai.path_to_stedgeai else None
        cfg.stm32ai["version"] = Path(cfg.stedgeai.path_to_stedgeai).parts[-4]
        cfg.stedgeai["version"] = Path(cfg.stedgeai.path_to_stedgeai).parts[-4]
    # Path to cubeIDE only needed for MCU in deployment service
    if hardware_type == "MCU":
        if operation_mode == "deployment" and not os.path.isfile(cfg.path_to_cubeIDE):
            raise ValueError("Path for `path_to_cubeIDE` does not exist.\n"
                            "Please check the cfg.tools section!")


def parse_benchmarking_section(cfg: DictConfig) -> None:
    '''
    parses the 'benchmarking' section of the configuration file
    args:
        cfg (DictConfig): 'benchmarking' section of the configuration file
    '''
    legal = ["board"]
    check_config_attributes(cfg, specs={"legal": legal, "all": legal}, section="benchmarking")


def parse_quantization_section(cfg: DictConfig, 
                               legal: List) -> None:
    '''
    parses the dictionary containing entire configuration file
    args:
        cfg (DictConfig): 'quantization' section of the configuration file
        legal (List): UC specific usable attributes
    '''
    required = [x for x in legal if x not in ["export_dir", "granularity", "optimize", "target_opset", "operating_mode",
                                              "onnx_quant_parameters", "op_types_to_quantize", "onnx_extra_options",
                                              "iterative_quant_parameters"]]
    check_config_attributes(cfg, specs={"legal": legal, "all": required}, section="quantization")

    # Set default values of missing optional arguments
    if not cfg.export_dir:
        cfg.export_dir = "quantized_models"
    if not cfg.granularity:
        cfg.granularity = "per_channel"
    cfg.optimize = cfg.optimize if cfg.optimize is not None else False
    cfg.target_opset = cfg.target_opset if cfg.target_opset is not None else 17
    cfg.operating_mode = cfg.operating_mode if cfg.operating_mode else 'default'

    # Check the quantizer name
    if cfg.quantizer.lower() not in ["tflite_converter", "onnx_quantizer"]:
        raise ValueError(f"\nUnknown or unsupported quantizer. Received `{cfg.quantizer}`\n"
                         "Supported quantizers are : TFlite_converter or Onnx_quantizer\n"
                         "Please check the 'quantization.quantizer' attribute in your configuration file.")

    # Check the granularity value
    if cfg.granularity not in ['per_channel', 'per_tensor']:
        raise ValueError(f"\nUnknown or unsupported granularity value. Received `{cfg.granularity}`\n"
                         "Supported granularity: 'per_channel, or 'per_tensor'\n"
                         "Please check the 'quantization.granularity' attribute in your configuration file.")
    # Check optimization
    if cfg.optimize not in [True, False]:
        raise ValueError(f"\nUnknown or unsupported optimize value. Received `{cfg.optimize}`\n"
                         "Supported optimize parameters: 'True', or 'False'\n"
                         "Please check the 'quantization.optimize' attribute in your configuration file.")

    if not isinstance(cfg.target_opset, int):
        raise ValueError(f"\nUnknown or unsupported target_opset value. Received `{cfg.optimize}`\n"
                         "Supported target_opset parameters: 'int' up to latest onnx_opset\n"
                         "Please check the 'quantization.target_opset' attribute in your configuration file.")
    # Check the quantizer type
    if cfg.quantization_type.lower() not in ["ptq"]:
        raise ValueError(f"\nUnknown or unsupported quantization type. Received `{cfg.quantization_type}`\n"
                         "Supported type: PTQ\n"
                         "Please check the 'quantization.quantization_type' attribute in your configuration file.")


def parse_evaluation_section(cfg: DictConfig, 
                             legal: List) -> None:
    '''
    parses the dictionary containing entire configuration file
    args:
        cfg (DictConfig): 'evaluation' section of the configuration file
        legal (List): UC specific usable attributes
    '''
    required = []
    check_config_attributes(cfg, specs={"legal": legal, "all": required}, section="evaluation")

    # Set default values of missing optional arguments
    if not cfg.gen_npy_input:
        cfg.gen_npy_input = False
    if not cfg.gen_npy_output:
        cfg.gen_npy_output = False
    if not cfg.profile:
        cfg.profile = "profile_O3"
    if not cfg.input_type:
        cfg.input_type = "uint8"
    if not cfg.output_type:
        cfg.output_type = "int8"
    if not cfg.input_chpos:
        cfg.input_chpos = "chlast"
    if not cfg.output_chpos:
        cfg.output_chpos = "chlast"
    if not cfg.target:
        cfg.target = "host"


def parse_top_level(cfg: DictConfig, 
                    mode_groups: DictConfig = None,
                    mode_choices: List = None, 
                    legal: List = None) -> None:
    '''
    parses the dictionary containing entire configuration file
    args:
        cfg (DictConfig): configuration dictionary containing entire configuration file
        mode_groups (DictConfig): configuration about a given mode
        mode_choices (List): currently supported modes
        legal (List): UC specific usable attributes
    '''

    # Check that operation_mode is present and has a value
    message = "\nPlease check the top-level of your configuration file."
    if "operation_mode" not in cfg:
        raise ValueError("\nMissing `operation_mode` attribute\n"
                         f"Supported modes: {mode_choices}{message}")
    if cfg.operation_mode is None:
        raise ValueError("\nExpecting a value for `operation_mode` attribute\n"
                         f"Supported modes: {mode_choices}{message}")
    # [KH]: to be added when all use cases have model section]
    # # Check that the model section is present and has a value
    # if "model" not in cfg:
    #     raise ValueError("\nMissing `model` section at the top level of your configuration file.\n"
    #                      f"Please check your configuration file.{message}")
    # if cfg.model is None:
    #     raise ValueError("\nExpecting a value for `model` section at the top level of your configuration file.\n"
    #                      f"Please check your configuration file.{message}")

    # Check that the value of operation_mode is valid 
    mode = cfg.operation_mode
    if mode not in mode_choices:
        raise ValueError(f"\nUnknown value for `operation_mode` attribute. Received {mode}\n"
                         f"Supported modes: {mode_choices}{message}")

    # Attributes usable at the top level
    required = ["mlflow"]    # [KH]: should include model later] #, "model"]
#    if mode not in mode_groups.training:
#        # We need the 'general' section to provide model_path.
#        required += ["general",]
    if mode != "benchmarking":
        # Need the preprocessing & feature extraction sections (when available)
        required += ["preprocessing"]
        if "feature_extraction" in legal:
            required += ["feature_extraction"]
    if mode not in ("prediction", "quantization", "benchmarking", "deployment", "chain_qb", "chain_qd"):
        required += ["dataset",]
    if mode in mode_groups.training:
        required += ["training",]
    if mode in mode_groups.quantization:
        required += ["quantization",]
#    if mode == "prediction":
#        required += ["prediction",]
    if mode in mode_groups.benchmarking:
        required += ["benchmarking", "tools"]
    if mode in mode_groups.deployment:
        required += ["deployment", "tools"]
    if mode in mode_groups.compression:
        required += ["compression", "training"]     # Needed as fine tuning is part of the compression feature

    check_config_attributes(cfg, specs={"legal": legal, "all": required}, section="top_level")


def parse_general_section(cfg: DictConfig, 
                          mode: str = None, 
                          mode_groups: str = None,
                          legal: List = None,
                          required: List = None,
                          output_dir: str = '') -> None:
    '''
    parses the general section of configuration file.
    args:
        cfg (DictConfig): configuration dictionary
        mode (str): operation mode
        mode_groups (str): operation mode group
        legal (List): UC specific usable attributes
        required (List): UC specific required attributes
        output_dir (str): output directory for the current run
    '''
#    # Usage of the model_path attribute in training modes
#    # is checked when parsing the 'training' section.
#    required.append("model_path") if not mode_groups.training else []
    check_config_attributes(cfg, specs={"legal": legal, "all": required}, section="general")

    # Set default values of missing optional attributes
    if not cfg.project_name:
        cfg.project_name = "<unnamed>"
    if not cfg.logs_dir:
        cfg.logs_dir = "logs"
    if not cfg.saved_models_dir:
        cfg.saved_models_dir = "saved_models"
    cfg.deterministic_ops = cfg.deterministic_ops if cfg.deterministic_ops is not None else False
    cfg.display_figures = cfg.display_figures if cfg.display_figures is not None else True
    if not cfg.global_seed or cfg.global_seed == 'None':
        cfg.global_seed = 123

    if not cfg.num_threads_tflite:
        cfg.num_threads_tflite = 1
        


def parse_random_periodic_resizing(cfg, output_stride):

    message = "Please check the 'random_periodic_resizing' section of your configuration file."
    if "period" not in cfg:  
        raise ValueError(f"\nMissing `period` argument\n{message}")
    if "image_sizes" not in cfg:  
        raise ValueError(f"\nMissing `image_sizes` argument\n{message}")
                         
    # Image sizes can be given using tuples, arrays or a mix
    # of tuples and arrays. We convert all sizes to tuples.
    sizes_str = '['
    for size in cfg.image_sizes:
        if isinstance(size, (list, tuple)):
            sizes_str += '('
            for x in size:
                sizes_str += str(x) + ','
            sizes_str = sizes_str[:-1] + '),'
        else:
            sizes_str += str(size) + ','
    sizes_str = sizes_str[:-1] + ']'

    sizes_message = "\nInvalid syntax for `image_sizes` argument\n"   
    try:
        x = eval(sizes_str)
        random_sizes = np.array(x, dtype=np.int32)
    except:
        raise ValueError(sizes_message + message)
    
    if np.shape(random_sizes)[1] != 2:
        raise ValueError(sizes_message + message)

    # Check that the image sizes are compatible with the network stride.
    for size in random_sizes:
        if np.shape(output_stride)==():
            output_strides = [output_stride]
        else:
            output_strides = output_stride
        for os in output_strides:
            if (size[0] % os != 0) or (size[1] % os != 0):
                raise ValueError(
                    f"Image sizes must be multiples of the network stride.\n"
                    f"Network stride: {os}\n"
                    f"Invalid image size: {size}\n"
                    f"{message}")

    return random_sizes.tolist()
    

def parse_compression_section(cfg: DictConfig, 
                               legal: List) -> None:
    '''
    parses the dictionary containing entire configuration file
    args:
        cfg (DictConfig): 'compression' section of the configuration file
        legal (List): UC specific usable attributes
    '''
    required = [x for x in legal if x not in ["factor", "strong_optimization"]]
    check_config_attributes(cfg, specs={"legal": legal, "all": required}, section="compression")

    # Set default values of missing optional arguments
    cfg.factor = cfg.factor if cfg.factor else 0.5
    cfg.strong_optimization = cfg.strong_optimization if cfg.strong_optimization else False

    # Check the compression factor type
    if not isinstance(cfg.factor, float):
        raise ValueError(f"\nUnknown or unsupported factor value. Received `{cfg.factor}`\n"
                         "Supported factor parameters: 'float'\n"
                         "Please check the 'compression.factor' attribute in your configuration file.")

    # Check optimization
    if cfg.strong_optimization not in [True, False]:
        raise ValueError(f"\nUnknown or unsupported strong_optimization value. Received `{cfg.strong_optimization}`\n"
                         "Supported optimize parameters: 'True', or 'False'\n"
                         "Please check the 'compression.strong_optimization' attribute in your configuration file.")


def parse_training_section(cfg: DictConfig, 
                           legal: List = None) -> None:
    '''
    parses the training section of configuration file.
    args:
        cfg (DictConfig): 'training' section of the configuration file
        model_path_used (bool): a flag to tell if the 'model.model_path' parameter is provided
        model_type_used (bool): a flag to tell if the 'model.model_type' parameter is provided
        legal (List): usable attributes
    '''
    required = ["batch_size", "epochs", "optimizer"]
    check_config_attributes(cfg, specs={"legal": legal, "all": required}, section="training")

    # The optimizer may be written on one line. For example: "optimizer: Adam"
    # In this case, we got a string instead of a dictionary.
    if type(cfg.optimizer) == str:
        cfg.optimizer = DefaultMunch.fromDict({cfg.optimizer: None})

def parse_model_section(cfg: DictConfig, mode: str, mode_groups, legal: list, required: list = None) -> None:
    """
    Checks and parses the root-level 'model' section of the config file.

    Args:
        cfg (DictConfig): The model configuration parameters.
        mode (str): The current operation mode.
        mode_groups: The mode groups object.
        legal (list): List of legal model attributes for this use case.
        required (list, optional): List of required model attributes. If None, no required fields are enforced.

    Returns:
        None
    """
    req = [] if required is None else list(required)
    if cfg.model_name:
        req.append("input_shape")
    # Mutually exclusive model sources
    model_sources = ["model_name", "model_path"]
    set_sources = [name for name in model_sources if getattr(cfg, name, None)]
    if len(set_sources) == 0:
        raise ValueError(
            "\nExpecting one of the following model source attributes to be set: "
            f"{', '.join(model_sources)}\nPlease check your configuration file." )


    check_config_attributes(cfg, specs={"legal": legal, "all": req}, section="model")

    if cfg.model_path and cfg.model_path[:4].lower() == "http":
        print('[INFO] : A URL found for model.model_path variable!')
        url = cfg.model_path
        output_dir = HydraConfig.get().runtime.output_dir
        model_dir = os.path.join(output_dir, os.path.splitext(os.path.basename(cfg.model_path))[0])
        os.makedirs(model_dir, exist_ok=True)
        local_path = os.path.join(model_dir, url.split('/')[-1])
        download_file(url, local_path)
        cfg.model_path = local_path

    # Check model_path
    if cfg.model_path:
        file_extension = Path(cfg.model_path).suffix.lower()
        if file_extension in [".h5", ".keras", ".tflite"]:
            cfg.framework = "tf"
            check_model_file_extension(cfg.model_path, mode, mode_groups, "model_path")
        elif file_extension in [".pt", ".pth"]:
            cfg.framework = "torch"
        elif file_extension in [".onnx"] and not cfg.framework:
            cfg.framework = "tf"
            check_model_file_extension(cfg.model_path, mode, mode_groups, "model_path")
    else:
        # Get end of model name to set the framework to be used"
        if cfg.model_name[-3:]=='_pt':
            cfg.framework = "torch"
        else:
            cfg.framework = "tf"
                
    if cfg.framework == "tf":
        if len(set_sources) > 1:
            raise ValueError(
                "\nThe following model source attributes are mutually exclusive and more than one is set: "
                f"{', '.join(set_sources)}\nPlease check your configuration file.")

        
def parse_prediction_section(cfg: DictConfig) -> None:
    '''
    parses the prediction section of configuration file.
    args:
        cfg (DictConfig): 'prediction' section of the configuration file
    '''
    legal = ["seed","target", "reid_distance_metric",
             "profile", "input_type", "output_type", "input_chpos", "output_chpos"]
    required = []
    check_config_attributes(cfg, specs={"legal": legal, "all": required}, section="prediction")

    # Set default values of missing optional arguments
    if not cfg.profile:
        cfg.profile = "profile_O3"
    if not cfg.input_type:
        cfg.input_type = "uint8"
    if not cfg.output_type:
        cfg.output_type = "int8"
    if not cfg.input_chpos:
        cfg.input_chpos = "chlast"
    if not cfg.output_chpos:
        cfg.output_chpos = "chlast"
    if not cfg.target:
        cfg.target = "host"

def parse_deployment_section(cfg: DictConfig,
                             legal: List = None,
                             legal_hw: List = None) -> None:
    '''
    parses the training section of configuration file.
    args:
        cfg (DictConfig): 'deployment' section of the configuration file
        legal (List): usable attributes
        legal_hw (List): usable attributes for the HW setup part
    '''
    required = [x for x in legal]
    check_config_attributes(cfg, specs={"legal": legal, "all": required}, section="deployment")
    required = [x for x in legal_hw if x != 'stlink_serial_number']
    check_config_attributes(cfg.hardware_setup, specs={"legal": legal_hw, "all": required},
                            section="deployment.hardware_setup")


def parse_mlflow_section(cfg: DictConfig) -> None:
    '''
    parses the mlflow section of configuration
    args:
        cfg (DictConfig): 'mlflow' section of the configuration
    '''
    legal = ["uri"]
    check_config_attributes(cfg, specs={"legal": legal, "all": legal}, section="mlflow")


def check_hardware_type(cfg: DictConfig, 
                        mode_groups: DictConfig = None) -> None:
    '''
    parses the mlflow section of configuration
    args:
        cfg (DictConfig): dictionary containing the configuration file section to check
        mode_groups (DictConfig): configuration about a given mode
    '''
    # By default MCU is selected
    cfg["hardware_type"] = "MCU"

    # Check if a MPU target is specified in the configuration
    if cfg.operation_mode in mode_groups.benchmarking:
        if cfg.benchmarking.board is not None:
            if "STM32MP" in cfg.benchmarking.board:
                cfg.hardware_type = "MPU"
    elif cfg.operation_mode in mode_groups.deployment:
        if cfg.deployment.hardware_setup.serie is not None:
            if "STM32MP" in cfg.deployment.hardware_setup.serie:
                cfg.hardware_type = "MPU"


def get_class_names_from_file(cfg: DictConfig) -> List[str]:
    if cfg.classes_file_path :
        with open(cfg.classes_file_path, 'r') as file:
            class_names = [line.strip() for line in file]
    return class_names

from omegaconf import DictConfig, OmegaConf

def flatten_config(cfg, preserve_keys=("class_map",)):
    # Ensure cfg is a plain dict first
    if isinstance(cfg, DictConfig):
        cfg = OmegaConf.to_container(cfg, resolve=False)
    elif "DefaultMunch" in str(type(cfg)):
        cfg = cfg.toDict()
    else:
        cfg = dict(cfg)

    preserve_keys = set(preserve_keys)
    flat_config = {}

    def _flatten(d):
        for k, v in d.items():
            if isinstance(v, dict):
                if k in preserve_keys:
                    flat_config[k] = v
                else:
                    _flatten(v)
            else:
                flat_config[k] = v 

    _flatten(cfg)
    return flat_config