File size: 6,294 Bytes
747451d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# /*---------------------------------------------------------------------------------------------
#  * Copyright (c) 2022-2023 STMicroelectronics.
#  * All rights reserved.
#  *
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/

import os
import numpy as np
import tensorflow as tf
from pathlib import Path
from omegaconf import DictConfig
from .utils import load_subset_dataloaders
from munch import DefaultMunch


def load_tfs_like(cfg: DictConfig, 
                  image_size: tuple[int],
                  val_batch_size: int) -> dict:
    """
    Handles datasets where images and labels are already in TFS format (JPEG + TFS files)
    and the dataset are already split into several subsets.

    Args:
        cfg (DictConfig): Configuration object containing dataset parameters including:
            - dataset.format: The corresponding dataset format (tfs, darknet_yolo, coco, pascal_voc).
            - dataset.training_path: Path for processed training data
            - dataset.test_path: Path for processed test data
            - dataset.validation_path: Path for processed validation data
            - dataset.quantization_path: Path for quantization data (required if quantization is enabled)
            - dataset.prediction_path: Path for prediction data (required if prediction is enabled)
            - dataset.class_names: List of class names to use
            - dataset.download_data: Whether to download dataset (unsupported for Darknet YOLO format)
            - settings.max_detections: Optional maximum number of detections per image
            - settings.exclude_unlabeled_images: Whether to exclude images without labels
            - operation_mode (str): One of the supported modes or chains (e.g., chain_eqeb, training, evaluation, etc.)

    Returns:
        dict[str, tf.data.Dataset]: Dictionary containing training, validation, test,
            quantization and prediction datasets as TensorFlow datasets.
    """

    if not hasattr(cfg, 'operation_mode'):
        raise ValueError("cfg.operation_mode must be specified")

    # Check for unsupported download_data option
    if hasattr(cfg.dataset, 'download_data') and cfg.dataset.download_data:
        raise NotImplementedError("Downloading dataset is unsupported for TFS-ready format. "
                                  "Please prepare the dataset manually in the expected format.")

    mode_str = cfg.operation_mode.lower()

    mode_groups = DefaultMunch.fromDict({
        "training": ["training", "chain_tqeb", "chain_tqe"],
        "evaluation": ["evaluation", "chain_tqeb", "chain_tqe", "chain_eqe", "chain_eqeb"],
        "quantization": ["quantization", "chain_tqeb", "chain_tqe", "chain_eqe",
                         "chain_qb", "chain_eqeb", "chain_qd"],
        "benchmarking": ["benchmarking", "chain_tqeb", "chain_qb", "chain_eqeb"],
        "deployment": ["deployment", "chain_qd"],
        "prediction": ["prediction"],
        "compression": ["compression"]
    })

    # Conditional addition based on cfg.quantization.operating_mode
    if getattr(cfg.quantization, "operating_mode", None) == "full_auto":
        additional_items = ["quantization", "chain_qd", "chain_qb"]
        
        for item in additional_items:
            if item not in mode_groups.evaluation:
                mode_groups.evaluation.append(item)

    def is_mode_in_group(group_name):
        return mode_str in mode_groups.get(group_name, [])

    is_training = is_mode_in_group("training")
    is_evaluation = is_mode_in_group("evaluation")
    is_quantization = is_mode_in_group("quantization")
    is_benchmarking = is_mode_in_group("benchmarking")
    is_deployment = is_mode_in_group("deployment")
    is_prediction = is_mode_in_group("prediction")
    is_compression = is_mode_in_group("compression")

    # Verify required class names
    if not hasattr(cfg.dataset, 'class_names'):
        raise ValueError("Class names must be specified in cfg.dataset.class_names")

    # Validate paths depending on operation mode
    if is_training:
        if not hasattr(cfg.dataset, 'training_path'):
            raise ValueError("cfg.dataset.training_path must be specified in training mode")
        if not os.path.exists(cfg.dataset.training_path):
            raise ValueError(f"Training path {cfg.dataset.training_path} does not exist")
        os.makedirs(cfg.dataset.training_path, exist_ok=True)
        print("Skipping dataset analysis as for TFS dataset format.\n")

    if hasattr(cfg.dataset, 'validation_path') and cfg.dataset.validation_path:
        if not os.path.exists(cfg.dataset.validation_path):
            raise ValueError(f"Validation path {cfg.dataset.validation_path} does not exist")
        os.makedirs(cfg.dataset.validation_path, exist_ok=True)

    if is_evaluation:
        if not hasattr(cfg.dataset, 'test_path') and not cfg.dataset.test_path:
            raise ValueError("cfg.dataset.test_path must be specified in evaluation mode")
        if not os.path.exists(cfg.dataset.test_path):
            raise ValueError(f"Test path {cfg.dataset.test_path} does not exist")

    if is_prediction:
        if not hasattr(cfg.dataset, 'prediction_path'):
            raise ValueError("cfg.dataset.prediction_path must be specified in prediction mode")
        if not os.path.exists(cfg.dataset.prediction_path):
            raise ValueError(f"Prediction path {cfg.dataset.prediction_path} does not exist")

    if is_quantization:
        if not hasattr(cfg.dataset, 'quantization_path'):
            raise ValueError("cfg.dataset.quantization_path must be specified in quantization mode")
        if not os.path.exists(cfg.dataset.quantization_path):
            raise ValueError(f"Quantization path {cfg.dataset.quantization_path} does not exist")

    print("Loading datasets in darknet format...")
    return load_subset_dataloaders(cfg, is_training, is_evaluation,
                             is_prediction, is_quantization,
                             image_size=image_size, val_batch_size=val_batch_size)