AbstractPhil commited on
Commit
f1b7957
·
verified ·
1 Parent(s): 9ff340a

Create flow_leco_trainer.py

Browse files
Files changed (1) hide show
  1. flow_leco_trainer.py +530 -0
flow_leco_trainer.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Lune LECO Trainer - Fixed
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import datetime
8
+ from dataclasses import dataclass, asdict, field
9
+ from typing import List, Literal
10
+ from tqdm.auto import tqdm
11
+ from enum import Enum
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch.utils.tensorboard import SummaryWriter
16
+ from safetensors.torch import save_file
17
+
18
+ from diffusers import UNet2DConditionModel
19
+ from transformers import CLIPTextModel, CLIPTokenizer
20
+ from huggingface_hub import hf_hub_download
21
+
22
+
23
+ class ActionType(str, Enum):
24
+ """LECO action types"""
25
+ ERASE = "erase"
26
+ ENHANCE = "enhance"
27
+ REPLACE = "replace"
28
+ SUPPRESS = "suppress"
29
+
30
+
31
+ @dataclass
32
+ class ConceptPair:
33
+ """
34
+ Single concept transformation pair.
35
+
36
+ The LoRA learns: pred(concept) - pred(anchor)
37
+
38
+ Examples:
39
+ Erase: ConceptPair("anime style", "")
40
+ Enhance: ConceptPair("masterpiece", "")
41
+ Replace: ConceptPair("van gogh", "monet")
42
+ Suppress: ConceptPair("nsfw", "sfw")
43
+ """
44
+ concept: str
45
+ anchor: str = ""
46
+ weight: float = 1.0
47
+ inference_weight: float = -1.0
48
+
49
+
50
+ @dataclass
51
+ class PreservationSet:
52
+ """Prompts that should remain unchanged"""
53
+ prompts: List[str] = field(default_factory=list)
54
+ weight: float = 0.3
55
+
56
+
57
+ @dataclass
58
+ class LECOConfig:
59
+ # Model paths
60
+ output_dir: str = "./leco_outputs"
61
+ base_model_repo: str = "AbstractPhil/sd15-flow-lune-flux"
62
+ base_checkpoint: str = "sd15_flow_flux_t2_6_pose_t4_6_port_t1_4_s18765.pt"
63
+
64
+ # HuggingFace
65
+ hf_repo_id: str = "AbstractPhil/lune-leco-adapters"
66
+ upload_to_hub: bool = False
67
+
68
+ # Training data
69
+ action: ActionType = ActionType.ERASE
70
+ concept_pairs: List[ConceptPair] = field(default_factory=list)
71
+ preservation: PreservationSet = field(default_factory=PreservationSet)
72
+
73
+ # LoRA architecture
74
+ lora_rank: int = 4
75
+ lora_alpha: float = 1.0
76
+ lora_dropout: float = 0.0
77
+ training_method: Literal["full", "selfattn", "xattn", "noxattn", "innoxattn"] = "full"
78
+
79
+ # Training
80
+ seed: int = 42
81
+ iterations: int = 1000
82
+ lr: float = 1e-4
83
+ pairs_per_step: int = 1
84
+
85
+ # Flow-matching
86
+ shift: float = 2.5
87
+ min_timestep: float = 0.0
88
+ max_timestep: float = 1000.0
89
+
90
+ # Resolution
91
+ resolution: int = 512
92
+
93
+
94
+ def get_target_modules(training_method: str) -> List[str]:
95
+ """Get layer names for LoRA injection"""
96
+ attn1 = ["attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out.0"]
97
+ attn2 = ["attn2.to_q", "attn2.to_k", "attn2.to_v", "attn2.to_out.0"]
98
+
99
+ method_map = {
100
+ "full": attn1 + attn2,
101
+ "selfattn": attn1,
102
+ "xattn": attn2,
103
+ "noxattn": attn1,
104
+ "innoxattn": attn2
105
+ }
106
+ return method_map.get(training_method, attn1 + attn2)
107
+
108
+
109
+ def create_lora_layers(unet: torch.nn.Module, config: LECOConfig):
110
+ """Create LoRA layers in ComfyUI/A1111 format"""
111
+ target_modules = get_target_modules(config.training_method)
112
+ lora_state = {}
113
+ trainable_params = []
114
+
115
+ def get_lora_key(module_path: str) -> str:
116
+ return f"lora_unet_{module_path.replace('.', '_')}"
117
+
118
+ print(f"Creating LoRA layers (method: {config.training_method})...")
119
+ layer_count = 0
120
+
121
+ for name, module in unet.named_modules():
122
+ if not any(target in name for target in target_modules):
123
+ continue
124
+
125
+ if not isinstance(module, torch.nn.Linear):
126
+ continue
127
+
128
+ lora_key = get_lora_key(name)
129
+ in_dim = module.in_features
130
+ out_dim = module.out_features
131
+ rank = config.lora_rank
132
+
133
+ # LoRA matrices
134
+ # down: [rank, in_features]
135
+ # up: [out_features, rank]
136
+ lora_down = torch.nn.Parameter(torch.zeros(rank, in_dim))
137
+ lora_up = torch.nn.Parameter(torch.zeros(out_dim, rank))
138
+
139
+ torch.nn.init.kaiming_uniform_(lora_down, a=1.0)
140
+ torch.nn.init.zeros_(lora_up)
141
+
142
+ lora_state[f"{lora_key}.lora_down.weight"] = lora_down
143
+ lora_state[f"{lora_key}.lora_up.weight"] = lora_up
144
+ lora_state[f"{lora_key}.alpha"] = torch.tensor(config.lora_alpha)
145
+ lora_state[f"{lora_key}._module"] = module
146
+
147
+ trainable_params.extend([lora_down, lora_up])
148
+ layer_count += 1
149
+
150
+ print(f"✓ Created {layer_count} LoRA layers ({len(trainable_params)} parameters)")
151
+ return lora_state, trainable_params
152
+
153
+
154
+ def apply_lora_hooks(unet: torch.nn.Module, lora_state: dict, scale: float = 1.0) -> list:
155
+ """
156
+ Apply LoRA using forward hooks.
157
+
158
+ LoRA computation: out = out + scale * (x @ down.T @ up.T)
159
+ Using F.linear: F.linear(x, W) computes x @ W.T
160
+
161
+ So: F.linear(F.linear(x, down), up) gives x @ down.T @ up.T ✓
162
+ """
163
+ handles = []
164
+
165
+ for key in lora_state:
166
+ if not key.endswith(".lora_down.weight"):
167
+ continue
168
+
169
+ base_key = key.replace(".lora_down.weight", "")
170
+ module = lora_state[f"{base_key}._module"]
171
+ lora_down = lora_state[f"{base_key}.lora_down.weight"]
172
+ lora_up = lora_state[f"{base_key}.lora_up.weight"]
173
+ alpha = lora_state[f"{base_key}.alpha"].item()
174
+ rank = lora_down.shape[0]
175
+
176
+ scaling = (alpha / rank) * scale
177
+
178
+ def make_hook(down, up, s):
179
+ def forward_hook(mod, inp, out):
180
+ x = inp[0]
181
+ # F.linear handles transpose internally
182
+ # down is [rank, in_features], F.linear does x @ down.T
183
+ # up is [out_features, rank], F.linear does result @ up.T
184
+ lora_out = F.linear(F.linear(x, down), up)
185
+ return out + lora_out * s
186
+ return forward_hook
187
+
188
+ handle = module.register_forward_hook(make_hook(lora_down, lora_up, scaling))
189
+ handles.append(handle)
190
+
191
+ return handles
192
+
193
+
194
+ def remove_lora_hooks(handles: list):
195
+ """Remove all LoRA hooks"""
196
+ for handle in handles:
197
+ handle.remove()
198
+
199
+
200
+ @torch.no_grad()
201
+ def encode_text(prompt: str, tokenizer, text_encoder, device) -> torch.Tensor:
202
+ """Encode text to CLIP embeddings"""
203
+ tokens = tokenizer(
204
+ prompt,
205
+ padding="max_length",
206
+ max_length=tokenizer.model_max_length,
207
+ truncation=True,
208
+ return_tensors="pt"
209
+ ).input_ids.to(device)
210
+
211
+ return text_encoder(tokens)[0]
212
+
213
+
214
+ def compute_leco_loss(
215
+ unet: torch.nn.Module,
216
+ lora_state: dict,
217
+ pair: ConceptPair,
218
+ tokenizer,
219
+ text_encoder,
220
+ config: LECOConfig,
221
+ device: str = "cuda"
222
+ ):
223
+ """
224
+ Compute LECO loss for a concept pair.
225
+
226
+ Trains LoRA to reproduce: pred(concept) - pred(anchor)
227
+ """
228
+ # Sample timestep
229
+ min_sigma = config.min_timestep / 1000.0
230
+ max_sigma = config.max_timestep / 1000.0
231
+ sigma = min_sigma + torch.rand(1, device=device) * (max_sigma - min_sigma)
232
+ sigma = (config.shift * sigma) / (1 + (config.shift - 1) * sigma)
233
+ timestep = sigma * 1000.0
234
+ sigma = sigma.view(1, 1, 1, 1)
235
+
236
+ # Random noise
237
+ noise = torch.randn(1, 4, config.resolution // 8, config.resolution // 8, device=device)
238
+ noisy_input = sigma * noise
239
+
240
+ # Encode prompts
241
+ concept_emb = encode_text(pair.concept, tokenizer, text_encoder, device)
242
+ anchor_emb = encode_text(pair.anchor, tokenizer, text_encoder, device)
243
+
244
+ # Compute target direction (without LoRA)
245
+ with torch.no_grad():
246
+ pred_concept = unet(
247
+ noisy_input, timestep,
248
+ encoder_hidden_states=concept_emb,
249
+ return_dict=False
250
+ )[0]
251
+
252
+ pred_anchor = unet(
253
+ noisy_input, timestep,
254
+ encoder_hidden_states=anchor_emb,
255
+ return_dict=False
256
+ )[0]
257
+
258
+ target_delta = pred_concept - pred_anchor
259
+
260
+ # Apply LoRA and get its contribution
261
+ handles = apply_lora_hooks(unet, lora_state, scale=1.0)
262
+
263
+ try:
264
+ pred_with_lora = unet(
265
+ noisy_input, timestep,
266
+ encoder_hidden_states=concept_emb,
267
+ return_dict=False
268
+ )[0]
269
+
270
+ lora_delta = pred_with_lora - pred_concept
271
+ loss = F.mse_loss(lora_delta, target_delta)
272
+
273
+ finally:
274
+ remove_lora_hooks(handles)
275
+
276
+ return loss, {
277
+ "timestep": timestep.item(),
278
+ "sigma": sigma.item(),
279
+ "target_norm": target_delta.norm().item(),
280
+ "lora_norm": lora_delta.norm().item()
281
+ }
282
+
283
+
284
+ def compute_preservation_loss(
285
+ unet: torch.nn.Module,
286
+ lora_state: dict,
287
+ preservation: PreservationSet,
288
+ tokenizer,
289
+ text_encoder,
290
+ config: LECOConfig,
291
+ device: str = "cuda"
292
+ ):
293
+ """Penalize LoRA changes to preservation prompts"""
294
+ if not preservation.prompts:
295
+ return 0.0, {}
296
+
297
+ min_sigma = config.min_timestep / 1000.0
298
+ max_sigma = config.max_timestep / 1000.0
299
+ sigma = min_sigma + torch.rand(1, device=device) * (max_sigma - min_sigma)
300
+ sigma = (config.shift * sigma) / (1 + (config.shift - 1) * sigma)
301
+ timestep = sigma * 1000.0
302
+ sigma = sigma.view(1, 1, 1, 1)
303
+
304
+ total_loss = 0
305
+
306
+ for prompt in preservation.prompts:
307
+ noise = torch.randn(1, 4, config.resolution // 8, config.resolution // 8, device=device)
308
+ noisy_input = sigma * noise
309
+ prompt_emb = encode_text(prompt, tokenizer, text_encoder, device)
310
+
311
+ with torch.no_grad():
312
+ pred_base = unet(
313
+ noisy_input, timestep,
314
+ encoder_hidden_states=prompt_emb,
315
+ return_dict=False
316
+ )[0]
317
+
318
+ handles = apply_lora_hooks(unet, lora_state, scale=1.0)
319
+
320
+ try:
321
+ pred_with_lora = unet(
322
+ noisy_input, timestep,
323
+ encoder_hidden_states=prompt_emb,
324
+ return_dict=False
325
+ )[0]
326
+ finally:
327
+ remove_lora_hooks(handles)
328
+
329
+ total_loss += F.mse_loss(pred_with_lora, pred_base)
330
+
331
+ avg_loss = total_loss / len(preservation.prompts)
332
+ return avg_loss, {"count": len(preservation.prompts), "avg": avg_loss.item()}
333
+
334
+
335
+ def train_leco(config: LECOConfig):
336
+ """Main training loop"""
337
+ device = "cuda"
338
+ torch.manual_seed(config.seed)
339
+
340
+ if not config.concept_pairs:
341
+ raise ValueError("No concept pairs specified!")
342
+
343
+ # Setup output
344
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
345
+ concept_names = "_".join([
346
+ p.concept.replace(" ", "")[:12]
347
+ for p in config.concept_pairs[:2]
348
+ ])
349
+ if len(config.concept_pairs) > 2:
350
+ concept_names += f"_plus{len(config.concept_pairs)-2}"
351
+
352
+ run_name = f"{config.action.value}_{concept_names}_{timestamp}"
353
+ output_dir = os.path.join(config.output_dir, run_name)
354
+ os.makedirs(output_dir, exist_ok=True)
355
+
356
+ writer = SummaryWriter(log_dir=output_dir, flush_secs=60)
357
+
358
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
359
+ json.dump(asdict(config), f, indent=2)
360
+
361
+ print("="*80)
362
+ print(f"LECO Training: {config.action.value.upper()}")
363
+ print("="*80)
364
+
365
+ # Load model
366
+ print("\nLoading base model...")
367
+ checkpoint_path = hf_hub_download(
368
+ repo_id=config.base_model_repo,
369
+ filename=config.base_checkpoint,
370
+ repo_type="model"
371
+ )
372
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
373
+
374
+ unet = UNet2DConditionModel.from_pretrained(
375
+ "runwayml/stable-diffusion-v1-5",
376
+ subfolder="unet",
377
+ torch_dtype=torch.float32
378
+ )
379
+
380
+ student_dict = checkpoint["student"]
381
+ cleaned_dict = {k[5:] if k.startswith("unet.") else k: v for k, v in student_dict.items()}
382
+ unet.load_state_dict(cleaned_dict, strict=False)
383
+ unet = unet.to(device)
384
+ unet.requires_grad_(False)
385
+ unet.eval()
386
+ print("✓ Loaded UNet")
387
+
388
+ # Load CLIP
389
+ print("Loading CLIP text encoder...")
390
+ tokenizer = CLIPTokenizer.from_pretrained(
391
+ "runwayml/stable-diffusion-v1-5", subfolder="tokenizer"
392
+ )
393
+ text_encoder = CLIPTextModel.from_pretrained(
394
+ "runwayml/stable-diffusion-v1-5", subfolder="text_encoder",
395
+ torch_dtype=torch.float32
396
+ ).to(device)
397
+ text_encoder.requires_grad_(False)
398
+ text_encoder.eval()
399
+ print("✓ Loaded CLIP")
400
+
401
+ # Create LoRA
402
+ print(f"\nInjecting LoRA (rank={config.lora_rank}, alpha={config.lora_alpha})...")
403
+ lora_state, trainable_params = create_lora_layers(unet, config)
404
+
405
+ for key in lora_state:
406
+ if isinstance(lora_state[key], torch.Tensor):
407
+ lora_state[key] = lora_state[key].to(device)
408
+
409
+ optimizer = torch.optim.AdamW(trainable_params, lr=config.lr, weight_decay=0.01)
410
+
411
+ # Print config
412
+ print(f"\nTraining Configuration:")
413
+ print(f" Action: {config.action.value}")
414
+ print(f" Concept pairs: {len(config.concept_pairs)}")
415
+ for i, pair in enumerate(config.concept_pairs, 1):
416
+ anchor_str = f"→ '{pair.anchor}'" if pair.anchor else "(none)"
417
+ print(f" {i}. '{pair.concept}' {anchor_str} (weight: {pair.weight})")
418
+
419
+ if config.preservation.prompts:
420
+ print(f" Preservation: {len(config.preservation.prompts)} prompts")
421
+
422
+ print(f"\n Iterations: {config.iterations}")
423
+ print(f" Learning rate: {config.lr}")
424
+ print(f" Training method: {config.training_method}")
425
+ print("="*80 + "\n")
426
+
427
+ # Training loop
428
+ progress = tqdm(range(config.iterations), desc="Training")
429
+
430
+ for step in progress:
431
+ import random
432
+ if config.pairs_per_step >= len(config.concept_pairs):
433
+ active_pairs = config.concept_pairs
434
+ else:
435
+ active_pairs = random.sample(config.concept_pairs, config.pairs_per_step)
436
+
437
+ total_loss = 0
438
+ all_metrics = []
439
+
440
+ for pair in active_pairs:
441
+ loss, metrics = compute_leco_loss(
442
+ unet, lora_state, pair,
443
+ tokenizer, text_encoder, config, device
444
+ )
445
+ total_loss += loss * pair.weight
446
+ all_metrics.append(metrics)
447
+
448
+ if config.preservation.prompts:
449
+ pres_loss, pres_metrics = compute_preservation_loss(
450
+ unet, lora_state, config.preservation,
451
+ tokenizer, text_encoder, config, device
452
+ )
453
+ total_loss += pres_loss * config.preservation.weight
454
+ else:
455
+ pres_loss = 0
456
+
457
+ total_loss.backward()
458
+ grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
459
+ optimizer.step()
460
+ optimizer.zero_grad()
461
+
462
+ # Logging
463
+ writer.add_scalar("loss/total", total_loss.item(), step)
464
+ writer.add_scalar("loss/preservation", pres_loss if isinstance(pres_loss, (float, int)) else pres_loss.item(), step)
465
+ writer.add_scalar("grad_norm", grad_norm.item(), step)
466
+
467
+ avg_target = sum(m["target_norm"] for m in all_metrics) / len(all_metrics)
468
+ progress.set_postfix({
469
+ "loss": f"{total_loss.item():.4f}",
470
+ "grad": f"{grad_norm.item():.3f}",
471
+ "target": f"{avg_target:.3f}"
472
+ })
473
+
474
+ if (step + 1) % 200 == 0 or step == config.iterations - 1:
475
+ save_checkpoint(lora_state, config, output_dir, step + 1, concept_names)
476
+
477
+ writer.close()
478
+
479
+ print("\n" + "="*80)
480
+ print("✅ Training complete!")
481
+ print(f"Output: {output_dir}")
482
+ print("="*80)
483
+
484
+ return output_dir
485
+
486
+
487
+ def save_checkpoint(lora_state, config, output_dir, step, name_suffix):
488
+ """Save LoRA in SafeTensors format"""
489
+ save_dict = {}
490
+
491
+ for key, value in lora_state.items():
492
+ if isinstance(value, torch.Tensor) and not key.endswith("._module"):
493
+ save_dict[key] = value.detach().cpu()
494
+
495
+ concepts_str = ", ".join([p.concept for p in config.concept_pairs])
496
+ anchors_str = ", ".join([p.anchor for p in config.concept_pairs if p.anchor])
497
+
498
+ metadata = {
499
+ "ss_network_module": "networks.lora",
500
+ "ss_network_dim": str(config.lora_rank),
501
+ "ss_network_alpha": str(config.lora_alpha),
502
+ "ss_base_model": "runwayml/stable-diffusion-v1-5",
503
+ "ss_training_method": config.training_method,
504
+ "leco_action": config.action.value,
505
+ "leco_concepts": concepts_str,
506
+ "leco_anchors": anchors_str,
507
+ "leco_step": str(step)
508
+ }
509
+
510
+ filename = f"leco_{name_suffix}_r{config.lora_rank}_s{step}.safetensors"
511
+ filepath = os.path.join(output_dir, filename)
512
+
513
+ save_file(save_dict, filepath, metadata=metadata)
514
+ print(f"\n✓ Saved: {filename}")
515
+
516
+
517
+ if __name__ == "__main__":
518
+ config = LECOConfig(
519
+ action=ActionType.ENHANCE,
520
+ concept_pairs=[
521
+ ConceptPair("masterpiece", "", weight=1.0),
522
+ ConceptPair("best quality", "", weight=1.0),
523
+ ConceptPair("highly detailed", "", weight=0.8),
524
+ ],
525
+ iterations=600,
526
+ lora_rank=4,
527
+ training_method="selfattn"
528
+ )
529
+
530
+ train_leco(config)