Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -370,7 +370,7 @@ print(f"✓ VAE loaded (scale={VAE_SCALE})")
|
|
| 370 |
|
| 371 |
|
| 372 |
# ============================================================================
|
| 373 |
-
# EULER DISCRETE FLOW MATCHING SAMPLER
|
| 374 |
# Training uses: x_t = (1-t)*noise + t*data, v = data - noise
|
| 375 |
# So t=0 is noise, t=1 is data. We sample from t=0 to t=1.
|
| 376 |
# ============================================================================
|
|
@@ -403,51 +403,82 @@ def generate(
|
|
| 403 |
vae.to(DEVICE)
|
| 404 |
|
| 405 |
with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=DTYPE):
|
| 406 |
-
# Encode
|
| 407 |
t5_in = t5_tok(prompt, max_length=128, padding="max_length",
|
| 408 |
truncation=True, return_tensors="pt").to(DEVICE)
|
| 409 |
-
|
| 410 |
|
| 411 |
clip_in = clip_tok(prompt, max_length=77, padding="max_length",
|
| 412 |
truncation=True, return_tensors="pt").to(DEVICE)
|
| 413 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
|
| 415 |
# Latent dimensions
|
| 416 |
H_lat = height // 8
|
| 417 |
W_lat = width // 8
|
| 418 |
C = 16
|
| 419 |
-
L = 128 # T5 sequence length
|
| 420 |
|
| 421 |
# Start from noise (t=0 in this convention)
|
| 422 |
x = torch.randn(1, H_lat * W_lat, C, device=DEVICE, dtype=DTYPE, generator=generator)
|
| 423 |
|
| 424 |
# Position IDs
|
| 425 |
img_ids = TinyFluxDeep.create_img_ids(1, H_lat, W_lat, DEVICE)
|
| 426 |
-
txt_ids = TinyFluxDeep.create_txt_ids(L, DEVICE)
|
| 427 |
|
| 428 |
# Timesteps: 0 -> 1 (noise to data) with Flux shift
|
| 429 |
t_linear = torch.linspace(0, 1, num_inference_steps + 1, device=DEVICE)
|
| 430 |
-
timesteps = flux_shift(t_linear, shift=SHIFT)
|
| 431 |
|
| 432 |
# Euler flow matching: x_{t+dt} = x_t + v * dt
|
| 433 |
-
# v predicts direction from noise to data
|
| 434 |
for i in range(num_inference_steps):
|
| 435 |
t_curr = timesteps[i]
|
| 436 |
t_next = timesteps[i + 1]
|
| 437 |
-
dt = t_next - t_curr
|
| 438 |
|
| 439 |
t_batch = t_curr.unsqueeze(0)
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
x = x + v * dt
|
| 452 |
|
| 453 |
# Decode latents
|
|
@@ -509,8 +540,8 @@ with gr.Blocks(css=css) as demo:
|
|
| 509 |
negative_prompt = gr.Text(
|
| 510 |
label="Negative prompt",
|
| 511 |
max_lines=1,
|
| 512 |
-
placeholder="
|
| 513 |
-
|
| 514 |
)
|
| 515 |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
|
| 516 |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
|
@@ -520,14 +551,14 @@ with gr.Blocks(css=css) as demo:
|
|
| 520 |
height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512)
|
| 521 |
|
| 522 |
with gr.Row():
|
| 523 |
-
guidance_scale = gr.Slider(label="
|
| 524 |
-
num_inference_steps = gr.Slider(label="Steps", minimum=10, maximum=50, step=1, value=
|
| 525 |
|
| 526 |
gr.Examples(examples=examples, inputs=[prompt])
|
| 527 |
|
| 528 |
gr.Markdown("""
|
| 529 |
---
|
| 530 |
-
**Notes:** Trained at 512×512.
|
| 531 |
""")
|
| 532 |
|
| 533 |
gr.on(
|
|
|
|
| 370 |
|
| 371 |
|
| 372 |
# ============================================================================
|
| 373 |
+
# EULER DISCRETE FLOW MATCHING SAMPLER WITH CFG
|
| 374 |
# Training uses: x_t = (1-t)*noise + t*data, v = data - noise
|
| 375 |
# So t=0 is noise, t=1 is data. We sample from t=0 to t=1.
|
| 376 |
# ============================================================================
|
|
|
|
| 403 |
vae.to(DEVICE)
|
| 404 |
|
| 405 |
with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=DTYPE):
|
| 406 |
+
# Encode prompts
|
| 407 |
t5_in = t5_tok(prompt, max_length=128, padding="max_length",
|
| 408 |
truncation=True, return_tensors="pt").to(DEVICE)
|
| 409 |
+
t5_cond = t5_enc(**t5_in).last_hidden_state
|
| 410 |
|
| 411 |
clip_in = clip_tok(prompt, max_length=77, padding="max_length",
|
| 412 |
truncation=True, return_tensors="pt").to(DEVICE)
|
| 413 |
+
clip_cond = clip_enc(**clip_in).pooler_output
|
| 414 |
+
|
| 415 |
+
# Encode negative prompt for CFG
|
| 416 |
+
do_cfg = guidance_scale > 1.0
|
| 417 |
+
if do_cfg:
|
| 418 |
+
neg_prompt = negative_prompt if negative_prompt else ""
|
| 419 |
+
t5_neg_in = t5_tok(neg_prompt, max_length=128, padding="max_length",
|
| 420 |
+
truncation=True, return_tensors="pt").to(DEVICE)
|
| 421 |
+
t5_uncond = t5_enc(**t5_neg_in).last_hidden_state
|
| 422 |
+
|
| 423 |
+
clip_neg_in = clip_tok(neg_prompt, max_length=77, padding="max_length",
|
| 424 |
+
truncation=True, return_tensors="pt").to(DEVICE)
|
| 425 |
+
clip_uncond = clip_enc(**clip_neg_in).pooler_output
|
| 426 |
+
|
| 427 |
+
# Batch for efficient forward pass
|
| 428 |
+
t5_batch = torch.cat([t5_uncond, t5_cond], dim=0)
|
| 429 |
+
clip_batch = torch.cat([clip_uncond, clip_cond], dim=0)
|
| 430 |
|
| 431 |
# Latent dimensions
|
| 432 |
H_lat = height // 8
|
| 433 |
W_lat = width // 8
|
| 434 |
C = 16
|
|
|
|
| 435 |
|
| 436 |
# Start from noise (t=0 in this convention)
|
| 437 |
x = torch.randn(1, H_lat * W_lat, C, device=DEVICE, dtype=DTYPE, generator=generator)
|
| 438 |
|
| 439 |
# Position IDs
|
| 440 |
img_ids = TinyFluxDeep.create_img_ids(1, H_lat, W_lat, DEVICE)
|
|
|
|
| 441 |
|
| 442 |
# Timesteps: 0 -> 1 (noise to data) with Flux shift
|
| 443 |
t_linear = torch.linspace(0, 1, num_inference_steps + 1, device=DEVICE)
|
| 444 |
+
timesteps = flux_shift(t_linear, shift=SHIFT)
|
| 445 |
|
| 446 |
# Euler flow matching: x_{t+dt} = x_t + v * dt
|
|
|
|
| 447 |
for i in range(num_inference_steps):
|
| 448 |
t_curr = timesteps[i]
|
| 449 |
t_next = timesteps[i + 1]
|
| 450 |
+
dt = t_next - t_curr
|
| 451 |
|
| 452 |
t_batch = t_curr.unsqueeze(0)
|
| 453 |
+
guidance_embed = torch.tensor([guidance_scale], device=DEVICE, dtype=DTYPE)
|
| 454 |
+
|
| 455 |
+
if do_cfg:
|
| 456 |
+
# Batched forward pass for efficiency
|
| 457 |
+
x_batch = x.repeat(2, 1, 1)
|
| 458 |
+
img_ids_batch = img_ids
|
| 459 |
+
t_batch_2 = t_batch.repeat(2)
|
| 460 |
+
guidance_batch = guidance_embed.repeat(2)
|
| 461 |
+
|
| 462 |
+
v_batch = model(
|
| 463 |
+
hidden_states=x_batch,
|
| 464 |
+
encoder_hidden_states=t5_batch,
|
| 465 |
+
pooled_projections=clip_batch,
|
| 466 |
+
timestep=t_batch_2,
|
| 467 |
+
img_ids=img_ids_batch,
|
| 468 |
+
guidance=guidance_batch,
|
| 469 |
+
)
|
| 470 |
+
v_uncond, v_cond = v_batch.chunk(2, dim=0)
|
| 471 |
+
v = v_uncond + guidance_scale * (v_cond - v_uncond)
|
| 472 |
+
else:
|
| 473 |
+
v = model(
|
| 474 |
+
hidden_states=x,
|
| 475 |
+
encoder_hidden_states=t5_cond,
|
| 476 |
+
pooled_projections=clip_cond,
|
| 477 |
+
timestep=t_batch,
|
| 478 |
+
img_ids=img_ids,
|
| 479 |
+
guidance=guidance_embed,
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
x = x + v * dt
|
| 483 |
|
| 484 |
# Decode latents
|
|
|
|
| 540 |
negative_prompt = gr.Text(
|
| 541 |
label="Negative prompt",
|
| 542 |
max_lines=1,
|
| 543 |
+
placeholder="blurry, distorted, low quality",
|
| 544 |
+
value="",
|
| 545 |
)
|
| 546 |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
|
| 547 |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
|
|
|
| 551 |
height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512)
|
| 552 |
|
| 553 |
with gr.Row():
|
| 554 |
+
guidance_scale = gr.Slider(label="CFG Scale", minimum=1.0, maximum=10.0, step=0.5, value=5.0)
|
| 555 |
+
num_inference_steps = gr.Slider(label="Steps", minimum=10, maximum=50, step=1, value=25)
|
| 556 |
|
| 557 |
gr.Examples(examples=examples, inputs=[prompt])
|
| 558 |
|
| 559 |
gr.Markdown("""
|
| 560 |
---
|
| 561 |
+
**Notes:** Trained at 512×512. CFG 3.0-7.0 recommended, 20-30 steps.
|
| 562 |
""")
|
| 563 |
|
| 564 |
gr.on(
|