sd15-flow-matching / sd15_flow_sol_ddpm_inference
AbstractPhil's picture
Create sd15_flow_sol_ddpm_inference
77fceae verified
# ============================================================================
# SD1.5-Flow-Sol Correct Inference (Colab Cell)
# ============================================================================
# Matches trainer's sample() method exactly:
# - DDPM scheduler timesteps
# - Specifically aligned for the SOL training pipeline to ensure accurate inference.
# - Model predicts velocity
# - Convert velocity → epsilon for scheduler stepping
# ============================================================================
!pip install -q diffusers transformers accelerate safetensors
import torch
import gc
from huggingface_hub import hf_hub_download
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from PIL import Image
import numpy as np
torch.cuda.empty_cache()
gc.collect()
# ============================================================================
# CONFIG
# ============================================================================
DEVICE = "cuda"
DTYPE = torch.float16
SOL_REPO = "AbstractPhil/sd15-flow-matching"
SOL_FILENAME = "sd15_flowmatch_david_weighted_efinal.pt"
# ============================================================================
# LOAD MODELS
# ============================================================================
print("Loading CLIP...")
clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval()
print("Loading VAE...")
vae = AutoencoderKL.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
subfolder="vae",
torch_dtype=DTYPE
).to(DEVICE).eval()
print("Loading UNet...")
unet = UNet2DConditionModel.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
subfolder="unet",
torch_dtype=DTYPE,
).to(DEVICE).eval()
print("Loading DDPM Scheduler...")
sched = DDPMScheduler(num_train_timesteps=1000)
# ============================================================================
# LOAD SOL WEIGHTS
# ============================================================================
print(f"\nLoading Sol from {SOL_REPO}...")
weights_path = hf_hub_download(repo_id=SOL_REPO, filename=SOL_FILENAME)
checkpoint = torch.load(weights_path, map_location="cpu")
state_dict = checkpoint["student"]
print(f" gstep: {checkpoint.get('gstep', 'unknown')}")
if any(k.startswith("unet.") for k in state_dict.keys()):
state_dict = {k.replace("unet.", ""): v for k, v in state_dict.items() if k.startswith("unet.")}
state_dict = {k: v for k, v in state_dict.items() if not k.startswith(("hooks.", "local_heads."))}
missing, unexpected = unet.load_state_dict(state_dict, strict=False)
print(f" Loaded: {len(state_dict)} keys, missing: {len(missing)}, unexpected: {len(unexpected)}")
del checkpoint, state_dict
gc.collect()
for p in unet.parameters():
p.requires_grad = False
print("✓ Sol ready!")
# ============================================================================
# HELPER: Alpha/Sigma from DDPM schedule (matches trainer)
# ============================================================================
def alpha_sigma(t: torch.LongTensor):
"""Get alpha and sigma from DDPM alphas_cumprod - matches trainer exactly."""
ac = sched.alphas_cumprod.to(DEVICE)[t]
alpha = ac.sqrt().view(-1, 1, 1, 1).float()
sigma = (1.0 - ac).sqrt().view(-1, 1, 1, 1).float()
return alpha, sigma
# ============================================================================
# CORRECT SAMPLER (matches trainer's sample() method)
# ============================================================================
@torch.inference_mode()
def generate_sol(prompt, negative_prompt="", seed=42, steps=30, cfg=7.5):
"""
Matches trainer's sample() method exactly:
1. Use DDPM scheduler timesteps
2. Model predicts velocity v
3. Convert v → x0_hat → eps_hat
4. Use sched.step(eps_hat, t, x_t)
"""
if seed is not None:
torch.manual_seed(seed)
# Encode prompts
inputs = clip_tok(prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True).to(DEVICE)
cond = clip_enc(**inputs).last_hidden_state.to(DTYPE)
inputs_neg = clip_tok(negative_prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True).to(DEVICE)
uncond = clip_enc(**inputs_neg).last_hidden_state.to(DTYPE)
# Set scheduler timesteps
sched.set_timesteps(steps, device=DEVICE)
# Start from noise
x_t = torch.randn(1, 4, 64, 64, device=DEVICE, dtype=DTYPE)
print(f"Sampling '{prompt[:40]}' | {steps} steps, cfg={cfg}")
for i, t_scalar in enumerate(sched.timesteps):
t = torch.full((1,), t_scalar, device=DEVICE, dtype=torch.long)
# Model predicts VELOCITY (not epsilon!)
v_cond = unet(x_t.to(DTYPE), t, encoder_hidden_states=cond).sample
v_uncond = unet(x_t.to(DTYPE), t, encoder_hidden_states=uncond).sample
# CFG on velocity
v_hat = v_uncond + cfg * (v_cond - v_uncond)
# Convert velocity to epsilon (EXACTLY as trainer does)
alpha, sigma = alpha_sigma(t)
# v = alpha * eps - sigma * x0
# x_t = alpha * x0 + sigma * eps
# Solve for x0: x0 = (alpha * x_t - sigma * v) / (alpha^2 + sigma^2)
# Then: eps = (x_t - alpha * x0) / sigma
denom = alpha**2 + sigma**2
x0_hat = (alpha * x_t.float() - sigma * v_hat.float()) / (denom + 1e-8)
eps_hat = (x_t.float() - alpha * x0_hat) / (sigma + 1e-8)
# Step with epsilon
step_out = sched.step(eps_hat, t_scalar, x_t.float())
x_t = step_out.prev_sample.to(DTYPE)
if (i + 1) % max(1, steps // 5) == 0:
print(f" Step {i+1}/{steps}, t={t_scalar}")
# Decode
x_t = x_t / 0.18215
img = vae.decode(x_t).sample
img = (img / 2 + 0.5).clamp(0, 1)[0].permute(1, 2, 0).cpu().float().numpy()
return Image.fromarray((img * 255).astype(np.uint8))
# ============================================================================
# TEST
# ============================================================================
print("\n" + "="*60)
print("Generating test images with Sol (correct sampler)")
print("="*60)
from IPython.display import display
prompts = [
"a castle at sunset",
"a portrait of a woman",
"a city street at night",
]
for prompt in prompts:
print()
img = generate_sol(prompt, negative_prompt="", seed=42, steps=30, cfg=7.5)
display(img)
print("\n✓ Done!")