File size: 6,953 Bytes
1d7a19e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# ============================================================================
# SD1.5-Flow-Lune Inference - CORRECT (matches trainer)
# ============================================================================
# Trainer's flow convention:
#   x_t = sigma * noise + (1 - sigma) * data
#   target = noise - data  (velocity points FROM data TO noise)
#   sigma=0 → clean, sigma=1 → noise
#
# Sampling: sigma goes 1 → 0, so we SUBTRACT velocity
#   x_{sigma - dt} = x_sigma - v * dt
# ============================================================================

!pip install -q diffusers transformers accelerate safetensors

import torch
import gc
from huggingface_hub import hf_hub_download
from diffusers import UNet2DConditionModel, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
from safetensors.torch import load_file
from PIL import Image
import numpy as np
import json

torch.cuda.empty_cache()
gc.collect()

# ============================================================================
# CONFIG
# ============================================================================
DEVICE = "cuda"
DTYPE = torch.float16

LUNE_REPO = "AbstractPhil/sd15-flow-lune-flux"
LUNE_WEIGHTS = "flux_t2_6_pose_t4_6_port_t1_4/checkpoint-00018765/unet/diffusion_pytorch_model.safetensors"
LUNE_CONFIG = "flux_t2_6_pose_t4_6_port_t1_4/checkpoint-00018765/unet/config.json"

# ============================================================================
# LOAD MODELS
# ============================================================================
print("Loading CLIP...")
clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval()

print("Loading VAE...")
vae = AutoencoderKL.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
    subfolder="vae",
    torch_dtype=DTYPE
).to(DEVICE).eval()

# ============================================================================
# LOAD LUNE
# ============================================================================
print(f"\nLoading Lune...")
config_path = hf_hub_download(repo_id=LUNE_REPO, filename=LUNE_CONFIG)
with open(config_path, 'r') as f:
    lune_config = json.load(f)

print(f"  prediction_type: {lune_config.get('prediction_type', 'NOT SET')}")

unet = UNet2DConditionModel.from_config(lune_config).to(DEVICE).to(DTYPE).eval()

weights_path = hf_hub_download(repo_id=LUNE_REPO, filename=LUNE_WEIGHTS)
state_dict = load_file(weights_path)
unet.load_state_dict(state_dict, strict=False)

del state_dict
gc.collect()

for p in unet.parameters():
    p.requires_grad = False

print("✓ Lune ready!")

# ============================================================================
# HELPERS
# ============================================================================
def shift_sigma(sigma: torch.Tensor, shift: float = 3.0) -> torch.Tensor:
    """
    Apply timestep shift (same as trainer).
    sigma_shifted = shift * sigma / (1 + (shift - 1) * sigma)
    """
    return (shift * sigma) / (1 + (shift - 1) * sigma)

@torch.inference_mode()
def encode_prompt(prompt):
    inputs = clip_tok(prompt, return_tensors="pt", padding="max_length", 
                      max_length=77, truncation=True).to(DEVICE)
    return clip_enc(**inputs).last_hidden_state.to(DTYPE)

# ============================================================================
# CORRECT SAMPLER (matches trainer exactly)
# ============================================================================
@torch.inference_mode()
def generate_lune(
    prompt: str,
    negative_prompt: str = "",
    seed: int = 42,
    steps: int = 30,
    cfg: float = 7.5,
    shift: float = 3.0,
):
    """
    Correct Lune sampler matching trainer's flow convention.
    
    Trainer:
        x_t = sigma * noise + (1 - sigma) * data
        target = noise - data
        
    Sampling:
        - Start at sigma=1 (pure noise)
        - End at sigma=0 (clean data)
        - x_{sigma - dt} = x_sigma - v * dt  (SUBTRACT because v points toward noise)
    """
    torch.manual_seed(seed)
    
    cond = encode_prompt(prompt)
    uncond = encode_prompt(negative_prompt) if negative_prompt else encode_prompt("")
    
    # Start from pure noise (sigma=1)
    x = torch.randn(1, 4, 64, 64, device=DEVICE, dtype=DTYPE)
    
    # Sigma schedule: 1 → 0 (noise → data)
    # Linear spacing then apply shift
    sigmas_linear = torch.linspace(1, 0, steps + 1, device=DEVICE)
    sigmas = shift_sigma(sigmas_linear, shift=shift)
    
    print(f"Lune: '{prompt[:30]}' | {steps} steps, cfg={cfg}, shift={shift}")
    print(f"  sigma range: {sigmas[0].item():.3f} → {sigmas[-1].item():.3f}")
    
    for i in range(steps):
        sigma = sigmas[i]
        sigma_next = sigmas[i + 1]
        dt = sigma - sigma_next  # Positive, going from high to low sigma
        
        # Timestep for UNet: sigma * 1000 (matches trainer)
        timestep = sigma * 1000
        t_input = timestep.view(1).to(DEVICE)
        
        # Predict velocity v = noise - data
        v_cond = unet(x, t_input, encoder_hidden_states=cond).sample
        v_uncond = unet(x, t_input, encoder_hidden_states=uncond).sample
        v = v_uncond + cfg * (v_cond - v_uncond)
        
        # Euler step: SUBTRACT velocity (going from noise toward data)
        # x_{sigma - dt} = x_sigma - v * dt
        x = x - v * dt
        
        if (i + 1) % (steps // 5) == 0:
            print(f"  Step {i+1}/{steps}, sigma={sigma.item():.3f} → {sigma_next.item():.3f}")
    
    # Decode
    x = x / 0.18215
    img = vae.decode(x).sample
    img = (img / 2 + 0.5).clamp(0, 1)[0].permute(1, 2, 0).cpu().float().numpy()
    return Image.fromarray((img * 255).astype(np.uint8))

# ============================================================================
# TEST
# ============================================================================
print("\n" + "="*60)
print("Testing Lune with CORRECT flow convention")
print("  x_t = sigma*noise + (1-sigma)*data")
print("  v = noise - data")
print("  Sample by SUBTRACTING v")
print("="*60)

from IPython.display import display

prompt = "a castle at sunset"

print("\n--- shift=3.0 (default) ---")
img = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=3.0)
display(img)

print("\n--- shift=2.5 (trainer default) ---")
img2 = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=2.5)
display(img2)

print("\n--- shift=1.0 (no shift) ---")
img3 = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=1.0)
display(img3)

# Grid comparison
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for ax, (s, im) in zip(axes, [(3.0, img), (2.5, img2), (1.0, img3)]):
    ax.imshow(im)
    ax.set_title(f"shift={s}")
    ax.axis('off')
plt.tight_layout()
plt.show()

print("\n✓ If images look correct, the output should be beautiful.")