AbstractPhil commited on
Commit
083a11a
·
verified ·
1 Parent(s): d4b69df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -25
app.py CHANGED
@@ -370,7 +370,7 @@ print(f"✓ VAE loaded (scale={VAE_SCALE})")
370
 
371
 
372
  # ============================================================================
373
- # EULER DISCRETE FLOW MATCHING SAMPLER
374
  # Training uses: x_t = (1-t)*noise + t*data, v = data - noise
375
  # So t=0 is noise, t=1 is data. We sample from t=0 to t=1.
376
  # ============================================================================
@@ -403,51 +403,82 @@ def generate(
403
  vae.to(DEVICE)
404
 
405
  with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=DTYPE):
406
- # Encode prompt
407
  t5_in = t5_tok(prompt, max_length=128, padding="max_length",
408
  truncation=True, return_tensors="pt").to(DEVICE)
409
- t5_out = t5_enc(**t5_in).last_hidden_state
410
 
411
  clip_in = clip_tok(prompt, max_length=77, padding="max_length",
412
  truncation=True, return_tensors="pt").to(DEVICE)
413
- clip_out = clip_enc(**clip_in).pooler_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
  # Latent dimensions
416
  H_lat = height // 8
417
  W_lat = width // 8
418
  C = 16
419
- L = 128 # T5 sequence length
420
 
421
  # Start from noise (t=0 in this convention)
422
  x = torch.randn(1, H_lat * W_lat, C, device=DEVICE, dtype=DTYPE, generator=generator)
423
 
424
  # Position IDs
425
  img_ids = TinyFluxDeep.create_img_ids(1, H_lat, W_lat, DEVICE)
426
- txt_ids = TinyFluxDeep.create_txt_ids(L, DEVICE)
427
 
428
  # Timesteps: 0 -> 1 (noise to data) with Flux shift
429
  t_linear = torch.linspace(0, 1, num_inference_steps + 1, device=DEVICE)
430
- timesteps = flux_shift(t_linear, shift=SHIFT).clamp(1e-4, 1 - 1e-4)
431
 
432
  # Euler flow matching: x_{t+dt} = x_t + v * dt
433
- # v predicts direction from noise to data
434
  for i in range(num_inference_steps):
435
  t_curr = timesteps[i]
436
  t_next = timesteps[i + 1]
437
- dt = t_next - t_curr # Positive since going 0->1
438
 
439
  t_batch = t_curr.unsqueeze(0)
440
- guidance = torch.tensor([guidance_scale], device=DEVICE, dtype=DTYPE)
441
-
442
- v = model(
443
- hidden_states=x,
444
- encoder_hidden_states=t5_out,
445
- pooled_projections=clip_out,
446
- timestep=t_batch,
447
- img_ids=img_ids,
448
- txt_ids=txt_ids,
449
- guidance=guidance,
450
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  x = x + v * dt
452
 
453
  # Decode latents
@@ -509,8 +540,8 @@ with gr.Blocks(css=css) as demo:
509
  negative_prompt = gr.Text(
510
  label="Negative prompt",
511
  max_lines=1,
512
- placeholder="(not used)",
513
- visible=False,
514
  )
515
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
516
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
@@ -520,14 +551,14 @@ with gr.Blocks(css=css) as demo:
520
  height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512)
521
 
522
  with gr.Row():
523
- guidance_scale = gr.Slider(label="Guidance", minimum=1.0, maximum=10.0, step=0.5, value=3.5)
524
- num_inference_steps = gr.Slider(label="Steps", minimum=10, maximum=50, step=1, value=28)
525
 
526
  gr.Examples(examples=examples, inputs=[prompt])
527
 
528
  gr.Markdown("""
529
  ---
530
- **Notes:** Trained at 512×512. Best results at guidance 3.0-5.0, 20-30 steps.
531
  """)
532
 
533
  gr.on(
 
370
 
371
 
372
  # ============================================================================
373
+ # EULER DISCRETE FLOW MATCHING SAMPLER WITH CFG
374
  # Training uses: x_t = (1-t)*noise + t*data, v = data - noise
375
  # So t=0 is noise, t=1 is data. We sample from t=0 to t=1.
376
  # ============================================================================
 
403
  vae.to(DEVICE)
404
 
405
  with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=DTYPE):
406
+ # Encode prompts
407
  t5_in = t5_tok(prompt, max_length=128, padding="max_length",
408
  truncation=True, return_tensors="pt").to(DEVICE)
409
+ t5_cond = t5_enc(**t5_in).last_hidden_state
410
 
411
  clip_in = clip_tok(prompt, max_length=77, padding="max_length",
412
  truncation=True, return_tensors="pt").to(DEVICE)
413
+ clip_cond = clip_enc(**clip_in).pooler_output
414
+
415
+ # Encode negative prompt for CFG
416
+ do_cfg = guidance_scale > 1.0
417
+ if do_cfg:
418
+ neg_prompt = negative_prompt if negative_prompt else ""
419
+ t5_neg_in = t5_tok(neg_prompt, max_length=128, padding="max_length",
420
+ truncation=True, return_tensors="pt").to(DEVICE)
421
+ t5_uncond = t5_enc(**t5_neg_in).last_hidden_state
422
+
423
+ clip_neg_in = clip_tok(neg_prompt, max_length=77, padding="max_length",
424
+ truncation=True, return_tensors="pt").to(DEVICE)
425
+ clip_uncond = clip_enc(**clip_neg_in).pooler_output
426
+
427
+ # Batch for efficient forward pass
428
+ t5_batch = torch.cat([t5_uncond, t5_cond], dim=0)
429
+ clip_batch = torch.cat([clip_uncond, clip_cond], dim=0)
430
 
431
  # Latent dimensions
432
  H_lat = height // 8
433
  W_lat = width // 8
434
  C = 16
 
435
 
436
  # Start from noise (t=0 in this convention)
437
  x = torch.randn(1, H_lat * W_lat, C, device=DEVICE, dtype=DTYPE, generator=generator)
438
 
439
  # Position IDs
440
  img_ids = TinyFluxDeep.create_img_ids(1, H_lat, W_lat, DEVICE)
 
441
 
442
  # Timesteps: 0 -> 1 (noise to data) with Flux shift
443
  t_linear = torch.linspace(0, 1, num_inference_steps + 1, device=DEVICE)
444
+ timesteps = flux_shift(t_linear, shift=SHIFT)
445
 
446
  # Euler flow matching: x_{t+dt} = x_t + v * dt
 
447
  for i in range(num_inference_steps):
448
  t_curr = timesteps[i]
449
  t_next = timesteps[i + 1]
450
+ dt = t_next - t_curr
451
 
452
  t_batch = t_curr.unsqueeze(0)
453
+ guidance_embed = torch.tensor([guidance_scale], device=DEVICE, dtype=DTYPE)
454
+
455
+ if do_cfg:
456
+ # Batched forward pass for efficiency
457
+ x_batch = x.repeat(2, 1, 1)
458
+ img_ids_batch = img_ids
459
+ t_batch_2 = t_batch.repeat(2)
460
+ guidance_batch = guidance_embed.repeat(2)
461
+
462
+ v_batch = model(
463
+ hidden_states=x_batch,
464
+ encoder_hidden_states=t5_batch,
465
+ pooled_projections=clip_batch,
466
+ timestep=t_batch_2,
467
+ img_ids=img_ids_batch,
468
+ guidance=guidance_batch,
469
+ )
470
+ v_uncond, v_cond = v_batch.chunk(2, dim=0)
471
+ v = v_uncond + guidance_scale * (v_cond - v_uncond)
472
+ else:
473
+ v = model(
474
+ hidden_states=x,
475
+ encoder_hidden_states=t5_cond,
476
+ pooled_projections=clip_cond,
477
+ timestep=t_batch,
478
+ img_ids=img_ids,
479
+ guidance=guidance_embed,
480
+ )
481
+
482
  x = x + v * dt
483
 
484
  # Decode latents
 
540
  negative_prompt = gr.Text(
541
  label="Negative prompt",
542
  max_lines=1,
543
+ placeholder="blurry, distorted, low quality",
544
+ value="",
545
  )
546
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
547
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
551
  height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512)
552
 
553
  with gr.Row():
554
+ guidance_scale = gr.Slider(label="CFG Scale", minimum=1.0, maximum=10.0, step=0.5, value=5.0)
555
+ num_inference_steps = gr.Slider(label="Steps", minimum=10, maximum=50, step=1, value=25)
556
 
557
  gr.Examples(examples=examples, inputs=[prompt])
558
 
559
  gr.Markdown("""
560
  ---
561
+ **Notes:** Trained at 512×512. CFG 3.0-7.0 recommended, 20-30 steps.
562
  """)
563
 
564
  gr.on(