# ============================================================================ # SD1.5-Flow-Lune Inference - CORRECT (matches trainer) # ============================================================================ # Trainer's flow convention: # x_t = sigma * noise + (1 - sigma) * data # target = noise - data (velocity points FROM data TO noise) # sigma=0 → clean, sigma=1 → noise # # Sampling: sigma goes 1 → 0, so we SUBTRACT velocity # x_{sigma - dt} = x_sigma - v * dt # ============================================================================ !pip install -q diffusers transformers accelerate safetensors import torch import gc from huggingface_hub import hf_hub_download from diffusers import UNet2DConditionModel, AutoencoderKL from transformers import CLIPTextModel, CLIPTokenizer from safetensors.torch import load_file from PIL import Image import numpy as np import json torch.cuda.empty_cache() gc.collect() # ============================================================================ # CONFIG # ============================================================================ DEVICE = "cuda" DTYPE = torch.float16 LUNE_REPO = "AbstractPhil/sd15-flow-lune-flux" LUNE_WEIGHTS = "flux_t2_6_pose_t4_6_port_t1_4/checkpoint-00018765/unet/diffusion_pytorch_model.safetensors" LUNE_CONFIG = "flux_t2_6_pose_t4_6_port_t1_4/checkpoint-00018765/unet/config.json" # ============================================================================ # LOAD MODELS # ============================================================================ print("Loading CLIP...") clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval() print("Loading VAE...") vae = AutoencoderKL.from_pretrained( "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="vae", torch_dtype=DTYPE ).to(DEVICE).eval() # ============================================================================ # LOAD LUNE # ============================================================================ print(f"\nLoading Lune...") config_path = hf_hub_download(repo_id=LUNE_REPO, filename=LUNE_CONFIG) with open(config_path, 'r') as f: lune_config = json.load(f) print(f" prediction_type: {lune_config.get('prediction_type', 'NOT SET')}") unet = UNet2DConditionModel.from_config(lune_config).to(DEVICE).to(DTYPE).eval() weights_path = hf_hub_download(repo_id=LUNE_REPO, filename=LUNE_WEIGHTS) state_dict = load_file(weights_path) unet.load_state_dict(state_dict, strict=False) del state_dict gc.collect() for p in unet.parameters(): p.requires_grad = False print("✓ Lune ready!") # ============================================================================ # HELPERS # ============================================================================ def shift_sigma(sigma: torch.Tensor, shift: float = 3.0) -> torch.Tensor: """ Apply timestep shift (same as trainer). sigma_shifted = shift * sigma / (1 + (shift - 1) * sigma) """ return (shift * sigma) / (1 + (shift - 1) * sigma) @torch.inference_mode() def encode_prompt(prompt): inputs = clip_tok(prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True).to(DEVICE) return clip_enc(**inputs).last_hidden_state.to(DTYPE) # ============================================================================ # CORRECT SAMPLER (matches trainer exactly) # ============================================================================ @torch.inference_mode() def generate_lune( prompt: str, negative_prompt: str = "", seed: int = 42, steps: int = 30, cfg: float = 7.5, shift: float = 3.0, ): """ Correct Lune sampler matching trainer's flow convention. Trainer: x_t = sigma * noise + (1 - sigma) * data target = noise - data Sampling: - Start at sigma=1 (pure noise) - End at sigma=0 (clean data) - x_{sigma - dt} = x_sigma - v * dt (SUBTRACT because v points toward noise) """ torch.manual_seed(seed) cond = encode_prompt(prompt) uncond = encode_prompt(negative_prompt) if negative_prompt else encode_prompt("") # Start from pure noise (sigma=1) x = torch.randn(1, 4, 64, 64, device=DEVICE, dtype=DTYPE) # Sigma schedule: 1 → 0 (noise → data) # Linear spacing then apply shift sigmas_linear = torch.linspace(1, 0, steps + 1, device=DEVICE) sigmas = shift_sigma(sigmas_linear, shift=shift) print(f"Lune: '{prompt[:30]}' | {steps} steps, cfg={cfg}, shift={shift}") print(f" sigma range: {sigmas[0].item():.3f} → {sigmas[-1].item():.3f}") for i in range(steps): sigma = sigmas[i] sigma_next = sigmas[i + 1] dt = sigma - sigma_next # Positive, going from high to low sigma # Timestep for UNet: sigma * 1000 (matches trainer) timestep = sigma * 1000 t_input = timestep.view(1).to(DEVICE) # Predict velocity v = noise - data v_cond = unet(x, t_input, encoder_hidden_states=cond).sample v_uncond = unet(x, t_input, encoder_hidden_states=uncond).sample v = v_uncond + cfg * (v_cond - v_uncond) # Euler step: SUBTRACT velocity (going from noise toward data) # x_{sigma - dt} = x_sigma - v * dt x = x - v * dt if (i + 1) % (steps // 5) == 0: print(f" Step {i+1}/{steps}, sigma={sigma.item():.3f} → {sigma_next.item():.3f}") # Decode x = x / 0.18215 img = vae.decode(x).sample img = (img / 2 + 0.5).clamp(0, 1)[0].permute(1, 2, 0).cpu().float().numpy() return Image.fromarray((img * 255).astype(np.uint8)) # ============================================================================ # TEST # ============================================================================ print("\n" + "="*60) print("Testing Lune with CORRECT flow convention") print(" x_t = sigma*noise + (1-sigma)*data") print(" v = noise - data") print(" Sample by SUBTRACTING v") print("="*60) from IPython.display import display prompt = "a castle at sunset" print("\n--- shift=3.0 (default) ---") img = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=3.0) display(img) print("\n--- shift=2.5 (trainer default) ---") img2 = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=2.5) display(img2) print("\n--- shift=1.0 (no shift) ---") img3 = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=1.0) display(img3) # Grid comparison import matplotlib.pyplot as plt fig, axes = plt.subplots(1, 3, figsize=(15, 5)) for ax, (s, im) in zip(axes, [(3.0, img), (2.5, img2), (1.0, img3)]): ax.imshow(im) ax.set_title(f"shift={s}") ax.axis('off') plt.tight_layout() plt.show() print("\n✓ If images look correct, the output should be beautiful.")