| """ |
| SD15 Flow-Matching trainer |
| Author: AbstractPhil |
| |
| Loads the current format pt and ensures through multiple validations that the process is correct for training. |
| |
| Trains flow matching for sd15. |
| |
| License: MIT |
| If you use my work, a cite wouldnt hurt. |
| |
| """ |
|
|
| import os |
| import json |
| import datetime |
| from dataclasses import dataclass, asdict |
| from tqdm.auto import tqdm |
| import matplotlib.pyplot as plt |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.utils.tensorboard import SummaryWriter |
| from torch.utils.data import DataLoader |
|
|
| import datasets |
| from diffusers import UNet2DConditionModel |
| from huggingface_hub import HfApi, create_repo, hf_hub_download |
|
|
|
|
| @dataclass |
| class TrainConfig: |
| output_dir: str = "./outputs" |
| model_repo: str = "AbstractPhil/sd15-flow-matching-try2" |
| checkpoint_filename: str = "sd15_flowmatch_david_weighted_2_e34.pt" |
| dataset_name: str = "AbstractPhil/sd15-latent-distillation-500k" |
| |
| |
| hf_repo_id: str = "AbstractPhil/sd15-flow-lune" |
| upload_to_hub: bool = True |
| |
| seed: int = 42 |
| batch_size: int = 16 |
| base_lr: float = 2e-6 |
| shift: float = 2.0 |
| dropout: float = 0.1 |
| |
| max_train_steps: int = 50_000 |
| checkpointing_steps: int = 1000 |
| num_workers: int = 0 |
| |
| |
| vae_scale: float = 0.18215 |
|
|
|
|
| def load_student_unet(repo_id: str, filename: str, device="cuda") -> UNet2DConditionModel: |
| """Load UNet from .pt checkpoint containing student state_dict""" |
| |
| print(f"Downloading checkpoint from {repo_id}/{filename}...") |
| checkpoint_path = hf_hub_download( |
| repo_id=repo_id, |
| filename=filename, |
| repo_type="model" |
| ) |
| print(f"✓ Downloaded to: {checkpoint_path}") |
| |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") |
| |
| |
| print("Loading SD1.5 UNet architecture...") |
| unet = UNet2DConditionModel.from_pretrained( |
| "runwayml/stable-diffusion-v1-5", |
| subfolder="unet", |
| torch_dtype=torch.float32 |
| ) |
| |
| |
| original_state_dict = {k: v.clone() for k, v in unet.state_dict().items()} |
| |
| |
| student_state_dict = checkpoint["student"] |
| |
| |
| cleaned_student_dict = {} |
| for key, value in student_state_dict.items(): |
| if key.startswith("unet."): |
| cleaned_key = key[5:] |
| cleaned_student_dict[cleaned_key] = value |
| else: |
| cleaned_student_dict[key] = value |
| |
| print(f"\n{'='*70}") |
| print("WEIGHT VERIFICATION") |
| print(f"{'='*70}") |
| |
| |
| original_keys = set(original_state_dict.keys()) |
| student_keys = set(cleaned_student_dict.keys()) |
| |
| matching_keys = original_keys & student_keys |
| |
| print(f"Original UNet keys: {len(original_keys)}") |
| print(f"Student checkpoint keys: {len(student_keys)}") |
| print(f"Matching keys: {len(matching_keys)}") |
| |
| |
| total_params = 0 |
| different_params = 0 |
| mean_diff_sum = 0.0 |
| max_diff = 0.0 |
| |
| for key in matching_keys: |
| if key not in original_state_dict or key not in cleaned_student_dict: |
| continue |
| |
| orig = original_state_dict[key] |
| student = cleaned_student_dict[key].float() |
| |
| if orig.shape != student.shape: |
| print(f"⚠ Shape mismatch for {key}: {orig.shape} vs {student.shape}") |
| continue |
| |
| total_params += orig.numel() |
| |
| |
| diff = (orig - student).abs() |
| if diff.max() > 1e-6: |
| different_params += orig.numel() |
| mean_diff_sum += diff.sum().item() |
| max_diff = max(max_diff, diff.max().item()) |
| |
| pct_different = (different_params / total_params * 100) if total_params > 0 else 0 |
| avg_diff = mean_diff_sum / different_params if different_params > 0 else 0 |
| |
| print(f"\nStudent vs Original (BEFORE loading):") |
| print(f" Total parameters: {total_params:,}") |
| print(f" Parameters different: {different_params:,} ({pct_different:.1f}%)") |
| print(f" Average difference: {avg_diff:.6f}") |
| print(f" Max difference: {max_diff:.6f}") |
| |
| |
| load_result = unet.load_state_dict(cleaned_student_dict, strict=False) |
| |
| if load_result.missing_keys: |
| print(f"\n⚠ Missing keys during load: {len(load_result.missing_keys)}") |
| for key in load_result.missing_keys[:3]: |
| print(f" - {key}") |
| |
| if load_result.unexpected_keys: |
| print(f"⚠ Unexpected keys during load: {len(load_result.unexpected_keys)}") |
| for key in load_result.unexpected_keys[:3]: |
| print(f" - {key}") |
| |
| |
| loaded_state_dict = unet.state_dict() |
| |
| total_params_after = 0 |
| changed_params = 0 |
| mean_diff_after = 0.0 |
| max_diff_after = 0.0 |
| |
| for key in matching_keys: |
| if key not in original_state_dict or key not in loaded_state_dict: |
| continue |
| |
| orig = original_state_dict[key] |
| loaded = loaded_state_dict[key] |
| |
| total_params_after += orig.numel() |
| |
| diff = (orig - loaded).abs() |
| if diff.max() > 1e-6: |
| changed_params += orig.numel() |
| mean_diff_after += diff.sum().item() |
| max_diff_after = max(max_diff_after, diff.max().item()) |
| |
| pct_changed = (changed_params / total_params_after * 100) if total_params_after > 0 else 0 |
| avg_diff_after = mean_diff_after / changed_params if changed_params > 0 else 0 |
| |
| print(f"\nOriginal vs Loaded (AFTER loading):") |
| print(f" Parameters changed: {changed_params:,} ({pct_changed:.1f}%)") |
| print(f" Average difference: {avg_diff_after:.6f}") |
| print(f" Max difference: {max_diff_after:.6f}") |
| |
| print(f"\n{'='*70}") |
| |
| if pct_different < 50: |
| print(f"⚠️ WARNING: Student weights only {pct_different:.1f}% different from base!") |
| print(" This checkpoint may not be trained.") |
| elif pct_changed < 90: |
| print(f"⚠️ WARNING: Only {pct_changed:.1f}% of weights changed after loading!") |
| print(" The load may have failed.") |
| else: |
| print(f"✅ Weights loaded successfully!") |
| print(f" Checkpoint step: {checkpoint.get('gstep', 'unknown')}") |
| print(f" {pct_different:.1f}% of weights differ from base SD1.5") |
| |
| print(f"{'='*70}\n") |
| |
| return unet.to(device) |
|
|
|
|
| def train(config: TrainConfig): |
| device = "cuda" |
| torch.backends.cuda.matmul.allow_tf32 = True |
| |
| torch.manual_seed(config.seed) |
| torch.cuda.manual_seed(config.seed) |
| |
| |
| date_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
| real_output_dir = os.path.join(config.output_dir, date_time) |
| os.makedirs(real_output_dir, exist_ok=True) |
| t_writer = SummaryWriter(log_dir=real_output_dir, flush_secs=60) |
| |
| |
| hf_api = None |
| if config.upload_to_hub: |
| try: |
| hf_api = HfApi() |
| create_repo( |
| repo_id=config.hf_repo_id, |
| repo_type="model", |
| exist_ok=True, |
| private=False |
| ) |
| print(f"✓ HuggingFace repo ready: {config.hf_repo_id}") |
| except Exception as e: |
| print(f"⚠ Hub upload disabled: {e}") |
| config.upload_to_hub = False |
| |
| |
| config_path = os.path.join(real_output_dir, "config.json") |
| with open(config_path, "w") as f: |
| json.dump(asdict(config), f, indent=2) |
| |
| if config.upload_to_hub: |
| hf_api.upload_file( |
| path_or_fileobj=config_path, |
| path_in_repo="config.json", |
| repo_id=config.hf_repo_id, |
| repo_type="model" |
| ) |
| |
| |
| print(f"\nLoading dataset (streaming): {config.dataset_name}") |
| train_dataset = datasets.load_dataset( |
| config.dataset_name, |
| split="train", |
| streaming=True, |
| trust_remote_code=True |
| ) |
| train_dataset = train_dataset.shuffle(seed=config.seed, buffer_size=1000) |
| print(f"✓ Dataset loaded in streaming mode") |
| |
| def collate_fn(examples): |
| |
| latents = torch.stack([torch.tensor(ex["latent"]) for ex in examples]) |
| latents = latents * config.vae_scale |
| |
| clip_embeddings = torch.stack([torch.tensor(ex["clip_embedding"]) for ex in examples]) |
| ids = [ex["id"] for ex in examples] |
| prompts = [ex["prompt"] for ex in examples] |
| |
| return latents, clip_embeddings, ids, prompts |
| |
| train_dataloader = DataLoader( |
| dataset=train_dataset, |
| batch_size=config.batch_size, |
| collate_fn=collate_fn, |
| num_workers=config.num_workers, |
| ) |
| |
| |
| print("\nVerifying latent scaling on first batch...") |
| first_batch = next(iter(train_dataloader)) |
| latents_check, _, _, _ = first_batch |
| print(f"Raw latent range: [{latents_check.min():.3f}, {latents_check.max():.3f}]") |
| latents_check = latents_check.to(device) |
| print(f"After GPU transfer: [{latents_check.min():.3f}, {latents_check.max():.3f}]") |
| print(f"Expected: ~[-1, 1] for properly scaled latents") |
| del latents_check |
| |
| |
| print(f"\nLoading model from HuggingFace...") |
| unet = load_student_unet(config.model_repo, config.checkpoint_filename, device=device) |
| unet.requires_grad_(True) |
| unet.enable_gradient_checkpointing() |
| unet.train() |
| |
| optimizer = torch.optim.Adam( |
| unet.parameters(), |
| lr=config.base_lr * (config.batch_size ** 0.5), |
| ) |
| |
| global_step = 0 |
| train_logs = { |
| "train_step": [], |
| "train_loss": [], |
| "train_timestep": [], |
| "trained_images": [] |
| } |
| |
| def get_prediction(batch, log_to=None): |
| latents, encoder_hidden_states, ids, prompts = batch |
| |
| |
| latents = latents.to(dtype=torch.float32, device=device) |
| encoder_hidden_states = encoder_hidden_states.to(dtype=torch.float32, device=device) |
| |
| batch_size = latents.shape[0] |
| |
| |
| dropout_mask = torch.rand(batch_size, device=device) < config.dropout |
| encoder_hidden_states = encoder_hidden_states.clone() |
| encoder_hidden_states[dropout_mask] = 0 |
| |
| |
| sigmas = torch.rand(batch_size, device=device) |
| sigmas = (config.shift * sigmas) / (1 + (config.shift - 1) * sigmas) |
| timesteps = sigmas * 1000 |
| sigmas = sigmas[:, None, None, None] |
| |
| |
| noise = torch.randn_like(latents) |
| noisy_latents = noise * sigmas + latents * (1 - sigmas) |
| target = noise - latents |
| |
| |
| pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] |
| |
| loss = F.mse_loss(pred, target, reduction="none") |
| loss = loss.mean(dim=list(range(1, len(loss.shape)))) |
| |
| if log_to is not None: |
| for i in range(batch_size): |
| log_to["train_step"].append(global_step) |
| log_to["train_loss"].append(loss[i].item()) |
| log_to["train_timestep"].append(timesteps[i].item()) |
| log_to["trained_images"].append({ |
| "step": global_step, |
| "id": ids[i], |
| "prompt": prompts[i] |
| }) |
| |
| return loss.mean() |
| |
| def plot_logs(log_dict): |
| plt.figure(figsize=(10, 6)) |
| plt.scatter( |
| log_dict["train_timestep"], |
| log_dict["train_loss"], |
| s=3, |
| c=log_dict["train_step"], |
| marker=".", |
| cmap='cool' |
| ) |
| plt.xlabel("timestep") |
| plt.ylabel("loss") |
| plt.yscale("log") |
| plt.colorbar(label="step") |
| |
| def save_checkpoint(step): |
| checkpoint_path = os.path.join(real_output_dir, f"checkpoint-{step:08}") |
| os.makedirs(checkpoint_path, exist_ok=True) |
| |
| |
| unet.save_pretrained( |
| os.path.join(checkpoint_path, "unet"), |
| safe_serialization=True |
| ) |
| |
| |
| pt_filename = f"sd15_flow_lune_e{step//1000}_s{step}.pt" |
| pt_path = os.path.join(checkpoint_path, pt_filename) |
| |
| torch.save({ |
| "cfg": asdict(config), |
| "student": unet.state_dict(), |
| "opt": optimizer.state_dict(), |
| "gstep": step |
| }, pt_path) |
| |
| |
| metadata = { |
| "step": step, |
| "trained_images": train_logs["trained_images"] |
| } |
| metadata_path = os.path.join(checkpoint_path, "trained_images.json") |
| with open(metadata_path, "w") as f: |
| json.dump(metadata, f, indent=2) |
| |
| print(f"✓ Checkpoint saved at step {step}") |
| |
| |
| if config.upload_to_hub and hf_api is not None: |
| try: |
| hf_api.upload_file( |
| path_or_fileobj=pt_path, |
| path_in_repo=pt_filename, |
| repo_id=config.hf_repo_id, |
| repo_type="model" |
| ) |
| |
| hf_api.upload_folder( |
| folder_path=os.path.join(checkpoint_path, "unet"), |
| path_in_repo=f"checkpoint-{step:08}/unet", |
| repo_id=config.hf_repo_id, |
| repo_type="model" |
| ) |
| |
| hf_api.upload_file( |
| path_or_fileobj=metadata_path, |
| path_in_repo=f"checkpoint-{step:08}/trained_images.json", |
| repo_id=config.hf_repo_id, |
| repo_type="model" |
| ) |
| |
| print(f"✓ Uploaded to hub: {config.hf_repo_id}") |
| except Exception as e: |
| print(f"⚠ Upload failed: {e}") |
| |
| print("\nStarting training...") |
| progress_bar = tqdm(range(0, config.max_train_steps)) |
| |
| for batch in train_dataloader: |
| loss = get_prediction(batch, log_to=train_logs) |
| t_writer.add_scalar("train/loss", loss.detach().item(), global_step) |
| |
| loss.backward() |
| |
| grad_norm = torch.nn.utils.clip_grad_norm_(unet.parameters(), 2.0) |
| t_writer.add_scalar("train/grad_norm", grad_norm.detach().item(), global_step) |
| |
| optimizer.step() |
| optimizer.zero_grad() |
| |
| progress_bar.update(1) |
| progress_bar.set_postfix({"loss": f"{loss.item():.4f}"}) |
| global_step += 1 |
| |
| if global_step % 100 == 0: |
| plot_logs(train_logs) |
| t_writer.add_figure("train_loss", plt.gcf(), global_step) |
| plt.close() |
| |
| if global_step % config.checkpointing_steps == 0: |
| save_checkpoint(global_step) |
| |
| if global_step >= config.max_train_steps: |
| save_checkpoint(global_step) |
| print("\n✅ Training complete!") |
| return |
|
|
|
|
| if __name__ == "__main__": |
| config = TrainConfig() |
| train(config) |