Spaces:
Runtime error
Runtime error
| # src/dataloader.py | |
| import tensorflow as tf | |
| import pandas as pd | |
| import cv2 | |
| import numpy as np | |
| import os | |
| IMG_SIZE = (128,128) | |
| def load_image(path): | |
| path = path.decode("utf-8") | |
| if not os.path.exists(path): | |
| return np.zeros((128, 128, 1), dtype=np.float32) | |
| img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) | |
| if img is None: | |
| return np.zeros((128, 128, 1), dtype=np.float32) | |
| img = cv2.resize(img, (128, 128)) | |
| img = img / 255.0 | |
| # VERY IMPORTANT: add channel dimension | |
| img = np.expand_dims(img, axis=-1) | |
| return img.astype(np.float32) | |
| def parse_pair(img1_path, img2_path, label): | |
| img1 = tf.numpy_function(load_image, [img1_path], tf.float32) | |
| img2 = tf.numpy_function(load_image, [img2_path], tf.float32) | |
| label = tf.cast(label, tf.float32) | |
| img1.set_shape((*IMG_SIZE, 1)) | |
| img2.set_shape((*IMG_SIZE, 1)) | |
| return (img1, img2), label | |
| def create_dataset(csv_file, batch_size=16, validation_split=0.2): | |
| df = pd.read_csv(csv_file) | |
| # Shuffle once | |
| df = df.sample(frac=1).reset_index(drop=True) | |
| split_idx = int(len(df) * (1 - validation_split)) | |
| train_df = df[:split_idx] | |
| val_df = df[split_idx:] | |
| def make_ds(dataframe): | |
| ds = tf.data.Dataset.from_tensor_slices( | |
| (dataframe["img1"], dataframe["img2"], dataframe["label"]) | |
| ) | |
| ds = ds.map(parse_pair, num_parallel_calls=tf.data.AUTOTUNE) | |
| ds = ds.batch(batch_size) | |
| ds = ds.cache().prefetch(tf.data.AUTOTUNE) | |
| return ds | |
| return make_ds(train_df), make_ds(val_df) | |