lilyecho commited on
Commit
2b3b63a
ยท
verified ยท
1 Parent(s): 674819b

Upload model_zoo/mit_customize_img_ids_bs_32_rank_512_usedataset_controlnetuse_original_size_resolution_1024_customize_img_ids_customize_txt_ids/program.py with huggingface_hub

Browse files
model_zoo/mit_customize_img_ids_bs_32_rank_512_usedataset_controlnetuse_original_size_resolution_1024_customize_img_ids_customize_txt_ids/program.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import copy
5
+ from torch import Tensor
6
+ from omegaconf import OmegaConf
7
+ import wandb
8
+
9
+ from src.flux.modules.layers import LoRALinearLayer, LastLayer
10
+ from src.flux.train_utils import *
11
+ from src.flux.util import load_ae, load_clip, load_flow_model2, load_t5
12
+
13
+ import datetime
14
+ import logging
15
+ import os
16
+ import torch.distributed as dist
17
+ from src.flux.fsdp_utils import setup_model, build_dataloader, save_model_checkpoint, save_optimizer_checkpoint
18
+ from tqdm import tqdm
19
+ from image_datasets.combined_dataset_ar_prepared import MultiHumanDataset
20
+ from src.flux.sampling import denoise, get_noise, get_schedule, prepare, prepare_dual, prepare_dual_train, prepare_dual_train_ar
21
+ import time
22
+ import contextlib
23
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
24
+ from einops import rearrange
25
+ import random
26
+ import torch.nn.functional as F
27
+ import json
28
+ from pathlib import Path
29
+ from src.flux.xflux_pipeline import XFluxSampler
30
+ from PIL import Image, ImageDraw, ImageFont
31
+ import html
32
+
33
+ ################################ Split head for Img_in and Final_Layer#########################
34
+ class ImgInSplit(nn.Module): # must call after loading pre-trained model
35
+ def __init__(self, old_img_in: nn.Linear, keep_ori_weights: bool = False, zero_init: bool = False, img_seq_len: int = 1024):
36
+ super().__init__()
37
+ assert not (keep_ori_weights and zero_init), "keep_ori_weights and zero_init cannot be both True"
38
+ self.old_img_in = old_img_in
39
+
40
+ self.pose_in = copy.deepcopy(old_img_in)
41
+ if not keep_ori_weights:
42
+ if zero_init:
43
+ nn.init.zeros_(self.pose_in.weight)
44
+ nn.init.zeros_(self.pose_in.bias)
45
+ else:
46
+ nn.init.normal_(self.pose_in.weight, mean=0.0, std=0.02)
47
+ nn.init.zeros_(self.pose_in.bias)
48
+
49
+ self.img_seq_len = img_seq_len
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ assert x.dim() == 3, "x should be in shape (B, L1+L2, D)"
53
+ B, L, D = x.shape
54
+ pose_len = L - self.img_seq_len
55
+
56
+ x_pose = x[:, :pose_len, :]
57
+ x_img = x[:, pose_len:, :]
58
+
59
+ x_img = self.old_img_in(x_img)
60
+ x_pose = self.pose_in(x_pose)
61
+
62
+ return torch.cat([x_pose, x_img], dim=1)
63
+
64
+ def forward_pose_only(self, x: Tensor) -> Tensor:
65
+ assert x.dim() == 3, "x should be in shape (B, L1+L2, D)"
66
+
67
+ x_pose = x
68
+ x_pose = self.pose_in(x_pose)
69
+
70
+ return x_pose
71
+
72
+
73
+ class LastLayerSplitTwoMod(nn.Module): # two vec condition version
74
+ """
75
+ Same math as the original LastLayer, but with
76
+ โ€ข two independent output heads (linear1, linear2)
77
+ โ€ข two independent AdaLN modulators (ada1, ada2)
78
+
79
+ Args
80
+ ----
81
+ old_layer : a *loaded* LastLayer whose weights you want to duplicate.
82
+ """
83
+
84
+ def __init__(self, old_layer: "LastLayer", keep_ori_weights: bool = False, zero_init: bool = False, img_seq_len: int = 1024):
85
+ super().__init__()
86
+ self.old_layer = old_layer
87
+
88
+ # duplicate AdaLN MLPs
89
+ self.adaLN_modulation_pose = copy.deepcopy(old_layer.adaLN_modulation)
90
+ if not keep_ori_weights:
91
+ if zero_init:
92
+ nn.init.zeros_(self.adaLN_modulation_pose[1].weight)
93
+ nn.init.zeros_(self.adaLN_modulation_pose[1].bias)
94
+ else:
95
+ nn.init.normal_(self.adaLN_modulation_pose[1].weight, mean=0.0, std=0.02)
96
+ nn.init.zeros_(self.adaLN_modulation_pose[1].bias)
97
+
98
+ # duplicate output heads
99
+ self.linear_pose_img = copy.deepcopy(old_layer.linear)
100
+ if not keep_ori_weights:
101
+ if zero_init:
102
+ nn.init.zeros_(self.linear_pose_img.weight)
103
+ nn.init.zeros_(self.linear_pose_img.bias)
104
+ else:
105
+ nn.init.normal_(self.linear_pose_img.weight, mean=0.0, std=0.02)
106
+ nn.init.zeros_(self.linear_pose_img.bias)
107
+ self.img_seq_len = img_seq_len
108
+
109
+ # ---------------------------------------------------------------------
110
+ def forward(self, x: Tensor, vec1: Tensor, vec2: Tensor) -> Tensor:
111
+ """
112
+ x : (B, 2*T, hidden_size)
113
+ vec1 : (B, hidden_size) โ€“ conditioning for the *first* half
114
+ vec2 : (B, hidden_size) โ€“ conditioning for the *second* half
115
+ """
116
+ assert x.dim() == 3, "x should be in shape (B, L1+L2, D)"
117
+ B, L, D = x.shape
118
+ pose_len = L - self.img_seq_len
119
+
120
+ x_pose = x[:, :pose_len, :] # contain cond_pose and gen_pose
121
+ x_img = x[:, pose_len:, :]
122
+
123
+ # branch 1
124
+ shift, scale = self.old_layer.adaLN_modulation(vec1).chunk(2, dim=1)
125
+ x_img = (1 + scale[:, None, :]) * self.old_layer.norm_final(x_img) + shift[:, None, :]
126
+ x_img = self.old_layer.linear(x_img)
127
+
128
+ # branch 2
129
+ shift_pose, scale_pose = self.adaLN_modulation_pose(vec2).chunk(2, dim=1)
130
+ x_pose = (1 + scale_pose[:, None, :]) * self.old_layer.norm_final(x_pose) + shift_pose[:, None, :] # ERROR!
131
+ x_pose = self.linear_pose_img(x_pose)
132
+
133
+ # print("shape of [x_pose, x_img]", x_pose.shape, x_img.shape)
134
+
135
+ return torch.cat([x_pose, x_img], dim=1)
136
+
137
+ def forward_pose_only(self, x: Tensor, vec2: Tensor) -> Tensor:
138
+ """
139
+ x : (B, 2*T, hidden_size)
140
+ vec1 : (B, hidden_size) โ€“ conditioning for the *first* half
141
+ vec2 : (B, hidden_size) โ€“ conditioning for the *second* half
142
+ """
143
+ assert x.dim() == 3, "x should be in shape (B, L1+L2, D)"
144
+ x_pose = x
145
+
146
+ # branch 2
147
+ shift_pose, scale_pose = self.adaLN_modulation_pose(vec2).chunk(2, dim=1)
148
+ x_pose = (1 + scale_pose[:, None, :]) * self.old_layer.norm_final(x_pose) + shift_pose[:, None, :] # ERROR!
149
+ x_pose = self.linear_pose_img(x_pose)
150
+
151
+ return x_pose
152
+
153
+
154
+ def replace_split_head(dit, args):
155
+ old_img_in = dit.img_in
156
+ dit.img_in = ImgInSplit(old_img_in, keep_ori_weights=args.keep_ori_weights, zero_init=args.zero_init, img_seq_len=args.img_seq_len)
157
+
158
+ old_final_layer = dit.final_layer
159
+ dit.final_layer = LastLayerSplitTwoMod(old_final_layer, keep_ori_weights=args.keep_ori_weights, zero_init=args.zero_init, img_seq_len=args.img_seq_len)
160
+
161
+ def reduce_loss(loss: torch.Tensor) -> float:
162
+ """
163
+ loss : scalar tensor on *this* rank (already averaged over local-batch)
164
+ returns : python float = mean(loss) over all ranks
165
+ """
166
+ with torch.no_grad():
167
+ dist.all_reduce(loss, op=dist.ReduceOp.SUM) # ฮฃ over ranks
168
+ loss /= dist.get_world_size() # โ†’ average
169
+ return loss.item()
170
+
171
+ import matplotlib.pyplot as plt
172
+ import numpy as np
173
+
174
+ def draw_bboxes_on_image(
175
+ image_size: tuple = (512, 512),
176
+ background_color: str = 'black',
177
+ bboxes: list[list[int]] = None,
178
+ bbox_colors: list[str] = None,
179
+ line_width: int = 3,
180
+ title: str = "Bounding Boxes"
181
+ ) -> Image.Image:
182
+ if bboxes is None:
183
+ bboxes = []
184
+ if bbox_colors is None:
185
+ bbox_colors = ["red", "green", "blue", "purple", "orange"]
186
+
187
+ # Create the image with the specified background color
188
+ img = Image.new('RGB', image_size, color=background_color)
189
+ draw = ImageDraw.Draw(img)
190
+
191
+ # Draw each bounding box on the image
192
+ for i, bbox in enumerate(bboxes):
193
+ x1, y1, x2, y2 = bbox
194
+ color = bbox_colors[i % len(bbox_colors)] # Cycle through colors
195
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=line_width)
196
+
197
+ # Display the image
198
+ plt.figure(figsize=(image_size[0]/80, image_size[1]/80)) # Adjust figsize dynamically
199
+ plt.imshow(np.array(img))
200
+ plt.title(title)
201
+ plt.axis('off')
202
+ plt.show()
203
+
204
+ return img
205
+
206
+
207
+ def draw_bboxes_on_existing_image(
208
+ image: Image.Image,
209
+ bboxes: list[list[int]] = None,
210
+ bbox_colors: list[str] = None,
211
+ line_width: int = 3,
212
+ ) -> Image.Image:
213
+ """
214
+ Draw bounding boxes on an existing PIL Image
215
+ """
216
+ if bboxes is None:
217
+ return image
218
+ if bbox_colors is None:
219
+ bbox_colors = ["red", "green", "blue", "purple", "orange"]
220
+
221
+ # Create a copy to avoid modifying the original
222
+ img_with_boxes = image.copy()
223
+ draw = ImageDraw.Draw(img_with_boxes)
224
+
225
+ # Draw each bounding box on the image
226
+ for i, bbox in enumerate(bboxes):
227
+ x1, y1, x2, y2 = bbox
228
+ color = bbox_colors[i % len(bbox_colors)] # Cycle through colors
229
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=line_width)
230
+
231
+ return img_with_boxes
232
+
233
+ # ------------------------------
234
+ # Utility: build / refresh an HTML gallery showing generated samples
235
+ # ------------------------------------------------------------------
236
+
237
+ def _refresh_html_gallery(base_save_dir: str, inference_dir: str, json_path: str, seeds: list[int], html_filename: str):
238
+ """Regenerate an HTML gallery of all saved images.
239
+
240
+ The directory layout is expected to be:
241
+ base_save_dir / inference_dir / prompt_<idx> / variation_<var_idx> / seed_<seed>.jpg
242
+
243
+ Args
244
+ ----
245
+ base_save_dir : root directory where images are stored ("save_dir")
246
+ inference_dir : sub-directory containing the samples (args.inference_output_dir)
247
+ json_path : path to the prompt JSON to fetch text descriptions
248
+ seeds : list of seeds used (for consistent ordering)
249
+ html_filename : full path to output HTML file. Will be overwritten each call.
250
+ """
251
+
252
+ try:
253
+ with open(json_path, "r") as f_json:
254
+ prompt_data = json.load(f_json)
255
+ except Exception as e:
256
+ print(f"โŒ Failed to load JSON for HTML refresh: {e}")
257
+ return
258
+
259
+ root_dir = os.path.join(base_save_dir, inference_dir)
260
+
261
+ html_parts = [
262
+ "<html>",
263
+ "<head>",
264
+ "<meta charset='utf-8' />",
265
+ "<title>Inference Gallery</title>",
266
+ "<style>\n",
267
+ "body { font-family: Arial, sans-serif; }\n",
268
+ "h2 { margin-top: 40px; border-bottom: 1px solid #ccc; padding-bottom: 4px;}\n",
269
+ "h3 { margin-top: 20px; color: #555;}\n",
270
+ ".img-row { display: flex; flex-wrap: wrap; gap: 8px; }\n",
271
+ ".img-row img { max-width: 256px; height: auto; border: 1px solid #ddd;}\n",
272
+ "</style>",
273
+ "</head>",
274
+ "<body>",
275
+ f"<h1>Inference Gallery ({html.escape(os.path.basename(html_filename))})</h1>",
276
+ ]
277
+
278
+ for idx, item in enumerate(prompt_data):
279
+ prompt_dir = os.path.join(root_dir, f"prompt_{idx}")
280
+ if not os.path.isdir(prompt_dir):
281
+ # No images yet for this prompt
282
+ continue
283
+
284
+ general_prompt = html.escape(item.get("general_prompt", ""))
285
+ prompt_list_text = "<br/>".join(html.escape(t) for t in item.get("prompt_list", []))
286
+
287
+ html_parts.append(f"<h2>Prompt {idx}: {general_prompt}</h2>")
288
+ if prompt_list_text:
289
+ html_parts.append(f"<p style='margin-left:20px;'>{prompt_list_text}</p>")
290
+
291
+ num_variations = len(item.get("variations", []))
292
+ for var_idx in range(num_variations):
293
+ var_dir = os.path.join(prompt_dir, f"variation_{var_idx}")
294
+ if not os.path.isdir(var_dir):
295
+ continue # variation not generated yet
296
+
297
+ html_parts.append(f"<h3>Variation {var_idx}</h3>")
298
+ html_parts.append("<div class='img-row'>")
299
+
300
+ for seed in seeds:
301
+ img_path_abs = os.path.join(var_dir, f"seed_{seed}.jpg")
302
+ if os.path.exists(img_path_abs):
303
+ img_path_rel = os.path.relpath(img_path_abs, os.path.dirname(html_filename))
304
+ html_parts.append(f"<img src='{img_path_rel}' alt='prompt{idx}_var{var_idx}_seed{seed}' />")
305
+
306
+ html_parts.append("</div>")
307
+
308
+ html_parts.extend(["</body>", "</html>"])
309
+
310
+ try:
311
+ with open(html_filename, "w") as f_html:
312
+ f_html.write("\n".join(html_parts))
313
+ except Exception as e:
314
+ print(f"โŒ Failed to write HTML gallery: {e}")
315
+
316
+
317
+ def sample_steps_inference(dit, args, global_step, wandbrun, rank, offload=True, save_dir=None):
318
+ """Run inference using prompts and bounding box variations defined in an external JSON file."""
319
+
320
+ if not hasattr(args, "sample_prompts_json"):
321
+ raise ValueError("`args.sample_prompts_json` must be provided when using JSON-based prompts.")
322
+
323
+ # ------------------------------------------------------------------
324
+ # Load prompt information from JSON
325
+ # ------------------------------------------------------------------
326
+ with open(args.sample_prompts_json, "r") as f_json:
327
+ sample_prompts = json.load(f_json) # List[dict]
328
+
329
+ total_variations = sum(len(item.get("variations", [])) for item in sample_prompts)
330
+ total_samples_to_generate = total_variations * len(args.sample_seeds)
331
+
332
+ if rank == 0:
333
+ print(
334
+ f"๐ŸŽฏ Starting inference: {len(sample_prompts)} prompts ร— {len(args.sample_seeds)} seeds ร— "
335
+ f"avg {total_variations/len(sample_prompts):.1f} variations โ‰ˆ {total_samples_to_generate} total samples"
336
+ )
337
+
338
+ sample_count = 0
339
+
340
+ # Determine HTML output file (named by current global_step)
341
+ #html_output_path = os.path.join(save_dir, f"inference_{global_step}.html")
342
+ if args.use_v1_bbox:
343
+ html_output_path = os.path.join(save_dir, f"inference_{global_step}_use_v1_bbox.html")
344
+ else:
345
+ html_output_path = os.path.join(save_dir, f"inference_{global_step}_normal_bbox.html")
346
+
347
+ for prompt_idx, prompt_dict in enumerate(sample_prompts):
348
+ # if prompt_idx <= 0:
349
+ # continue
350
+ # for odd prompt_idx, skip
351
+ prompts = prompt_dict["prompt_list"]
352
+ general_prompt = prompt_dict["general_prompt"]
353
+
354
+ variations = prompt_dict.get("annotated_variations", [])
355
+ if rank == 0:
356
+ print(
357
+ f"๐Ÿ“ Processing prompt {prompt_idx}: '{general_prompt[:50]}...' with {len(variations)} variations"
358
+ )
359
+
360
+ for var_idx, var_data in enumerate(variations):
361
+ # Convert normalized coordinates (0-1) to absolute pixel coordinates
362
+ bounding_boxes_in_order = [
363
+ [
364
+ int(bb[0] * args.sample_width),
365
+ int(bb[1] * args.sample_height),
366
+ int(bb[2] * args.sample_width),
367
+ int(bb[3] * args.sample_height),
368
+ ]
369
+ for bb in var_data["bboxes"]
370
+ ]
371
+ # reverse the order of the bounding boxes
372
+ # bounding_boxes_in_order.reverse() # for bugging TODO
373
+
374
+ bounding_boxes_image = draw_bboxes_on_image(
375
+ image_size=(args.sample_width, args.sample_height),
376
+ bboxes=bounding_boxes_in_order,
377
+ )
378
+
379
+ for seed_idx, seed in enumerate(args.sample_seeds):
380
+ sample_count += 1
381
+ if rank == 0:
382
+ print(
383
+ f"๐ŸŒฑ Generating sample {sample_count}/{total_samples_to_generate} - "
384
+ f"Prompt {prompt_idx}, Variation {var_idx}, Seed {seed}"
385
+ )
386
+
387
+ sample_step(
388
+ dit,
389
+ args,
390
+ prompt_idx,
391
+ var_idx,
392
+ prompts,
393
+ general_prompt,
394
+ bounding_boxes_in_order,
395
+ bounding_boxes_image,
396
+ global_step,
397
+ wandbrun,
398
+ rank,
399
+ offload=offload,
400
+ seed_idx=seed_idx,
401
+ save_dir=save_dir,
402
+ seed=seed,
403
+ html_output_path=html_output_path,
404
+ )
405
+
406
+ if rank == 0:
407
+ print(f"โœ… Completed inference: Generated {sample_count} samples")
408
+
409
+ # Added `var_idx` parameter to support multiple bounding box variations per prompt
410
+ def sample_step(dit, args, prompt_idx, var_idx, prompts, general_prompt, bounding_boxes_in_order, bounding_boxes_image, global_step, wandbrun, rank, offload=True, seed_idx=0, save_dir=None, seed=None, html_output_path=None):
411
+ # Use provided seed or fallback to first seed
412
+ if seed is None:
413
+ seed = args.sample_seeds[0]
414
+
415
+ if rank == 0:
416
+ print(
417
+ f"๐Ÿ” DEBUG: Inside sample_step - received idx={prompt_idx}, var_idx={var_idx}, seed={seed}, seed_idx={seed_idx}"
418
+ )
419
+
420
+ image_name = (
421
+ f"Inference Results for step {global_step}, prompt {prompt_idx}, variation {var_idx}, seed {seed}"
422
+ )
423
+ local_gpu = torch.cuda.current_device()
424
+ if rank == 0:
425
+ print(f"๐ŸŽจ Generating images: step={global_step}, prompt_idx={prompt_idx}, seed={seed}")
426
+ sampler = XFluxSampler(clip=None, t5=None, ae=None, model=dit, device=f"cuda:{local_gpu}", offload=offload)
427
+
428
+ all_rounds_images = []
429
+
430
+ # Use autoregressive sampling with multiple rounds
431
+ rounds_output, clip, t5, vae = sampler.forward_multiperson(
432
+ prompts=prompts,
433
+ general_prompt=general_prompt,
434
+ width=args.sample_width,
435
+ height=args.sample_height,
436
+ num_steps=args.sample_steps,
437
+ seed=seed,
438
+ customize_img_ids=args.customize_img_ids,
439
+ customize_txt_ids=args.customize_txt_ids,
440
+ bounding_boxes_in_order=bounding_boxes_in_order,
441
+ use_v1_bbox=args.use_v1_bbox
442
+ )
443
+
444
+ # add visualization codes here
445
+ # Helper to create a centered text banner of given width
446
+ def _create_text_banner(text: str, width: int, font: ImageFont.FreeTypeFont, padding: int = 10, bg_color: str = "white", text_color: str = "black"):
447
+ draw_dummy = ImageDraw.Draw(Image.new('RGB', (1, 1)))
448
+
449
+ # Split text into lines that fit the banner width
450
+ max_text_width = width - 2 * padding
451
+ words = text.split()
452
+ lines = []
453
+ current_line = ""
454
+ for word in words:
455
+ test_line = f"{current_line} {word}".strip()
456
+ # Measure width of the test line
457
+ if hasattr(draw_dummy, "textbbox"):
458
+ bbox = draw_dummy.textbbox((0, 0), test_line, font=font)
459
+ line_w = bbox[2] - bbox[0]
460
+ line_h = bbox[3] - bbox[1]
461
+ else:
462
+ try:
463
+ line_w, line_h = draw_dummy.textsize(test_line, font=font)
464
+ except AttributeError:
465
+ bbox = font.getbbox(test_line)
466
+ line_w = bbox[2] - bbox[0]
467
+ line_h = bbox[3] - bbox[1]
468
+ if line_w <= max_text_width:
469
+ current_line = test_line
470
+ else:
471
+ if current_line:
472
+ lines.append(current_line)
473
+ current_line = word
474
+ if current_line:
475
+ lines.append(current_line)
476
+
477
+ # Determine banner height based on number of lines
478
+ text_height = line_h # height of one line
479
+ banner_height = len(lines) * text_height + (len(lines) + 1) * padding
480
+ banner = Image.new('RGB', (width, banner_height), color=bg_color)
481
+ draw = ImageDraw.Draw(banner)
482
+
483
+ y = padding
484
+ for line in lines:
485
+ if hasattr(draw_dummy, "textbbox"):
486
+ bbox = draw_dummy.textbbox((0, 0), line, font=font)
487
+ text_w = bbox[2] - bbox[0]
488
+ else:
489
+ try:
490
+ text_w, _ = draw_dummy.textsize(line, font=font)
491
+ except AttributeError:
492
+ bbox = font.getbbox(line)
493
+ text_w = bbox[2] - bbox[0]
494
+ draw.text(((width - text_w) // 2, y), line, fill=text_color, font=font)
495
+ y += text_height + padding
496
+
497
+ return banner
498
+
499
+ # Prepare font
500
+ try:
501
+ font = ImageFont.truetype("DejaVuSans.ttf", size=16)
502
+ except Exception:
503
+ font = ImageFont.load_default()
504
+
505
+ round_keys = sorted(rounds_output.keys(), key=lambda x: int(x))
506
+ per_round_images = []
507
+
508
+ for round_idx, key in enumerate(round_keys):
509
+ round_data = rounds_output[key]
510
+ pose_img = round_data["pose_img"]
511
+ real_img = round_data["real_img"]
512
+
513
+ # Ensure both images are RGB PIL Images of same height
514
+ if pose_img.mode != 'RGB':
515
+ pose_img = pose_img.convert('RGB')
516
+ if real_img.mode != 'RGB':
517
+ real_img = real_img.convert('RGB')
518
+
519
+ # Draw bounding boxes on pose image - show cumulative people up to this round
520
+ if bounding_boxes_in_order is not None:
521
+ # For round 0: show first person's bbox, round 1: show first + second person's bboxes, etc.
522
+ bboxes_to_show = bounding_boxes_in_order[:round_idx+1] # +1 because round 0 = 1 person
523
+ if rank == 0:
524
+ print(f"๐ŸŽฏ Round {key}: Drawing {len(bboxes_to_show)} bounding boxes on pose image")
525
+ pose_img = draw_bboxes_on_existing_image(pose_img, bboxes_to_show, line_width=2)
526
+
527
+ concat_width = pose_img.width + real_img.width
528
+ concat_height = max(pose_img.height, real_img.height)
529
+ concat_img = Image.new('RGB', (concat_width, concat_height), color='white')
530
+ concat_img.paste(pose_img, (0, 0))
531
+ concat_img.paste(real_img, (pose_img.width, 0))
532
+
533
+ title_banner = _create_text_banner(f"Round {key}", concat_width, font)
534
+ prompt_text = prompts[round_idx] if round_idx < len(prompts) else ""
535
+ prompt_banner = _create_text_banner(prompt_text, concat_width, font)
536
+
537
+ total_h = title_banner.height + concat_img.height + prompt_banner.height
538
+ round_img = Image.new('RGB', (concat_width, total_h), color='white')
539
+ y_offset = 0
540
+ round_img.paste(title_banner, (0, y_offset)); y_offset += title_banner.height
541
+ round_img.paste(concat_img, (0, y_offset)); y_offset += concat_img.height
542
+ round_img.paste(prompt_banner, (0, y_offset))
543
+
544
+ per_round_images.append(round_img)
545
+
546
+ # Determine final composite dimensions
547
+ final_width = max(img.width for img in per_round_images)
548
+ general_banner = _create_text_banner(general_prompt, final_width, font)
549
+ seed_banner = _create_text_banner(f"Seed: {seed}", final_width, font, bg_color="lightblue")
550
+ final_height = general_banner.height + seed_banner.height + sum(img.height for img in per_round_images)
551
+
552
+ final_img = Image.new('RGB', (final_width, final_height), color='white')
553
+ y_offset = 0
554
+ final_img.paste(general_banner, (0, y_offset)); y_offset += general_banner.height
555
+ final_img.paste(seed_banner, (0, y_offset)); y_offset += seed_banner.height
556
+ for img in per_round_images:
557
+ final_img.paste(img, (0, y_offset))
558
+ y_offset += img.height
559
+
560
+ # Add bounding boxes image in top-right corner
561
+ if bounding_boxes_image is not None:
562
+ # Resize bounding boxes image to a smaller size
563
+ bbox_img_size = min(200, final_width // 4, final_height // 4) # Adaptive size
564
+ bbox_img_resized = bounding_boxes_image.resize((bbox_img_size, bbox_img_size), Image.Resampling.LANCZOS)
565
+
566
+ # Position in top-right corner with some padding
567
+ padding = 10
568
+ bbox_x = final_width - bbox_img_size - padding
569
+ bbox_y = padding
570
+
571
+ # Ensure we don't go out of bounds
572
+ bbox_x = max(0, bbox_x)
573
+ bbox_y = max(0, bbox_y)
574
+
575
+ # Paste the resized bounding boxes image
576
+ final_img.paste(bbox_img_resized, (bbox_x, bbox_y))
577
+
578
+ # Log the image to wandb if available and on rank 0
579
+ if wandbrun is not None and rank == 0:
580
+ wandb_caption = f"{general_prompt} (Var: {var_idx}, Seed: {seed})"
581
+ wandbrun.log({f"sample/{image_name}": wandb.Image(final_img, caption=wandb_caption)}, step=global_step)
582
+
583
+ # Save locally for inspection (only on rank 0 to avoid conflicts)
584
+ if rank == 0:
585
+ print(f"๐Ÿ” DEBUG: About to save with idx={prompt_idx}, var_idx={var_idx}, seed={seed}")
586
+ save_path = os.path.join(
587
+ save_dir,
588
+ args.inference_output_dir,
589
+ f"prompt_{prompt_idx}",
590
+ f"variation_{var_idx}",
591
+ f"seed_{seed}.jpg",
592
+ )
593
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
594
+
595
+ try:
596
+ final_img.save(save_path, format="JPEG", quality=95)
597
+ print(f"๐Ÿ’พ Saved: {save_path}")
598
+ except Exception as e:
599
+ print(f"โŒ Failed to save {save_path}: {e}")
600
+ # After saving the image, refresh HTML gallery (only rank 0)
601
+ if html_output_path is not None:
602
+ _refresh_html_gallery(
603
+ base_save_dir=save_dir,
604
+ inference_dir=args.inference_output_dir,
605
+ json_path=args.sample_prompts_json,
606
+ seeds=args.sample_seeds,
607
+ html_filename=html_output_path,
608
+ )
609
+
610
+ else:
611
+ print(f"โญ๏ธ Rank {rank}: Skipping save (only rank 0 saves files)")
612
+
613
+ del clip, t5, vae
614
+ dit.to(f"cuda:{local_gpu}")
615
+
616
+
617
+
618
+
619
+
620
+
621
+
622
+ def main():
623
+ args = OmegaConf.load(parse_args())
624
+ args.exp_name, args.save_dir = generate_exp_name("ar_triplelora_v0", args, "bs", "rank", "use_dataset", "resize_to_square", "resolution", "customize_img_ids", "customize_txt_ids", "generate_img_ids_type", "background_color", "loss_pose_background_lambda", "double_real_lora", "single_real_lora","real_lr_scale")
625
+
626
+ args.training_width = args.resolution
627
+ args.training_height = args.resolution
628
+ args.sample_width = args.resolution
629
+ args.sample_height = args.resolution
630
+ args.img_seq_len = (args.resolution // 16) * (args.resolution // 16) # TODO check here is 1024 in original repo
631
+ args.cond_seq_len = (args.resolution // 16) * (args.resolution // 16) # TODO check here is 1024 in original repo
632
+ save_dir = Path.cwd() / args.save_dir / args.exp_name
633
+ os.makedirs(save_dir, exist_ok=True)
634
+
635
+ if args.use_v1_bbox:
636
+ args.inference_output_dir.replace("samples", "samples_use_v1_bbox")
637
+ else:
638
+ args.inference_output_dir.replace("samples", "samples_train_bbox")
639
+ # save configs
640
+ with open(save_dir / "config.yaml", "w") as f:
641
+ OmegaConf.save(config=args, f=f)
642
+
643
+ # save programe file
644
+ with open(save_dir / "program.py", "w") as f:
645
+ f.write(open(__file__).read())
646
+
647
+ rank = dist.get_rank()
648
+ if args.use_wandb:
649
+ wandb_run = setup_wandb(args, rank)
650
+ logging.info("***** Preparing model *****")
651
+ local_gpu = torch.cuda.current_device()
652
+
653
+ t5 = load_t5(f"cuda:{local_gpu}", max_length=512)
654
+ clip = load_clip(f"cuda:{local_gpu}")
655
+
656
+ # load dit to all rank's cpu: now every rank hold a copy of dit on cpu
657
+ dit = load_flow_model2(args.model_name, device="cpu") # handle gradient checkpointing in fsdp_utils.py
658
+ ##### replace module / add lora #########################################################
659
+ if args.use_lora:
660
+ print("Using triple LoRA version")
661
+ replace_attn_processor_triplelora_ar(dit, args) # add lora to transformer_blocks (attn & mlp)
662
+ else:
663
+ print("not using LoRA, finetuning all parameters")
664
+
665
+ replace_split_head(dit, args) # split head for img_in and final_layer
666
+
667
+ ###### set trainable parameters ############################################################
668
+ trainable_names = args.trainable_names # ['img_in', 'final_layer']
669
+ if args.use_lora:
670
+ trainable_names.append('_lora') # attn_lora, proj_lora, mod_lora
671
+ disable_grad(dit, trainable_names) # dit.train() inside disable_grad()
672
+ else:
673
+ dit.train() # train all parameters
674
+
675
+ dit.to(torch.bfloat16)
676
+ ##### FSDP setup #########################################################################
677
+ logging.info("***** FSDP setup *****")
678
+ dit, optimizer, global_step = setup_model(dit, args) # TODO will need to update parameter group lr before every optimizer step
679
+
680
+ logging.info("***** Sample step once before training start *****")
681
+
682
+
683
+ sample_steps_inference(dit, args, global_step, wandb_run, rank, offload=args.offload_when_sample, save_dir=save_dir)
684
+
685
+ # Print summary of what should be generated
686
+ if rank == 0:
687
+ # Recompute expected file count based on JSON structure (prompts ร— variations ร— seeds)
688
+ with open(args.sample_prompts_json, "r") as _fjson:
689
+ _sample_prompts_tmp = json.load(_fjson)
690
+ expected_files = (
691
+ sum(len(item.get("variations", [])) for item in _sample_prompts_tmp)
692
+ * len(args.sample_seeds)
693
+ )
694
+ samples_dir = os.path.join(save_dir, args.inference_output_dir)
695
+ if os.path.exists(samples_dir):
696
+ actual_files = len([f for f in os.listdir(samples_dir) if f.endswith('.jpg')])
697
+ print(f"๐Ÿ“Š Summary: Expected {expected_files} files, found {actual_files} files in {samples_dir}")
698
+ else:
699
+ print(f"๐Ÿ“Š Summary: Expected {expected_files} files, but {samples_dir} doesn't exist yet")
700
+
701
+
702
+ if __name__ == "__main__":
703
+ main()
704
+
705
+
706
+
707
+ # torchrun --nproc_per_node 2 --master_port 22484 v0_ar_triplelora_infer_customize_ids_by_json2.py --config train_configs/v0/ar_inference_customize_ids_by_json1024_2.yaml