AbstractPhil commited on
Commit
905499d
·
verified ·
1 Parent(s): bd0694a

Create inference_v3.py

Browse files
Files changed (1) hide show
  1. inference_v3.py +524 -0
inference_v3.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================================
2
+ # TinyFlux-Deep Inference Cell - With ExpertPredictor
3
+ # ============================================================================
4
+ # Run the model cell before this one (defines TinyFluxDeep, TinyFluxDeepConfig)
5
+ # Loads from: AbstractPhil/tiny-flux-deep or local checkpoint
6
+ #
7
+ # The ExpertPredictor runs standalone at inference - no SD1.5-flow needed.
8
+ # It predicts timestep expertise from (time_emb, clip_pooled).
9
+ # ============================================================================
10
+
11
+ import torch
12
+ from huggingface_hub import hf_hub_download
13
+ from safetensors.torch import load_file
14
+ from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
15
+ from diffusers import AutoencoderKL
16
+ from PIL import Image
17
+ import numpy as np
18
+ import os
19
+
20
+ # ============================================================================
21
+ # CONFIG
22
+ # ============================================================================
23
+ DEVICE = "cuda"
24
+ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
25
+
26
+ # Model loading
27
+ HF_REPO = "AbstractPhil/tiny-flux-deep"
28
+ # stable v3 step_316875
29
+ LOAD_FROM = "hub:step_346875" # "hub", "hub:step_XXXXX", "hub:step_XXXXX_ema", "local:/path/to/weights.safetensors"
30
+
31
+ # Generation settings
32
+ NUM_STEPS = 50
33
+ GUIDANCE_SCALE = 5.0 # Note: this is now just for CFG, not the broken guidance_in
34
+ HEIGHT = 512
35
+ WIDTH = 512
36
+ SEED = None
37
+ SHIFT = 3.0
38
+
39
+ # Model architecture (must match training)
40
+ USE_EXPERT_PREDICTOR = True
41
+ EXPERT_DIM = 1280
42
+ EXPERT_HIDDEN_DIM = 512
43
+
44
+ # ============================================================================
45
+ # LOAD TEXT ENCODERS
46
+ # ============================================================================
47
+ print("Loading text encoders...")
48
+ t5_tok = T5Tokenizer.from_pretrained("google/flan-t5-base")
49
+ t5_enc = T5EncoderModel.from_pretrained("google/flan-t5-base", torch_dtype=DTYPE).to(DEVICE).eval()
50
+
51
+ clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
52
+ clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval()
53
+
54
+ # ============================================================================
55
+ # LOAD VAE
56
+ # ============================================================================
57
+ print("Loading Flux VAE...")
58
+ vae = AutoencoderKL.from_pretrained(
59
+ "black-forest-labs/FLUX.1-schnell",
60
+ subfolder="vae",
61
+ torch_dtype=DTYPE
62
+ ).to(DEVICE).eval()
63
+
64
+ # ============================================================================
65
+ # LOAD TINYFLUX-DEEP MODEL
66
+ # ============================================================================
67
+ print(f"Loading TinyFlux-Deep from: {LOAD_FROM}")
68
+
69
+ # Config with ExpertPredictor (no guidance_embeds)
70
+ config = TinyFluxDeepConfig(
71
+ use_expert_predictor=USE_EXPERT_PREDICTOR,
72
+ expert_dim=EXPERT_DIM,
73
+ expert_hidden_dim=EXPERT_HIDDEN_DIM,
74
+ guidance_embeds=False, # Replaced by expert_predictor
75
+ )
76
+ model = TinyFluxDeep(config).to(DEVICE).to(DTYPE)
77
+
78
+ # Keys to handle during loading
79
+ DEPRECATED_KEYS = {
80
+ 'time_in.sin_basis',
81
+ 'guidance_in.sin_basis',
82
+ 'guidance_in.mlp.0.weight',
83
+ 'guidance_in.mlp.0.bias',
84
+ 'guidance_in.mlp.2.weight',
85
+ 'guidance_in.mlp.2.bias',
86
+ }
87
+
88
+
89
+ def load_weights(path):
90
+ """Load weights from .safetensors or .pt file."""
91
+ if path.endswith(".safetensors"):
92
+ state_dict = load_file(path)
93
+ elif path.endswith(".pt"):
94
+ ckpt = torch.load(path, map_location=DEVICE, weights_only=False)
95
+ if isinstance(ckpt, dict):
96
+ if "model" in ckpt:
97
+ state_dict = ckpt["model"]
98
+ elif "state_dict" in ckpt:
99
+ state_dict = ckpt["state_dict"]
100
+ else:
101
+ state_dict = ckpt
102
+ else:
103
+ state_dict = ckpt
104
+ else:
105
+ try:
106
+ state_dict = load_file(path)
107
+ except:
108
+ state_dict = torch.load(path, map_location=DEVICE, weights_only=False)
109
+
110
+ # Strip "_orig_mod." prefix from keys (added by torch.compile)
111
+ if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
112
+ print(" Stripping torch.compile prefix...")
113
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
114
+
115
+ return state_dict
116
+
117
+
118
+ def load_model_weights(model, weights, source_name):
119
+ """Load weights with architecture upgrade support."""
120
+ model_state = model.state_dict()
121
+
122
+ loaded = []
123
+ skipped_deprecated = []
124
+ skipped_shape = []
125
+ missing_new = []
126
+
127
+ # Load matching weights
128
+ for k, v in weights.items():
129
+ if k in DEPRECATED_KEYS or k.startswith('guidance_in.'):
130
+ skipped_deprecated.append(k)
131
+ elif k in model_state:
132
+ if v.shape == model_state[k].shape:
133
+ model_state[k] = v
134
+ loaded.append(k)
135
+ else:
136
+ skipped_shape.append((k, v.shape, model_state[k].shape))
137
+ else:
138
+ # Key not in model (maybe old architecture)
139
+ skipped_deprecated.append(k)
140
+
141
+ # Find new keys not in checkpoint
142
+ for k in model_state:
143
+ if k not in weights and not any(k.startswith(d.split('.')[0]) for d in DEPRECATED_KEYS if '.' in d):
144
+ missing_new.append(k)
145
+
146
+ # Apply loaded weights
147
+ model.load_state_dict(model_state, strict=False)
148
+
149
+ # Report
150
+ print(f" ✓ Loaded: {len(loaded)} weights")
151
+ if skipped_deprecated:
152
+ print(f" ✓ Skipped deprecated: {len(skipped_deprecated)} (guidance_in, etc)")
153
+ if skipped_shape:
154
+ print(f" ⚠ Shape mismatch: {len(skipped_shape)}")
155
+ for k, old, new in skipped_shape[:3]:
156
+ print(f" {k}: {old} vs {new}")
157
+ if missing_new:
158
+ # Group by module
159
+ modules = set(k.split('.')[0] for k in missing_new)
160
+ print(f" ℹ New modules (fresh init): {modules}")
161
+
162
+ print(f"✓ Loaded from {source_name}")
163
+
164
+
165
+ if LOAD_FROM == "hub":
166
+ try:
167
+ weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.safetensors")
168
+ except:
169
+ weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.pt")
170
+ weights = load_weights(weights_path)
171
+ load_model_weights(model, weights, HF_REPO)
172
+
173
+ elif LOAD_FROM.startswith("hub:"):
174
+ ckpt_name = LOAD_FROM[4:]
175
+ for ext in [".safetensors", ".pt", ""]:
176
+ try:
177
+ if ckpt_name.endswith((".safetensors", ".pt")):
178
+ filename = ckpt_name if "/" in ckpt_name else f"checkpoints/{ckpt_name}"
179
+ else:
180
+ filename = f"checkpoints/{ckpt_name}{ext}"
181
+ weights_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
182
+ weights = load_weights(weights_path)
183
+ load_model_weights(model, weights, f"{HF_REPO}/{filename}")
184
+ break
185
+ except Exception as e:
186
+ continue
187
+ else:
188
+ raise ValueError(f"Could not find checkpoint: {ckpt_name}")
189
+
190
+ elif LOAD_FROM.startswith("local:"):
191
+ weights_path = LOAD_FROM[6:]
192
+ weights = load_weights(weights_path)
193
+ load_model_weights(model, weights, weights_path)
194
+
195
+ else:
196
+ raise ValueError(f"Unknown LOAD_FROM: {LOAD_FROM}")
197
+
198
+ model.eval()
199
+
200
+ # Count parameters
201
+ total_params = sum(p.numel() for p in model.parameters())
202
+ expert_params = sum(p.numel() for p in model.expert_predictor.parameters()) if model.expert_predictor else 0
203
+ print(f"Model params: {total_params:,} (expert_predictor: {expert_params:,})")
204
+
205
+ # ============================================================================
206
+ # ENCODING FUNCTIONS
207
+ # ============================================================================
208
+ @torch.inference_mode()
209
+ def encode_prompt(prompt: str, max_length: int = 128):
210
+ """Encode prompt with flan-t5-base and CLIP-L."""
211
+ t5_in = t5_tok(
212
+ prompt,
213
+ max_length=max_length,
214
+ padding="max_length",
215
+ truncation=True,
216
+ return_tensors="pt"
217
+ ).to(DEVICE)
218
+ t5_out = t5_enc(
219
+ input_ids=t5_in.input_ids,
220
+ attention_mask=t5_in.attention_mask
221
+ ).last_hidden_state
222
+
223
+ clip_in = clip_tok(
224
+ prompt,
225
+ max_length=77,
226
+ padding="max_length",
227
+ truncation=True,
228
+ return_tensors="pt"
229
+ ).to(DEVICE)
230
+ clip_out = clip_enc(
231
+ input_ids=clip_in.input_ids,
232
+ attention_mask=clip_in.attention_mask
233
+ )
234
+ clip_pooled = clip_out.pooler_output
235
+
236
+ return t5_out.to(DTYPE), clip_pooled.to(DTYPE)
237
+
238
+
239
+ # ============================================================================
240
+ # FLOW MATCHING HELPERS
241
+ # ============================================================================
242
+ def flux_shift(t, s=SHIFT):
243
+ """Flux timestep shift - biases towards higher t (closer to data)."""
244
+ return s * t / (1 + (s - 1) * t)
245
+
246
+
247
+ # ============================================================================
248
+ # EULER DISCRETE FLOW MATCHING SAMPLER
249
+ # ============================================================================
250
+ @torch.inference_mode()
251
+ def euler_sample(
252
+ model,
253
+ prompt: str,
254
+ negative_prompt: str = "",
255
+ num_steps: int = 28,
256
+ guidance_scale: float = 3.5,
257
+ height: int = 512,
258
+ width: int = 512,
259
+ seed: int = None,
260
+ ):
261
+ """
262
+ Euler discrete sampler for rectified flow matching.
263
+
264
+ Flow Matching formulation:
265
+ x_t = (1 - t) * noise + t * data
266
+ At t=0: noise, At t=1: data
267
+ Velocity v = data - noise (constant)
268
+
269
+ Sampling: Integrate from t=0 (noise) to t=1 (data)
270
+
271
+ With ExpertPredictor:
272
+ - No guidance embedding needed
273
+ - Expert predictor runs internally from (time_emb, clip_pooled)
274
+ - CFG still works via positive/negative prompt difference
275
+ """
276
+ if seed is not None:
277
+ torch.manual_seed(seed)
278
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
279
+ else:
280
+ generator = None
281
+
282
+ H_lat = height // 8
283
+ W_lat = width // 8
284
+ C_lat = 16
285
+
286
+ # Encode prompts
287
+ t5_cond, clip_cond = encode_prompt(prompt)
288
+ if guidance_scale > 1.0 and negative_prompt is not None:
289
+ t5_uncond, clip_uncond = encode_prompt(negative_prompt)
290
+ else:
291
+ t5_uncond, clip_uncond = None, None
292
+
293
+ # Start from pure noise (t=0)
294
+ x = torch.randn(1, H_lat * W_lat, C_lat, device=DEVICE, dtype=DTYPE, generator=generator)
295
+
296
+ # Create image position IDs
297
+ img_ids = TinyFluxDeep.create_img_ids(1, H_lat, W_lat, DEVICE)
298
+
299
+ # Timesteps: 0 → 1 with flux shift
300
+ t_linear = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE)
301
+ timesteps = flux_shift(t_linear, s=SHIFT)
302
+
303
+ print(f"Sampling with {num_steps} Euler steps (t: 0→1, shifted)...")
304
+
305
+ for i in range(num_steps):
306
+ t_curr = timesteps[i]
307
+ t_next = timesteps[i + 1]
308
+ dt = t_next - t_curr
309
+
310
+ t_batch = t_curr.unsqueeze(0)
311
+
312
+ # Predict velocity (no guidance embedding, expert_predictor runs internally)
313
+ v_cond = model(
314
+ hidden_states=x,
315
+ encoder_hidden_states=t5_cond,
316
+ pooled_projections=clip_cond,
317
+ timestep=t_batch,
318
+ img_ids=img_ids,
319
+ # No guidance parameter - ExpertPredictor handles timestep awareness
320
+ # No expert_features - predictor runs standalone at inference
321
+ )
322
+
323
+ # Classifier-free guidance (true CFG via prompt difference)
324
+ if guidance_scale > 1.0 and t5_uncond is not None:
325
+ v_uncond = model(
326
+ hidden_states=x,
327
+ encoder_hidden_states=t5_uncond,
328
+ pooled_projections=clip_uncond,
329
+ timestep=t_batch,
330
+ img_ids=img_ids,
331
+ )
332
+ v = v_uncond + guidance_scale * (v_cond - v_uncond)
333
+ else:
334
+ v = v_cond
335
+
336
+ # Euler step: x_{t+dt} = x_t + v * dt
337
+ x = x + v * dt
338
+
339
+ if (i + 1) % max(1, num_steps // 5) == 0 or i == num_steps - 1:
340
+ print(f" Step {i+1}/{num_steps}, t={t_next.item():.3f}")
341
+
342
+ # Reshape: (1, H*W, C) -> (1, C, H, W)
343
+ latents = x.reshape(1, H_lat, W_lat, C_lat).permute(0, 3, 1, 2)
344
+
345
+ return latents
346
+
347
+
348
+ # ============================================================================
349
+ # DECODE LATENTS TO IMAGE
350
+ # ============================================================================
351
+ @torch.inference_mode()
352
+ def decode_latents(latents):
353
+ """Decode VAE latents to PIL Image."""
354
+ latents = latents / vae.config.scaling_factor
355
+ image = vae.decode(latents.to(vae.dtype)).sample
356
+ image = (image / 2 + 0.5).clamp(0, 1)
357
+ image = image[0].float().permute(1, 2, 0).cpu().numpy()
358
+ image = (image * 255).astype(np.uint8)
359
+ return Image.fromarray(image)
360
+
361
+
362
+ # ============================================================================
363
+ # MAIN GENERATION FUNCTION
364
+ # ============================================================================
365
+ def generate(
366
+ prompt: str,
367
+ negative_prompt: str = "",
368
+ num_steps: int = NUM_STEPS,
369
+ guidance_scale: float = GUIDANCE_SCALE,
370
+ height: int = HEIGHT,
371
+ width: int = WIDTH,
372
+ seed: int = SEED,
373
+ save_path: str = None,
374
+ ):
375
+ """
376
+ Generate an image from a text prompt.
377
+
378
+ Args:
379
+ prompt: Text description of desired image
380
+ negative_prompt: What to avoid (empty string for none)
381
+ num_steps: Number of Euler steps (20-50 recommended)
382
+ guidance_scale: CFG scale (1.0=none, 3-7 typical)
383
+ height: Output height in pixels (divisible by 8)
384
+ width: Output width in pixels (divisible by 8)
385
+ seed: Random seed (None for random)
386
+ save_path: Path to save image (None to skip)
387
+
388
+ Returns:
389
+ PIL.Image
390
+ """
391
+ print(f"\nGenerating: '{prompt}'")
392
+ print(f"Settings: {num_steps} steps, cfg={guidance_scale}, {width}x{height}, seed={seed}")
393
+
394
+ latents = euler_sample(
395
+ model=model,
396
+ prompt=prompt,
397
+ negative_prompt=negative_prompt,
398
+ num_steps=num_steps,
399
+ guidance_scale=guidance_scale,
400
+ height=height,
401
+ width=width,
402
+ seed=seed,
403
+ )
404
+
405
+ print("Decoding latents...")
406
+ image = decode_latents(latents)
407
+
408
+ if save_path:
409
+ image.save(save_path)
410
+ print(f"✓ Saved to {save_path}")
411
+
412
+ print("✓ Done!")
413
+ return image
414
+
415
+
416
+ # ============================================================================
417
+ # BATCH GENERATION
418
+ # ============================================================================
419
+ def generate_batch(
420
+ prompts: list,
421
+ negative_prompt: str = "",
422
+ num_steps: int = NUM_STEPS,
423
+ guidance_scale: float = GUIDANCE_SCALE,
424
+ height: int = HEIGHT,
425
+ width: int = WIDTH,
426
+ seed: int = SEED,
427
+ output_dir: str = "./outputs",
428
+ ):
429
+ """Generate multiple images."""
430
+ os.makedirs(output_dir, exist_ok=True)
431
+ images = []
432
+
433
+ for i, prompt in enumerate(prompts):
434
+ img_seed = seed + i if seed is not None else None
435
+ image = generate(
436
+ prompt=prompt,
437
+ negative_prompt=negative_prompt,
438
+ num_steps=num_steps,
439
+ guidance_scale=guidance_scale,
440
+ height=height,
441
+ width=width,
442
+ seed=img_seed,
443
+ save_path=os.path.join(output_dir, f"{i:03d}.png"),
444
+ )
445
+ images.append(image)
446
+
447
+ return images
448
+
449
+
450
+ # ============================================================================
451
+ # COMPARISON FUNCTION (old vs new model)
452
+ # ============================================================================
453
+ def compare_with_without_expert(
454
+ prompt: str,
455
+ negative_prompt: str = "",
456
+ num_steps: int = 30,
457
+ guidance_scale: float = 5.0,
458
+ seed: int = 42,
459
+ save_prefix: str = "compare",
460
+ ):
461
+ """
462
+ Generate same prompt with expert_predictor enabled vs disabled.
463
+ Useful for A/B testing the effect of the distilled expert.
464
+ """
465
+ # With expert
466
+ image_with = generate(
467
+ prompt=prompt,
468
+ negative_prompt=negative_prompt,
469
+ num_steps=num_steps,
470
+ guidance_scale=guidance_scale,
471
+ seed=seed,
472
+ save_path=f"{save_prefix}_with_expert.png",
473
+ )
474
+
475
+ # Without expert (temporarily disable)
476
+ old_predictor = model.expert_predictor
477
+ model.expert_predictor = None
478
+
479
+ image_without = generate(
480
+ prompt=prompt,
481
+ negative_prompt=negative_prompt,
482
+ num_steps=num_steps,
483
+ guidance_scale=guidance_scale,
484
+ seed=seed,
485
+ save_path=f"{save_prefix}_without_expert.png",
486
+ )
487
+
488
+ # Restore
489
+ model.expert_predictor = old_predictor
490
+
491
+ # Side by side
492
+ combined = Image.new('RGB', (image_with.width * 2, image_with.height))
493
+ combined.paste(image_without, (0, 0))
494
+ combined.paste(image_with, (image_with.width, 0))
495
+ combined.save(f"{save_prefix}_comparison.png")
496
+
497
+ print(f"\n✓ Comparison saved: {save_prefix}_comparison.png")
498
+ print(f" Left: without expert | Right: with expert")
499
+
500
+ return image_without, image_with, combined
501
+
502
+
503
+ # ============================================================================
504
+ # QUICK TEST
505
+ # ============================================================================
506
+ print("\n" + "="*60)
507
+ print("TinyFlux-Deep + ExpertPredictor Inference Ready!")
508
+ print("="*60)
509
+ print(f"Config: {config.hidden_size} hidden, {config.num_attention_heads} heads")
510
+ print(f" {config.num_double_layers} double, {config.num_single_layers} single layers")
511
+ print(f" ExpertPredictor: {config.use_expert_predictor} (dim={config.expert_dim})")
512
+ print(f"Total: {total_params:,} parameters")
513
+
514
+ # Example usage:
515
+ image = generate(
516
+ prompt="subject, animal, feline, lion, natural habitat",
517
+ negative_prompt="",
518
+ num_steps=50,
519
+ guidance_scale=5.0,
520
+ seed=4545,
521
+ width=512,
522
+ height=512,
523
+ )
524
+ image