AbstractPhil commited on
Commit
d4b69df
·
verified ·
1 Parent(s): a29d3c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -371,6 +371,8 @@ print(f"✓ VAE loaded (scale={VAE_SCALE})")
371
 
372
  # ============================================================================
373
  # EULER DISCRETE FLOW MATCHING SAMPLER
 
 
374
  # ============================================================================
375
  def flux_shift(t, shift=SHIFT):
376
  """Flux time shift: s*t / (1 + (s-1)*t)"""
@@ -416,22 +418,23 @@ def generate(
416
  C = 16
417
  L = 128 # T5 sequence length
418
 
419
- # Start from noise (t=1 in flow matching)
420
  x = torch.randn(1, H_lat * W_lat, C, device=DEVICE, dtype=DTYPE, generator=generator)
421
 
422
  # Position IDs
423
  img_ids = TinyFluxDeep.create_img_ids(1, H_lat, W_lat, DEVICE)
424
  txt_ids = TinyFluxDeep.create_txt_ids(L, DEVICE)
425
 
426
- # Timesteps: 1 -> 0 with Flux shift
427
- t_linear = torch.linspace(1, 0, num_inference_steps + 1, device=DEVICE)
428
- timesteps = flux_shift(t_linear, shift=SHIFT)
429
 
430
- # Euler discrete flow matching: x_{t-dt} = x_t + v * dt
 
431
  for i in range(num_inference_steps):
432
  t_curr = timesteps[i]
433
  t_next = timesteps[i + 1]
434
- dt = t_next - t_curr # Negative since going 1->0
435
 
436
  t_batch = t_curr.unsqueeze(0)
437
  guidance = torch.tensor([guidance_scale], device=DEVICE, dtype=DTYPE)
@@ -493,7 +496,6 @@ with gr.Blocks(css=css) as demo:
493
  with gr.Row():
494
  prompt = gr.Text(
495
  label="Prompt",
496
- value="cat",
497
  show_label=False,
498
  max_lines=2,
499
  placeholder="Enter your prompt...",
 
371
 
372
  # ============================================================================
373
  # EULER DISCRETE FLOW MATCHING SAMPLER
374
+ # Training uses: x_t = (1-t)*noise + t*data, v = data - noise
375
+ # So t=0 is noise, t=1 is data. We sample from t=0 to t=1.
376
  # ============================================================================
377
  def flux_shift(t, shift=SHIFT):
378
  """Flux time shift: s*t / (1 + (s-1)*t)"""
 
418
  C = 16
419
  L = 128 # T5 sequence length
420
 
421
+ # Start from noise (t=0 in this convention)
422
  x = torch.randn(1, H_lat * W_lat, C, device=DEVICE, dtype=DTYPE, generator=generator)
423
 
424
  # Position IDs
425
  img_ids = TinyFluxDeep.create_img_ids(1, H_lat, W_lat, DEVICE)
426
  txt_ids = TinyFluxDeep.create_txt_ids(L, DEVICE)
427
 
428
+ # Timesteps: 0 -> 1 (noise to data) with Flux shift
429
+ t_linear = torch.linspace(0, 1, num_inference_steps + 1, device=DEVICE)
430
+ timesteps = flux_shift(t_linear, shift=SHIFT).clamp(1e-4, 1 - 1e-4)
431
 
432
+ # Euler flow matching: x_{t+dt} = x_t + v * dt
433
+ # v predicts direction from noise to data
434
  for i in range(num_inference_steps):
435
  t_curr = timesteps[i]
436
  t_next = timesteps[i + 1]
437
+ dt = t_next - t_curr # Positive since going 0->1
438
 
439
  t_batch = t_curr.unsqueeze(0)
440
  guidance = torch.tensor([guidance_scale], device=DEVICE, dtype=DTYPE)
 
496
  with gr.Row():
497
  prompt = gr.Text(
498
  label="Prompt",
 
499
  show_label=False,
500
  max_lines=2,
501
  placeholder="Enter your prompt...",