Update README.md
Browse files
README.md
CHANGED
|
@@ -94,17 +94,8 @@ def flux_shift(t, s=3.0):
|
|
| 94 |
"""Flux-style timestep shifting - biases toward data end."""
|
| 95 |
return s * t / (1 + (s - 1) * t)
|
| 96 |
|
| 97 |
-
def generate(model, t5_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
|
| 98 |
-
"""
|
| 99 |
-
Euler sampling for rectified flow.
|
| 100 |
-
|
| 101 |
-
Flow matching: x_t = (1-t)*noise + t*data
|
| 102 |
-
- t=0: pure noise
|
| 103 |
-
- t=1: pure data
|
| 104 |
-
- v = data - noise
|
| 105 |
-
|
| 106 |
-
Integrate from t=0 → t=1
|
| 107 |
-
"""
|
| 108 |
device = next(model.parameters()).device
|
| 109 |
dtype = next(model.parameters()).dtype
|
| 110 |
|
|
@@ -112,22 +103,17 @@ def generate(model, t5_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
|
|
| 112 |
x = torch.randn(1, 64*64, 16, device=device, dtype=dtype)
|
| 113 |
img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, device)
|
| 114 |
|
| 115 |
-
#
|
| 116 |
-
|
| 117 |
-
timesteps = flux_shift(t_linear, s=3.0)
|
| 118 |
-
|
| 119 |
-
# Null embeddings for CFG
|
| 120 |
-
t5_null = torch.zeros_like(t5_emb)
|
| 121 |
-
clip_null = torch.zeros_like(clip_pooled)
|
| 122 |
|
| 123 |
for i in range(num_steps):
|
| 124 |
t_curr = timesteps[i]
|
| 125 |
t_next = timesteps[i + 1]
|
| 126 |
-
dt = t_next - t_curr
|
| 127 |
|
| 128 |
-
t_batch = t_curr.
|
| 129 |
|
| 130 |
-
#
|
| 131 |
v_cond = model(
|
| 132 |
hidden_states=x,
|
| 133 |
encoder_hidden_states=t5_emb,
|
|
@@ -136,10 +122,11 @@ def generate(model, t5_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
|
|
| 136 |
img_ids=img_ids,
|
| 137 |
)
|
| 138 |
|
|
|
|
| 139 |
v_uncond = model(
|
| 140 |
hidden_states=x,
|
| 141 |
-
encoder_hidden_states=
|
| 142 |
-
pooled_projections=
|
| 143 |
timestep=t_batch,
|
| 144 |
img_ids=img_ids,
|
| 145 |
)
|
|
@@ -147,7 +134,7 @@ def generate(model, t5_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
|
|
| 147 |
# Classifier-free guidance
|
| 148 |
v = v_uncond + cfg_scale * (v_cond - v_uncond)
|
| 149 |
|
| 150 |
-
# Euler step
|
| 151 |
x = x + v * dt
|
| 152 |
|
| 153 |
return x # [1, 4096, 16] - decode with VAE
|
|
@@ -418,12 +405,9 @@ vae = AutoencoderKL.from_pretrained(
|
|
| 418 |
model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py")
|
| 419 |
exec(open(model_py).read())
|
| 420 |
|
| 421 |
-
config = TinyFluxConfig(
|
| 422 |
-
use_sol_prior=True, # Disabled until trained
|
| 423 |
-
use_t5_vec=True, # Disabled until trained
|
| 424 |
-
)
|
| 425 |
model = TinyFluxDeep(config).to("cuda", torch.bfloat16)
|
| 426 |
-
weights = load_file(hf_hub_download("AbstractPhil/tiny-flux-deep", "
|
| 427 |
model.load_state_dict(weights, strict=False)
|
| 428 |
model.eval()
|
| 429 |
|
|
@@ -444,71 +428,63 @@ def encode_prompt(prompt):
|
|
| 444 |
|
| 445 |
return t5_emb, clip_pooled
|
| 446 |
|
| 447 |
-
|
| 448 |
-
def flux_shift(t, s=3.0):
|
| 449 |
-
"""Flux-style timestep shift."""
|
| 450 |
-
return s * t / (1 + (s - 1) * t)
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
@torch.inference_mode()
|
| 454 |
def generate_image(prompt, num_steps=25, cfg_scale=4.0, seed=None):
|
| 455 |
"""
|
| 456 |
Euler sampling for rectified flow.
|
| 457 |
|
| 458 |
-
Flow
|
| 459 |
-
|
| 460 |
-
At t=0: pure noise
|
| 461 |
-
At t=1: pure data
|
| 462 |
-
Velocity v = data - noise (constant)
|
| 463 |
-
|
| 464 |
-
Sampling: Integrate from t=0 (noise) → t=1 (data)
|
| 465 |
"""
|
| 466 |
if seed is not None:
|
| 467 |
torch.manual_seed(seed)
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
image = generate_image("a photograph of a tiger in natural habitat", seed=42)
|
| 513 |
image.save("tiger.png")
|
| 514 |
```
|
|
|
|
| 94 |
"""Flux-style timestep shifting - biases toward data end."""
|
| 95 |
return s * t / (1 + (s - 1) * t)
|
| 96 |
|
| 97 |
+
def generate(model, t5_emb, clip_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
|
| 98 |
+
"""Euler sampling with classifier-free guidance."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
device = next(model.parameters()).device
|
| 100 |
dtype = next(model.parameters()).dtype
|
| 101 |
|
|
|
|
| 103 |
x = torch.randn(1, 64*64, 16, device=device, dtype=dtype)
|
| 104 |
img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, device)
|
| 105 |
|
| 106 |
+
# Rectified flow: integrate from t=0 (noise) to t=1 (data)
|
| 107 |
+
timesteps = flux_shift(torch.linspace(0, 1, num_steps + 1, device=device))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
for i in range(num_steps):
|
| 110 |
t_curr = timesteps[i]
|
| 111 |
t_next = timesteps[i + 1]
|
| 112 |
+
dt = t_next - t_curr
|
| 113 |
|
| 114 |
+
t_batch = t_curr.expand(1)
|
| 115 |
|
| 116 |
+
# Conditional prediction
|
| 117 |
v_cond = model(
|
| 118 |
hidden_states=x,
|
| 119 |
encoder_hidden_states=t5_emb,
|
|
|
|
| 122 |
img_ids=img_ids,
|
| 123 |
)
|
| 124 |
|
| 125 |
+
# Unconditional prediction (for CFG)
|
| 126 |
v_uncond = model(
|
| 127 |
hidden_states=x,
|
| 128 |
+
encoder_hidden_states=torch.zeros_like(t5_emb),
|
| 129 |
+
pooled_projections=torch.zeros_like(clip_pooled),
|
| 130 |
timestep=t_batch,
|
| 131 |
img_ids=img_ids,
|
| 132 |
)
|
|
|
|
| 134 |
# Classifier-free guidance
|
| 135 |
v = v_uncond + cfg_scale * (v_cond - v_uncond)
|
| 136 |
|
| 137 |
+
# Euler step
|
| 138 |
x = x + v * dt
|
| 139 |
|
| 140 |
return x # [1, 4096, 16] - decode with VAE
|
|
|
|
| 405 |
model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py")
|
| 406 |
exec(open(model_py).read())
|
| 407 |
|
| 408 |
+
config = TinyFluxConfig()
|
|
|
|
|
|
|
|
|
|
| 409 |
model = TinyFluxDeep(config).to("cuda", torch.bfloat16)
|
| 410 |
+
weights = load_file(hf_hub_download("AbstractPhil/tiny-flux-deep", "model.safetensors"))
|
| 411 |
model.load_state_dict(weights, strict=False)
|
| 412 |
model.eval()
|
| 413 |
|
|
|
|
| 428 |
|
| 429 |
return t5_emb, clip_pooled
|
| 430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
def generate_image(prompt, num_steps=25, cfg_scale=4.0, seed=None):
|
| 432 |
"""
|
| 433 |
Euler sampling for rectified flow.
|
| 434 |
|
| 435 |
+
Flow: x_t = (1-t)*noise + t*data
|
| 436 |
+
Integrate from t=0 (noise) to t=1 (data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
"""
|
| 438 |
if seed is not None:
|
| 439 |
torch.manual_seed(seed)
|
| 440 |
+
|
| 441 |
+
t5_emb, clip_pooled = encode_prompt(prompt)
|
| 442 |
+
|
| 443 |
+
# Null embeddings for CFG
|
| 444 |
+
t5_null, clip_null = encode_prompt("")
|
| 445 |
+
|
| 446 |
+
# Start from pure noise (t=0)
|
| 447 |
+
x = torch.randn(1, 64*64, 16, device="cuda", dtype=torch.bfloat16)
|
| 448 |
+
img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, "cuda")
|
| 449 |
+
|
| 450 |
+
# Rectified flow: 0 → 1 with Flux shift
|
| 451 |
+
def flux_shift(t, s=3.0):
|
| 452 |
+
return s * t / (1 + (s - 1) * t)
|
| 453 |
+
|
| 454 |
+
timesteps = flux_shift(torch.linspace(0, 1, num_steps + 1, device="cuda"))
|
| 455 |
+
|
| 456 |
+
with torch.no_grad():
|
| 457 |
+
for i in range(num_steps):
|
| 458 |
+
t = timesteps[i].expand(1)
|
| 459 |
+
dt = timesteps[i + 1] - timesteps[i] # Positive
|
| 460 |
+
|
| 461 |
+
# Conditional
|
| 462 |
+
v_cond = model(x, t5_emb, clip_pooled, t, img_ids)
|
| 463 |
+
|
| 464 |
+
# Unconditional
|
| 465 |
+
v_uncond = model(x, t5_null, clip_null, t, img_ids)
|
| 466 |
+
|
| 467 |
+
# CFG
|
| 468 |
+
v = v_uncond + cfg_scale * (v_cond - v_uncond)
|
| 469 |
+
|
| 470 |
+
# Euler step
|
| 471 |
+
x = x + v * dt
|
| 472 |
+
|
| 473 |
+
# Decode with VAE
|
| 474 |
+
x = x.reshape(1, 64, 64, 16).permute(0, 3, 1, 2) # [B, C, H, W]
|
| 475 |
+
x = x / vae.config.scaling_factor
|
| 476 |
+
with torch.no_grad():
|
| 477 |
+
image = vae.decode(x).sample
|
| 478 |
+
|
| 479 |
+
# Convert to PIL
|
| 480 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 481 |
+
image = image[0].permute(1, 2, 0).cpu().float().numpy()
|
| 482 |
+
image = (image * 255).astype("uint8")
|
| 483 |
+
|
| 484 |
+
from PIL import Image
|
| 485 |
+
return Image.fromarray(image)
|
| 486 |
+
|
| 487 |
+
# Generate!
|
| 488 |
image = generate_image("a photograph of a tiger in natural habitat", seed=42)
|
| 489 |
image.save("tiger.png")
|
| 490 |
```
|