|
|
import random |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import wandb |
|
|
from datasets import load_dataset |
|
|
from diffusers import DDIMScheduler |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
from utils.pipeline_controlnet import LightControlNetPipeline |
|
|
|
|
|
|
|
|
def image_grid(imgs, rows, cols): |
|
|
assert len(imgs) == rows * cols |
|
|
|
|
|
w, h = imgs[0].size |
|
|
grid = Image.new("RGB", size=(cols * w, rows * h)) |
|
|
|
|
|
for i, img in enumerate(imgs): |
|
|
grid.paste(img, box=(i % cols * w, i // cols * h)) |
|
|
return grid |
|
|
|
|
|
|
|
|
def log_validation(val_dataset, text_encoder, unet, controlnet, args, accelerator): |
|
|
pipeline = LightControlNetPipeline.from_pretrained( |
|
|
args.pretrained_model_name_or_path, |
|
|
controlnet=accelerator.unwrap_model(controlnet, keep_fp32_wrapper=True), |
|
|
unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True).model, |
|
|
text_encoder=accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True), |
|
|
safety_checker=None, |
|
|
revision=args.revision, |
|
|
) |
|
|
|
|
|
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) |
|
|
pipeline = pipeline.to(accelerator.device) |
|
|
|
|
|
pipeline.set_progress_bar_config(disable=True) |
|
|
|
|
|
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) |
|
|
|
|
|
image_logs = [] |
|
|
|
|
|
for idx in range(args.num_validation_images): |
|
|
data = val_dataset[idx] |
|
|
validation_prompt = data["text"] |
|
|
validation_image = data["conditioning_pixel_values"] |
|
|
|
|
|
image = pipeline( |
|
|
validation_prompt, |
|
|
[validation_image], |
|
|
num_inference_steps=50, |
|
|
generator=generator, |
|
|
)[0][0] |
|
|
|
|
|
image_logs.append( |
|
|
{ |
|
|
"validation_image": validation_image, |
|
|
"image": image, |
|
|
"validation_prompt": validation_prompt, |
|
|
} |
|
|
) |
|
|
|
|
|
for tracker in accelerator.trackers: |
|
|
formatted_images = [] |
|
|
|
|
|
for log in image_logs: |
|
|
image = log["image"] |
|
|
validation_prompt = log["validation_prompt"] |
|
|
validation_image = log["validation_image"] |
|
|
|
|
|
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) |
|
|
|
|
|
image = wandb.Image(image, caption=validation_prompt) |
|
|
formatted_images.append(image) |
|
|
|
|
|
tracker.log({"validation": formatted_images}) |
|
|
|
|
|
del pipeline |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
def make_dataset(args, tokenizer, accelerator, split="train"): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.dataset_name is not None: |
|
|
|
|
|
dataset = load_dataset( |
|
|
args.dataset_name, |
|
|
args.dataset_config_name, |
|
|
cache_dir=args.cache_dir, |
|
|
) |
|
|
else: |
|
|
if args.train_data_dir is not None: |
|
|
dataset = load_dataset( |
|
|
args.train_data_dir, |
|
|
cache_dir=args.cache_dir, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
column_names = dataset[split].column_names |
|
|
|
|
|
|
|
|
if args.image_column is None: |
|
|
image_column = column_names[0] |
|
|
else: |
|
|
image_column = args.image_column |
|
|
if image_column not in column_names: |
|
|
raise ValueError( |
|
|
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" |
|
|
) |
|
|
|
|
|
if args.caption_column is None: |
|
|
caption_column = column_names[1] |
|
|
else: |
|
|
caption_column = args.caption_column |
|
|
if caption_column not in column_names: |
|
|
raise ValueError( |
|
|
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" |
|
|
) |
|
|
|
|
|
if args.conditioning_image_column is None: |
|
|
conditioning_image_column = column_names[2] |
|
|
else: |
|
|
conditioning_image_column = args.conditioning_image_column |
|
|
if conditioning_image_column not in column_names: |
|
|
raise ValueError( |
|
|
f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" |
|
|
) |
|
|
|
|
|
def tokenize_captions(examples, is_train=True): |
|
|
captions = [] |
|
|
for caption in examples[caption_column]: |
|
|
if random.random() < args.proportion_empty_prompts: |
|
|
captions.append("") |
|
|
elif isinstance(caption, str): |
|
|
captions.append(caption) |
|
|
elif isinstance(caption, (list, np.ndarray)): |
|
|
|
|
|
captions.append(random.choice(caption) if is_train else caption[0]) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Caption column `{caption_column}` should contain either strings or lists of strings." |
|
|
) |
|
|
inputs = tokenizer( |
|
|
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" |
|
|
) |
|
|
return inputs.input_ids |
|
|
|
|
|
image_transforms = transforms.Compose( |
|
|
[ |
|
|
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), |
|
|
transforms.CenterCrop(args.resolution), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5], [0.5]), |
|
|
] |
|
|
) |
|
|
|
|
|
conditioning_image_transforms = transforms.Compose( |
|
|
[ |
|
|
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), |
|
|
transforms.CenterCrop(args.resolution), |
|
|
transforms.ToTensor(), |
|
|
] |
|
|
) |
|
|
|
|
|
def preprocess_train(examples): |
|
|
images = [image.convert("RGB") for image in examples[image_column]] |
|
|
images = [image_transforms(image) for image in images] |
|
|
|
|
|
conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]] |
|
|
conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] |
|
|
|
|
|
examples["pixel_values"] = images |
|
|
examples["conditioning_pixel_values"] = conditioning_images |
|
|
examples["input_ids"] = tokenize_captions(examples) |
|
|
|
|
|
return examples |
|
|
|
|
|
with accelerator.main_process_first(): |
|
|
if args.max_train_samples is not None: |
|
|
dataset[split] = dataset[split].shuffle(seed=args.seed).select(range(args.max_train_samples)) |
|
|
|
|
|
split_dataset = dataset[split].with_transform(preprocess_train) |
|
|
|
|
|
return split_dataset |
|
|
|
|
|
|
|
|
def collate_fn(examples): |
|
|
pixel_values = torch.stack([example["pixel_values"] for example in examples]) |
|
|
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() |
|
|
|
|
|
conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) |
|
|
conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() |
|
|
|
|
|
input_ids = torch.stack([example["input_ids"] for example in examples]) |
|
|
|
|
|
return { |
|
|
"pixel_values": pixel_values, |
|
|
"conditioning_pixel_values": conditioning_pixel_values, |
|
|
"input_ids": input_ids, |
|
|
} |
|
|
|