AbstractPhil commited on
Commit
77fceae
·
verified ·
1 Parent(s): 64b8c56

Create sd15_flow_sol_ddpm_inference

Browse files
Files changed (1) hide show
  1. sd15_flow_sol_ddpm_inference +178 -0
sd15_flow_sol_ddpm_inference ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # SD1.5-Flow-Sol Correct Inference (Colab Cell)
3
+ # ============================================================================
4
+ # Matches trainer's sample() method exactly:
5
+ # - DDPM scheduler timesteps
6
+ # - Specifically aligned for the SOL training pipeline to ensure accurate inference.
7
+ # - Model predicts velocity
8
+ # - Convert velocity → epsilon for scheduler stepping
9
+ # ============================================================================
10
+
11
+ !pip install -q diffusers transformers accelerate safetensors
12
+
13
+ import torch
14
+ import gc
15
+ from huggingface_hub import hf_hub_download
16
+ from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
17
+ from transformers import CLIPTextModel, CLIPTokenizer
18
+ from PIL import Image
19
+ import numpy as np
20
+
21
+ torch.cuda.empty_cache()
22
+ gc.collect()
23
+
24
+ # ============================================================================
25
+ # CONFIG
26
+ # ============================================================================
27
+ DEVICE = "cuda"
28
+ DTYPE = torch.float16
29
+
30
+ SOL_REPO = "AbstractPhil/sd15-flow-matching"
31
+ SOL_FILENAME = "sd15_flowmatch_david_weighted_efinal.pt"
32
+
33
+ # ============================================================================
34
+ # LOAD MODELS
35
+ # ============================================================================
36
+ print("Loading CLIP...")
37
+ clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
38
+ clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval()
39
+
40
+ print("Loading VAE...")
41
+ vae = AutoencoderKL.from_pretrained(
42
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
43
+ subfolder="vae",
44
+ torch_dtype=DTYPE
45
+ ).to(DEVICE).eval()
46
+
47
+ print("Loading UNet...")
48
+ unet = UNet2DConditionModel.from_pretrained(
49
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
50
+ subfolder="unet",
51
+ torch_dtype=DTYPE,
52
+ ).to(DEVICE).eval()
53
+
54
+ print("Loading DDPM Scheduler...")
55
+ sched = DDPMScheduler(num_train_timesteps=1000)
56
+
57
+ # ============================================================================
58
+ # LOAD SOL WEIGHTS
59
+ # ============================================================================
60
+ print(f"\nLoading Sol from {SOL_REPO}...")
61
+ weights_path = hf_hub_download(repo_id=SOL_REPO, filename=SOL_FILENAME)
62
+ checkpoint = torch.load(weights_path, map_location="cpu")
63
+
64
+ state_dict = checkpoint["student"]
65
+ print(f" gstep: {checkpoint.get('gstep', 'unknown')}")
66
+
67
+ if any(k.startswith("unet.") for k in state_dict.keys()):
68
+ state_dict = {k.replace("unet.", ""): v for k, v in state_dict.items() if k.startswith("unet.")}
69
+
70
+ state_dict = {k: v for k, v in state_dict.items() if not k.startswith(("hooks.", "local_heads."))}
71
+
72
+ missing, unexpected = unet.load_state_dict(state_dict, strict=False)
73
+ print(f" Loaded: {len(state_dict)} keys, missing: {len(missing)}, unexpected: {len(unexpected)}")
74
+
75
+ del checkpoint, state_dict
76
+ gc.collect()
77
+
78
+ for p in unet.parameters():
79
+ p.requires_grad = False
80
+
81
+ print("✓ Sol ready!")
82
+
83
+ # ============================================================================
84
+ # HELPER: Alpha/Sigma from DDPM schedule (matches trainer)
85
+ # ============================================================================
86
+ def alpha_sigma(t: torch.LongTensor):
87
+ """Get alpha and sigma from DDPM alphas_cumprod - matches trainer exactly."""
88
+ ac = sched.alphas_cumprod.to(DEVICE)[t]
89
+ alpha = ac.sqrt().view(-1, 1, 1, 1).float()
90
+ sigma = (1.0 - ac).sqrt().view(-1, 1, 1, 1).float()
91
+ return alpha, sigma
92
+
93
+ # ============================================================================
94
+ # CORRECT SAMPLER (matches trainer's sample() method)
95
+ # ============================================================================
96
+ @torch.inference_mode()
97
+ def generate_sol(prompt, negative_prompt="", seed=42, steps=30, cfg=7.5):
98
+ """
99
+ Matches trainer's sample() method exactly:
100
+ 1. Use DDPM scheduler timesteps
101
+ 2. Model predicts velocity v
102
+ 3. Convert v → x0_hat → eps_hat
103
+ 4. Use sched.step(eps_hat, t, x_t)
104
+ """
105
+ if seed is not None:
106
+ torch.manual_seed(seed)
107
+
108
+ # Encode prompts
109
+ inputs = clip_tok(prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True).to(DEVICE)
110
+ cond = clip_enc(**inputs).last_hidden_state.to(DTYPE)
111
+
112
+ inputs_neg = clip_tok(negative_prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True).to(DEVICE)
113
+ uncond = clip_enc(**inputs_neg).last_hidden_state.to(DTYPE)
114
+
115
+ # Set scheduler timesteps
116
+ sched.set_timesteps(steps, device=DEVICE)
117
+
118
+ # Start from noise
119
+ x_t = torch.randn(1, 4, 64, 64, device=DEVICE, dtype=DTYPE)
120
+
121
+ print(f"Sampling '{prompt[:40]}' | {steps} steps, cfg={cfg}")
122
+
123
+ for i, t_scalar in enumerate(sched.timesteps):
124
+ t = torch.full((1,), t_scalar, device=DEVICE, dtype=torch.long)
125
+
126
+ # Model predicts VELOCITY (not epsilon!)
127
+ v_cond = unet(x_t.to(DTYPE), t, encoder_hidden_states=cond).sample
128
+ v_uncond = unet(x_t.to(DTYPE), t, encoder_hidden_states=uncond).sample
129
+
130
+ # CFG on velocity
131
+ v_hat = v_uncond + cfg * (v_cond - v_uncond)
132
+
133
+ # Convert velocity to epsilon (EXACTLY as trainer does)
134
+ alpha, sigma = alpha_sigma(t)
135
+
136
+ # v = alpha * eps - sigma * x0
137
+ # x_t = alpha * x0 + sigma * eps
138
+ # Solve for x0: x0 = (alpha * x_t - sigma * v) / (alpha^2 + sigma^2)
139
+ # Then: eps = (x_t - alpha * x0) / sigma
140
+ denom = alpha**2 + sigma**2
141
+ x0_hat = (alpha * x_t.float() - sigma * v_hat.float()) / (denom + 1e-8)
142
+ eps_hat = (x_t.float() - alpha * x0_hat) / (sigma + 1e-8)
143
+
144
+ # Step with epsilon
145
+ step_out = sched.step(eps_hat, t_scalar, x_t.float())
146
+ x_t = step_out.prev_sample.to(DTYPE)
147
+
148
+ if (i + 1) % max(1, steps // 5) == 0:
149
+ print(f" Step {i+1}/{steps}, t={t_scalar}")
150
+
151
+ # Decode
152
+ x_t = x_t / 0.18215
153
+ img = vae.decode(x_t).sample
154
+ img = (img / 2 + 0.5).clamp(0, 1)[0].permute(1, 2, 0).cpu().float().numpy()
155
+
156
+ return Image.fromarray((img * 255).astype(np.uint8))
157
+
158
+ # ============================================================================
159
+ # TEST
160
+ # ============================================================================
161
+ print("\n" + "="*60)
162
+ print("Generating test images with Sol (correct sampler)")
163
+ print("="*60)
164
+
165
+ from IPython.display import display
166
+
167
+ prompts = [
168
+ "a castle at sunset",
169
+ "a portrait of a woman",
170
+ "a city street at night",
171
+ ]
172
+
173
+ for prompt in prompts:
174
+ print()
175
+ img = generate_sol(prompt, negative_prompt="", seed=42, steps=30, cfg=7.5)
176
+ display(img)
177
+
178
+ print("\n✓ Done!")