va1bhavagrawa1 commited on
Commit
6abe861
·
1 Parent(s): c714a7e

working training code with both the stages

Browse files
train/default_config.yaml CHANGED
@@ -3,12 +3,12 @@ debug: false
3
  distributed_type: MULTI_GPU
4
  main_process_port: 14121
5
  downcast_bf16: 'no'
6
- gpu_ids: 1,
7
  machine_rank: 0
8
  main_training_function: main
9
  mixed_precision: fp16
10
  num_machines: 1
11
- num_processes: 1
12
  same_network: true
13
  tpu_env: []
14
  tpu_use_cluster: false
 
3
  distributed_type: MULTI_GPU
4
  main_process_port: 14121
5
  downcast_bf16: 'no'
6
+ gpu_ids: 1,3,
7
  machine_rank: 0
8
  main_training_function: main
9
  mixed_precision: fp16
10
  num_machines: 1
11
+ num_processes: 2
12
  same_network: true
13
  tpu_env: []
14
  tpu_use_cluster: false
train/src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (149 Bytes). View file
 
train/src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (165 Bytes). View file
 
train/src/__pycache__/jsonl_datasets.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
train/src/__pycache__/jsonl_datasets.cpython-311.pyc ADDED
Binary file (21.2 kB). View file
 
train/src/__pycache__/layers.cpython-310.pyc ADDED
Binary file (8.96 kB). View file
 
train/src/__pycache__/layers.cpython-311.pyc ADDED
Binary file (21.1 kB). View file
 
train/src/__pycache__/lora_helper.cpython-310.pyc ADDED
Binary file (6.3 kB). View file
 
train/src/__pycache__/lora_helper.cpython-311.pyc ADDED
Binary file (16.8 kB). View file
 
train/src/__pycache__/pipeline.cpython-310.pyc ADDED
Binary file (23.8 kB). View file
 
train/src/__pycache__/pipeline.cpython-311.pyc ADDED
Binary file (40.9 kB). View file
 
train/src/__pycache__/prompt_helper.cpython-310.pyc ADDED
Binary file (3.88 kB). View file
 
train/src/__pycache__/prompt_helper.cpython-311.pyc ADDED
Binary file (7.25 kB). View file
 
train/src/__pycache__/transformer_flux.cpython-310.pyc ADDED
Binary file (15.2 kB). View file
 
train/src/__pycache__/transformer_flux.cpython-311.pyc ADDED
Binary file (27.4 kB). View file
 
train/src/jsonl_datasets.py CHANGED
@@ -8,49 +8,6 @@ import os
8
  import os.path as osp
9
  import cv2
10
 
11
- def do_z_pass(seg_masks: torch.Tensor, dist_values: torch.Tensor) -> torch.Tensor:
12
- """
13
- Performs a z-pass on segmentation masks based on distance values to the camera.
14
- For each pixel, if multiple subjects' masks are active, only the one with the smallest distance (closest) remains active.
15
-
16
- Args:
17
- seg_masks (torch.Tensor): Binary segmentation masks of shape (n_subjects, h, w) with dtype uint8.
18
- dist_values (torch.Tensor): Distance values for each subject of shape (n_subjects,).
19
-
20
- Returns:
21
- torch.Tensor: Processed segmentation masks after z-pass, same shape and dtype as seg_masks.
22
- """
23
- # Ensure tensors are on the same device
24
- device = seg_masks.device
25
-
26
- # Get dimensions
27
- n_subjects, h, w = seg_masks.shape
28
-
29
- # Reshape distance values for broadcasting across spatial dimensions
30
- dist_values_expanded = dist_values.view(n_subjects, 1, 1)
31
-
32
- # Create a tensor where active pixels have their distance, others have a high value (1e10)
33
- masked_dist = torch.where(seg_masks.bool(), dist_values_expanded, torch.tensor(1e10, device=device))
34
-
35
- # Find the subject index with the minimum distance for each pixel (shape (h, w))
36
- closest_indices = torch.argmin(masked_dist, dim=0)
37
-
38
- # Initialize output tensor with zeros
39
- output = torch.zeros_like(seg_masks)
40
-
41
- # Scatter 1s into the output tensor where the closest subject's indices are
42
- # closest_indices.unsqueeze(0) adds a dummy dimension to match scatter's expected shape
43
- output.scatter_(
44
- dim=0,
45
- index=closest_indices.unsqueeze(0),
46
- src=torch.ones_like(closest_indices.unsqueeze(0), dtype=output.dtype)
47
- )
48
-
49
- # Zero out any positions where the original mask was inactive
50
- output = output * seg_masks
51
-
52
- return output
53
-
54
  Image.MAX_IMAGE_PIXELS = None
55
 
56
  def multiple_16(num: float):
@@ -70,7 +27,7 @@ def load_image_safely(image_path, size):
70
  f.write(f"{image_path}\n")
71
  return Image.new("RGB", (size, size), (255, 255, 255))
72
 
73
- def make_train_dataset(args, tokenizer, accelerator):
74
  if args.current_train_data_dir is not None:
75
  print("load_data")
76
  dataset = load_dataset('json', data_files=args.current_train_data_dir)
@@ -82,36 +39,16 @@ def make_train_dataset(args, tokenizer, accelerator):
82
 
83
  # 6. Get the column names for input/target.
84
  target_column = args.target_column
85
- if args.subject_column is not None:
86
- subject_columns = args.subject_column.split(",")
 
 
 
 
87
  if args.spatial_column is not None:
88
  spatial_columns= args.spatial_column.split(",")
89
 
90
  size = args.cond_size
91
- # by default the noise size would be randomly sampled from (512, 1024)
92
- # noise_size = get_random_resolution(max_size=args.noise_size) # maybe 768 or higher
93
- noise_size = get_random_resolution(min_size=512, max_size=512) # maybe 768 or higher
94
- # subject_cond_train_transforms = transforms.Compose(
95
- # [
96
- # transforms.Lambda(lambda img: img.resize((
97
- # multiple_16(size * img.size[0] / max(img.size)),
98
- # multiple_16(size * img.size[1] / max(img.size))
99
- # ), resample=Image.BILINEAR)),
100
- # transforms.RandomHorizontalFlip(p=0.7),
101
- # transforms.RandomRotation(degrees=20),
102
- # transforms.Lambda(lambda img: transforms.Pad(
103
- # padding=(
104
- # int((size - img.size[0]) / 2),
105
- # int((size - img.size[1]) / 2),
106
- # int((size - img.size[0]) / 2),
107
- # int((size - img.size[1]) / 2)
108
- # ),
109
- # fill=0
110
- # )(img)),
111
- # transforms.ToTensor(),
112
- # transforms.Normalize([0.5], [0.5]),
113
- # ]
114
- # )
115
  cond_train_transforms = transforms.Compose(
116
  [
117
  transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR),
@@ -120,7 +57,6 @@ def make_train_dataset(args, tokenizer, accelerator):
120
  transforms.Normalize([0.5], [0.5]),
121
  ]
122
  )
123
- subject_cond_train_transforms = cond_train_transforms
124
 
125
  def train_transforms(image, noise_size):
126
  train_transforms_ = transforms.Compose(
@@ -141,11 +77,6 @@ def make_train_dataset(args, tokenizer, accelerator):
141
  concatenated_image = torch.cat(transformed_images, dim=1)
142
  return concatenated_image
143
 
144
- def load_and_transform_subject_images(images):
145
- transformed_images = [subject_cond_train_transforms(image) for image in images]
146
- concatenated_image = torch.cat(transformed_images, dim=1)
147
- return concatenated_image
148
-
149
  tokenizer_clip = tokenizer[0]
150
  tokenizer_t5 = tokenizer[1]
151
 
@@ -176,12 +107,12 @@ def make_train_dataset(args, tokenizer, accelerator):
176
  prompt_file_name = "space_prompt.pth"
177
  else:
178
  prompt_file_name = "_".join(caption.split(" ")) + ".pth"
179
- if osp.exists(osp.join(args.inference_embeds_dir, prompt_file_name)):
 
180
  prompt_embeds = torch.load(osp.join(args.inference_embeds_dir, prompt_file_name), map_location="cpu")
181
  pooled_prompt_embeds = prompt_embeds["pooled_prompt_embeds"]
182
  prompt_embeds = prompt_embeds["prompt_embeds"]
183
  else:
184
- # raise FileNotFoundError(f"Prompt embeddings for '{caption}' not found in {args.inference_embeds_dir}. Please precompute and save them.")
185
  prompt_embeds = torch.zeros((1, 77, 768)) # Placeholder tensor
186
  pooled_prompt_embeds = torch.zeros((1, 768)) # Placeholder tensor
187
  all_prompt_embeds.append(prompt_embeds.squeeze(0))
@@ -233,9 +164,6 @@ def make_train_dataset(args, tokenizer, accelerator):
233
  def preprocess_train(examples):
234
  _examples = {}
235
  train_data_dir = osp.dirname(args.current_train_data_dir)
236
- if args.subject_column is not None:
237
- subject_images = [[load_image_safely(osp.join(train_data_dir, examples[column][i]), args.cond_size) for column in subject_columns] for i in range(len(examples[target_column]))]
238
- _examples["subject_pixel_values"] = [load_and_transform_subject_images(subject) for subject in subject_images]
239
  if args.spatial_column is not None:
240
  # this now has two conditions
241
  spatial_images = [[load_image_safely(osp.join(train_data_dir, examples[column][i]), args.cond_size) for column in spatial_columns] for i in range(len(examples[target_column]))]
@@ -245,9 +173,7 @@ def make_train_dataset(args, tokenizer, accelerator):
245
  _examples["PLACEHOLDER_prompts"] = examples["PLACEHOLDER_prompts"]
246
  subjects = examples["subjects"]
247
  _examples["subjects"] = subjects
248
- subjects_ = ["_".join(subject) for subject in subjects] # get the subject names with "_" instead of space
249
  _examples["prompts"] = []
250
- # getting the prompts by replacing the PLACEHOLDER in the prompt with the actual subject names
251
  for i in range(len(examples["subjects"])):
252
  # replace the subjects string in the PLACEHOLDER
253
  prompt = examples["PLACEHOLDER_prompts"][i]
@@ -255,7 +181,6 @@ def make_train_dataset(args, tokenizer, accelerator):
255
  prompt = prompt.replace("PLACEHOLDER", placeholder_string)
256
  _examples["prompts"].append(prompt)
257
  _examples["prompt_embeds"], _examples["pooled_prompt_embeds"] = retrieve_prompt_embeds_from_disk(args, _examples)
258
- # gettin the z passed cuboids segmentation mask
259
  _examples["cuboids_segmasks"] = []
260
 
261
  def generous_resize_batch(masks, new_h, new_w):
@@ -290,9 +215,6 @@ def make_train_dataset(args, tokenizer, accelerator):
290
  segmasks_this_example[~mask] = 0
291
  segmasks_this_example = generous_resize_batch(segmasks_this_example, 32, 32)
292
  assert segmasks_this_example.shape == (len(subjects[i]), 32, 32), f"Segmentation masks shape {segmasks_this_example.shape} does not match expected shape {(len(subjects[i]), 32, 32)} for example {i}"
293
- # z_passed_segmask = do_z_pass(segmasks_this_example, depth_values_this_example)
294
- # print(f"{z_passed_segmask.shape = }, {segmasks_this_example.shape = }")
295
- # _examples["cuboids_segmasks"].append(z_passed_segmask)
296
  _examples["cuboids_segmasks"].append(segmasks_this_example)
297
 
298
  _examples["token_ids_clip"], _examples["token_ids_t5"] = tokenize_prompt_clip_t5(_examples)
@@ -316,12 +238,6 @@ def collate_fn(examples):
316
  cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
317
  else:
318
  cond_pixel_values = None
319
- if examples[0].get("subject_pixel_values") is not None:
320
- subject_pixel_values = torch.stack([example["subject_pixel_values"] for example in examples])
321
- subject_pixel_values = subject_pixel_values.to(memory_format=torch.contiguous_format).float()
322
- else:
323
- subject_pixel_values = None
324
-
325
  target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
326
  target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
327
  token_ids_clip = torch.stack([torch.tensor(example["token_ids_clip"]) for example in examples])
@@ -335,7 +251,6 @@ def collate_fn(examples):
335
 
336
  return {
337
  "cond_pixel_values": cond_pixel_values,
338
- "subject_pixel_values": subject_pixel_values,
339
  "pixel_values": target_pixel_values,
340
  "text_ids_1": token_ids_clip,
341
  "text_ids_2": token_ids_t5,
 
8
  import os.path as osp
9
  import cv2
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  Image.MAX_IMAGE_PIXELS = None
12
 
13
  def multiple_16(num: float):
 
27
  f.write(f"{image_path}\n")
28
  return Image.new("RGB", (size, size), (255, 255, 255))
29
 
30
+ def make_train_dataset(args, tokenizer, accelerator, noise_size, only_realistic_images=False):
31
  if args.current_train_data_dir is not None:
32
  print("load_data")
33
  dataset = load_dataset('json', data_files=args.current_train_data_dir)
 
39
 
40
  # 6. Get the column names for input/target.
41
  target_column = args.target_column
42
+ if only_realistic_images:
43
+ before = len(dataset["train"])
44
+ dataset["train"] = dataset["train"].filter(lambda example: osp.basename(example[target_column]) != "main.jpg")
45
+ after = len(dataset["train"])
46
+ print(f"[only_realistic_images] filtered out {before - after} examples")
47
+
48
  if args.spatial_column is not None:
49
  spatial_columns= args.spatial_column.split(",")
50
 
51
  size = args.cond_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  cond_train_transforms = transforms.Compose(
53
  [
54
  transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR),
 
57
  transforms.Normalize([0.5], [0.5]),
58
  ]
59
  )
 
60
 
61
  def train_transforms(image, noise_size):
62
  train_transforms_ = transforms.Compose(
 
77
  concatenated_image = torch.cat(transformed_images, dim=1)
78
  return concatenated_image
79
 
 
 
 
 
 
80
  tokenizer_clip = tokenizer[0]
81
  tokenizer_t5 = tokenizer[1]
82
 
 
107
  prompt_file_name = "space_prompt.pth"
108
  else:
109
  prompt_file_name = "_".join(caption.split(" ")) + ".pth"
110
+ if args.inference_embeds_dir is not None:
111
+ assert osp.exists(osp.join(args.inference_embeds_dir, prompt_file_name)), f"Prompt embeddings for '{caption}' not found in {args.inference_embeds_dir}. Please precompute and save them."
112
  prompt_embeds = torch.load(osp.join(args.inference_embeds_dir, prompt_file_name), map_location="cpu")
113
  pooled_prompt_embeds = prompt_embeds["pooled_prompt_embeds"]
114
  prompt_embeds = prompt_embeds["prompt_embeds"]
115
  else:
 
116
  prompt_embeds = torch.zeros((1, 77, 768)) # Placeholder tensor
117
  pooled_prompt_embeds = torch.zeros((1, 768)) # Placeholder tensor
118
  all_prompt_embeds.append(prompt_embeds.squeeze(0))
 
164
  def preprocess_train(examples):
165
  _examples = {}
166
  train_data_dir = osp.dirname(args.current_train_data_dir)
 
 
 
167
  if args.spatial_column is not None:
168
  # this now has two conditions
169
  spatial_images = [[load_image_safely(osp.join(train_data_dir, examples[column][i]), args.cond_size) for column in spatial_columns] for i in range(len(examples[target_column]))]
 
173
  _examples["PLACEHOLDER_prompts"] = examples["PLACEHOLDER_prompts"]
174
  subjects = examples["subjects"]
175
  _examples["subjects"] = subjects
 
176
  _examples["prompts"] = []
 
177
  for i in range(len(examples["subjects"])):
178
  # replace the subjects string in the PLACEHOLDER
179
  prompt = examples["PLACEHOLDER_prompts"][i]
 
181
  prompt = prompt.replace("PLACEHOLDER", placeholder_string)
182
  _examples["prompts"].append(prompt)
183
  _examples["prompt_embeds"], _examples["pooled_prompt_embeds"] = retrieve_prompt_embeds_from_disk(args, _examples)
 
184
  _examples["cuboids_segmasks"] = []
185
 
186
  def generous_resize_batch(masks, new_h, new_w):
 
215
  segmasks_this_example[~mask] = 0
216
  segmasks_this_example = generous_resize_batch(segmasks_this_example, 32, 32)
217
  assert segmasks_this_example.shape == (len(subjects[i]), 32, 32), f"Segmentation masks shape {segmasks_this_example.shape} does not match expected shape {(len(subjects[i]), 32, 32)} for example {i}"
 
 
 
218
  _examples["cuboids_segmasks"].append(segmasks_this_example)
219
 
220
  _examples["token_ids_clip"], _examples["token_ids_t5"] = tokenize_prompt_clip_t5(_examples)
 
238
  cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
239
  else:
240
  cond_pixel_values = None
 
 
 
 
 
 
241
  target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
242
  target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
243
  token_ids_clip = torch.stack([torch.tensor(example["token_ids_clip"]) for example in examples])
 
251
 
252
  return {
253
  "cond_pixel_values": cond_pixel_values,
 
254
  "pixel_values": target_pixel_values,
255
  "text_ids_1": token_ids_clip,
256
  "text_ids_2": token_ids_t5,
train/src/layers.py CHANGED
@@ -99,7 +99,6 @@ class MultiSingleStreamBlockLoraProcessor(nn.Module):
99
  use_cond = False,
100
  call_ids = None,
101
  cuboids_segmasks: torch.Tensor = None,
102
- store_qk: Optional[str] = None,
103
  ) -> torch.FloatTensor:
104
 
105
  batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
@@ -155,16 +154,13 @@ class MultiSingleStreamBlockLoraProcessor(nn.Module):
155
  for subject_idx, call_ids_this_subject in enumerate(call_ids_this_example):
156
  # preparing the cuboid mask
157
  cuboid_mask = cuboids_segmasks[batch_idx][subject_idx] # (h, w)
158
- # assert cuboid_mask.shape == (int(math.sqrt(num_img_tokens)), int(math.sqrt(num_img_tokens))), f"{cuboid_mask.shape=}, {num_img_tokens=}"
159
  cuboid_mask = cuboid_mask.to(torch.bool)
160
 
161
- # assert scaled_block_size == scaled_cond_size + 512, f"{scaled_cond_size=}, {scaled_block_size=}"
162
  for i in range(num_cond_blocks):
163
  cuboid_mask = cuboids_segmasks[batch_idx][subject_idx] # (h, w)
164
  cuboid_mask = cuboid_mask.to(torch.bool)
165
  # masking out the condition tokens -> text token attention map
166
  mask_subset = mask[batch_idx, :, scaled_block_size + i*scaled_cond_size : scaled_block_size + (i+1)*scaled_cond_size, call_ids_this_subject]
167
- # assert mask_subset.shape == (1, num_img_tokens, len(call_ids_this_subject)), f"{mask_subset.shape=}, {attn.heads=}, {num_img_tokens=}, {len(call_ids_this_subject)=}"
168
  mask_subset[:, cuboid_mask.flatten()] = 0 # enable attention to cuboid regions
169
 
170
  mask[batch_idx, :, scaled_block_size + i*scaled_cond_size : scaled_block_size + (i+1)*scaled_cond_size, call_ids_this_subject] = mask_subset
@@ -175,14 +171,6 @@ class MultiSingleStreamBlockLoraProcessor(nn.Module):
175
 
176
  hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
177
 
178
- if store_qk:
179
- attn_weights = query.detach().to(torch.float16) @ key.detach().to(torch.float16).transpose(-1, -2) # (batch_size, num_heads, query_len, key_len)
180
- attn_weights = attn_weights + mask
181
- attn_weights = torch.mean(torch.softmax(attn_weights, dim=-1), dim=1)
182
- attn_weights = attn_weights.cpu()
183
- os.makedirs(osp.dirname(store_qk), exist_ok=True)
184
- torch.save(attn_weights, store_qk + ".pth")
185
-
186
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
187
  hidden_states = hidden_states.to(query.dtype)
188
 
@@ -228,7 +216,6 @@ class MultiDoubleStreamBlockLoraProcessor(nn.Module):
228
  use_cond=False,
229
  call_ids = None,
230
  cuboids_segmasks: torch.Tensor = None,
231
- store_qk: Optional[str] = None,
232
  ) -> torch.FloatTensor:
233
 
234
  batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
@@ -309,16 +296,13 @@ class MultiDoubleStreamBlockLoraProcessor(nn.Module):
309
  for subject_idx, call_ids_this_subject in enumerate(call_ids_this_example):
310
  # preparing the cuboid mask
311
  cuboid_mask = cuboids_segmasks[batch_idx][subject_idx] # (h, w)
312
- # assert cuboid_mask.shape == (int(math.sqrt(num_img_tokens)), int(math.sqrt(num_img_tokens))), f"{cuboid_mask.shape=}, {num_img_tokens=}, {scaled_block_size=}"
313
  cuboid_mask = cuboid_mask.to(torch.bool)
314
 
315
- # assert scaled_block_size == scaled_cond_size + 512, f"{scaled_cond_size=}, {scaled_block_size=}"
316
  for i in range(num_cond_blocks):
317
  cuboid_mask = cuboids_segmasks[batch_idx][subject_idx] # (h, w)
318
  cuboid_mask = cuboid_mask.to(torch.bool)
319
  # masking out the condition tokens -> text token attention map
320
  mask_subset = mask[batch_idx, :, scaled_block_size + i*scaled_cond_size : scaled_block_size + (i+1)*scaled_cond_size, call_ids_this_subject]
321
- # assert mask_subset.shape == (1, num_img_tokens, len(call_ids_this_subject)), f"{mask_subset.shape=}, {attn.heads=}, {num_img_tokens=}, {len(call_ids_this_subject)=}"
322
  mask_subset[:, cuboid_mask.flatten()] = 0 # enable attention to cuboid regions
323
 
324
  mask[batch_idx, :, scaled_block_size + i*scaled_cond_size : scaled_block_size + (i+1)*scaled_cond_size, call_ids_this_subject] = mask_subset
@@ -329,15 +313,6 @@ class MultiDoubleStreamBlockLoraProcessor(nn.Module):
329
 
330
  hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
331
 
332
- if store_qk:
333
- attn_weights = query.detach().to(torch.float16) @ key.detach().to(torch.float16).transpose(-1, -2) # (batch_size, num_heads, query_len, key_len)
334
- attn_weights = attn_weights + mask
335
- attn_weights = torch.mean(torch.softmax(attn_weights, dim=-1), dim=1)
336
- attn_weights = attn_weights.cpu()
337
- os.makedirs(osp.dirname(store_qk), exist_ok=True)
338
- torch.save(attn_weights, store_qk + ".pth")
339
-
340
-
341
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
342
  hidden_states = hidden_states.to(query.dtype)
343
 
 
99
  use_cond = False,
100
  call_ids = None,
101
  cuboids_segmasks: torch.Tensor = None,
 
102
  ) -> torch.FloatTensor:
103
 
104
  batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
 
154
  for subject_idx, call_ids_this_subject in enumerate(call_ids_this_example):
155
  # preparing the cuboid mask
156
  cuboid_mask = cuboids_segmasks[batch_idx][subject_idx] # (h, w)
 
157
  cuboid_mask = cuboid_mask.to(torch.bool)
158
 
 
159
  for i in range(num_cond_blocks):
160
  cuboid_mask = cuboids_segmasks[batch_idx][subject_idx] # (h, w)
161
  cuboid_mask = cuboid_mask.to(torch.bool)
162
  # masking out the condition tokens -> text token attention map
163
  mask_subset = mask[batch_idx, :, scaled_block_size + i*scaled_cond_size : scaled_block_size + (i+1)*scaled_cond_size, call_ids_this_subject]
 
164
  mask_subset[:, cuboid_mask.flatten()] = 0 # enable attention to cuboid regions
165
 
166
  mask[batch_idx, :, scaled_block_size + i*scaled_cond_size : scaled_block_size + (i+1)*scaled_cond_size, call_ids_this_subject] = mask_subset
 
171
 
172
  hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
173
 
 
 
 
 
 
 
 
 
174
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
175
  hidden_states = hidden_states.to(query.dtype)
176
 
 
216
  use_cond=False,
217
  call_ids = None,
218
  cuboids_segmasks: torch.Tensor = None,
 
219
  ) -> torch.FloatTensor:
220
 
221
  batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
 
296
  for subject_idx, call_ids_this_subject in enumerate(call_ids_this_example):
297
  # preparing the cuboid mask
298
  cuboid_mask = cuboids_segmasks[batch_idx][subject_idx] # (h, w)
 
299
  cuboid_mask = cuboid_mask.to(torch.bool)
300
 
 
301
  for i in range(num_cond_blocks):
302
  cuboid_mask = cuboids_segmasks[batch_idx][subject_idx] # (h, w)
303
  cuboid_mask = cuboid_mask.to(torch.bool)
304
  # masking out the condition tokens -> text token attention map
305
  mask_subset = mask[batch_idx, :, scaled_block_size + i*scaled_cond_size : scaled_block_size + (i+1)*scaled_cond_size, call_ids_this_subject]
 
306
  mask_subset[:, cuboid_mask.flatten()] = 0 # enable attention to cuboid regions
307
 
308
  mask[batch_idx, :, scaled_block_size + i*scaled_cond_size : scaled_block_size + (i+1)*scaled_cond_size, call_ids_this_subject] = mask_subset
 
313
 
314
  hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
315
 
 
 
 
 
 
 
 
 
 
316
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
317
  hidden_states = hidden_states.to(query.dtype)
318
 
train/src/prompt_helper.py CHANGED
@@ -1,215 +0,0 @@
1
- import torch
2
- import os
3
- import os.path as osp
4
-
5
-
6
- def load_text_encoders(args, class_one, class_two):
7
- text_encoder_one = class_one.from_pretrained(
8
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
9
- )
10
- text_encoder_two = class_two.from_pretrained(
11
- args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
12
- )
13
- return text_encoder_one, text_encoder_two
14
-
15
-
16
- def tokenize_prompt(tokenizer, prompt, max_sequence_length):
17
- text_inputs = tokenizer(
18
- prompt,
19
- padding="max_length",
20
- max_length=max_sequence_length,
21
- truncation=True,
22
- return_length=False,
23
- return_overflowing_tokens=False,
24
- return_tensors="pt",
25
- )
26
- text_input_ids = text_inputs.input_ids
27
- return text_input_ids
28
-
29
-
30
- def tokenize_prompt_clip(tokenizer, prompt):
31
- text_inputs = tokenizer(
32
- prompt,
33
- padding="max_length",
34
- max_length=77,
35
- truncation=True,
36
- return_length=False,
37
- return_overflowing_tokens=False,
38
- return_tensors="pt",
39
- )
40
- text_input_ids = text_inputs.input_ids
41
- return text_input_ids
42
-
43
-
44
- def tokenize_prompt_t5(tokenizer, prompt):
45
- text_inputs = tokenizer(
46
- prompt,
47
- padding="max_length",
48
- max_length=512,
49
- truncation=True,
50
- return_length=False,
51
- return_overflowing_tokens=False,
52
- return_tensors="pt",
53
- )
54
- text_input_ids = text_inputs.input_ids
55
- return text_input_ids
56
-
57
-
58
- def _encode_prompt_with_t5(
59
- text_encoder,
60
- tokenizer,
61
- max_sequence_length=512,
62
- prompt=None,
63
- num_images_per_prompt=1,
64
- device=None,
65
- text_input_ids=None,
66
- ):
67
- prompt = [prompt] if isinstance(prompt, str) else prompt
68
- batch_size = len(prompt)
69
-
70
- if tokenizer is not None:
71
- text_inputs = tokenizer(
72
- prompt,
73
- padding="max_length",
74
- max_length=max_sequence_length,
75
- truncation=True,
76
- return_length=False,
77
- return_overflowing_tokens=False,
78
- return_tensors="pt",
79
- )
80
- text_input_ids = text_inputs.input_ids
81
- else:
82
- if text_input_ids is None:
83
- raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
84
-
85
- prompt_embeds = text_encoder(text_input_ids.to(device))[0]
86
-
87
- dtype = text_encoder.dtype
88
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
89
-
90
- _, seq_len, _ = prompt_embeds.shape
91
-
92
- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
93
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
94
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
95
-
96
- return prompt_embeds
97
-
98
-
99
- def _encode_prompt_with_clip(
100
- text_encoder,
101
- tokenizer,
102
- prompt: str,
103
- device=None,
104
- text_input_ids=None,
105
- num_images_per_prompt: int = 1,
106
- ):
107
- prompt = [prompt] if isinstance(prompt, str) else prompt
108
- batch_size = len(prompt)
109
-
110
- if tokenizer is not None:
111
- text_inputs = tokenizer(
112
- prompt,
113
- padding="max_length",
114
- max_length=77,
115
- truncation=True,
116
- return_overflowing_tokens=False,
117
- return_length=False,
118
- return_tensors="pt",
119
- )
120
-
121
- text_input_ids = text_inputs.input_ids
122
- else:
123
- if text_input_ids is None:
124
- raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
125
-
126
- prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
127
-
128
- # Use pooled output of CLIPTextModel
129
- prompt_embeds = prompt_embeds.pooler_output
130
- prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
131
-
132
- # duplicate text embeddings for each generation per prompt, using mps friendly method
133
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
134
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
135
-
136
- return prompt_embeds
137
-
138
-
139
- def encode_prompt(
140
- args,
141
- text_encoders,
142
- tokenizers,
143
- prompt: str,
144
- max_sequence_length,
145
- device=None,
146
- num_images_per_prompt: int = 1,
147
- text_input_ids_list=None,
148
- ):
149
- prompt = [prompt] if isinstance(prompt, str) else prompt
150
- dtype = text_encoders[0].dtype
151
-
152
- _prompt_ = "_".join(prompt)
153
- if osp.exists(osp.join(args.inference_embeds_dir, f"{_prompt_}.pth")):
154
- prompt_embeds = torch.load(osp.join(args.inference_embeds_dir, f"{_prompt_}.pth"))
155
- pooled_prompt_embeds = prompt_embeds["pooled_prompt_embeds"]
156
- prompt_embeds = prompt_embeds["prompt_embeds"]
157
-
158
- else:
159
- pooled_prompt_embeds = _encode_prompt_with_clip(
160
- text_encoder=text_encoders[0],
161
- tokenizer=tokenizers[0],
162
- prompt=prompt,
163
- device=device if device is not None else text_encoders[0].device,
164
- num_images_per_prompt=num_images_per_prompt,
165
- text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
166
- )
167
-
168
- prompt_embeds = _encode_prompt_with_t5(
169
- text_encoder=text_encoders[1],
170
- tokenizer=tokenizers[1],
171
- max_sequence_length=max_sequence_length,
172
- prompt=prompt,
173
- num_images_per_prompt=num_images_per_prompt,
174
- device=device if device is not None else text_encoders[1].device,
175
- text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
176
- )
177
-
178
- text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
179
-
180
- return prompt_embeds, pooled_prompt_embeds, text_ids
181
-
182
-
183
- def encode_token_ids(text_encoders, tokens, accelerator, num_images_per_prompt=1, device=None):
184
- text_encoder_clip = text_encoders[0]
185
- text_encoder_t5 = text_encoders[1]
186
- tokens_clip, tokens_t5 = tokens[0], tokens[1]
187
- batch_size = tokens_clip.shape[0]
188
-
189
- if device == "cpu":
190
- device = "cpu"
191
- else:
192
- device = accelerator.device
193
-
194
- # clip
195
- prompt_embeds = text_encoder_clip(tokens_clip.to(device), output_hidden_states=False)
196
- # Use pooled output of CLIPTextModel
197
- prompt_embeds = prompt_embeds.pooler_output
198
- prompt_embeds = prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device)
199
- # duplicate text embeddings for each generation per prompt, using mps friendly method
200
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
201
- pooled_prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
202
- pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device)
203
-
204
- # t5
205
- prompt_embeds = text_encoder_t5(tokens_t5.to(device))[0]
206
- dtype = text_encoder_t5.dtype
207
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=accelerator.device)
208
- _, seq_len, _ = prompt_embeds.shape
209
- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
210
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
211
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
212
-
213
- text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=accelerator.device, dtype=dtype)
214
-
215
- return prompt_embeds, pooled_prompt_embeds, text_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train/src/transformer_flux.py CHANGED
@@ -395,9 +395,9 @@ class FluxTransformer2DModel(
395
  if self.original_attn_processors is not None:
396
  self.set_attn_processor(self.original_attn_processors)
397
 
398
- def _set_gradient_checkpointing(self, module, value=False):
399
  if hasattr(module, "gradient_checkpointing"):
400
- module.gradient_checkpointing = value
401
 
402
  def forward(
403
  self,
@@ -416,7 +416,6 @@ class FluxTransformer2DModel(
416
  controlnet_blocks_repeat: bool = False,
417
  call_ids: list = None,
418
  cuboids_segmasks: torch.Tensor = None,
419
- store_qk: bool = False,
420
  ) -> Union[torch.Tensor, Transformer2DModelOutput]:
421
  if cond_hidden_states is not None:
422
  use_condition = True
@@ -512,10 +511,6 @@ class FluxTransformer2DModel(
512
  )
513
 
514
  else:
515
- if store_qk:
516
- overall_block_idx = index_block
517
- joint_attention_kwargs["store_qk"] = osp.join(store_qk, f"{str(overall_block_idx).zfill(3)}")
518
-
519
  encoder_hidden_states, hidden_states, cond_hidden_states = block(
520
  hidden_states=hidden_states,
521
  encoder_hidden_states=encoder_hidden_states,
@@ -566,10 +561,6 @@ class FluxTransformer2DModel(
566
  )
567
 
568
  else:
569
- if store_qk:
570
- overall_block_idx = index_block + len(self.transformer_blocks)
571
- joint_attention_kwargs["store_qk"] = osp.join(store_qk, f"{str(overall_block_idx).zfill(3)}")
572
-
573
  hidden_states, cond_hidden_states = block(
574
  hidden_states=hidden_states,
575
  cond_hidden_states=cond_hidden_states if use_condition else None,
 
395
  if self.original_attn_processors is not None:
396
  self.set_attn_processor(self.original_attn_processors)
397
 
398
+ def _set_gradient_checkpointing(self, module, enable=False):
399
  if hasattr(module, "gradient_checkpointing"):
400
+ module.gradient_checkpointing = enable
401
 
402
  def forward(
403
  self,
 
416
  controlnet_blocks_repeat: bool = False,
417
  call_ids: list = None,
418
  cuboids_segmasks: torch.Tensor = None,
 
419
  ) -> Union[torch.Tensor, Transformer2DModelOutput]:
420
  if cond_hidden_states is not None:
421
  use_condition = True
 
511
  )
512
 
513
  else:
 
 
 
 
514
  encoder_hidden_states, hidden_states, cond_hidden_states = block(
515
  hidden_states=hidden_states,
516
  encoder_hidden_states=encoder_hidden_states,
 
561
  )
562
 
563
  else:
 
 
 
 
564
  hidden_states, cond_hidden_states = block(
565
  hidden_states=hidden_states,
566
  cond_hidden_states=cond_hidden_states if use_condition else None,
train/train.py CHANGED
@@ -42,7 +42,6 @@ from diffusers.utils import (
42
  convert_unet_state_dict_to_peft
43
  )
44
 
45
- from src.prompt_helper import *
46
  from src.lora_helper import *
47
  from src.pipeline import FluxPipeline, resize_position_encoding, prepare_latent_subject_ids
48
  from src.layers import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
@@ -60,64 +59,146 @@ logger = get_logger(__name__)
60
  import matplotlib.pyplot as plt
61
  import torch
62
 
63
- def create_validation_figure(output_image, spatial_image, subject_image, prompt, validation_idx, global_step):
64
- """
65
- Create a 2x2 matplotlib figure showing validation results.
66
-
67
- Args:
68
- output_image: Generated output image (PIL Image)
69
- spatial_image: Spatial condition image (PIL Image or None)
70
- subject_image: Subject condition image (PIL Image or None)
71
- prompt: Text prompt string
72
- validation_idx: Index of validation prompt
73
- global_step: Current training step
74
-
75
- Returns:
76
- matplotlib figure
77
- """
78
- fig, axes = plt.subplots(2, 2, figsize=(12, 20))
79
- fig.suptitle(f'Validation Results - Step {global_step} - Prompt {validation_idx}', fontsize=14)
80
-
81
- # Output image (top-left)
82
- axes[0, 0].imshow(np.array(output_image))
83
- axes[0, 0].set_title('Generated Output')
84
- axes[0, 0].axis('off')
85
-
86
- # Spatial condition (top-right)
87
- if spatial_image is not None:
88
- axes[0, 1].imshow(np.array(spatial_image))
89
- axes[0, 1].set_title('Spatial Condition')
 
 
 
 
 
 
 
90
  else:
91
- axes[0, 1].text(0.5, 0.5, 'NOT AVAILABLE',
92
- horizontalalignment='center', verticalalignment='center',
93
- transform=axes[0, 1].transAxes, fontsize=14, fontweight='bold')
94
- axes[0, 1].set_title('Spatial Condition')
95
- axes[0, 1].axis('off')
96
-
97
- # Subject condition (bottom-left)
98
- if subject_image is not None:
99
- axes[1, 0].imshow(np.array(subject_image))
100
- axes[1, 0].set_title('Subject Condition')
101
  else:
102
- axes[1, 0].text(0.5, 0.5, 'NOT AVAILABLE',
103
- horizontalalignment='center', verticalalignment='center',
104
- transform=axes[1, 0].transAxes, fontsize=14, fontweight='bold')
105
- axes[1, 0].set_title('Subject Condition')
106
- axes[1, 0].axis('off')
107
-
108
- # Prompt and info (bottom-right)
109
- info_text = f'Prompt:\n"{prompt}"\n\nStep: {global_step}\nValidation Index: {validation_idx}'
110
- axes[1, 1].text(0.5, 0.5, info_text,
111
- horizontalalignment='center', verticalalignment='center',
112
- transform=axes[1, 1].transAxes, fontsize=10, wrap=True)
113
- axes[1, 1].set_title('Prompt & Info')
114
- axes[1, 1].axis('off')
115
-
116
- plt.tight_layout()
117
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- def visualize_training_data(batch, vae, model_input, noisy_model_input, cond_input, subject_input, args, global_step, accelerator):
 
121
  """
122
  Visualize training data including all entities from the batch.
123
 
@@ -127,300 +208,155 @@ def visualize_training_data(batch, vae, model_input, noisy_model_input, cond_inp
127
  model_input: Clean latents before adding noise
128
  noisy_model_input: Noisy latents passed to transformer
129
  cond_input: Spatial condition latents (may be None)
130
- subject_input: Subject condition latents (may be None)
131
  args: Training arguments
132
  global_step: Current training step
133
  accelerator: Accelerator instance
134
  """
135
-
136
- # Check availability of conditions
137
- has_spatial_condition = batch["cond_pixel_values"] is not None
138
- has_subject_condition = batch["subject_pixel_values"] is not None
139
  has_cuboids_segmasks = "cuboids_segmasks" in batch and batch["cuboids_segmasks"] is not None
140
- has_cuboids_segmasks_bev = "cuboids_segmasks_bev" in batch and batch["cuboids_segmasks_bev"] is not None
141
-
142
- # Initialize variables
143
- spatial_img = None
144
- subject_img = None
145
-
146
  with torch.no_grad():
147
- # Get VAE config for proper decoding
148
  vae_config_shift_factor = vae.config.shift_factor
149
  vae_config_scaling_factor = vae.config.scaling_factor
150
- vae_dtype = vae.dtype
151
- vae = vae.to(torch.float32)
152
-
153
  # Decode spatial condition if available
154
  if has_spatial_condition:
155
  cond_for_decode = (cond_input / vae_config_scaling_factor) + vae_config_shift_factor
156
  spatial_decoded = vae.decode(cond_for_decode.float()).sample
157
- spatial_decoded = (spatial_decoded / 2 + 0.5).clamp(0, 1) # Normalize to [0,1]
158
  spatial_img = spatial_decoded[0].float().cpu().permute(1, 2, 0).numpy()
159
-
160
- # Decode subject condition if available
161
- if has_subject_condition:
162
- subject_for_decode = (subject_input / vae_config_scaling_factor) + vae_config_shift_factor
163
- subject_decoded = vae.decode(subject_for_decode.float()).sample
164
- subject_decoded = (subject_decoded / 2 + 0.5).clamp(0, 1) # Normalize to [0,1]
165
- subject_img = subject_decoded[0].float().cpu().permute(1, 2, 0).numpy()
166
-
167
- # Decode clean model input
168
  clean_for_decode = (model_input / vae_config_scaling_factor) + vae_config_shift_factor
169
  clean_decoded = vae.decode(clean_for_decode.float()).sample
170
- clean_decoded = (clean_decoded / 2 + 0.5).clamp(0, 1)
171
-
172
- # Decode noisy model input
173
  noisy_for_decode = (noisy_model_input / vae_config_scaling_factor) + vae_config_shift_factor
174
  noisy_decoded = vae.decode(noisy_for_decode.float()).sample
175
- noisy_decoded = (noisy_decoded / 2 + 0.5).clamp(0, 1)
176
-
177
- # Convert to CPU and numpy for visualization (take first batch item)
178
- clean_img = clean_decoded[0].float().cpu().permute(1, 2, 0).numpy()
179
- noisy_img = noisy_decoded[0].float().cpu().permute(1, 2, 0).numpy()
180
-
181
- # Get text prompt and other info
182
  text_prompt = batch["prompts"][0] if isinstance(batch["prompts"], list) else batch["prompts"]
183
  call_id = batch["call_ids"][0] if batch["call_ids"] is not None else "N/A"
184
-
185
- # Create figure with more subplots to accommodate all entities including BEV
186
- fig, axes = plt.subplots(4, 3, figsize=(18, 24))
187
- # fig.suptitle(f'Training Data Visualization - Step {global_step}', fontsize=16)
188
-
189
- # Spatial condition (0,0)
190
- if has_spatial_condition and spatial_img is not None:
 
 
191
  axes[0, 0].imshow(spatial_img)
192
- axes[0, 0].set_title('Spatial Condition')
193
  else:
194
- axes[0, 0].text(0.5, 0.5, 'NOT AVAILABLE',
195
- horizontalalignment='center', verticalalignment='center',
196
- transform=axes[0, 0].transAxes, fontsize=14, fontweight='bold')
197
- axes[0, 0].set_title('Spatial Condition')
198
  axes[0, 0].axis('off')
199
-
200
- # Subject condition (0,1)
201
- if has_subject_condition and subject_img is not None:
202
- axes[0, 1].imshow(subject_img)
203
- axes[0, 1].set_title('Subject Condition')
204
- else:
205
- axes[0, 1].text(0.5, 0.5, 'NOT AVAILABLE',
206
- horizontalalignment='center', verticalalignment='center',
207
- transform=axes[0, 1].transAxes, fontsize=14, fontweight='bold')
208
- axes[0, 1].set_title('Subject Condition')
209
  axes[0, 1].axis('off')
210
-
211
- # Clean model input (0,2)
212
- axes[0, 2].imshow(clean_img)
213
- axes[0, 2].set_title('Clean Model Input')
214
  axes[0, 2].axis('off')
215
-
216
- # Noisy model input (1,0)
217
- axes[1, 0].imshow(noisy_img)
218
- axes[1, 0].set_title('Noisy Model Input')
219
- axes[1, 0].axis('off')
220
-
221
- # Cuboids segmentation masks with legend (1,1 and 1,2)
222
  if has_cuboids_segmasks:
223
- segmask = batch["cuboids_segmasks"][0].float().cpu().numpy() # Shape: (n_subjects, h, w)
224
  n_subjects, h, w = segmask.shape
225
-
226
- # Only use first 4 subjects for visualization
227
- n_subjects_to_show = min(4, n_subjects)
228
-
229
- # Create colored segmentation visualization
230
- np.random.seed(42) # For consistent colors
231
- colors = np.random.rand(n_subjects_to_show + 1, 3) # +1 for background
232
- colors[0] = [0, 0, 0] # Background is black
233
-
234
- # Create 2x2 grid of individual subject masks
235
- grid_h, grid_w = 2, 2
236
- combined_mask = np.zeros((h * grid_h, w * grid_w, 3))
237
-
238
- for idx in range(n_subjects_to_show):
239
- row = idx // grid_w
240
- col = idx % grid_w
241
-
242
- # Create binary mask for this subject
243
  subject_mask = np.zeros((h, w, 3))
244
- mask = segmask[idx] > 0.5 # Binary threshold
245
- subject_mask[mask] = colors[idx + 1]
246
-
247
- # Place in grid
248
- combined_mask[row*h:(row+1)*h, col*w:(col+1)*w] = subject_mask
249
-
250
- axes[1, 1].imshow(combined_mask)
251
- axes[1, 1].set_title('Cuboids Segmentation (2x2 Grid)')
 
 
 
 
 
 
 
 
 
 
252
  axes[1, 1].axis('off')
253
-
254
- # Create legend in the next subplot (1,2) - only for first 4 subjects
255
- axes[1, 2].set_xlim(0, 1)
256
- axes[1, 2].set_ylim(0, 1)
257
-
258
- # Add legend entries
259
- legend_y_positions = np.linspace(0.9, 0.1, n_subjects_to_show + 1)
260
- axes[1, 2].text(0.1, legend_y_positions[0], f"Background",
261
- color=colors[0], fontsize=12, fontweight='bold')
262
-
263
- for subject_idx in range(n_subjects_to_show):
264
- axes[1, 2].text(0.1, legend_y_positions[subject_idx + 1],
265
- f"Subject {subject_idx}",
266
- color=colors[subject_idx + 1], fontsize=12, fontweight='bold')
267
-
268
- axes[1, 2].set_title('Segmentation Legend (First 4)')
269
- axes[1, 2].axis('off')
270
  else:
271
- axes[1, 1].text(0.5, 0.5, 'NOT AVAILABLE',
272
- horizontalalignment='center', verticalalignment='center',
273
- transform=axes[1, 1].transAxes, fontsize=14, fontweight='bold')
274
- axes[1, 1].set_title('Cuboids Segmentation')
275
- axes[1, 1].axis('off')
276
-
277
- axes[1, 2].text(0.5, 0.5, 'NOT AVAILABLE',
278
- horizontalalignment='center', verticalalignment='center',
279
- transform=axes[1, 2].transAxes, fontsize=14, fontweight='bold')
280
- axes[1, 2].set_title('Segmentation Legend')
281
- axes[1, 2].axis('off')
282
-
283
- # BEV Cuboids segmentation masks with legend (2,0 and 2,1)
284
- if has_cuboids_segmasks_bev:
285
- segmask_bev = batch["cuboids_segmasks_bev"][0].float().cpu().numpy() # Shape: (n_subjects, h, w)
286
- n_subjects_bev, h_bev, w_bev = segmask_bev.shape
287
-
288
- # Create colored segmentation visualization for BEV (use different seed for different colors)
289
- np.random.seed(123) # Different seed for BEV colors
290
- colors_bev = np.random.rand(n_subjects_bev + 1, 3) # +1 for background
291
- colors_bev[0] = [0, 0, 0] # Background is black
292
-
293
- # Create RGB image from BEV segmentation
294
- colored_segmask_bev = np.zeros((h_bev, w_bev, 3))
295
- for subject_idx in range(n_subjects_bev):
296
- mask_bev = segmask_bev[subject_idx] > 0.5 # Binary threshold
297
- colored_segmask_bev[mask_bev] = colors_bev[subject_idx + 1]
298
-
299
- axes[2, 0].imshow(colored_segmask_bev)
300
- axes[2, 0].set_title('BEV Cuboids Segmentation')
301
- axes[2, 0].axis('off')
302
-
303
- # Create BEV legend in the next subplot (2,1)
304
- axes[2, 1].set_xlim(0, 1)
305
- axes[2, 1].set_ylim(0, 1)
306
-
307
- # Add BEV legend entries
308
- legend_y_positions_bev = np.linspace(0.9, 0.1, n_subjects_bev + 1)
309
- axes[2, 1].text(0.1, legend_y_positions_bev[0], f"Background",
310
- color=colors_bev[0], fontsize=12, fontweight='bold')
311
-
312
- for subject_idx in range(n_subjects_bev):
313
- axes[2, 1].text(0.1, legend_y_positions_bev[subject_idx + 1],
314
- f"Subject {subject_idx}",
315
- color=colors_bev[subject_idx + 1], fontsize=12, fontweight='bold')
316
-
317
- axes[2, 1].set_title('BEV Segmentation Legend')
318
- axes[2, 1].axis('off')
319
- else:
320
- axes[2, 0].text(0.5, 0.5, 'NOT AVAILABLE',
321
- horizontalalignment='center', verticalalignment='center',
322
- transform=axes[2, 0].transAxes, fontsize=14, fontweight='bold')
323
- axes[2, 0].set_title('BEV Cuboids Segmentation')
324
- axes[2, 0].axis('off')
325
-
326
- axes[2, 1].text(0.5, 0.5, 'NOT AVAILABLE',
327
- horizontalalignment='center', verticalalignment='center',
328
- transform=axes[2, 1].transAxes, fontsize=14, fontweight='bold')
329
- axes[2, 1].set_title('BEV Segmentation Legend')
330
- axes[2, 1].axis('off')
331
-
332
- # Text prompt and call ID (2,2)
333
- axes[2, 2].text(0.5, 0.5, f'Text Prompt:\n\n"{text_prompt}"\n\nCall ID: {call_id}',
334
- horizontalalignment='center', verticalalignment='center',
335
- transform=axes[2, 2].transAxes, fontsize=12, wrap=True)
336
- axes[2, 2].set_title('Text Prompt & Call ID')
337
- axes[2, 2].axis('off')
338
-
339
- # Pixel values info (3,0)
340
- pixel_info = f'Pixel Values Shape: {batch["pixel_values"].shape}\n'
341
  if has_spatial_condition:
342
- pixel_info += f'Spatial Shape: {batch["cond_pixel_values"].shape}\n'
343
- if has_subject_condition:
344
- pixel_info += f'Subject Shape: {batch["subject_pixel_values"].shape}\n'
345
  if has_cuboids_segmasks:
346
- pixel_info += f'Cuboids Segmasks: {len(batch["cuboids_segmasks"])}\n'
347
- if has_cuboids_segmasks_bev:
348
- pixel_info += f'BEV Segmasks: {len(batch["cuboids_segmasks_bev"])}'
349
-
350
- axes[3, 0].text(0.5, 0.5, pixel_info,
351
- horizontalalignment='center', verticalalignment='center',
352
- transform=axes[3, 0].transAxes, fontsize=10, fontfamily='monospace')
353
- axes[3, 0].set_title('Tensor Shapes')
354
- axes[3, 0].axis('off')
355
-
356
- # Training info (3,1)
357
- training_info = f'Global Step: {global_step}\nConditions:\nSpatial: {"✓" if has_spatial_condition else "✗"}\nSubject: {"✓" if has_subject_condition else "✗"}\nSegmasks: {"✓" if has_cuboids_segmasks else "✗"}\nBEV Segmasks: {"✓" if has_cuboids_segmasks_bev else "✗"}'
358
- axes[3, 1].text(0.5, 0.5, training_info,
359
- horizontalalignment='center', verticalalignment='center',
360
- transform=axes[3, 1].transAxes, fontsize=12, fontfamily='monospace')
361
- axes[3, 1].set_title('Training Info')
362
- axes[3, 1].axis('off')
363
-
364
- # Additional info (3,2) - can be used for any extra debugging info
365
- axes[3, 2].text(0.5, 0.5, 'Additional Info\n(Reserved)',
366
- horizontalalignment='center', verticalalignment='center',
367
- transform=axes[3, 2].transAxes, fontsize=12, fontfamily='monospace')
368
- axes[3, 2].set_title('Reserved')
369
- axes[3, 2].axis('off')
370
-
371
  plt.tight_layout()
372
-
373
- # Save the visualization
374
  save_dir = os.path.join(args.output_dir, "visualizations")
375
  os.makedirs(save_dir, exist_ok=True)
376
  save_path = os.path.join(save_dir, f"training_vis_step_{global_step}.png")
377
  plt.savefig(save_path, dpi=150, bbox_inches='tight')
378
  plt.close()
379
-
380
  logger.info(f"Training visualization saved to {save_path}")
381
 
382
  vae = vae.to(vae_dtype)
383
 
384
- def log_validation(
385
- pipeline,
386
- args,
387
- accelerator,
388
- pipeline_args,
389
- step,
390
- torch_dtype,
391
- is_final_validation=False,
392
- ):
393
- logger.info(
394
- f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
395
- f" {pipeline_args['prompt']}."
396
- )
397
- pipeline = pipeline.to(accelerator.device)
398
- pipeline.set_progress_bar_config(disable=True)
399
- # run inference
400
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
401
- # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
402
- autocast_ctx = nullcontext()
403
-
404
- with autocast_ctx:
405
- images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
406
-
407
- # for tracker in accelerator.trackers:
408
- # phase_name = "test" if is_final_validation else "validation"
409
- # if tracker.name == "tensorboard":
410
- # np_images = np.stack([np.asarray(img) for img in images])
411
- # tracker.writer.add_images(phase_name, np_images, step, dataformats="NHWC")
412
- # if tracker.name == "wandb":
413
- # tracker.log(
414
- # {
415
- # phase_name: [
416
- # wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
417
- # ]
418
- # },
419
- # )
420
-
421
- return images
422
-
423
-
424
  def import_model_class_from_model_name_or_path(
425
  pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
426
  ):
@@ -444,9 +380,7 @@ def parse_args(input_args=None):
444
  parser = argparse.ArgumentParser(description="Simple example of a training script.")
445
  parser.add_argument("--lora_num", type=int, default=2, help="number of the lora.")
446
  parser.add_argument("--cond_size", type=int, default=512, help="size of the condition data.")
447
- parser.add_argument("--test_h", type=int, default=1024, help="max side of the training data.")
448
  parser.add_argument("--debug", type=int, default=0, help="whether to enter debug mode -- visualizations, gradient checks, etc.")
449
- parser.add_argument("--test_w", type=int, default=1024, help="max side of the training data.")
450
  parser.add_argument("--mode",type=str,default=None,help="The mode of the controller. Choose between ['depth', 'pose', 'canny'].")
451
  parser.add_argument("--run_name",type=str,required=True,help="the name of the wandb run")
452
  parser.add_argument(
@@ -462,7 +396,7 @@ def parse_args(input_args=None):
462
  parser.add_argument(
463
  "--inference_embeds_dir",
464
  type=str,
465
- default="",
466
  help=(
467
  "the captions for images"
468
  ),
@@ -474,13 +408,6 @@ def parse_args(input_args=None):
474
  required=False,
475
  help="Path to pretrained model or model identifier from huggingface.co/models.",
476
  )
477
- parser.add_argument(
478
- "--pretrained_lora_path",
479
- type=str,
480
- default=None,
481
- required=False,
482
- help="Path to pretrained model",
483
- )
484
  parser.add_argument(
485
  "--revision",
486
  type=str,
@@ -502,14 +429,6 @@ def parse_args(input_args=None):
502
  "default, the standard Image Dataset maps out 'file_name' "
503
  "to 'image'.",
504
  )
505
- parser.add_argument(
506
- "--subject_column",
507
- type=str,
508
- default="image",
509
- help="The column of the dataset containing the subject image. By "
510
- "default, the standard Image Dataset maps out 'file_name' "
511
- "to 'image'.",
512
- )
513
  parser.add_argument(
514
  "--target_column",
515
  type=str,
@@ -531,42 +450,6 @@ def parse_args(input_args=None):
531
  default=512,
532
  help="Maximum sequence length to use with with the T5 text encoder",
533
  )
534
- parser.add_argument(
535
- "--validation_prompt",
536
- type=str,
537
- nargs="+",
538
- default="A woodenpot floating in a pool.",
539
- help="A prompt that is used during validation to verify that the model is learning.",
540
- )
541
- parser.add_argument(
542
- "--subject_test_images",
543
- type=str,
544
- nargs="+",
545
- default=["/tiamat-NAS/zhangyuxuan/datasets/benchmark_dataset/decoritems_woodenpot/0.png"],
546
- help="A list of subject test image paths.",
547
- )
548
- parser.add_argument(
549
- "--spatial_test_images",
550
- type=str,
551
- nargs="+",
552
- default=[],
553
- help="A list of spatial test image paths.",
554
- )
555
- parser.add_argument(
556
- "--num_validation_images",
557
- type=int,
558
- default=4,
559
- help="Number of images that should be generated during validation with `validation_prompt`.",
560
- )
561
- parser.add_argument(
562
- "--validation_steps",
563
- type=int,
564
- default=20,
565
- help=(
566
- "Run validation every X epochs. validation consists of running the prompt"
567
- " `args.validation_prompt` multiple times: `args.num_validation_images`."
568
- ),
569
- )
570
  parser.add_argument(
571
  "--ranks",
572
  type=int,
@@ -591,13 +474,8 @@ def parse_args(input_args=None):
591
  parser.add_argument(
592
  "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
593
  )
594
- parser.add_argument("--num_train_epochs", type=int, default=50)
595
- parser.add_argument(
596
- "--max_train_steps",
597
- type=int,
598
- default=None,
599
- help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
600
- )
601
  parser.add_argument(
602
  "--checkpointing_steps",
603
  type=int,
@@ -608,12 +486,6 @@ def parse_args(input_args=None):
608
  " training using `--resume_from_checkpoint`."
609
  ),
610
  )
611
- parser.add_argument(
612
- "--checkpoints_total_limit",
613
- type=int,
614
- default=None,
615
- help=("Max number of checkpoints to store."),
616
- )
617
  parser.add_argument(
618
  "--resume_from_checkpoint",
619
  type=str,
@@ -761,16 +633,10 @@ def parse_args(input_args=None):
761
  " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
762
  ),
763
  )
764
- parser.add_argument(
765
- "--cache_latents",
766
- action="store_true",
767
- default=False,
768
- help="Cache the VAE latents",
769
- )
770
  parser.add_argument(
771
  "--report_to",
772
  type=str,
773
- default="wandb",
774
  help=(
775
  'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
776
  ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
@@ -811,9 +677,12 @@ def main(args):
811
  "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
812
  )
813
 
814
- if args.pretrained_lora_path is not None:
 
 
815
  assert osp.exists(args.pretrained_lora_path), f"Make sure that the `pretrained_lora_path` {args.pretrained_lora_path} exists."
816
- args.resume_from_checkpoint = osp.dirname(args.pretrained_lora_path)
 
817
 
818
  args.output_dir = osp.join(args.output_dir, args.run_name)
819
  args.logging_dir = osp.join(args.output_dir, args.logging_dir)
@@ -821,8 +690,6 @@ def main(args):
821
  os.makedirs(args.logging_dir, exist_ok=True)
822
  logging_dir = Path(args.output_dir, args.logging_dir)
823
 
824
- if args.subject_column == "None":
825
- args.subject_column = None
826
  if args.spatial_column == "None":
827
  args.spatial_column = None
828
 
@@ -831,7 +698,7 @@ def main(args):
831
  accelerator = Accelerator(
832
  gradient_accumulation_steps=args.gradient_accumulation_steps,
833
  mixed_precision=args.mixed_precision,
834
- log_with=args.report_to,
835
  project_config=accelerator_project_config,
836
  # kwargs_handlers=[kwargs],
837
  )
@@ -892,6 +759,17 @@ def main(args):
892
  noise_scheduler_copy = copy.deepcopy(noise_scheduler)
893
  gc.collect()
894
  torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
895
  vae = AutoencoderKL.from_pretrained(
896
  args.pretrained_model_name_or_path,
897
  subfolder="vae",
@@ -905,6 +783,9 @@ def main(args):
905
  # We only train the additional adapter LoRA layers
906
  transformer.requires_grad_(True)
907
  vae.requires_grad_(False)
 
 
 
908
 
909
  # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
910
  # as these weights are only used for inference, keeping weights in full precision is not required.
@@ -922,6 +803,9 @@ def main(args):
922
 
923
  vae.to(accelerator.device, dtype=weight_dtype)
924
  transformer.to(accelerator.device, dtype=weight_dtype)
 
 
 
925
 
926
  if args.gradient_checkpointing:
927
  transformer.enable_gradient_checkpointing()
@@ -1061,39 +945,43 @@ def main(args):
1061
 
1062
  tokenizers = [tokenizer_one, tokenizer_two]
1063
 
1064
- # # Dataset and DataLoaders creation:
1065
- # train_dataset = make_train_dataset(args, tokenizers, accelerator)
1066
- # train_dataloader = torch.utils.data.DataLoader(
1067
- # train_dataset,
1068
- # batch_size=args.train_batch_size,
1069
- # shuffle=True,
1070
- # collate_fn=collate_fn,
1071
- # num_workers=args.dataloader_num_workers,
1072
- # )
1073
-
1074
  # now, we will define a dataset for each epoch to make it easier to save the state
1075
  shuffled_jsonls = os.listdir(osp.dirname(args.train_data_dir))
1076
  base_jsonl_name = osp.basename(args.train_data_dir).replace(".jsonl", "")
1077
  shuffled_jsonls = sorted([_ for _ in shuffled_jsonls if _.endswith('.jsonl') and "shuffled" in _ and base_jsonl_name in _])
1078
  shuffled_jsonls = [osp.join(osp.dirname(args.train_data_dir), _) for _ in shuffled_jsonls]
1079
  print(f"{shuffled_jsonls = }")
1080
- # exit(0)
1081
  assert len(shuffled_jsonls) > 0, f"Make sure that there are shuffled jsonl files in {osp.dirname(args.train_data_dir)}"
1082
  train_dataloaders = []
1083
- for epoch in range(args.num_train_epochs): # prepare dataloader for each epoch, irrespective of the resume state
1084
  shuffled_idx = epoch % len(shuffled_jsonls)
1085
  train_data_file = shuffled_jsonls[shuffled_idx]
1086
  assert osp.exists(train_data_file), f"Make sure that the train data jsonl file {train_data_file} exists."
1087
  args.current_train_data_dir = train_data_file
1088
- train_dataset = make_train_dataset(args, tokenizers, accelerator)
1089
  train_dataloader = torch.utils.data.DataLoader(
1090
  train_dataset,
1091
  batch_size=args.train_batch_size,
1092
- shuffle=False, # yayy!! reproducible experiments!
1093
  collate_fn=collate_fn,
1094
  num_workers=args.dataloader_num_workers,
1095
  )
1096
  train_dataloaders.append(train_dataloader)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1097
 
1098
  vae_config_shift_factor = vae.config.shift_factor
1099
  vae_config_scaling_factor = vae.config.scaling_factor
@@ -1101,15 +989,14 @@ def main(args):
1101
  # Scheduler and math around the number of training steps.
1102
  overrode_max_train_steps = False
1103
  num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1104
- if args.max_train_steps is None:
1105
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1106
- overrode_max_train_steps = True
1107
 
1108
  lr_scheduler = get_scheduler(
1109
  args.lr_scheduler,
1110
  optimizer=optimizer,
1111
  num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1112
- num_training_steps=args.max_train_steps * accelerator.num_processes,
1113
  num_cycles=args.lr_num_cycles,
1114
  power=args.lr_power,
1115
  )
@@ -1121,15 +1008,11 @@ def main(args):
1121
  optimizer, lr_scheduler
1122
  )
1123
 
1124
- print(f"before preparation, {len(train_dataloaders[0]) = }")
1125
-
1126
  prepared_train_dataloaders = []
1127
  for train_dataloader in train_dataloaders:
1128
  prepared_train_dataloaders.append(accelerator.prepare(train_dataloader))
1129
  train_dataloaders = prepared_train_dataloaders
1130
 
1131
- print(f"after preparation, {len(train_dataloaders[0]) = }")
1132
-
1133
  if args.pretrained_lora_path is not None:
1134
  accelerator.load_state(osp.dirname(args.pretrained_lora_path))
1135
 
@@ -1144,26 +1027,14 @@ def main(args):
1144
  # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1145
  num_update_steps_per_epoch = math.ceil(len(train_dataloaders[0]) / args.gradient_accumulation_steps)
1146
  if overrode_max_train_steps:
1147
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1148
  # Afterwards we recalculate our number of training epochs
1149
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1150
 
1151
  # We need to initialize the trackers we use, and also store our configuration.
1152
- # The trackers initializes automatically on the main process.
1153
- # if accelerator.is_main_process:
1154
- # tracker_name = "Easy_Control"
1155
- # accelerator.init_trackers(tracker_name, config=vars(args))
1156
 
1157
  if accelerator.is_main_process:
1158
- tracker_config = vars(copy.deepcopy(args))
1159
- # tracker_config.pop("validation_images")
1160
- wandb_args = {
1161
- "wandb": {
1162
- "entity": "generative_parts",
1163
- "name": args.run_name,
1164
- }
1165
- }
1166
- accelerator.init_trackers("seethrough3d", config=tracker_config, init_kwargs=wandb_args)
1167
 
1168
 
1169
  # Train!
@@ -1172,14 +1043,14 @@ def main(args):
1172
  logger.info("***** Running training *****")
1173
  logger.info(f" Num examples = {len(train_dataset)}")
1174
  logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1175
- logger.info(f" Num Epochs = {args.num_train_epochs}")
1176
  logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1177
  logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1178
  logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1179
- logger.info(f" Total optimization steps = {args.max_train_steps}")
1180
 
1181
  progress_bar = tqdm(
1182
- range(0, args.max_train_steps),
1183
  initial=initial_global_step,
1184
  desc="Steps",
1185
  # Only show the progress bar once on each machine.
@@ -1201,13 +1072,22 @@ def main(args):
1201
  vae_scale_factor = 16
1202
  height_cond = 2 * (args.cond_size // vae_scale_factor)
1203
  width_cond = 2 * (args.cond_size // vae_scale_factor)
1204
- offset = 64
1205
 
1206
  num_training_visualizations = 10
1207
-
1208
  skip_steps = initial_global_step - first_epoch * num_update_steps_per_epoch
1209
- print(f"{skip_steps = }")
1210
- for epoch in range(first_epoch, args.num_train_epochs):
 
 
 
 
 
 
 
 
 
 
1211
  transformer.train()
1212
  train_dataloader = train_dataloaders[epoch] # use a new dataloader for each epoch
1213
  if epoch == first_epoch and skip_steps > 0:
@@ -1218,21 +1098,48 @@ def main(args):
1218
  enumerated_dataloader = enumerate(dataloader_iterator, start=skip_steps)
1219
  else:
1220
  enumerated_dataloader = enumerate(train_dataloader)
 
 
1221
  for step, batch in enumerated_dataloader:
1222
  progress_bar.set_description(f"epoch {epoch}, dataset_ids: {batch['index']}")
1223
- torch.cuda.empty_cache()
1224
  models_to_accumulate = [transformer]
1225
  with accelerator.accumulate(models_to_accumulate):
1226
 
1227
- # tokens = [batch["text_ids_1"], batch["text_ids_2"]]
1228
- # prompt_embeds, pooled_prompt_embeds, text_ids = encode_token_ids(text_encoders, tokens, accelerator)
1229
- prompt_embeds = batch["prompt_embeds"]
1230
- pooled_prompt_embeds = batch["pooled_prompt_embeds"]
1231
- text_ids = torch.zeros((batch["prompt_embeds"].shape[1], 3))
1232
- prompt_embeds = prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
1233
- pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
1234
- text_ids = text_ids.to(dtype=vae.dtype, device=accelerator.device)
1235
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1236
  pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1237
  height_ = 2 * (int(pixel_values.shape[-2]) // vae_scale_factor)
1238
  width_ = 2 * (int(pixel_values.shape[-1]) // vae_scale_factor)
@@ -1283,30 +1190,6 @@ def main(args):
1283
  latent_image_ids_to_concat = [latent_image_ids]
1284
  packed_cond_model_input_to_concat = []
1285
 
1286
- if args.subject_column is not None:
1287
- # in case the condition is not spatial
1288
- subject_pixel_values = batch["subject_pixel_values"].to(dtype=vae.dtype)
1289
- subject_input = vae.encode(subject_pixel_values).latent_dist.sample()
1290
- subject_input = (subject_input - vae_config_shift_factor) * vae_config_scaling_factor
1291
- subject_input = subject_input.to(dtype=weight_dtype)
1292
- # the number of subjects in the concatenated subject image
1293
- sub_number = subject_pixel_values.shape[-2] // args.cond_size
1294
- latent_subject_ids = prepare_latent_subject_ids(height_cond, width_cond, accelerator.device, weight_dtype)
1295
- latent_subject_ids[:, 1] += offset
1296
- sub_latent_image_ids = torch.concat([latent_subject_ids for _ in range(sub_number)], dim=-2)
1297
- latent_image_ids_to_concat.append(sub_latent_image_ids)
1298
-
1299
- packed_subject_model_input = FluxPipeline._pack_latents(
1300
- subject_input,
1301
- batch_size=subject_input.shape[0],
1302
- num_channels_latents=subject_input.shape[1],
1303
- height=subject_input.shape[2],
1304
- width=subject_input.shape[3],
1305
- )
1306
- packed_cond_model_input_to_concat.append(packed_subject_model_input)
1307
- else:
1308
- subject_input = None
1309
-
1310
  if args.spatial_column is not None:
1311
  # in case the condition is spatial
1312
  cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
@@ -1347,7 +1230,6 @@ def main(args):
1347
  model_input=model_input,
1348
  noisy_model_input=noisy_model_input,
1349
  cond_input=cond_input,
1350
- subject_input=subject_input,
1351
  args=args,
1352
  global_step=global_step,
1353
  accelerator=accelerator
@@ -1408,25 +1290,6 @@ def main(args):
1408
  if accelerator.is_main_process:
1409
  if global_step % args.checkpointing_steps == 0:
1410
  # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1411
- if args.checkpoints_total_limit is not None:
1412
- checkpoints = os.listdir(args.output_dir)
1413
- checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1414
- checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1415
-
1416
- # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1417
- if len(checkpoints) >= args.checkpoints_total_limit:
1418
- num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1419
- removing_checkpoints = checkpoints[0:num_to_remove]
1420
-
1421
- logger.info(
1422
- f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1423
- )
1424
- logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1425
-
1426
- for removing_checkpoint in removing_checkpoints:
1427
- removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1428
- shutil.rmtree(removing_checkpoint)
1429
-
1430
  save_path = os.path.join(args.output_dir, f"epoch-{epoch}__checkpoint-{global_step}")
1431
  os.makedirs(save_path, exist_ok=True)
1432
  unwrapped_model_state = accelerator.unwrap_model(transformer).state_dict()
 
42
  convert_unet_state_dict_to_peft
43
  )
44
 
 
45
  from src.lora_helper import *
46
  from src.pipeline import FluxPipeline, resize_position_encoding, prepare_latent_subject_ids
47
  from src.layers import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
 
59
  import matplotlib.pyplot as plt
60
  import torch
61
 
62
+
63
+ def load_text_encoders(args, class_one, class_two):
64
+ text_encoder_one = class_one.from_pretrained(
65
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
66
+ )
67
+ text_encoder_two = class_two.from_pretrained(
68
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
69
+ )
70
+ return text_encoder_one, text_encoder_two
71
+
72
+
73
+ def _encode_prompt_with_t5(
74
+ text_encoder,
75
+ tokenizer,
76
+ max_sequence_length=512,
77
+ prompt=None,
78
+ num_images_per_prompt=1,
79
+ device=None,
80
+ text_input_ids=None,
81
+ ):
82
+ prompt = [prompt] if isinstance(prompt, str) else prompt
83
+ batch_size = len(prompt)
84
+
85
+ if tokenizer is not None:
86
+ text_inputs = tokenizer(
87
+ prompt,
88
+ padding="max_length",
89
+ max_length=max_sequence_length,
90
+ truncation=True,
91
+ return_length=False,
92
+ return_overflowing_tokens=False,
93
+ return_tensors="pt",
94
+ )
95
+ text_input_ids = text_inputs.input_ids
96
  else:
97
+ if text_input_ids is None:
98
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
99
+
100
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
101
+
102
+ if hasattr(text_encoder, "module"):
103
+ dtype = text_encoder.module.dtype
 
 
 
104
  else:
105
+ dtype = text_encoder.dtype
106
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
107
+
108
+ _, seq_len, _ = prompt_embeds.shape
109
+
110
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
111
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
112
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
113
+
114
+ return prompt_embeds
115
+
116
+
117
+ def _encode_prompt_with_clip(
118
+ text_encoder,
119
+ tokenizer,
120
+ prompt: str,
121
+ device=None,
122
+ text_input_ids=None,
123
+ num_images_per_prompt: int = 1,
124
+ ):
125
+ prompt = [prompt] if isinstance(prompt, str) else prompt
126
+ batch_size = len(prompt)
127
+
128
+ if tokenizer is not None:
129
+ text_inputs = tokenizer(
130
+ prompt,
131
+ padding="max_length",
132
+ max_length=77,
133
+ truncation=True,
134
+ return_overflowing_tokens=False,
135
+ return_length=False,
136
+ return_tensors="pt",
137
+ )
138
+
139
+ text_input_ids = text_inputs.input_ids
140
+ else:
141
+ if text_input_ids is None:
142
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
143
+
144
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
145
+
146
+ if hasattr(text_encoder, "module"):
147
+ dtype = text_encoder.module.dtype
148
+ else:
149
+ dtype = text_encoder.dtype
150
+ # Use pooled output of CLIPTextModel
151
+ prompt_embeds = prompt_embeds.pooler_output
152
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
153
+
154
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
155
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
156
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
157
+
158
+ return prompt_embeds
159
+
160
+
161
+ def encode_prompt(
162
+ text_encoders,
163
+ tokenizers,
164
+ prompt: str,
165
+ max_sequence_length,
166
+ device=None,
167
+ num_images_per_prompt: int = 1,
168
+ text_input_ids_list=None,
169
+ ):
170
+ prompt = [prompt] if isinstance(prompt, str) else prompt
171
+
172
+ if hasattr(text_encoders[0], "module"):
173
+ dtype = text_encoders[0].module.dtype
174
+ else:
175
+ dtype = text_encoders[0].dtype
176
 
177
+ pooled_prompt_embeds = _encode_prompt_with_clip(
178
+ text_encoder=text_encoders[0],
179
+ tokenizer=tokenizers[0],
180
+ prompt=prompt,
181
+ device=device if device is not None else text_encoders[0].device,
182
+ num_images_per_prompt=num_images_per_prompt,
183
+ text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
184
+ )
185
+
186
+ prompt_embeds = _encode_prompt_with_t5(
187
+ text_encoder=text_encoders[1],
188
+ tokenizer=tokenizers[1],
189
+ max_sequence_length=max_sequence_length,
190
+ prompt=prompt,
191
+ num_images_per_prompt=num_images_per_prompt,
192
+ device=device if device is not None else text_encoders[1].device,
193
+ text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
194
+ )
195
+
196
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
197
+
198
+ return prompt_embeds, pooled_prompt_embeds, text_ids
199
 
200
+
201
+ def visualize_training_data(batch, vae, model_input, noisy_model_input, cond_input, args, global_step, accelerator):
202
  """
203
  Visualize training data including all entities from the batch.
204
 
 
208
  model_input: Clean latents before adding noise
209
  noisy_model_input: Noisy latents passed to transformer
210
  cond_input: Spatial condition latents (may be None)
 
211
  args: Training arguments
212
  global_step: Current training step
213
  accelerator: Accelerator instance
214
  """
215
+
216
+ has_spatial_condition = cond_input is not None
 
 
217
  has_cuboids_segmasks = "cuboids_segmasks" in batch and batch["cuboids_segmasks"] is not None
218
+
 
 
 
 
 
219
  with torch.no_grad():
 
220
  vae_config_shift_factor = vae.config.shift_factor
221
  vae_config_scaling_factor = vae.config.scaling_factor
222
+ vae_dtype = vae.dtype
223
+ vae = vae.to(torch.float32)
224
+
225
  # Decode spatial condition if available
226
  if has_spatial_condition:
227
  cond_for_decode = (cond_input / vae_config_scaling_factor) + vae_config_shift_factor
228
  spatial_decoded = vae.decode(cond_for_decode.float()).sample
229
+ spatial_decoded = (spatial_decoded / 2 + 0.5).clamp(0, 1)
230
  spatial_img = spatial_decoded[0].float().cpu().permute(1, 2, 0).numpy()
231
+ else:
232
+ spatial_img = None
233
+
234
+ # Decode clean and noisy model inputs
 
 
 
 
 
235
  clean_for_decode = (model_input / vae_config_scaling_factor) + vae_config_shift_factor
236
  clean_decoded = vae.decode(clean_for_decode.float()).sample
237
+ clean_img = (clean_decoded / 2 + 0.5).clamp(0, 1)[0].float().cpu().permute(1, 2, 0).numpy()
238
+
 
239
  noisy_for_decode = (noisy_model_input / vae_config_scaling_factor) + vae_config_shift_factor
240
  noisy_decoded = vae.decode(noisy_for_decode.float()).sample
241
+ noisy_img = (noisy_decoded / 2 + 0.5).clamp(0, 1)[0].float().cpu().permute(1, 2, 0).numpy()
242
+
 
 
 
 
 
243
  text_prompt = batch["prompts"][0] if isinstance(batch["prompts"], list) else batch["prompts"]
244
  call_id = batch["call_ids"][0] if batch["call_ids"] is not None else "N/A"
245
+
246
+ # 3x3 grid layout:
247
+ # Row 0: Spatial Condition | Clean Model Input | Noisy Model Input
248
+ # Row 1: Cuboids Segmentation | Segmentation Legend | Text Prompt & Call ID
249
+ # Row 2: Tensor Shapes | Training Info | (hidden)
250
+ fig, axes = plt.subplots(3, 3, figsize=(18, 18))
251
+
252
+ # --- Row 0: images ---
253
+ if has_spatial_condition:
254
  axes[0, 0].imshow(spatial_img)
 
255
  else:
256
+ axes[0, 0].text(0.5, 0.5, 'NOT AVAILABLE',
257
+ ha='center', va='center', transform=axes[0, 0].transAxes,
258
+ fontsize=14, fontweight='bold')
259
+ axes[0, 0].set_title('Spatial Condition')
260
  axes[0, 0].axis('off')
261
+
262
+ axes[0, 1].imshow(clean_img)
263
+ axes[0, 1].set_title('Clean Model Input (Target)')
 
 
 
 
 
 
 
264
  axes[0, 1].axis('off')
265
+
266
+ axes[0, 2].imshow(noisy_img)
267
+ axes[0, 2].set_title('Noisy Model Input')
 
268
  axes[0, 2].axis('off')
269
+
270
+ # --- Row 1: segmentation ---
 
 
 
 
 
271
  if has_cuboids_segmasks:
272
+ segmask = batch["cuboids_segmasks"][0].float().cpu().numpy() # (n_subjects, h, w)
273
  n_subjects, h, w = segmask.shape
274
+ n_show = min(4, n_subjects)
275
+
276
+ np.random.seed(42)
277
+ colors = np.random.rand(n_show + 1, 3)
278
+ colors[0] = [0, 0, 0] # background black
279
+
280
+ # 2x2 grid of individual subject masks
281
+ combined_mask = np.zeros((h * 2, w * 2, 3))
282
+ for idx in range(n_show):
283
+ row_i, col_i = idx // 2, idx % 2
 
 
 
 
 
 
 
 
284
  subject_mask = np.zeros((h, w, 3))
285
+ subject_mask[segmask[idx] > 0.5] = colors[idx + 1]
286
+ combined_mask[row_i*h:(row_i+1)*h, col_i*w:(col_i+1)*w] = subject_mask
287
+
288
+ axes[1, 0].imshow(combined_mask)
289
+ axes[1, 0].set_title(f'Cuboids Segmentation (first {n_show}, 2×2 grid)')
290
+ axes[1, 0].axis('off')
291
+
292
+ # Legend
293
+ axes[1, 1].set_xlim(0, 1)
294
+ axes[1, 1].set_ylim(0, 1)
295
+ y_positions = np.linspace(0.9, 0.1, n_show + 1)
296
+ axes[1, 1].text(0.1, y_positions[0], 'Background',
297
+ color='white', fontsize=12, fontweight='bold',
298
+ bbox=dict(facecolor='black', boxstyle='round,pad=0.2'))
299
+ for i in range(n_show):
300
+ axes[1, 1].text(0.1, y_positions[i + 1], f'Object {i}',
301
+ color=colors[i + 1], fontsize=12, fontweight='bold')
302
+ axes[1, 1].set_title('Segmentation Legend')
303
  axes[1, 1].axis('off')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  else:
305
+ for col in range(2):
306
+ axes[1, col].text(0.5, 0.5, 'NOT AVAILABLE',
307
+ ha='center', va='center', transform=axes[1, col].transAxes,
308
+ fontsize=14, fontweight='bold')
309
+ axes[1, col].axis('off')
310
+ axes[1, 0].set_title('Cuboids Segmentation')
311
+ axes[1, 1].set_title('Segmentation Legend')
312
+
313
+ # Text prompt and call ID
314
+ axes[1, 2].text(0.5, 0.5, f'Prompt:\n\n"{text_prompt}"\n\nCall ID:\n{call_id}',
315
+ ha='center', va='center', transform=axes[1, 2].transAxes,
316
+ fontsize=11, wrap=True)
317
+ axes[1, 2].set_title('Text Prompt & Call ID')
318
+ axes[1, 2].axis('off')
319
+
320
+ # --- Row 2: info ---
321
+ pixel_info = f'pixel_values: {tuple(batch["pixel_values"].shape)}\n'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  if has_spatial_condition:
323
+ pixel_info += f'cond_pixel_values: {tuple(batch["cond_pixel_values"].shape)}\n'
 
 
324
  if has_cuboids_segmasks:
325
+ seg = batch["cuboids_segmasks"]
326
+ pixel_info += f'cuboids_segmasks: {tuple(seg[0].shape) if hasattr(seg[0], "shape") else len(seg)} items\n'
327
+
328
+ axes[2, 0].text(0.5, 0.5, pixel_info,
329
+ ha='center', va='center', transform=axes[2, 0].transAxes,
330
+ fontsize=10, fontfamily='monospace')
331
+ axes[2, 0].set_title('Tensor Shapes')
332
+ axes[2, 0].axis('off')
333
+
334
+ training_info = (
335
+ f'Global Step: {global_step}\n\n'
336
+ f'Conditions:\n'
337
+ f' Spatial: {"✓" if has_spatial_condition else "✗"}\n'
338
+ f' Segmasks: {"✓" if has_cuboids_segmasks else "✗"}'
339
+ )
340
+ axes[2, 1].text(0.5, 0.5, training_info,
341
+ ha='center', va='center', transform=axes[2, 1].transAxes,
342
+ fontsize=12, fontfamily='monospace')
343
+ axes[2, 1].set_title('Training Info')
344
+ axes[2, 1].axis('off')
345
+
346
+ axes[2, 2].axis('off') # unused slot
347
+
 
 
348
  plt.tight_layout()
349
+
 
350
  save_dir = os.path.join(args.output_dir, "visualizations")
351
  os.makedirs(save_dir, exist_ok=True)
352
  save_path = os.path.join(save_dir, f"training_vis_step_{global_step}.png")
353
  plt.savefig(save_path, dpi=150, bbox_inches='tight')
354
  plt.close()
355
+
356
  logger.info(f"Training visualization saved to {save_path}")
357
 
358
  vae = vae.to(vae_dtype)
359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  def import_model_class_from_model_name_or_path(
361
  pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
362
  ):
 
380
  parser = argparse.ArgumentParser(description="Simple example of a training script.")
381
  parser.add_argument("--lora_num", type=int, default=2, help="number of the lora.")
382
  parser.add_argument("--cond_size", type=int, default=512, help="size of the condition data.")
 
383
  parser.add_argument("--debug", type=int, default=0, help="whether to enter debug mode -- visualizations, gradient checks, etc.")
 
384
  parser.add_argument("--mode",type=str,default=None,help="The mode of the controller. Choose between ['depth', 'pose', 'canny'].")
385
  parser.add_argument("--run_name",type=str,required=True,help="the name of the wandb run")
386
  parser.add_argument(
 
396
  parser.add_argument(
397
  "--inference_embeds_dir",
398
  type=str,
399
+ default=None,
400
  help=(
401
  "the captions for images"
402
  ),
 
408
  required=False,
409
  help="Path to pretrained model or model identifier from huggingface.co/models.",
410
  )
 
 
 
 
 
 
 
411
  parser.add_argument(
412
  "--revision",
413
  type=str,
 
429
  "default, the standard Image Dataset maps out 'file_name' "
430
  "to 'image'.",
431
  )
 
 
 
 
 
 
 
 
432
  parser.add_argument(
433
  "--target_column",
434
  type=str,
 
450
  default=512,
451
  help="Maximum sequence length to use with with the T5 text encoder",
452
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  parser.add_argument(
454
  "--ranks",
455
  type=int,
 
474
  parser.add_argument(
475
  "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
476
  )
477
+ parser.add_argument("--stage1_epochs", type=int, default=50)
478
+ parser.add_argument("--stage2_steps", type=int, default=5000)
 
 
 
 
 
479
  parser.add_argument(
480
  "--checkpointing_steps",
481
  type=int,
 
486
  " training using `--resume_from_checkpoint`."
487
  ),
488
  )
 
 
 
 
 
 
489
  parser.add_argument(
490
  "--resume_from_checkpoint",
491
  type=str,
 
633
  " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
634
  ),
635
  )
 
 
 
 
 
 
636
  parser.add_argument(
637
  "--report_to",
638
  type=str,
639
+ default="tensorboard",
640
  help=(
641
  'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
642
  ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
 
677
  "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
678
  )
679
 
680
+ if args.resume_from_checkpoint is not None:
681
+ assert osp.exists(args.resume_from_checkpoint), f"Make sure that the `resume_from_checkpoint` {args.resume_from_checkpoint} exists."
682
+ args.pretrained_lora_path = osp.join(args.resume_from_checkpoint, f"lora.safetensors")
683
  assert osp.exists(args.pretrained_lora_path), f"Make sure that the `pretrained_lora_path` {args.pretrained_lora_path} exists."
684
+ else:
685
+ args.pretrained_lora_path = None
686
 
687
  args.output_dir = osp.join(args.output_dir, args.run_name)
688
  args.logging_dir = osp.join(args.output_dir, args.logging_dir)
 
690
  os.makedirs(args.logging_dir, exist_ok=True)
691
  logging_dir = Path(args.output_dir, args.logging_dir)
692
 
 
 
693
  if args.spatial_column == "None":
694
  args.spatial_column = None
695
 
 
698
  accelerator = Accelerator(
699
  gradient_accumulation_steps=args.gradient_accumulation_steps,
700
  mixed_precision=args.mixed_precision,
701
+ log_with=args.report_to,
702
  project_config=accelerator_project_config,
703
  # kwargs_handlers=[kwargs],
704
  )
 
759
  noise_scheduler_copy = copy.deepcopy(noise_scheduler)
760
  gc.collect()
761
  torch.cuda.empty_cache()
762
+
763
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
764
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder"
765
+ )
766
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
767
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
768
+ )
769
+ if args.inference_embeds_dir is None:
770
+ text_encoder_one, text_encoder_two = load_text_encoders(args, text_encoder_cls_one, text_encoder_cls_two)
771
+ else:
772
+ assert osp.exists(args.inference_embeds_dir), f"Make sure that the `inference_embeds_dir` {args.inference_embeds_dir} exists."
773
  vae = AutoencoderKL.from_pretrained(
774
  args.pretrained_model_name_or_path,
775
  subfolder="vae",
 
783
  # We only train the additional adapter LoRA layers
784
  transformer.requires_grad_(True)
785
  vae.requires_grad_(False)
786
+ if args.inference_embeds_dir is None:
787
+ text_encoder_one.requires_grad_(False)
788
+ text_encoder_two.requires_grad_(False)
789
 
790
  # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
791
  # as these weights are only used for inference, keeping weights in full precision is not required.
 
803
 
804
  vae.to(accelerator.device, dtype=weight_dtype)
805
  transformer.to(accelerator.device, dtype=weight_dtype)
806
+ if args.inference_embeds_dir is None:
807
+ text_encoder_one.to(accelerator.device, dtype=torch.float32)
808
+ text_encoder_two.to(accelerator.device, dtype=torch.float32)
809
 
810
  if args.gradient_checkpointing:
811
  transformer.enable_gradient_checkpointing()
 
945
 
946
  tokenizers = [tokenizer_one, tokenizer_two]
947
 
 
 
 
 
 
 
 
 
 
 
948
  # now, we will define a dataset for each epoch to make it easier to save the state
949
  shuffled_jsonls = os.listdir(osp.dirname(args.train_data_dir))
950
  base_jsonl_name = osp.basename(args.train_data_dir).replace(".jsonl", "")
951
  shuffled_jsonls = sorted([_ for _ in shuffled_jsonls if _.endswith('.jsonl') and "shuffled" in _ and base_jsonl_name in _])
952
  shuffled_jsonls = [osp.join(osp.dirname(args.train_data_dir), _) for _ in shuffled_jsonls]
953
  print(f"{shuffled_jsonls = }")
 
954
  assert len(shuffled_jsonls) > 0, f"Make sure that there are shuffled jsonl files in {osp.dirname(args.train_data_dir)}"
955
  train_dataloaders = []
956
+ for epoch in range(args.stage1_epochs): # prepare dataloader for each epoch, irrespective of the resume state
957
  shuffled_idx = epoch % len(shuffled_jsonls)
958
  train_data_file = shuffled_jsonls[shuffled_idx]
959
  assert osp.exists(train_data_file), f"Make sure that the train data jsonl file {train_data_file} exists."
960
  args.current_train_data_dir = train_data_file
961
+ train_dataset = make_train_dataset(args, tokenizers, accelerator, 512)
962
  train_dataloader = torch.utils.data.DataLoader(
963
  train_dataset,
964
  batch_size=args.train_batch_size,
965
+ shuffle=False,
966
  collate_fn=collate_fn,
967
  num_workers=args.dataloader_num_workers,
968
  )
969
  train_dataloaders.append(train_dataloader)
970
+
971
+ if args.stage2_steps is not None:
972
+ args.current_train_data_dir = shuffled_jsonls[0]
973
+ train_dataset_stage2 = make_train_dataset(args, tokenizers, accelerator, 1024, only_realistic_images=True)
974
+ n_stage2 = min(args.stage2_steps * args.train_batch_size * args.gradient_accumulation_steps * accelerator.num_processes, len(train_dataset_stage2))
975
+ print(f"Stage2: subsetting dataset from {len(train_dataset_stage2)} to {n_stage2} examples")
976
+ train_dataset_stage2 = torch.utils.data.Subset(train_dataset_stage2, list(range(n_stage2)))
977
+ train_dataloader_stage2 = torch.utils.data.DataLoader(
978
+ train_dataset_stage2,
979
+ batch_size=args.train_batch_size,
980
+ shuffle=False,
981
+ collate_fn=collate_fn,
982
+ num_workers=args.dataloader_num_workers,
983
+ )
984
+ train_dataloaders.append(train_dataloader_stage2)
985
 
986
  vae_config_shift_factor = vae.config.shift_factor
987
  vae_config_scaling_factor = vae.config.scaling_factor
 
989
  # Scheduler and math around the number of training steps.
990
  overrode_max_train_steps = False
991
  num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
992
+ stage1_steps = args.stage1_epochs * num_update_steps_per_epoch
993
+ overrode_max_train_steps = True
 
994
 
995
  lr_scheduler = get_scheduler(
996
  args.lr_scheduler,
997
  optimizer=optimizer,
998
  num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
999
+ num_training_steps=stage1_steps * accelerator.num_processes,
1000
  num_cycles=args.lr_num_cycles,
1001
  power=args.lr_power,
1002
  )
 
1008
  optimizer, lr_scheduler
1009
  )
1010
 
 
 
1011
  prepared_train_dataloaders = []
1012
  for train_dataloader in train_dataloaders:
1013
  prepared_train_dataloaders.append(accelerator.prepare(train_dataloader))
1014
  train_dataloaders = prepared_train_dataloaders
1015
 
 
 
1016
  if args.pretrained_lora_path is not None:
1017
  accelerator.load_state(osp.dirname(args.pretrained_lora_path))
1018
 
 
1027
  # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1028
  num_update_steps_per_epoch = math.ceil(len(train_dataloaders[0]) / args.gradient_accumulation_steps)
1029
  if overrode_max_train_steps:
1030
+ stage1_steps = args.stage1_epochs * num_update_steps_per_epoch
1031
  # Afterwards we recalculate our number of training epochs
1032
+ args.stage1_epochs = math.ceil(stage1_steps / num_update_steps_per_epoch)
1033
 
1034
  # We need to initialize the trackers we use, and also store our configuration.
 
 
 
 
1035
 
1036
  if accelerator.is_main_process:
1037
+ accelerator.init_trackers(args.run_name)
 
 
 
 
 
 
 
 
1038
 
1039
 
1040
  # Train!
 
1043
  logger.info("***** Running training *****")
1044
  logger.info(f" Num examples = {len(train_dataset)}")
1045
  logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1046
+ logger.info(f" Num Epochs = {args.stage1_epochs}")
1047
  logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1048
  logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1049
  logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1050
+ logger.info(f" Total optimization steps = {stage1_steps}")
1051
 
1052
  progress_bar = tqdm(
1053
+ range(0, stage1_steps + args.stage2_steps),
1054
  initial=initial_global_step,
1055
  desc="Steps",
1056
  # Only show the progress bar once on each machine.
 
1072
  vae_scale_factor = 16
1073
  height_cond = 2 * (args.cond_size // vae_scale_factor)
1074
  width_cond = 2 * (args.cond_size // vae_scale_factor)
 
1075
 
1076
  num_training_visualizations = 10
1077
+
1078
  skip_steps = initial_global_step - first_epoch * num_update_steps_per_epoch
1079
+
1080
+ # Estimate total training steps across all dataloaders
1081
+ total_steps_estimate = sum(
1082
+ math.ceil(len(dl) / args.gradient_accumulation_steps) for dl in train_dataloaders
1083
+ )
1084
+ logger.info(f"Estimated total steps across all dataloaders: {total_steps_estimate}")
1085
+ for i, dl in enumerate(train_dataloaders):
1086
+ steps_i = math.ceil(len(dl) / args.gradient_accumulation_steps)
1087
+ label = f"epoch-{i}" if i < args.stage1_epochs else "stage2"
1088
+ logger.info(f" {label}: {len(dl)} batches → {steps_i} steps")
1089
+
1090
+ for epoch in range(first_epoch, len(train_dataloaders)):
1091
  transformer.train()
1092
  train_dataloader = train_dataloaders[epoch] # use a new dataloader for each epoch
1093
  if epoch == first_epoch and skip_steps > 0:
 
1098
  enumerated_dataloader = enumerate(dataloader_iterator, start=skip_steps)
1099
  else:
1100
  enumerated_dataloader = enumerate(train_dataloader)
1101
+ if epoch == first_epoch:
1102
+ continue
1103
  for step, batch in enumerated_dataloader:
1104
  progress_bar.set_description(f"epoch {epoch}, dataset_ids: {batch['index']}")
 
1105
  models_to_accumulate = [transformer]
1106
  with accelerator.accumulate(models_to_accumulate):
1107
 
1108
+ if args.inference_embeds_dir is None:
1109
+ print(f"encoding {batch['prompts'] = }")
1110
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
1111
+ text_encoders=[text_encoder_one, text_encoder_two],
1112
+ tokenizers=[tokenizer_one, tokenizer_two],
1113
+ prompt=batch["prompts"],
1114
+ max_sequence_length=512,
1115
+ device=accelerator.device,
1116
+ )
1117
+ # for i, prompt in enumerate(batch["prompts"]):
1118
+ # # prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
1119
+ # # text_encoders=[text_encoder_one, text_encoder_two],
1120
+ # # tokenizers=[tokenizer_one, tokenizer_two],
1121
+ # # prompt=prompt,
1122
+ # # max_sequence_length=512,
1123
+ # # device=accelerator.device,
1124
+ # # )
1125
+ # print(f"{prompt_embeds.shape = }, {pooled_prompt_embeds.shape = }, {text_ids.shape = }")
1126
+ # # checking if the cached embeddings match
1127
+ # inference_embeds_dir = "/archive/vaibhav.agrawal/a-bev-of-the-latents/inference_embeds_datasetv7_superhard"
1128
+ # cached_prompt_path = osp.join(inference_embeds_dir, f"{'_'.join(prompt.lower().split())}.pth")
1129
+ # assert osp.exists(cached_prompt_path), f"Make sure that the cached prompt embedding {cached_prompt_path} exists."
1130
+ # cached_prompt_embeds = torch.load(cached_prompt_path, map_location="cpu")
1131
+ # assert torch.allclose(cached_prompt_embeds["prompt_embeds"].cpu().float(), prompt_embeds[i].cpu().float(), atol=1e-3), f"Cached prompt embeds for prompt {prompt} do not match the computed prompt embeds. Make sure that the cached prompt embeds are correct., {torch.mean(torch.abs(cached_prompt_embeds['prompt_embeds'].cpu().float() - prompt_embeds[i].cpu().float())) = }, {torch.mean(torch.abs(cached_prompt_embeds['prompt_embeds'].cpu().float())) = }"
1132
+ # assert torch.allclose(cached_prompt_embeds["pooled_prompt_embeds"].cpu().float(), pooled_prompt_embeds[i].cpu().float(), atol=1e-3), f"Cached pooled prompt embeds for prompt {prompt} do not match the computed pooled prompt embeds. Make sure that the cached pooled prompt embeds are correct., {torch.mean(torch.abs(cached_prompt_embeds['pooled_prompt_embeds'].cpu().float() - pooled_prompt_embeds[i].cpu().float())) = }"
1133
+ else:
1134
+ assert "prompt_embeds" in batch and "pooled_prompt_embeds" in batch, "Make sure that the dataloader returns `prompt_embeds` and `pooled_prompt_embeds` when `inference_embeds_dir` is not None."
1135
+ prompt_embeds = batch["prompt_embeds"]
1136
+ pooled_prompt_embeds = batch["pooled_prompt_embeds"]
1137
+ text_ids = torch.zeros((batch["prompt_embeds"].shape[1], 3))
1138
+ prompt_embeds = prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
1139
+ pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
1140
+ text_ids = text_ids.to(dtype=vae.dtype, device=accelerator.device)
1141
+
1142
+
1143
  pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1144
  height_ = 2 * (int(pixel_values.shape[-2]) // vae_scale_factor)
1145
  width_ = 2 * (int(pixel_values.shape[-1]) // vae_scale_factor)
 
1190
  latent_image_ids_to_concat = [latent_image_ids]
1191
  packed_cond_model_input_to_concat = []
1192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1193
  if args.spatial_column is not None:
1194
  # in case the condition is spatial
1195
  cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
 
1230
  model_input=model_input,
1231
  noisy_model_input=noisy_model_input,
1232
  cond_input=cond_input,
 
1233
  args=args,
1234
  global_step=global_step,
1235
  accelerator=accelerator
 
1290
  if accelerator.is_main_process:
1291
  if global_step % args.checkpointing_steps == 0:
1292
  # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1293
  save_path = os.path.join(args.output_dir, f"epoch-{epoch}__checkpoint-{global_step}")
1294
  os.makedirs(save_path, exist_ok=True)
1295
  unwrapped_model_state = accelerator.unwrap_model(transformer).state_dict()
train/train.sh CHANGED
@@ -1,14 +1,3 @@
1
- #!/bin/bash
2
- #SBATCH --job-name=vaibhav
3
- #SBATCH --output=%j.out
4
- #SBATCH --ntasks=1
5
- #SBATCH --cpus-per-task=4
6
- #SBATCH --mem=150G
7
- #SBATCH --gres=gpu:4
8
- #SBATCH --partition=ada
9
-
10
- # chetna
11
- # export MODEL_DIR="black-forest-labs/FLUX.1-Kontext-dev" # your flux path
12
  export MODEL_DIR="black-forest-labs/FLUX.1-dev" # your flux path
13
  export OUTPUT_DIR="/archive/vaibhav.agrawal/a-bev-of-the-latents/easycontrol_cuboids" # your save path
14
  export CONFIG="./default_config.yaml"
@@ -16,29 +5,9 @@ export TRAIN_DATA="/archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_super
16
  export LOG_PATH="$OUTPUT_DIR/log"
17
  export INFERENCE_EMBEDS_DIR="/archive/vaibhav.agrawal/a-bev-of-the-latents/inference_embeds_datasetv7_superhard"
18
 
19
- export WANDB_API_KEY=f27c837d8d7d0c8d79f3eb1de21fa78233c03be6
20
-
21
- # kotak
22
- # export MODEL_DIR="black-forest-labs/FLUX.1-dev" # your flux path
23
- # export OUTPUT_DIR="/archive/vaibhav.agrawal/a-bev-of-the-latents/easycontrol_cuboids" # your save path
24
- # export CONFIG="./default_config.yaml"
25
- # export TRAIN_DATA="/archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv6/cuboids.jsonl" # your data jsonl file
26
- # export LOG_PATH="$OUTPUT_DIR/log"
27
- # export INFERENCE_EMBEDS_DIR="/archive/vaibhav.agrawal/a-bev-of-the-latents/inference_embeds_flux2"
28
-
29
- # kotak
30
- # export MODEL_DIR="black-forest-labs/FLUX.1-dev" # your flux path
31
- # export OUTPUT_DIR="./easycontrol_cuboids" # your save path
32
- # export CONFIG="./default_config.yaml"
33
- # export TRAIN_DATA="/home/venky/vaibhav.agrawal/a-bev-of-the-latents/datasets/actual_data/datasetv6/cuboids.jsonl" # your data jsonl file
34
- # export LOG_PATH="$OUTPUT_DIR/log"
35
- # export INFERENCE_EMBEDS_DIR="/home/venky/vaibhav.agrawal/a-bev-of-the-latents/caching/inference_embeds_flux2"
36
-
37
- # i love this.
38
  accelerate launch --config_file $CONFIG train.py \
39
  --pretrained_model_name_or_path $MODEL_DIR \
40
  --cond_size=512 \
41
- --subject_column="None" \
42
  --spatial_column="cv" \
43
  --target_column="target" \
44
  --caption_column="caption" \
@@ -47,25 +16,12 @@ accelerate launch --config_file $CONFIG train.py \
47
  --lora_num 1 \
48
  --output_dir=$OUTPUT_DIR \
49
  --logging_dir=$LOG_PATH \
50
- --run_name="rgb__r1" \
51
  --debug=1 \
52
  --mixed_precision="bf16" \
53
  --train_data_dir=$TRAIN_DATA \
54
  --learning_rate=1e-4 \
55
  --train_batch_size=1 \
56
- --inference_embeds_dir $INFERENCE_EMBEDS_DIR \
57
- --validation_prompt "a photo of sedan and pickup truck and suv amongst autumn-colored trees along a winding river" "a photo of cow and suv on a sandy beach with palm trees swaying in the breeze" "a photo of table and horse and suv in a dense pine forest with tall trees reaching the sky" \
58
- --num_train_epochs=1 \
59
- --validation_steps=5000000000000 \
60
- --checkpointing_steps=2500 \
61
- --spatial_test_images "cuboids/sedan__pickup_truck__suv/005/cuboids.png" "cuboids/cow__suv/008/cuboids.png" "cuboids/table__horse__suv/007/cuboids.png" \
62
- --subject_test_images None \
63
- --test_h 512 \
64
- --test_w 512 \
65
- --num_validation_images=1
66
-
67
- # --run_name="semantic_info_from_cuboid_cond" \
68
- # --run_name="datasetv8__0.8_0.1_0.1" \
69
- # --pretrained_lora_path="/archive/vaibhav.agrawal/a-bev-of-the-latents/easycontrol_cuboids/wireframe/epoch-0__checkpoint-5000/lora.safetensors" \
70
- # --pretrained_lora_path="/archive/vaibhav.agrawal/a-bev-of-the-latents/easycontrol_cuboids/rgb/epoch-0__checkpoint-7500/lora.safetensors" \
71
- # --pretrained_lora_path="/archive/vaibhav.agrawal/a-bev-of-the-latents/easycontrol_cuboids/datasetv9__wireframe_best_case/epoch-0__checkpoint-3888/lora.safetensors" \
 
 
 
 
 
 
 
 
 
 
 
 
1
  export MODEL_DIR="black-forest-labs/FLUX.1-dev" # your flux path
2
  export OUTPUT_DIR="/archive/vaibhav.agrawal/a-bev-of-the-latents/easycontrol_cuboids" # your save path
3
  export CONFIG="./default_config.yaml"
 
5
  export LOG_PATH="$OUTPUT_DIR/log"
6
  export INFERENCE_EMBEDS_DIR="/archive/vaibhav.agrawal/a-bev-of-the-latents/inference_embeds_datasetv7_superhard"
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  accelerate launch --config_file $CONFIG train.py \
9
  --pretrained_model_name_or_path $MODEL_DIR \
10
  --cond_size=512 \
 
11
  --spatial_column="cv" \
12
  --target_column="target" \
13
  --caption_column="caption" \
 
16
  --lora_num 1 \
17
  --output_dir=$OUTPUT_DIR \
18
  --logging_dir=$LOG_PATH \
19
+ --run_name="seethrough3d" \
20
  --debug=1 \
21
  --mixed_precision="bf16" \
22
  --train_data_dir=$TRAIN_DATA \
23
  --learning_rate=1e-4 \
24
  --train_batch_size=1 \
25
+ --stage1_epochs=1 \
26
+ --stage2_steps=5000 \
27
+ --checkpointing_steps=2500
 
 
 
 
 
 
 
 
 
 
 
 
 
train/train_notworking.py ADDED
@@ -0,0 +1,1397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import logging
4
+ import random
5
+ import math
6
+ import os
7
+ import shutil
8
+ import gc
9
+ from contextlib import nullcontext
10
+ from pathlib import Path
11
+ import re
12
+ from safetensors.torch import save_file
13
+
14
+ from PIL import Image
15
+ import numpy as np
16
+ import torch.utils.checkpoint
17
+ import transformers
18
+ from accelerate import Accelerator
19
+ from accelerate.logging import get_logger
20
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
21
+
22
+ from tqdm.auto import tqdm
23
+ from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
24
+
25
+ import diffusers
26
+
27
+ from diffusers import (
28
+ AutoencoderKL,
29
+ FlowMatchEulerDiscreteScheduler
30
+ )
31
+ from diffusers.optimization import get_scheduler
32
+ from diffusers.training_utils import (
33
+ cast_training_params,
34
+ compute_density_for_timestep_sampling,
35
+ compute_loss_weighting_for_sd3,
36
+ )
37
+ import os.path as osp
38
+ from diffusers.utils.torch_utils import is_compiled_module
39
+ from diffusers.utils import (
40
+ check_min_version,
41
+ is_wandb_available,
42
+ convert_unet_state_dict_to_peft
43
+ )
44
+
45
+ from src.lora_helper import *
46
+ from src.pipeline import FluxPipeline, resize_position_encoding, prepare_latent_subject_ids
47
+ from src.layers import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
48
+ from src.transformer_flux import FluxTransformer2DModel
49
+ from src.jsonl_datasets import make_train_dataset, collate_fn
50
+
51
+ if is_wandb_available():
52
+ import wandb
53
+
54
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
55
+ check_min_version("0.31.0.dev0")
56
+
57
+ logger = get_logger(__name__)
58
+
59
+ import matplotlib.pyplot as plt
60
+ import torch
61
+
62
+
63
+ def load_text_encoders(args, class_one, class_two):
64
+ text_encoder_one = class_one.from_pretrained(
65
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
66
+ )
67
+ text_encoder_two = class_two.from_pretrained(
68
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
69
+ )
70
+ return text_encoder_one, text_encoder_two
71
+
72
+
73
+ def _encode_prompt_with_t5(
74
+ text_encoder,
75
+ tokenizer,
76
+ max_sequence_length=512,
77
+ prompt=None,
78
+ num_images_per_prompt=1,
79
+ device=None,
80
+ text_input_ids=None,
81
+ ):
82
+ prompt = [prompt] if isinstance(prompt, str) else prompt
83
+ batch_size = len(prompt)
84
+
85
+ if tokenizer is not None:
86
+ text_inputs = tokenizer(
87
+ prompt,
88
+ padding="max_length",
89
+ max_length=max_sequence_length,
90
+ truncation=True,
91
+ return_length=False,
92
+ return_overflowing_tokens=False,
93
+ return_tensors="pt",
94
+ )
95
+ text_input_ids = text_inputs.input_ids
96
+ else:
97
+ if text_input_ids is None:
98
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
99
+
100
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
101
+
102
+ if hasattr(text_encoder, "module"):
103
+ dtype = text_encoder.module.dtype
104
+ else:
105
+ dtype = text_encoder.dtype
106
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
107
+
108
+ _, seq_len, _ = prompt_embeds.shape
109
+
110
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
111
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
112
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
113
+
114
+ return prompt_embeds
115
+
116
+
117
+ def _encode_prompt_with_clip(
118
+ text_encoder,
119
+ tokenizer,
120
+ prompt: str,
121
+ device=None,
122
+ text_input_ids=None,
123
+ num_images_per_prompt: int = 1,
124
+ ):
125
+ prompt = [prompt] if isinstance(prompt, str) else prompt
126
+ batch_size = len(prompt)
127
+
128
+ if tokenizer is not None:
129
+ text_inputs = tokenizer(
130
+ prompt,
131
+ padding="max_length",
132
+ max_length=77,
133
+ truncation=True,
134
+ return_overflowing_tokens=False,
135
+ return_length=False,
136
+ return_tensors="pt",
137
+ )
138
+
139
+ text_input_ids = text_inputs.input_ids
140
+ else:
141
+ if text_input_ids is None:
142
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
143
+
144
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
145
+
146
+ if hasattr(text_encoder, "module"):
147
+ dtype = text_encoder.module.dtype
148
+ else:
149
+ dtype = text_encoder.dtype
150
+ # Use pooled output of CLIPTextModel
151
+ prompt_embeds = prompt_embeds.pooler_output
152
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
153
+
154
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
155
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
156
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
157
+
158
+ return prompt_embeds
159
+
160
+
161
+ def encode_prompt(
162
+ text_encoders,
163
+ tokenizers,
164
+ prompt: str,
165
+ max_sequence_length,
166
+ device=None,
167
+ num_images_per_prompt: int = 1,
168
+ text_input_ids_list=None,
169
+ ):
170
+ prompt = [prompt] if isinstance(prompt, str) else prompt
171
+
172
+ if hasattr(text_encoders[0], "module"):
173
+ dtype = text_encoders[0].module.dtype
174
+ else:
175
+ dtype = text_encoders[0].dtype
176
+
177
+ pooled_prompt_embeds = _encode_prompt_with_clip(
178
+ text_encoder=text_encoders[0],
179
+ tokenizer=tokenizers[0],
180
+ prompt=prompt,
181
+ device=device if device is not None else text_encoders[0].device,
182
+ num_images_per_prompt=num_images_per_prompt,
183
+ text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
184
+ )
185
+
186
+ prompt_embeds = _encode_prompt_with_t5(
187
+ text_encoder=text_encoders[1],
188
+ tokenizer=tokenizers[1],
189
+ max_sequence_length=max_sequence_length,
190
+ prompt=prompt,
191
+ num_images_per_prompt=num_images_per_prompt,
192
+ device=device if device is not None else text_encoders[1].device,
193
+ text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
194
+ )
195
+
196
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
197
+
198
+ return prompt_embeds, pooled_prompt_embeds, text_ids
199
+
200
+
201
+ def visualize_training_data(batch, vae, model_input, noisy_model_input, cond_input, args, global_step, accelerator):
202
+ """
203
+ Visualize training data including all entities from the batch.
204
+
205
+ Args:
206
+ batch: Training batch containing data
207
+ vae: VAE model for decoding latents
208
+ model_input: Clean latents before adding noise
209
+ noisy_model_input: Noisy latents passed to transformer
210
+ cond_input: Spatial condition latents (may be None)
211
+ args: Training arguments
212
+ global_step: Current training step
213
+ accelerator: Accelerator instance
214
+ """
215
+
216
+ # Check availability of conditions
217
+ has_spatial_condition = batch["cond_pixel_values"] is not None
218
+ has_cuboids_segmasks = "cuboids_segmasks" in batch and batch["cuboids_segmasks"] is not None
219
+ has_cuboids_segmasks_bev = "cuboids_segmasks_bev" in batch and batch["cuboids_segmasks_bev"] is not None
220
+
221
+ # Initialize variables
222
+ spatial_img = None
223
+
224
+ with torch.no_grad():
225
+ # Get VAE config for proper decoding
226
+ vae_config_shift_factor = vae.config.shift_factor
227
+ vae_config_scaling_factor = vae.config.scaling_factor
228
+ vae_dtype = vae.dtype
229
+ vae = vae.to(torch.float32)
230
+
231
+ # Decode spatial condition if available
232
+ if has_spatial_condition:
233
+ cond_for_decode = (cond_input / vae_config_scaling_factor) + vae_config_shift_factor
234
+ spatial_decoded = vae.decode(cond_for_decode.float()).sample
235
+ spatial_decoded = (spatial_decoded / 2 + 0.5).clamp(0, 1) # Normalize to [0,1]
236
+ spatial_img = spatial_decoded[0].float().cpu().permute(1, 2, 0).numpy()
237
+
238
+ # Decode clean model input
239
+ clean_for_decode = (model_input / vae_config_scaling_factor) + vae_config_shift_factor
240
+ clean_decoded = vae.decode(clean_for_decode.float()).sample
241
+ clean_decoded = (clean_decoded / 2 + 0.5).clamp(0, 1)
242
+
243
+ # Decode noisy model input
244
+ noisy_for_decode = (noisy_model_input / vae_config_scaling_factor) + vae_config_shift_factor
245
+ noisy_decoded = vae.decode(noisy_for_decode.float()).sample
246
+ noisy_decoded = (noisy_decoded / 2 + 0.5).clamp(0, 1)
247
+
248
+ # Convert to CPU and numpy for visualization (take first batch item)
249
+ clean_img = clean_decoded[0].float().cpu().permute(1, 2, 0).numpy()
250
+ noisy_img = noisy_decoded[0].float().cpu().permute(1, 2, 0).numpy()
251
+
252
+ # Get text prompt and other info
253
+ text_prompt = batch["prompts"][0] if isinstance(batch["prompts"], list) else batch["prompts"]
254
+ call_id = batch["call_ids"][0] if batch["call_ids"] is not None else "N/A"
255
+
256
+ # Create figure with more subplots to accommodate all entities including BEV
257
+ fig, axes = plt.subplots(4, 3, figsize=(18, 24))
258
+ # fig.suptitle(f'Training Data Visualization - Step {global_step}', fontsize=16)
259
+
260
+ # Spatial condition (0,0)
261
+ if has_spatial_condition and spatial_img is not None:
262
+ axes[0, 0].imshow(spatial_img)
263
+ axes[0, 0].set_title('Spatial Condition')
264
+ else:
265
+ axes[0, 0].text(0.5, 0.5, 'NOT AVAILABLE',
266
+ horizontalalignment='center', verticalalignment='center',
267
+ transform=axes[0, 0].transAxes, fontsize=14, fontweight='bold')
268
+ axes[0, 0].set_title('Spatial Condition')
269
+ axes[0, 0].axis('off')
270
+
271
+ # Clean model input (0,2)
272
+ axes[0, 2].imshow(clean_img)
273
+ axes[0, 2].set_title('Clean Model Input')
274
+ axes[0, 2].axis('off')
275
+
276
+ # Noisy model input (1,0)
277
+ axes[1, 0].imshow(noisy_img)
278
+ axes[1, 0].set_title('Noisy Model Input')
279
+ axes[1, 0].axis('off')
280
+
281
+ # Cuboids segmentation masks with legend (1,1 and 1,2)
282
+ if has_cuboids_segmasks:
283
+ segmask = batch["cuboids_segmasks"][0].float().cpu().numpy() # Shape: (n_subjects, h, w)
284
+ n_subjects, h, w = segmask.shape
285
+
286
+ # Only use first 4 subjects for visualization
287
+ n_subjects_to_show = min(4, n_subjects)
288
+
289
+ # Create colored segmentation visualization
290
+ np.random.seed(42) # For consistent colors
291
+ colors = np.random.rand(n_subjects_to_show + 1, 3) # +1 for background
292
+ colors[0] = [0, 0, 0] # Background is black
293
+
294
+ # Create 2x2 grid of individual subject masks
295
+ grid_h, grid_w = 2, 2
296
+ combined_mask = np.zeros((h * grid_h, w * grid_w, 3))
297
+
298
+ for idx in range(n_subjects_to_show):
299
+ row = idx // grid_w
300
+ col = idx % grid_w
301
+
302
+ # Create binary mask for this subject
303
+ subject_mask = np.zeros((h, w, 3))
304
+ mask = segmask[idx] > 0.5 # Binary threshold
305
+ subject_mask[mask] = colors[idx + 1]
306
+
307
+ # Place in grid
308
+ combined_mask[row*h:(row+1)*h, col*w:(col+1)*w] = subject_mask
309
+
310
+ axes[1, 1].imshow(combined_mask)
311
+ axes[1, 1].set_title('Cuboids Segmentation (2x2 Grid)')
312
+ axes[1, 1].axis('off')
313
+
314
+ # Create legend in the next subplot (1,2) - only for first 4 subjects
315
+ axes[1, 2].set_xlim(0, 1)
316
+ axes[1, 2].set_ylim(0, 1)
317
+
318
+ # Add legend entries
319
+ legend_y_positions = np.linspace(0.9, 0.1, n_subjects_to_show + 1)
320
+ axes[1, 2].text(0.1, legend_y_positions[0], f"Background",
321
+ color=colors[0], fontsize=12, fontweight='bold')
322
+
323
+ for subject_idx in range(n_subjects_to_show):
324
+ axes[1, 2].text(0.1, legend_y_positions[subject_idx + 1],
325
+ f"Subject {subject_idx}",
326
+ color=colors[subject_idx + 1], fontsize=12, fontweight='bold')
327
+
328
+ axes[1, 2].set_title('Segmentation Legend (First 4)')
329
+ axes[1, 2].axis('off')
330
+ else:
331
+ axes[1, 1].text(0.5, 0.5, 'NOT AVAILABLE',
332
+ horizontalalignment='center', verticalalignment='center',
333
+ transform=axes[1, 1].transAxes, fontsize=14, fontweight='bold')
334
+ axes[1, 1].set_title('Cuboids Segmentation')
335
+ axes[1, 1].axis('off')
336
+
337
+ axes[1, 2].text(0.5, 0.5, 'NOT AVAILABLE',
338
+ horizontalalignment='center', verticalalignment='center',
339
+ transform=axes[1, 2].transAxes, fontsize=14, fontweight='bold')
340
+ axes[1, 2].set_title('Segmentation Legend')
341
+ axes[1, 2].axis('off')
342
+
343
+ # BEV Cuboids segmentation masks with legend (2,0 and 2,1)
344
+ if has_cuboids_segmasks_bev:
345
+ segmask_bev = batch["cuboids_segmasks_bev"][0].float().cpu().numpy() # Shape: (n_subjects, h, w)
346
+ n_subjects_bev, h_bev, w_bev = segmask_bev.shape
347
+
348
+ # Create colored segmentation visualization for BEV (use different seed for different colors)
349
+ np.random.seed(123) # Different seed for BEV colors
350
+ colors_bev = np.random.rand(n_subjects_bev + 1, 3) # +1 for background
351
+ colors_bev[0] = [0, 0, 0] # Background is black
352
+
353
+ # Create RGB image from BEV segmentation
354
+ colored_segmask_bev = np.zeros((h_bev, w_bev, 3))
355
+ for subject_idx in range(n_subjects_bev):
356
+ mask_bev = segmask_bev[subject_idx] > 0.5 # Binary threshold
357
+ colored_segmask_bev[mask_bev] = colors_bev[subject_idx + 1]
358
+
359
+ axes[2, 0].imshow(colored_segmask_bev)
360
+ axes[2, 0].set_title('BEV Cuboids Segmentation')
361
+ axes[2, 0].axis('off')
362
+
363
+ # Create BEV legend in the next subplot (2,1)
364
+ axes[2, 1].set_xlim(0, 1)
365
+ axes[2, 1].set_ylim(0, 1)
366
+
367
+ # Add BEV legend entries
368
+ legend_y_positions_bev = np.linspace(0.9, 0.1, n_subjects_bev + 1)
369
+ axes[2, 1].text(0.1, legend_y_positions_bev[0], f"Background",
370
+ color=colors_bev[0], fontsize=12, fontweight='bold')
371
+
372
+ for subject_idx in range(n_subjects_bev):
373
+ axes[2, 1].text(0.1, legend_y_positions_bev[subject_idx + 1],
374
+ f"Subject {subject_idx}",
375
+ color=colors_bev[subject_idx + 1], fontsize=12, fontweight='bold')
376
+
377
+ axes[2, 1].set_title('BEV Segmentation Legend')
378
+ axes[2, 1].axis('off')
379
+ else:
380
+ axes[2, 0].text(0.5, 0.5, 'NOT AVAILABLE',
381
+ horizontalalignment='center', verticalalignment='center',
382
+ transform=axes[2, 0].transAxes, fontsize=14, fontweight='bold')
383
+ axes[2, 0].set_title('BEV Cuboids Segmentation')
384
+ axes[2, 0].axis('off')
385
+
386
+ axes[2, 1].text(0.5, 0.5, 'NOT AVAILABLE',
387
+ horizontalalignment='center', verticalalignment='center',
388
+ transform=axes[2, 1].transAxes, fontsize=14, fontweight='bold')
389
+ axes[2, 1].set_title('BEV Segmentation Legend')
390
+ axes[2, 1].axis('off')
391
+
392
+ # Text prompt and call ID (2,2)
393
+ axes[2, 2].text(0.5, 0.5, f'Text Prompt:\n\n"{text_prompt}"\n\nCall ID: {call_id}',
394
+ horizontalalignment='center', verticalalignment='center',
395
+ transform=axes[2, 2].transAxes, fontsize=12, wrap=True)
396
+ axes[2, 2].set_title('Text Prompt & Call ID')
397
+ axes[2, 2].axis('off')
398
+
399
+ # Pixel values info (3,0)
400
+ pixel_info = f'Pixel Values Shape: {batch["pixel_values"].shape}\n'
401
+ if has_spatial_condition:
402
+ pixel_info += f'Spatial Shape: {batch["cond_pixel_values"].shape}\n'
403
+ if has_cuboids_segmasks:
404
+ pixel_info += f'Cuboids Segmasks: {len(batch["cuboids_segmasks"])}\n'
405
+ if has_cuboids_segmasks_bev:
406
+ pixel_info += f'BEV Segmasks: {len(batch["cuboids_segmasks_bev"])}'
407
+
408
+ axes[3, 0].text(0.5, 0.5, pixel_info,
409
+ horizontalalignment='center', verticalalignment='center',
410
+ transform=axes[3, 0].transAxes, fontsize=10, fontfamily='monospace')
411
+ axes[3, 0].set_title('Tensor Shapes')
412
+ axes[3, 0].axis('off')
413
+
414
+ # Training info (3,1)
415
+ training_info = f'Global Step: {global_step}\nConditions:\nSpatial: {"✓" if has_spatial_condition else "✗"}\nSubject: {"fuck you"}\nSegmasks: {"✓" if has_cuboids_segmasks else "✗"}\nBEV Segmasks: {"✓" if has_cuboids_segmasks_bev else "✗"}'
416
+ axes[3, 1].text(0.5, 0.5, training_info,
417
+ horizontalalignment='center', verticalalignment='center',
418
+ transform=axes[3, 1].transAxes, fontsize=12, fontfamily='monospace')
419
+ axes[3, 1].set_title('Training Info')
420
+ axes[3, 1].axis('off')
421
+
422
+ # Additional info (3,2) - can be used for any extra debugging info
423
+ axes[3, 2].text(0.5, 0.5, 'Additional Info\n(Reserved)',
424
+ horizontalalignment='center', verticalalignment='center',
425
+ transform=axes[3, 2].transAxes, fontsize=12, fontfamily='monospace')
426
+ axes[3, 2].set_title('Reserved')
427
+ axes[3, 2].axis('off')
428
+
429
+ plt.tight_layout()
430
+
431
+ # Save the visualization
432
+ save_dir = os.path.join(args.output_dir, "visualizations")
433
+ os.makedirs(save_dir, exist_ok=True)
434
+ save_path = os.path.join(save_dir, f"training_vis_step_{global_step}.png")
435
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
436
+ plt.close()
437
+
438
+ logger.info(f"Training visualization saved to {save_path}")
439
+
440
+ vae = vae.to(vae_dtype)
441
+
442
+ def import_model_class_from_model_name_or_path(
443
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
444
+ ):
445
+ text_encoder_config = PretrainedConfig.from_pretrained(
446
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
447
+ )
448
+ model_class = text_encoder_config.architectures[0]
449
+ if model_class == "CLIPTextModel":
450
+ from transformers import CLIPTextModel
451
+
452
+ return CLIPTextModel
453
+ elif model_class == "T5EncoderModel":
454
+ from transformers import T5EncoderModel
455
+
456
+ return T5EncoderModel
457
+ else:
458
+ raise ValueError(f"{model_class} is not supported.")
459
+
460
+
461
+ def parse_args(input_args=None):
462
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
463
+ parser.add_argument("--lora_num", type=int, default=2, help="number of the lora.")
464
+ parser.add_argument("--cond_size", type=int, default=512, help="size of the condition data.")
465
+ parser.add_argument("--debug", type=int, default=0, help="whether to enter debug mode -- visualizations, gradient checks, etc.")
466
+ parser.add_argument("--mode",type=str,default=None,help="The mode of the controller. Choose between ['depth', 'pose', 'canny'].")
467
+ parser.add_argument("--run_name",type=str,required=True,help="the name of the wandb run")
468
+ parser.add_argument(
469
+ "--train_data_dir",
470
+ type=str,
471
+ default="",
472
+ help=(
473
+ "A folder containing the training data. Folder contents must follow the structure described in"
474
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
475
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
476
+ ),
477
+ )
478
+ parser.add_argument(
479
+ "--inference_embeds_dir",
480
+ type=str,
481
+ default=None,
482
+ help=(
483
+ "the captions for images"
484
+ ),
485
+ )
486
+ parser.add_argument(
487
+ "--pretrained_model_name_or_path",
488
+ type=str,
489
+ default="",
490
+ required=False,
491
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
492
+ )
493
+ parser.add_argument(
494
+ "--revision",
495
+ type=str,
496
+ default=None,
497
+ required=False,
498
+ help="Revision of pretrained model identifier from huggingface.co/models.",
499
+ )
500
+ parser.add_argument(
501
+ "--variant",
502
+ type=str,
503
+ default=None,
504
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
505
+ )
506
+ parser.add_argument(
507
+ "--spatial_column",
508
+ type=str,
509
+ default="None",
510
+ help="The column of the dataset containing the canny image. By "
511
+ "default, the standard Image Dataset maps out 'file_name' "
512
+ "to 'image'.",
513
+ )
514
+ parser.add_argument(
515
+ "--target_column",
516
+ type=str,
517
+ default="image",
518
+ help="The column of the dataset containing the target image. By "
519
+ "default, the standard Image Dataset maps out 'file_name' "
520
+ "to 'image'.",
521
+ )
522
+ parser.add_argument(
523
+ "--caption_column",
524
+ type=str,
525
+ default="caption_left,caption_right",
526
+ help="The column of the dataset containing the instance prompt for each image",
527
+ )
528
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
529
+ parser.add_argument(
530
+ "--max_sequence_length",
531
+ type=int,
532
+ default=512,
533
+ help="Maximum sequence length to use with with the T5 text encoder",
534
+ )
535
+ parser.add_argument(
536
+ "--ranks",
537
+ type=int,
538
+ nargs="+",
539
+ default=[128],
540
+ help=("The dimension of the LoRA update matrices."),
541
+ )
542
+ parser.add_argument(
543
+ "--network_alphas",
544
+ type=int,
545
+ nargs="+",
546
+ default=[128],
547
+ help=("The dimension of the LoRA update matrices."),
548
+ )
549
+ parser.add_argument(
550
+ "--output_dir",
551
+ type=str,
552
+ required=True,
553
+ help="The output directory where the model predictions and checkpoints will be written.",
554
+ )
555
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
556
+ parser.add_argument(
557
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
558
+ )
559
+ parser.add_argument("--num_train_epochs", type=int, default=50)
560
+ parser.add_argument(
561
+ "--max_train_steps",
562
+ type=int,
563
+ default=None,
564
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
565
+ )
566
+ parser.add_argument(
567
+ "--checkpointing_steps",
568
+ type=int,
569
+ default=1000,
570
+ help=(
571
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
572
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
573
+ " training using `--resume_from_checkpoint`."
574
+ ),
575
+ )
576
+ parser.add_argument(
577
+ "--resume_from_checkpoint",
578
+ type=str,
579
+ default=None,
580
+ help=(
581
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
582
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
583
+ ),
584
+ )
585
+ parser.add_argument(
586
+ "--gradient_accumulation_steps",
587
+ type=int,
588
+ default=1,
589
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
590
+ )
591
+ parser.add_argument(
592
+ "--gradient_checkpointing",
593
+ action="store_true",
594
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
595
+ )
596
+ parser.add_argument(
597
+ "--learning_rate",
598
+ type=float,
599
+ default=1e-4,
600
+ help="Initial learning rate (after the potential warmup period) to use.",
601
+ )
602
+
603
+ parser.add_argument(
604
+ "--guidance_scale",
605
+ type=float,
606
+ default=1,
607
+ help="the FLUX.1 dev variant is a guidance distilled model",
608
+ )
609
+ parser.add_argument(
610
+ "--scale_lr",
611
+ action="store_true",
612
+ default=False,
613
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
614
+ )
615
+ parser.add_argument(
616
+ "--lr_scheduler",
617
+ type=str,
618
+ default="constant",
619
+ help=(
620
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
621
+ ' "constant", "constant_with_warmup"]'
622
+ ),
623
+ )
624
+ parser.add_argument(
625
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
626
+ )
627
+ parser.add_argument(
628
+ "--lr_num_cycles",
629
+ type=int,
630
+ default=1,
631
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
632
+ )
633
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
634
+ parser.add_argument(
635
+ "--dataloader_num_workers",
636
+ type=int,
637
+ default=2,
638
+ help=(
639
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
640
+ ),
641
+ )
642
+ parser.add_argument(
643
+ "--weighting_scheme",
644
+ type=str,
645
+ default="none",
646
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
647
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
648
+ )
649
+ parser.add_argument(
650
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
651
+ )
652
+ parser.add_argument(
653
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
654
+ )
655
+ parser.add_argument(
656
+ "--mode_scale",
657
+ type=float,
658
+ default=1.29,
659
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
660
+ )
661
+ parser.add_argument(
662
+ "--optimizer",
663
+ type=str,
664
+ default="AdamW",
665
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
666
+ )
667
+
668
+ parser.add_argument(
669
+ "--use_8bit_adam",
670
+ action="store_true",
671
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
672
+ )
673
+
674
+ parser.add_argument(
675
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
676
+ )
677
+ parser.add_argument(
678
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
679
+ )
680
+ parser.add_argument(
681
+ "--prodigy_beta3",
682
+ type=float,
683
+ default=None,
684
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
685
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
686
+ )
687
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
688
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
689
+ parser.add_argument(
690
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
691
+ )
692
+
693
+ parser.add_argument(
694
+ "--adam_epsilon",
695
+ type=float,
696
+ default=1e-08,
697
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
698
+ )
699
+
700
+ parser.add_argument(
701
+ "--prodigy_use_bias_correction",
702
+ type=bool,
703
+ default=True,
704
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
705
+ )
706
+ parser.add_argument(
707
+ "--prodigy_safeguard_warmup",
708
+ type=bool,
709
+ default=True,
710
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
711
+ "Ignored if optimizer is adamW",
712
+ )
713
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
714
+ parser.add_argument(
715
+ "--logging_dir",
716
+ type=str,
717
+ default="logs",
718
+ help=(
719
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
720
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
721
+ ),
722
+ )
723
+ parser.add_argument(
724
+ "--report_to",
725
+ type=str,
726
+ default="tensorboard",
727
+ help=(
728
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
729
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
730
+ ),
731
+ )
732
+ parser.add_argument(
733
+ "--mixed_precision",
734
+ type=str,
735
+ default="bf16",
736
+ choices=["no", "fp16", "bf16"],
737
+ help=(
738
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
739
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
740
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
741
+ ),
742
+ )
743
+ parser.add_argument(
744
+ "--upcast_before_saving",
745
+ action="store_true",
746
+ default=False,
747
+ help=(
748
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
749
+ "Defaults to precision dtype used for training to save memory"
750
+ ),
751
+ )
752
+
753
+ if input_args is not None:
754
+ args = parser.parse_args(input_args)
755
+ else:
756
+ args = parser.parse_args()
757
+ return args
758
+
759
+
760
+ def main(args):
761
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
762
+ # due to pytorch#99272, MPS does not yet support bfloat16.
763
+ raise ValueError(
764
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
765
+ )
766
+
767
+ if args.resume_from_checkpoint is not None:
768
+ assert osp.exists(args.resume_from_checkpoint), f"Make sure that the `resume_from_checkpoint` {args.resume_from_checkpoint} exists."
769
+ args.pretrained_lora_path = osp.join(args.resume_from_checkpoint, f"lora.safetensors")
770
+ assert osp.exists(args.pretrained_lora_path), f"Make sure that the `pretrained_lora_path` {args.pretrained_lora_path} exists."
771
+ else:
772
+ args.pretrained_lora_path = None
773
+
774
+ args.output_dir = osp.join(args.output_dir, args.run_name)
775
+ args.logging_dir = osp.join(args.output_dir, args.logging_dir)
776
+ os.makedirs(args.output_dir, exist_ok=True)
777
+ os.makedirs(args.logging_dir, exist_ok=True)
778
+ logging_dir = Path(args.output_dir, args.logging_dir)
779
+
780
+ if args.spatial_column == "None":
781
+ args.spatial_column = None
782
+
783
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
784
+ # kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
785
+ accelerator = Accelerator(
786
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
787
+ mixed_precision=args.mixed_precision,
788
+ log_with=args.report_to,
789
+ project_config=accelerator_project_config,
790
+ # kwargs_handlers=[kwargs],
791
+ )
792
+
793
+ def save_model_hook(models, weights, output_dir):
794
+ pass
795
+
796
+ def load_model_hook(models, input_dir):
797
+ pass
798
+
799
+ # Disable AMP for MPS.
800
+ if torch.backends.mps.is_available():
801
+ accelerator.native_amp = False
802
+
803
+ if args.report_to == "wandb":
804
+ if not is_wandb_available():
805
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
806
+
807
+ # Make one log on every process with the configuration for debugging.
808
+ logging.basicConfig(
809
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
810
+ datefmt="%m/%d/%Y %H:%M:%S",
811
+ level=logging.INFO,
812
+ )
813
+ logger.info(accelerator.state, main_process_only=False)
814
+ if accelerator.is_local_main_process:
815
+ transformers.utils.logging.set_verbosity_warning()
816
+ diffusers.utils.logging.set_verbosity_info()
817
+ else:
818
+ transformers.utils.logging.set_verbosity_error()
819
+ diffusers.utils.logging.set_verbosity_error()
820
+
821
+ # If passed along, set the training seed now.
822
+ if args.seed is not None:
823
+ set_seed(args.seed)
824
+
825
+ # Handle the repository creation
826
+ if accelerator.is_main_process:
827
+ if args.output_dir is not None:
828
+ os.makedirs(args.output_dir, exist_ok=True)
829
+
830
+ # Load the tokenizers
831
+ tokenizer_one = CLIPTokenizer.from_pretrained(
832
+ args.pretrained_model_name_or_path,
833
+ subfolder="tokenizer",
834
+ revision=args.revision,
835
+ )
836
+ tokenizer_two = T5TokenizerFast.from_pretrained(
837
+ args.pretrained_model_name_or_path,
838
+ subfolder="tokenizer_2",
839
+ revision=args.revision,
840
+ )
841
+
842
+ # Load scheduler and models
843
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
844
+ args.pretrained_model_name_or_path, subfolder="scheduler"
845
+ )
846
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
847
+ gc.collect()
848
+ torch.cuda.empty_cache()
849
+
850
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
851
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder"
852
+ )
853
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
854
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
855
+ )
856
+ if args.inference_embeds_dir is None:
857
+ text_encoder_one, text_encoder_two = load_text_encoders(args, text_encoder_cls_one, text_encoder_cls_two)
858
+ else:
859
+ assert osp.exists(args.inference_embeds_dir), f"Make sure that the `inference_embeds_dir` {args.inference_embeds_dir} exists."
860
+ vae = AutoencoderKL.from_pretrained(
861
+ args.pretrained_model_name_or_path,
862
+ subfolder="vae",
863
+ revision=args.revision,
864
+ variant=args.variant,
865
+ )
866
+ transformer = FluxTransformer2DModel.from_pretrained(
867
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
868
+ )
869
+
870
+ # We only train the additional adapter LoRA layers
871
+ transformer.requires_grad_(True)
872
+ vae.requires_grad_(False)
873
+ if args.inference_embeds_dir is None:
874
+ text_encoder_one.requires_grad_(False)
875
+ text_encoder_two.requires_grad_(False)
876
+
877
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
878
+ # as these weights are only used for inference, keeping weights in full precision is not required.
879
+ weight_dtype = torch.float32
880
+ if accelerator.mixed_precision == "fp16":
881
+ weight_dtype = torch.float16
882
+ elif accelerator.mixed_precision == "bf16":
883
+ weight_dtype = torch.bfloat16
884
+
885
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
886
+ # due to pytorch#99272, MPS does not yet support bfloat16.
887
+ raise ValueError(
888
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
889
+ )
890
+
891
+ vae.to(accelerator.device, dtype=weight_dtype)
892
+ transformer.to(accelerator.device, dtype=weight_dtype)
893
+ if args.inference_embeds_dir is None:
894
+ text_encoder_one.to(accelerator.device, dtype=torch.float32)
895
+ text_encoder_two.to(accelerator.device, dtype=torch.float32)
896
+
897
+ if args.gradient_checkpointing:
898
+ transformer.enable_gradient_checkpointing()
899
+
900
+ #### lora_layers ####
901
+ if args.pretrained_lora_path is not None:
902
+ lora_path = args.pretrained_lora_path
903
+ checkpoint = load_checkpoint(lora_path)
904
+ lora_attn_procs = {}
905
+ double_blocks_idx = list(range(19))
906
+ single_blocks_idx = list(range(38))
907
+ number = 1
908
+ for name, attn_processor in transformer.attn_processors.items():
909
+ match = re.search(r'\.(\d+)\.', name)
910
+ if match:
911
+ layer_index = int(match.group(1))
912
+
913
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
914
+ lora_state_dicts = {}
915
+ for key, value in checkpoint.items():
916
+ # Match based on the layer index in the key (assuming the key contains layer index)
917
+ if re.search(r'\.(\d+)\.', key):
918
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
919
+ if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
920
+ lora_state_dicts[key] = value
921
+
922
+ print("setting LoRA Processor for", name)
923
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
924
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
925
+ )
926
+
927
+ # Load the weights from the checkpoint dictionary into the corresponding layers
928
+ for n in range(number):
929
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
930
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
931
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
932
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
933
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
934
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
935
+ lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
936
+ lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
937
+
938
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
939
+
940
+ lora_state_dicts = {}
941
+ for key, value in checkpoint.items():
942
+ # Match based on the layer index in the key (assuming the key contains layer index)
943
+ if re.search(r'\.(\d+)\.', key):
944
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
945
+ if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
946
+ lora_state_dicts[key] = value
947
+
948
+ print("setting LoRA Processor for", name)
949
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
950
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
951
+ )
952
+
953
+ # Load the weights from the checkpoint dictionary into the corresponding layers
954
+ for n in range(number):
955
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
956
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
957
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
958
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
959
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
960
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
961
+ else:
962
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
963
+ else:
964
+ lora_attn_procs = {}
965
+ double_blocks_idx = list(range(19))
966
+ single_blocks_idx = list(range(38))
967
+ for name, attn_processor in transformer.attn_processors.items():
968
+ match = re.search(r'\.(\d+)\.', name)
969
+ if match:
970
+ layer_index = int(match.group(1))
971
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
972
+ lora_state_dicts = {}
973
+ print("setting LoRA Processor for", name)
974
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
975
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
976
+ )
977
+
978
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
979
+ print("setting LoRA Processor for", name)
980
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
981
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
982
+ )
983
+
984
+ else:
985
+ lora_attn_procs[name] = attn_processor
986
+ ######################
987
+ transformer.set_attn_processor(lora_attn_procs)
988
+ transformer.train()
989
+ for n, param in transformer.named_parameters():
990
+ if '_lora' not in n:
991
+ param.requires_grad = False
992
+ print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'M parameters')
993
+
994
+ def unwrap_model(model):
995
+ model = accelerator.unwrap_model(model)
996
+ model = model._orig_mod if is_compiled_module(model) else model
997
+ return model
998
+
999
+ # Potentially load in the weights and states from a previous save
1000
+ if args.resume_from_checkpoint:
1001
+ foldername = osp.basename(args.resume_from_checkpoint)
1002
+ first_epoch = epoch = int(foldername.split("-")[1].split("__")[0])
1003
+ initial_global_step = global_step = int(foldername.split("-")[-1])
1004
+ else:
1005
+ initial_global_step = 0
1006
+ global_step = 0
1007
+ first_epoch = 0
1008
+
1009
+ if args.scale_lr:
1010
+ args.learning_rate = (
1011
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1012
+ )
1013
+
1014
+ # Make sure the trainable params are in float32.
1015
+ if args.mixed_precision == "fp16":
1016
+ models = [transformer]
1017
+ # only upcast trainable parameters (LoRA) into fp32
1018
+ cast_training_params(models, dtype=torch.float32)
1019
+
1020
+ # Optimization parameters
1021
+ params_to_optimize = [p for p in transformer.parameters() if p.requires_grad]
1022
+ transformer_parameters_with_lr = {"params": params_to_optimize, "lr": args.learning_rate}
1023
+ print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'parameters')
1024
+
1025
+ optimizer_class = torch.optim.AdamW
1026
+ optimizer = optimizer_class(
1027
+ [transformer_parameters_with_lr],
1028
+ betas=(args.adam_beta1, args.adam_beta2),
1029
+ weight_decay=args.adam_weight_decay,
1030
+ eps=args.adam_epsilon,
1031
+ )
1032
+
1033
+ tokenizers = [tokenizer_one, tokenizer_two]
1034
+
1035
+ # now, we will define a dataset for each epoch to make it easier to save the state
1036
+ shuffled_jsonls = os.listdir(osp.dirname(args.train_data_dir))
1037
+ base_jsonl_name = osp.basename(args.train_data_dir).replace(".jsonl", "")
1038
+ shuffled_jsonls = sorted([_ for _ in shuffled_jsonls if _.endswith('.jsonl') and "shuffled" in _ and base_jsonl_name in _])
1039
+ shuffled_jsonls = [osp.join(osp.dirname(args.train_data_dir), _) for _ in shuffled_jsonls]
1040
+ print(f"{shuffled_jsonls = }")
1041
+ assert len(shuffled_jsonls) > 0, f"Make sure that there are shuffled jsonl files in {osp.dirname(args.train_data_dir)}"
1042
+ train_dataloaders = []
1043
+ for epoch in range(args.num_train_epochs): # prepare dataloader for each epoch, irrespective of the resume state
1044
+ shuffled_idx = epoch % len(shuffled_jsonls)
1045
+ train_data_file = shuffled_jsonls[shuffled_idx]
1046
+ assert osp.exists(train_data_file), f"Make sure that the train data jsonl file {train_data_file} exists."
1047
+ args.current_train_data_dir = train_data_file
1048
+ train_dataset = make_train_dataset(args, tokenizers, accelerator)
1049
+ train_dataloader = torch.utils.data.DataLoader(
1050
+ train_dataset,
1051
+ batch_size=args.train_batch_size,
1052
+ shuffle=False, # yayy!! reproducible experiments!
1053
+ collate_fn=collate_fn,
1054
+ num_workers=args.dataloader_num_workers,
1055
+ )
1056
+ train_dataloaders.append(train_dataloader)
1057
+
1058
+ vae_config_shift_factor = vae.config.shift_factor
1059
+ vae_config_scaling_factor = vae.config.scaling_factor
1060
+
1061
+ # Scheduler and math around the number of training steps.
1062
+ overrode_max_train_steps = False
1063
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1064
+ if args.max_train_steps is None:
1065
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1066
+ overrode_max_train_steps = True
1067
+
1068
+ lr_scheduler = get_scheduler(
1069
+ args.lr_scheduler,
1070
+ optimizer=optimizer,
1071
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1072
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
1073
+ num_cycles=args.lr_num_cycles,
1074
+ power=args.lr_power,
1075
+ )
1076
+
1077
+
1078
+ accelerator.register_save_state_pre_hook(save_model_hook)
1079
+ accelerator.register_load_state_pre_hook(load_model_hook)
1080
+ optimizer, lr_scheduler = accelerator.prepare(
1081
+ optimizer, lr_scheduler
1082
+ )
1083
+
1084
+ print(f"before preparation, {len(train_dataloaders[0]) = }")
1085
+
1086
+ prepared_train_dataloaders = []
1087
+ for train_dataloader in train_dataloaders:
1088
+ prepared_train_dataloaders.append(accelerator.prepare(train_dataloader))
1089
+ train_dataloaders = prepared_train_dataloaders
1090
+
1091
+ print(f"after preparation, {len(train_dataloaders[0]) = }")
1092
+
1093
+ if args.pretrained_lora_path is not None:
1094
+ accelerator.load_state(osp.dirname(args.pretrained_lora_path))
1095
+
1096
+ # Explicitly move optimizer states to accelerator.device
1097
+ for state in optimizer.state.values():
1098
+ for k, v in state.items():
1099
+ if isinstance(v, torch.Tensor):
1100
+ state[k] = v.to(accelerator.device)
1101
+
1102
+ transformer = accelerator.prepare(transformer)
1103
+
1104
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1105
+ num_update_steps_per_epoch = math.ceil(len(train_dataloaders[0]) / args.gradient_accumulation_steps)
1106
+ if overrode_max_train_steps:
1107
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1108
+ # Afterwards we recalculate our number of training epochs
1109
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1110
+
1111
+ # We need to initialize the trackers we use, and also store our configuration.
1112
+
1113
+ if accelerator.is_main_process:
1114
+ accelerator.init_trackers(args.run_name)
1115
+
1116
+
1117
+ # Train!
1118
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1119
+
1120
+ logger.info("***** Running training *****")
1121
+ logger.info(f" Num examples = {len(train_dataset)}")
1122
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1123
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1124
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1125
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1126
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1127
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1128
+
1129
+ progress_bar = tqdm(
1130
+ range(0, args.max_train_steps),
1131
+ initial=initial_global_step,
1132
+ desc="Steps",
1133
+ # Only show the progress bar once on each machine.
1134
+ disable=not accelerator.is_local_main_process,
1135
+ )
1136
+
1137
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1138
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
1139
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
1140
+ timesteps = timesteps.to(accelerator.device)
1141
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
1142
+
1143
+ sigma = sigmas[step_indices].flatten()
1144
+ while len(sigma.shape) < n_dim:
1145
+ sigma = sigma.unsqueeze(-1)
1146
+ return sigma
1147
+
1148
+ # some fixed parameters
1149
+ vae_scale_factor = 16
1150
+ height_cond = 2 * (args.cond_size // vae_scale_factor)
1151
+ width_cond = 2 * (args.cond_size // vae_scale_factor)
1152
+ offset = 64
1153
+
1154
+ num_training_visualizations = 10
1155
+
1156
+ skip_steps = initial_global_step - first_epoch * num_update_steps_per_epoch
1157
+ print(f"{skip_steps = }")
1158
+ for epoch in range(first_epoch, args.num_train_epochs):
1159
+ transformer.train()
1160
+ train_dataloader = train_dataloaders[epoch] # use a new dataloader for each epoch
1161
+ if epoch == first_epoch and skip_steps > 0:
1162
+ logger.info(f"Skipping {skip_steps} batches in epoch {epoch} due to resuming from checkpoint")
1163
+ # dataloader_iterator = skip_first_batches_manual(train_dataloader, skip_steps)
1164
+ dataloader_iterator = accelerator.skip_first_batches(train_dataloader, skip_steps)
1165
+ # Convert back to enumerate format
1166
+ enumerated_dataloader = enumerate(dataloader_iterator, start=skip_steps)
1167
+ else:
1168
+ enumerated_dataloader = enumerate(train_dataloader)
1169
+ for step, batch in enumerated_dataloader:
1170
+ progress_bar.set_description(f"epoch {epoch}, dataset_ids: {batch['index']}")
1171
+ models_to_accumulate = [transformer]
1172
+ with accelerator.accumulate(models_to_accumulate):
1173
+
1174
+ if args.inference_embeds_dir is None:
1175
+ print(f"encoding {batch['prompts'] = }")
1176
+ # prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
1177
+ # prompt=batch["prompts"],
1178
+ # prompt_2=batch["prompts"],
1179
+ # )
1180
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
1181
+ text_encoders=[text_encoder_one, text_encoder_two],
1182
+ tokenizers=[tokenizer_one, tokenizer_two],
1183
+ prompt=batch["prompts"],
1184
+ max_sequence_length=512,
1185
+ device=accelerator.device,
1186
+ )
1187
+ for i, prompt in enumerate(batch["prompts"]):
1188
+ # prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
1189
+ # text_encoders=[text_encoder_one, text_encoder_two],
1190
+ # tokenizers=[tokenizer_one, tokenizer_two],
1191
+ # prompt=prompt,
1192
+ # max_sequence_length=512,
1193
+ # device=accelerator.device,
1194
+ # )
1195
+ print(f"{prompt_embeds.shape = }, {pooled_prompt_embeds.shape = }, {text_ids.shape = }")
1196
+ # checking if the cached embeddings match
1197
+ inference_embeds_dir = "/archive/vaibhav.agrawal/a-bev-of-the-latents/inference_embeds_datasetv7_superhard"
1198
+ cached_prompt_path = osp.join(inference_embeds_dir, f"{'_'.join(prompt.lower().split())}.pth")
1199
+ assert osp.exists(cached_prompt_path), f"Make sure that the cached prompt embedding {cached_prompt_path} exists."
1200
+ cached_prompt_embeds = torch.load(cached_prompt_path, map_location="cpu")
1201
+ assert torch.allclose(cached_prompt_embeds["prompt_embeds"].cpu().float(), prompt_embeds[i].cpu().float(), atol=1e-3), f"Cached prompt embeds for prompt {prompt} do not match the computed prompt embeds. Make sure that the cached prompt embeds are correct., {torch.mean(torch.abs(cached_prompt_embeds['prompt_embeds'].cpu().float() - prompt_embeds[i].cpu().float())) = }, {torch.mean(torch.abs(cached_prompt_embeds['prompt_embeds'].cpu().float())) = }"
1202
+ assert torch.allclose(cached_prompt_embeds["pooled_prompt_embeds"].cpu().float(), pooled_prompt_embeds[i].cpu().float(), atol=1e-3), f"Cached pooled prompt embeds for prompt {prompt} do not match the computed pooled prompt embeds. Make sure that the cached pooled prompt embeds are correct., {torch.mean(torch.abs(cached_prompt_embeds['pooled_prompt_embeds'].cpu().float() - pooled_prompt_embeds[i].cpu().float())) = }"
1203
+ print(f"fucking passed the test!")
1204
+ else:
1205
+ assert "prompt_embeds" in batch and "pooled_prompt_embeds" in batch, "Make sure that the dataloader returns `prompt_embeds` and `pooled_prompt_embeds` when `inference_embeds_dir` is not None."
1206
+ prompt_embeds = batch["prompt_embeds"]
1207
+ pooled_prompt_embeds = batch["pooled_prompt_embeds"]
1208
+ text_ids = torch.zeros((batch["prompt_embeds"].shape[1], 3))
1209
+ prompt_embeds = prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
1210
+ pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
1211
+ text_ids = text_ids.to(dtype=vae.dtype, device=accelerator.device)
1212
+
1213
+
1214
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1215
+ height_ = 2 * (int(pixel_values.shape[-2]) // vae_scale_factor)
1216
+ width_ = 2 * (int(pixel_values.shape[-1]) // vae_scale_factor)
1217
+
1218
+ model_input = vae.encode(pixel_values).latent_dist.sample()
1219
+ model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
1220
+ model_input = model_input.to(dtype=weight_dtype)
1221
+
1222
+ latent_image_ids, cond_latent_image_ids = resize_position_encoding(
1223
+ model_input.shape[0],
1224
+ height_,
1225
+ width_,
1226
+ height_cond,
1227
+ width_cond,
1228
+ accelerator.device,
1229
+ weight_dtype,
1230
+ )
1231
+
1232
+ # Sample noise that we'll add to the latents
1233
+ noise = torch.randn_like(model_input)
1234
+ bsz = model_input.shape[0]
1235
+
1236
+ # Sample a random timestep for each image
1237
+ # for weighting schemes where we sample timesteps non-uniformly
1238
+ u = compute_density_for_timestep_sampling(
1239
+ weighting_scheme=args.weighting_scheme,
1240
+ batch_size=bsz,
1241
+ logit_mean=args.logit_mean,
1242
+ logit_std=args.logit_std,
1243
+ mode_scale=args.mode_scale,
1244
+ )
1245
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
1246
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
1247
+
1248
+ # Add noise according to flow matching.
1249
+ # zt = (1 - texp) * x + texp * z1
1250
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
1251
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
1252
+
1253
+ packed_noisy_model_input = FluxPipeline._pack_latents(
1254
+ noisy_model_input,
1255
+ batch_size=model_input.shape[0],
1256
+ num_channels_latents=model_input.shape[1],
1257
+ height=model_input.shape[2],
1258
+ width=model_input.shape[3],
1259
+ )
1260
+
1261
+ latent_image_ids_to_concat = [latent_image_ids]
1262
+ packed_cond_model_input_to_concat = []
1263
+
1264
+ if args.spatial_column is not None:
1265
+ # in case the condition is spatial
1266
+ cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
1267
+ cond_input = vae.encode(cond_pixel_values).latent_dist.sample()
1268
+ cond_input = (cond_input - vae_config_shift_factor) * vae_config_scaling_factor
1269
+ cond_input = cond_input.to(dtype=weight_dtype)
1270
+ # number of conditions in the concatenated condition image
1271
+ cond_number = cond_pixel_values.shape[-2] // args.cond_size
1272
+ cond_latent_image_ids = torch.concat([cond_latent_image_ids for _ in range(cond_number)], dim=-2)
1273
+ latent_image_ids_to_concat.append(cond_latent_image_ids)
1274
+
1275
+ packed_cond_model_input = FluxPipeline._pack_latents(
1276
+ cond_input,
1277
+ batch_size=cond_input.shape[0],
1278
+ num_channels_latents=cond_input.shape[1],
1279
+ height=cond_input.shape[2],
1280
+ width=cond_input.shape[3],
1281
+ )
1282
+ packed_cond_model_input_to_concat.append(packed_cond_model_input)
1283
+ else:
1284
+ cond_input = None
1285
+
1286
+ latent_image_ids = torch.concat(latent_image_ids_to_concat, dim=-2)
1287
+ cond_packed_noisy_model_input = torch.concat(packed_cond_model_input_to_concat, dim=-2)
1288
+
1289
+ # handle guidance
1290
+ if accelerator.unwrap_model(transformer).config.guidance_embeds:
1291
+ guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
1292
+ guidance = guidance.expand(model_input.shape[0])
1293
+ else:
1294
+ guidance = None
1295
+
1296
+ # Visualize training data before transformer forward pass
1297
+ if accelerator.is_main_process and args.debug and num_training_visualizations > 0 and global_step % 5 == 0:
1298
+ visualize_training_data(
1299
+ batch=batch,
1300
+ vae=vae,
1301
+ model_input=model_input,
1302
+ noisy_model_input=noisy_model_input,
1303
+ cond_input=cond_input,
1304
+ args=args,
1305
+ global_step=global_step,
1306
+ accelerator=accelerator
1307
+ )
1308
+ num_training_visualizations -= 1
1309
+
1310
+ # Predict the noise residual
1311
+ model_pred = transformer(
1312
+ hidden_states=packed_noisy_model_input,
1313
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
1314
+ cond_hidden_states=cond_packed_noisy_model_input,
1315
+ timestep=timesteps / 1000,
1316
+ guidance=guidance,
1317
+ pooled_projections=pooled_prompt_embeds,
1318
+ encoder_hidden_states=prompt_embeds,
1319
+ txt_ids=text_ids,
1320
+ img_ids=latent_image_ids,
1321
+ return_dict=False,
1322
+ call_ids=batch["call_ids"],
1323
+ cuboids_segmasks=batch["cuboids_segmasks"],
1324
+ )[0]
1325
+
1326
+ model_pred = FluxPipeline._unpack_latents(
1327
+ model_pred,
1328
+ height=int(pixel_values.shape[-2]),
1329
+ width=int(pixel_values.shape[-1]),
1330
+ vae_scale_factor=vae_scale_factor,
1331
+ )
1332
+
1333
+ # these weighting schemes use a uniform timestep sampling
1334
+ # and instead post-weight the loss
1335
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
1336
+
1337
+ # flow matching loss
1338
+ target = noise - model_input
1339
+
1340
+ # Compute regular loss.
1341
+ loss = torch.mean(
1342
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1343
+ 1,
1344
+ )
1345
+
1346
+ loss = loss.mean()
1347
+ accelerator.backward(loss)
1348
+ if accelerator.sync_gradients:
1349
+ params_to_clip = (transformer.parameters())
1350
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1351
+
1352
+ optimizer.step()
1353
+ lr_scheduler.step()
1354
+ optimizer.zero_grad()
1355
+
1356
+ # Checks if the accelerator has performed an optimization step behind the scenes
1357
+ if accelerator.sync_gradients:
1358
+ progress_bar.update(1)
1359
+ global_step += 1
1360
+
1361
+ if accelerator.is_main_process:
1362
+ if global_step % args.checkpointing_steps == 0:
1363
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1364
+ save_path = os.path.join(args.output_dir, f"epoch-{epoch}__checkpoint-{global_step}")
1365
+ os.makedirs(save_path, exist_ok=True)
1366
+ unwrapped_model_state = accelerator.unwrap_model(transformer).state_dict()
1367
+ lora_state_dict = {k:unwrapped_model_state[k] for k in unwrapped_model_state.keys() if '_lora' in k}
1368
+ save_file(
1369
+ lora_state_dict,
1370
+ os.path.join(save_path, "lora.safetensors")
1371
+ )
1372
+ accelerator.save_state(save_path)
1373
+ os.remove(osp.join(save_path, "model.safetensors"))
1374
+ logger.info(f"Saved state to {save_path}")
1375
+
1376
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1377
+ progress_bar.set_postfix(**logs)
1378
+ accelerator.log(logs, step=global_step)
1379
+
1380
+ save_path = os.path.join(args.output_dir, f"epoch-{epoch}__checkpoint-{global_step}")
1381
+ os.makedirs(save_path, exist_ok=True)
1382
+ unwrapped_model_state = accelerator.unwrap_model(transformer).state_dict()
1383
+ lora_state_dict = {k:unwrapped_model_state[k] for k in unwrapped_model_state.keys() if '_lora' in k}
1384
+ save_file(
1385
+ lora_state_dict,
1386
+ os.path.join(save_path, "lora.safetensors")
1387
+ )
1388
+ accelerator.save_state(save_path)
1389
+ os.remove(osp.join(save_path, "model.safetensors"))
1390
+ logger.info(f"Saved state to {save_path}")
1391
+ accelerator.wait_for_everyone()
1392
+ accelerator.end_training()
1393
+
1394
+
1395
+ if __name__ == "__main__":
1396
+ args = parse_args()
1397
+ main(args)
train/train_working.py ADDED
@@ -0,0 +1,1397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import logging
4
+ import random
5
+ import math
6
+ import os
7
+ import shutil
8
+ import gc
9
+ from contextlib import nullcontext
10
+ from pathlib import Path
11
+ import re
12
+ from safetensors.torch import save_file
13
+
14
+ from PIL import Image
15
+ import numpy as np
16
+ import torch.utils.checkpoint
17
+ import transformers
18
+ from accelerate import Accelerator
19
+ from accelerate.logging import get_logger
20
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
21
+
22
+ from tqdm.auto import tqdm
23
+ from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
24
+
25
+ import diffusers
26
+
27
+ from diffusers import (
28
+ AutoencoderKL,
29
+ FlowMatchEulerDiscreteScheduler
30
+ )
31
+ from diffusers.optimization import get_scheduler
32
+ from diffusers.training_utils import (
33
+ cast_training_params,
34
+ compute_density_for_timestep_sampling,
35
+ compute_loss_weighting_for_sd3,
36
+ )
37
+ import os.path as osp
38
+ from diffusers.utils.torch_utils import is_compiled_module
39
+ from diffusers.utils import (
40
+ check_min_version,
41
+ is_wandb_available,
42
+ convert_unet_state_dict_to_peft
43
+ )
44
+
45
+ from src.lora_helper import *
46
+ from src.pipeline import FluxPipeline, resize_position_encoding, prepare_latent_subject_ids
47
+ from src.layers import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
48
+ from src.transformer_flux import FluxTransformer2DModel
49
+ from src.jsonl_datasets import make_train_dataset, collate_fn
50
+
51
+ if is_wandb_available():
52
+ import wandb
53
+
54
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
55
+ check_min_version("0.31.0.dev0")
56
+
57
+ logger = get_logger(__name__)
58
+
59
+ import matplotlib.pyplot as plt
60
+ import torch
61
+
62
+
63
+ def load_text_encoders(args, class_one, class_two):
64
+ text_encoder_one = class_one.from_pretrained(
65
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
66
+ )
67
+ text_encoder_two = class_two.from_pretrained(
68
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
69
+ )
70
+ return text_encoder_one, text_encoder_two
71
+
72
+
73
+ def _encode_prompt_with_t5(
74
+ text_encoder,
75
+ tokenizer,
76
+ max_sequence_length=512,
77
+ prompt=None,
78
+ num_images_per_prompt=1,
79
+ device=None,
80
+ text_input_ids=None,
81
+ ):
82
+ prompt = [prompt] if isinstance(prompt, str) else prompt
83
+ batch_size = len(prompt)
84
+
85
+ if tokenizer is not None:
86
+ text_inputs = tokenizer(
87
+ prompt,
88
+ padding="max_length",
89
+ max_length=max_sequence_length,
90
+ truncation=True,
91
+ return_length=False,
92
+ return_overflowing_tokens=False,
93
+ return_tensors="pt",
94
+ )
95
+ text_input_ids = text_inputs.input_ids
96
+ else:
97
+ if text_input_ids is None:
98
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
99
+
100
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
101
+
102
+ if hasattr(text_encoder, "module"):
103
+ dtype = text_encoder.module.dtype
104
+ else:
105
+ dtype = text_encoder.dtype
106
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
107
+
108
+ _, seq_len, _ = prompt_embeds.shape
109
+
110
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
111
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
112
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
113
+
114
+ return prompt_embeds
115
+
116
+
117
+ def _encode_prompt_with_clip(
118
+ text_encoder,
119
+ tokenizer,
120
+ prompt: str,
121
+ device=None,
122
+ text_input_ids=None,
123
+ num_images_per_prompt: int = 1,
124
+ ):
125
+ prompt = [prompt] if isinstance(prompt, str) else prompt
126
+ batch_size = len(prompt)
127
+
128
+ if tokenizer is not None:
129
+ text_inputs = tokenizer(
130
+ prompt,
131
+ padding="max_length",
132
+ max_length=77,
133
+ truncation=True,
134
+ return_overflowing_tokens=False,
135
+ return_length=False,
136
+ return_tensors="pt",
137
+ )
138
+
139
+ text_input_ids = text_inputs.input_ids
140
+ else:
141
+ if text_input_ids is None:
142
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
143
+
144
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
145
+
146
+ if hasattr(text_encoder, "module"):
147
+ dtype = text_encoder.module.dtype
148
+ else:
149
+ dtype = text_encoder.dtype
150
+ # Use pooled output of CLIPTextModel
151
+ prompt_embeds = prompt_embeds.pooler_output
152
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
153
+
154
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
155
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
156
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
157
+
158
+ return prompt_embeds
159
+
160
+
161
+ def encode_prompt(
162
+ text_encoders,
163
+ tokenizers,
164
+ prompt: str,
165
+ max_sequence_length,
166
+ device=None,
167
+ num_images_per_prompt: int = 1,
168
+ text_input_ids_list=None,
169
+ ):
170
+ prompt = [prompt] if isinstance(prompt, str) else prompt
171
+
172
+ if hasattr(text_encoders[0], "module"):
173
+ dtype = text_encoders[0].module.dtype
174
+ else:
175
+ dtype = text_encoders[0].dtype
176
+
177
+ pooled_prompt_embeds = _encode_prompt_with_clip(
178
+ text_encoder=text_encoders[0],
179
+ tokenizer=tokenizers[0],
180
+ prompt=prompt,
181
+ device=device if device is not None else text_encoders[0].device,
182
+ num_images_per_prompt=num_images_per_prompt,
183
+ text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
184
+ )
185
+
186
+ prompt_embeds = _encode_prompt_with_t5(
187
+ text_encoder=text_encoders[1],
188
+ tokenizer=tokenizers[1],
189
+ max_sequence_length=max_sequence_length,
190
+ prompt=prompt,
191
+ num_images_per_prompt=num_images_per_prompt,
192
+ device=device if device is not None else text_encoders[1].device,
193
+ text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
194
+ )
195
+
196
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
197
+
198
+ return prompt_embeds, pooled_prompt_embeds, text_ids
199
+
200
+
201
+ def visualize_training_data(batch, vae, model_input, noisy_model_input, cond_input, args, global_step, accelerator):
202
+ """
203
+ Visualize training data including all entities from the batch.
204
+
205
+ Args:
206
+ batch: Training batch containing data
207
+ vae: VAE model for decoding latents
208
+ model_input: Clean latents before adding noise
209
+ noisy_model_input: Noisy latents passed to transformer
210
+ cond_input: Spatial condition latents (may be None)
211
+ args: Training arguments
212
+ global_step: Current training step
213
+ accelerator: Accelerator instance
214
+ """
215
+
216
+ # Check availability of conditions
217
+ has_spatial_condition = batch["cond_pixel_values"] is not None
218
+ has_cuboids_segmasks = "cuboids_segmasks" in batch and batch["cuboids_segmasks"] is not None
219
+ has_cuboids_segmasks_bev = "cuboids_segmasks_bev" in batch and batch["cuboids_segmasks_bev"] is not None
220
+
221
+ # Initialize variables
222
+ spatial_img = None
223
+
224
+ with torch.no_grad():
225
+ # Get VAE config for proper decoding
226
+ vae_config_shift_factor = vae.config.shift_factor
227
+ vae_config_scaling_factor = vae.config.scaling_factor
228
+ vae_dtype = vae.dtype
229
+ vae = vae.to(torch.float32)
230
+
231
+ # Decode spatial condition if available
232
+ if has_spatial_condition:
233
+ cond_for_decode = (cond_input / vae_config_scaling_factor) + vae_config_shift_factor
234
+ spatial_decoded = vae.decode(cond_for_decode.float()).sample
235
+ spatial_decoded = (spatial_decoded / 2 + 0.5).clamp(0, 1) # Normalize to [0,1]
236
+ spatial_img = spatial_decoded[0].float().cpu().permute(1, 2, 0).numpy()
237
+
238
+ # Decode clean model input
239
+ clean_for_decode = (model_input / vae_config_scaling_factor) + vae_config_shift_factor
240
+ clean_decoded = vae.decode(clean_for_decode.float()).sample
241
+ clean_decoded = (clean_decoded / 2 + 0.5).clamp(0, 1)
242
+
243
+ # Decode noisy model input
244
+ noisy_for_decode = (noisy_model_input / vae_config_scaling_factor) + vae_config_shift_factor
245
+ noisy_decoded = vae.decode(noisy_for_decode.float()).sample
246
+ noisy_decoded = (noisy_decoded / 2 + 0.5).clamp(0, 1)
247
+
248
+ # Convert to CPU and numpy for visualization (take first batch item)
249
+ clean_img = clean_decoded[0].float().cpu().permute(1, 2, 0).numpy()
250
+ noisy_img = noisy_decoded[0].float().cpu().permute(1, 2, 0).numpy()
251
+
252
+ # Get text prompt and other info
253
+ text_prompt = batch["prompts"][0] if isinstance(batch["prompts"], list) else batch["prompts"]
254
+ call_id = batch["call_ids"][0] if batch["call_ids"] is not None else "N/A"
255
+
256
+ # Create figure with more subplots to accommodate all entities including BEV
257
+ fig, axes = plt.subplots(4, 3, figsize=(18, 24))
258
+ # fig.suptitle(f'Training Data Visualization - Step {global_step}', fontsize=16)
259
+
260
+ # Spatial condition (0,0)
261
+ if has_spatial_condition and spatial_img is not None:
262
+ axes[0, 0].imshow(spatial_img)
263
+ axes[0, 0].set_title('Spatial Condition')
264
+ else:
265
+ axes[0, 0].text(0.5, 0.5, 'NOT AVAILABLE',
266
+ horizontalalignment='center', verticalalignment='center',
267
+ transform=axes[0, 0].transAxes, fontsize=14, fontweight='bold')
268
+ axes[0, 0].set_title('Spatial Condition')
269
+ axes[0, 0].axis('off')
270
+
271
+ # Clean model input (0,2)
272
+ axes[0, 2].imshow(clean_img)
273
+ axes[0, 2].set_title('Clean Model Input')
274
+ axes[0, 2].axis('off')
275
+
276
+ # Noisy model input (1,0)
277
+ axes[1, 0].imshow(noisy_img)
278
+ axes[1, 0].set_title('Noisy Model Input')
279
+ axes[1, 0].axis('off')
280
+
281
+ # Cuboids segmentation masks with legend (1,1 and 1,2)
282
+ if has_cuboids_segmasks:
283
+ segmask = batch["cuboids_segmasks"][0].float().cpu().numpy() # Shape: (n_subjects, h, w)
284
+ n_subjects, h, w = segmask.shape
285
+
286
+ # Only use first 4 subjects for visualization
287
+ n_subjects_to_show = min(4, n_subjects)
288
+
289
+ # Create colored segmentation visualization
290
+ np.random.seed(42) # For consistent colors
291
+ colors = np.random.rand(n_subjects_to_show + 1, 3) # +1 for background
292
+ colors[0] = [0, 0, 0] # Background is black
293
+
294
+ # Create 2x2 grid of individual subject masks
295
+ grid_h, grid_w = 2, 2
296
+ combined_mask = np.zeros((h * grid_h, w * grid_w, 3))
297
+
298
+ for idx in range(n_subjects_to_show):
299
+ row = idx // grid_w
300
+ col = idx % grid_w
301
+
302
+ # Create binary mask for this subject
303
+ subject_mask = np.zeros((h, w, 3))
304
+ mask = segmask[idx] > 0.5 # Binary threshold
305
+ subject_mask[mask] = colors[idx + 1]
306
+
307
+ # Place in grid
308
+ combined_mask[row*h:(row+1)*h, col*w:(col+1)*w] = subject_mask
309
+
310
+ axes[1, 1].imshow(combined_mask)
311
+ axes[1, 1].set_title('Cuboids Segmentation (2x2 Grid)')
312
+ axes[1, 1].axis('off')
313
+
314
+ # Create legend in the next subplot (1,2) - only for first 4 subjects
315
+ axes[1, 2].set_xlim(0, 1)
316
+ axes[1, 2].set_ylim(0, 1)
317
+
318
+ # Add legend entries
319
+ legend_y_positions = np.linspace(0.9, 0.1, n_subjects_to_show + 1)
320
+ axes[1, 2].text(0.1, legend_y_positions[0], f"Background",
321
+ color=colors[0], fontsize=12, fontweight='bold')
322
+
323
+ for subject_idx in range(n_subjects_to_show):
324
+ axes[1, 2].text(0.1, legend_y_positions[subject_idx + 1],
325
+ f"Subject {subject_idx}",
326
+ color=colors[subject_idx + 1], fontsize=12, fontweight='bold')
327
+
328
+ axes[1, 2].set_title('Segmentation Legend (First 4)')
329
+ axes[1, 2].axis('off')
330
+ else:
331
+ axes[1, 1].text(0.5, 0.5, 'NOT AVAILABLE',
332
+ horizontalalignment='center', verticalalignment='center',
333
+ transform=axes[1, 1].transAxes, fontsize=14, fontweight='bold')
334
+ axes[1, 1].set_title('Cuboids Segmentation')
335
+ axes[1, 1].axis('off')
336
+
337
+ axes[1, 2].text(0.5, 0.5, 'NOT AVAILABLE',
338
+ horizontalalignment='center', verticalalignment='center',
339
+ transform=axes[1, 2].transAxes, fontsize=14, fontweight='bold')
340
+ axes[1, 2].set_title('Segmentation Legend')
341
+ axes[1, 2].axis('off')
342
+
343
+ # BEV Cuboids segmentation masks with legend (2,0 and 2,1)
344
+ if has_cuboids_segmasks_bev:
345
+ segmask_bev = batch["cuboids_segmasks_bev"][0].float().cpu().numpy() # Shape: (n_subjects, h, w)
346
+ n_subjects_bev, h_bev, w_bev = segmask_bev.shape
347
+
348
+ # Create colored segmentation visualization for BEV (use different seed for different colors)
349
+ np.random.seed(123) # Different seed for BEV colors
350
+ colors_bev = np.random.rand(n_subjects_bev + 1, 3) # +1 for background
351
+ colors_bev[0] = [0, 0, 0] # Background is black
352
+
353
+ # Create RGB image from BEV segmentation
354
+ colored_segmask_bev = np.zeros((h_bev, w_bev, 3))
355
+ for subject_idx in range(n_subjects_bev):
356
+ mask_bev = segmask_bev[subject_idx] > 0.5 # Binary threshold
357
+ colored_segmask_bev[mask_bev] = colors_bev[subject_idx + 1]
358
+
359
+ axes[2, 0].imshow(colored_segmask_bev)
360
+ axes[2, 0].set_title('BEV Cuboids Segmentation')
361
+ axes[2, 0].axis('off')
362
+
363
+ # Create BEV legend in the next subplot (2,1)
364
+ axes[2, 1].set_xlim(0, 1)
365
+ axes[2, 1].set_ylim(0, 1)
366
+
367
+ # Add BEV legend entries
368
+ legend_y_positions_bev = np.linspace(0.9, 0.1, n_subjects_bev + 1)
369
+ axes[2, 1].text(0.1, legend_y_positions_bev[0], f"Background",
370
+ color=colors_bev[0], fontsize=12, fontweight='bold')
371
+
372
+ for subject_idx in range(n_subjects_bev):
373
+ axes[2, 1].text(0.1, legend_y_positions_bev[subject_idx + 1],
374
+ f"Subject {subject_idx}",
375
+ color=colors_bev[subject_idx + 1], fontsize=12, fontweight='bold')
376
+
377
+ axes[2, 1].set_title('BEV Segmentation Legend')
378
+ axes[2, 1].axis('off')
379
+ else:
380
+ axes[2, 0].text(0.5, 0.5, 'NOT AVAILABLE',
381
+ horizontalalignment='center', verticalalignment='center',
382
+ transform=axes[2, 0].transAxes, fontsize=14, fontweight='bold')
383
+ axes[2, 0].set_title('BEV Cuboids Segmentation')
384
+ axes[2, 0].axis('off')
385
+
386
+ axes[2, 1].text(0.5, 0.5, 'NOT AVAILABLE',
387
+ horizontalalignment='center', verticalalignment='center',
388
+ transform=axes[2, 1].transAxes, fontsize=14, fontweight='bold')
389
+ axes[2, 1].set_title('BEV Segmentation Legend')
390
+ axes[2, 1].axis('off')
391
+
392
+ # Text prompt and call ID (2,2)
393
+ axes[2, 2].text(0.5, 0.5, f'Text Prompt:\n\n"{text_prompt}"\n\nCall ID: {call_id}',
394
+ horizontalalignment='center', verticalalignment='center',
395
+ transform=axes[2, 2].transAxes, fontsize=12, wrap=True)
396
+ axes[2, 2].set_title('Text Prompt & Call ID')
397
+ axes[2, 2].axis('off')
398
+
399
+ # Pixel values info (3,0)
400
+ pixel_info = f'Pixel Values Shape: {batch["pixel_values"].shape}\n'
401
+ if has_spatial_condition:
402
+ pixel_info += f'Spatial Shape: {batch["cond_pixel_values"].shape}\n'
403
+ if has_cuboids_segmasks:
404
+ pixel_info += f'Cuboids Segmasks: {len(batch["cuboids_segmasks"])}\n'
405
+ if has_cuboids_segmasks_bev:
406
+ pixel_info += f'BEV Segmasks: {len(batch["cuboids_segmasks_bev"])}'
407
+
408
+ axes[3, 0].text(0.5, 0.5, pixel_info,
409
+ horizontalalignment='center', verticalalignment='center',
410
+ transform=axes[3, 0].transAxes, fontsize=10, fontfamily='monospace')
411
+ axes[3, 0].set_title('Tensor Shapes')
412
+ axes[3, 0].axis('off')
413
+
414
+ # Training info (3,1)
415
+ training_info = f'Global Step: {global_step}\nConditions:\nSpatial: {"✓" if has_spatial_condition else "✗"}\nSubject: {"fuck you"}\nSegmasks: {"✓" if has_cuboids_segmasks else "✗"}\nBEV Segmasks: {"✓" if has_cuboids_segmasks_bev else "✗"}'
416
+ axes[3, 1].text(0.5, 0.5, training_info,
417
+ horizontalalignment='center', verticalalignment='center',
418
+ transform=axes[3, 1].transAxes, fontsize=12, fontfamily='monospace')
419
+ axes[3, 1].set_title('Training Info')
420
+ axes[3, 1].axis('off')
421
+
422
+ # Additional info (3,2) - can be used for any extra debugging info
423
+ axes[3, 2].text(0.5, 0.5, 'Additional Info\n(Reserved)',
424
+ horizontalalignment='center', verticalalignment='center',
425
+ transform=axes[3, 2].transAxes, fontsize=12, fontfamily='monospace')
426
+ axes[3, 2].set_title('Reserved')
427
+ axes[3, 2].axis('off')
428
+
429
+ plt.tight_layout()
430
+
431
+ # Save the visualization
432
+ save_dir = os.path.join(args.output_dir, "visualizations")
433
+ os.makedirs(save_dir, exist_ok=True)
434
+ save_path = os.path.join(save_dir, f"training_vis_step_{global_step}.png")
435
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
436
+ plt.close()
437
+
438
+ logger.info(f"Training visualization saved to {save_path}")
439
+
440
+ vae = vae.to(vae_dtype)
441
+
442
+ def import_model_class_from_model_name_or_path(
443
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
444
+ ):
445
+ text_encoder_config = PretrainedConfig.from_pretrained(
446
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
447
+ )
448
+ model_class = text_encoder_config.architectures[0]
449
+ if model_class == "CLIPTextModel":
450
+ from transformers import CLIPTextModel
451
+
452
+ return CLIPTextModel
453
+ elif model_class == "T5EncoderModel":
454
+ from transformers import T5EncoderModel
455
+
456
+ return T5EncoderModel
457
+ else:
458
+ raise ValueError(f"{model_class} is not supported.")
459
+
460
+
461
+ def parse_args(input_args=None):
462
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
463
+ parser.add_argument("--lora_num", type=int, default=2, help="number of the lora.")
464
+ parser.add_argument("--cond_size", type=int, default=512, help="size of the condition data.")
465
+ parser.add_argument("--debug", type=int, default=0, help="whether to enter debug mode -- visualizations, gradient checks, etc.")
466
+ parser.add_argument("--mode",type=str,default=None,help="The mode of the controller. Choose between ['depth', 'pose', 'canny'].")
467
+ parser.add_argument("--run_name",type=str,required=True,help="the name of the wandb run")
468
+ parser.add_argument(
469
+ "--train_data_dir",
470
+ type=str,
471
+ default="",
472
+ help=(
473
+ "A folder containing the training data. Folder contents must follow the structure described in"
474
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
475
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
476
+ ),
477
+ )
478
+ parser.add_argument(
479
+ "--inference_embeds_dir",
480
+ type=str,
481
+ default=None,
482
+ help=(
483
+ "the captions for images"
484
+ ),
485
+ )
486
+ parser.add_argument(
487
+ "--pretrained_model_name_or_path",
488
+ type=str,
489
+ default="",
490
+ required=False,
491
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
492
+ )
493
+ parser.add_argument(
494
+ "--revision",
495
+ type=str,
496
+ default=None,
497
+ required=False,
498
+ help="Revision of pretrained model identifier from huggingface.co/models.",
499
+ )
500
+ parser.add_argument(
501
+ "--variant",
502
+ type=str,
503
+ default=None,
504
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
505
+ )
506
+ parser.add_argument(
507
+ "--spatial_column",
508
+ type=str,
509
+ default="None",
510
+ help="The column of the dataset containing the canny image. By "
511
+ "default, the standard Image Dataset maps out 'file_name' "
512
+ "to 'image'.",
513
+ )
514
+ parser.add_argument(
515
+ "--target_column",
516
+ type=str,
517
+ default="image",
518
+ help="The column of the dataset containing the target image. By "
519
+ "default, the standard Image Dataset maps out 'file_name' "
520
+ "to 'image'.",
521
+ )
522
+ parser.add_argument(
523
+ "--caption_column",
524
+ type=str,
525
+ default="caption_left,caption_right",
526
+ help="The column of the dataset containing the instance prompt for each image",
527
+ )
528
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
529
+ parser.add_argument(
530
+ "--max_sequence_length",
531
+ type=int,
532
+ default=512,
533
+ help="Maximum sequence length to use with with the T5 text encoder",
534
+ )
535
+ parser.add_argument(
536
+ "--ranks",
537
+ type=int,
538
+ nargs="+",
539
+ default=[128],
540
+ help=("The dimension of the LoRA update matrices."),
541
+ )
542
+ parser.add_argument(
543
+ "--network_alphas",
544
+ type=int,
545
+ nargs="+",
546
+ default=[128],
547
+ help=("The dimension of the LoRA update matrices."),
548
+ )
549
+ parser.add_argument(
550
+ "--output_dir",
551
+ type=str,
552
+ required=True,
553
+ help="The output directory where the model predictions and checkpoints will be written.",
554
+ )
555
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
556
+ parser.add_argument(
557
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
558
+ )
559
+ parser.add_argument("--num_train_epochs", type=int, default=50)
560
+ parser.add_argument(
561
+ "--max_train_steps",
562
+ type=int,
563
+ default=None,
564
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
565
+ )
566
+ parser.add_argument(
567
+ "--checkpointing_steps",
568
+ type=int,
569
+ default=1000,
570
+ help=(
571
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
572
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
573
+ " training using `--resume_from_checkpoint`."
574
+ ),
575
+ )
576
+ parser.add_argument(
577
+ "--resume_from_checkpoint",
578
+ type=str,
579
+ default=None,
580
+ help=(
581
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
582
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
583
+ ),
584
+ )
585
+ parser.add_argument(
586
+ "--gradient_accumulation_steps",
587
+ type=int,
588
+ default=1,
589
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
590
+ )
591
+ parser.add_argument(
592
+ "--gradient_checkpointing",
593
+ action="store_true",
594
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
595
+ )
596
+ parser.add_argument(
597
+ "--learning_rate",
598
+ type=float,
599
+ default=1e-4,
600
+ help="Initial learning rate (after the potential warmup period) to use.",
601
+ )
602
+
603
+ parser.add_argument(
604
+ "--guidance_scale",
605
+ type=float,
606
+ default=1,
607
+ help="the FLUX.1 dev variant is a guidance distilled model",
608
+ )
609
+ parser.add_argument(
610
+ "--scale_lr",
611
+ action="store_true",
612
+ default=False,
613
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
614
+ )
615
+ parser.add_argument(
616
+ "--lr_scheduler",
617
+ type=str,
618
+ default="constant",
619
+ help=(
620
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
621
+ ' "constant", "constant_with_warmup"]'
622
+ ),
623
+ )
624
+ parser.add_argument(
625
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
626
+ )
627
+ parser.add_argument(
628
+ "--lr_num_cycles",
629
+ type=int,
630
+ default=1,
631
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
632
+ )
633
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
634
+ parser.add_argument(
635
+ "--dataloader_num_workers",
636
+ type=int,
637
+ default=2,
638
+ help=(
639
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
640
+ ),
641
+ )
642
+ parser.add_argument(
643
+ "--weighting_scheme",
644
+ type=str,
645
+ default="none",
646
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
647
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
648
+ )
649
+ parser.add_argument(
650
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
651
+ )
652
+ parser.add_argument(
653
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
654
+ )
655
+ parser.add_argument(
656
+ "--mode_scale",
657
+ type=float,
658
+ default=1.29,
659
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
660
+ )
661
+ parser.add_argument(
662
+ "--optimizer",
663
+ type=str,
664
+ default="AdamW",
665
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
666
+ )
667
+
668
+ parser.add_argument(
669
+ "--use_8bit_adam",
670
+ action="store_true",
671
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
672
+ )
673
+
674
+ parser.add_argument(
675
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
676
+ )
677
+ parser.add_argument(
678
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
679
+ )
680
+ parser.add_argument(
681
+ "--prodigy_beta3",
682
+ type=float,
683
+ default=None,
684
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
685
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
686
+ )
687
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
688
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
689
+ parser.add_argument(
690
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
691
+ )
692
+
693
+ parser.add_argument(
694
+ "--adam_epsilon",
695
+ type=float,
696
+ default=1e-08,
697
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
698
+ )
699
+
700
+ parser.add_argument(
701
+ "--prodigy_use_bias_correction",
702
+ type=bool,
703
+ default=True,
704
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
705
+ )
706
+ parser.add_argument(
707
+ "--prodigy_safeguard_warmup",
708
+ type=bool,
709
+ default=True,
710
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
711
+ "Ignored if optimizer is adamW",
712
+ )
713
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
714
+ parser.add_argument(
715
+ "--logging_dir",
716
+ type=str,
717
+ default="logs",
718
+ help=(
719
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
720
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
721
+ ),
722
+ )
723
+ parser.add_argument(
724
+ "--report_to",
725
+ type=str,
726
+ default="tensorboard",
727
+ help=(
728
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
729
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
730
+ ),
731
+ )
732
+ parser.add_argument(
733
+ "--mixed_precision",
734
+ type=str,
735
+ default="bf16",
736
+ choices=["no", "fp16", "bf16"],
737
+ help=(
738
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
739
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
740
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
741
+ ),
742
+ )
743
+ parser.add_argument(
744
+ "--upcast_before_saving",
745
+ action="store_true",
746
+ default=False,
747
+ help=(
748
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
749
+ "Defaults to precision dtype used for training to save memory"
750
+ ),
751
+ )
752
+
753
+ if input_args is not None:
754
+ args = parser.parse_args(input_args)
755
+ else:
756
+ args = parser.parse_args()
757
+ return args
758
+
759
+
760
+ def main(args):
761
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
762
+ # due to pytorch#99272, MPS does not yet support bfloat16.
763
+ raise ValueError(
764
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
765
+ )
766
+
767
+ if args.resume_from_checkpoint is not None:
768
+ assert osp.exists(args.resume_from_checkpoint), f"Make sure that the `resume_from_checkpoint` {args.resume_from_checkpoint} exists."
769
+ args.pretrained_lora_path = osp.join(args.resume_from_checkpoint, f"lora.safetensors")
770
+ assert osp.exists(args.pretrained_lora_path), f"Make sure that the `pretrained_lora_path` {args.pretrained_lora_path} exists."
771
+ else:
772
+ args.pretrained_lora_path = None
773
+
774
+ args.output_dir = osp.join(args.output_dir, args.run_name)
775
+ args.logging_dir = osp.join(args.output_dir, args.logging_dir)
776
+ os.makedirs(args.output_dir, exist_ok=True)
777
+ os.makedirs(args.logging_dir, exist_ok=True)
778
+ logging_dir = Path(args.output_dir, args.logging_dir)
779
+
780
+ if args.spatial_column == "None":
781
+ args.spatial_column = None
782
+
783
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
784
+ # kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
785
+ accelerator = Accelerator(
786
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
787
+ mixed_precision=args.mixed_precision,
788
+ log_with=args.report_to,
789
+ project_config=accelerator_project_config,
790
+ # kwargs_handlers=[kwargs],
791
+ )
792
+
793
+ def save_model_hook(models, weights, output_dir):
794
+ pass
795
+
796
+ def load_model_hook(models, input_dir):
797
+ pass
798
+
799
+ # Disable AMP for MPS.
800
+ if torch.backends.mps.is_available():
801
+ accelerator.native_amp = False
802
+
803
+ if args.report_to == "wandb":
804
+ if not is_wandb_available():
805
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
806
+
807
+ # Make one log on every process with the configuration for debugging.
808
+ logging.basicConfig(
809
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
810
+ datefmt="%m/%d/%Y %H:%M:%S",
811
+ level=logging.INFO,
812
+ )
813
+ logger.info(accelerator.state, main_process_only=False)
814
+ if accelerator.is_local_main_process:
815
+ transformers.utils.logging.set_verbosity_warning()
816
+ diffusers.utils.logging.set_verbosity_info()
817
+ else:
818
+ transformers.utils.logging.set_verbosity_error()
819
+ diffusers.utils.logging.set_verbosity_error()
820
+
821
+ # If passed along, set the training seed now.
822
+ if args.seed is not None:
823
+ set_seed(args.seed)
824
+
825
+ # Handle the repository creation
826
+ if accelerator.is_main_process:
827
+ if args.output_dir is not None:
828
+ os.makedirs(args.output_dir, exist_ok=True)
829
+
830
+ # Load the tokenizers
831
+ tokenizer_one = CLIPTokenizer.from_pretrained(
832
+ args.pretrained_model_name_or_path,
833
+ subfolder="tokenizer",
834
+ revision=args.revision,
835
+ )
836
+ tokenizer_two = T5TokenizerFast.from_pretrained(
837
+ args.pretrained_model_name_or_path,
838
+ subfolder="tokenizer_2",
839
+ revision=args.revision,
840
+ )
841
+
842
+ # Load scheduler and models
843
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
844
+ args.pretrained_model_name_or_path, subfolder="scheduler"
845
+ )
846
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
847
+ gc.collect()
848
+ torch.cuda.empty_cache()
849
+
850
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
851
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder"
852
+ )
853
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
854
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
855
+ )
856
+ if args.inference_embeds_dir is None:
857
+ text_encoder_one, text_encoder_two = load_text_encoders(args, text_encoder_cls_one, text_encoder_cls_two)
858
+ else:
859
+ assert osp.exists(args.inference_embeds_dir), f"Make sure that the `inference_embeds_dir` {args.inference_embeds_dir} exists."
860
+ vae = AutoencoderKL.from_pretrained(
861
+ args.pretrained_model_name_or_path,
862
+ subfolder="vae",
863
+ revision=args.revision,
864
+ variant=args.variant,
865
+ )
866
+ transformer = FluxTransformer2DModel.from_pretrained(
867
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
868
+ )
869
+
870
+ # We only train the additional adapter LoRA layers
871
+ transformer.requires_grad_(True)
872
+ vae.requires_grad_(False)
873
+ if args.inference_embeds_dir is None:
874
+ text_encoder_one.requires_grad_(False)
875
+ text_encoder_two.requires_grad_(False)
876
+
877
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
878
+ # as these weights are only used for inference, keeping weights in full precision is not required.
879
+ weight_dtype = torch.float32
880
+ if accelerator.mixed_precision == "fp16":
881
+ weight_dtype = torch.float16
882
+ elif accelerator.mixed_precision == "bf16":
883
+ weight_dtype = torch.bfloat16
884
+
885
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
886
+ # due to pytorch#99272, MPS does not yet support bfloat16.
887
+ raise ValueError(
888
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
889
+ )
890
+
891
+ vae.to(accelerator.device, dtype=weight_dtype)
892
+ transformer.to(accelerator.device, dtype=weight_dtype)
893
+ if args.inference_embeds_dir is None:
894
+ text_encoder_one.to(accelerator.device, dtype=torch.float32)
895
+ text_encoder_two.to(accelerator.device, dtype=torch.float32)
896
+
897
+ if args.gradient_checkpointing:
898
+ transformer.enable_gradient_checkpointing()
899
+
900
+ #### lora_layers ####
901
+ if args.pretrained_lora_path is not None:
902
+ lora_path = args.pretrained_lora_path
903
+ checkpoint = load_checkpoint(lora_path)
904
+ lora_attn_procs = {}
905
+ double_blocks_idx = list(range(19))
906
+ single_blocks_idx = list(range(38))
907
+ number = 1
908
+ for name, attn_processor in transformer.attn_processors.items():
909
+ match = re.search(r'\.(\d+)\.', name)
910
+ if match:
911
+ layer_index = int(match.group(1))
912
+
913
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
914
+ lora_state_dicts = {}
915
+ for key, value in checkpoint.items():
916
+ # Match based on the layer index in the key (assuming the key contains layer index)
917
+ if re.search(r'\.(\d+)\.', key):
918
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
919
+ if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
920
+ lora_state_dicts[key] = value
921
+
922
+ print("setting LoRA Processor for", name)
923
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
924
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
925
+ )
926
+
927
+ # Load the weights from the checkpoint dictionary into the corresponding layers
928
+ for n in range(number):
929
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
930
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
931
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
932
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
933
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
934
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
935
+ lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
936
+ lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
937
+
938
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
939
+
940
+ lora_state_dicts = {}
941
+ for key, value in checkpoint.items():
942
+ # Match based on the layer index in the key (assuming the key contains layer index)
943
+ if re.search(r'\.(\d+)\.', key):
944
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
945
+ if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
946
+ lora_state_dicts[key] = value
947
+
948
+ print("setting LoRA Processor for", name)
949
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
950
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
951
+ )
952
+
953
+ # Load the weights from the checkpoint dictionary into the corresponding layers
954
+ for n in range(number):
955
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
956
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
957
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
958
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
959
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
960
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
961
+ else:
962
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
963
+ else:
964
+ lora_attn_procs = {}
965
+ double_blocks_idx = list(range(19))
966
+ single_blocks_idx = list(range(38))
967
+ for name, attn_processor in transformer.attn_processors.items():
968
+ match = re.search(r'\.(\d+)\.', name)
969
+ if match:
970
+ layer_index = int(match.group(1))
971
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
972
+ lora_state_dicts = {}
973
+ print("setting LoRA Processor for", name)
974
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
975
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
976
+ )
977
+
978
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
979
+ print("setting LoRA Processor for", name)
980
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
981
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
982
+ )
983
+
984
+ else:
985
+ lora_attn_procs[name] = attn_processor
986
+ ######################
987
+ transformer.set_attn_processor(lora_attn_procs)
988
+ transformer.train()
989
+ for n, param in transformer.named_parameters():
990
+ if '_lora' not in n:
991
+ param.requires_grad = False
992
+ print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'M parameters')
993
+
994
+ def unwrap_model(model):
995
+ model = accelerator.unwrap_model(model)
996
+ model = model._orig_mod if is_compiled_module(model) else model
997
+ return model
998
+
999
+ # Potentially load in the weights and states from a previous save
1000
+ if args.resume_from_checkpoint:
1001
+ foldername = osp.basename(args.resume_from_checkpoint)
1002
+ first_epoch = epoch = int(foldername.split("-")[1].split("__")[0])
1003
+ initial_global_step = global_step = int(foldername.split("-")[-1])
1004
+ else:
1005
+ initial_global_step = 0
1006
+ global_step = 0
1007
+ first_epoch = 0
1008
+
1009
+ if args.scale_lr:
1010
+ args.learning_rate = (
1011
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1012
+ )
1013
+
1014
+ # Make sure the trainable params are in float32.
1015
+ if args.mixed_precision == "fp16":
1016
+ models = [transformer]
1017
+ # only upcast trainable parameters (LoRA) into fp32
1018
+ cast_training_params(models, dtype=torch.float32)
1019
+
1020
+ # Optimization parameters
1021
+ params_to_optimize = [p for p in transformer.parameters() if p.requires_grad]
1022
+ transformer_parameters_with_lr = {"params": params_to_optimize, "lr": args.learning_rate}
1023
+ print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'parameters')
1024
+
1025
+ optimizer_class = torch.optim.AdamW
1026
+ optimizer = optimizer_class(
1027
+ [transformer_parameters_with_lr],
1028
+ betas=(args.adam_beta1, args.adam_beta2),
1029
+ weight_decay=args.adam_weight_decay,
1030
+ eps=args.adam_epsilon,
1031
+ )
1032
+
1033
+ tokenizers = [tokenizer_one, tokenizer_two]
1034
+
1035
+ # now, we will define a dataset for each epoch to make it easier to save the state
1036
+ shuffled_jsonls = os.listdir(osp.dirname(args.train_data_dir))
1037
+ base_jsonl_name = osp.basename(args.train_data_dir).replace(".jsonl", "")
1038
+ shuffled_jsonls = sorted([_ for _ in shuffled_jsonls if _.endswith('.jsonl') and "shuffled" in _ and base_jsonl_name in _])
1039
+ shuffled_jsonls = [osp.join(osp.dirname(args.train_data_dir), _) for _ in shuffled_jsonls]
1040
+ print(f"{shuffled_jsonls = }")
1041
+ assert len(shuffled_jsonls) > 0, f"Make sure that there are shuffled jsonl files in {osp.dirname(args.train_data_dir)}"
1042
+ train_dataloaders = []
1043
+ for epoch in range(args.num_train_epochs): # prepare dataloader for each epoch, irrespective of the resume state
1044
+ shuffled_idx = epoch % len(shuffled_jsonls)
1045
+ train_data_file = shuffled_jsonls[shuffled_idx]
1046
+ assert osp.exists(train_data_file), f"Make sure that the train data jsonl file {train_data_file} exists."
1047
+ args.current_train_data_dir = train_data_file
1048
+ train_dataset = make_train_dataset(args, tokenizers, accelerator)
1049
+ train_dataloader = torch.utils.data.DataLoader(
1050
+ train_dataset,
1051
+ batch_size=args.train_batch_size,
1052
+ shuffle=False, # yayy!! reproducible experiments!
1053
+ collate_fn=collate_fn,
1054
+ num_workers=args.dataloader_num_workers,
1055
+ )
1056
+ train_dataloaders.append(train_dataloader)
1057
+
1058
+ vae_config_shift_factor = vae.config.shift_factor
1059
+ vae_config_scaling_factor = vae.config.scaling_factor
1060
+
1061
+ # Scheduler and math around the number of training steps.
1062
+ overrode_max_train_steps = False
1063
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1064
+ if args.max_train_steps is None:
1065
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1066
+ overrode_max_train_steps = True
1067
+
1068
+ lr_scheduler = get_scheduler(
1069
+ args.lr_scheduler,
1070
+ optimizer=optimizer,
1071
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1072
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
1073
+ num_cycles=args.lr_num_cycles,
1074
+ power=args.lr_power,
1075
+ )
1076
+
1077
+
1078
+ accelerator.register_save_state_pre_hook(save_model_hook)
1079
+ accelerator.register_load_state_pre_hook(load_model_hook)
1080
+ optimizer, lr_scheduler = accelerator.prepare(
1081
+ optimizer, lr_scheduler
1082
+ )
1083
+
1084
+ print(f"before preparation, {len(train_dataloaders[0]) = }")
1085
+
1086
+ prepared_train_dataloaders = []
1087
+ for train_dataloader in train_dataloaders:
1088
+ prepared_train_dataloaders.append(accelerator.prepare(train_dataloader))
1089
+ train_dataloaders = prepared_train_dataloaders
1090
+
1091
+ print(f"after preparation, {len(train_dataloaders[0]) = }")
1092
+
1093
+ if args.pretrained_lora_path is not None:
1094
+ accelerator.load_state(osp.dirname(args.pretrained_lora_path))
1095
+
1096
+ # Explicitly move optimizer states to accelerator.device
1097
+ for state in optimizer.state.values():
1098
+ for k, v in state.items():
1099
+ if isinstance(v, torch.Tensor):
1100
+ state[k] = v.to(accelerator.device)
1101
+
1102
+ transformer = accelerator.prepare(transformer)
1103
+
1104
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1105
+ num_update_steps_per_epoch = math.ceil(len(train_dataloaders[0]) / args.gradient_accumulation_steps)
1106
+ if overrode_max_train_steps:
1107
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1108
+ # Afterwards we recalculate our number of training epochs
1109
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1110
+
1111
+ # We need to initialize the trackers we use, and also store our configuration.
1112
+
1113
+ if accelerator.is_main_process:
1114
+ accelerator.init_trackers(args.run_name)
1115
+
1116
+
1117
+ # Train!
1118
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1119
+
1120
+ logger.info("***** Running training *****")
1121
+ logger.info(f" Num examples = {len(train_dataset)}")
1122
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1123
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1124
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1125
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1126
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1127
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1128
+
1129
+ progress_bar = tqdm(
1130
+ range(0, args.max_train_steps),
1131
+ initial=initial_global_step,
1132
+ desc="Steps",
1133
+ # Only show the progress bar once on each machine.
1134
+ disable=not accelerator.is_local_main_process,
1135
+ )
1136
+
1137
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1138
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
1139
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
1140
+ timesteps = timesteps.to(accelerator.device)
1141
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
1142
+
1143
+ sigma = sigmas[step_indices].flatten()
1144
+ while len(sigma.shape) < n_dim:
1145
+ sigma = sigma.unsqueeze(-1)
1146
+ return sigma
1147
+
1148
+ # some fixed parameters
1149
+ vae_scale_factor = 16
1150
+ height_cond = 2 * (args.cond_size // vae_scale_factor)
1151
+ width_cond = 2 * (args.cond_size // vae_scale_factor)
1152
+ offset = 64
1153
+
1154
+ num_training_visualizations = 10
1155
+
1156
+ skip_steps = initial_global_step - first_epoch * num_update_steps_per_epoch
1157
+ print(f"{skip_steps = }")
1158
+ for epoch in range(first_epoch, args.num_train_epochs):
1159
+ transformer.train()
1160
+ train_dataloader = train_dataloaders[epoch] # use a new dataloader for each epoch
1161
+ if epoch == first_epoch and skip_steps > 0:
1162
+ logger.info(f"Skipping {skip_steps} batches in epoch {epoch} due to resuming from checkpoint")
1163
+ # dataloader_iterator = skip_first_batches_manual(train_dataloader, skip_steps)
1164
+ dataloader_iterator = accelerator.skip_first_batches(train_dataloader, skip_steps)
1165
+ # Convert back to enumerate format
1166
+ enumerated_dataloader = enumerate(dataloader_iterator, start=skip_steps)
1167
+ else:
1168
+ enumerated_dataloader = enumerate(train_dataloader)
1169
+ for step, batch in enumerated_dataloader:
1170
+ progress_bar.set_description(f"epoch {epoch}, dataset_ids: {batch['index']}")
1171
+ models_to_accumulate = [transformer]
1172
+ with accelerator.accumulate(models_to_accumulate):
1173
+
1174
+ if args.inference_embeds_dir is None:
1175
+ print(f"encoding {batch['prompts'] = }")
1176
+ # prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
1177
+ # prompt=batch["prompts"],
1178
+ # prompt_2=batch["prompts"],
1179
+ # )
1180
+ # prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
1181
+ # text_encoders=[text_encoder_one, text_encoder_two],
1182
+ # tokenizers=[tokenizer_one, tokenizer_two],
1183
+ # prompt=batch["prompts"],
1184
+ # max_sequence_length=512,
1185
+ # device=accelerator.device,
1186
+ # )
1187
+ for i, prompt in enumerate(batch["prompts"]):
1188
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
1189
+ text_encoders=[text_encoder_one, text_encoder_two],
1190
+ tokenizers=[tokenizer_one, tokenizer_two],
1191
+ prompt=prompt,
1192
+ max_sequence_length=512,
1193
+ device=accelerator.device,
1194
+ )
1195
+ print(f"{prompt_embeds.shape = }, {pooled_prompt_embeds.shape = }, {text_ids.shape = }")
1196
+ # checking if the cached embeddings match
1197
+ inference_embeds_dir = "/archive/vaibhav.agrawal/a-bev-of-the-latents/inference_embeds_datasetv7_superhard"
1198
+ cached_prompt_path = osp.join(inference_embeds_dir, f"{'_'.join(prompt.lower().split())}.pth")
1199
+ assert osp.exists(cached_prompt_path), f"Make sure that the cached prompt embedding {cached_prompt_path} exists."
1200
+ cached_prompt_embeds = torch.load(cached_prompt_path, map_location="cpu")
1201
+ assert torch.allclose(cached_prompt_embeds["prompt_embeds"].cpu().float(), prompt_embeds.cpu().float(), atol=1e-3), f"Cached prompt embeds for prompt {prompt} do not match the computed prompt embeds. Make sure that the cached prompt embeds are correct., {torch.mean(torch.abs(cached_prompt_embeds['prompt_embeds'].cpu().float() - prompt_embeds.cpu().float())) = }, {torch.mean(torch.abs(cached_prompt_embeds['prompt_embeds'].cpu().float())) = }"
1202
+ assert torch.allclose(cached_prompt_embeds["pooled_prompt_embeds"].cpu().float(), pooled_prompt_embeds.cpu().float(), atol=1e-3), f"Cached pooled prompt embeds for prompt {prompt} do not match the computed pooled prompt embeds. Make sure that the cached pooled prompt embeds are correct., {torch.mean(torch.abs(cached_prompt_embeds['pooled_prompt_embeds'].cpu().float() - pooled_prompt_embeds.cpu().float())) = }"
1203
+ print(f"fucking passed the test!")
1204
+ else:
1205
+ assert "prompt_embeds" in batch and "pooled_prompt_embeds" in batch, "Make sure that the dataloader returns `prompt_embeds` and `pooled_prompt_embeds` when `inference_embeds_dir` is not None."
1206
+ prompt_embeds = batch["prompt_embeds"]
1207
+ pooled_prompt_embeds = batch["pooled_prompt_embeds"]
1208
+ text_ids = torch.zeros((batch["prompt_embeds"].shape[1], 3))
1209
+ prompt_embeds = prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
1210
+ pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
1211
+ text_ids = text_ids.to(dtype=vae.dtype, device=accelerator.device)
1212
+
1213
+
1214
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1215
+ height_ = 2 * (int(pixel_values.shape[-2]) // vae_scale_factor)
1216
+ width_ = 2 * (int(pixel_values.shape[-1]) // vae_scale_factor)
1217
+
1218
+ model_input = vae.encode(pixel_values).latent_dist.sample()
1219
+ model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
1220
+ model_input = model_input.to(dtype=weight_dtype)
1221
+
1222
+ latent_image_ids, cond_latent_image_ids = resize_position_encoding(
1223
+ model_input.shape[0],
1224
+ height_,
1225
+ width_,
1226
+ height_cond,
1227
+ width_cond,
1228
+ accelerator.device,
1229
+ weight_dtype,
1230
+ )
1231
+
1232
+ # Sample noise that we'll add to the latents
1233
+ noise = torch.randn_like(model_input)
1234
+ bsz = model_input.shape[0]
1235
+
1236
+ # Sample a random timestep for each image
1237
+ # for weighting schemes where we sample timesteps non-uniformly
1238
+ u = compute_density_for_timestep_sampling(
1239
+ weighting_scheme=args.weighting_scheme,
1240
+ batch_size=bsz,
1241
+ logit_mean=args.logit_mean,
1242
+ logit_std=args.logit_std,
1243
+ mode_scale=args.mode_scale,
1244
+ )
1245
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
1246
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
1247
+
1248
+ # Add noise according to flow matching.
1249
+ # zt = (1 - texp) * x + texp * z1
1250
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
1251
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
1252
+
1253
+ packed_noisy_model_input = FluxPipeline._pack_latents(
1254
+ noisy_model_input,
1255
+ batch_size=model_input.shape[0],
1256
+ num_channels_latents=model_input.shape[1],
1257
+ height=model_input.shape[2],
1258
+ width=model_input.shape[3],
1259
+ )
1260
+
1261
+ latent_image_ids_to_concat = [latent_image_ids]
1262
+ packed_cond_model_input_to_concat = []
1263
+
1264
+ if args.spatial_column is not None:
1265
+ # in case the condition is spatial
1266
+ cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
1267
+ cond_input = vae.encode(cond_pixel_values).latent_dist.sample()
1268
+ cond_input = (cond_input - vae_config_shift_factor) * vae_config_scaling_factor
1269
+ cond_input = cond_input.to(dtype=weight_dtype)
1270
+ # number of conditions in the concatenated condition image
1271
+ cond_number = cond_pixel_values.shape[-2] // args.cond_size
1272
+ cond_latent_image_ids = torch.concat([cond_latent_image_ids for _ in range(cond_number)], dim=-2)
1273
+ latent_image_ids_to_concat.append(cond_latent_image_ids)
1274
+
1275
+ packed_cond_model_input = FluxPipeline._pack_latents(
1276
+ cond_input,
1277
+ batch_size=cond_input.shape[0],
1278
+ num_channels_latents=cond_input.shape[1],
1279
+ height=cond_input.shape[2],
1280
+ width=cond_input.shape[3],
1281
+ )
1282
+ packed_cond_model_input_to_concat.append(packed_cond_model_input)
1283
+ else:
1284
+ cond_input = None
1285
+
1286
+ latent_image_ids = torch.concat(latent_image_ids_to_concat, dim=-2)
1287
+ cond_packed_noisy_model_input = torch.concat(packed_cond_model_input_to_concat, dim=-2)
1288
+
1289
+ # handle guidance
1290
+ if accelerator.unwrap_model(transformer).config.guidance_embeds:
1291
+ guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
1292
+ guidance = guidance.expand(model_input.shape[0])
1293
+ else:
1294
+ guidance = None
1295
+
1296
+ # Visualize training data before transformer forward pass
1297
+ if accelerator.is_main_process and args.debug and num_training_visualizations > 0 and global_step % 5 == 0:
1298
+ visualize_training_data(
1299
+ batch=batch,
1300
+ vae=vae,
1301
+ model_input=model_input,
1302
+ noisy_model_input=noisy_model_input,
1303
+ cond_input=cond_input,
1304
+ args=args,
1305
+ global_step=global_step,
1306
+ accelerator=accelerator
1307
+ )
1308
+ num_training_visualizations -= 1
1309
+
1310
+ # Predict the noise residual
1311
+ model_pred = transformer(
1312
+ hidden_states=packed_noisy_model_input,
1313
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
1314
+ cond_hidden_states=cond_packed_noisy_model_input,
1315
+ timestep=timesteps / 1000,
1316
+ guidance=guidance,
1317
+ pooled_projections=pooled_prompt_embeds,
1318
+ encoder_hidden_states=prompt_embeds,
1319
+ txt_ids=text_ids,
1320
+ img_ids=latent_image_ids,
1321
+ return_dict=False,
1322
+ call_ids=batch["call_ids"],
1323
+ cuboids_segmasks=batch["cuboids_segmasks"],
1324
+ )[0]
1325
+
1326
+ model_pred = FluxPipeline._unpack_latents(
1327
+ model_pred,
1328
+ height=int(pixel_values.shape[-2]),
1329
+ width=int(pixel_values.shape[-1]),
1330
+ vae_scale_factor=vae_scale_factor,
1331
+ )
1332
+
1333
+ # these weighting schemes use a uniform timestep sampling
1334
+ # and instead post-weight the loss
1335
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
1336
+
1337
+ # flow matching loss
1338
+ target = noise - model_input
1339
+
1340
+ # Compute regular loss.
1341
+ loss = torch.mean(
1342
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1343
+ 1,
1344
+ )
1345
+
1346
+ loss = loss.mean()
1347
+ accelerator.backward(loss)
1348
+ if accelerator.sync_gradients:
1349
+ params_to_clip = (transformer.parameters())
1350
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1351
+
1352
+ optimizer.step()
1353
+ lr_scheduler.step()
1354
+ optimizer.zero_grad()
1355
+
1356
+ # Checks if the accelerator has performed an optimization step behind the scenes
1357
+ if accelerator.sync_gradients:
1358
+ progress_bar.update(1)
1359
+ global_step += 1
1360
+
1361
+ if accelerator.is_main_process:
1362
+ if global_step % args.checkpointing_steps == 0:
1363
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1364
+ save_path = os.path.join(args.output_dir, f"epoch-{epoch}__checkpoint-{global_step}")
1365
+ os.makedirs(save_path, exist_ok=True)
1366
+ unwrapped_model_state = accelerator.unwrap_model(transformer).state_dict()
1367
+ lora_state_dict = {k:unwrapped_model_state[k] for k in unwrapped_model_state.keys() if '_lora' in k}
1368
+ save_file(
1369
+ lora_state_dict,
1370
+ os.path.join(save_path, "lora.safetensors")
1371
+ )
1372
+ accelerator.save_state(save_path)
1373
+ os.remove(osp.join(save_path, "model.safetensors"))
1374
+ logger.info(f"Saved state to {save_path}")
1375
+
1376
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1377
+ progress_bar.set_postfix(**logs)
1378
+ accelerator.log(logs, step=global_step)
1379
+
1380
+ save_path = os.path.join(args.output_dir, f"epoch-{epoch}__checkpoint-{global_step}")
1381
+ os.makedirs(save_path, exist_ok=True)
1382
+ unwrapped_model_state = accelerator.unwrap_model(transformer).state_dict()
1383
+ lora_state_dict = {k:unwrapped_model_state[k] for k in unwrapped_model_state.keys() if '_lora' in k}
1384
+ save_file(
1385
+ lora_state_dict,
1386
+ os.path.join(save_path, "lora.safetensors")
1387
+ )
1388
+ accelerator.save_state(save_path)
1389
+ os.remove(osp.join(save_path, "model.safetensors"))
1390
+ logger.info(f"Saved state to {save_path}")
1391
+ accelerator.wait_for_everyone()
1392
+ accelerator.end_training()
1393
+
1394
+
1395
+ if __name__ == "__main__":
1396
+ args = parse_args()
1397
+ main(args)