Bordoglor's picture
Upload folder using huggingface_hub
302920f verified
import random
import numpy as np
import torch
import wandb
from datasets import load_dataset
from diffusers import DDIMScheduler
from PIL import Image
from torchvision import transforms
from utils.pipeline_controlnet import LightControlNetPipeline
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def log_validation(val_dataset, text_encoder, unet, controlnet, args, accelerator):
pipeline = LightControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
controlnet=accelerator.unwrap_model(controlnet, keep_fp32_wrapper=True),
unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True).model,
text_encoder=accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True),
safety_checker=None,
revision=args.revision,
)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
image_logs = []
for idx in range(args.num_validation_images):
data = val_dataset[idx]
validation_prompt = data["text"]
validation_image = data["conditioning_pixel_values"]
image = pipeline(
validation_prompt,
[validation_image],
num_inference_steps=50,
generator=generator,
)[0][0]
image_logs.append(
{
"validation_image": validation_image,
"image": image,
"validation_prompt": validation_prompt,
}
)
for tracker in accelerator.trackers:
formatted_images = []
for log in image_logs:
image = log["image"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)
tracker.log({"validation": formatted_images})
del pipeline
torch.cuda.empty_cache()
def make_dataset(args, tokenizer, accelerator, split="train"):
# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
else:
if args.train_data_dir is not None:
dataset = load_dataset(
args.train_data_dir,
cache_dir=args.cache_dir,
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset[split].column_names
# Get the column names for input/target.
if args.image_column is None:
image_column = column_names[0]
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
if args.caption_column is None:
caption_column = column_names[1]
else:
caption_column = args.caption_column
if caption_column not in column_names:
raise ValueError(
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
if args.conditioning_image_column is None:
conditioning_image_column = column_names[2]
else:
conditioning_image_column = args.conditioning_image_column
if conditioning_image_column not in column_names:
raise ValueError(
f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
def tokenize_captions(examples, is_train=True):
captions = []
for caption in examples[caption_column]:
if random.random() < args.proportion_empty_prompts:
captions.append("")
elif isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
else:
raise ValueError(
f"Caption column `{caption_column}` should contain either strings or lists of strings."
)
inputs = tokenizer(
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
return inputs.input_ids
image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
conditioning_image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
]
)
def preprocess_train(examples):
images = [image.convert("RGB") for image in examples[image_column]]
images = [image_transforms(image) for image in images]
conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
examples["pixel_values"] = images
examples["conditioning_pixel_values"] = conditioning_images
examples["input_ids"] = tokenize_captions(examples)
return examples
with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset[split] = dataset[split].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
split_dataset = dataset[split].with_transform(preprocess_train)
return split_dataset
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = torch.stack([example["input_ids"] for example in examples])
return {
"pixel_values": pixel_values,
"conditioning_pixel_values": conditioning_pixel_values,
"input_ids": input_ids,
}