AbstractPhil commited on
Commit
3fea588
·
verified ·
1 Parent(s): f1b7957

Update flow_leco_trainer.py

Browse files
Files changed (1) hide show
  1. flow_leco_trainer.py +309 -199
flow_leco_trainer.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- Lune LECO Trainer - Fixed
3
  """
4
 
5
  import os
@@ -22,36 +22,56 @@ from huggingface_hub import hf_hub_download
22
 
23
  class ActionType(str, Enum):
24
  """LECO action types"""
25
- ERASE = "erase"
26
- ENHANCE = "enhance"
27
- REPLACE = "replace"
28
- SUPPRESS = "suppress"
29
 
30
 
31
  @dataclass
32
- class ConceptPair:
33
  """
34
- Single concept transformation pair.
35
 
36
- The LoRA learns: pred(concept) - pred(anchor)
 
 
 
 
37
 
38
  Examples:
39
- Erase: ConceptPair("anime style", "")
40
- Enhance: ConceptPair("masterpiece", "")
41
- Replace: ConceptPair("van gogh", "monet")
42
- Suppress: ConceptPair("nsfw", "sfw")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
- concept: str
45
- anchor: str = ""
46
- weight: float = 1.0
47
- inference_weight: float = -1.0
48
-
49
-
50
- @dataclass
51
- class PreservationSet:
52
- """Prompts that should remain unchanged"""
53
- prompts: List[str] = field(default_factory=list)
54
- weight: float = 0.3
55
 
56
 
57
  @dataclass
@@ -61,28 +81,29 @@ class LECOConfig:
61
  base_model_repo: str = "AbstractPhil/sd15-flow-lune-flux"
62
  base_checkpoint: str = "sd15_flow_flux_t2_6_pose_t4_6_port_t1_4_s18765.pt"
63
 
64
- # HuggingFace
65
  hf_repo_id: str = "AbstractPhil/lune-leco-adapters"
66
  upload_to_hub: bool = False
67
 
68
  # Training data
69
  action: ActionType = ActionType.ERASE
70
- concept_pairs: List[ConceptPair] = field(default_factory=list)
71
- preservation: PreservationSet = field(default_factory=PreservationSet)
72
 
73
  # LoRA architecture
74
  lora_rank: int = 4
75
  lora_alpha: float = 1.0
76
  lora_dropout: float = 0.0
77
- training_method: Literal["full", "selfattn", "xattn", "noxattn", "innoxattn"] = "full"
78
 
79
- # Training
80
  seed: int = 42
81
  iterations: int = 1000
82
  lr: float = 1e-4
83
- pairs_per_step: int = 1
84
 
85
- # Flow-matching
 
 
 
86
  shift: float = 2.5
87
  min_timestep: float = 0.0
88
  max_timestep: float = 1000.0
@@ -92,7 +113,7 @@ class LECOConfig:
92
 
93
 
94
  def get_target_modules(training_method: str) -> List[str]:
95
- """Get layer names for LoRA injection"""
96
  attn1 = ["attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out.0"]
97
  attn2 = ["attn2.to_q", "attn2.to_k", "attn2.to_v", "attn2.to_out.0"]
98
 
@@ -107,7 +128,7 @@ def get_target_modules(training_method: str) -> List[str]:
107
 
108
 
109
  def create_lora_layers(unet: torch.nn.Module, config: LECOConfig):
110
- """Create LoRA layers in ComfyUI/A1111 format"""
111
  target_modules = get_target_modules(config.training_method)
112
  lora_state = {}
113
  trainable_params = []
@@ -115,9 +136,6 @@ def create_lora_layers(unet: torch.nn.Module, config: LECOConfig):
115
  def get_lora_key(module_path: str) -> str:
116
  return f"lora_unet_{module_path.replace('.', '_')}"
117
 
118
- print(f"Creating LoRA layers (method: {config.training_method})...")
119
- layer_count = 0
120
-
121
  for name, module in unet.named_modules():
122
  if not any(target in name for target in target_modules):
123
  continue
@@ -130,9 +148,6 @@ def create_lora_layers(unet: torch.nn.Module, config: LECOConfig):
130
  out_dim = module.out_features
131
  rank = config.lora_rank
132
 
133
- # LoRA matrices
134
- # down: [rank, in_features]
135
- # up: [out_features, rank]
136
  lora_down = torch.nn.Parameter(torch.zeros(rank, in_dim))
137
  lora_up = torch.nn.Parameter(torch.zeros(out_dim, rank))
138
 
@@ -145,21 +160,13 @@ def create_lora_layers(unet: torch.nn.Module, config: LECOConfig):
145
  lora_state[f"{lora_key}._module"] = module
146
 
147
  trainable_params.extend([lora_down, lora_up])
148
- layer_count += 1
149
 
150
- print(f"✓ Created {layer_count} LoRA layers ({len(trainable_params)} parameters)")
151
  return lora_state, trainable_params
152
 
153
 
154
  def apply_lora_hooks(unet: torch.nn.Module, lora_state: dict, scale: float = 1.0) -> list:
155
- """
156
- Apply LoRA using forward hooks.
157
-
158
- LoRA computation: out = out + scale * (x @ down.T @ up.T)
159
- Using F.linear: F.linear(x, W) computes x @ W.T
160
-
161
- So: F.linear(F.linear(x, down), up) gives x @ down.T @ up.T ✓
162
- """
163
  handles = []
164
 
165
  for key in lora_state:
@@ -178,9 +185,6 @@ def apply_lora_hooks(unet: torch.nn.Module, lora_state: dict, scale: float = 1.0
178
  def make_hook(down, up, s):
179
  def forward_hook(mod, inp, out):
180
  x = inp[0]
181
- # F.linear handles transpose internally
182
- # down is [rank, in_features], F.linear does x @ down.T
183
- # up is [out_features, rank], F.linear does result @ up.T
184
  lora_out = F.linear(F.linear(x, down), up)
185
  return out + lora_out * s
186
  return forward_hook
@@ -211,145 +215,182 @@ def encode_text(prompt: str, tokenizer, text_encoder, device) -> torch.Tensor:
211
  return text_encoder(tokens)[0]
212
 
213
 
214
- def compute_leco_loss(
215
  unet: torch.nn.Module,
216
  lora_state: dict,
217
- pair: ConceptPair,
218
  tokenizer,
219
  text_encoder,
220
  config: LECOConfig,
221
  device: str = "cuda"
222
  ):
223
  """
224
- Compute LECO loss for a concept pair.
225
 
226
- Trains LoRA to reproduce: pred(concept) - pred(anchor)
227
- """
228
- # Sample timestep
229
- min_sigma = config.min_timestep / 1000.0
230
- max_sigma = config.max_timestep / 1000.0
231
- sigma = min_sigma + torch.rand(1, device=device) * (max_sigma - min_sigma)
232
- sigma = (config.shift * sigma) / (1 + (config.shift - 1) * sigma)
233
- timestep = sigma * 1000.0
234
- sigma = sigma.view(1, 1, 1, 1)
235
-
236
- # Random noise
237
- noise = torch.randn(1, 4, config.resolution // 8, config.resolution // 8, device=device)
238
- noisy_input = sigma * noise
239
-
240
- # Encode prompts
241
- concept_emb = encode_text(pair.concept, tokenizer, text_encoder, device)
242
- anchor_emb = encode_text(pair.anchor, tokenizer, text_encoder, device)
243
-
244
- # Compute target direction (without LoRA)
245
- with torch.no_grad():
246
- pred_concept = unet(
247
- noisy_input, timestep,
248
- encoder_hidden_states=concept_emb,
249
- return_dict=False
250
- )[0]
251
-
252
- pred_anchor = unet(
253
- noisy_input, timestep,
254
- encoder_hidden_states=anchor_emb,
255
- return_dict=False
256
- )[0]
257
-
258
- target_delta = pred_concept - pred_anchor
259
 
260
- # Apply LoRA and get its contribution
261
- handles = apply_lora_hooks(unet, lora_state, scale=1.0)
 
262
 
263
- try:
264
- pred_with_lora = unet(
265
- noisy_input, timestep,
266
- encoder_hidden_states=concept_emb,
267
- return_dict=False
268
- )[0]
269
-
270
- lora_delta = pred_with_lora - pred_concept
271
- loss = F.mse_loss(lora_delta, target_delta)
272
-
273
- finally:
274
- remove_lora_hooks(handles)
275
-
276
- return loss, {
277
- "timestep": timestep.item(),
278
- "sigma": sigma.item(),
279
- "target_norm": target_delta.norm().item(),
280
- "lora_norm": lora_delta.norm().item()
281
- }
282
-
283
-
284
- def compute_preservation_loss(
285
- unet: torch.nn.Module,
286
- lora_state: dict,
287
- preservation: PreservationSet,
288
- tokenizer,
289
- text_encoder,
290
- config: LECOConfig,
291
- device: str = "cuda"
292
- ):
293
- """Penalize LoRA changes to preservation prompts"""
294
- if not preservation.prompts:
295
- return 0.0, {}
296
 
 
297
  min_sigma = config.min_timestep / 1000.0
298
  max_sigma = config.max_timestep / 1000.0
299
  sigma = min_sigma + torch.rand(1, device=device) * (max_sigma - min_sigma)
300
  sigma = (config.shift * sigma) / (1 + (config.shift - 1) * sigma)
301
  timestep = sigma * 1000.0
302
- sigma = sigma.view(1, 1, 1, 1)
303
 
304
  total_loss = 0
 
 
 
 
 
 
305
 
306
- for prompt in preservation.prompts:
 
307
  noise = torch.randn(1, 4, config.resolution // 8, config.resolution // 8, device=device)
308
- noisy_input = sigma * noise
309
- prompt_emb = encode_text(prompt, tokenizer, text_encoder, device)
310
 
 
 
 
 
 
 
 
 
 
 
 
311
  with torch.no_grad():
312
- pred_base = unet(
 
 
 
 
 
 
313
  noisy_input, timestep,
314
- encoder_hidden_states=prompt_emb,
315
  return_dict=False
316
  )[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
 
318
  handles = apply_lora_hooks(unet, lora_state, scale=1.0)
319
 
320
  try:
321
  pred_with_lora = unet(
322
  noisy_input, timestep,
323
- encoder_hidden_states=prompt_emb,
324
  return_dict=False
325
  )[0]
326
  finally:
327
  remove_lora_hooks(handles)
328
 
329
- total_loss += F.mse_loss(pred_with_lora, pred_base)
330
-
331
- avg_loss = total_loss / len(preservation.prompts)
332
- return avg_loss, {"count": len(preservation.prompts), "avg": avg_loss.item()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
 
335
  def train_leco(config: LECOConfig):
336
- """Main training loop"""
337
  device = "cuda"
338
  torch.manual_seed(config.seed)
339
 
340
- if not config.concept_pairs:
341
- raise ValueError("No concept pairs specified!")
 
 
 
 
 
342
 
343
  # Setup output
344
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
345
- concept_names = "_".join([
346
- p.concept.replace(" ", "")[:12]
347
- for p in config.concept_pairs[:2]
348
- ])
349
- if len(config.concept_pairs) > 2:
350
- concept_names += f"_plus{len(config.concept_pairs)-2}"
351
-
352
- run_name = f"{config.action.value}_{concept_names}_{timestamp}"
353
  output_dir = os.path.join(config.output_dir, run_name)
354
  os.makedirs(output_dir, exist_ok=True)
355
 
@@ -398,30 +439,39 @@ def train_leco(config: LECOConfig):
398
  text_encoder.eval()
399
  print("✓ Loaded CLIP")
400
 
401
- # Create LoRA
402
  print(f"\nInjecting LoRA (rank={config.lora_rank}, alpha={config.lora_alpha})...")
403
  lora_state, trainable_params = create_lora_layers(unet, config)
404
 
405
- for key in lora_state:
406
- if isinstance(lora_state[key], torch.Tensor):
407
- lora_state[key] = lora_state[key].to(device)
 
 
 
 
 
 
408
 
409
  optimizer = torch.optim.AdamW(trainable_params, lr=config.lr, weight_decay=0.01)
410
 
411
  # Print config
412
  print(f"\nTraining Configuration:")
413
  print(f" Action: {config.action.value}")
414
- print(f" Concept pairs: {len(config.concept_pairs)}")
415
- for i, pair in enumerate(config.concept_pairs, 1):
416
- anchor_str = f" '{pair.anchor}'" if pair.anchor else "(none)"
417
- print(f" {i}. '{pair.concept}' {anchor_str} (weight: {pair.weight})")
418
-
419
- if config.preservation.prompts:
420
- print(f" Preservation: {len(config.preservation.prompts)} prompts")
 
 
421
 
422
  print(f"\n Iterations: {config.iterations}")
423
  print(f" Learning rate: {config.lr}")
424
  print(f" Training method: {config.training_method}")
 
425
  print("="*80 + "\n")
426
 
427
  # Training loop
@@ -429,50 +479,37 @@ def train_leco(config: LECOConfig):
429
 
430
  for step in progress:
431
  import random
432
- if config.pairs_per_step >= len(config.concept_pairs):
433
- active_pairs = config.concept_pairs
434
- else:
435
- active_pairs = random.sample(config.concept_pairs, config.pairs_per_step)
436
-
437
- total_loss = 0
438
- all_metrics = []
439
 
440
- for pair in active_pairs:
441
- loss, metrics = compute_leco_loss(
442
- unet, lora_state, pair,
443
- tokenizer, text_encoder, config, device
444
- )
445
- total_loss += loss * pair.weight
446
- all_metrics.append(metrics)
447
 
448
- if config.preservation.prompts:
449
- pres_loss, pres_metrics = compute_preservation_loss(
450
- unet, lora_state, config.preservation,
451
- tokenizer, text_encoder, config, device
452
- )
453
- total_loss += pres_loss * config.preservation.weight
454
- else:
455
- pres_loss = 0
456
 
457
- total_loss.backward()
 
458
  grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
459
  optimizer.step()
460
  optimizer.zero_grad()
461
 
462
  # Logging
463
- writer.add_scalar("loss/total", total_loss.item(), step)
464
- writer.add_scalar("loss/preservation", pres_loss if isinstance(pres_loss, (float, int)) else pres_loss.item(), step)
 
465
  writer.add_scalar("grad_norm", grad_norm.item(), step)
466
 
467
- avg_target = sum(m["target_norm"] for m in all_metrics) / len(all_metrics)
468
  progress.set_postfix({
469
- "loss": f"{total_loss.item():.4f}",
470
- "grad": f"{grad_norm.item():.3f}",
471
- "target": f"{avg_target:.3f}"
 
472
  })
473
 
474
  if (step + 1) % 200 == 0 or step == config.iterations - 1:
475
- save_checkpoint(lora_state, config, output_dir, step + 1, concept_names)
476
 
477
  writer.close()
478
 
@@ -492,8 +529,15 @@ def save_checkpoint(lora_state, config, output_dir, step, name_suffix):
492
  if isinstance(value, torch.Tensor) and not key.endswith("._module"):
493
  save_dict[key] = value.detach().cpu()
494
 
495
- concepts_str = ", ".join([p.concept for p in config.concept_pairs])
496
- anchors_str = ", ".join([p.anchor for p in config.concept_pairs if p.anchor])
 
 
 
 
 
 
 
497
 
498
  metadata = {
499
  "ss_network_module": "networks.lora",
@@ -502,9 +546,11 @@ def save_checkpoint(lora_state, config, output_dir, step, name_suffix):
502
  "ss_base_model": "runwayml/stable-diffusion-v1-5",
503
  "ss_training_method": config.training_method,
504
  "leco_action": config.action.value,
505
- "leco_concepts": concepts_str,
506
- "leco_anchors": anchors_str,
507
- "leco_step": str(step)
 
 
508
  }
509
 
510
  filename = f"leco_{name_suffix}_r{config.lora_rank}_s{step}.safetensors"
@@ -514,17 +560,81 @@ def save_checkpoint(lora_state, config, output_dir, step, name_suffix):
514
  print(f"\n✓ Saved: {filename}")
515
 
516
 
 
 
 
 
517
  if __name__ == "__main__":
518
- config = LECOConfig(
519
- action=ActionType.ENHANCE,
520
- concept_pairs=[
521
- ConceptPair("masterpiece", "", weight=1.0),
522
- ConceptPair("best quality", "", weight=1.0),
523
- ConceptPair("highly detailed", "", weight=0.8),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  ],
525
- iterations=600,
526
  lora_rank=4,
527
- training_method="selfattn"
 
528
  )
529
 
530
- train_leco(config)
 
 
1
  """
2
+ Lune LECO Trainer - Proper Concept Group Implementation
3
  """
4
 
5
  import os
 
22
 
23
  class ActionType(str, Enum):
24
  """LECO action types"""
25
+ ERASE = "erase" # sources → empty
26
+ ENHANCE = "enhance" # sources → amplified
27
+ REPLACE = "replace" # sources → target
28
+ NEUTRALIZE = "neutralize" # sources → neutral
29
 
30
 
31
  @dataclass
32
+ class ConceptGroup:
33
  """
34
+ A group of related concepts to transform together.
35
 
36
+ Training strategy:
37
+ - Sample from sources: these are the concepts to modify
38
+ - Transform to target: what they should become
39
+ - Use neutral as intermediate: optional neutral reference point
40
+ - Preserve negatives: concepts that should NOT be affected
41
 
42
  Examples:
43
+ # Erase multiple anime styles
44
+ ConceptGroup(
45
+ sources=["anime", "manga", "cartoon"],
46
+ target="",
47
+ negatives=["realistic", "photograph"],
48
+ weight=1.0
49
+ )
50
+
51
+ # Replace artists
52
+ ConceptGroup(
53
+ sources=["van gogh", "picasso"],
54
+ target="monet",
55
+ neutral="painting",
56
+ negatives=["photograph", "digital art"],
57
+ weight=1.0
58
+ )
59
+
60
+ # Neutralize NSFW to safe
61
+ ConceptGroup(
62
+ sources=["nsfw", "nude", "explicit"],
63
+ target="safe",
64
+ neutral="person",
65
+ negatives=["portrait", "art", "figure drawing"],
66
+ weight=2.0
67
+ )
68
  """
69
+ sources: List[str] # Concepts to modify (sampled during training)
70
+ target: str = "" # What to transform to (empty = erase)
71
+ neutral: str = "" # Optional neutral reference point
72
+ negatives: List[str] = field(default_factory=list) # Concepts to preserve
73
+ weight: float = 1.0 # Group importance
74
+ preservation_weight: float = 0.5 # How strongly to preserve negatives
 
 
 
 
 
75
 
76
 
77
  @dataclass
 
81
  base_model_repo: str = "AbstractPhil/sd15-flow-lune-flux"
82
  base_checkpoint: str = "sd15_flow_flux_t2_6_pose_t4_6_port_t1_4_s18765.pt"
83
 
84
+ # HuggingFace upload
85
  hf_repo_id: str = "AbstractPhil/lune-leco-adapters"
86
  upload_to_hub: bool = False
87
 
88
  # Training data
89
  action: ActionType = ActionType.ERASE
90
+ concept_groups: List[ConceptGroup] = field(default_factory=list)
 
91
 
92
  # LoRA architecture
93
  lora_rank: int = 4
94
  lora_alpha: float = 1.0
95
  lora_dropout: float = 0.0
96
+ training_method: Literal["full", "selfattn", "xattn", "noxattn", "innoxattn"] = "xattn"
97
 
98
+ # Training hyperparameters
99
  seed: int = 42
100
  iterations: int = 1000
101
  lr: float = 1e-4
 
102
 
103
+ # Sampling strategy
104
+ sources_per_step: int = 2 # How many source concepts to sample per step
105
+
106
+ # Flow-matching parameters
107
  shift: float = 2.5
108
  min_timestep: float = 0.0
109
  max_timestep: float = 1000.0
 
113
 
114
 
115
  def get_target_modules(training_method: str) -> List[str]:
116
+ """Get layer names to inject LoRA based on training method."""
117
  attn1 = ["attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out.0"]
118
  attn2 = ["attn2.to_q", "attn2.to_k", "attn2.to_v", "attn2.to_out.0"]
119
 
 
128
 
129
 
130
  def create_lora_layers(unet: torch.nn.Module, config: LECOConfig):
131
+ """Create LoRA layers in ComfyUI/A1111 compatible format."""
132
  target_modules = get_target_modules(config.training_method)
133
  lora_state = {}
134
  trainable_params = []
 
136
  def get_lora_key(module_path: str) -> str:
137
  return f"lora_unet_{module_path.replace('.', '_')}"
138
 
 
 
 
139
  for name, module in unet.named_modules():
140
  if not any(target in name for target in target_modules):
141
  continue
 
148
  out_dim = module.out_features
149
  rank = config.lora_rank
150
 
 
 
 
151
  lora_down = torch.nn.Parameter(torch.zeros(rank, in_dim))
152
  lora_up = torch.nn.Parameter(torch.zeros(out_dim, rank))
153
 
 
160
  lora_state[f"{lora_key}._module"] = module
161
 
162
  trainable_params.extend([lora_down, lora_up])
 
163
 
164
+ print(f"✓ Created {len(trainable_params)//2} LoRA layers ({len(trainable_params)} parameters)")
165
  return lora_state, trainable_params
166
 
167
 
168
  def apply_lora_hooks(unet: torch.nn.Module, lora_state: dict, scale: float = 1.0) -> list:
169
+ """Apply LoRA using forward hooks."""
 
 
 
 
 
 
 
170
  handles = []
171
 
172
  for key in lora_state:
 
185
  def make_hook(down, up, s):
186
  def forward_hook(mod, inp, out):
187
  x = inp[0]
 
 
 
188
  lora_out = F.linear(F.linear(x, down), up)
189
  return out + lora_out * s
190
  return forward_hook
 
215
  return text_encoder(tokens)[0]
216
 
217
 
218
+ def compute_concept_group_loss(
219
  unet: torch.nn.Module,
220
  lora_state: dict,
221
+ group: ConceptGroup,
222
  tokenizer,
223
  text_encoder,
224
  config: LECOConfig,
225
  device: str = "cuda"
226
  ):
227
  """
228
+ Compute LECO loss for a concept group.
229
 
230
+ Strategy:
231
+ 1. Sample source concepts from group.sources
232
+ 2. Compute transformation: source → target (using neutral if provided)
233
+ 3. Preserve negatives (ensure LoRA doesn't affect them)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ The LoRA learns to transform ALL sources to the same target.
236
+ """
237
+ import random
238
 
239
+ # Sample source concepts for this step
240
+ num_sources = min(config.sources_per_step, len(group.sources))
241
+ sampled_sources = random.sample(group.sources, num_sources)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ # Sample timestep (shared for this group)
244
  min_sigma = config.min_timestep / 1000.0
245
  max_sigma = config.max_timestep / 1000.0
246
  sigma = min_sigma + torch.rand(1, device=device) * (max_sigma - min_sigma)
247
  sigma = (config.shift * sigma) / (1 + (config.shift - 1) * sigma)
248
  timestep = sigma * 1000.0
249
+ sigma_expanded = sigma.view(1, 1, 1, 1)
250
 
251
  total_loss = 0
252
+ metrics = {
253
+ "source_loss": 0,
254
+ "preservation_loss": 0,
255
+ "sources_processed": 0,
256
+ "negatives_processed": 0
257
+ }
258
 
259
+ # === SOURCE TRANSFORMATION LOSS ===
260
+ for source_concept in sampled_sources:
261
  noise = torch.randn(1, 4, config.resolution // 8, config.resolution // 8, device=device)
262
+ noisy_input = sigma_expanded * noise
 
263
 
264
+ # Encode prompts
265
+ source_emb = encode_text(source_concept, tokenizer, text_encoder, device)
266
+ target_emb = encode_text(group.target, tokenizer, text_encoder, device)
267
+
268
+ # Optional: use neutral as intermediate reference
269
+ if group.neutral:
270
+ neutral_emb = encode_text(group.neutral, tokenizer, text_encoder, device)
271
+ else:
272
+ neutral_emb = None
273
+
274
+ # Compute target direction WITHOUT LoRA
275
  with torch.no_grad():
276
+ pred_source = unet(
277
+ noisy_input, timestep,
278
+ encoder_hidden_states=source_emb,
279
+ return_dict=False
280
+ )[0]
281
+
282
+ pred_target = unet(
283
  noisy_input, timestep,
284
+ encoder_hidden_states=target_emb,
285
  return_dict=False
286
  )[0]
287
+
288
+ # Determine transformation direction
289
+ if group.neutral and neutral_emb is not None:
290
+ # Use neutral as reference: source → neutral → target
291
+ pred_neutral = unet(
292
+ noisy_input, timestep,
293
+ encoder_hidden_states=neutral_emb,
294
+ return_dict=False
295
+ )[0]
296
+
297
+ # Two-step transformation
298
+ step1 = pred_neutral - pred_source # source → neutral
299
+ step2 = pred_target - pred_neutral # neutral → target
300
+ target_delta = step1 + step2 # combined transformation
301
+ else:
302
+ # Direct transformation: source → target
303
+ target_delta = pred_target - pred_source
304
 
305
+ # Apply LoRA and measure its effect
306
  handles = apply_lora_hooks(unet, lora_state, scale=1.0)
307
 
308
  try:
309
  pred_with_lora = unet(
310
  noisy_input, timestep,
311
+ encoder_hidden_states=source_emb,
312
  return_dict=False
313
  )[0]
314
  finally:
315
  remove_lora_hooks(handles)
316
 
317
+ # LoRA contribution
318
+ lora_delta = pred_with_lora - pred_source
319
+
320
+ # Loss: LoRA should reproduce the transformation
321
+ source_loss = F.mse_loss(lora_delta, target_delta)
322
+ total_loss += source_loss * group.weight
323
+ metrics["source_loss"] += source_loss.item()
324
+ metrics["sources_processed"] += 1
325
+
326
+ # === PRESERVATION LOSS (negatives should remain unchanged) ===
327
+ for negative_concept in group.negatives:
328
+ noise = torch.randn(1, 4, config.resolution // 8, config.resolution // 8, device=device)
329
+ noisy_input = sigma_expanded * noise
330
+
331
+ negative_emb = encode_text(negative_concept, tokenizer, text_encoder, device)
332
+
333
+ # Baseline without LoRA
334
+ with torch.no_grad():
335
+ pred_negative = unet(
336
+ noisy_input, timestep,
337
+ encoder_hidden_states=negative_emb,
338
+ return_dict=False
339
+ )[0]
340
+
341
+ # With LoRA
342
+ handles = apply_lora_hooks(unet, lora_state, scale=1.0)
343
+
344
+ try:
345
+ pred_with_lora = unet(
346
+ noisy_input, timestep,
347
+ encoder_hidden_states=negative_emb,
348
+ return_dict=False
349
+ )[0]
350
+ finally:
351
+ remove_lora_hooks(handles)
352
+
353
+ # Penalize any change
354
+ preservation_loss = F.mse_loss(pred_with_lora, pred_negative)
355
+ total_loss += preservation_loss * group.preservation_weight
356
+ metrics["preservation_loss"] += preservation_loss.item()
357
+ metrics["negatives_processed"] += 1
358
+
359
+ # Average metrics
360
+ if metrics["sources_processed"] > 0:
361
+ metrics["source_loss"] /= metrics["sources_processed"]
362
+ if metrics["negatives_processed"] > 0:
363
+ metrics["preservation_loss"] /= metrics["negatives_processed"]
364
+
365
+ metrics["timestep"] = timestep.item()
366
+ metrics["sigma"] = sigma.item()
367
+
368
+ return total_loss, metrics
369
 
370
 
371
  def train_leco(config: LECOConfig):
372
+ """Main training loop with proper concept groups"""
373
  device = "cuda"
374
  torch.manual_seed(config.seed)
375
 
376
+ if not config.concept_groups:
377
+ raise ValueError("No concept groups specified!")
378
+
379
+ # Validate concept groups
380
+ for group in config.concept_groups:
381
+ if not group.sources:
382
+ raise ValueError("Each concept group must have at least one source concept")
383
 
384
  # Setup output
385
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
386
+
387
+ # Create name from first group
388
+ first_group = config.concept_groups[0]
389
+ source_names = "_".join([s.replace(" ", "")[:10] for s in first_group.sources[:2]])
390
+ if len(first_group.sources) > 2:
391
+ source_names += f"_plus{len(first_group.sources)-2}"
392
+
393
+ run_name = f"{config.action.value}_{source_names}_{timestamp}"
394
  output_dir = os.path.join(config.output_dir, run_name)
395
  os.makedirs(output_dir, exist_ok=True)
396
 
 
439
  text_encoder.eval()
440
  print("✓ Loaded CLIP")
441
 
442
+ # Create LoRA layers
443
  print(f"\nInjecting LoRA (rank={config.lora_rank}, alpha={config.lora_alpha})...")
444
  lora_state, trainable_params = create_lora_layers(unet, config)
445
 
446
+ # Move Parameters to device IN-PLACE
447
+ print(f"Moving LoRA parameters to {device}...")
448
+ for param in trainable_params:
449
+ param.data = param.data.to(device)
450
+
451
+ # Move other tensors to device
452
+ for key, value in lora_state.items():
453
+ if isinstance(value, torch.Tensor) and not isinstance(value, torch.nn.Parameter):
454
+ lora_state[key] = value.to(device)
455
 
456
  optimizer = torch.optim.AdamW(trainable_params, lr=config.lr, weight_decay=0.01)
457
 
458
  # Print config
459
  print(f"\nTraining Configuration:")
460
  print(f" Action: {config.action.value}")
461
+ print(f" Concept groups: {len(config.concept_groups)}")
462
+ for i, group in enumerate(config.concept_groups, 1):
463
+ print(f"\n Group {i} (weight: {group.weight}):")
464
+ print(f" Sources: {', '.join(group.sources)}")
465
+ print(f" Target: '{group.target}'" if group.target else " Target: (erase)")
466
+ if group.neutral:
467
+ print(f" Neutral: '{group.neutral}'")
468
+ if group.negatives:
469
+ print(f" Preserve: {', '.join(group.negatives)}")
470
 
471
  print(f"\n Iterations: {config.iterations}")
472
  print(f" Learning rate: {config.lr}")
473
  print(f" Training method: {config.training_method}")
474
+ print(f" Sources per step: {config.sources_per_step}")
475
  print("="*80 + "\n")
476
 
477
  # Training loop
 
479
 
480
  for step in progress:
481
  import random
 
 
 
 
 
 
 
482
 
483
+ # Sample a concept group
484
+ group = random.choice(config.concept_groups)
 
 
 
 
 
485
 
486
+ # Compute loss for this group
487
+ loss, metrics = compute_concept_group_loss(
488
+ unet, lora_state, group,
489
+ tokenizer, text_encoder, config, device
490
+ )
 
 
 
491
 
492
+ # Backprop
493
+ loss.backward()
494
  grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
495
  optimizer.step()
496
  optimizer.zero_grad()
497
 
498
  # Logging
499
+ writer.add_scalar("loss/total", loss.item(), step)
500
+ writer.add_scalar("loss/source", metrics["source_loss"], step)
501
+ writer.add_scalar("loss/preservation", metrics["preservation_loss"], step)
502
  writer.add_scalar("grad_norm", grad_norm.item(), step)
503
 
 
504
  progress.set_postfix({
505
+ "loss": f"{loss.item():.4f}",
506
+ "src": f"{metrics['source_loss']:.4f}",
507
+ "pres": f"{metrics['preservation_loss']:.4f}",
508
+ "grad": f"{grad_norm.item():.3f}"
509
  })
510
 
511
  if (step + 1) % 200 == 0 or step == config.iterations - 1:
512
+ save_checkpoint(lora_state, config, output_dir, step + 1, source_names)
513
 
514
  writer.close()
515
 
 
529
  if isinstance(value, torch.Tensor) and not key.endswith("._module"):
530
  save_dict[key] = value.detach().cpu()
531
 
532
+ # Build metadata
533
+ all_sources = []
534
+ all_targets = []
535
+ all_negatives = []
536
+ for group in config.concept_groups:
537
+ all_sources.extend(group.sources)
538
+ if group.target:
539
+ all_targets.append(group.target)
540
+ all_negatives.extend(group.negatives)
541
 
542
  metadata = {
543
  "ss_network_module": "networks.lora",
 
546
  "ss_base_model": "runwayml/stable-diffusion-v1-5",
547
  "ss_training_method": config.training_method,
548
  "leco_action": config.action.value,
549
+ "leco_sources": ", ".join(all_sources),
550
+ "leco_targets": ", ".join(all_targets) if all_targets else "",
551
+ "leco_negatives": ", ".join(all_negatives),
552
+ "leco_step": str(step),
553
+ "leco_num_groups": str(len(config.concept_groups))
554
  }
555
 
556
  filename = f"leco_{name_suffix}_r{config.lora_rank}_s{step}.safetensors"
 
560
  print(f"\n✓ Saved: {filename}")
561
 
562
 
563
+ # ============================================================================
564
+ # EXAMPLE CONFIGURATIONS
565
+ # ============================================================================
566
+
567
  if __name__ == "__main__":
568
+
569
+ # Example 1: Erase anime styles (multiple sources → empty)
570
+ config_erase_anime = LECOConfig(
571
+ action=ActionType.ERASE,
572
+ concept_groups=[
573
+ ConceptGroup(
574
+ sources=["anime", "manga", "cartoon"],
575
+ target="", # Erase
576
+ negatives=["realistic", "photograph", "painting"],
577
+ weight=1.0
578
+ )
579
+ ],
580
+ iterations=1000,
581
+ lora_rank=4,
582
+ training_method="xattn" # Cross-attention for semantic content
583
+ )
584
+
585
+ # Example 2: Replace artists (multiple sources → single target)
586
+ config_replace_artists = LECOConfig(
587
+ action=ActionType.REPLACE,
588
+ concept_groups=[
589
+ ConceptGroup(
590
+ sources=["van gogh", "picasso", "dali"],
591
+ target="monet",
592
+ neutral="painting", # Use painting as neutral reference
593
+ negatives=["photograph", "digital art"],
594
+ weight=1.0
595
+ )
596
+ ],
597
+ iterations=800,
598
+ lora_rank=8,
599
+ training_method="xattn"
600
+ )
601
+
602
+ # Example 3: Neutralize NSFW (multiple sources → safe target)
603
+ config_nsfw = LECOConfig(
604
+ action=ActionType.NEUTRALIZE,
605
+ concept_groups=[
606
+ ConceptGroup(
607
+ sources=["nsfw", "nude", "explicit", "naked"],
608
+ target="clothed",
609
+ neutral="person",
610
+ negatives=["portrait", "figure drawing", "classical art", "sculpture"],
611
+ weight=2.0,
612
+ preservation_weight=0.8 # Strong preservation
613
+ )
614
+ ],
615
+ iterations=1200,
616
+ lora_rank=4,
617
+ training_method="full"
618
+ )
619
+
620
+ # Example 4: Your original request - weird food combos
621
+ config_food = LECOConfig(
622
+ action=ActionType.ERASE,
623
+ concept_groups=[
624
+ ConceptGroup(
625
+ sources=["potato chicken sandwich", "taco pizza", "banana sushi"],
626
+ target="",
627
+ neutral="food",
628
+ negatives=["normal sandwiches", "table", "walls", "plates", "restaurant"],
629
+ weight=1.0,
630
+ preservation_weight=1.5
631
+ )
632
  ],
633
+ iterations=1000,
634
  lora_rank=4,
635
+ training_method="xattn",
636
+ sources_per_step=2 # Sample 2 weird foods per training step
637
  )
638
 
639
+ # Train
640
+ train_leco(config_erase_anime)