|
|
import os |
|
|
import random |
|
|
import time |
|
|
import warnings |
|
|
from multiprocessing.pool import ThreadPool |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from keras.src import tree |
|
|
from keras.src.api_export import keras_export |
|
|
from keras.src.utils import io_utils |
|
|
from keras.src.utils.module_utils import tensorflow as tf |
|
|
|
|
|
|
|
|
@keras_export("keras.utils.split_dataset") |
|
|
def split_dataset( |
|
|
dataset, left_size=None, right_size=None, shuffle=False, seed=None |
|
|
): |
|
|
"""Splits a dataset into a left half and a right half (e.g. train / test). |
|
|
|
|
|
Args: |
|
|
dataset: |
|
|
A `tf.data.Dataset`, a `torch.utils.data.Dataset` object, |
|
|
or a list/tuple of arrays with the same length. |
|
|
left_size: If float (in the range `[0, 1]`), it signifies |
|
|
the fraction of the data to pack in the left dataset. If integer, it |
|
|
signifies the number of samples to pack in the left dataset. If |
|
|
`None`, defaults to the complement to `right_size`. |
|
|
Defaults to `None`. |
|
|
right_size: If float (in the range `[0, 1]`), it signifies |
|
|
the fraction of the data to pack in the right dataset. |
|
|
If integer, it signifies the number of samples to pack |
|
|
in the right dataset. |
|
|
If `None`, defaults to the complement to `left_size`. |
|
|
Defaults to `None`. |
|
|
shuffle: Boolean, whether to shuffle the data before splitting it. |
|
|
seed: A random seed for shuffling. |
|
|
|
|
|
Returns: |
|
|
A tuple of two `tf.data.Dataset` objects: |
|
|
the left and right splits. |
|
|
|
|
|
Example: |
|
|
|
|
|
>>> data = np.random.random(size=(1000, 4)) |
|
|
>>> left_ds, right_ds = keras.utils.split_dataset(data, left_size=0.8) |
|
|
>>> int(left_ds.cardinality()) |
|
|
800 |
|
|
>>> int(right_ds.cardinality()) |
|
|
200 |
|
|
""" |
|
|
dataset_type_spec = _get_type_spec(dataset) |
|
|
|
|
|
if dataset_type_spec is None: |
|
|
raise TypeError( |
|
|
"The `dataset` argument must be either" |
|
|
"a `tf.data.Dataset`, a `torch.utils.data.Dataset`" |
|
|
"object, or a list/tuple of arrays. " |
|
|
f"Received: dataset={dataset} of type {type(dataset)}" |
|
|
) |
|
|
|
|
|
if right_size is None and left_size is None: |
|
|
raise ValueError( |
|
|
"At least one of the `left_size` or `right_size` " |
|
|
"must be specified. Received: left_size=None and " |
|
|
"right_size=None" |
|
|
) |
|
|
|
|
|
dataset_as_list = _convert_dataset_to_list(dataset, dataset_type_spec) |
|
|
|
|
|
if shuffle: |
|
|
if seed is None: |
|
|
seed = random.randint(0, int(1e6)) |
|
|
random.seed(seed) |
|
|
random.shuffle(dataset_as_list) |
|
|
|
|
|
total_length = len(dataset_as_list) |
|
|
|
|
|
left_size, right_size = _rescale_dataset_split_sizes( |
|
|
left_size, right_size, total_length |
|
|
) |
|
|
left_split = list(dataset_as_list[:left_size]) |
|
|
right_split = list(dataset_as_list[-right_size:]) |
|
|
|
|
|
left_split = _restore_dataset_from_list( |
|
|
left_split, dataset_type_spec, dataset |
|
|
) |
|
|
right_split = _restore_dataset_from_list( |
|
|
right_split, dataset_type_spec, dataset |
|
|
) |
|
|
|
|
|
left_split = tf.data.Dataset.from_tensor_slices(left_split) |
|
|
right_split = tf.data.Dataset.from_tensor_slices(right_split) |
|
|
|
|
|
|
|
|
if dataset_type_spec is tf.data.Dataset and is_batched(dataset): |
|
|
batch_size = get_batch_size(dataset) |
|
|
if batch_size is not None: |
|
|
left_split = left_split.batch(batch_size) |
|
|
right_split = right_split.batch(batch_size) |
|
|
|
|
|
left_split = left_split.prefetch(tf.data.AUTOTUNE) |
|
|
right_split = right_split.prefetch(tf.data.AUTOTUNE) |
|
|
return left_split, right_split |
|
|
|
|
|
|
|
|
def _convert_dataset_to_list( |
|
|
dataset, |
|
|
dataset_type_spec, |
|
|
data_size_warning_flag=True, |
|
|
ensure_shape_similarity=True, |
|
|
): |
|
|
"""Convert `dataset` object to a list of samples. |
|
|
|
|
|
Args: |
|
|
dataset: A `tf.data.Dataset`, a `torch.utils.data.Dataset` object, |
|
|
or a list/tuple of arrays. |
|
|
dataset_type_spec: the type of the dataset. |
|
|
data_size_warning_flag: If set to `True`, a warning will |
|
|
be issued if the dataset takes longer than 10 seconds to iterate. |
|
|
Defaults to `True`. |
|
|
ensure_shape_similarity: If set to `True`, the shape of |
|
|
the first sample will be used to validate the shape of rest of the |
|
|
samples. Defaults to `True`. |
|
|
|
|
|
Returns: |
|
|
List: A list of samples. |
|
|
""" |
|
|
dataset_iterator = _get_data_iterator_from_dataset( |
|
|
dataset, dataset_type_spec |
|
|
) |
|
|
dataset_as_list = [] |
|
|
|
|
|
start_time = time.time() |
|
|
for sample in _get_next_sample( |
|
|
dataset_iterator, |
|
|
ensure_shape_similarity, |
|
|
data_size_warning_flag, |
|
|
start_time, |
|
|
): |
|
|
dataset_as_list.append(sample) |
|
|
|
|
|
return dataset_as_list |
|
|
|
|
|
|
|
|
def _get_data_iterator_from_dataset(dataset, dataset_type_spec): |
|
|
"""Get the iterator from a dataset. |
|
|
|
|
|
Args: |
|
|
dataset: A `tf.data.Dataset`, a `torch.utils.data.Dataset` object, |
|
|
or a list/tuple of arrays. |
|
|
dataset_type_spec: The type of the dataset. |
|
|
|
|
|
Returns: |
|
|
iterator: An `iterator` object. |
|
|
""" |
|
|
if dataset_type_spec is list: |
|
|
if len(dataset) == 0: |
|
|
raise ValueError( |
|
|
"Received an empty list dataset. " |
|
|
"Please provide a non-empty list of arrays." |
|
|
) |
|
|
|
|
|
expected_shape = None |
|
|
for i, element in enumerate(dataset): |
|
|
if not isinstance(element, np.ndarray): |
|
|
raise ValueError( |
|
|
"Expected a list of `numpy.ndarray` objects," |
|
|
f"Received: {type(element)} at index {i}." |
|
|
) |
|
|
if expected_shape is None: |
|
|
expected_shape = element.shape |
|
|
elif element.shape[0] != expected_shape[0]: |
|
|
raise ValueError( |
|
|
"Received a list of NumPy arrays with different lengths." |
|
|
f"Mismatch found at index {i}, " |
|
|
f"Expected shape={expected_shape} " |
|
|
f"Received shape={np.array(element).shape}." |
|
|
"Please provide a list of NumPy arrays of the same length." |
|
|
) |
|
|
|
|
|
return iter(zip(*dataset)) |
|
|
elif dataset_type_spec is tuple: |
|
|
if len(dataset) == 0: |
|
|
raise ValueError( |
|
|
"Received an empty list dataset." |
|
|
"Please provide a non-empty tuple of arrays." |
|
|
) |
|
|
|
|
|
expected_shape = None |
|
|
for i, element in enumerate(dataset): |
|
|
if not isinstance(element, np.ndarray): |
|
|
raise ValueError( |
|
|
"Expected a tuple of `numpy.ndarray` objects," |
|
|
f"Received: {type(element)} at index {i}." |
|
|
) |
|
|
if expected_shape is None: |
|
|
expected_shape = element.shape |
|
|
elif element.shape[0] != expected_shape[0]: |
|
|
raise ValueError( |
|
|
"Received a tuple of NumPy arrays with different lengths." |
|
|
f"Mismatch found at index {i}, " |
|
|
f"Expected shape={expected_shape} " |
|
|
f"Received shape={np.array(element).shape}." |
|
|
"Please provide a tuple of NumPy arrays of the same length." |
|
|
) |
|
|
|
|
|
return iter(zip(*dataset)) |
|
|
elif dataset_type_spec is tf.data.Dataset: |
|
|
if is_batched(dataset): |
|
|
dataset = dataset.unbatch() |
|
|
return iter(dataset) |
|
|
|
|
|
elif is_torch_dataset(dataset): |
|
|
return iter(dataset) |
|
|
elif dataset_type_spec is np.ndarray: |
|
|
return iter(dataset) |
|
|
raise ValueError(f"Invalid dataset_type_spec: {dataset_type_spec}") |
|
|
|
|
|
|
|
|
def _get_next_sample( |
|
|
dataset_iterator, |
|
|
ensure_shape_similarity, |
|
|
data_size_warning_flag, |
|
|
start_time, |
|
|
): |
|
|
"""Yield data samples from the `dataset_iterator`. |
|
|
|
|
|
Args: |
|
|
dataset_iterator: An `iterator` object. |
|
|
ensure_shape_similarity: If set to `True`, the shape of |
|
|
the first sample will be used to validate the shape of rest of the |
|
|
samples. Defaults to `True`. |
|
|
data_size_warning_flag: If set to `True`, a warning will |
|
|
be issued if the dataset takes longer than 10 seconds to iterate. |
|
|
Defaults to `True`. |
|
|
start_time (float): the start time of the dataset iteration. this is |
|
|
used only if `data_size_warning_flag` is set to true. |
|
|
|
|
|
Yields: |
|
|
data_sample: The next sample. |
|
|
""" |
|
|
from keras.src.trainers.data_adapters.data_adapter_utils import ( |
|
|
is_torch_tensor, |
|
|
) |
|
|
|
|
|
try: |
|
|
dataset_iterator = iter(dataset_iterator) |
|
|
first_sample = next(dataset_iterator) |
|
|
if isinstance(first_sample, (tf.Tensor, np.ndarray)) or is_torch_tensor( |
|
|
first_sample |
|
|
): |
|
|
first_sample_shape = np.array(first_sample).shape |
|
|
else: |
|
|
first_sample_shape = None |
|
|
ensure_shape_similarity = False |
|
|
yield first_sample |
|
|
except StopIteration: |
|
|
raise ValueError( |
|
|
"Received an empty dataset. Argument `dataset` must " |
|
|
"be a non-empty list/tuple of `numpy.ndarray` objects " |
|
|
"or `tf.data.Dataset` objects." |
|
|
) |
|
|
|
|
|
for i, sample in enumerate(dataset_iterator): |
|
|
if ensure_shape_similarity: |
|
|
if first_sample_shape != np.array(sample).shape: |
|
|
raise ValueError( |
|
|
"All `dataset` samples must have same shape, " |
|
|
f"Expected shape: {np.array(first_sample).shape} " |
|
|
f"Received shape: {np.array(sample).shape} at index " |
|
|
f"{i}." |
|
|
) |
|
|
if data_size_warning_flag: |
|
|
if i % 10 == 0: |
|
|
cur_time = time.time() |
|
|
|
|
|
if int(cur_time - start_time) > 10 and data_size_warning_flag: |
|
|
warnings.warn( |
|
|
"The dataset is taking longer than 10 seconds to " |
|
|
"iterate over. This may be due to the size of the " |
|
|
"dataset. Keep in mind that the `split_dataset` " |
|
|
"utility is only for small in-memory dataset " |
|
|
"(e.g. < 10,000 samples).", |
|
|
category=ResourceWarning, |
|
|
source="split_dataset", |
|
|
) |
|
|
data_size_warning_flag = False |
|
|
yield sample |
|
|
|
|
|
|
|
|
def is_torch_dataset(dataset): |
|
|
if hasattr(dataset, "__class__"): |
|
|
for parent in dataset.__class__.__mro__: |
|
|
if parent.__name__ == "Dataset" and str( |
|
|
parent.__module__ |
|
|
).startswith("torch.utils.data"): |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def _rescale_dataset_split_sizes(left_size, right_size, total_length): |
|
|
"""Rescale the dataset split sizes. |
|
|
|
|
|
We want to ensure that the sum of |
|
|
the split sizes is equal to the total length of the dataset. |
|
|
|
|
|
Args: |
|
|
left_size: The size of the left dataset split. |
|
|
right_size: The size of the right dataset split. |
|
|
total_length: The total length of the dataset. |
|
|
|
|
|
Returns: |
|
|
tuple: A tuple of rescaled `left_size` and `right_size` integers. |
|
|
""" |
|
|
left_size_type = type(left_size) |
|
|
right_size_type = type(right_size) |
|
|
|
|
|
|
|
|
if (left_size is not None and left_size_type not in [int, float]) and ( |
|
|
right_size is not None and right_size_type not in [int, float] |
|
|
): |
|
|
raise TypeError( |
|
|
"Invalid `left_size` and `right_size` Types. Expected: " |
|
|
"integer or float or None, Received: type(left_size)=" |
|
|
f"{left_size_type} and type(right_size)={right_size_type}" |
|
|
) |
|
|
|
|
|
|
|
|
if left_size is not None and left_size_type not in [int, float]: |
|
|
raise TypeError( |
|
|
"Invalid `left_size` Type. Expected: int or float or None, " |
|
|
f"Received: type(left_size)={left_size_type}. " |
|
|
) |
|
|
|
|
|
|
|
|
if right_size is not None and right_size_type not in [int, float]: |
|
|
raise TypeError( |
|
|
"Invalid `right_size` Type. " |
|
|
"Expected: int or float or None," |
|
|
f"Received: type(right_size)={right_size_type}." |
|
|
) |
|
|
|
|
|
|
|
|
if left_size == 0 and right_size == 0: |
|
|
raise ValueError( |
|
|
"Both `left_size` and `right_size` are zero. " |
|
|
"At least one of the split sizes must be non-zero." |
|
|
) |
|
|
|
|
|
|
|
|
if ( |
|
|
left_size_type is int |
|
|
and (left_size <= 0 or left_size >= total_length) |
|
|
or left_size_type is float |
|
|
and (left_size <= 0 or left_size >= 1) |
|
|
): |
|
|
raise ValueError( |
|
|
"`left_size` should be either a positive integer " |
|
|
f"smaller than {total_length}, or a float " |
|
|
"within the range `[0, 1]`. Received: left_size=" |
|
|
f"{left_size}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
right_size_type is int |
|
|
and (right_size <= 0 or right_size >= total_length) |
|
|
or right_size_type is float |
|
|
and (right_size <= 0 or right_size >= 1) |
|
|
): |
|
|
raise ValueError( |
|
|
"`right_size` should be either a positive integer " |
|
|
f"and smaller than {total_length} or a float " |
|
|
"within the range `[0, 1]`. Received: right_size=" |
|
|
f"{right_size}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
right_size_type is left_size_type is float |
|
|
and right_size + left_size > 1 |
|
|
): |
|
|
raise ValueError( |
|
|
"The sum of `left_size` and `right_size` is greater " |
|
|
"than 1. It must be less than or equal to 1." |
|
|
) |
|
|
|
|
|
if left_size_type is float: |
|
|
left_size = round(left_size * total_length) |
|
|
elif left_size_type is int: |
|
|
left_size = float(left_size) |
|
|
|
|
|
if right_size_type is float: |
|
|
right_size = round(right_size * total_length) |
|
|
elif right_size_type is int: |
|
|
right_size = float(right_size) |
|
|
|
|
|
if left_size is None: |
|
|
left_size = total_length - right_size |
|
|
elif right_size is None: |
|
|
right_size = total_length - left_size |
|
|
|
|
|
if left_size + right_size > total_length: |
|
|
raise ValueError( |
|
|
"The sum of `left_size` and `right_size` should " |
|
|
f"be smaller than the {total_length}. " |
|
|
f"Received: left_size + right_size = {left_size + right_size}" |
|
|
f"and total_length = {total_length}" |
|
|
) |
|
|
|
|
|
for split, side in [(left_size, "left"), (right_size, "right")]: |
|
|
if split == 0: |
|
|
raise ValueError( |
|
|
f"With `dataset` of length={total_length}, `left_size`=" |
|
|
f"{left_size} and `right_size`={right_size}." |
|
|
f"Resulting {side} side dataset split will be empty. " |
|
|
"Adjust any of the aforementioned parameters" |
|
|
) |
|
|
|
|
|
left_size, right_size = int(left_size), int(right_size) |
|
|
return left_size, right_size |
|
|
|
|
|
|
|
|
def _restore_dataset_from_list( |
|
|
dataset_as_list, dataset_type_spec, original_dataset |
|
|
): |
|
|
"""Restore the dataset from the list of arrays.""" |
|
|
if dataset_type_spec in [tuple, list, tf.data.Dataset] or is_torch_dataset( |
|
|
original_dataset |
|
|
): |
|
|
|
|
|
element_spec = dataset_as_list[0] |
|
|
|
|
|
dataset_as_list = [tree.flatten(sample) for sample in dataset_as_list] |
|
|
|
|
|
dataset_as_list = [np.array(sample) for sample in zip(*dataset_as_list)] |
|
|
|
|
|
dataset_as_list = tree.pack_sequence_as(element_spec, dataset_as_list) |
|
|
|
|
|
return tree.traverse( |
|
|
lambda x: tuple(x) if isinstance(x, list) else x, |
|
|
dataset_as_list, |
|
|
top_down=False, |
|
|
) |
|
|
|
|
|
return dataset_as_list |
|
|
|
|
|
|
|
|
def is_batched(dataset): |
|
|
"""Check if the `tf.data.Dataset` is batched.""" |
|
|
return hasattr(dataset, "_batch_size") |
|
|
|
|
|
|
|
|
def get_batch_size(dataset): |
|
|
"""Get the batch size of the dataset.""" |
|
|
if is_batched(dataset): |
|
|
return dataset._batch_size |
|
|
else: |
|
|
return None |
|
|
|
|
|
|
|
|
def _get_type_spec(dataset): |
|
|
"""Get the type spec of the dataset.""" |
|
|
if isinstance(dataset, tuple): |
|
|
return tuple |
|
|
elif isinstance(dataset, list): |
|
|
return list |
|
|
elif isinstance(dataset, np.ndarray): |
|
|
return np.ndarray |
|
|
elif isinstance(dataset, tf.data.Dataset): |
|
|
return tf.data.Dataset |
|
|
elif is_torch_dataset(dataset): |
|
|
from torch.utils.data import Dataset as TorchDataset |
|
|
|
|
|
return TorchDataset |
|
|
else: |
|
|
return None |
|
|
|
|
|
|
|
|
def index_directory( |
|
|
directory, |
|
|
labels, |
|
|
formats, |
|
|
class_names=None, |
|
|
shuffle=True, |
|
|
seed=None, |
|
|
follow_links=False, |
|
|
verbose=True, |
|
|
): |
|
|
"""List all files in `directory`, with their labels. |
|
|
|
|
|
Args: |
|
|
directory: Directory where the data is located. |
|
|
If `labels` is `"inferred"`, it should contain |
|
|
subdirectories, each containing files for a class. |
|
|
Otherwise, the directory structure is ignored. |
|
|
labels: Either `"inferred"` |
|
|
(labels are generated from the directory structure), |
|
|
`None` (no labels), |
|
|
or a list/tuple of integer labels of the same size as the number |
|
|
of valid files found in the directory. |
|
|
Labels should be sorted according |
|
|
to the alphanumeric order of the image file paths |
|
|
(obtained via `os.walk(directory)` in Python). |
|
|
formats: Allowlist of file extensions to index |
|
|
(e.g. `".jpg"`, `".txt"`). |
|
|
class_names: Only valid if `labels="inferred"`. This is the explicit |
|
|
list of class names (must match names of subdirectories). Used |
|
|
to control the order of the classes |
|
|
(otherwise alphanumerical order is used). |
|
|
shuffle: Whether to shuffle the data. Defaults to `True`. |
|
|
If set to `False`, sorts the data in alphanumeric order. |
|
|
seed: Optional random seed for shuffling. |
|
|
follow_links: Whether to visits subdirectories pointed to by symlinks. |
|
|
verbose: Whether the function prints number of files found and classes. |
|
|
Defaults to `True`. |
|
|
|
|
|
Returns: |
|
|
tuple (file_paths, labels, class_names). |
|
|
- file_paths: list of file paths (strings). |
|
|
- labels: list of matching integer labels (same length as file_paths) |
|
|
- class_names: names of the classes corresponding to these labels, in |
|
|
order. |
|
|
""" |
|
|
if labels == "inferred": |
|
|
subdirs = [] |
|
|
for subdir in sorted(tf.io.gfile.listdir(directory)): |
|
|
if tf.io.gfile.isdir(tf.io.gfile.join(directory, subdir)): |
|
|
if not subdir.startswith("."): |
|
|
if subdir.endswith("/"): |
|
|
subdir = subdir[:-1] |
|
|
subdirs.append(subdir) |
|
|
if class_names is not None: |
|
|
if not set(class_names).issubset(set(subdirs)): |
|
|
raise ValueError( |
|
|
"The `class_names` passed did not match the " |
|
|
"names of the subdirectories of the target directory. " |
|
|
f"Expected: {subdirs} (or a subset of it), " |
|
|
f"but received: class_names={class_names}" |
|
|
) |
|
|
subdirs = class_names |
|
|
else: |
|
|
|
|
|
subdirs = [""] |
|
|
if class_names is not None: |
|
|
if labels is None: |
|
|
raise ValueError( |
|
|
"When `labels=None` (no labels), argument `class_names` " |
|
|
"cannot be specified." |
|
|
) |
|
|
else: |
|
|
raise ValueError( |
|
|
"When argument `labels` is specified, argument " |
|
|
"`class_names` cannot be specified (the `class_names` " |
|
|
"will be the sorted list of labels)." |
|
|
) |
|
|
class_names = subdirs |
|
|
class_indices = dict(zip(class_names, range(len(class_names)))) |
|
|
|
|
|
|
|
|
|
|
|
pool = ThreadPool() |
|
|
results = [] |
|
|
filenames = [] |
|
|
|
|
|
for dirpath in (tf.io.gfile.join(directory, subdir) for subdir in subdirs): |
|
|
results.append( |
|
|
pool.apply_async( |
|
|
index_subdirectory, |
|
|
(dirpath, class_indices, follow_links, formats), |
|
|
) |
|
|
) |
|
|
labels_list = [] |
|
|
for res in results: |
|
|
partial_filenames, partial_labels = res.get() |
|
|
labels_list.append(partial_labels) |
|
|
filenames += partial_filenames |
|
|
|
|
|
if labels == "inferred": |
|
|
|
|
|
i = 0 |
|
|
labels = np.zeros((len(filenames),), dtype="int32") |
|
|
for partial_labels in labels_list: |
|
|
labels[i : i + len(partial_labels)] = partial_labels |
|
|
i += len(partial_labels) |
|
|
elif labels is None: |
|
|
class_names = None |
|
|
else: |
|
|
|
|
|
if len(labels) != len(filenames): |
|
|
raise ValueError( |
|
|
"Expected the lengths of `labels` to match the number " |
|
|
"of files in the target directory. len(labels) is " |
|
|
f"{len(labels)} while we found {len(filenames)} files " |
|
|
f"in directory {directory}." |
|
|
) |
|
|
class_names = [str(label) for label in sorted(set(labels))] |
|
|
if verbose: |
|
|
if labels is None: |
|
|
io_utils.print_msg(f"Found {len(filenames)} files.") |
|
|
else: |
|
|
io_utils.print_msg( |
|
|
f"Found {len(filenames)} files belonging " |
|
|
f"to {len(class_names)} classes." |
|
|
) |
|
|
pool.close() |
|
|
pool.join() |
|
|
file_paths = [tf.io.gfile.join(directory, fname) for fname in filenames] |
|
|
|
|
|
if shuffle: |
|
|
|
|
|
if seed is None: |
|
|
seed = np.random.randint(1e6) |
|
|
rng = np.random.RandomState(seed) |
|
|
rng.shuffle(file_paths) |
|
|
if labels is not None: |
|
|
rng = np.random.RandomState(seed) |
|
|
rng.shuffle(labels) |
|
|
return file_paths, labels, class_names |
|
|
|
|
|
|
|
|
def iter_valid_files(directory, follow_links, formats): |
|
|
if not follow_links: |
|
|
walk = tf.io.gfile.walk(directory) |
|
|
else: |
|
|
walk = os.walk(directory, followlinks=follow_links) |
|
|
for root, _, files in sorted(walk, key=lambda x: x[0]): |
|
|
for fname in sorted(files): |
|
|
if fname.lower().endswith(formats): |
|
|
yield root, fname |
|
|
|
|
|
|
|
|
def index_subdirectory(directory, class_indices, follow_links, formats): |
|
|
"""Recursively walks directory and list image paths and their class index. |
|
|
|
|
|
Args: |
|
|
directory: string, target directory. |
|
|
class_indices: dict mapping class names to their index. |
|
|
follow_links: boolean, whether to recursively follow subdirectories |
|
|
(if False, we only list top-level images in `directory`). |
|
|
formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt"). |
|
|
|
|
|
Returns: |
|
|
tuple `(filenames, labels)`. `filenames` is a list of relative file |
|
|
paths, and `labels` is a list of integer labels corresponding |
|
|
to these files. |
|
|
""" |
|
|
dirname = os.path.basename(directory) |
|
|
valid_files = iter_valid_files(directory, follow_links, formats) |
|
|
labels = [] |
|
|
filenames = [] |
|
|
for root, fname in valid_files: |
|
|
labels.append(class_indices[dirname]) |
|
|
absolute_path = tf.io.gfile.join(root, fname) |
|
|
relative_path = tf.io.gfile.join( |
|
|
dirname, os.path.relpath(absolute_path, directory) |
|
|
) |
|
|
filenames.append(relative_path) |
|
|
return filenames, labels |
|
|
|
|
|
|
|
|
def get_training_or_validation_split(samples, labels, validation_split, subset): |
|
|
"""Potentially restrict samples & labels to a training or validation split. |
|
|
|
|
|
Args: |
|
|
samples: List of elements. |
|
|
labels: List of corresponding labels. |
|
|
validation_split: Float, fraction of data to reserve for validation. |
|
|
subset: Subset of the data to return. |
|
|
Either `"training"`, `"validation"`, or `None`. |
|
|
If `None`, we return all of the data. |
|
|
|
|
|
Returns: |
|
|
tuple (samples, labels), potentially restricted to the specified subset. |
|
|
""" |
|
|
if not validation_split: |
|
|
return samples, labels |
|
|
|
|
|
num_val_samples = int(validation_split * len(samples)) |
|
|
if subset == "training": |
|
|
io_utils.print_msg( |
|
|
f"Using {len(samples) - num_val_samples} files for training." |
|
|
) |
|
|
samples = samples[:-num_val_samples] |
|
|
if labels is not None: |
|
|
labels = labels[:-num_val_samples] |
|
|
elif subset == "validation": |
|
|
io_utils.print_msg(f"Using {num_val_samples} files for validation.") |
|
|
samples = samples[-num_val_samples:] |
|
|
if labels is not None: |
|
|
labels = labels[-num_val_samples:] |
|
|
else: |
|
|
raise ValueError( |
|
|
'`subset` must be either "training" ' |
|
|
f'or "validation", received: {subset}' |
|
|
) |
|
|
return samples, labels |
|
|
|
|
|
|
|
|
def labels_to_dataset(labels, label_mode, num_classes): |
|
|
"""Create a `tf.data.Dataset` from the list/tuple of labels. |
|
|
|
|
|
Args: |
|
|
labels: list/tuple of labels to be converted into a `tf.data.Dataset`. |
|
|
label_mode: String describing the encoding of `labels`. Options are: |
|
|
- `"binary"` indicates that the labels (there can be only 2) are encoded |
|
|
as `float32` scalars with values 0 or 1 |
|
|
(e.g. for `binary_crossentropy`). |
|
|
- `"categorical"` means that the labels are mapped into a categorical |
|
|
vector. (e.g. for `categorical_crossentropy` loss). |
|
|
num_classes: number of classes of labels. |
|
|
|
|
|
Returns: |
|
|
A `tf.data.Dataset` instance. |
|
|
""" |
|
|
label_ds = tf.data.Dataset.from_tensor_slices(labels) |
|
|
if label_mode == "binary": |
|
|
label_ds = label_ds.map( |
|
|
lambda x: tf.expand_dims(tf.cast(x, "float32"), axis=-1), |
|
|
num_parallel_calls=tf.data.AUTOTUNE, |
|
|
) |
|
|
elif label_mode == "categorical": |
|
|
label_ds = label_ds.map( |
|
|
lambda x: tf.one_hot(x, num_classes), |
|
|
num_parallel_calls=tf.data.AUTOTUNE, |
|
|
) |
|
|
return label_ds |
|
|
|
|
|
|
|
|
def check_validation_split_arg(validation_split, subset, shuffle, seed): |
|
|
"""Raise errors in case of invalid argument values. |
|
|
|
|
|
Args: |
|
|
validation_split: float between 0 and 1, fraction of data to reserve for |
|
|
validation. |
|
|
subset: One of `"training"`, `"validation"`, or `"both"`. Only used if |
|
|
`validation_split` is set. |
|
|
shuffle: Whether to shuffle the data. Either `True` or `False`. |
|
|
seed: random seed for shuffling and transformations. |
|
|
""" |
|
|
if validation_split and not 0 < validation_split < 1: |
|
|
raise ValueError( |
|
|
"`validation_split` must be between 0 and 1, " |
|
|
f"received: {validation_split}" |
|
|
) |
|
|
if (validation_split or subset) and not (validation_split and subset): |
|
|
raise ValueError( |
|
|
"If `subset` is set, `validation_split` must be set, and inversely." |
|
|
) |
|
|
if subset not in ("training", "validation", "both", None): |
|
|
raise ValueError( |
|
|
'`subset` must be either "training", ' |
|
|
f'"validation" or "both", received: {subset}' |
|
|
) |
|
|
if validation_split and shuffle and seed is None: |
|
|
raise ValueError( |
|
|
"If using `validation_split` and shuffling the data, you must " |
|
|
"provide a `seed` argument, to make sure that there is no " |
|
|
"overlap between the training and validation subset." |
|
|
) |
|
|
|