import argparse import json from pathlib import Path from datetime import datetime import numpy as np import torch from PIL import Image, ImageDraw, ImageFont from torchvision.utils import make_grid from diffusers import StableDiffusionXLPipeline, AutoencoderKL try: from pytorch_msssim import ssim, ms_ssim except ImportError: print("Installing pytorch-msssim...") import subprocess subprocess.check_call(["pip", "install", "pytorch-msssim"]) from pytorch_msssim import ssim, ms_ssim def add_caption_to_image(image, caption, font_size=20): """Add caption to image and return as tensor""" # Convert tensor to PIL Image if needed if isinstance(image, torch.Tensor): image = (image * 255).clamp(0, 255).to(torch.uint8) image = image.permute(1, 2, 0).cpu().numpy() image = Image.fromarray(image) # Create new image with space for caption margin = 10 width = image.width height = image.height + font_size + 2*margin new_image = Image.new('RGB', (width, height), 'white') new_image.paste(image, (0, 0)) # Add caption draw = ImageDraw.Draw(new_image) try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", font_size) except: font = ImageFont.load_default() # Center the text text_width = draw.textlength(caption, font=font) x = (width - text_width) // 2 y = height - font_size - margin draw.text((x, y), caption, fill='black', font=font) # Convert back to tensor new_image = torch.from_numpy(np.array(new_image)).permute(2, 0, 1).float() / 255.0 return new_image def create_image_grid(images, prompts, images_per_prompt, font_size=20): """Create a grid of images with captions""" # First add captions to all images captioned_images = [] for i, img in enumerate(images): prompt_idx = i // images_per_prompt img_idx = i % images_per_prompt + 1 caption = f"{prompts[prompt_idx]} ({img_idx}/{images_per_prompt})" img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0 captioned_img = add_caption_to_image(img_tensor, caption, font_size) captioned_images.append(captioned_img) # Convert to tensor and create grid image_tensor = torch.stack(captioned_images) grid = make_grid(image_tensor, nrow=images_per_prompt, padding=10) return grid def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--output_path", type=str, required=True, help="path to save the images" ) parser.add_argument( "--content_LoRA", type=str, default=None, help="path for the content LoRA" ) parser.add_argument( "--content_alpha", type=float, default=1.0, help="scale factor for content LoRA weights" ) parser.add_argument( "--style_LoRA", type=str, default=None, help="path for the style LoRA" ) parser.add_argument( "--style_alpha", type=float, default=1.0, help="scale factor for style LoRA weights" ) parser.add_argument( "--num_images_per_prompt", type=int, default=4, help="number of images per prompt" ) parser.add_argument( "--evaluation_prompt_file", type=str, required=True, help="path to evaluation prompts file" ) parser.add_argument( "--placeholder_style", type=str, required=True, help="placeholder for the style prompt" ) parser.add_argument( "--placeholder_content", type=str, required=True, help="placeholder for the content prompt" ) parser.add_argument( "--name_concept", type=str, required=True, help="name of the concept being evaluated" ) parser.add_argument( "--font_size", type=int, default=20, help="font size for image captions" ) return parser.parse_args() def process_prompts(pipeline, prompts, output_dir, args, prompt_type, lora_type, start_idx=0): """Process a set of prompts and save results""" all_images = [] current_idx = start_idx for prompt in prompts: formatted_prompt = prompt.replace("{}", args.placeholder_style if lora_type == "style" else args.placeholder_content) # Update config to use new argument names config = { "gen_prompt": formatted_prompt, "content_LoRA": args.content_LoRA if lora_type == "content" else None, "content_alpha": args.content_alpha if lora_type == "content" else None, "style_LoRA": args.style_LoRA if lora_type == "style" else None, "style_alpha": args.style_alpha if lora_type == "style" else None } # Save config with consecutive numbering config_path = output_dir / f'prompt_{current_idx}_params.json' with open(config_path, 'w') as f: json.dump(config, f, indent=4) # Generate images images = pipeline(formatted_prompt, num_images_per_prompt=args.num_images_per_prompt).images all_images.extend(images) # Save individual images with consecutive numbering prompt_dir = output_dir / 'output' / 'ours' / f'prompt_{current_idx}_{prompt_type}' prompt_dir.mkdir(parents=True, exist_ok=True) for img_idx, img in enumerate(images): img.save(prompt_dir / f'{img_idx:03d}.jpg') current_idx += 1 return all_images, [p.replace("{}", args.placeholder_style if lora_type == "style" else args.placeholder_content) for p in prompts], current_idx if __name__ == '__main__': args = parse_args() # Create timestamped output directory timestamp = datetime.now().strftime("%Y%m%d%H%M%S") result_dir = Path(args.output_path) / f'{args.name_concept}_{timestamp}' result_dir.mkdir(parents=True, exist_ok=True) # Load benchmark prompts with open(args.evaluation_prompt_file, 'r') as f: benchmark_prompts = json.load(f) # Initialize pipeline vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) pipeline = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", vae=vae, torch_dtype=torch.float16 ).to("cuda") current_prompt_idx = 0 # Process content prompts if content LoRA is provided if args.content_LoRA is not None: print("Loading content LoRA...") pipeline.load_lora_weights(args.content_LoRA, scale=args.content_alpha) for category, prompts in benchmark_prompts["content"].items(): print(f"Processing content {category} prompts...") images, formatted_prompts, current_prompt_idx = process_prompts( pipeline, prompts, result_dir, args, f"content_{category}", "content", start_idx=current_prompt_idx ) grid = create_image_grid(images, formatted_prompts, args.num_images_per_prompt, args.font_size) grid_image = Image.fromarray((grid.permute(1, 2, 0).numpy() * 255).astype(np.uint8)) grid_path = result_dir / f'grid_content_{category}.png' grid_image.save(grid_path) # Unload content LoRA pipeline.unload_lora_weights() # Process style prompts if style LoRA is provided if args.style_LoRA is not None: print("Loading style LoRA...") pipeline.load_lora_weights(args.style_LoRA, scale=args.style_alpha) print("Processing style prompts...") images, formatted_prompts, _ = process_prompts( pipeline, benchmark_prompts["style"], result_dir, args, "style", "style", start_idx=current_prompt_idx ) grid = create_image_grid(images, formatted_prompts, args.num_images_per_prompt, args.font_size) grid_image = Image.fromarray((grid.permute(1, 2, 0).numpy() * 255).astype(np.uint8)) grid_path = result_dir / 'grid_style.png' grid_image.save(grid_path) print(f"Results saved to {result_dir}")