AbstractPhil commited on
Commit
e9ef39b
·
verified ·
1 Parent(s): 27fa1b7

Update inference_colab.py

Browse files
Files changed (1) hide show
  1. inference_colab.py +149 -61
inference_colab.py CHANGED
@@ -59,25 +59,68 @@ print(f"Loading TinyFlux from: {LOAD_FROM}")
59
  config = TinyFluxConfig()
60
  model = TinyFlux(config).to(DEVICE).to(DTYPE)
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  if LOAD_FROM == "hub":
63
- # Load best model from hub
64
- weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.safetensors")
65
- weights = load_file(weights_path)
 
 
 
66
  model.load_state_dict(weights)
67
- print(f"✓ Loaded from {HF_REPO}/model.safetensors")
68
  elif LOAD_FROM.startswith("hub:"):
69
  # Load specific checkpoint from hub
70
  ckpt_name = LOAD_FROM[4:]
71
- if not ckpt_name.endswith(".safetensors"):
72
- ckpt_name = f"checkpoints/{ckpt_name}.safetensors"
73
- weights_path = hf_hub_download(repo_id=HF_REPO, filename=ckpt_name)
74
- weights = load_file(weights_path)
75
- model.load_state_dict(weights)
76
- print(f"✓ Loaded from {HF_REPO}/{ckpt_name}")
 
 
 
 
 
 
 
 
 
 
77
  elif LOAD_FROM.startswith("local:"):
78
  # Load local file
79
  weights_path = LOAD_FROM[6:]
80
- weights = load_file(weights_path)
81
  model.load_state_dict(weights)
82
  print(f"✓ Loaded from {weights_path}")
83
  else:
@@ -121,6 +164,15 @@ def encode_prompt(prompt: str, max_length: int = 128):
121
 
122
  return t5_out, clip_pooled
123
 
 
 
 
 
 
 
 
 
 
124
  # ============================================================================
125
  # EULER DISCRETE FLOW MATCHING SAMPLER
126
  # ============================================================================
@@ -134,19 +186,20 @@ def euler_sample(
134
  height: int = 512,
135
  width: int = 512,
136
  seed: int = None,
 
 
137
  ):
138
  """
139
  Euler discrete sampler for flow matching.
140
 
141
- Flow matching formulation:
142
- x_t = (1 - t) * x_0 + t * x_1
143
- where x_0 = noise, x_1 = data
144
- velocity v = x_1 - x_0 = data - noise
145
-
146
- Sampling (t: 0 -> 1, noise -> data):
147
- x_{t+dt} = x_t + v_pred * dt
148
 
149
- With Flux shift for improved sampling distribution.
 
 
 
150
  """
151
  # Set seed
152
  if seed is not None:
@@ -156,42 +209,54 @@ def euler_sample(
156
  generator = None
157
 
158
  # Latent dimensions (VAE downscales by 8)
159
- H_lat = height // 8 # 64 for 512
160
- W_lat = width // 8 # 64 for 512
161
- C_lat = 16 # Flux VAE channels
162
 
163
- # Encode prompts
164
  t5_cond, clip_cond = encode_prompt(prompt)
 
 
165
  if guidance_scale > 1.0 and negative_prompt is not None:
166
  t5_uncond, clip_uncond = encode_prompt(negative_prompt)
 
 
167
  else:
168
  t5_uncond, clip_uncond = None, None
169
 
170
- # Start from pure noise (t=0 in flow matching convention)
171
- # Shape: (1, H*W, C)
172
  x = torch.randn(1, H_lat * W_lat, C_lat, device=DEVICE, dtype=DTYPE, generator=generator)
173
 
174
  # Create image position IDs for RoPE
175
  img_ids = TinyFlux.create_img_ids(1, H_lat, W_lat, DEVICE)
176
 
177
- # Timesteps: 0 -> 1 (noise -> data)
178
- # We use uniform spacing, model handles flux shift internally for training
179
- # For inference, linear timesteps work well
180
- timesteps = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE)
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- print(f"Sampling with {num_steps} Euler steps...")
183
 
184
  for i in range(num_steps):
185
  t_curr = timesteps[i]
186
  t_next = timesteps[i + 1]
187
  dt = t_next - t_curr
188
 
189
- t_batch = t_curr.unsqueeze(0) # (1,)
190
-
191
- # Guidance embedding (used during training with random values 1-5)
192
  guidance_embed = torch.tensor([guidance_scale], device=DEVICE, dtype=DTYPE)
193
 
194
- # Conditional prediction
195
  v_cond = model(
196
  hidden_states=x,
197
  encoder_hidden_states=t5_cond,
@@ -211,14 +276,16 @@ def euler_sample(
211
  img_ids=img_ids,
212
  guidance=guidance_embed,
213
  )
 
214
  v = v_uncond + guidance_scale * (v_cond - v_uncond)
215
  else:
216
  v = v_cond
217
 
218
- # Euler step: x_{t+dt} = x_t + v * dt
 
219
  x = x + v * dt
220
 
221
- if (i + 1) % 5 == 0 or i == num_steps - 1:
222
  print(f" Step {i+1}/{num_steps}, t={t_next.item():.3f}")
223
 
224
  # Reshape to image format: (1, H*W, C) -> (1, C, H, W)
@@ -235,14 +302,14 @@ def decode_latents(latents):
235
  # Flux VAE scaling
236
  latents = latents / vae.config.scaling_factor
237
 
238
- # Decode
239
- image = vae.decode(latents.float()).sample
240
 
241
  # Normalize to [0, 1]
242
  image = (image / 2 + 0.5).clamp(0, 1)
243
 
244
- # To PIL
245
- image = image[0].permute(1, 2, 0).cpu().numpy()
246
  image = (image * 255).astype(np.uint8)
247
 
248
  return Image.fromarray(image)
@@ -259,6 +326,8 @@ def generate(
259
  width: int = WIDTH,
260
  seed: int = SEED,
261
  save_path: str = None,
 
 
262
  ):
263
  """
264
  Generate an image from a text prompt.
@@ -272,14 +341,16 @@ def generate(
272
  width: Output width in pixels (must be divisible by 8)
273
  seed: Random seed (None for random)
274
  save_path: Path to save image (None to skip saving)
 
 
275
 
276
  Returns:
277
  PIL.Image
278
  """
279
  print(f"\nGenerating: '{prompt}'")
280
- print(f"Settings: {num_steps} steps, cfg={guidance_scale}, {width}x{height}, seed={seed}")
281
 
282
- # Sample latents
283
  latents = euler_sample(
284
  model=model,
285
  prompt=prompt,
@@ -289,6 +360,8 @@ def generate(
289
  height=height,
290
  width=width,
291
  seed=seed,
 
 
292
  )
293
 
294
  # Decode to image
@@ -315,6 +388,8 @@ def generate_batch(
315
  width: int = WIDTH,
316
  seed: int = SEED,
317
  output_dir: str = "./outputs",
 
 
318
  ):
319
  """Generate multiple images."""
320
  os.makedirs(output_dir, exist_ok=True)
@@ -333,6 +408,8 @@ def generate_batch(
333
  width=width,
334
  seed=img_seed,
335
  save_path=os.path.join(output_dir, f"{i:03d}.png"),
 
 
336
  )
337
  images.append(image)
338
 
@@ -345,28 +422,39 @@ if __name__ == "__main__" or True: # Always run in Colab
345
  print("\n" + "="*60)
346
  print("TinyFlux Inference Ready!")
347
  print("="*60)
348
- print(f"""
349
- Usage:
350
- # Single image
351
- image = generate("a photo of a cat")
352
- image.show()
353
-
354
- # With options
355
  image = generate(
356
- prompt="a beautiful sunset over mountains",
357
  negative_prompt="blurry, low quality",
358
- num_steps=30,
359
- guidance_scale=4.0,
360
  height=512,
361
  width=512,
362
- seed=42,
363
  save_path="output.png"
364
  )
365
-
366
- # Batch generation
367
- images = generate_batch([
368
- "a red sports car",
369
- "a blue ocean wave",
370
- "a green forest path",
371
- ], output_dir="./my_outputs")
372
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  config = TinyFluxConfig()
60
  model = TinyFlux(config).to(DEVICE).to(DTYPE)
61
 
62
+ def load_weights(path):
63
+ """Load weights from .safetensors or .pt file."""
64
+ if path.endswith(".safetensors"):
65
+ state_dict = load_file(path)
66
+ elif path.endswith(".pt"):
67
+ ckpt = torch.load(path, map_location=DEVICE, weights_only=False)
68
+ # Handle different checkpoint formats
69
+ if isinstance(ckpt, dict):
70
+ if "model" in ckpt:
71
+ state_dict = ckpt["model"]
72
+ elif "state_dict" in ckpt:
73
+ state_dict = ckpt["state_dict"]
74
+ else:
75
+ state_dict = ckpt
76
+ else:
77
+ state_dict = ckpt
78
+ else:
79
+ # Try safetensors first, then pt
80
+ try:
81
+ state_dict = load_file(path)
82
+ except:
83
+ state_dict = torch.load(path, map_location=DEVICE, weights_only=False)
84
+
85
+ # Strip "_orig_mod." prefix from keys (added by torch.compile)
86
+ if any(k.startswith("_orig_mod.") for k in state_dict.keys()):
87
+ print(" Stripping torch.compile prefix from state_dict keys...")
88
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
89
+
90
+ return state_dict
91
+
92
  if LOAD_FROM == "hub":
93
+ # Load best model from hub - try safetensors first, then pt
94
+ try:
95
+ weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.safetensors")
96
+ except:
97
+ weights_path = hf_hub_download(repo_id=HF_REPO, filename="model.pt")
98
+ weights = load_weights(weights_path)
99
  model.load_state_dict(weights)
100
+ print(f"✓ Loaded from {HF_REPO}")
101
  elif LOAD_FROM.startswith("hub:"):
102
  # Load specific checkpoint from hub
103
  ckpt_name = LOAD_FROM[4:]
104
+ # Try multiple extensions
105
+ for ext in [".safetensors", ".pt", ""]:
106
+ try:
107
+ if ckpt_name.endswith((".safetensors", ".pt")):
108
+ filename = ckpt_name if "/" in ckpt_name else f"checkpoints/{ckpt_name}"
109
+ else:
110
+ filename = f"checkpoints/{ckpt_name}{ext}"
111
+ weights_path = hf_hub_download(repo_id=HF_REPO, filename=filename)
112
+ weights = load_weights(weights_path)
113
+ model.load_state_dict(weights)
114
+ print(f"✓ Loaded from {HF_REPO}/{filename}")
115
+ break
116
+ except Exception as e:
117
+ continue
118
+ else:
119
+ raise ValueError(f"Could not find checkpoint: {ckpt_name}")
120
  elif LOAD_FROM.startswith("local:"):
121
  # Load local file
122
  weights_path = LOAD_FROM[6:]
123
+ weights = load_weights(weights_path)
124
  model.load_state_dict(weights)
125
  print(f"✓ Loaded from {weights_path}")
126
  else:
 
164
 
165
  return t5_out, clip_pooled
166
 
167
+ # ============================================================================
168
+ # FLOW MATCHING HELPERS
169
+ # ============================================================================
170
+ SHIFT = 3.0 # Flux shift parameter (must match training)
171
+
172
+ def flux_shift(t, s=SHIFT):
173
+ """Flux timestep shift - biases towards higher t (closer to data)."""
174
+ return s * t / (1 + (s - 1) * t)
175
+
176
  # ============================================================================
177
  # EULER DISCRETE FLOW MATCHING SAMPLER
178
  # ============================================================================
 
186
  height: int = 512,
187
  width: int = 512,
188
  seed: int = None,
189
+ direction: str = "forward",
190
+ use_shift: bool = True,
191
  ):
192
  """
193
  Euler discrete sampler for flow matching.
194
 
195
+ Args:
196
+ direction: "forward" (t:0→1, correct) or "reverse" (t:1→0, for old models)
197
+ use_shift: Whether to apply flux_shift to timesteps
 
 
 
 
198
 
199
+ Flow Matching formulation:
200
+ x_t = (1 - t) * noise + t * data
201
+ At t=0: noise, At t=1: data
202
+ Velocity v = data - noise
203
  """
204
  # Set seed
205
  if seed is not None:
 
209
  generator = None
210
 
211
  # Latent dimensions (VAE downscales by 8)
212
+ H_lat = height // 8
213
+ W_lat = width // 8
214
+ C_lat = 16
215
 
216
+ # Encode prompts (ensure correct dtype)
217
  t5_cond, clip_cond = encode_prompt(prompt)
218
+ t5_cond = t5_cond.to(DTYPE)
219
+ clip_cond = clip_cond.to(DTYPE)
220
  if guidance_scale > 1.0 and negative_prompt is not None:
221
  t5_uncond, clip_uncond = encode_prompt(negative_prompt)
222
+ t5_uncond = t5_uncond.to(DTYPE)
223
+ clip_uncond = clip_uncond.to(DTYPE)
224
  else:
225
  t5_uncond, clip_uncond = None, None
226
 
227
+ # Start from pure noise
 
228
  x = torch.randn(1, H_lat * W_lat, C_lat, device=DEVICE, dtype=DTYPE, generator=generator)
229
 
230
  # Create image position IDs for RoPE
231
  img_ids = TinyFlux.create_img_ids(1, H_lat, W_lat, DEVICE)
232
 
233
+ # Build timesteps based on direction
234
+ if direction == "forward":
235
+ t_linear = torch.linspace(0, 1, num_steps + 1, device=DEVICE, dtype=DTYPE)
236
+ dir_str = "01"
237
+ else: # reverse
238
+ t_linear = torch.linspace(1, 0, num_steps + 1, device=DEVICE, dtype=DTYPE)
239
+ dir_str = "1→0"
240
+
241
+ # Apply flux_shift if requested
242
+ if use_shift:
243
+ timesteps = flux_shift(t_linear)
244
+ shift_str = ", shifted"
245
+ else:
246
+ timesteps = t_linear
247
+ shift_str = ""
248
 
249
+ print(f"Sampling with {num_steps} Euler steps (t: {dir_str}{shift_str})...")
250
 
251
  for i in range(num_steps):
252
  t_curr = timesteps[i]
253
  t_next = timesteps[i + 1]
254
  dt = t_next - t_curr
255
 
256
+ t_batch = t_curr.unsqueeze(0)
 
 
257
  guidance_embed = torch.tensor([guidance_scale], device=DEVICE, dtype=DTYPE)
258
 
259
+ # Predict velocity: v = data - noise direction
260
  v_cond = model(
261
  hidden_states=x,
262
  encoder_hidden_states=t5_cond,
 
276
  img_ids=img_ids,
277
  guidance=guidance_embed,
278
  )
279
+ # CFG formula: v = v_uncond + scale * (v_cond - v_uncond)
280
  v = v_uncond + guidance_scale * (v_cond - v_uncond)
281
  else:
282
  v = v_cond
283
 
284
+ # Euler integration step: x_{t+dt} = x_t + v * dt
285
+ # v points towards data, dt > 0, so we move towards data
286
  x = x + v * dt
287
 
288
+ if (i + 1) % max(1, num_steps // 5) == 0 or i == num_steps - 1:
289
  print(f" Step {i+1}/{num_steps}, t={t_next.item():.3f}")
290
 
291
  # Reshape to image format: (1, H*W, C) -> (1, C, H, W)
 
302
  # Flux VAE scaling
303
  latents = latents / vae.config.scaling_factor
304
 
305
+ # Decode (match VAE dtype)
306
+ image = vae.decode(latents.to(vae.dtype)).sample
307
 
308
  # Normalize to [0, 1]
309
  image = (image / 2 + 0.5).clamp(0, 1)
310
 
311
+ # To PIL (need float32 for numpy)
312
+ image = image[0].float().permute(1, 2, 0).cpu().numpy()
313
  image = (image * 255).astype(np.uint8)
314
 
315
  return Image.fromarray(image)
 
326
  width: int = WIDTH,
327
  seed: int = SEED,
328
  save_path: str = None,
329
+ direction: str = "forward",
330
+ use_shift: bool = True,
331
  ):
332
  """
333
  Generate an image from a text prompt.
 
341
  width: Output width in pixels (must be divisible by 8)
342
  seed: Random seed (None for random)
343
  save_path: Path to save image (None to skip saving)
344
+ direction: "forward" (t:0→1) or "reverse" (t:1→0) for old models
345
+ use_shift: Whether to apply flux_shift to timesteps
346
 
347
  Returns:
348
  PIL.Image
349
  """
350
  print(f"\nGenerating: '{prompt}'")
351
+ print(f"Settings: {num_steps} steps, cfg={guidance_scale}, {width}x{height}, seed={seed}, dir={direction}, shift={use_shift}")
352
 
353
+ # Sample latents using Euler flow matching
354
  latents = euler_sample(
355
  model=model,
356
  prompt=prompt,
 
360
  height=height,
361
  width=width,
362
  seed=seed,
363
+ direction=direction,
364
+ use_shift=use_shift,
365
  )
366
 
367
  # Decode to image
 
388
  width: int = WIDTH,
389
  seed: int = SEED,
390
  output_dir: str = "./outputs",
391
+ direction: str = "forward",
392
+ use_shift: bool = True,
393
  ):
394
  """Generate multiple images."""
395
  os.makedirs(output_dir, exist_ok=True)
 
408
  width=width,
409
  seed=img_seed,
410
  save_path=os.path.join(output_dir, f"{i:03d}.png"),
411
+ direction=direction,
412
+ use_shift=use_shift,
413
  )
414
  images.append(image)
415
 
 
422
  print("\n" + "="*60)
423
  print("TinyFlux Inference Ready!")
424
  print("="*60)
 
 
 
 
 
 
 
425
  image = generate(
426
+ prompt="a cat in a tree by a sidewalk",
427
  negative_prompt="blurry, low quality",
428
+ num_steps=1,
429
+ guidance_scale=5.0,
430
  height=512,
431
  width=512,
432
+ seed=1024,
433
  save_path="output.png"
434
  )
435
+
436
+ # print(f"""
437
+ #Usage:
438
+ # # Single image
439
+ # image = generate("a photo of a cat")
440
+ # image.show()
441
+ #
442
+ # # With options
443
+ # image = generate(
444
+ # prompt="a beautiful sunset over mountains",
445
+ # negative_prompt="blurry, low quality",
446
+ # num_steps=30,
447
+ # guidance_scale=4.0,
448
+ # height=512,
449
+ # width=512,
450
+ # seed=42,
451
+ # save_path="output.png"
452
+ # )
453
+ #
454
+ # # Batch generation
455
+ # images = generate_batch([
456
+ # "a red sports car",
457
+ # "a blue ocean wave",
458
+ # "a green forest path",
459
+ # ], output_dir="./my_outputs")
460
+ #""")