Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Helper functions for creating TFRecord datasets.""" | |
| import hashlib | |
| import io | |
| import itertools | |
| from absl import logging | |
| import numpy as np | |
| from PIL import Image | |
| import tensorflow as tf, tf_keras | |
| import multiprocessing as mp | |
| LOG_EVERY = 100 | |
| def convert_to_feature(value, value_type=None): | |
| """Converts the given python object to a tf.train.Feature. | |
| Args: | |
| value: int, float, bytes or a list of them. | |
| value_type: optional, if specified, forces the feature to be of the given | |
| type. Otherwise, type is inferred automatically. Can be one of | |
| ['bytes', 'int64', 'float', 'bytes_list', 'int64_list', 'float_list'] | |
| Returns: | |
| feature: A tf.train.Feature object. | |
| """ | |
| if value_type is None: | |
| element = value[0] if isinstance(value, list) else value | |
| if isinstance(element, bytes): | |
| value_type = 'bytes' | |
| elif isinstance(element, (int, np.integer)): | |
| value_type = 'int64' | |
| elif isinstance(element, (float, np.floating)): | |
| value_type = 'float' | |
| else: | |
| raise ValueError('Cannot convert type {} to feature'. | |
| format(type(element))) | |
| if isinstance(value, list): | |
| value_type = value_type + '_list' | |
| if value_type == 'int64': | |
| return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) | |
| elif value_type == 'int64_list': | |
| value = np.asarray(value).astype(np.int64).reshape(-1) | |
| return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) | |
| elif value_type == 'float': | |
| return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) | |
| elif value_type == 'float_list': | |
| value = np.asarray(value).astype(np.float32).reshape(-1) | |
| return tf.train.Feature(float_list=tf.train.FloatList(value=value)) | |
| elif value_type == 'bytes': | |
| return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | |
| elif value_type == 'bytes_list': | |
| return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) | |
| else: | |
| raise ValueError('Unknown value_type parameter - {}'.format(value_type)) | |
| def image_info_to_feature_dict(height, width, filename, image_id, | |
| encoded_str, encoded_format): | |
| """Convert image information to a dict of features.""" | |
| key = hashlib.sha256(encoded_str).hexdigest() | |
| return { | |
| 'image/height': convert_to_feature(height), | |
| 'image/width': convert_to_feature(width), | |
| 'image/filename': convert_to_feature(filename.encode('utf8')), | |
| 'image/source_id': convert_to_feature(str(image_id).encode('utf8')), | |
| 'image/key/sha256': convert_to_feature(key.encode('utf8')), | |
| 'image/encoded': convert_to_feature(encoded_str), | |
| 'image/format': convert_to_feature(encoded_format.encode('utf8')), | |
| } | |
| def read_image(image_path): | |
| pil_image = Image.open(image_path) | |
| return np.asarray(pil_image) | |
| def encode_mask_as_png(mask): | |
| pil_image = Image.fromarray(mask) | |
| output_io = io.BytesIO() | |
| pil_image.save(output_io, format='PNG') | |
| return output_io.getvalue() | |
| def write_tf_record_dataset(output_path, annotation_iterator, | |
| process_func, num_shards, | |
| multiple_processes=None, unpack_arguments=True): | |
| """Iterates over annotations, processes them and writes into TFRecords. | |
| Args: | |
| output_path: The prefix path to create TF record files. | |
| annotation_iterator: An iterator of tuples containing details about the | |
| dataset. | |
| process_func: A function which takes the elements from the tuples of | |
| annotation_iterator as arguments and returns a tuple of (tf.train.Example, | |
| int). The integer indicates the number of annotations that were skipped. | |
| num_shards: int, the number of shards to write for the dataset. | |
| multiple_processes: integer, the number of multiple parallel processes to | |
| use. If None, uses multi-processing with number of processes equal to | |
| `os.cpu_count()`, which is Python's default behavior. If set to 0, | |
| multi-processing is disabled. | |
| Whether or not to use multiple processes to write TF Records. | |
| unpack_arguments: | |
| Whether to unpack the tuples from annotation_iterator as individual | |
| arguments to the process func or to pass the returned value as it is. | |
| Returns: | |
| num_skipped: The total number of skipped annotations. | |
| """ | |
| writers = [ | |
| tf.io.TFRecordWriter( | |
| output_path + '-%05d-of-%05d.tfrecord' % (i, num_shards)) | |
| for i in range(num_shards) | |
| ] | |
| total_num_annotations_skipped = 0 | |
| if multiple_processes is None or multiple_processes > 0: | |
| pool = mp.Pool( | |
| processes=multiple_processes) | |
| if unpack_arguments: | |
| tf_example_iterator = pool.starmap(process_func, annotation_iterator) | |
| else: | |
| tf_example_iterator = pool.imap(process_func, annotation_iterator) | |
| else: | |
| if unpack_arguments: | |
| tf_example_iterator = itertools.starmap(process_func, annotation_iterator) | |
| else: | |
| tf_example_iterator = map(process_func, annotation_iterator) | |
| for idx, (tf_example, num_annotations_skipped) in enumerate( | |
| tf_example_iterator): | |
| if idx % LOG_EVERY == 0: | |
| logging.info('On image %d', idx) | |
| total_num_annotations_skipped += num_annotations_skipped | |
| writers[idx % num_shards].write(tf_example.SerializeToString()) | |
| if multiple_processes is None or multiple_processes > 0: | |
| pool.close() | |
| pool.join() | |
| for writer in writers: | |
| writer.close() | |
| logging.info('Finished writing, skipped %d annotations.', | |
| total_num_annotations_skipped) | |
| return total_num_annotations_skipped | |
| def check_and_make_dir(directory): | |
| """Creates the directory if it doesn't exist.""" | |
| if not tf.io.gfile.isdir(directory): | |
| tf.io.gfile.makedirs(directory) | |