|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
!pip install -q diffusers transformers accelerate safetensors |
|
|
|
|
|
import torch |
|
|
import gc |
|
|
from huggingface_hub import hf_hub_download |
|
|
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler |
|
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEVICE = "cuda" |
|
|
DTYPE = torch.float16 |
|
|
|
|
|
SOL_REPO = "AbstractPhil/sd15-flow-matching" |
|
|
SOL_FILENAME = "sd15_flowmatch_david_weighted_efinal.pt" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading CLIP...") |
|
|
clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
|
|
clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval() |
|
|
|
|
|
print("Loading VAE...") |
|
|
vae = AutoencoderKL.from_pretrained( |
|
|
"stable-diffusion-v1-5/stable-diffusion-v1-5", |
|
|
subfolder="vae", |
|
|
torch_dtype=DTYPE |
|
|
).to(DEVICE).eval() |
|
|
|
|
|
print("Loading UNet...") |
|
|
unet = UNet2DConditionModel.from_pretrained( |
|
|
"stable-diffusion-v1-5/stable-diffusion-v1-5", |
|
|
subfolder="unet", |
|
|
torch_dtype=DTYPE, |
|
|
).to(DEVICE).eval() |
|
|
|
|
|
print("Loading DDPM Scheduler...") |
|
|
sched = DDPMScheduler(num_train_timesteps=1000) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\nLoading Sol from {SOL_REPO}...") |
|
|
weights_path = hf_hub_download(repo_id=SOL_REPO, filename=SOL_FILENAME) |
|
|
checkpoint = torch.load(weights_path, map_location="cpu") |
|
|
|
|
|
state_dict = checkpoint["student"] |
|
|
print(f" gstep: {checkpoint.get('gstep', 'unknown')}") |
|
|
|
|
|
if any(k.startswith("unet.") for k in state_dict.keys()): |
|
|
state_dict = {k.replace("unet.", ""): v for k, v in state_dict.items() if k.startswith("unet.")} |
|
|
|
|
|
state_dict = {k: v for k, v in state_dict.items() if not k.startswith(("hooks.", "local_heads."))} |
|
|
|
|
|
missing, unexpected = unet.load_state_dict(state_dict, strict=False) |
|
|
print(f" Loaded: {len(state_dict)} keys, missing: {len(missing)}, unexpected: {len(unexpected)}") |
|
|
|
|
|
del checkpoint, state_dict |
|
|
gc.collect() |
|
|
|
|
|
for p in unet.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
print("✓ Sol ready!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def alpha_sigma(t: torch.LongTensor): |
|
|
"""Get alpha and sigma from DDPM alphas_cumprod - matches trainer exactly.""" |
|
|
ac = sched.alphas_cumprod.to(DEVICE)[t] |
|
|
alpha = ac.sqrt().view(-1, 1, 1, 1).float() |
|
|
sigma = (1.0 - ac).sqrt().view(-1, 1, 1, 1).float() |
|
|
return alpha, sigma |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate_sol(prompt, negative_prompt="", seed=42, steps=30, cfg=7.5): |
|
|
""" |
|
|
Matches trainer's sample() method exactly: |
|
|
1. Use DDPM scheduler timesteps |
|
|
2. Model predicts velocity v |
|
|
3. Convert v → x0_hat → eps_hat |
|
|
4. Use sched.step(eps_hat, t, x_t) |
|
|
""" |
|
|
if seed is not None: |
|
|
torch.manual_seed(seed) |
|
|
|
|
|
# Encode prompts |
|
|
inputs = clip_tok(prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True).to(DEVICE) |
|
|
cond = clip_enc(**inputs).last_hidden_state.to(DTYPE) |
|
|
|
|
|
inputs_neg = clip_tok(negative_prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True).to(DEVICE) |
|
|
uncond = clip_enc(**inputs_neg).last_hidden_state.to(DTYPE) |
|
|
|
|
|
# Set scheduler timesteps |
|
|
sched.set_timesteps(steps, device=DEVICE) |
|
|
|
|
|
# Start from noise |
|
|
x_t = torch.randn(1, 4, 64, 64, device=DEVICE, dtype=DTYPE) |
|
|
|
|
|
print(f"Sampling '{prompt[:40]}' | {steps} steps, cfg={cfg}") |
|
|
|
|
|
for i, t_scalar in enumerate(sched.timesteps): |
|
|
t = torch.full((1,), t_scalar, device=DEVICE, dtype=torch.long) |
|
|
|
|
|
# Model predicts VELOCITY (not epsilon!) |
|
|
v_cond = unet(x_t.to(DTYPE), t, encoder_hidden_states=cond).sample |
|
|
v_uncond = unet(x_t.to(DTYPE), t, encoder_hidden_states=uncond).sample |
|
|
|
|
|
# CFG on velocity |
|
|
v_hat = v_uncond + cfg * (v_cond - v_uncond) |
|
|
|
|
|
# Convert velocity to epsilon (EXACTLY as trainer does) |
|
|
alpha, sigma = alpha_sigma(t) |
|
|
|
|
|
# v = alpha * eps - sigma * x0 |
|
|
# x_t = alpha * x0 + sigma * eps |
|
|
# Solve for x0: x0 = (alpha * x_t - sigma * v) / (alpha^2 + sigma^2) |
|
|
# Then: eps = (x_t - alpha * x0) / sigma |
|
|
denom = alpha**2 + sigma**2 |
|
|
x0_hat = (alpha * x_t.float() - sigma * v_hat.float()) / (denom + 1e-8) |
|
|
eps_hat = (x_t.float() - alpha * x0_hat) / (sigma + 1e-8) |
|
|
|
|
|
# Step with epsilon |
|
|
step_out = sched.step(eps_hat, t_scalar, x_t.float()) |
|
|
x_t = step_out.prev_sample.to(DTYPE) |
|
|
|
|
|
if (i + 1) % max(1, steps // 5) == 0: |
|
|
print(f" Step {i+1}/{steps}, t={t_scalar}") |
|
|
|
|
|
# Decode |
|
|
x_t = x_t / 0.18215 |
|
|
img = vae.decode(x_t).sample |
|
|
img = (img / 2 + 0.5).clamp(0, 1)[0].permute(1, 2, 0).cpu().float().numpy() |
|
|
|
|
|
return Image.fromarray((img * 255).astype(np.uint8)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("Generating test images with Sol (correct sampler)") |
|
|
print("="*60) |
|
|
|
|
|
from IPython.display import display |
|
|
|
|
|
prompts = [ |
|
|
"a castle at sunset", |
|
|
"a portrait of a woman", |
|
|
"a city street at night", |
|
|
] |
|
|
|
|
|
for prompt in prompts: |
|
|
print() |
|
|
img = generate_sol(prompt, negative_prompt="", seed=42, steps=30, cfg=7.5) |
|
|
display(img) |
|
|
|
|
|
print("\n✓ Done!") |