|
|
import argparse |
|
|
import json |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
from torchvision.utils import make_grid |
|
|
from diffusers import StableDiffusionXLPipeline, AutoencoderKL |
|
|
|
|
|
try: |
|
|
from pytorch_msssim import ssim, ms_ssim |
|
|
except ImportError: |
|
|
print("Installing pytorch-msssim...") |
|
|
import subprocess |
|
|
subprocess.check_call(["pip", "install", "pytorch-msssim"]) |
|
|
from pytorch_msssim import ssim, ms_ssim |
|
|
|
|
|
|
|
|
def add_caption_to_image(image, caption, font_size=20): |
|
|
"""Add caption to image and return as tensor""" |
|
|
|
|
|
if isinstance(image, torch.Tensor): |
|
|
image = (image * 255).clamp(0, 255).to(torch.uint8) |
|
|
image = image.permute(1, 2, 0).cpu().numpy() |
|
|
image = Image.fromarray(image) |
|
|
|
|
|
|
|
|
margin = 10 |
|
|
width = image.width |
|
|
height = image.height + font_size + 2*margin |
|
|
new_image = Image.new('RGB', (width, height), 'white') |
|
|
new_image.paste(image, (0, 0)) |
|
|
|
|
|
|
|
|
draw = ImageDraw.Draw(new_image) |
|
|
try: |
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", font_size) |
|
|
except: |
|
|
font = ImageFont.load_default() |
|
|
|
|
|
|
|
|
text_width = draw.textlength(caption, font=font) |
|
|
x = (width - text_width) // 2 |
|
|
y = height - font_size - margin |
|
|
|
|
|
draw.text((x, y), caption, fill='black', font=font) |
|
|
|
|
|
|
|
|
new_image = torch.from_numpy(np.array(new_image)).permute(2, 0, 1).float() / 255.0 |
|
|
return new_image |
|
|
|
|
|
|
|
|
def create_image_grid(images, prompts, images_per_prompt, font_size=20): |
|
|
"""Create a grid of images with captions""" |
|
|
|
|
|
captioned_images = [] |
|
|
for i, img in enumerate(images): |
|
|
prompt_idx = i // images_per_prompt |
|
|
img_idx = i % images_per_prompt + 1 |
|
|
caption = f"{prompts[prompt_idx]} ({img_idx}/{images_per_prompt})" |
|
|
img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0 |
|
|
captioned_img = add_caption_to_image(img_tensor, caption, font_size) |
|
|
captioned_images.append(captioned_img) |
|
|
|
|
|
|
|
|
image_tensor = torch.stack(captioned_images) |
|
|
grid = make_grid(image_tensor, nrow=images_per_prompt, padding=10) |
|
|
|
|
|
return grid |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--output_path", type=str, required=True, help="path to save the images" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--content_LoRA", type=str, default=None, help="path for the content LoRA" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--content_alpha", type=float, default=1.0, help="scale factor for content LoRA weights" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--style_LoRA", type=str, default=None, help="path for the style LoRA" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--style_alpha", type=float, default=1.0, help="scale factor for style LoRA weights" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num_images_per_prompt", type=int, default=4, help="number of images per prompt" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--evaluation_prompt_file", type=str, required=True, help="path to evaluation prompts file" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--placeholder_style", type=str, required=True, help="placeholder for the style prompt" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--placeholder_content", type=str, required=True, help="placeholder for the content prompt" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--name_concept", type=str, required=True, help="name of the concept being evaluated" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--font_size", type=int, default=20, help="font size for image captions" |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def process_prompts(pipeline, prompts, output_dir, args, prompt_type, lora_type, start_idx=0): |
|
|
"""Process a set of prompts and save results""" |
|
|
all_images = [] |
|
|
current_idx = start_idx |
|
|
|
|
|
for prompt in prompts: |
|
|
formatted_prompt = prompt.replace("{}", args.placeholder_style if lora_type == "style" else args.placeholder_content) |
|
|
|
|
|
|
|
|
config = { |
|
|
"gen_prompt": formatted_prompt, |
|
|
"content_LoRA": args.content_LoRA if lora_type == "content" else None, |
|
|
"content_alpha": args.content_alpha if lora_type == "content" else None, |
|
|
"style_LoRA": args.style_LoRA if lora_type == "style" else None, |
|
|
"style_alpha": args.style_alpha if lora_type == "style" else None |
|
|
} |
|
|
|
|
|
|
|
|
config_path = output_dir / f'prompt_{current_idx}_params.json' |
|
|
with open(config_path, 'w') as f: |
|
|
json.dump(config, f, indent=4) |
|
|
|
|
|
|
|
|
images = pipeline(formatted_prompt, num_images_per_prompt=args.num_images_per_prompt).images |
|
|
all_images.extend(images) |
|
|
|
|
|
|
|
|
prompt_dir = output_dir / 'output' / 'ours' / f'prompt_{current_idx}_{prompt_type}' |
|
|
prompt_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
for img_idx, img in enumerate(images): |
|
|
img.save(prompt_dir / f'{img_idx:03d}.jpg') |
|
|
|
|
|
current_idx += 1 |
|
|
|
|
|
return all_images, [p.replace("{}", args.placeholder_style if lora_type == "style" else args.placeholder_content) for p in prompts], current_idx |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") |
|
|
result_dir = Path(args.output_path) / f'{args.name_concept}_{timestamp}' |
|
|
result_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
with open(args.evaluation_prompt_file, 'r') as f: |
|
|
benchmark_prompts = json.load(f) |
|
|
|
|
|
|
|
|
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) |
|
|
pipeline = StableDiffusionXLPipeline.from_pretrained( |
|
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
|
vae=vae, |
|
|
torch_dtype=torch.float16 |
|
|
).to("cuda") |
|
|
|
|
|
current_prompt_idx = 0 |
|
|
|
|
|
|
|
|
if args.content_LoRA is not None: |
|
|
print("Loading content LoRA...") |
|
|
pipeline.load_lora_weights(args.content_LoRA, scale=args.content_alpha) |
|
|
|
|
|
for category, prompts in benchmark_prompts["content"].items(): |
|
|
print(f"Processing content {category} prompts...") |
|
|
images, formatted_prompts, current_prompt_idx = process_prompts( |
|
|
pipeline, prompts, result_dir, args, f"content_{category}", "content", |
|
|
start_idx=current_prompt_idx |
|
|
) |
|
|
|
|
|
grid = create_image_grid(images, formatted_prompts, args.num_images_per_prompt, args.font_size) |
|
|
grid_image = Image.fromarray((grid.permute(1, 2, 0).numpy() * 255).astype(np.uint8)) |
|
|
grid_path = result_dir / f'grid_content_{category}.png' |
|
|
grid_image.save(grid_path) |
|
|
|
|
|
|
|
|
pipeline.unload_lora_weights() |
|
|
|
|
|
|
|
|
if args.style_LoRA is not None: |
|
|
print("Loading style LoRA...") |
|
|
pipeline.load_lora_weights(args.style_LoRA, scale=args.style_alpha) |
|
|
|
|
|
print("Processing style prompts...") |
|
|
images, formatted_prompts, _ = process_prompts( |
|
|
pipeline, benchmark_prompts["style"], result_dir, args, "style", "style", |
|
|
start_idx=current_prompt_idx |
|
|
) |
|
|
|
|
|
grid = create_image_grid(images, formatted_prompts, args.num_images_per_prompt, args.font_size) |
|
|
grid_image = Image.fromarray((grid.permute(1, 2, 0).numpy() * 255).astype(np.uint8)) |
|
|
grid_path = result_dir / 'grid_style.png' |
|
|
grid_image.save(grid_path) |
|
|
|
|
|
print(f"Results saved to {result_dir}") |
|
|
|