github-actions[bot]
🚀 Deploy method comparison app from GH action
495c4e6
Raw
History Blame Contribute Delete
3.82 kB
# Copyright 2026-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data handling for the image generation benchmark."""
import numpy as np
import torchvision.transforms as T
from datasets import load_dataset
from PIL import Image
from PIL.ImageOps import exif_transpose
def _to_rgb(image) -> Image.Image:
if isinstance(image, Image.Image):
return image.convert("RGB")
return Image.fromarray(image).convert("RGB")
def _build_train_pixel_values(images: list[Image.Image], resolution: int):
size = resolution, resolution # hard-code square
train_augmentations = T.Compose(
[
T.Resize(size, interpolation=T.InterpolationMode.BILINEAR),
T.ToTensor(),
T.Normalize([0.5], [0.5]),
]
)
return [train_augmentations(exif_transpose(image)) for image in images]
def get_train_valid_test_datasets(*, train_config, print_fn=print):
ds = load_dataset(train_config.dataset_id, split=train_config.dataset_split)
image_column = train_config.image_column
train_size = len(ds) - train_config.valid_size - train_config.test_size
prompts = train_config.instance_prompts
if isinstance(prompts, str):
prompts = [prompts] * len(ds)
else:
if len(ds) != len(prompts):
raise ValueError(f"Need 1 instance prompt per sample image, found {len(prompts)} and {len(ds)} instead.")
train_size = len(ds) - train_config.valid_size - train_config.test_size
if train_size < 1:
raise ValueError(
f"Dataset too small: need at least {1 + train_config.valid_size + train_config.test_size} rows, "
f"found {len(ds)}"
)
np.random.seed(0)
indices = np.arange(len(ds))
np.random.shuffle(indices)
idx_train = indices[:train_size]
idx_valid = indices[train_size : train_size + train_config.valid_size]
idx_test = indices[
train_size + train_config.valid_size : train_size + train_config.valid_size + train_config.test_size
]
ds_train = ds.select(idx_train)
ds_valid = ds.select(idx_valid)
ds_test = ds.select(idx_test)
train_images = [_to_rgb(img) for img in ds_train[image_column]]
valid_images = [_to_rgb(img) for img in ds_valid[image_column]]
test_images = [_to_rgb(img) for img in ds_test[image_column]]
train_prompts = [prompts[i] for i in idx_train]
valid_prompts = [prompts[i] for i in idx_valid]
test_prompts = [prompts[i] for i in idx_test]
train_dataset = {
"pixel_values": _build_train_pixel_values(train_images, train_config.resolution),
"prompts": train_prompts,
"repeats": train_config.repeats,
}
valid_dataset = [
{"raw_image": exif_transpose(image), "prompt": prompt} for image, prompt in zip(valid_images, valid_prompts)
]
test_dataset = [
{"raw_image": exif_transpose(image), "prompt": prompt} for image, prompt in zip(test_images, test_prompts)
]
print_fn(f"Dataset: {train_config.dataset_id}")
print_fn(f"Raw rows: {len(ds)}")
print_fn(f"Train rows: {len(train_dataset['prompts']) * train_dataset['repeats']}")
print_fn(f"Valid rows: {len(valid_dataset)}")
print_fn(f"Test rows: {len(test_dataset)}")
return train_dataset, valid_dataset, test_dataset