File size: 8,056 Bytes
4d217d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import argparse
import json
from pathlib import Path
from datetime import datetime
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from torchvision.utils import make_grid
from diffusers import StableDiffusionXLPipeline, AutoencoderKL

try:
    from pytorch_msssim import ssim, ms_ssim
except ImportError:
    print("Installing pytorch-msssim...")
    import subprocess
    subprocess.check_call(["pip", "install", "pytorch-msssim"])
    from pytorch_msssim import ssim, ms_ssim


def add_caption_to_image(image, caption, font_size=20):
    """Add caption to image and return as tensor"""
    # Convert tensor to PIL Image if needed
    if isinstance(image, torch.Tensor):
        image = (image * 255).clamp(0, 255).to(torch.uint8)
        image = image.permute(1, 2, 0).cpu().numpy()
        image = Image.fromarray(image)
    
    # Create new image with space for caption
    margin = 10
    width = image.width
    height = image.height + font_size + 2*margin
    new_image = Image.new('RGB', (width, height), 'white')
    new_image.paste(image, (0, 0))
    
    # Add caption
    draw = ImageDraw.Draw(new_image)
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", font_size)
    except:
        font = ImageFont.load_default()
    
    # Center the text
    text_width = draw.textlength(caption, font=font)
    x = (width - text_width) // 2
    y = height - font_size - margin
    
    draw.text((x, y), caption, fill='black', font=font)
    
    # Convert back to tensor
    new_image = torch.from_numpy(np.array(new_image)).permute(2, 0, 1).float() / 255.0
    return new_image


def create_image_grid(images, prompts, images_per_prompt, font_size=20):
    """Create a grid of images with captions"""
    # First add captions to all images
    captioned_images = []
    for i, img in enumerate(images):
        prompt_idx = i // images_per_prompt
        img_idx = i % images_per_prompt + 1
        caption = f"{prompts[prompt_idx]} ({img_idx}/{images_per_prompt})"
        img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
        captioned_img = add_caption_to_image(img_tensor, caption, font_size)
        captioned_images.append(captioned_img)
    
    # Convert to tensor and create grid
    image_tensor = torch.stack(captioned_images)
    grid = make_grid(image_tensor, nrow=images_per_prompt, padding=10)
    
    return grid


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--output_path", type=str, required=True, help="path to save the images"
    )
    parser.add_argument(
        "--content_LoRA", type=str, default=None, help="path for the content LoRA"
    )
    parser.add_argument(
        "--content_alpha", type=float, default=1.0, help="scale factor for content LoRA weights"
    )
    parser.add_argument(
        "--style_LoRA", type=str, default=None, help="path for the style LoRA"
    )
    parser.add_argument(
        "--style_alpha", type=float, default=1.0, help="scale factor for style LoRA weights"
    )
    parser.add_argument(
        "--num_images_per_prompt", type=int, default=4, help="number of images per prompt"
    )
    parser.add_argument(
        "--evaluation_prompt_file", type=str, required=True, help="path to evaluation prompts file"
    )
    parser.add_argument(
        "--placeholder_style", type=str, required=True, help="placeholder for the style prompt"
    )
    parser.add_argument(
        "--placeholder_content", type=str, required=True, help="placeholder for the content prompt"
    )
    parser.add_argument(
        "--name_concept", type=str, required=True, help="name of the concept being evaluated"
    )
    parser.add_argument(
        "--font_size", type=int, default=20, help="font size for image captions"
    )
    return parser.parse_args()


def process_prompts(pipeline, prompts, output_dir, args, prompt_type, lora_type, start_idx=0):
    """Process a set of prompts and save results"""
    all_images = []
    current_idx = start_idx
    
    for prompt in prompts:
        formatted_prompt = prompt.replace("{}", args.placeholder_style if lora_type == "style" else args.placeholder_content)
        
        # Update config to use new argument names
        config = {
            "gen_prompt": formatted_prompt,
            "content_LoRA": args.content_LoRA if lora_type == "content" else None,
            "content_alpha": args.content_alpha if lora_type == "content" else None,
            "style_LoRA": args.style_LoRA if lora_type == "style" else None,
            "style_alpha": args.style_alpha if lora_type == "style" else None
        }
        
        # Save config with consecutive numbering
        config_path = output_dir / f'prompt_{current_idx}_params.json'
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=4)

        # Generate images
        images = pipeline(formatted_prompt, num_images_per_prompt=args.num_images_per_prompt).images
        all_images.extend(images)

        # Save individual images with consecutive numbering
        prompt_dir = output_dir / 'output' / 'ours' / f'prompt_{current_idx}_{prompt_type}'
        prompt_dir.mkdir(parents=True, exist_ok=True)
        
        for img_idx, img in enumerate(images):
            img.save(prompt_dir / f'{img_idx:03d}.jpg')
            
        current_idx += 1
    
    return all_images, [p.replace("{}", args.placeholder_style if lora_type == "style" else args.placeholder_content) for p in prompts], current_idx


if __name__ == '__main__':
    args = parse_args()
    
    # Create timestamped output directory
    timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
    result_dir = Path(args.output_path) / f'{args.name_concept}_{timestamp}'
    result_dir.mkdir(parents=True, exist_ok=True)

    # Load benchmark prompts
    with open(args.evaluation_prompt_file, 'r') as f:
        benchmark_prompts = json.load(f)

    # Initialize pipeline
    vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
    pipeline = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        vae=vae,
        torch_dtype=torch.float16
    ).to("cuda")

    current_prompt_idx = 0

    # Process content prompts if content LoRA is provided
    if args.content_LoRA is not None:
        print("Loading content LoRA...")
        pipeline.load_lora_weights(args.content_LoRA, scale=args.content_alpha)

        for category, prompts in benchmark_prompts["content"].items():
            print(f"Processing content {category} prompts...")
            images, formatted_prompts, current_prompt_idx = process_prompts(
                pipeline, prompts, result_dir, args, f"content_{category}", "content",
                start_idx=current_prompt_idx
            )
            
            grid = create_image_grid(images, formatted_prompts, args.num_images_per_prompt, args.font_size)
            grid_image = Image.fromarray((grid.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
            grid_path = result_dir / f'grid_content_{category}.png'
            grid_image.save(grid_path)

        # Unload content LoRA
        pipeline.unload_lora_weights()

    # Process style prompts if style LoRA is provided
    if args.style_LoRA is not None:
        print("Loading style LoRA...")
        pipeline.load_lora_weights(args.style_LoRA, scale=args.style_alpha)

        print("Processing style prompts...")
        images, formatted_prompts, _ = process_prompts(
            pipeline, benchmark_prompts["style"], result_dir, args, "style", "style",
            start_idx=current_prompt_idx
        )
        
        grid = create_image_grid(images, formatted_prompts, args.num_images_per_prompt, args.font_size)
        grid_image = Image.fromarray((grid.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
        grid_path = result_dir / 'grid_style.png'
        grid_image.save(grid_path)

    print(f"Results saved to {result_dir}")