File size: 8,056 Bytes
4d217d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import argparse
import json
from pathlib import Path
from datetime import datetime
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from torchvision.utils import make_grid
from diffusers import StableDiffusionXLPipeline, AutoencoderKL
try:
from pytorch_msssim import ssim, ms_ssim
except ImportError:
print("Installing pytorch-msssim...")
import subprocess
subprocess.check_call(["pip", "install", "pytorch-msssim"])
from pytorch_msssim import ssim, ms_ssim
def add_caption_to_image(image, caption, font_size=20):
"""Add caption to image and return as tensor"""
# Convert tensor to PIL Image if needed
if isinstance(image, torch.Tensor):
image = (image * 255).clamp(0, 255).to(torch.uint8)
image = image.permute(1, 2, 0).cpu().numpy()
image = Image.fromarray(image)
# Create new image with space for caption
margin = 10
width = image.width
height = image.height + font_size + 2*margin
new_image = Image.new('RGB', (width, height), 'white')
new_image.paste(image, (0, 0))
# Add caption
draw = ImageDraw.Draw(new_image)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", font_size)
except:
font = ImageFont.load_default()
# Center the text
text_width = draw.textlength(caption, font=font)
x = (width - text_width) // 2
y = height - font_size - margin
draw.text((x, y), caption, fill='black', font=font)
# Convert back to tensor
new_image = torch.from_numpy(np.array(new_image)).permute(2, 0, 1).float() / 255.0
return new_image
def create_image_grid(images, prompts, images_per_prompt, font_size=20):
"""Create a grid of images with captions"""
# First add captions to all images
captioned_images = []
for i, img in enumerate(images):
prompt_idx = i // images_per_prompt
img_idx = i % images_per_prompt + 1
caption = f"{prompts[prompt_idx]} ({img_idx}/{images_per_prompt})"
img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
captioned_img = add_caption_to_image(img_tensor, caption, font_size)
captioned_images.append(captioned_img)
# Convert to tensor and create grid
image_tensor = torch.stack(captioned_images)
grid = make_grid(image_tensor, nrow=images_per_prompt, padding=10)
return grid
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--output_path", type=str, required=True, help="path to save the images"
)
parser.add_argument(
"--content_LoRA", type=str, default=None, help="path for the content LoRA"
)
parser.add_argument(
"--content_alpha", type=float, default=1.0, help="scale factor for content LoRA weights"
)
parser.add_argument(
"--style_LoRA", type=str, default=None, help="path for the style LoRA"
)
parser.add_argument(
"--style_alpha", type=float, default=1.0, help="scale factor for style LoRA weights"
)
parser.add_argument(
"--num_images_per_prompt", type=int, default=4, help="number of images per prompt"
)
parser.add_argument(
"--evaluation_prompt_file", type=str, required=True, help="path to evaluation prompts file"
)
parser.add_argument(
"--placeholder_style", type=str, required=True, help="placeholder for the style prompt"
)
parser.add_argument(
"--placeholder_content", type=str, required=True, help="placeholder for the content prompt"
)
parser.add_argument(
"--name_concept", type=str, required=True, help="name of the concept being evaluated"
)
parser.add_argument(
"--font_size", type=int, default=20, help="font size for image captions"
)
return parser.parse_args()
def process_prompts(pipeline, prompts, output_dir, args, prompt_type, lora_type, start_idx=0):
"""Process a set of prompts and save results"""
all_images = []
current_idx = start_idx
for prompt in prompts:
formatted_prompt = prompt.replace("{}", args.placeholder_style if lora_type == "style" else args.placeholder_content)
# Update config to use new argument names
config = {
"gen_prompt": formatted_prompt,
"content_LoRA": args.content_LoRA if lora_type == "content" else None,
"content_alpha": args.content_alpha if lora_type == "content" else None,
"style_LoRA": args.style_LoRA if lora_type == "style" else None,
"style_alpha": args.style_alpha if lora_type == "style" else None
}
# Save config with consecutive numbering
config_path = output_dir / f'prompt_{current_idx}_params.json'
with open(config_path, 'w') as f:
json.dump(config, f, indent=4)
# Generate images
images = pipeline(formatted_prompt, num_images_per_prompt=args.num_images_per_prompt).images
all_images.extend(images)
# Save individual images with consecutive numbering
prompt_dir = output_dir / 'output' / 'ours' / f'prompt_{current_idx}_{prompt_type}'
prompt_dir.mkdir(parents=True, exist_ok=True)
for img_idx, img in enumerate(images):
img.save(prompt_dir / f'{img_idx:03d}.jpg')
current_idx += 1
return all_images, [p.replace("{}", args.placeholder_style if lora_type == "style" else args.placeholder_content) for p in prompts], current_idx
if __name__ == '__main__':
args = parse_args()
# Create timestamped output directory
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
result_dir = Path(args.output_path) / f'{args.name_concept}_{timestamp}'
result_dir.mkdir(parents=True, exist_ok=True)
# Load benchmark prompts
with open(args.evaluation_prompt_file, 'r') as f:
benchmark_prompts = json.load(f)
# Initialize pipeline
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
vae=vae,
torch_dtype=torch.float16
).to("cuda")
current_prompt_idx = 0
# Process content prompts if content LoRA is provided
if args.content_LoRA is not None:
print("Loading content LoRA...")
pipeline.load_lora_weights(args.content_LoRA, scale=args.content_alpha)
for category, prompts in benchmark_prompts["content"].items():
print(f"Processing content {category} prompts...")
images, formatted_prompts, current_prompt_idx = process_prompts(
pipeline, prompts, result_dir, args, f"content_{category}", "content",
start_idx=current_prompt_idx
)
grid = create_image_grid(images, formatted_prompts, args.num_images_per_prompt, args.font_size)
grid_image = Image.fromarray((grid.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
grid_path = result_dir / f'grid_content_{category}.png'
grid_image.save(grid_path)
# Unload content LoRA
pipeline.unload_lora_weights()
# Process style prompts if style LoRA is provided
if args.style_LoRA is not None:
print("Loading style LoRA...")
pipeline.load_lora_weights(args.style_LoRA, scale=args.style_alpha)
print("Processing style prompts...")
images, formatted_prompts, _ = process_prompts(
pipeline, benchmark_prompts["style"], result_dir, args, "style", "style",
start_idx=current_prompt_idx
)
grid = create_image_grid(images, formatted_prompts, args.num_images_per_prompt, args.font_size)
grid_image = Image.fromarray((grid.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
grid_path = result_dir / 'grid_style.png'
grid_image.save(grid_path)
print(f"Results saved to {result_dir}")
|