File size: 816 Bytes
d1bfee5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import tensorflow as tf
import os
import numpy as np

def load_image(image_file):
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, [256, 256])
    image = (image * 2) - 1
    return image

def get_dataset(root_path, subset="train"):
    path_a = os.path.join(root_path, f"{subset}A")
    path_b = os.path.join(root_path, f"{subset}B")
    
    list_a = tf.data.Dataset.list_files(path_a + "/*.jpg")
    list_b = tf.data.Dataset.list_files(path_b + "/*.jpg")
    
    ds_a = list_a.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
    ds_b = list_b.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
    
    return tf.data.Dataset.zip((ds_a, ds_b))