trislee02 commited on
Commit
4d217d6
·
verified ·
1 Parent(s): c4d7110

Create train_lora.py

Browse files
Files changed (1) hide show
  1. train_lora.py +207 -0
train_lora.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ from datetime import datetime
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image, ImageDraw, ImageFont
8
+ from torchvision.utils import make_grid
9
+ from diffusers import StableDiffusionXLPipeline, AutoencoderKL
10
+
11
+ try:
12
+ from pytorch_msssim import ssim, ms_ssim
13
+ except ImportError:
14
+ print("Installing pytorch-msssim...")
15
+ import subprocess
16
+ subprocess.check_call(["pip", "install", "pytorch-msssim"])
17
+ from pytorch_msssim import ssim, ms_ssim
18
+
19
+
20
+ def add_caption_to_image(image, caption, font_size=20):
21
+ """Add caption to image and return as tensor"""
22
+ # Convert tensor to PIL Image if needed
23
+ if isinstance(image, torch.Tensor):
24
+ image = (image * 255).clamp(0, 255).to(torch.uint8)
25
+ image = image.permute(1, 2, 0).cpu().numpy()
26
+ image = Image.fromarray(image)
27
+
28
+ # Create new image with space for caption
29
+ margin = 10
30
+ width = image.width
31
+ height = image.height + font_size + 2*margin
32
+ new_image = Image.new('RGB', (width, height), 'white')
33
+ new_image.paste(image, (0, 0))
34
+
35
+ # Add caption
36
+ draw = ImageDraw.Draw(new_image)
37
+ try:
38
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", font_size)
39
+ except:
40
+ font = ImageFont.load_default()
41
+
42
+ # Center the text
43
+ text_width = draw.textlength(caption, font=font)
44
+ x = (width - text_width) // 2
45
+ y = height - font_size - margin
46
+
47
+ draw.text((x, y), caption, fill='black', font=font)
48
+
49
+ # Convert back to tensor
50
+ new_image = torch.from_numpy(np.array(new_image)).permute(2, 0, 1).float() / 255.0
51
+ return new_image
52
+
53
+
54
+ def create_image_grid(images, prompts, images_per_prompt, font_size=20):
55
+ """Create a grid of images with captions"""
56
+ # First add captions to all images
57
+ captioned_images = []
58
+ for i, img in enumerate(images):
59
+ prompt_idx = i // images_per_prompt
60
+ img_idx = i % images_per_prompt + 1
61
+ caption = f"{prompts[prompt_idx]} ({img_idx}/{images_per_prompt})"
62
+ img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
63
+ captioned_img = add_caption_to_image(img_tensor, caption, font_size)
64
+ captioned_images.append(captioned_img)
65
+
66
+ # Convert to tensor and create grid
67
+ image_tensor = torch.stack(captioned_images)
68
+ grid = make_grid(image_tensor, nrow=images_per_prompt, padding=10)
69
+
70
+ return grid
71
+
72
+
73
+ def parse_args():
74
+ parser = argparse.ArgumentParser()
75
+ parser.add_argument(
76
+ "--output_path", type=str, required=True, help="path to save the images"
77
+ )
78
+ parser.add_argument(
79
+ "--content_LoRA", type=str, default=None, help="path for the content LoRA"
80
+ )
81
+ parser.add_argument(
82
+ "--content_alpha", type=float, default=1.0, help="scale factor for content LoRA weights"
83
+ )
84
+ parser.add_argument(
85
+ "--style_LoRA", type=str, default=None, help="path for the style LoRA"
86
+ )
87
+ parser.add_argument(
88
+ "--style_alpha", type=float, default=1.0, help="scale factor for style LoRA weights"
89
+ )
90
+ parser.add_argument(
91
+ "--num_images_per_prompt", type=int, default=4, help="number of images per prompt"
92
+ )
93
+ parser.add_argument(
94
+ "--evaluation_prompt_file", type=str, required=True, help="path to evaluation prompts file"
95
+ )
96
+ parser.add_argument(
97
+ "--placeholder_style", type=str, required=True, help="placeholder for the style prompt"
98
+ )
99
+ parser.add_argument(
100
+ "--placeholder_content", type=str, required=True, help="placeholder for the content prompt"
101
+ )
102
+ parser.add_argument(
103
+ "--name_concept", type=str, required=True, help="name of the concept being evaluated"
104
+ )
105
+ parser.add_argument(
106
+ "--font_size", type=int, default=20, help="font size for image captions"
107
+ )
108
+ return parser.parse_args()
109
+
110
+
111
+ def process_prompts(pipeline, prompts, output_dir, args, prompt_type, lora_type, start_idx=0):
112
+ """Process a set of prompts and save results"""
113
+ all_images = []
114
+ current_idx = start_idx
115
+
116
+ for prompt in prompts:
117
+ formatted_prompt = prompt.replace("{}", args.placeholder_style if lora_type == "style" else args.placeholder_content)
118
+
119
+ # Update config to use new argument names
120
+ config = {
121
+ "gen_prompt": formatted_prompt,
122
+ "content_LoRA": args.content_LoRA if lora_type == "content" else None,
123
+ "content_alpha": args.content_alpha if lora_type == "content" else None,
124
+ "style_LoRA": args.style_LoRA if lora_type == "style" else None,
125
+ "style_alpha": args.style_alpha if lora_type == "style" else None
126
+ }
127
+
128
+ # Save config with consecutive numbering
129
+ config_path = output_dir / f'prompt_{current_idx}_params.json'
130
+ with open(config_path, 'w') as f:
131
+ json.dump(config, f, indent=4)
132
+
133
+ # Generate images
134
+ images = pipeline(formatted_prompt, num_images_per_prompt=args.num_images_per_prompt).images
135
+ all_images.extend(images)
136
+
137
+ # Save individual images with consecutive numbering
138
+ prompt_dir = output_dir / 'output' / 'ours' / f'prompt_{current_idx}_{prompt_type}'
139
+ prompt_dir.mkdir(parents=True, exist_ok=True)
140
+
141
+ for img_idx, img in enumerate(images):
142
+ img.save(prompt_dir / f'{img_idx:03d}.jpg')
143
+
144
+ current_idx += 1
145
+
146
+ return all_images, [p.replace("{}", args.placeholder_style if lora_type == "style" else args.placeholder_content) for p in prompts], current_idx
147
+
148
+
149
+ if __name__ == '__main__':
150
+ args = parse_args()
151
+
152
+ # Create timestamped output directory
153
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
154
+ result_dir = Path(args.output_path) / f'{args.name_concept}_{timestamp}'
155
+ result_dir.mkdir(parents=True, exist_ok=True)
156
+
157
+ # Load benchmark prompts
158
+ with open(args.evaluation_prompt_file, 'r') as f:
159
+ benchmark_prompts = json.load(f)
160
+
161
+ # Initialize pipeline
162
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
163
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
164
+ "stabilityai/stable-diffusion-xl-base-1.0",
165
+ vae=vae,
166
+ torch_dtype=torch.float16
167
+ ).to("cuda")
168
+
169
+ current_prompt_idx = 0
170
+
171
+ # Process content prompts if content LoRA is provided
172
+ if args.content_LoRA is not None:
173
+ print("Loading content LoRA...")
174
+ pipeline.load_lora_weights(args.content_LoRA, scale=args.content_alpha)
175
+
176
+ for category, prompts in benchmark_prompts["content"].items():
177
+ print(f"Processing content {category} prompts...")
178
+ images, formatted_prompts, current_prompt_idx = process_prompts(
179
+ pipeline, prompts, result_dir, args, f"content_{category}", "content",
180
+ start_idx=current_prompt_idx
181
+ )
182
+
183
+ grid = create_image_grid(images, formatted_prompts, args.num_images_per_prompt, args.font_size)
184
+ grid_image = Image.fromarray((grid.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
185
+ grid_path = result_dir / f'grid_content_{category}.png'
186
+ grid_image.save(grid_path)
187
+
188
+ # Unload content LoRA
189
+ pipeline.unload_lora_weights()
190
+
191
+ # Process style prompts if style LoRA is provided
192
+ if args.style_LoRA is not None:
193
+ print("Loading style LoRA...")
194
+ pipeline.load_lora_weights(args.style_LoRA, scale=args.style_alpha)
195
+
196
+ print("Processing style prompts...")
197
+ images, formatted_prompts, _ = process_prompts(
198
+ pipeline, benchmark_prompts["style"], result_dir, args, "style", "style",
199
+ start_idx=current_prompt_idx
200
+ )
201
+
202
+ grid = create_image_grid(images, formatted_prompts, args.num_images_per_prompt, args.font_size)
203
+ grid_image = Image.fromarray((grid.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
204
+ grid_path = result_dir / 'grid_style.png'
205
+ grid_image.save(grid_path)
206
+
207
+ print(f"Results saved to {result_dir}")