AbstractPhil commited on
Commit
1d7a19e
·
verified ·
1 Parent(s): 78499ec

Create inference_sd15_flow_lune.py

Browse files
Files changed (1) hide show
  1. inference_sd15_flow_lune.py +195 -0
inference_sd15_flow_lune.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # SD1.5-Flow-Lune Inference - CORRECT (matches trainer)
3
+ # ============================================================================
4
+ # Trainer's flow convention:
5
+ # x_t = sigma * noise + (1 - sigma) * data
6
+ # target = noise - data (velocity points FROM data TO noise)
7
+ # sigma=0 → clean, sigma=1 → noise
8
+ #
9
+ # Sampling: sigma goes 1 → 0, so we SUBTRACT velocity
10
+ # x_{sigma - dt} = x_sigma - v * dt
11
+ # ============================================================================
12
+
13
+ !pip install -q diffusers transformers accelerate safetensors
14
+
15
+ import torch
16
+ import gc
17
+ from huggingface_hub import hf_hub_download
18
+ from diffusers import UNet2DConditionModel, AutoencoderKL
19
+ from transformers import CLIPTextModel, CLIPTokenizer
20
+ from safetensors.torch import load_file
21
+ from PIL import Image
22
+ import numpy as np
23
+ import json
24
+
25
+ torch.cuda.empty_cache()
26
+ gc.collect()
27
+
28
+ # ============================================================================
29
+ # CONFIG
30
+ # ============================================================================
31
+ DEVICE = "cuda"
32
+ DTYPE = torch.float16
33
+
34
+ LUNE_REPO = "AbstractPhil/sd15-flow-lune-flux"
35
+ LUNE_WEIGHTS = "flux_t2_6_pose_t4_6_port_t1_4/checkpoint-00018765/unet/diffusion_pytorch_model.safetensors"
36
+ LUNE_CONFIG = "flux_t2_6_pose_t4_6_port_t1_4/checkpoint-00018765/unet/config.json"
37
+
38
+ # ============================================================================
39
+ # LOAD MODELS
40
+ # ============================================================================
41
+ print("Loading CLIP...")
42
+ clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
43
+ clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval()
44
+
45
+ print("Loading VAE...")
46
+ vae = AutoencoderKL.from_pretrained(
47
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
48
+ subfolder="vae",
49
+ torch_dtype=DTYPE
50
+ ).to(DEVICE).eval()
51
+
52
+ # ============================================================================
53
+ # LOAD LUNE
54
+ # ============================================================================
55
+ print(f"\nLoading Lune...")
56
+ config_path = hf_hub_download(repo_id=LUNE_REPO, filename=LUNE_CONFIG)
57
+ with open(config_path, 'r') as f:
58
+ lune_config = json.load(f)
59
+
60
+ print(f" prediction_type: {lune_config.get('prediction_type', 'NOT SET')}")
61
+
62
+ unet = UNet2DConditionModel.from_config(lune_config).to(DEVICE).to(DTYPE).eval()
63
+
64
+ weights_path = hf_hub_download(repo_id=LUNE_REPO, filename=LUNE_WEIGHTS)
65
+ state_dict = load_file(weights_path)
66
+ unet.load_state_dict(state_dict, strict=False)
67
+
68
+ del state_dict
69
+ gc.collect()
70
+
71
+ for p in unet.parameters():
72
+ p.requires_grad = False
73
+
74
+ print("✓ Lune ready!")
75
+
76
+ # ============================================================================
77
+ # HELPERS
78
+ # ============================================================================
79
+ def shift_sigma(sigma: torch.Tensor, shift: float = 3.0) -> torch.Tensor:
80
+ """
81
+ Apply timestep shift (same as trainer).
82
+ sigma_shifted = shift * sigma / (1 + (shift - 1) * sigma)
83
+ """
84
+ return (shift * sigma) / (1 + (shift - 1) * sigma)
85
+
86
+ @torch.inference_mode()
87
+ def encode_prompt(prompt):
88
+ inputs = clip_tok(prompt, return_tensors="pt", padding="max_length",
89
+ max_length=77, truncation=True).to(DEVICE)
90
+ return clip_enc(**inputs).last_hidden_state.to(DTYPE)
91
+
92
+ # ============================================================================
93
+ # CORRECT SAMPLER (matches trainer exactly)
94
+ # ============================================================================
95
+ @torch.inference_mode()
96
+ def generate_lune(
97
+ prompt: str,
98
+ negative_prompt: str = "",
99
+ seed: int = 42,
100
+ steps: int = 30,
101
+ cfg: float = 7.5,
102
+ shift: float = 3.0,
103
+ ):
104
+ """
105
+ Correct Lune sampler matching trainer's flow convention.
106
+
107
+ Trainer:
108
+ x_t = sigma * noise + (1 - sigma) * data
109
+ target = noise - data
110
+
111
+ Sampling:
112
+ - Start at sigma=1 (pure noise)
113
+ - End at sigma=0 (clean data)
114
+ - x_{sigma - dt} = x_sigma - v * dt (SUBTRACT because v points toward noise)
115
+ """
116
+ torch.manual_seed(seed)
117
+
118
+ cond = encode_prompt(prompt)
119
+ uncond = encode_prompt(negative_prompt) if negative_prompt else encode_prompt("")
120
+
121
+ # Start from pure noise (sigma=1)
122
+ x = torch.randn(1, 4, 64, 64, device=DEVICE, dtype=DTYPE)
123
+
124
+ # Sigma schedule: 1 → 0 (noise → data)
125
+ # Linear spacing then apply shift
126
+ sigmas_linear = torch.linspace(1, 0, steps + 1, device=DEVICE)
127
+ sigmas = shift_sigma(sigmas_linear, shift=shift)
128
+
129
+ print(f"Lune: '{prompt[:30]}' | {steps} steps, cfg={cfg}, shift={shift}")
130
+ print(f" sigma range: {sigmas[0].item():.3f} → {sigmas[-1].item():.3f}")
131
+
132
+ for i in range(steps):
133
+ sigma = sigmas[i]
134
+ sigma_next = sigmas[i + 1]
135
+ dt = sigma - sigma_next # Positive, going from high to low sigma
136
+
137
+ # Timestep for UNet: sigma * 1000 (matches trainer)
138
+ timestep = sigma * 1000
139
+ t_input = timestep.view(1).to(DEVICE)
140
+
141
+ # Predict velocity v = noise - data
142
+ v_cond = unet(x, t_input, encoder_hidden_states=cond).sample
143
+ v_uncond = unet(x, t_input, encoder_hidden_states=uncond).sample
144
+ v = v_uncond + cfg * (v_cond - v_uncond)
145
+
146
+ # Euler step: SUBTRACT velocity (going from noise toward data)
147
+ # x_{sigma - dt} = x_sigma - v * dt
148
+ x = x - v * dt
149
+
150
+ if (i + 1) % (steps // 5) == 0:
151
+ print(f" Step {i+1}/{steps}, sigma={sigma.item():.3f} → {sigma_next.item():.3f}")
152
+
153
+ # Decode
154
+ x = x / 0.18215
155
+ img = vae.decode(x).sample
156
+ img = (img / 2 + 0.5).clamp(0, 1)[0].permute(1, 2, 0).cpu().float().numpy()
157
+ return Image.fromarray((img * 255).astype(np.uint8))
158
+
159
+ # ============================================================================
160
+ # TEST
161
+ # ============================================================================
162
+ print("\n" + "="*60)
163
+ print("Testing Lune with CORRECT flow convention")
164
+ print(" x_t = sigma*noise + (1-sigma)*data")
165
+ print(" v = noise - data")
166
+ print(" Sample by SUBTRACTING v")
167
+ print("="*60)
168
+
169
+ from IPython.display import display
170
+
171
+ prompt = "a castle at sunset"
172
+
173
+ print("\n--- shift=3.0 (default) ---")
174
+ img = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=3.0)
175
+ display(img)
176
+
177
+ print("\n--- shift=2.5 (trainer default) ---")
178
+ img2 = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=2.5)
179
+ display(img2)
180
+
181
+ print("\n--- shift=1.0 (no shift) ---")
182
+ img3 = generate_lune(prompt, seed=42, steps=30, cfg=7.5, shift=1.0)
183
+ display(img3)
184
+
185
+ # Grid comparison
186
+ import matplotlib.pyplot as plt
187
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
188
+ for ax, (s, im) in zip(axes, [(3.0, img), (2.5, img2), (1.0, img3)]):
189
+ ax.imshow(im)
190
+ ax.set_title(f"shift={s}")
191
+ ax.axis('off')
192
+ plt.tight_layout()
193
+ plt.show()
194
+
195
+ print("\n✓ If images look correct, the output should be beautiful.")