import tensorflow as tf import os import numpy as np def load_image(image_file): image = tf.io.read_file(image_file) image = tf.image.decode_jpeg(image, channels=3) image = tf.image.convert_image_dtype(image, tf.float32) image = tf.image.resize(image, [256, 256]) image = (image * 2) - 1 return image def get_dataset(root_path, subset="train"): path_a = os.path.join(root_path, f"{subset}A") path_b = os.path.join(root_path, f"{subset}B") list_a = tf.data.Dataset.list_files(path_a + "/*.jpg") list_b = tf.data.Dataset.list_files(path_b + "/*.jpg") ds_a = list_a.map(load_image, num_parallel_calls=tf.data.AUTOTUNE) ds_b = list_b.map(load_image, num_parallel_calls=tf.data.AUTOTUNE) return tf.data.Dataset.zip((ds_a, ds_b))