|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
!pip install -q diffusers transformers accelerate safetensors |
|
|
|
|
|
import torch |
|
|
import gc |
|
|
from huggingface_hub import hf_hub_download |
|
|
from diffusers import UNet2DConditionModel, AutoencoderKL |
|
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
from safetensors.torch import load_file |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import json |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEVICE = "cuda" |
|
|
DTYPE = torch.float16 |
|
|
|
|
|
LUNE_REPO = "AbstractPhil/sd15-flow-lune-flux" |
|
|
LUNE_WEIGHTS = "flux_t2_6_pose_t4_6_port_t1_4/checkpoint-00018765/unet/diffusion_pytorch_model.safetensors" |
|
|
LUNE_CONFIG = "flux_t2_6_pose_t4_6_port_t1_4/checkpoint-00018765/unet/config.json" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading CLIP...") |
|
|
clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
|
|
clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval() |
|
|
|
|
|
print("Loading VAE...") |
|
|
vae = AutoencoderKL.from_pretrained( |
|
|
"stable-diffusion-v1-5/stable-diffusion-v1-5", |
|
|
subfolder="vae", |
|
|
torch_dtype=DTYPE |
|
|
).to(DEVICE).eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\nLoading Lune...") |
|
|
config_path = hf_hub_download(repo_id=LUNE_REPO, filename=LUNE_CONFIG) |
|
|
with open(config_path, 'r') as f: |
|
|
lune_config = json.load(f) |
|
|
|
|
|
print(f" prediction_type: {lune_config.get('prediction_type', 'NOT SET')}") |
|
|
|
|
|
unet = UNet2DConditionModel.from_config(lune_config).to(DEVICE).to(DTYPE).eval() |
|
|
|
|
|
weights_path = hf_hub_download(repo_id=LUNE_REPO, filename=LUNE_WEIGHTS) |
|
|
state_dict = load_file(weights_path) |
|
|
unet.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
del state_dict |
|
|
gc.collect() |
|
|
|
|
|
for p in unet.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
print("β Lune ready!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def shift_sigma(sigma: torch.Tensor, shift: float = 3.0) -> torch.Tensor: |
|
|
""" |
|
|
Apply timestep shift (same as trainer). |
|
|
sigma_shifted = shift * sigma / (1 + (shift - 1) * sigma) |
|
|
""" |
|
|
return (shift * sigma) / (1 + (shift - 1) * sigma) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def encode_prompt(prompt): |
|
|
inputs = clip_tok(prompt, return_tensors="pt", padding="max_length", |
|
|
max_length=77, truncation=True).to(DEVICE) |
|
|
return clip_enc(**inputs).last_hidden_state.to(DTYPE) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate_lune( |
|
|
prompt: str, |
|
|
negative_prompt: str = "", |
|
|
seed: int = 42, |
|
|
steps: int = 30, |
|
|
cfg: float = 7.5, |
|
|
shift: float = 3.0, |
|
|
): |
|
|
""" |
|
|
Correct Lune sampler matching trainer's flow convention. |
|
|
|
|
|
Trainer: |
|
|
x_t = sigma * noise + (1 - sigma) * data |
|
|
target = noise - data |
|
|
|
|
|
Sampling: |
|
|
- Start at sigma=1 (pure noise) |
|
|
- End at sigma=0 (clean data) |
|
|
- x_{sigma - dt} = x_sigma - v * dt (SUBTRACT because v points toward noise) |
|
|
""" |
|
|
torch.manual_seed(seed) |
|
|
|
|
|
cond = encode_prompt(prompt) |
|
|
uncond = encode_prompt(negative_prompt) if negative_prompt else encode_prompt("") |
|
|
|
|
|
|
|
|
x = torch.randn(1, 4, 64, 64, device=DEVICE, dtype=DTYPE) |
|
|
|
|
|
|
|
|
|
|
|
sigmas_linear = torch.linspace(1, 0, steps + 1, device=DEVICE) |
|
|
sigmas = shift_sigma(sigmas_linear, shift=shift) |
|
|
|
|
|
print(f"Lune: '{prompt[:30]}' | {steps} steps, cfg={cfg}, shift={shift}") |
|
|
print(f" sigma range: {sigmas[0].item():.3f} β {sigmas[-1].item():.3f}") |
|
|
|
|
|
for i in range(steps): |
|
|
sigma = sigmas[i] |
|
|
sigma_next = sigmas[i + 1] |
|
|
dt = sigma - sigma_next |
|
|
|
|
|
|
|
|
timestep = sigma * 1000 |
|
|
t_input = timestep.view(1).to(DEVICE) |
|
|
|
|
|
|
|
|
v_cond = unet(x, t_input, encoder_hidden_states=cond).sample |
|
|
v_uncond = unet(x, t_input, encoder_hidden_states=uncond).sample |
|
|
v = v_uncond + cfg * (v_cond - v_uncond) |
|
|
|
|
|
|
|
|
|
|
|
x = x - v * dt |
|
|
|
|
|
if (i + 1) % (steps // 5) == 0: |
|
|
print(f" Step {i+1}/{steps}, sigma={sigma.item():.3f} β {sigma_next.item():.3f}") |
|
|
|
|
|
|
|
|
x = x / 0.18215 |
|
|
img = vae.decode(x).sample |
|
|
img = (img / 2 + 0.5).clamp(0, 1)[0].permute(1, 2, 0).cpu().float().numpy() |
|
|
return Image.fromarray((img * 255).astype(np.uint8)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("Testing Lune with CORRECT flow convention") |
|
|
print(" x_t = sigma*noise + (1-sigma)*data") |
|
|
print(" v = noise - data") |
|
|
print(" Sample by SUBTRACTING v") |
|
|
print("="*60) |
|
|
|
|
|
from IPython.display import display |
|
|
|
|
|
prompt = "a castle at sunset" |
|
|
|
|
|
print("\n--- shift=3.0 (default) ---") |
|
|
img = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=3.0) |
|
|
display(img) |
|
|
|
|
|
print("\n--- shift=2.5 (trainer default) ---") |
|
|
img2 = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=2.5) |
|
|
display(img2) |
|
|
|
|
|
print("\n--- shift=1.0 (no shift) ---") |
|
|
img3 = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=1.0) |
|
|
display(img3) |
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
fig, axes = plt.subplots(1, 3, figsize=(15, 5)) |
|
|
for ax, (s, im) in zip(axes, [(3.0, img), (2.5, img2), (1.0, img3)]): |
|
|
ax.imshow(im) |
|
|
ax.set_title(f"shift={s}") |
|
|
ax.axis('off') |
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
print("\nβ If images look correct, the output should be beautiful.") |