image-Style / tf_dataset.py
d-e-e-k-11's picture
Upload folder using huggingface_hub
d1bfee5 verified
import tensorflow as tf
import os
import numpy as np
def load_image(image_file):
image = tf.io.read_file(image_file)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [256, 256])
image = (image * 2) - 1
return image
def get_dataset(root_path, subset="train"):
path_a = os.path.join(root_path, f"{subset}A")
path_b = os.path.join(root_path, f"{subset}B")
list_a = tf.data.Dataset.list_files(path_a + "/*.jpg")
list_b = tf.data.Dataset.list_files(path_b + "/*.jpg")
ds_a = list_a.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
ds_b = list_b.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
return tf.data.Dataset.zip((ds_a, ds_b))