File size: 6,825 Bytes
747451d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | # /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2022-2023 STMicroelectronics.
# * All rights reserved.
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
import os
import pickle
import numpy as np
import tensorflow as tf
from pathlib import Path
from typing import Tuple, List
from .utils import get_train_val_ds, get_ds, get_prediction_ds
def load_custom_dataset(training_path: str = None,
validation_path: str = None,
quantization_path: str = None,
test_path: str = None,
prediction_path: str = None,
validation_split: float = None,
quantization_split: float = None,
class_names: list[str] = None,
image_size: tuple[int] = None,
interpolation: str = None,
aspect_ratio: str = None,
color_mode: str = None,
batch_size: int = None,
seed: int = None
) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
"""
Loads the images from the given dataset root directories and returns training,
validation, and test tf.data.Datasets.
The datasets have the following directory structure (checked in parse_config.py):
dataset_root_dir:
class_a:
a_image_1.jpg
a_image_2.jpg
class_b:
b_image_1.jpg
b_image_2.jpg
Args:
training_path (str): Path to the directory containing the training images.
validation_path (str): Path to the directory containing the validation images.
quantization_path (str): Path to the directory containing the quantization images.
test_path (str): Path to the directory containing the test images.
validation_split (float): Fraction of the data to use for validation.
quantization_split (float): Fraction of the data to use for quantization.
class_names (list[str]): List of class names to use for the images.
image_size (tuple[int]): resizing (height, width) of input images
interpolation (str): Interpolation method to use when resizing the images.
aspect_ratio (bool): Whether or not to crop the images to the specified aspect ratio.
color_mode (str): Color mode to use for the images.
batch_size (int): Batch size to use for the datasets.
seed (int): Seed to use for shuffling the data.
Returns:
Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]: Training, validation,
quantization, test and prediction datasets.
"""
# Get training and validation sets
if training_path and not validation_path:
# There is no validation. We split the
# training set in two to create one.
train_ds, val_ds = get_train_val_ds(
training_path,
class_names=class_names,
image_size=image_size,
interpolation=interpolation,
aspect_ratio=aspect_ratio,
color_mode=color_mode,
validation_split=validation_split,
batch_size=batch_size,
shuffle=True,
seed=seed)
elif training_path and validation_path:
train_ds = get_ds(
training_path,
class_names=class_names,
image_size=image_size,
interpolation=interpolation,
aspect_ratio=aspect_ratio,
color_mode=color_mode,
batch_size=batch_size,
shuffle=True,
seed=seed)
val_ds = get_ds(
validation_path,
class_names=class_names,
image_size=image_size,
interpolation=interpolation,
aspect_ratio=aspect_ratio,
color_mode=color_mode,
batch_size=batch_size,
shuffle=False,
seed=seed)
elif validation_path:
val_ds = get_ds(
validation_path,
class_names=class_names,
image_size=image_size,
interpolation=interpolation,
aspect_ratio=aspect_ratio,
color_mode=color_mode,
batch_size=batch_size,
shuffle=False,
seed=seed)
train_ds = None
else:
train_ds = None
val_ds = None
# Get quantization set
if quantization_path:
quantization_ds = get_ds(
quantization_path,
class_names=class_names,
image_size=image_size,
interpolation=interpolation,
aspect_ratio=aspect_ratio,
color_mode=color_mode,
batch_size=1,
shuffle=False,
seed=seed)
elif train_ds is not None:
quantization_ds = get_ds(
training_path,
class_names=class_names,
image_size=image_size,
interpolation=interpolation,
aspect_ratio=aspect_ratio,
color_mode=color_mode,
batch_size=1,
shuffle=False,
seed=seed)
else:
quantization_ds = None
if quantization_ds:
quant_split = quantization_split if quantization_split else 1.0
print(f'[INFO] : Quantizing by using {quant_split * 100} % of the provided dataset...')
quantization_ds = quantization_ds.take(int(len(quantization_ds) * float(quant_split)))
# Get test set
if test_path:
test_ds = get_ds(
test_path,
class_names=class_names,
image_size=image_size,
interpolation=interpolation,
aspect_ratio=aspect_ratio,
color_mode=color_mode,
batch_size=batch_size,
shuffle=False,
seed=seed)
else:
test_ds = None
# Get prediction set
if prediction_path:
predict_ds = get_prediction_ds(
prediction_path,
class_names=class_names,
image_size=image_size,
interpolation=interpolation,
aspect_ratio=aspect_ratio,
color_mode=color_mode,
batch_size=1,
shuffle=False,
seed=seed)
else:
predict_ds = None
return {'train': train_ds, 'valid': val_ds, 'quantization': quantization_ds, 'test': test_ds, 'predict': predict_ds}
|