AbstractPhil commited on
Commit
a59941b
·
verified ·
1 Parent(s): 3fea588

way better trainer

Browse files
Files changed (1) hide show
  1. flow_leco_trainer.py +425 -396
flow_leco_trainer.py CHANGED
@@ -1,16 +1,19 @@
1
  """
2
- Lune LECO Trainer - Proper Concept Group Implementation
 
3
  """
4
 
5
  import os
6
  import json
7
  import datetime
 
8
  from dataclasses import dataclass, asdict, field
9
- from typing import List, Literal
10
  from tqdm.auto import tqdm
11
- from enum import Enum
12
 
13
  import torch
 
14
  import torch.nn.functional as F
15
  from torch.utils.tensorboard import SummaryWriter
16
  from safetensors.torch import save_file
@@ -20,100 +23,144 @@ from transformers import CLIPTextModel, CLIPTokenizer
20
  from huggingface_hub import hf_hub_download
21
 
22
 
23
- class ActionType(str, Enum):
24
- """LECO action types"""
25
- ERASE = "erase" # sources → empty
26
- ENHANCE = "enhance" # sources → amplified
27
- REPLACE = "replace" # sources → target
28
- NEUTRALIZE = "neutralize" # sources → neutral
29
-
30
 
31
- @dataclass
32
- class ConceptGroup:
33
- """
34
- A group of related concepts to transform together.
35
-
36
- Training strategy:
37
- - Sample from sources: these are the concepts to modify
38
- - Transform to target: what they should become
39
- - Use neutral as intermediate: optional neutral reference point
40
- - Preserve negatives: concepts that should NOT be affected
41
-
42
- Examples:
43
- # Erase multiple anime styles
44
- ConceptGroup(
45
- sources=["anime", "manga", "cartoon"],
46
- target="",
47
- negatives=["realistic", "photograph"],
48
- weight=1.0
49
- )
50
-
51
- # Replace artists
52
- ConceptGroup(
53
- sources=["van gogh", "picasso"],
54
- target="monet",
55
- neutral="painting",
56
- negatives=["photograph", "digital art"],
57
- weight=1.0
58
- )
59
-
60
- # Neutralize NSFW to safe
61
- ConceptGroup(
62
- sources=["nsfw", "nude", "explicit"],
63
- target="safe",
64
- neutral="person",
65
- negatives=["portrait", "art", "figure drawing"],
66
- weight=2.0
67
- )
68
- """
69
- sources: List[str] # Concepts to modify (sampled during training)
70
- target: str = "" # What to transform to (empty = erase)
71
- neutral: str = "" # Optional neutral reference point
72
- negatives: List[str] = field(default_factory=list) # Concepts to preserve
73
- weight: float = 1.0 # Group importance
74
- preservation_weight: float = 0.5 # How strongly to preserve negatives
75
 
76
 
77
  @dataclass
78
- class LECOConfig:
79
- # Model paths
80
  output_dir: str = "./leco_outputs"
81
  base_model_repo: str = "AbstractPhil/sd15-flow-lune-flux"
82
  base_checkpoint: str = "sd15_flow_flux_t2_6_pose_t4_6_port_t1_4_s18765.pt"
 
83
 
84
- # HuggingFace upload
85
- hf_repo_id: str = "AbstractPhil/lune-leco-adapters"
86
- upload_to_hub: bool = False
87
 
88
- # Training data
89
- action: ActionType = ActionType.ERASE
90
- concept_groups: List[ConceptGroup] = field(default_factory=list)
91
-
92
- # LoRA architecture
93
- lora_rank: int = 4
94
  lora_alpha: float = 1.0
95
- lora_dropout: float = 0.0
96
- training_method: Literal["full", "selfattn", "xattn", "noxattn", "innoxattn"] = "xattn"
97
 
98
- # Training hyperparameters
99
  seed: int = 42
100
- iterations: int = 1000
101
- lr: float = 1e-4
 
 
 
102
 
103
- # Sampling strategy
104
- sources_per_step: int = 2 # How many source concepts to sample per step
 
105
 
106
- # Flow-matching parameters
107
  shift: float = 2.5
108
  min_timestep: float = 0.0
109
  max_timestep: float = 1000.0
110
-
111
- # Resolution
112
  resolution: int = 512
113
 
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  def get_target_modules(training_method: str) -> List[str]:
116
- """Get layer names to inject LoRA based on training method."""
117
  attn1 = ["attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out.0"]
118
  attn2 = ["attn2.to_q", "attn2.to_k", "attn2.to_v", "attn2.to_out.0"]
119
 
@@ -127,8 +174,8 @@ def get_target_modules(training_method: str) -> List[str]:
127
  return method_map.get(training_method, attn1 + attn2)
128
 
129
 
130
- def create_lora_layers(unet: torch.nn.Module, config: LECOConfig):
131
- """Create LoRA layers in ComfyUI/A1111 compatible format."""
132
  target_modules = get_target_modules(config.training_method)
133
  lora_state = {}
134
  trainable_params = []
@@ -136,11 +183,13 @@ def create_lora_layers(unet: torch.nn.Module, config: LECOConfig):
136
  def get_lora_key(module_path: str) -> str:
137
  return f"lora_unet_{module_path.replace('.', '_')}"
138
 
 
 
139
  for name, module in unet.named_modules():
140
  if not any(target in name for target in target_modules):
141
  continue
142
 
143
- if not isinstance(module, torch.nn.Linear):
144
  continue
145
 
146
  lora_key = get_lora_key(name)
@@ -148,11 +197,11 @@ def create_lora_layers(unet: torch.nn.Module, config: LECOConfig):
148
  out_dim = module.out_features
149
  rank = config.lora_rank
150
 
151
- lora_down = torch.nn.Parameter(torch.zeros(rank, in_dim))
152
- lora_up = torch.nn.Parameter(torch.zeros(out_dim, rank))
153
 
154
- torch.nn.init.kaiming_uniform_(lora_down, a=1.0)
155
- torch.nn.init.zeros_(lora_up)
156
 
157
  lora_state[f"{lora_key}.lora_down.weight"] = lora_down
158
  lora_state[f"{lora_key}.lora_up.weight"] = lora_up
@@ -165,8 +214,8 @@ def create_lora_layers(unet: torch.nn.Module, config: LECOConfig):
165
  return lora_state, trainable_params
166
 
167
 
168
- def apply_lora_hooks(unet: torch.nn.Module, lora_state: dict, scale: float = 1.0) -> list:
169
- """Apply LoRA using forward hooks."""
170
  handles = []
171
 
172
  for key in lora_state:
@@ -201,197 +250,195 @@ def remove_lora_hooks(handles: list):
201
  handle.remove()
202
 
203
 
204
- @torch.no_grad()
205
- def encode_text(prompt: str, tokenizer, text_encoder, device) -> torch.Tensor:
206
- """Encode text to CLIP embeddings"""
207
- tokens = tokenizer(
208
- prompt,
209
- padding="max_length",
210
- max_length=tokenizer.model_max_length,
211
- truncation=True,
212
- return_tensors="pt"
213
- ).input_ids.to(device)
214
-
215
- return text_encoder(tokens)[0]
216
-
217
 
218
- def compute_concept_group_loss(
219
- unet: torch.nn.Module,
220
- lora_state: dict,
221
- group: ConceptGroup,
222
  tokenizer,
223
  text_encoder,
224
- config: LECOConfig,
225
  device: str = "cuda"
226
  ):
227
- """
228
- Compute LECO loss for a concept group.
229
-
230
- Strategy:
231
- 1. Sample source concepts from group.sources
232
- 2. Compute transformation: source → target (using neutral if provided)
233
- 3. Preserve negatives (ensure LoRA doesn't affect them)
234
-
235
- The LoRA learns to transform ALL sources to the same target.
236
- """
237
- import random
238
 
239
- # Sample source concepts for this step
240
- num_sources = min(config.sources_per_step, len(group.sources))
241
- sampled_sources = random.sample(group.sources, num_sources)
242
-
243
- # Sample timestep (shared for this group)
244
  min_sigma = config.min_timestep / 1000.0
245
  max_sigma = config.max_timestep / 1000.0
246
- sigma = min_sigma + torch.rand(1, device=device) * (max_sigma - min_sigma)
 
 
 
 
247
  sigma = (config.shift * sigma) / (1 + (config.shift - 1) * sigma)
248
  timestep = sigma * 1000.0
249
  sigma_expanded = sigma.view(1, 1, 1, 1)
250
 
251
- total_loss = 0
252
- metrics = {
253
- "source_loss": 0,
254
- "preservation_loss": 0,
255
- "sources_processed": 0,
256
- "negatives_processed": 0
257
- }
258
 
259
- # === SOURCE TRANSFORMATION LOSS ===
260
- for source_concept in sampled_sources:
261
- noise = torch.randn(1, 4, config.resolution // 8, config.resolution // 8, device=device)
262
- noisy_input = sigma_expanded * noise
263
-
264
- # Encode prompts
265
- source_emb = encode_text(source_concept, tokenizer, text_encoder, device)
266
- target_emb = encode_text(group.target, tokenizer, text_encoder, device)
267
-
268
- # Optional: use neutral as intermediate reference
269
- if group.neutral:
270
- neutral_emb = encode_text(group.neutral, tokenizer, text_encoder, device)
271
- else:
272
- neutral_emb = None
273
-
274
- # Compute target direction WITHOUT LoRA
275
- with torch.no_grad():
276
- pred_source = unet(
277
- noisy_input, timestep,
278
- encoder_hidden_states=source_emb,
279
- return_dict=False
280
- )[0]
281
-
282
- pred_target = unet(
283
- noisy_input, timestep,
284
- encoder_hidden_states=target_emb,
285
- return_dict=False
286
- )[0]
287
-
288
- # Determine transformation direction
289
- if group.neutral and neutral_emb is not None:
290
- # Use neutral as reference: source → neutral → target
291
- pred_neutral = unet(
292
- noisy_input, timestep,
293
- encoder_hidden_states=neutral_emb,
294
- return_dict=False
295
- )[0]
296
-
297
- # Two-step transformation
298
- step1 = pred_neutral - pred_source # source → neutral
299
- step2 = pred_target - pred_neutral # neutral → target
300
- target_delta = step1 + step2 # combined transformation
301
- else:
302
- # Direct transformation: source → target
303
- target_delta = pred_target - pred_source
304
-
305
- # Apply LoRA and measure its effect
306
- handles = apply_lora_hooks(unet, lora_state, scale=1.0)
307
-
308
- try:
309
- pred_with_lora = unet(
310
- noisy_input, timestep,
311
- encoder_hidden_states=source_emb,
312
- return_dict=False
313
- )[0]
314
- finally:
315
- remove_lora_hooks(handles)
316
-
317
- # LoRA contribution
318
- lora_delta = pred_with_lora - pred_source
319
 
320
- # Loss: LoRA should reproduce the transformation
321
- source_loss = F.mse_loss(lora_delta, target_delta)
322
- total_loss += source_loss * group.weight
323
- metrics["source_loss"] += source_loss.item()
324
- metrics["sources_processed"] += 1
325
-
326
- # === PRESERVATION LOSS (negatives should remain unchanged) ===
327
- for negative_concept in group.negatives:
328
- noise = torch.randn(1, 4, config.resolution // 8, config.resolution // 8, device=device)
329
- noisy_input = sigma_expanded * noise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
- negative_emb = encode_text(negative_concept, tokenizer, text_encoder, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
- # Baseline without LoRA
334
- with torch.no_grad():
335
- pred_negative = unet(
336
- noisy_input, timestep,
337
- encoder_hidden_states=negative_emb,
338
- return_dict=False
339
- )[0]
340
 
341
- # With LoRA
342
- handles = apply_lora_hooks(unet, lora_state, scale=1.0)
343
 
344
- try:
345
- pred_with_lora = unet(
346
- noisy_input, timestep,
347
- encoder_hidden_states=negative_emb,
348
- return_dict=False
349
- )[0]
350
- finally:
351
- remove_lora_hooks(handles)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
- # Penalize any change
354
- preservation_loss = F.mse_loss(pred_with_lora, pred_negative)
355
- total_loss += preservation_loss * group.preservation_weight
356
- metrics["preservation_loss"] += preservation_loss.item()
357
- metrics["negatives_processed"] += 1
358
 
359
- # Average metrics
360
- if metrics["sources_processed"] > 0:
361
- metrics["source_loss"] /= metrics["sources_processed"]
362
- if metrics["negatives_processed"] > 0:
363
- metrics["preservation_loss"] /= metrics["negatives_processed"]
364
 
365
- metrics["timestep"] = timestep.item()
366
- metrics["sigma"] = sigma.item()
 
 
 
 
 
 
 
 
 
367
 
368
  return total_loss, metrics
369
 
370
 
371
- def train_leco(config: LECOConfig):
372
- """Main training loop with proper concept groups"""
 
 
 
 
373
  device = "cuda"
374
  torch.manual_seed(config.seed)
375
 
376
- if not config.concept_groups:
377
- raise ValueError("No concept groups specified!")
378
 
379
- # Validate concept groups
380
- for group in config.concept_groups:
381
- if not group.sources:
382
- raise ValueError("Each concept group must have at least one source concept")
383
 
384
- # Setup output
385
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
386
-
387
- # Create name from first group
388
- first_group = config.concept_groups[0]
389
- source_names = "_".join([s.replace(" ", "")[:10] for s in first_group.sources[:2]])
390
- if len(first_group.sources) > 2:
391
- source_names += f"_plus{len(first_group.sources)-2}"
392
-
393
- run_name = f"{config.action.value}_{source_names}_{timestamp}"
394
- output_dir = os.path.join(config.output_dir, run_name)
395
  os.makedirs(output_dir, exist_ok=True)
396
 
397
  writer = SummaryWriter(log_dir=output_dir, flush_secs=60)
@@ -400,11 +447,33 @@ def train_leco(config: LECOConfig):
400
  json.dump(asdict(config), f, indent=2)
401
 
402
  print("="*80)
403
- print(f"LECO Training: {config.action.value.upper()}")
 
 
404
  print("="*80)
405
 
406
- # Load model
407
- print("\nLoading base model...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  checkpoint_path = hf_hub_download(
409
  repo_id=config.base_model_repo,
410
  filename=config.base_checkpoint,
@@ -420,14 +489,28 @@ def train_leco(config: LECOConfig):
420
 
421
  student_dict = checkpoint["student"]
422
  cleaned_dict = {k[5:] if k.startswith("unet.") else k: v for k, v in student_dict.items()}
423
- unet.load_state_dict(cleaned_dict, strict=False)
 
 
 
424
  unet = unet.to(device)
425
  unet.requires_grad_(False)
426
  unet.eval()
427
- print("✓ Loaded UNet")
428
 
429
- # Load CLIP
430
- print("Loading CLIP text encoder...")
 
 
 
 
 
 
 
 
 
 
 
 
431
  tokenizer = CLIPTokenizer.from_pretrained(
432
  "runwayml/stable-diffusion-v1-5", subfolder="tokenizer"
433
  )
@@ -439,77 +522,99 @@ def train_leco(config: LECOConfig):
439
  text_encoder.eval()
440
  print("✓ Loaded CLIP")
441
 
442
- # Create LoRA layers
443
- print(f"\nInjecting LoRA (rank={config.lora_rank}, alpha={config.lora_alpha})...")
444
- lora_state, trainable_params = create_lora_layers(unet, config)
 
 
 
 
 
 
445
 
446
- # Move Parameters to device IN-PLACE
447
  print(f"Moving LoRA parameters to {device}...")
448
  for param in trainable_params:
449
  param.data = param.data.to(device)
450
 
451
- # Move other tensors to device
452
  for key, value in lora_state.items():
453
- if isinstance(value, torch.Tensor) and not isinstance(value, torch.nn.Parameter):
454
  lora_state[key] = value.to(device)
455
 
456
  optimizer = torch.optim.AdamW(trainable_params, lr=config.lr, weight_decay=0.01)
457
 
458
- # Print config
459
  print(f"\nTraining Configuration:")
460
- print(f" Action: {config.action.value}")
461
- print(f" Concept groups: {len(config.concept_groups)}")
462
- for i, group in enumerate(config.concept_groups, 1):
463
- print(f"\n Group {i} (weight: {group.weight}):")
464
- print(f" Sources: {', '.join(group.sources)}")
465
- print(f" Target: '{group.target}'" if group.target else " Target: (erase)")
466
- if group.neutral:
467
- print(f" Neutral: '{group.neutral}'")
468
- if group.negatives:
469
- print(f" Preserve: {', '.join(group.negatives)}")
470
 
471
  print(f"\n Iterations: {config.iterations}")
 
 
472
  print(f" Learning rate: {config.lr}")
473
- print(f" Training method: {config.training_method}")
474
- print(f" Sources per step: {config.sources_per_step}")
475
  print("="*80 + "\n")
476
 
477
- # Training loop
478
  progress = tqdm(range(config.iterations), desc="Training")
479
 
480
  for step in progress:
481
- import random
482
-
483
- # Sample a concept group
484
- group = random.choice(config.concept_groups)
485
 
486
- # Compute loss for this group
487
- loss, metrics = compute_concept_group_loss(
488
- unet, lora_state, group,
489
- tokenizer, text_encoder, config, device
 
 
490
  )
491
 
492
- # Backprop
493
  loss.backward()
494
  grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
495
  optimizer.step()
496
  optimizer.zero_grad()
497
 
498
- # Logging
499
  writer.add_scalar("loss/total", loss.item(), step)
500
- writer.add_scalar("loss/source", metrics["source_loss"], step)
501
- writer.add_scalar("loss/preservation", metrics["preservation_loss"], step)
502
  writer.add_scalar("grad_norm", grad_norm.item(), step)
 
503
 
504
  progress.set_postfix({
505
  "loss": f"{loss.item():.4f}",
506
- "src": f"{metrics['source_loss']:.4f}",
507
- "pres": f"{metrics['preservation_loss']:.4f}",
 
508
  "grad": f"{grad_norm.item():.3f}"
509
  })
510
 
511
- if (step + 1) % 200 == 0 or step == config.iterations - 1:
512
- save_checkpoint(lora_state, config, output_dir, step + 1, source_names)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
 
514
  writer.close()
515
 
@@ -521,120 +626,44 @@ def train_leco(config: LECOConfig):
521
  return output_dir
522
 
523
 
524
- def save_checkpoint(lora_state, config, output_dir, step, name_suffix):
525
- """Save LoRA in SafeTensors format"""
526
- save_dict = {}
527
-
528
- for key, value in lora_state.items():
529
- if isinstance(value, torch.Tensor) and not key.endswith("._module"):
530
- save_dict[key] = value.detach().cpu()
531
-
532
- # Build metadata
533
- all_sources = []
534
- all_targets = []
535
- all_negatives = []
536
- for group in config.concept_groups:
537
- all_sources.extend(group.sources)
538
- if group.target:
539
- all_targets.append(group.target)
540
- all_negatives.extend(group.negatives)
541
-
542
- metadata = {
543
- "ss_network_module": "networks.lora",
544
- "ss_network_dim": str(config.lora_rank),
545
- "ss_network_alpha": str(config.lora_alpha),
546
- "ss_base_model": "runwayml/stable-diffusion-v1-5",
547
- "ss_training_method": config.training_method,
548
- "leco_action": config.action.value,
549
- "leco_sources": ", ".join(all_sources),
550
- "leco_targets": ", ".join(all_targets) if all_targets else "",
551
- "leco_negatives": ", ".join(all_negatives),
552
- "leco_step": str(step),
553
- "leco_num_groups": str(len(config.concept_groups))
554
- }
555
-
556
- filename = f"leco_{name_suffix}_r{config.lora_rank}_s{step}.safetensors"
557
- filepath = os.path.join(output_dir, filename)
558
-
559
- save_file(save_dict, filepath, metadata=metadata)
560
- print(f"\n✓ Saved: {filename}")
561
-
562
-
563
- # ============================================================================
564
- # EXAMPLE CONFIGURATIONS
565
- # ============================================================================
566
-
567
  if __name__ == "__main__":
568
 
569
- # Example 1: Erase anime styles (multiple sources → empty)
570
- config_erase_anime = LECOConfig(
571
- action=ActionType.ERASE,
572
- concept_groups=[
573
- ConceptGroup(
574
- sources=["anime", "manga", "cartoon"],
575
- target="", # Erase
576
- negatives=["realistic", "photograph", "painting"],
577
- weight=1.0
578
- )
579
- ],
580
- iterations=1000,
581
- lora_rank=4,
582
- training_method="xattn" # Cross-attention for semantic content
583
- )
584
 
585
- # Example 2: Replace artists (multiple sources → single target)
586
- config_replace_artists = LECOConfig(
587
- action=ActionType.REPLACE,
588
- concept_groups=[
589
- ConceptGroup(
590
- sources=["van gogh", "picasso", "dali"],
591
- target="monet",
592
- neutral="painting", # Use painting as neutral reference
593
- negatives=["photograph", "digital art"],
594
- weight=1.0
595
- )
596
- ],
597
- iterations=800,
598
- lora_rank=8,
599
- training_method="xattn"
600
- )
601
 
602
- # Example 3: Neutralize NSFW (multiple sources → safe target)
603
- config_nsfw = LECOConfig(
604
- action=ActionType.NEUTRALIZE,
605
- concept_groups=[
606
- ConceptGroup(
607
- sources=["nsfw", "nude", "explicit", "naked"],
608
- target="clothed",
609
- neutral="person",
610
- negatives=["portrait", "figure drawing", "classical art", "sculpture"],
611
- weight=2.0,
612
- preservation_weight=0.8 # Strong preservation
613
- )
614
- ],
615
- iterations=1200,
616
- lora_rank=4,
617
- training_method="full"
618
  )
619
 
620
- # Example 4: Your original request - weird food combos
621
- config_food = LECOConfig(
622
- action=ActionType.ERASE,
623
- concept_groups=[
624
- ConceptGroup(
625
- sources=["potato chicken sandwich", "taco pizza", "banana sushi"],
626
- target="",
627
- neutral="food",
628
- negatives=["normal sandwiches", "table", "walls", "plates", "restaurant"],
629
- weight=1.0,
630
- preservation_weight=1.5
631
- )
632
- ],
633
- iterations=1000,
634
- lora_rank=4,
635
  training_method="xattn",
636
- sources_per_step=2 # Sample 2 weird foods per training step
 
 
 
 
 
 
 
 
 
637
  )
638
 
639
- # Train
640
- train_leco(config_erase_anime)
 
1
  """
2
+ LECO Attribute Binding Trainer - COMPLETE WITH PROPER FLOW MATCHING
3
+ Complete script with correct flow matching SNR and velocity prediction
4
  """
5
 
6
  import os
7
  import json
8
  import datetime
9
+ import random
10
  from dataclasses import dataclass, asdict, field
11
+ from typing import List, Tuple
12
  from tqdm.auto import tqdm
13
+ from itertools import product
14
 
15
  import torch
16
+ import torch.nn as nn
17
  import torch.nn.functional as F
18
  from torch.utils.tensorboard import SummaryWriter
19
  from safetensors.torch import save_file
 
23
  from huggingface_hub import hf_hub_download
24
 
25
 
26
+ # ============================================================================
27
+ # DATA STRUCTURES
28
+ # ============================================================================
 
 
 
 
29
 
30
+ @dataclass(frozen=True)
31
+ class AttributePair:
32
+ """A specific combination that should stay distinct"""
33
+ attr1: str
34
+ attr2: str
35
+ negatives: Tuple[str, ...] = ()
36
+ weight: float = 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  @dataclass
40
+ class AttributeBindingConfig:
41
+ """Config for attribute binding training"""
42
  output_dir: str = "./leco_outputs"
43
  base_model_repo: str = "AbstractPhil/sd15-flow-lune-flux"
44
  base_checkpoint: str = "sd15_flow_flux_t2_6_pose_t4_6_port_t1_4_s18765.pt"
45
+ name_prefix: str = "leco"
46
 
47
+ attribute_pairs: List[AttributePair] = field(default_factory=list)
 
 
48
 
49
+ lora_rank: int = 8
 
 
 
 
 
50
  lora_alpha: float = 1.0
51
+ training_method: str = "xattn"
 
52
 
 
53
  seed: int = 42
54
+ iterations: int = 500
55
+ save_every: int = 250
56
+ lr: float = 2e-4
57
+ pairs_per_batch: int = 4
58
+ negatives_per_positive: int = 2
59
 
60
+ # Min-SNR parameters
61
+ use_min_snr: bool = True
62
+ min_snr_gamma: float = 5.0
63
 
64
+ # Flow matching parameters
65
  shift: float = 2.5
66
  min_timestep: float = 0.0
67
  max_timestep: float = 1000.0
 
 
68
  resolution: int = 512
69
 
70
 
71
+ @dataclass
72
+ class LECOConfig:
73
+ """Minimal config for LoRA creation"""
74
+ lora_rank: int = 4
75
+ lora_alpha: float = 1.0
76
+ training_method: str = "xattn"
77
+
78
+
79
+ # ============================================================================
80
+ # ATTRIBUTE COMBINATION HELPERS
81
+ # ============================================================================
82
+
83
+ def extract_color(text: str) -> str:
84
+ """Extract color from text"""
85
+ colors = [
86
+ "red", "blue", "green", "yellow", "purple", "orange", "pink",
87
+ "black", "white", "brown", "blonde", "silver", "gold", "cyan",
88
+ "magenta", "teal", "lavender", "gray", "grey", "beige", "navy",
89
+ "maroon", "turquoise", "violet", "indigo", "crimson"
90
+ ]
91
+ text_lower = text.lower()
92
+ for color in colors:
93
+ if color in text_lower:
94
+ return color
95
+ return None
96
+
97
+
98
+ def generate_smart_negatives(attr1: str, attr2: str, all_negatives: List[str] = None) -> List[str]:
99
+ """Automatically generate wrong combinations"""
100
+ negatives = []
101
+
102
+ color1 = extract_color(attr1)
103
+ color2 = extract_color(attr2)
104
+
105
+ if color1 and color2 and color1 != color2:
106
+ swapped_attr1 = attr1.replace(color1, color2)
107
+ swapped_attr2 = attr2.replace(color2, color1)
108
+ negatives.append(f"{swapped_attr1}, {swapped_attr2}")
109
+ negatives.append(f"{attr1}, {attr2.replace(color2, color1)}")
110
+ negatives.append(f"{attr1.replace(color1, color2)}, {attr2}")
111
+
112
+ # Add universal negatives to combinations
113
+ if all_negatives:
114
+ for neg in all_negatives:
115
+ negatives.append(f"{attr1}, {attr2}, {neg}")
116
+
117
+ return list(set(negatives))
118
+
119
+
120
+ def create_attribute_combinations(
121
+ pair_attr1: List[str],
122
+ pair_attr2: List[str],
123
+ negatives: List[str] = None,
124
+ weight: float = 1.0,
125
+ auto_generate_negatives: bool = True
126
+ ) -> List[AttributePair]:
127
+ """Create all combinations of two attribute lists"""
128
+ pairs = []
129
+
130
+ for attr1, attr2 in product(pair_attr1, pair_attr2):
131
+ if auto_generate_negatives:
132
+ neg_list = generate_smart_negatives(attr1, attr2, negatives)
133
+ else:
134
+ neg_list = []
135
+ if negatives:
136
+ for neg in negatives:
137
+ neg_list.append(f"{attr1}, {neg}")
138
+ neg_list.append(f"{neg}, {attr2}")
139
+
140
+ pairs.append(AttributePair(
141
+ attr1=attr1,
142
+ attr2=attr2,
143
+ negatives=tuple(neg_list),
144
+ weight=weight
145
+ ))
146
+
147
+ return pairs
148
+
149
+
150
+ def combine_attribute_groups(*groups: List[AttributePair]) -> List[AttributePair]:
151
+ """Combine multiple attribute groups"""
152
+ combined = []
153
+ for group in groups:
154
+ combined.extend(group)
155
+ return combined
156
+
157
+
158
+ # ============================================================================
159
+ # LORA UTILITIES
160
+ # ============================================================================
161
+
162
  def get_target_modules(training_method: str) -> List[str]:
163
+ """Get layer names to inject LoRA"""
164
  attn1 = ["attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out.0"]
165
  attn2 = ["attn2.to_q", "attn2.to_k", "attn2.to_v", "attn2.to_out.0"]
166
 
 
174
  return method_map.get(training_method, attn1 + attn2)
175
 
176
 
177
+ def create_lora_layers(unet: nn.Module, config: LECOConfig):
178
+ """Create LoRA layers"""
179
  target_modules = get_target_modules(config.training_method)
180
  lora_state = {}
181
  trainable_params = []
 
183
  def get_lora_key(module_path: str) -> str:
184
  return f"lora_unet_{module_path.replace('.', '_')}"
185
 
186
+ print(f"Creating LoRA layers (method: {config.training_method})...")
187
+
188
  for name, module in unet.named_modules():
189
  if not any(target in name for target in target_modules):
190
  continue
191
 
192
+ if not isinstance(module, nn.Linear):
193
  continue
194
 
195
  lora_key = get_lora_key(name)
 
197
  out_dim = module.out_features
198
  rank = config.lora_rank
199
 
200
+ lora_down = nn.Parameter(torch.zeros(rank, in_dim))
201
+ lora_up = nn.Parameter(torch.zeros(out_dim, rank))
202
 
203
+ nn.init.kaiming_uniform_(lora_down, a=1.0)
204
+ nn.init.zeros_(lora_up)
205
 
206
  lora_state[f"{lora_key}.lora_down.weight"] = lora_down
207
  lora_state[f"{lora_key}.lora_up.weight"] = lora_up
 
214
  return lora_state, trainable_params
215
 
216
 
217
+ def apply_lora_hooks(unet: nn.Module, lora_state: dict, scale: float = 1.0) -> list:
218
+ """Apply LoRA using forward hooks"""
219
  handles = []
220
 
221
  for key in lora_state:
 
250
  handle.remove()
251
 
252
 
253
+ # ============================================================================
254
+ # TRAINING LOSS WITH PROPER FLOW MATCHING
255
+ # ============================================================================
 
 
 
 
 
 
 
 
 
 
256
 
257
+ def compute_attribute_binding_loss_batched(
258
+ unet,
259
+ lora_state,
260
+ positive_pairs: List[AttributePair],
261
  tokenizer,
262
  text_encoder,
263
+ config: AttributeBindingConfig,
264
  device: str = "cuda"
265
  ):
266
+ """Batched attribute binding with PROPER FLOW MATCHING"""
 
 
 
 
 
 
 
 
 
 
267
 
268
+ # 1. Sample sigma with constrained range (matching your training code)
 
 
 
 
269
  min_sigma = config.min_timestep / 1000.0
270
  max_sigma = config.max_timestep / 1000.0
271
+
272
+ sigma = torch.rand(1, device=device)
273
+ sigma = min_sigma + sigma * (max_sigma - min_sigma) # Constrain to range
274
+
275
+ # Apply shift transformation
276
  sigma = (config.shift * sigma) / (1 + (config.shift - 1) * sigma)
277
  timestep = sigma * 1000.0
278
  sigma_expanded = sigma.view(1, 1, 1, 1)
279
 
280
+ # 2. Flow matching: x_t = sigma * noise + (1 - sigma) * x_0
281
+ # For LECO: we use pure noise as x_0 (no clean latents available)
282
+ noise = torch.randn(1, 4, config.resolution // 8, config.resolution // 8, device=device)
283
+ noisy_input = sigma_expanded * noise # Simplified since x_0 = 0 (centered)
 
 
 
284
 
285
+ # Build prompts
286
+ positive_prompts = []
287
+ negative_prompts = []
288
+ pair_weights = []
289
+
290
+ for pair in positive_pairs:
291
+ correct = f"{pair.attr1}, {pair.attr2}"
292
+ positive_prompts.append(correct)
293
+ pair_weights.append(pair.weight)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
+ if pair.negatives:
296
+ sampled_negs = random.sample(
297
+ list(pair.negatives),
298
+ min(config.negatives_per_positive, len(pair.negatives))
299
+ )
300
+ negative_prompts.extend(sampled_negs)
301
+
302
+ if not positive_prompts:
303
+ return torch.tensor(0.0, device=device), {
304
+ "positive_loss": 0, "negative_loss": 0,
305
+ "positive_count": 0, "negative_count": 0,
306
+ "timestep": 0.0, "snr_weight": 1.0
307
+ }
308
+
309
+ neutral_prompt = ""
310
+ all_prompts = [neutral_prompt] + positive_prompts + negative_prompts
311
+
312
+ text_inputs = tokenizer(
313
+ all_prompts,
314
+ padding="max_length",
315
+ max_length=tokenizer.model_max_length,
316
+ truncation=True,
317
+ return_tensors="pt"
318
+ ).to(device)
319
+
320
+ all_embeddings = text_encoder(text_inputs.input_ids)[0]
321
+
322
+ neutral_emb = all_embeddings[0:1]
323
+ positive_embs = all_embeddings[1:1+len(positive_prompts)]
324
+ negative_embs = all_embeddings[1+len(positive_prompts):]
325
+
326
+ batch_size = len(all_prompts) - 1
327
+ noisy_input_batch = noisy_input.repeat(batch_size, 1, 1, 1)
328
+ timestep_batch = timestep.repeat(batch_size)
329
+
330
+ combined_embs = torch.cat([positive_embs, negative_embs], dim=0)
331
+
332
+ # Get VELOCITY predictions
333
+ with torch.no_grad():
334
+ vel_neutral = unet(
335
+ noisy_input, timestep_batch[0:1],
336
+ encoder_hidden_states=neutral_emb,
337
+ return_dict=False
338
+ )[0]
339
 
340
+ vel_baseline = unet(
341
+ noisy_input_batch, timestep_batch,
342
+ encoder_hidden_states=combined_embs,
343
+ return_dict=False
344
+ )[0]
345
+
346
+ vel_positive_baseline = vel_baseline[:len(positive_prompts)]
347
+ vel_negative_baseline = vel_baseline[len(positive_prompts):]
348
+
349
+ handles = apply_lora_hooks(unet, lora_state, scale=1.0)
350
+
351
+ try:
352
+ vel_with_lora = unet(
353
+ noisy_input_batch, timestep_batch,
354
+ encoder_hidden_states=combined_embs,
355
+ return_dict=False
356
+ )[0]
357
+ finally:
358
+ remove_lora_hooks(handles)
359
+
360
+ vel_positive_lora = vel_with_lora[:len(positive_prompts)]
361
+ vel_negative_lora = vel_with_lora[len(positive_prompts):]
362
+
363
+ # 3. Compute FLOW MATCHING SNR (not DDPM)
364
+ snr_weight = 1.0
365
+ if config.use_min_snr:
366
+ # Flow matching SNR: ((1 - sigma)^2) / (sigma^2)
367
+ sigma_sq = sigma.squeeze() ** 2
368
+ snr = ((1 - sigma.squeeze()) ** 2) / (sigma_sq + 1e-8)
369
 
370
+ # Min-SNR clamping
371
+ snr_clamped = torch.minimum(snr, torch.tensor(config.min_snr_gamma, device=device))
372
+ snr_weight_tensor = snr_clamped / snr
 
 
 
 
373
 
374
+ # Velocity prediction adjustment: divide by (SNR + 1)
375
+ snr_weight_tensor = snr_weight_tensor / (snr + 1)
376
 
377
+ snr_weight = snr_weight_tensor.item()
378
+ else:
379
+ snr_weight_tensor = torch.ones(1, device=device)
380
+
381
+ # Compute losses
382
+ vel_neutral_expanded = vel_neutral.expand_as(vel_positive_baseline)
383
+ target_positive_direction = vel_positive_baseline - vel_neutral_expanded
384
+ lora_positive_delta = vel_positive_lora - vel_positive_baseline
385
+
386
+ positive_loss_per_sample = F.mse_loss(
387
+ lora_positive_delta,
388
+ target_positive_direction * 0.3,
389
+ reduction='none'
390
+ ).mean(dim=(1,2,3))
391
+
392
+ # Apply both pair weights and SNR weights
393
+ pair_weights_tensor = torch.tensor(pair_weights, device=device)
394
+ weighted_positive_loss = (positive_loss_per_sample * pair_weights_tensor * snr_weight_tensor).mean()
395
+
396
+ negative_loss = torch.tensor(0.0, device=device)
397
+ lora_negative_norm = 0.0
398
+
399
+ if len(negative_prompts) > 0:
400
+ vel_neutral_expanded_neg = vel_neutral.expand_as(vel_negative_baseline)
401
+ target_negative_direction = vel_neutral_expanded_neg - vel_negative_baseline
402
+ lora_negative_delta = vel_negative_lora - vel_negative_baseline
403
 
404
+ negative_loss = F.mse_loss(lora_negative_delta, target_negative_direction * 0.2, reduction='mean')
405
+ negative_loss = negative_loss * snr_weight_tensor
406
+ lora_negative_norm = lora_negative_delta.norm().item()
 
 
407
 
408
+ total_loss = weighted_positive_loss + negative_loss * 0.5
 
 
 
 
409
 
410
+ metrics = {
411
+ "positive_loss": weighted_positive_loss.item(),
412
+ "negative_loss": negative_loss.item() if isinstance(negative_loss, torch.Tensor) else 0.0,
413
+ "positive_count": len(positive_prompts),
414
+ "negative_count": len(negative_prompts),
415
+ "timestep": timestep.item(),
416
+ "sigma": sigma.item(),
417
+ "snr_weight": snr_weight,
418
+ "lora_positive_norm": lora_positive_delta.norm().item(),
419
+ "lora_negative_norm": lora_negative_norm
420
+ }
421
 
422
  return total_loss, metrics
423
 
424
 
425
+ # ============================================================================
426
+ # TRAINING FUNCTION
427
+ # ============================================================================
428
+
429
+ def train_attribute_binding(config: AttributeBindingConfig):
430
+ """Fast training for attribute binding with Min-SNR"""
431
  device = "cuda"
432
  torch.manual_seed(config.seed)
433
 
434
+ if not config.attribute_pairs:
435
+ raise ValueError("No attribute pairs specified!")
436
 
437
+ pairs_with_negatives = sum(1 for p in config.attribute_pairs if p.negatives)
438
+ print(f"Pairs with explicit negatives: {pairs_with_negatives}/{len(config.attribute_pairs)}")
 
 
439
 
 
440
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
441
+ output_dir = os.path.join(config.output_dir, f"attribute_binding_{timestamp}")
 
 
 
 
 
 
 
 
442
  os.makedirs(output_dir, exist_ok=True)
443
 
444
  writer = SummaryWriter(log_dir=output_dir, flush_secs=60)
 
447
  json.dump(asdict(config), f, indent=2)
448
 
449
  print("="*80)
450
+ print("ATTRIBUTE BINDING TRAINING")
451
+ if config.use_min_snr:
452
+ print(f"Using Min-SNR Weighting (gamma={config.min_snr_gamma})")
453
  print("="*80)
454
 
455
+ # VERIFY UNET LOADING
456
+ print("\nVerifying UNet loading...")
457
+ print("Loading base SD1.5 UNet for comparison...")
458
+ unet_base = UNet2DConditionModel.from_pretrained(
459
+ "runwayml/stable-diffusion-v1-5",
460
+ subfolder="unet",
461
+ torch_dtype=torch.float32
462
+ ).to(device)
463
+
464
+ # Create test inputs
465
+ test_latents = torch.randn(1, 4, 64, 64, device=device)
466
+ test_timestep = torch.tensor([500], device=device)
467
+ test_encoder = torch.randn(1, 77, 768, device=device)
468
+
469
+ with torch.no_grad():
470
+ baseline_out = unet_base(test_latents, test_timestep, encoder_hidden_states=test_encoder, return_dict=False)[0]
471
+
472
+ print(f"Baseline output norm: {baseline_out.norm().item():.6f}")
473
+ del unet_base
474
+ torch.cuda.empty_cache()
475
+
476
+ print("\nLoading Lune flow-matching model...")
477
  checkpoint_path = hf_hub_download(
478
  repo_id=config.base_model_repo,
479
  filename=config.base_checkpoint,
 
489
 
490
  student_dict = checkpoint["student"]
491
  cleaned_dict = {k[5:] if k.startswith("unet.") else k: v for k, v in student_dict.items()}
492
+ missing, unexpected = unet.load_state_dict(cleaned_dict, strict=False)
493
+
494
+ print(f"Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
495
+
496
  unet = unet.to(device)
497
  unet.requires_grad_(False)
498
  unet.eval()
 
499
 
500
+ # Verify Lune loaded correctly
501
+ with torch.no_grad():
502
+ lune_out = unet(test_latents, test_timestep, encoder_hidden_states=test_encoder, return_dict=False)[0]
503
+
504
+ print(f"Lune output norm: {lune_out.norm().item():.6f}")
505
+ diff = (lune_out - baseline_out).abs().mean().item()
506
+ print(f"Difference from baseline: {diff:.6f}")
507
+
508
+ if diff < 1e-4:
509
+ print("⚠️ WARNING: Outputs are nearly identical - checkpoint may not have loaded!")
510
+ else:
511
+ print("✓ Lune checkpoint loaded correctly (outputs differ)")
512
+
513
+ print("\nLoading CLIP...")
514
  tokenizer = CLIPTokenizer.from_pretrained(
515
  "runwayml/stable-diffusion-v1-5", subfolder="tokenizer"
516
  )
 
522
  text_encoder.eval()
523
  print("✓ Loaded CLIP")
524
 
525
+ print(f"\nCreating LoRA (rank={config.lora_rank})...")
526
+
527
+ leco_config = LECOConfig(
528
+ lora_rank=config.lora_rank,
529
+ lora_alpha=config.lora_alpha,
530
+ training_method=config.training_method
531
+ )
532
+
533
+ lora_state, trainable_params = create_lora_layers(unet, leco_config)
534
 
 
535
  print(f"Moving LoRA parameters to {device}...")
536
  for param in trainable_params:
537
  param.data = param.data.to(device)
538
 
 
539
  for key, value in lora_state.items():
540
+ if isinstance(value, torch.Tensor) and not isinstance(value, nn.Parameter):
541
  lora_state[key] = value.to(device)
542
 
543
  optimizer = torch.optim.AdamW(trainable_params, lr=config.lr, weight_decay=0.01)
544
 
 
545
  print(f"\nTraining Configuration:")
546
+ print(f" Attribute pairs: {len(config.attribute_pairs)}")
547
+ for i, pair in enumerate(config.attribute_pairs[:3], 1):
548
+ print(f" {i}. {pair.attr1} + {pair.attr2} (weight: {pair.weight})")
549
+ if pair.negatives:
550
+ print(f" Negatives: {len(pair.negatives)} total")
551
+ if len(config.attribute_pairs) > 3:
552
+ print(f" ... and {len(config.attribute_pairs)-3} more")
 
 
 
553
 
554
  print(f"\n Iterations: {config.iterations}")
555
+ print(f" Pairs per batch: {config.pairs_per_batch}")
556
+ print(f" Negatives per positive: {config.negatives_per_positive}")
557
  print(f" Learning rate: {config.lr}")
 
 
558
  print("="*80 + "\n")
559
 
 
560
  progress = tqdm(range(config.iterations), desc="Training")
561
 
562
  for step in progress:
563
+ sampled_pairs = random.sample(
564
+ config.attribute_pairs,
565
+ min(config.pairs_per_batch, len(config.attribute_pairs))
566
+ )
567
 
568
+ loss, metrics = compute_attribute_binding_loss_batched(
569
+ unet, lora_state,
570
+ sampled_pairs,
571
+ tokenizer, text_encoder,
572
+ config,
573
+ device
574
  )
575
 
 
576
  loss.backward()
577
  grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
578
  optimizer.step()
579
  optimizer.zero_grad()
580
 
 
581
  writer.add_scalar("loss/total", loss.item(), step)
582
+ writer.add_scalar("loss/positive", metrics["positive_loss"], step)
583
+ writer.add_scalar("loss/negative", metrics["negative_loss"], step)
584
  writer.add_scalar("grad_norm", grad_norm.item(), step)
585
+ writer.add_scalar("snr_weight", metrics["snr_weight"], step)
586
 
587
  progress.set_postfix({
588
  "loss": f"{loss.item():.4f}",
589
+ "pos": f"{metrics['positive_loss']:.3f}",
590
+ "neg": f"{metrics['negative_loss']:.3f}",
591
+ "snr": f"{metrics['snr_weight']:.2f}",
592
  "grad": f"{grad_norm.item():.3f}"
593
  })
594
 
595
+ if (step + 1) % config.save_every == 0 or step == config.iterations - 1:
596
+ save_dict = {}
597
+ for key, value in lora_state.items():
598
+ if isinstance(value, torch.Tensor) and not key.endswith("._module"):
599
+ save_dict[key] = value.detach().cpu()
600
+
601
+ metadata = {
602
+ "ss_network_module": "networks.lora",
603
+ "ss_network_dim": str(config.lora_rank),
604
+ "ss_network_alpha": str(config.lora_alpha),
605
+ "ss_training_method": config.training_method,
606
+ "leco_action": "attribute_binding",
607
+ "leco_num_pairs": str(len(config.attribute_pairs)),
608
+ "leco_step": str(step + 1),
609
+ "leco_min_snr": str(config.use_min_snr),
610
+ "leco_min_snr_gamma": str(config.min_snr_gamma)
611
+ }
612
+
613
+ filename = f"{config.name_prefix}_r{config.lora_rank}_s{step+1}.safetensors"
614
+ filepath = os.path.join(output_dir, filename)
615
+
616
+ save_file(save_dict, filepath, metadata=metadata)
617
+ print(f"\n✓ Saved: {filename}")
618
 
619
  writer.close()
620
 
 
626
  return output_dir
627
 
628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629
  if __name__ == "__main__":
630
 
631
+ # Example 1: Hair + Clothes colors
632
+ universal_negs = ["ugly, duplicate, morbid, mutilated, blurry, fuzzy, out of frame, gross"]
 
 
 
 
 
 
 
 
 
 
 
 
 
633
 
634
+ hair_colors = ["red hair", "blue hair", "green hair"]
635
+ clothes = ["red shirt", "blue shirt", "green shirt"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
 
637
+ hair_clothes_pairs = create_attribute_combinations(
638
+ pair_attr1=hair_colors,
639
+ pair_attr2=clothes,
640
+ negatives=universal_negs,
641
+ weight=1.0,
642
+ auto_generate_negatives=True
 
 
 
 
 
 
 
 
 
 
643
  )
644
 
645
+ print(f"Generated {len(hair_clothes_pairs)} hair+clothes pairs")
646
+
647
+ # Training config
648
+ config = AttributeBindingConfig(
649
+ name_prefix="color_clothes_test",
650
+ attribute_pairs=hair_clothes_pairs,
651
+ iterations=5000,
652
+ lora_rank=16,
653
+ lr=2e-4,
654
+ pairs_per_batch=4,
655
+ negatives_per_positive=3,
 
 
 
 
656
  training_method="xattn",
657
+ save_every=250,
658
+
659
+ # Flow matching parameters
660
+ shift=2.5,
661
+ min_timestep=0.0,
662
+ max_timestep=1000.0,
663
+
664
+ # Min-SNR enabled
665
+ use_min_snr=True,
666
+ min_snr_gamma=5.0
667
  )
668
 
669
+ train_attribute_binding(config)