Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -371,6 +371,8 @@ print(f"✓ VAE loaded (scale={VAE_SCALE})")
|
|
| 371 |
|
| 372 |
# ============================================================================
|
| 373 |
# EULER DISCRETE FLOW MATCHING SAMPLER
|
|
|
|
|
|
|
| 374 |
# ============================================================================
|
| 375 |
def flux_shift(t, shift=SHIFT):
|
| 376 |
"""Flux time shift: s*t / (1 + (s-1)*t)"""
|
|
@@ -416,22 +418,23 @@ def generate(
|
|
| 416 |
C = 16
|
| 417 |
L = 128 # T5 sequence length
|
| 418 |
|
| 419 |
-
# Start from noise (t=
|
| 420 |
x = torch.randn(1, H_lat * W_lat, C, device=DEVICE, dtype=DTYPE, generator=generator)
|
| 421 |
|
| 422 |
# Position IDs
|
| 423 |
img_ids = TinyFluxDeep.create_img_ids(1, H_lat, W_lat, DEVICE)
|
| 424 |
txt_ids = TinyFluxDeep.create_txt_ids(L, DEVICE)
|
| 425 |
|
| 426 |
-
# Timesteps:
|
| 427 |
-
t_linear = torch.linspace(
|
| 428 |
-
timesteps = flux_shift(t_linear, shift=SHIFT)
|
| 429 |
|
| 430 |
-
# Euler
|
|
|
|
| 431 |
for i in range(num_inference_steps):
|
| 432 |
t_curr = timesteps[i]
|
| 433 |
t_next = timesteps[i + 1]
|
| 434 |
-
dt = t_next - t_curr #
|
| 435 |
|
| 436 |
t_batch = t_curr.unsqueeze(0)
|
| 437 |
guidance = torch.tensor([guidance_scale], device=DEVICE, dtype=DTYPE)
|
|
@@ -493,7 +496,6 @@ with gr.Blocks(css=css) as demo:
|
|
| 493 |
with gr.Row():
|
| 494 |
prompt = gr.Text(
|
| 495 |
label="Prompt",
|
| 496 |
-
value="cat",
|
| 497 |
show_label=False,
|
| 498 |
max_lines=2,
|
| 499 |
placeholder="Enter your prompt...",
|
|
|
|
| 371 |
|
| 372 |
# ============================================================================
|
| 373 |
# EULER DISCRETE FLOW MATCHING SAMPLER
|
| 374 |
+
# Training uses: x_t = (1-t)*noise + t*data, v = data - noise
|
| 375 |
+
# So t=0 is noise, t=1 is data. We sample from t=0 to t=1.
|
| 376 |
# ============================================================================
|
| 377 |
def flux_shift(t, shift=SHIFT):
|
| 378 |
"""Flux time shift: s*t / (1 + (s-1)*t)"""
|
|
|
|
| 418 |
C = 16
|
| 419 |
L = 128 # T5 sequence length
|
| 420 |
|
| 421 |
+
# Start from noise (t=0 in this convention)
|
| 422 |
x = torch.randn(1, H_lat * W_lat, C, device=DEVICE, dtype=DTYPE, generator=generator)
|
| 423 |
|
| 424 |
# Position IDs
|
| 425 |
img_ids = TinyFluxDeep.create_img_ids(1, H_lat, W_lat, DEVICE)
|
| 426 |
txt_ids = TinyFluxDeep.create_txt_ids(L, DEVICE)
|
| 427 |
|
| 428 |
+
# Timesteps: 0 -> 1 (noise to data) with Flux shift
|
| 429 |
+
t_linear = torch.linspace(0, 1, num_inference_steps + 1, device=DEVICE)
|
| 430 |
+
timesteps = flux_shift(t_linear, shift=SHIFT).clamp(1e-4, 1 - 1e-4)
|
| 431 |
|
| 432 |
+
# Euler flow matching: x_{t+dt} = x_t + v * dt
|
| 433 |
+
# v predicts direction from noise to data
|
| 434 |
for i in range(num_inference_steps):
|
| 435 |
t_curr = timesteps[i]
|
| 436 |
t_next = timesteps[i + 1]
|
| 437 |
+
dt = t_next - t_curr # Positive since going 0->1
|
| 438 |
|
| 439 |
t_batch = t_curr.unsqueeze(0)
|
| 440 |
guidance = torch.tensor([guidance_scale], device=DEVICE, dtype=DTYPE)
|
|
|
|
| 496 |
with gr.Row():
|
| 497 |
prompt = gr.Text(
|
| 498 |
label="Prompt",
|
|
|
|
| 499 |
show_label=False,
|
| 500 |
max_lines=2,
|
| 501 |
placeholder="Enter your prompt...",
|