| | """ |
| | SD15 Flow-Matching Trainer - ControlNet Pose Edition |
| | Author: AbstractPhil |
| | |
| | Trains Lune on controlnet pose dataset with transparent backgrounds. |
| | |
| | License: MIT |
| | """ |
| |
|
| | import os |
| | import json |
| | import datetime |
| | import random |
| | from dataclasses import dataclass, asdict, field |
| | from tqdm.auto import tqdm |
| | import matplotlib.pyplot as plt |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch.utils.tensorboard import SummaryWriter |
| | from torch.utils.data import DataLoader |
| |
|
| | import datasets |
| | from diffusers import UNet2DConditionModel, AutoencoderKL |
| | from transformers import CLIPTextModel, CLIPTokenizer |
| | from huggingface_hub import HfApi, create_repo, hf_hub_download |
| |
|
| |
|
| | @dataclass |
| | class TrainConfig: |
| | output_dir: str = "./outputs" |
| | model_repo: str = "AbstractPhil/sd15-flow-lune" |
| | checkpoint_filename: str = "sd15_flow_pretrain_pose_controlnet_t500_700_s8312.pt" |
| | dataset_name: str = "AbstractPhil/CN_pose3D_V7_512" |
| | use_masks: bool = True |
| | mask_column: str = "mask" |
| | |
| | |
| | hf_repo_id: str = "AbstractPhil/sd15-flow-lune" |
| | upload_to_hub: bool = True |
| | |
| | |
| | run_name: str = "pretrain_pose_controlnet_v7_v10_t400_600" |
| | |
| | |
| | continue_from_checkpoint: bool = False |
| | |
| | seed: int = 42 |
| | batch_size: int = 64 |
| |
|
| | |
| | base_lr: float = 2e-6 |
| | shift: float = 2.5 |
| | dropout: float = 0.1 |
| | min_snr_gamma: float = 5.0 |
| | |
| | |
| | |
| | min_timestep: float = 400.0 |
| | max_timestep: float = 600.0 |
| | |
| | |
| | num_train_epochs: int = 1 |
| | warmup_epochs: int = 1 |
| | checkpointing_steps: int = 2500 |
| | num_workers: int = 0 |
| | |
| | |
| | vae_scale: float = 0.18215 |
| |
|
| | |
| | delimiter: str = "," |
| | preserved_count: int = 2 |
| | remove_these: list = field(default_factory=lambda: [ |
| | "simple background", |
| | "white background"]) |
| | prepend_prompt: str = "doll" |
| | append_prompt: str = "transparent background" |
| | shuffle_prompt: bool = True |
| |
|
| |
|
| | def preprocess_caption(text: str, config: TrainConfig) -> str: |
| | """ |
| | Preprocess controlnet pose captions with config-based shuffling: |
| | - Lowercase and clean punctuation |
| | - Remove unwanted tokens from config.remove_these |
| | - Prepend config.prepend_prompt |
| | - Shuffle tokens (preserving first config.preserved_count) |
| | - Append config.append_prompt |
| | """ |
| | |
| | if text is None or text == "": |
| | if config.append_prompt: |
| | return config.append_prompt |
| | return "" |
| | |
| | |
| | text = text.lower() |
| | text = text.replace(".", config.delimiter) |
| | text = text.strip() |
| | |
| | |
| | while f"{config.delimiter}{config.delimiter}" in text: |
| | text = text.replace(f"{config.delimiter}{config.delimiter}", config.delimiter) |
| | while " " in text: |
| | text = text.replace(" ", " ") |
| | |
| | text = text.strip() |
| | |
| | |
| | if text.startswith(config.delimiter): |
| | text = text[1:].strip() |
| | if text.endswith(config.delimiter): |
| | text = text[:-1].strip() |
| | |
| | |
| | if config.prepend_prompt: |
| | text = f"{config.prepend_prompt}{config.delimiter} {text}" if text else config.prepend_prompt |
| | |
| | |
| | if config.shuffle_prompt and text: |
| | |
| | tokens = [t.strip() for t in text.split(config.delimiter) if t.strip()] |
| | |
| | |
| | if config.remove_these: |
| | tokens = [t for t in tokens if t not in config.remove_these] |
| | |
| | |
| | preserved = tokens[:config.preserved_count] |
| | shuffleable = tokens[config.preserved_count:] |
| | |
| | |
| | random.shuffle(shuffleable) |
| | |
| | |
| | tokens = preserved + shuffleable |
| | text = f"{config.delimiter} ".join(tokens) |
| | else: |
| | |
| | if config.remove_these and text: |
| | tokens = [t.strip() for t in text.split(config.delimiter) if t.strip()] |
| | tokens = [t for t in tokens if t not in config.remove_these] |
| | text = f"{config.delimiter} ".join(tokens) |
| | |
| | |
| | if config.append_prompt: |
| | text = f"{text}{config.delimiter} {config.append_prompt}" if text else config.append_prompt |
| | |
| | return text |
| |
|
| |
|
| | def load_student_unet(repo_id: str, filename: str, device="cuda"): |
| | """Load UNet from checkpoint, return checkpoint dict for optional optimizer/scheduler restoration""" |
| | print(f"Downloading checkpoint from {repo_id}/{filename}...") |
| | checkpoint_path = hf_hub_download( |
| | repo_id=repo_id, |
| | filename=filename, |
| | repo_type="model" |
| | ) |
| | |
| | checkpoint = torch.load(checkpoint_path, map_location="cpu") |
| | |
| | print("Loading SD1.5 UNet architecture...") |
| | unet = UNet2DConditionModel.from_pretrained( |
| | "runwayml/stable-diffusion-v1-5", |
| | subfolder="unet", |
| | torch_dtype=torch.float32 |
| | ) |
| | |
| | |
| | student_state_dict = checkpoint["student"] |
| | |
| | |
| | cleaned_dict = {} |
| | for key, value in student_state_dict.items(): |
| | cleaned_key = key[5:] if key.startswith("unet.") else key |
| | cleaned_dict[cleaned_key] = value |
| | |
| | unet.load_state_dict(cleaned_dict, strict=False) |
| | |
| | print(f"✓ Loaded UNet from step {checkpoint.get('gstep', 'unknown')}") |
| | |
| | return unet.to(device), checkpoint |
| |
|
| |
|
| | def train(config: TrainConfig): |
| | device = "cuda" |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | |
| | torch.manual_seed(config.seed) |
| | torch.cuda.manual_seed(config.seed) |
| | |
| | |
| | date_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
| | real_output_dir = os.path.join(config.output_dir, date_time) |
| | os.makedirs(real_output_dir, exist_ok=True) |
| | t_writer = SummaryWriter(log_dir=real_output_dir, flush_secs=60) |
| | |
| | |
| | hf_api = None |
| | if config.upload_to_hub: |
| | try: |
| | hf_api = HfApi() |
| | create_repo( |
| | repo_id=config.hf_repo_id, |
| | repo_type="model", |
| | exist_ok=True, |
| | private=False |
| | ) |
| | print(f"✓ HuggingFace repo ready: {config.hf_repo_id}") |
| | except Exception as e: |
| | print(f"⚠ Hub upload disabled: {e}") |
| | config.upload_to_hub = False |
| | |
| | |
| | config_path = os.path.join(real_output_dir, "config.json") |
| | with open(config_path, "w") as f: |
| | json.dump(asdict(config), f, indent=2) |
| | |
| | if config.upload_to_hub: |
| | hf_api.upload_file( |
| | path_or_fileobj=config_path, |
| | path_in_repo="config.json", |
| | repo_id=config.hf_repo_id, |
| | repo_type="model" |
| | ) |
| | |
| | |
| | print("\nLoading SD1.5 VAE and CLIP...") |
| | vae = AutoencoderKL.from_pretrained( |
| | "runwayml/stable-diffusion-v1-5", |
| | subfolder="vae", |
| | torch_dtype=torch.float32 |
| | ).to(device) |
| | vae.requires_grad_(False) |
| | vae.eval() |
| | |
| | tokenizer = CLIPTokenizer.from_pretrained( |
| | "runwayml/stable-diffusion-v1-5", |
| | subfolder="tokenizer" |
| | ) |
| | text_encoder = CLIPTextModel.from_pretrained( |
| | "runwayml/stable-diffusion-v1-5", |
| | subfolder="text_encoder", |
| | torch_dtype=torch.float32 |
| | ).to(device) |
| | text_encoder.requires_grad_(False) |
| | text_encoder.eval() |
| | |
| | print("✓ VAE and CLIP loaded") |
| | |
| | |
| | print(f"\nLoading dataset: {config.dataset_name}") |
| | train_dataset = datasets.load_dataset( |
| | config.dataset_name, |
| | split="train" |
| | ) |
| | |
| | print(f"✓ Loaded {len(train_dataset):,} images") |
| | print(f" Columns: {train_dataset.column_names}") |
| | |
| | |
| | steps_per_epoch = len(train_dataset) // config.batch_size |
| | total_steps = steps_per_epoch * config.num_train_epochs |
| | warmup_steps = steps_per_epoch * config.warmup_epochs |
| | |
| | print(f"\nTraining schedule:") |
| | print(f" Total images: {len(train_dataset):,}") |
| | print(f" Batch size: {config.batch_size}") |
| | print(f" Steps per epoch: {steps_per_epoch:,}") |
| | print(f" Total epochs: {config.num_train_epochs}") |
| | print(f" Total steps: {total_steps:,}") |
| | print(f" Warmup steps: {warmup_steps:,}") |
| | print(f"\nTimestep range:") |
| | print(f" Min timestep: {config.min_timestep}") |
| | print(f" Max timestep: {config.max_timestep}") |
| | print(f" Training on: {config.max_timestep - config.min_timestep} timestep range") |
| | print(f"\nPrompt preprocessing:") |
| | print(f" Shuffle: {config.shuffle_prompt}") |
| | print(f" Preserved tokens: {config.preserved_count}") |
| | print(f" Prepend: '{config.prepend_prompt}'") |
| | print(f" Append: '{config.append_prompt}'") |
| | print(f" Remove: {config.remove_these}") |
| | |
| | @torch.no_grad() |
| | def collate_fn(examples): |
| | """Encode images, masks (optional), and prompts at runtime""" |
| | import numpy as np |
| | |
| | images = [] |
| | masks = [] |
| | prompts = [] |
| | image_ids = [] |
| | |
| | for idx, ex in enumerate(examples): |
| | |
| | img = ex['image'].convert('RGB') |
| | img = torch.tensor(np.array(img)).permute(2, 0, 1).float() / 255.0 |
| | img = img * 2.0 - 1.0 |
| | images.append(img) |
| | |
| | |
| | if config.use_masks and config.mask_column in ex: |
| | |
| | mask = ex[config.mask_column].convert('L') |
| | mask = torch.tensor(np.array(mask)).float() / 255.0 |
| | masks.append(mask) |
| | |
| | |
| | raw_text = ex['text'] |
| | processed_prompt = preprocess_caption(raw_text, config) |
| | prompts.append(processed_prompt) |
| | image_ids.append(idx) |
| | |
| | images = torch.stack(images).to(device) |
| | |
| | |
| | latents = vae.encode(images).latent_dist.sample() |
| | latents = latents * config.vae_scale |
| | |
| | |
| | if config.use_masks and masks: |
| | masks = torch.stack(masks).to(device) |
| | |
| | masks_downsampled = F.interpolate( |
| | masks.unsqueeze(1), |
| | size=latents.shape[-2:], |
| | mode='nearest' |
| | ).squeeze(1) |
| | else: |
| | |
| | masks_downsampled = torch.ones( |
| | (latents.shape[0], latents.shape[2], latents.shape[3]), |
| | dtype=torch.float32 |
| | ) |
| | |
| | |
| | text_inputs = tokenizer( |
| | prompts, |
| | padding="max_length", |
| | max_length=tokenizer.model_max_length, |
| | truncation=True, |
| | return_tensors="pt" |
| | ).to(device) |
| | |
| | encoder_hidden_states = text_encoder(text_inputs.input_ids)[0] |
| | |
| | return ( |
| | latents.cpu(), |
| | masks_downsampled.cpu(), |
| | encoder_hidden_states.cpu(), |
| | image_ids, |
| | prompts |
| | ) |
| | |
| | train_dataloader = DataLoader( |
| | dataset=train_dataset, |
| | batch_size=config.batch_size, |
| | shuffle=True, |
| | collate_fn=collate_fn, |
| | num_workers=config.num_workers, |
| | pin_memory=True |
| | ) |
| | |
| | |
| | print(f"\nLoading model from HuggingFace...") |
| | unet, checkpoint = load_student_unet(config.model_repo, config.checkpoint_filename, device=device) |
| | unet.requires_grad_(True) |
| | unet.train() |
| | |
| | |
| | optimizer = torch.optim.AdamW( |
| | unet.parameters(), |
| | lr=config.base_lr, |
| | betas=(0.9, 0.999), |
| | weight_decay=0.01, |
| | eps=1e-8 |
| | ) |
| | |
| | |
| | if config.continue_from_checkpoint: |
| | scheduler = torch.optim.lr_scheduler.LambdaLR( |
| | optimizer, |
| | lr_lambda=lambda step: 1.0 |
| | ) |
| | else: |
| | def get_lr_scale(step): |
| | if step < warmup_steps: |
| | return step / warmup_steps |
| | return 1.0 |
| | |
| | scheduler = torch.optim.lr_scheduler.LambdaLR( |
| | optimizer, |
| | lr_lambda=get_lr_scale |
| | ) |
| | |
| | |
| | start_step = 0 |
| | |
| | if config.continue_from_checkpoint: |
| | if "opt" in checkpoint and "scheduler" in checkpoint: |
| | optimizer.load_state_dict(checkpoint["opt"]) |
| | scheduler.load_state_dict(checkpoint["scheduler"]) |
| | start_step = checkpoint.get("gstep", 0) |
| | print(f"✓ Resumed optimizer and scheduler from step {start_step}") |
| | print(f" Will train for {config.num_train_epochs} more epoch(s) = {total_steps:,} additional steps") |
| | else: |
| | print("⚠ No optimizer/scheduler state in checkpoint, starting fresh") |
| | else: |
| | print("✓ Starting with fresh optimizer (no state loaded)") |
| | |
| | global_step = start_step |
| | end_step = start_step + total_steps |
| | train_logs = { |
| | "train_step": [], |
| | "train_loss": [], |
| | "train_timestep": [], |
| | "trained_images": [] |
| | } |
| | |
| | def get_prediction(batch, log_to=None): |
| | latents, masks, encoder_hidden_states, ids, prompts = batch |
| | |
| | latents = latents.to(dtype=torch.float32, device=device) |
| | if config.use_masks: |
| | masks = masks.to(dtype=torch.float32, device=device) |
| | encoder_hidden_states = encoder_hidden_states.to(dtype=torch.float32, device=device) |
| | |
| | batch_size = latents.shape[0] |
| | |
| | |
| | dropout_mask = torch.rand(batch_size, device=device) < config.dropout |
| | encoder_hidden_states = encoder_hidden_states.clone() |
| | encoder_hidden_states[dropout_mask] = 0 |
| | |
| | |
| | min_sigma = config.min_timestep / 1000.0 |
| | max_sigma = config.max_timestep / 1000.0 |
| | |
| | sigmas = torch.rand(batch_size, device=device) |
| | sigmas = min_sigma + sigmas * (max_sigma - min_sigma) |
| | |
| | |
| | sigmas = (config.shift * sigmas) / (1 + (config.shift - 1) * sigmas) |
| | timesteps = sigmas * 1000 |
| | sigmas = sigmas[:, None, None, None] |
| | |
| | |
| | noise = torch.randn_like(latents) |
| | noisy_latents = noise * sigmas + latents * (1 - sigmas) |
| | target = noise - latents |
| | |
| | |
| | pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] |
| | |
| | |
| | loss = F.mse_loss(pred, target, reduction="none") |
| | loss = loss.mean(dim=1) |
| | |
| | |
| | |
| | snr = ((1 - sigmas.squeeze()) ** 2) / (sigmas.squeeze() ** 2 + 1e-8) |
| | snr_weight = torch.minimum(snr, torch.ones_like(snr) * config.min_snr_gamma) / snr |
| | |
| | |
| | snr_weight = snr_weight / (snr + 1) |
| | snr_weight = snr_weight[:, None, None] |
| | |
| | loss = loss * snr_weight |
| | |
| | |
| | if config.use_masks: |
| | |
| | |
| | masked_loss = loss * masks |
| | |
| | |
| | loss_per_sample = masked_loss.sum(dim=[1, 2]) / (masks.sum(dim=[1, 2]) + 1e-8) |
| | else: |
| | |
| | loss_per_sample = loss.mean(dim=[1, 2]) |
| | |
| | if log_to is not None: |
| | for i in range(batch_size): |
| | log_to["train_step"].append(global_step) |
| | log_to["train_loss"].append(loss_per_sample[i].item()) |
| | log_to["train_timestep"].append(timesteps[i].item()) |
| | log_to["trained_images"].append({ |
| | "step": global_step, |
| | "id": ids[i], |
| | "prompt": prompts[i] |
| | }) |
| | |
| | return loss_per_sample.mean() |
| | |
| | def plot_logs(log_dict): |
| | plt.figure(figsize=(10, 6)) |
| | plt.scatter( |
| | log_dict["train_timestep"], |
| | log_dict["train_loss"], |
| | s=3, |
| | c=log_dict["train_step"], |
| | marker=".", |
| | cmap='cool' |
| | ) |
| | plt.xlabel("timestep") |
| | plt.ylabel("loss") |
| | plt.yscale("log") |
| | plt.colorbar(label="step") |
| | |
| | def save_checkpoint(step, relative_epoch): |
| | checkpoint_path = os.path.join(real_output_dir, f"{config.run_name}_checkpoint-{step:08}") |
| | os.makedirs(checkpoint_path, exist_ok=True) |
| | |
| | |
| | unet.save_pretrained( |
| | os.path.join(checkpoint_path, "unet"), |
| | safe_serialization=True |
| | ) |
| | |
| | |
| | pt_filename = f"sd15_flow_{config.run_name}_s{step}.pt" |
| | pt_path = os.path.join(checkpoint_path, pt_filename) |
| | |
| | torch.save({ |
| | "cfg": asdict(config), |
| | "student": unet.state_dict(), |
| | "opt": optimizer.state_dict(), |
| | "scheduler": scheduler.state_dict(), |
| | "gstep": step, |
| | "relative_epoch": relative_epoch |
| | }, pt_path) |
| | |
| | |
| | metadata = { |
| | "step": step, |
| | "relative_epoch": relative_epoch, |
| | "trained_images": train_logs["trained_images"] |
| | } |
| | metadata_path = os.path.join(checkpoint_path, "trained_images.json") |
| | with open(metadata_path, "w") as f: |
| | json.dump(metadata, f, indent=2) |
| | |
| | print(f"✓ Checkpoint saved at step {step} (relative epoch {relative_epoch})") |
| | |
| | |
| | if config.upload_to_hub and hf_api is not None: |
| | try: |
| | hf_api.upload_file( |
| | path_or_fileobj=pt_path, |
| | path_in_repo=pt_filename, |
| | repo_id=config.hf_repo_id, |
| | repo_type="model" |
| | ) |
| | hf_api.upload_folder( |
| | folder_path=os.path.join(checkpoint_path, "unet"), |
| | path_in_repo=f"{config.run_name}/checkpoint-{step:08}/unet", |
| | repo_id=config.hf_repo_id, |
| | repo_type="model" |
| | ) |
| | hf_api.upload_file( |
| | path_or_fileobj=metadata_path, |
| | path_in_repo=f"{config.run_name}/checkpoint-{step:08}/trained_images.json", |
| | repo_id=config.hf_repo_id, |
| | repo_type="model" |
| | ) |
| | print(f"✓ Uploaded to hub: {config.hf_repo_id}") |
| | except Exception as e: |
| | print(f"⚠ Upload failed: {e}") |
| | |
| | print("\nStarting training...") |
| | progress_bar = tqdm(total=total_steps, initial=0) |
| | |
| | epoch = 0 |
| | while global_step < end_step: |
| | epoch += 1 |
| | for batch in train_dataloader: |
| | if global_step >= end_step: |
| | break |
| | |
| | loss = get_prediction(batch, log_to=train_logs) |
| | t_writer.add_scalar("train/loss", loss.item(), global_step) |
| | t_writer.add_scalar("train/lr", scheduler.get_last_lr()[0], global_step) |
| | |
| | |
| | if len(train_logs["train_timestep"]) > 0: |
| | recent_timesteps = train_logs["train_timestep"][-config.batch_size:] |
| | t_writer.add_scalar("train/mean_timestep", sum(recent_timesteps) / len(recent_timesteps), global_step) |
| | t_writer.add_scalar("train/min_timestep", min(recent_timesteps), global_step) |
| | t_writer.add_scalar("train/max_timestep", max(recent_timesteps), global_step) |
| | |
| | loss.backward() |
| | |
| | grad_norm = torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0) |
| | t_writer.add_scalar("train/grad_norm", grad_norm.item(), global_step) |
| | |
| | optimizer.step() |
| | scheduler.step() |
| | optimizer.zero_grad() |
| | |
| | progress_bar.update(1) |
| | progress_bar.set_postfix({ |
| | "epoch": epoch, |
| | "loss": f"{loss.item():.4f}", |
| | "lr": f"{scheduler.get_last_lr()[0]:.2e}", |
| | "gstep": global_step |
| | }) |
| | global_step += 1 |
| | |
| | if global_step % 100 == 0: |
| | plot_logs(train_logs) |
| | t_writer.add_figure("train_loss", plt.gcf(), global_step) |
| | plt.close() |
| | |
| | if global_step % config.checkpointing_steps == 0: |
| | save_checkpoint(global_step, epoch) |
| | |
| | |
| | save_checkpoint(global_step, epoch) |
| | |
| | print("\n✅ Training complete!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | config = TrainConfig() |
| | train(config) |