|
|
""" |
|
|
TinyFlux LoRA Training - Colab Edition |
|
|
|
|
|
Simple setup for testing LoRA with a small local dataset. |
|
|
|
|
|
Directory structure expected: |
|
|
/content/drive/MyDrive/lora_dataset/ |
|
|
image1.png |
|
|
image1.txt (caption) |
|
|
image2.jpg |
|
|
image2.txt |
|
|
... |
|
|
|
|
|
Or with a single prompts file: |
|
|
/content/drive/MyDrive/lora_dataset/ |
|
|
image1.png |
|
|
image2.jpg |
|
|
prompts.txt (one line per image, alphabetical order) |
|
|
|
|
|
Usage: |
|
|
from tinyflux.examples.train_lora_colab import train_lora, LoRAConfig |
|
|
|
|
|
config = LoRAConfig( |
|
|
data_dir="/content/drive/MyDrive/lora_dataset", |
|
|
output_dir="/content/lora_output", |
|
|
hf_repo="AbstractPhil/tiny-flux-lora", |
|
|
hf_subdir="my_lora_v1", |
|
|
repeats=100, |
|
|
steps=1000, |
|
|
) |
|
|
|
|
|
train_lora(config) |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
from typing import Optional, List |
|
|
from dataclasses import dataclass, field |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LoRAConfig: |
|
|
"""Configuration for LoRA training.""" |
|
|
|
|
|
|
|
|
data_dir: str = "/content/drive/MyDrive/lora_dataset" |
|
|
output_dir: str = "/content/lora_output" |
|
|
|
|
|
|
|
|
repeats: int = 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lora_config: str = "standard" |
|
|
|
|
|
|
|
|
lora_rank: Optional[int] = None |
|
|
lora_alpha: Optional[float] = None |
|
|
|
|
|
|
|
|
extra_single_blocks: int = 0 |
|
|
extra_double_blocks: int = 0 |
|
|
|
|
|
|
|
|
epochs: int = 10 |
|
|
batch_size: int = 16 |
|
|
lr: float = 1e-3 |
|
|
warmup_epochs: float = 0.5 |
|
|
train_resolution: int = 512 |
|
|
|
|
|
|
|
|
save_every_epoch: int = 1 |
|
|
|
|
|
|
|
|
hf_repo: Optional[str] = "AbstractPhil/tinyflux-lailah-loras" |
|
|
hf_subdir: str = "lora_v2_man_wearing_brown_cap_single_blocks_1e-3_with_lune" |
|
|
upload_every_epoch: int = 2 |
|
|
|
|
|
|
|
|
sample_prompts: List[str] = field(default_factory=lambda: [ |
|
|
"a red cube on a blue sphere", |
|
|
"a cat sitting on a table", |
|
|
"A man wearing a brown cap looking sitting at his computer with a black and brown dog resting next to him on the couch." |
|
|
"A man wearing a brown cap looking at his computer.," |
|
|
]) |
|
|
sample_every_epoch: bool = True |
|
|
sample_steps: int = 50 |
|
|
sample_cfg: float = 7.5 |
|
|
sample_seed: int = 42 |
|
|
|
|
|
|
|
|
build_lune: bool = True |
|
|
build_sol: bool = True |
|
|
|
|
|
|
|
|
base_repo: str = "AbstractPhil/tiny-flux-deep" |
|
|
base_weights: str = "step_417054.pt" |
|
|
|
|
|
def build_lora_config(self): |
|
|
"""Build TinyFluxLoRAConfig from training config.""" |
|
|
from tinyflux.model.lora_config import TinyFluxLoRAConfig, LoRADefaults, BlockExtensions |
|
|
|
|
|
|
|
|
if self.lora_config.endswith('.json'): |
|
|
cfg = TinyFluxLoRAConfig.load(self.lora_config) |
|
|
else: |
|
|
cfg = TinyFluxLoRAConfig.from_preset(self.lora_config) |
|
|
|
|
|
|
|
|
if self.lora_rank is not None: |
|
|
cfg.defaults.rank = self.lora_rank |
|
|
if self.lora_alpha is not None: |
|
|
cfg.defaults.alpha = self.lora_alpha |
|
|
|
|
|
|
|
|
if self.extra_single_blocks > 0 or self.extra_double_blocks > 0: |
|
|
cfg.extensions = BlockExtensions( |
|
|
single_blocks=self.extra_single_blocks, |
|
|
double_blocks=self.extra_double_blocks, |
|
|
) |
|
|
|
|
|
return cfg |
|
|
|
|
|
|
|
|
def upload_to_hf( |
|
|
local_path: str, |
|
|
repo_id: str, |
|
|
subdir: str, |
|
|
filename: Optional[str] = None, |
|
|
): |
|
|
"""Upload file to HuggingFace repo.""" |
|
|
from huggingface_hub import HfApi |
|
|
|
|
|
api = HfApi() |
|
|
|
|
|
if filename is None: |
|
|
filename = os.path.basename(local_path) |
|
|
|
|
|
path_in_repo = f"{subdir}/{filename}" if subdir else filename |
|
|
|
|
|
try: |
|
|
api.upload_file( |
|
|
path_or_fileobj=local_path, |
|
|
path_in_repo=path_in_repo, |
|
|
repo_id=repo_id, |
|
|
repo_type="model", |
|
|
) |
|
|
print(f" ✓ Uploaded to {repo_id}/{path_in_repo}") |
|
|
except Exception as e: |
|
|
print(f" ✗ Upload failed: {e}") |
|
|
|
|
|
|
|
|
def train_lora(config: Optional[LoRAConfig] = None, **kwargs): |
|
|
""" |
|
|
Main training function for Colab. |
|
|
|
|
|
Args: |
|
|
config: LoRAConfig instance, or pass kwargs directly |
|
|
""" |
|
|
import torch.nn.functional as F |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
|
|
|
if config is None: |
|
|
config = LoRAConfig(**kwargs) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.bfloat16 if device == "cuda" else torch.float32 |
|
|
|
|
|
print("=" * 60) |
|
|
print("TinyFlux LoRA Training") |
|
|
print("=" * 60) |
|
|
print(f"Device: {device}") |
|
|
print(f"Data: {config.data_dir}") |
|
|
print(f"Repeats: {config.repeats}") |
|
|
print(f"LoRA config: {config.lora_config}") |
|
|
rank_info = f", rank={config.lora_rank}" if config.lora_rank else "" |
|
|
print(f"Epochs: {config.epochs}{rank_info}, LR: {config.lr}") |
|
|
print(f"Train resolution: {config.train_resolution}x{config.train_resolution}") |
|
|
|
|
|
|
|
|
latent_size = config.train_resolution // 8 |
|
|
tokens = latent_size * latent_size |
|
|
print(f" Latent: {latent_size}x{latent_size} = {tokens} tokens") |
|
|
|
|
|
if config.hf_repo: |
|
|
print(f"HF Upload: {config.hf_repo}/{config.hf_subdir} every {config.upload_every_epoch} epochs") |
|
|
|
|
|
os.makedirs(config.output_dir, exist_ok=True) |
|
|
cache_dir = os.path.join(config.output_dir, "cache") |
|
|
samples_dir = os.path.join(config.output_dir, "samples") |
|
|
os.makedirs(samples_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n[1/6] Loading images...") |
|
|
|
|
|
from tinyflux.trainer.data_directory import ( |
|
|
DirectoryDataset, |
|
|
create_dataloader, |
|
|
) |
|
|
|
|
|
raw_dataset = DirectoryDataset(config.data_dir, repeats=1, target_size=512) |
|
|
images, prompts = raw_dataset.get_images_and_prompts() |
|
|
n_images = len(images) |
|
|
|
|
|
print(f" Found {n_images} images") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n[2/6] Building cache...") |
|
|
|
|
|
from tinyflux.model.zoo import ModelZoo |
|
|
from tinyflux.trainer.cache_experts import DatasetCache |
|
|
|
|
|
zoo = ModelZoo(device=device, dtype=dtype) |
|
|
|
|
|
cache_meta = os.path.join(cache_dir, "meta.pt") |
|
|
if os.path.exists(cache_meta): |
|
|
print(" Loading existing cache...") |
|
|
cache = DatasetCache.load(cache_dir) |
|
|
else: |
|
|
print(" Building new cache (this takes a few minutes)...") |
|
|
cache = DatasetCache.build( |
|
|
zoo, |
|
|
images, |
|
|
prompts, |
|
|
name="lora_dataset", |
|
|
build_lune=config.build_lune, |
|
|
build_sol=config.build_sol, |
|
|
batch_size=min(4, n_images), |
|
|
sol_batch_size=1, |
|
|
dtype=torch.float16, |
|
|
compile_experts=False, |
|
|
) |
|
|
cache.save(cache_dir) |
|
|
|
|
|
print(f" Cache: {len(cache)} samples") |
|
|
|
|
|
|
|
|
del images, raw_dataset |
|
|
zoo.unload("vae") |
|
|
zoo.unload("t5") |
|
|
zoo.unload("clip") |
|
|
zoo.unload("lune") |
|
|
zoo.unload("sol") |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n[3/6] Loading model...") |
|
|
|
|
|
from tinyflux.model.lora import TinyFluxLoRA |
|
|
from tinyflux.model.lora_config import TinyFluxLoRAConfig |
|
|
|
|
|
model = zoo.load_tinyflux( |
|
|
source=config.base_repo, |
|
|
ema_path=config.base_weights, |
|
|
train_mode=True, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
torch.backends.cuda.enable_flash_sdp(True) |
|
|
torch.backends.cuda.enable_mem_efficient_sdp(True) |
|
|
print(" Memory-efficient attention enabled") |
|
|
|
|
|
print(f"\n[4/6] Injecting LoRA ({config.lora_config})...") |
|
|
|
|
|
|
|
|
lora_cfg = config.build_lora_config() |
|
|
|
|
|
|
|
|
lora = TinyFluxLoRA(model, config=lora_cfg) |
|
|
|
|
|
|
|
|
has_lr_groups = len(lora_cfg.get_lr_groups(1.0)) > 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n[5/6] Setting up sampler...") |
|
|
|
|
|
from tinyflux.trainer.sampling import Sampler, save_samples |
|
|
|
|
|
|
|
|
|
|
|
sampler = None |
|
|
|
|
|
def do_sample(epoch_num: int) -> Optional[str]: |
|
|
"""Generate and save samples, loading encoders as needed.""" |
|
|
nonlocal sampler |
|
|
|
|
|
if not config.sample_prompts: |
|
|
return None |
|
|
|
|
|
|
|
|
if zoo.vae is None: |
|
|
zoo.load_vae() |
|
|
else: |
|
|
zoo.onload("vae") |
|
|
|
|
|
if zoo.t5 is None: |
|
|
zoo.load_t5() |
|
|
else: |
|
|
zoo.onload("t5") |
|
|
|
|
|
if zoo.clip is None: |
|
|
zoo.load_clip() |
|
|
else: |
|
|
zoo.onload("clip") |
|
|
|
|
|
|
|
|
if sampler is None: |
|
|
print(" Initializing sampler...") |
|
|
sampler = Sampler( |
|
|
zoo=zoo, |
|
|
model=model, |
|
|
ema=None, |
|
|
num_steps=config.sample_steps, |
|
|
guidance_scale=config.sample_cfg, |
|
|
shift=3.0, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
sample_images = sampler.generate( |
|
|
config.sample_prompts, |
|
|
seed=config.sample_seed, |
|
|
) |
|
|
sample_path = save_samples( |
|
|
sample_images, |
|
|
config.sample_prompts, |
|
|
epoch_num, |
|
|
samples_dir, |
|
|
) |
|
|
print(f" Saved: {sample_path}") |
|
|
|
|
|
if config.hf_repo: |
|
|
upload_to_hf( |
|
|
sample_path, |
|
|
config.hf_repo, |
|
|
f"{config.hf_subdir}/samples", |
|
|
) |
|
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
|
|
|
if torch.cuda.get_device_properties(0).total_memory < 20e9: |
|
|
zoo.offload("vae") |
|
|
zoo.offload("t5") |
|
|
zoo.offload("clip") |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return sample_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n[6/6] Training...") |
|
|
|
|
|
from tinyflux.trainer.schedules import sample_timesteps |
|
|
from tinyflux.utils.predictions import flow_x_t, flow_velocity |
|
|
from tinyflux.model.model import TinyFluxDeep |
|
|
|
|
|
loader = create_dataloader( |
|
|
cache, |
|
|
repeats=config.repeats, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=8 |
|
|
) |
|
|
|
|
|
|
|
|
steps_per_epoch = len(loader) |
|
|
total_steps = steps_per_epoch * config.epochs |
|
|
warmup_steps = int(config.warmup_epochs * steps_per_epoch) |
|
|
|
|
|
print(f" {n_images} images × {config.repeats} repeats = {steps_per_epoch} steps/epoch") |
|
|
print(f" {config.epochs} epochs = {total_steps} total steps") |
|
|
print(f" Warmup: {warmup_steps} steps ({config.warmup_epochs} epochs)") |
|
|
|
|
|
|
|
|
if has_lr_groups: |
|
|
param_groups = lora.get_param_groups(config.lr) |
|
|
optimizer = torch.optim.AdamW(param_groups, weight_decay=0.01) |
|
|
print(f" Using {len(param_groups)} LR groups") |
|
|
else: |
|
|
optimizer = torch.optim.AdamW(lora.parameters(), lr=config.lr, weight_decay=0.01) |
|
|
|
|
|
def lr_lambda(step): |
|
|
if step < warmup_steps: |
|
|
return step / warmup_steps |
|
|
return 1.0 |
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
|
|
model.train() |
|
|
global_step = 0 |
|
|
running_loss = 0.0 |
|
|
log_every = max(1, steps_per_epoch // 10) |
|
|
|
|
|
for epoch in range(1, config.epochs + 1): |
|
|
epoch_loss = 0.0 |
|
|
epoch_steps = 0 |
|
|
|
|
|
pbar = tqdm(loader, desc=f"Epoch {epoch}/{config.epochs}") |
|
|
|
|
|
for batch in pbar: |
|
|
indices = batch['index'] |
|
|
B = len(indices) |
|
|
|
|
|
|
|
|
latents, t5_embed, clip_embed = cache.get_encodings_batch(indices) |
|
|
latents = latents.to(device, dtype=dtype) |
|
|
t5_embed = t5_embed.to(device, dtype=dtype) |
|
|
clip_embed = clip_embed.to(device, dtype=dtype) |
|
|
|
|
|
|
|
|
target_latent_size = config.train_resolution // 8 |
|
|
if latents.shape[-1] != target_latent_size: |
|
|
latents = torch.nn.functional.interpolate( |
|
|
latents, |
|
|
size=(target_latent_size, target_latent_size), |
|
|
mode='bilinear', |
|
|
align_corners=False, |
|
|
) |
|
|
|
|
|
H = W = latents.shape[-1] |
|
|
|
|
|
|
|
|
t = sample_timesteps(B, device=device, dtype=dtype, shift=3.0) |
|
|
|
|
|
|
|
|
lune_features = cache.get_lune(indices, t) |
|
|
if lune_features is not None: |
|
|
lune_features = lune_features.to(device, dtype=dtype) |
|
|
|
|
|
sol_stats, sol_spatial = cache.get_sol(indices, t) |
|
|
if sol_stats is not None: |
|
|
sol_stats = sol_stats.to(device, dtype=dtype) |
|
|
sol_spatial = sol_spatial.to(device, dtype=dtype) |
|
|
|
|
|
|
|
|
noise = torch.randn_like(latents) |
|
|
x_t = flow_x_t(latents, noise, t) |
|
|
v_target = flow_velocity(latents, noise) |
|
|
|
|
|
|
|
|
x_t_seq = x_t.flatten(2).transpose(1, 2) |
|
|
v_target_seq = v_target.flatten(2).transpose(1, 2) |
|
|
|
|
|
|
|
|
img_ids = TinyFluxDeep.create_img_ids(B, H, W, device) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
with torch.autocast(device, dtype=dtype): |
|
|
v_pred = model( |
|
|
hidden_states=x_t_seq, |
|
|
encoder_hidden_states=t5_embed, |
|
|
pooled_projections=clip_embed, |
|
|
timestep=t, |
|
|
img_ids=img_ids, |
|
|
lune_features=lune_features, |
|
|
sol_stats=sol_stats, |
|
|
sol_spatial=sol_spatial, |
|
|
) |
|
|
|
|
|
loss = F.mse_loss(v_pred, v_target_seq) |
|
|
|
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(lora.parameters(), 1.0) |
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
loss_val = loss.item() |
|
|
running_loss += loss_val |
|
|
epoch_loss += loss_val |
|
|
global_step += 1 |
|
|
epoch_steps += 1 |
|
|
|
|
|
if global_step % log_every == 0: |
|
|
avg_loss = running_loss / log_every |
|
|
pbar.set_postfix( |
|
|
loss=f"{avg_loss:.4f}", |
|
|
lr=f"{scheduler.get_last_lr()[0]:.2e}", |
|
|
) |
|
|
running_loss = 0.0 |
|
|
|
|
|
|
|
|
avg_epoch_loss = epoch_loss / epoch_steps |
|
|
print(f" Epoch {epoch} complete | Loss: {avg_epoch_loss:.4f}") |
|
|
|
|
|
|
|
|
if epoch % config.save_every_epoch == 0: |
|
|
ckpt_path = os.path.join(config.output_dir, f"lora_epoch_{epoch}.safetensors") |
|
|
lora.save(ckpt_path) |
|
|
print(f" Saved: {ckpt_path}") |
|
|
|
|
|
|
|
|
if config.hf_repo and epoch % config.upload_every_epoch == 0: |
|
|
ckpt_path = os.path.join(config.output_dir, f"lora_epoch_{epoch}.safetensors") |
|
|
if not os.path.exists(ckpt_path): |
|
|
lora.save(ckpt_path) |
|
|
upload_to_hf(ckpt_path, config.hf_repo, config.hf_subdir) |
|
|
|
|
|
|
|
|
if config.sample_every_epoch and config.sample_prompts: |
|
|
print(f" Generating samples...") |
|
|
do_sample(epoch) |
|
|
|
|
|
|
|
|
final_path = os.path.join(config.output_dir, "lora_final.safetensors") |
|
|
lora.save(final_path) |
|
|
|
|
|
|
|
|
if config.hf_repo: |
|
|
upload_to_hf(final_path, config.hf_repo, config.hf_subdir, "lora_final.safetensors") |
|
|
|
|
|
|
|
|
if config.sample_prompts: |
|
|
print("\nGenerating final samples...") |
|
|
do_sample(config.epochs) |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Training complete!") |
|
|
print(f" Epochs: {config.epochs}") |
|
|
print(f" Total steps: {total_steps}") |
|
|
print(f" Final LoRA: {final_path}") |
|
|
if config.hf_repo: |
|
|
print(f" HF Repo: https://huggingface.co/{config.hf_repo}/tree/main/{config.hf_subdir}") |
|
|
print("=" * 60) |
|
|
|
|
|
return model, lora |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
COLAB_SETUP = """ |
|
|
# Cell 1: Mount Drive and install |
|
|
from google.colab import drive |
|
|
drive.mount('/content/drive') |
|
|
|
|
|
!pip install -q safetensors accelerate huggingface_hub |
|
|
!pip install -q git+https://github.com/AbstractPhil/tinyflux.git |
|
|
|
|
|
# Cell 2: Login to HuggingFace (for uploads) |
|
|
from huggingface_hub import login |
|
|
from google.colab import userdata |
|
|
login(userdata.get("HF_TOKEN")) |
|
|
|
|
|
# Cell 3: Train! |
|
|
from tinyflux.examples.train_lora_colab import train_lora, LoRAConfig |
|
|
|
|
|
config = LoRAConfig( |
|
|
# Data |
|
|
data_dir="/content/drive/MyDrive/test_1024", |
|
|
output_dir="/content/lora_output", |
|
|
repeats=100, # 10 images × 100 repeats = 1000 steps/epoch |
|
|
|
|
|
# LoRA config: preset name or path to JSON file |
|
|
# Presets: "minimal", "standard", "character", "concept", "full", "progressive" |
|
|
lora_config="character", |
|
|
|
|
|
# Optional: override rank from preset |
|
|
lora_rank=None, # Set to override default |
|
|
|
|
|
# Training |
|
|
epochs=10, |
|
|
batch_size=1, |
|
|
lr=1e-4, |
|
|
train_resolution=512, # 512 for A100, 256 for T4 |
|
|
|
|
|
# HuggingFace |
|
|
hf_repo="AbstractPhil/tinyflux-lailah-loras", |
|
|
hf_subdir="my_character_v1", |
|
|
upload_every_epoch=2, |
|
|
|
|
|
# Sampling |
|
|
sample_prompts=[ |
|
|
"a red cube on a blue sphere", |
|
|
"A man wearing a brown cap sitting at his computer with a black and brown dog resting next to him on the couch.", |
|
|
], |
|
|
sample_every_epoch=True, |
|
|
) |
|
|
|
|
|
model, lora = train_lora(config) |
|
|
""" |
|
|
|
|
|
if __name__ == "__main__": |
|
|
from huggingface_hub import login |
|
|
from google.colab import userdata |
|
|
login(userdata.get("HF_TOKEN")) |
|
|
|
|
|
config = LoRAConfig( |
|
|
data_dir="/content/drive/MyDrive/test_1024", |
|
|
output_dir="/content/lora_output3_no_experts_full", |
|
|
repeats=100, |
|
|
epochs=10, |
|
|
lora_config="full", |
|
|
build_sol=False, |
|
|
build_lune=False, |
|
|
train_resolution=512, |
|
|
) |
|
|
|
|
|
model, lora = train_lora(config) |