d-e-e-k-11 commited on
Commit
61a06ac
·
verified ·
1 Parent(s): b0898f2

Upload tf_dataset.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tf_dataset.py +23 -0
tf_dataset.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import os
3
+ import numpy as np
4
+
5
+ def load_image(image_file):
6
+ image = tf.io.read_file(image_file)
7
+ image = tf.image.decode_jpeg(image, channels=3)
8
+ image = tf.image.convert_image_dtype(image, tf.float32)
9
+ image = tf.image.resize(image, [256, 256])
10
+ image = (image * 2) - 1
11
+ return image
12
+
13
+ def get_dataset(root_path, subset="train"):
14
+ path_a = os.path.join(root_path, f"{subset}A")
15
+ path_b = os.path.join(root_path, f"{subset}B")
16
+
17
+ list_a = tf.data.Dataset.list_files(path_a + "/*.jpg")
18
+ list_b = tf.data.Dataset.list_files(path_b + "/*.jpg")
19
+
20
+ ds_a = list_a.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
21
+ ds_b = list_b.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
22
+
23
+ return tf.data.Dataset.zip((ds_a, ds_b))