tinyflux-experts / inference_sd15_flow_lune.py
AbstractPhil's picture
Create inference_sd15_flow_lune.py
1d7a19e verified
# ============================================================================
# SD1.5-Flow-Lune Inference - CORRECT (matches trainer)
# ============================================================================
# Trainer's flow convention:
# x_t = sigma * noise + (1 - sigma) * data
# target = noise - data (velocity points FROM data TO noise)
# sigma=0 β†’ clean, sigma=1 β†’ noise
#
# Sampling: sigma goes 1 β†’ 0, so we SUBTRACT velocity
# x_{sigma - dt} = x_sigma - v * dt
# ============================================================================
!pip install -q diffusers transformers accelerate safetensors
import torch
import gc
from huggingface_hub import hf_hub_download
from diffusers import UNet2DConditionModel, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
from safetensors.torch import load_file
from PIL import Image
import numpy as np
import json
torch.cuda.empty_cache()
gc.collect()
# ============================================================================
# CONFIG
# ============================================================================
DEVICE = "cuda"
DTYPE = torch.float16
LUNE_REPO = "AbstractPhil/sd15-flow-lune-flux"
LUNE_WEIGHTS = "flux_t2_6_pose_t4_6_port_t1_4/checkpoint-00018765/unet/diffusion_pytorch_model.safetensors"
LUNE_CONFIG = "flux_t2_6_pose_t4_6_port_t1_4/checkpoint-00018765/unet/config.json"
# ============================================================================
# LOAD MODELS
# ============================================================================
print("Loading CLIP...")
clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval()
print("Loading VAE...")
vae = AutoencoderKL.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
subfolder="vae",
torch_dtype=DTYPE
).to(DEVICE).eval()
# ============================================================================
# LOAD LUNE
# ============================================================================
print(f"\nLoading Lune...")
config_path = hf_hub_download(repo_id=LUNE_REPO, filename=LUNE_CONFIG)
with open(config_path, 'r') as f:
lune_config = json.load(f)
print(f" prediction_type: {lune_config.get('prediction_type', 'NOT SET')}")
unet = UNet2DConditionModel.from_config(lune_config).to(DEVICE).to(DTYPE).eval()
weights_path = hf_hub_download(repo_id=LUNE_REPO, filename=LUNE_WEIGHTS)
state_dict = load_file(weights_path)
unet.load_state_dict(state_dict, strict=False)
del state_dict
gc.collect()
for p in unet.parameters():
p.requires_grad = False
print("βœ“ Lune ready!")
# ============================================================================
# HELPERS
# ============================================================================
def shift_sigma(sigma: torch.Tensor, shift: float = 3.0) -> torch.Tensor:
"""
Apply timestep shift (same as trainer).
sigma_shifted = shift * sigma / (1 + (shift - 1) * sigma)
"""
return (shift * sigma) / (1 + (shift - 1) * sigma)
@torch.inference_mode()
def encode_prompt(prompt):
inputs = clip_tok(prompt, return_tensors="pt", padding="max_length",
max_length=77, truncation=True).to(DEVICE)
return clip_enc(**inputs).last_hidden_state.to(DTYPE)
# ============================================================================
# CORRECT SAMPLER (matches trainer exactly)
# ============================================================================
@torch.inference_mode()
def generate_lune(
prompt: str,
negative_prompt: str = "",
seed: int = 42,
steps: int = 30,
cfg: float = 7.5,
shift: float = 3.0,
):
"""
Correct Lune sampler matching trainer's flow convention.
Trainer:
x_t = sigma * noise + (1 - sigma) * data
target = noise - data
Sampling:
- Start at sigma=1 (pure noise)
- End at sigma=0 (clean data)
- x_{sigma - dt} = x_sigma - v * dt (SUBTRACT because v points toward noise)
"""
torch.manual_seed(seed)
cond = encode_prompt(prompt)
uncond = encode_prompt(negative_prompt) if negative_prompt else encode_prompt("")
# Start from pure noise (sigma=1)
x = torch.randn(1, 4, 64, 64, device=DEVICE, dtype=DTYPE)
# Sigma schedule: 1 β†’ 0 (noise β†’ data)
# Linear spacing then apply shift
sigmas_linear = torch.linspace(1, 0, steps + 1, device=DEVICE)
sigmas = shift_sigma(sigmas_linear, shift=shift)
print(f"Lune: '{prompt[:30]}' | {steps} steps, cfg={cfg}, shift={shift}")
print(f" sigma range: {sigmas[0].item():.3f} β†’ {sigmas[-1].item():.3f}")
for i in range(steps):
sigma = sigmas[i]
sigma_next = sigmas[i + 1]
dt = sigma - sigma_next # Positive, going from high to low sigma
# Timestep for UNet: sigma * 1000 (matches trainer)
timestep = sigma * 1000
t_input = timestep.view(1).to(DEVICE)
# Predict velocity v = noise - data
v_cond = unet(x, t_input, encoder_hidden_states=cond).sample
v_uncond = unet(x, t_input, encoder_hidden_states=uncond).sample
v = v_uncond + cfg * (v_cond - v_uncond)
# Euler step: SUBTRACT velocity (going from noise toward data)
# x_{sigma - dt} = x_sigma - v * dt
x = x - v * dt
if (i + 1) % (steps // 5) == 0:
print(f" Step {i+1}/{steps}, sigma={sigma.item():.3f} β†’ {sigma_next.item():.3f}")
# Decode
x = x / 0.18215
img = vae.decode(x).sample
img = (img / 2 + 0.5).clamp(0, 1)[0].permute(1, 2, 0).cpu().float().numpy()
return Image.fromarray((img * 255).astype(np.uint8))
# ============================================================================
# TEST
# ============================================================================
print("\n" + "="*60)
print("Testing Lune with CORRECT flow convention")
print(" x_t = sigma*noise + (1-sigma)*data")
print(" v = noise - data")
print(" Sample by SUBTRACTING v")
print("="*60)
from IPython.display import display
prompt = "a castle at sunset"
print("\n--- shift=3.0 (default) ---")
img = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=3.0)
display(img)
print("\n--- shift=2.5 (trainer default) ---")
img2 = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=2.5)
display(img2)
print("\n--- shift=1.0 (no shift) ---")
img3 = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=1.0)
display(img3)
# Grid comparison
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for ax, (s, im) in zip(axes, [(3.0, img), (2.5, img2), (1.0, img3)]):
ax.imshow(im)
ax.set_title(f"shift={s}")
ax.axis('off')
plt.tight_layout()
plt.show()
print("\nβœ“ If images look correct, the output should be beautiful.")