AbstractPhil commited on
Commit
40aa172
·
verified ·
1 Parent(s): ce52ba3

Create ablation_trainer.py

Browse files
Files changed (1) hide show
  1. ablation_trainer.py +617 -0
ablation_trainer.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TinyFlux LoRA Training - Colab Edition
3
+
4
+ Simple setup for testing LoRA with a small local dataset.
5
+
6
+ Directory structure expected:
7
+ /content/drive/MyDrive/lora_dataset/
8
+ image1.png
9
+ image1.txt (caption)
10
+ image2.jpg
11
+ image2.txt
12
+ ...
13
+
14
+ Or with a single prompts file:
15
+ /content/drive/MyDrive/lora_dataset/
16
+ image1.png
17
+ image2.jpg
18
+ prompts.txt (one line per image, alphabetical order)
19
+
20
+ Usage:
21
+ from tinyflux.examples.train_lora_colab import train_lora, LoRAConfig
22
+
23
+ config = LoRAConfig(
24
+ data_dir="/content/drive/MyDrive/lora_dataset",
25
+ output_dir="/content/lora_output",
26
+ hf_repo="AbstractPhil/tiny-flux-lora",
27
+ hf_subdir="my_lora_v1",
28
+ repeats=100,
29
+ steps=1000,
30
+ )
31
+
32
+ train_lora(config)
33
+ """
34
+
35
+ import os
36
+ import torch
37
+ from typing import Optional, List
38
+ from dataclasses import dataclass, field
39
+
40
+
41
+ @dataclass
42
+ class LoRAConfig:
43
+ """Configuration for LoRA training."""
44
+
45
+ # Data
46
+ data_dir: str = "/content/drive/MyDrive/lora_dataset"
47
+ output_dir: str = "/content/lora_output"
48
+
49
+ # Dataset inflation
50
+ repeats: int = 100 # Repeat each image N times per epoch
51
+
52
+ # LoRA configuration
53
+ # Preset: "minimal", "standard", "character", "concept", "full", "progressive"
54
+ # Or path to JSON config file
55
+ lora_config: str = "standard"
56
+
57
+ # Override defaults (applied on top of preset/config)
58
+ lora_rank: Optional[int] = None
59
+ lora_alpha: Optional[float] = None
60
+
61
+ # Model extensions
62
+ extra_single_blocks: int = 0
63
+ extra_double_blocks: int = 0
64
+
65
+ # Training (epoch-based)
66
+ epochs: int = 10
67
+ batch_size: int = 16
68
+ lr: float = 1e-3
69
+ warmup_epochs: float = 0.5
70
+ train_resolution: int = 512
71
+
72
+ # Checkpoints
73
+ save_every_epoch: int = 1
74
+
75
+ # HuggingFace upload
76
+ hf_repo: Optional[str] = "AbstractPhil/tinyflux-lailah-loras"
77
+ hf_subdir: str = "lora_v2_man_wearing_brown_cap_single_blocks_1e-3_with_lune"
78
+ upload_every_epoch: int = 2
79
+
80
+ # Sampling
81
+ sample_prompts: List[str] = field(default_factory=lambda: [
82
+ "a red cube on a blue sphere",
83
+ "a cat sitting on a table",
84
+ "A man wearing a brown cap looking sitting at his computer with a black and brown dog resting next to him on the couch."
85
+ "A man wearing a brown cap looking at his computer.,"
86
+ ])
87
+ sample_every_epoch: bool = True
88
+ sample_steps: int = 50
89
+ sample_cfg: float = 7.5
90
+ sample_seed: int = 42
91
+
92
+ # Experts
93
+ build_lune: bool = True
94
+ build_sol: bool = True
95
+
96
+ # Base model
97
+ base_repo: str = "AbstractPhil/tiny-flux-deep"
98
+ base_weights: str = "step_417054.pt"
99
+
100
+ def build_lora_config(self):
101
+ """Build TinyFluxLoRAConfig from training config."""
102
+ from tinyflux.model.lora_config import TinyFluxLoRAConfig, LoRADefaults, BlockExtensions
103
+
104
+ # Load from preset or file
105
+ if self.lora_config.endswith('.json'):
106
+ cfg = TinyFluxLoRAConfig.load(self.lora_config)
107
+ else:
108
+ cfg = TinyFluxLoRAConfig.from_preset(self.lora_config)
109
+
110
+ # Apply overrides
111
+ if self.lora_rank is not None:
112
+ cfg.defaults.rank = self.lora_rank
113
+ if self.lora_alpha is not None:
114
+ cfg.defaults.alpha = self.lora_alpha
115
+
116
+ # Apply extensions
117
+ if self.extra_single_blocks > 0 or self.extra_double_blocks > 0:
118
+ cfg.extensions = BlockExtensions(
119
+ single_blocks=self.extra_single_blocks,
120
+ double_blocks=self.extra_double_blocks,
121
+ )
122
+
123
+ return cfg
124
+
125
+
126
+ def upload_to_hf(
127
+ local_path: str,
128
+ repo_id: str,
129
+ subdir: str,
130
+ filename: Optional[str] = None,
131
+ ):
132
+ """Upload file to HuggingFace repo."""
133
+ from huggingface_hub import HfApi
134
+
135
+ api = HfApi()
136
+
137
+ if filename is None:
138
+ filename = os.path.basename(local_path)
139
+
140
+ path_in_repo = f"{subdir}/{filename}" if subdir else filename
141
+
142
+ try:
143
+ api.upload_file(
144
+ path_or_fileobj=local_path,
145
+ path_in_repo=path_in_repo,
146
+ repo_id=repo_id,
147
+ repo_type="model",
148
+ )
149
+ print(f" ✓ Uploaded to {repo_id}/{path_in_repo}")
150
+ except Exception as e:
151
+ print(f" ✗ Upload failed: {e}")
152
+
153
+
154
+ def train_lora(config: Optional[LoRAConfig] = None, **kwargs):
155
+ """
156
+ Main training function for Colab.
157
+
158
+ Args:
159
+ config: LoRAConfig instance, or pass kwargs directly
160
+ """
161
+ import torch.nn.functional as F
162
+ from tqdm.auto import tqdm
163
+
164
+ # Build config from kwargs if not provided
165
+ if config is None:
166
+ config = LoRAConfig(**kwargs)
167
+
168
+ device = "cuda" if torch.cuda.is_available() else "cpu"
169
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
170
+
171
+ print("=" * 60)
172
+ print("TinyFlux LoRA Training")
173
+ print("=" * 60)
174
+ print(f"Device: {device}")
175
+ print(f"Data: {config.data_dir}")
176
+ print(f"Repeats: {config.repeats}")
177
+ print(f"LoRA config: {config.lora_config}")
178
+ rank_info = f", rank={config.lora_rank}" if config.lora_rank else ""
179
+ print(f"Epochs: {config.epochs}{rank_info}, LR: {config.lr}")
180
+ print(f"Train resolution: {config.train_resolution}x{config.train_resolution}")
181
+
182
+ # Memory estimate
183
+ latent_size = config.train_resolution // 8
184
+ tokens = latent_size * latent_size
185
+ print(f" Latent: {latent_size}x{latent_size} = {tokens} tokens")
186
+
187
+ if config.hf_repo:
188
+ print(f"HF Upload: {config.hf_repo}/{config.hf_subdir} every {config.upload_every_epoch} epochs")
189
+
190
+ os.makedirs(config.output_dir, exist_ok=True)
191
+ cache_dir = os.path.join(config.output_dir, "cache")
192
+ samples_dir = os.path.join(config.output_dir, "samples")
193
+ os.makedirs(samples_dir, exist_ok=True)
194
+
195
+ # =========================================================================
196
+ # 1. Load dataset
197
+ # =========================================================================
198
+ print("\n[1/6] Loading images...")
199
+
200
+ from tinyflux.trainer.data_directory import (
201
+ DirectoryDataset,
202
+ create_dataloader,
203
+ )
204
+
205
+ raw_dataset = DirectoryDataset(config.data_dir, repeats=1, target_size=512)
206
+ images, prompts = raw_dataset.get_images_and_prompts()
207
+ n_images = len(images)
208
+
209
+ print(f" Found {n_images} images")
210
+
211
+ # =========================================================================
212
+ # 2. Build cache
213
+ # =========================================================================
214
+ print("\n[2/6] Building cache...")
215
+
216
+ from tinyflux.model.zoo import ModelZoo
217
+ from tinyflux.trainer.cache_experts import DatasetCache
218
+
219
+ zoo = ModelZoo(device=device, dtype=dtype)
220
+
221
+ cache_meta = os.path.join(cache_dir, "meta.pt")
222
+ if os.path.exists(cache_meta):
223
+ print(" Loading existing cache...")
224
+ cache = DatasetCache.load(cache_dir)
225
+ else:
226
+ print(" Building new cache (this takes a few minutes)...")
227
+ cache = DatasetCache.build(
228
+ zoo,
229
+ images,
230
+ prompts,
231
+ name="lora_dataset",
232
+ build_lune=config.build_lune,
233
+ build_sol=config.build_sol,
234
+ batch_size=min(4, n_images),
235
+ sol_batch_size=1,
236
+ dtype=torch.float16,
237
+ compile_experts=False,
238
+ )
239
+ cache.save(cache_dir)
240
+
241
+ print(f" Cache: {len(cache)} samples")
242
+
243
+ # Free cache-building memory - unload ALL models
244
+ del images, raw_dataset
245
+ zoo.unload("vae")
246
+ zoo.unload("t5")
247
+ zoo.unload("clip")
248
+ zoo.unload("lune")
249
+ zoo.unload("sol")
250
+ torch.cuda.empty_cache()
251
+
252
+ # =========================================================================
253
+ # 3. Load model + inject LoRA
254
+ # =========================================================================
255
+ print("\n[3/6] Loading model...")
256
+
257
+ from tinyflux.model.lora import TinyFluxLoRA
258
+ from tinyflux.model.lora_config import TinyFluxLoRAConfig
259
+
260
+ model = zoo.load_tinyflux(
261
+ source=config.base_repo,
262
+ ema_path=config.base_weights,
263
+ train_mode=True,
264
+ )
265
+
266
+ # Memory optimizations for T4/Colab
267
+ # Enable memory efficient attention
268
+ torch.backends.cuda.enable_flash_sdp(True)
269
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
270
+ print(" Memory-efficient attention enabled")
271
+
272
+ print(f"\n[4/6] Injecting LoRA ({config.lora_config})...")
273
+
274
+ # Build LoRA config from training config
275
+ lora_cfg = config.build_lora_config()
276
+
277
+ # Create LoRA with flexible config
278
+ lora = TinyFluxLoRA(model, config=lora_cfg)
279
+
280
+ # Use per-layer LR groups if available
281
+ has_lr_groups = len(lora_cfg.get_lr_groups(1.0)) > 1
282
+
283
+ # =========================================================================
284
+ # 4. Setup sampler (lazy - will load encoders only when sampling)
285
+ # =========================================================================
286
+ print("\n[5/6] Setting up sampler...")
287
+
288
+ from tinyflux.trainer.sampling import Sampler, save_samples
289
+
290
+ # Don't load encoders yet - will load on demand for sampling
291
+ # This saves ~3GB VRAM during training
292
+ sampler = None # Created lazily
293
+
294
+ def do_sample(epoch_num: int) -> Optional[str]:
295
+ """Generate and save samples, loading encoders as needed."""
296
+ nonlocal sampler
297
+
298
+ if not config.sample_prompts:
299
+ return None
300
+
301
+ # Ensure encoders are loaded and on GPU
302
+ if zoo.vae is None:
303
+ zoo.load_vae()
304
+ else:
305
+ zoo.onload("vae")
306
+
307
+ if zoo.t5 is None:
308
+ zoo.load_t5()
309
+ else:
310
+ zoo.onload("t5")
311
+
312
+ if zoo.clip is None:
313
+ zoo.load_clip()
314
+ else:
315
+ zoo.onload("clip")
316
+
317
+ # Create sampler if needed
318
+ if sampler is None:
319
+ print(" Initializing sampler...")
320
+ sampler = Sampler(
321
+ zoo=zoo,
322
+ model=model,
323
+ ema=None,
324
+ num_steps=config.sample_steps,
325
+ guidance_scale=config.sample_cfg,
326
+ shift=3.0,
327
+ device=device,
328
+ dtype=dtype,
329
+ )
330
+
331
+ model.eval()
332
+ with torch.no_grad():
333
+ sample_images = sampler.generate(
334
+ config.sample_prompts,
335
+ seed=config.sample_seed,
336
+ )
337
+ sample_path = save_samples(
338
+ sample_images,
339
+ config.sample_prompts,
340
+ epoch_num,
341
+ samples_dir,
342
+ )
343
+ print(f" Saved: {sample_path}")
344
+
345
+ if config.hf_repo:
346
+ upload_to_hf(
347
+ sample_path,
348
+ config.hf_repo,
349
+ f"{config.hf_subdir}/samples",
350
+ )
351
+
352
+ model.train()
353
+
354
+ # On A100 (40GB+), don't offload - plenty of VRAM
355
+ # Only offload on smaller GPUs to fit training
356
+ if torch.cuda.get_device_properties(0).total_memory < 20e9:
357
+ zoo.offload("vae")
358
+ zoo.offload("t5")
359
+ zoo.offload("clip")
360
+ torch.cuda.empty_cache()
361
+
362
+ return sample_path
363
+
364
+ # =========================================================================
365
+ # 5. Training loop (epoch-based)
366
+ # =========================================================================
367
+ print("\n[6/6] Training...")
368
+
369
+ from tinyflux.trainer.schedules import sample_timesteps
370
+ from tinyflux.utils.predictions import flow_x_t, flow_velocity
371
+ from tinyflux.model.model import TinyFluxDeep
372
+
373
+ loader = create_dataloader(
374
+ cache,
375
+ repeats=config.repeats,
376
+ batch_size=config.batch_size,
377
+ shuffle=True,
378
+ num_workers=8
379
+ )
380
+
381
+ # Calculate training metrics
382
+ steps_per_epoch = len(loader)
383
+ total_steps = steps_per_epoch * config.epochs
384
+ warmup_steps = int(config.warmup_epochs * steps_per_epoch)
385
+
386
+ print(f" {n_images} images × {config.repeats} repeats = {steps_per_epoch} steps/epoch")
387
+ print(f" {config.epochs} epochs = {total_steps} total steps")
388
+ print(f" Warmup: {warmup_steps} steps ({config.warmup_epochs} epochs)")
389
+
390
+ # Use per-layer LR groups if config has multiple lr_scales
391
+ if has_lr_groups:
392
+ param_groups = lora.get_param_groups(config.lr)
393
+ optimizer = torch.optim.AdamW(param_groups, weight_decay=0.01)
394
+ print(f" Using {len(param_groups)} LR groups")
395
+ else:
396
+ optimizer = torch.optim.AdamW(lora.parameters(), lr=config.lr, weight_decay=0.01)
397
+
398
+ def lr_lambda(step):
399
+ if step < warmup_steps:
400
+ return step / warmup_steps
401
+ return 1.0
402
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
403
+
404
+ model.train()
405
+ global_step = 0
406
+ running_loss = 0.0
407
+ log_every = max(1, steps_per_epoch // 10) # Log ~10 times per epoch
408
+
409
+ for epoch in range(1, config.epochs + 1):
410
+ epoch_loss = 0.0
411
+ epoch_steps = 0
412
+
413
+ pbar = tqdm(loader, desc=f"Epoch {epoch}/{config.epochs}")
414
+
415
+ for batch in pbar:
416
+ indices = batch['index']
417
+ B = len(indices)
418
+
419
+ # Get cached encodings
420
+ latents, t5_embed, clip_embed = cache.get_encodings_batch(indices)
421
+ latents = latents.to(device, dtype=dtype)
422
+ t5_embed = t5_embed.to(device, dtype=dtype)
423
+ clip_embed = clip_embed.to(device, dtype=dtype)
424
+
425
+ # Resize latents if training at different resolution
426
+ target_latent_size = config.train_resolution // 8
427
+ if latents.shape[-1] != target_latent_size:
428
+ latents = torch.nn.functional.interpolate(
429
+ latents,
430
+ size=(target_latent_size, target_latent_size),
431
+ mode='bilinear',
432
+ align_corners=False,
433
+ )
434
+
435
+ H = W = latents.shape[-1]
436
+
437
+ # Sample timesteps
438
+ t = sample_timesteps(B, device=device, dtype=dtype, shift=3.0)
439
+
440
+ # Get expert features
441
+ lune_features = cache.get_lune(indices, t)
442
+ if lune_features is not None:
443
+ lune_features = lune_features.to(device, dtype=dtype)
444
+
445
+ sol_stats, sol_spatial = cache.get_sol(indices, t)
446
+ if sol_stats is not None:
447
+ sol_stats = sol_stats.to(device, dtype=dtype)
448
+ sol_spatial = sol_spatial.to(device, dtype=dtype)
449
+
450
+ # Flow matching
451
+ noise = torch.randn_like(latents)
452
+ x_t = flow_x_t(latents, noise, t)
453
+ v_target = flow_velocity(latents, noise)
454
+
455
+ # Reshape for model
456
+ x_t_seq = x_t.flatten(2).transpose(1, 2)
457
+ v_target_seq = v_target.flatten(2).transpose(1, 2)
458
+
459
+ # Position IDs
460
+ img_ids = TinyFluxDeep.create_img_ids(B, H, W, device)
461
+
462
+ # Forward
463
+ optimizer.zero_grad()
464
+
465
+ with torch.autocast(device, dtype=dtype):
466
+ v_pred = model(
467
+ hidden_states=x_t_seq,
468
+ encoder_hidden_states=t5_embed,
469
+ pooled_projections=clip_embed,
470
+ timestep=t,
471
+ img_ids=img_ids,
472
+ lune_features=lune_features,
473
+ sol_stats=sol_stats,
474
+ sol_spatial=sol_spatial,
475
+ )
476
+
477
+ loss = F.mse_loss(v_pred, v_target_seq)
478
+
479
+ loss.backward()
480
+ torch.nn.utils.clip_grad_norm_(lora.parameters(), 1.0)
481
+ optimizer.step()
482
+ scheduler.step()
483
+
484
+ # Logging
485
+ loss_val = loss.item()
486
+ running_loss += loss_val
487
+ epoch_loss += loss_val
488
+ global_step += 1
489
+ epoch_steps += 1
490
+
491
+ if global_step % log_every == 0:
492
+ avg_loss = running_loss / log_every
493
+ pbar.set_postfix(
494
+ loss=f"{avg_loss:.4f}",
495
+ lr=f"{scheduler.get_last_lr()[0]:.2e}",
496
+ )
497
+ running_loss = 0.0
498
+
499
+ # End of epoch
500
+ avg_epoch_loss = epoch_loss / epoch_steps
501
+ print(f" Epoch {epoch} complete | Loss: {avg_epoch_loss:.4f}")
502
+
503
+ # Checkpoint every N epochs
504
+ if epoch % config.save_every_epoch == 0:
505
+ ckpt_path = os.path.join(config.output_dir, f"lora_epoch_{epoch}.safetensors")
506
+ lora.save(ckpt_path)
507
+ print(f" Saved: {ckpt_path}")
508
+
509
+ # Upload every N epochs
510
+ if config.hf_repo and epoch % config.upload_every_epoch == 0:
511
+ ckpt_path = os.path.join(config.output_dir, f"lora_epoch_{epoch}.safetensors")
512
+ if not os.path.exists(ckpt_path):
513
+ lora.save(ckpt_path)
514
+ upload_to_hf(ckpt_path, config.hf_repo, config.hf_subdir)
515
+
516
+ # Sample every epoch
517
+ if config.sample_every_epoch and config.sample_prompts:
518
+ print(f" Generating samples...")
519
+ do_sample(epoch)
520
+
521
+ # Final save
522
+ final_path = os.path.join(config.output_dir, "lora_final.safetensors")
523
+ lora.save(final_path)
524
+
525
+ # Final upload
526
+ if config.hf_repo:
527
+ upload_to_hf(final_path, config.hf_repo, config.hf_subdir, "lora_final.safetensors")
528
+
529
+ # Final sample
530
+ if config.sample_prompts:
531
+ print("\nGenerating final samples...")
532
+ do_sample(config.epochs)
533
+
534
+ print("\n" + "=" * 60)
535
+ print("Training complete!")
536
+ print(f" Epochs: {config.epochs}")
537
+ print(f" Total steps: {total_steps}")
538
+ print(f" Final LoRA: {final_path}")
539
+ if config.hf_repo:
540
+ print(f" HF Repo: https://huggingface.co/{config.hf_repo}/tree/main/{config.hf_subdir}")
541
+ print("=" * 60)
542
+
543
+ return model, lora
544
+
545
+
546
+ # =============================================================================
547
+ # Colab cell helper
548
+ # =============================================================================
549
+
550
+ COLAB_SETUP = """
551
+ # Cell 1: Mount Drive and install
552
+ from google.colab import drive
553
+ drive.mount('/content/drive')
554
+
555
+ !pip install -q safetensors accelerate huggingface_hub
556
+ !pip install -q git+https://github.com/AbstractPhil/tinyflux.git
557
+
558
+ # Cell 2: Login to HuggingFace (for uploads)
559
+ from huggingface_hub import login
560
+ from google.colab import userdata
561
+ login(userdata.get("HF_TOKEN"))
562
+
563
+ # Cell 3: Train!
564
+ from tinyflux.examples.train_lora_colab import train_lora, LoRAConfig
565
+
566
+ config = LoRAConfig(
567
+ # Data
568
+ data_dir="/content/drive/MyDrive/test_1024",
569
+ output_dir="/content/lora_output",
570
+ repeats=100, # 10 images × 100 repeats = 1000 steps/epoch
571
+
572
+ # LoRA config: preset name or path to JSON file
573
+ # Presets: "minimal", "standard", "character", "concept", "full", "progressive"
574
+ lora_config="character",
575
+
576
+ # Optional: override rank from preset
577
+ lora_rank=None, # Set to override default
578
+
579
+ # Training
580
+ epochs=10,
581
+ batch_size=1,
582
+ lr=1e-4,
583
+ train_resolution=512, # 512 for A100, 256 for T4
584
+
585
+ # HuggingFace
586
+ hf_repo="AbstractPhil/tinyflux-lailah-loras",
587
+ hf_subdir="my_character_v1",
588
+ upload_every_epoch=2,
589
+
590
+ # Sampling
591
+ sample_prompts=[
592
+ "a red cube on a blue sphere",
593
+ "A man wearing a brown cap sitting at his computer with a black and brown dog resting next to him on the couch.",
594
+ ],
595
+ sample_every_epoch=True,
596
+ )
597
+
598
+ model, lora = train_lora(config)
599
+ """
600
+
601
+ if __name__ == "__main__":
602
+ from huggingface_hub import login
603
+ from google.colab import userdata
604
+ login(userdata.get("HF_TOKEN"))
605
+
606
+ config = LoRAConfig(
607
+ data_dir="/content/drive/MyDrive/test_1024",
608
+ output_dir="/content/lora_output3_no_experts_full",
609
+ repeats=100,
610
+ epochs=10,
611
+ lora_config="full",
612
+ build_sol=False,
613
+ build_lune=False,
614
+ train_resolution=512,
615
+ )
616
+
617
+ model, lora = train_lora(config)