Update README.md
Browse files
README.md
CHANGED
|
@@ -94,26 +94,40 @@ def flux_shift(t, s=3.0):
|
|
| 94 |
"""Flux-style timestep shifting - biases toward data end."""
|
| 95 |
return s * t / (1 + (s - 1) * t)
|
| 96 |
|
| 97 |
-
def generate(model, t5_emb,
|
| 98 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
device = next(model.parameters()).device
|
| 100 |
dtype = next(model.parameters()).dtype
|
| 101 |
|
| 102 |
-
# Start from noise
|
| 103 |
x = torch.randn(1, 64*64, 16, device=device, dtype=dtype)
|
| 104 |
img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, device)
|
| 105 |
|
| 106 |
-
# Timesteps with Flux shift
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
for i in range(num_steps):
|
| 110 |
t_curr = timesteps[i]
|
| 111 |
t_next = timesteps[i + 1]
|
| 112 |
-
dt = t_next - t_curr
|
| 113 |
|
| 114 |
-
t_batch = t_curr.
|
| 115 |
|
| 116 |
-
#
|
| 117 |
v_cond = model(
|
| 118 |
hidden_states=x,
|
| 119 |
encoder_hidden_states=t5_emb,
|
|
@@ -122,11 +136,10 @@ def generate(model, t5_emb, clip_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
|
|
| 122 |
img_ids=img_ids,
|
| 123 |
)
|
| 124 |
|
| 125 |
-
# Unconditional prediction (for CFG)
|
| 126 |
v_uncond = model(
|
| 127 |
hidden_states=x,
|
| 128 |
-
encoder_hidden_states=
|
| 129 |
-
pooled_projections=
|
| 130 |
timestep=t_batch,
|
| 131 |
img_ids=img_ids,
|
| 132 |
)
|
|
@@ -134,7 +147,7 @@ def generate(model, t5_emb, clip_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
|
|
| 134 |
# Classifier-free guidance
|
| 135 |
v = v_uncond + cfg_scale * (v_cond - v_uncond)
|
| 136 |
|
| 137 |
-
# Euler step
|
| 138 |
x = x + v * dt
|
| 139 |
|
| 140 |
return x # [1, 4096, 16] - decode with VAE
|
|
@@ -405,9 +418,12 @@ vae = AutoencoderKL.from_pretrained(
|
|
| 405 |
model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py")
|
| 406 |
exec(open(model_py).read())
|
| 407 |
|
| 408 |
-
config = TinyFluxConfig(
|
|
|
|
|
|
|
|
|
|
| 409 |
model = TinyFluxDeep(config).to("cuda", torch.bfloat16)
|
| 410 |
-
weights = load_file(hf_hub_download("AbstractPhil/tiny-flux-deep", "
|
| 411 |
model.load_state_dict(weights, strict=False)
|
| 412 |
model.eval()
|
| 413 |
|
|
@@ -428,58 +444,71 @@ def encode_prompt(prompt):
|
|
| 428 |
|
| 429 |
return t5_emb, clip_pooled
|
| 430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
def generate_image(prompt, num_steps=25, cfg_scale=4.0, seed=None):
|
| 432 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
if seed is not None:
|
| 434 |
torch.manual_seed(seed)
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
from PIL import Image
|
| 480 |
-
return Image.fromarray(image)
|
| 481 |
-
|
| 482 |
-
# Generate!
|
| 483 |
image = generate_image("a photograph of a tiger in natural habitat", seed=42)
|
| 484 |
image.save("tiger.png")
|
| 485 |
```
|
|
|
|
| 94 |
"""Flux-style timestep shifting - biases toward data end."""
|
| 95 |
return s * t / (1 + (s - 1) * t)
|
| 96 |
|
| 97 |
+
def generate(model, t5_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
|
| 98 |
+
"""
|
| 99 |
+
Euler sampling for rectified flow.
|
| 100 |
+
|
| 101 |
+
Flow matching: x_t = (1-t)*noise + t*data
|
| 102 |
+
- t=0: pure noise
|
| 103 |
+
- t=1: pure data
|
| 104 |
+
- v = data - noise
|
| 105 |
+
|
| 106 |
+
Integrate from t=0 → t=1
|
| 107 |
+
"""
|
| 108 |
device = next(model.parameters()).device
|
| 109 |
dtype = next(model.parameters()).dtype
|
| 110 |
|
| 111 |
+
# Start from pure noise (t=0)
|
| 112 |
x = torch.randn(1, 64*64, 16, device=device, dtype=dtype)
|
| 113 |
img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, device)
|
| 114 |
|
| 115 |
+
# Timesteps: 0 → 1 with Flux shift
|
| 116 |
+
t_linear = torch.linspace(0, 1, num_steps + 1, device=device)
|
| 117 |
+
timesteps = flux_shift(t_linear, s=3.0)
|
| 118 |
+
|
| 119 |
+
# Null embeddings for CFG
|
| 120 |
+
t5_null = torch.zeros_like(t5_emb)
|
| 121 |
+
clip_null = torch.zeros_like(clip_pooled)
|
| 122 |
|
| 123 |
for i in range(num_steps):
|
| 124 |
t_curr = timesteps[i]
|
| 125 |
t_next = timesteps[i + 1]
|
| 126 |
+
dt = t_next - t_curr # Positive
|
| 127 |
|
| 128 |
+
t_batch = t_curr.unsqueeze(0)
|
| 129 |
|
| 130 |
+
# Predict velocity
|
| 131 |
v_cond = model(
|
| 132 |
hidden_states=x,
|
| 133 |
encoder_hidden_states=t5_emb,
|
|
|
|
| 136 |
img_ids=img_ids,
|
| 137 |
)
|
| 138 |
|
|
|
|
| 139 |
v_uncond = model(
|
| 140 |
hidden_states=x,
|
| 141 |
+
encoder_hidden_states=t5_null,
|
| 142 |
+
pooled_projections=clip_null,
|
| 143 |
timestep=t_batch,
|
| 144 |
img_ids=img_ids,
|
| 145 |
)
|
|
|
|
| 147 |
# Classifier-free guidance
|
| 148 |
v = v_uncond + cfg_scale * (v_cond - v_uncond)
|
| 149 |
|
| 150 |
+
# Euler step: x_{t+dt} = x_t + v * dt
|
| 151 |
x = x + v * dt
|
| 152 |
|
| 153 |
return x # [1, 4096, 16] - decode with VAE
|
|
|
|
| 418 |
model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py")
|
| 419 |
exec(open(model_py).read())
|
| 420 |
|
| 421 |
+
config = TinyFluxConfig(
|
| 422 |
+
use_sol_prior=True, # Disabled until trained
|
| 423 |
+
use_t5_vec=True, # Disabled until trained
|
| 424 |
+
)
|
| 425 |
model = TinyFluxDeep(config).to("cuda", torch.bfloat16)
|
| 426 |
+
weights = load_file(hf_hub_download("AbstractPhil/tiny-flux-deep", "checkpoint_runs/v4_init/lailah_401434_v4_init.safetensors"))
|
| 427 |
model.load_state_dict(weights, strict=False)
|
| 428 |
model.eval()
|
| 429 |
|
|
|
|
| 444 |
|
| 445 |
return t5_emb, clip_pooled
|
| 446 |
|
| 447 |
+
|
| 448 |
+
def flux_shift(t, s=3.0):
|
| 449 |
+
"""Flux-style timestep shift."""
|
| 450 |
+
return s * t / (1 + (s - 1) * t)
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
@torch.inference_mode()
|
| 454 |
def generate_image(prompt, num_steps=25, cfg_scale=4.0, seed=None):
|
| 455 |
+
"""
|
| 456 |
+
Euler sampling for rectified flow.
|
| 457 |
+
|
| 458 |
+
Flow matching formulation:
|
| 459 |
+
x_t = (1 - t) * noise + t * data
|
| 460 |
+
At t=0: pure noise
|
| 461 |
+
At t=1: pure data
|
| 462 |
+
Velocity v = data - noise (constant)
|
| 463 |
+
|
| 464 |
+
Sampling: Integrate from t=0 (noise) → t=1 (data)
|
| 465 |
+
"""
|
| 466 |
if seed is not None:
|
| 467 |
torch.manual_seed(seed)
|
| 468 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 469 |
+
t5_emb, clip_pooled = encode_prompt(prompt)
|
| 470 |
+
t5_null, clip_null = encode_prompt("")
|
| 471 |
+
|
| 472 |
+
# Start from pure noise (t=0)
|
| 473 |
+
x = torch.randn(1, 64*64, 16, device="cuda", dtype=torch.bfloat16)
|
| 474 |
+
img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, "cuda")
|
| 475 |
+
|
| 476 |
+
# Timesteps: 0 → 1 with Flux shift
|
| 477 |
+
t_linear = torch.linspace(0, 1, num_steps + 1, device="cuda", dtype=torch.float32)
|
| 478 |
+
timesteps = flux_shift(t_linear, s=3.0)
|
| 479 |
+
|
| 480 |
+
for i in range(num_steps):
|
| 481 |
+
t_curr = timesteps[i]
|
| 482 |
+
t_next = timesteps[i + 1]
|
| 483 |
+
dt = t_next - t_curr # Positive, moving toward data
|
| 484 |
+
|
| 485 |
+
t_batch = t_curr.unsqueeze(0)
|
| 486 |
+
|
| 487 |
+
# Predict velocity
|
| 488 |
+
v_cond = model(x, t5_emb, clip_pooled, t_batch, img_ids)
|
| 489 |
+
v_uncond = model(x, t5_null, clip_null, t_batch, img_ids)
|
| 490 |
+
|
| 491 |
+
# Classifier-free guidance
|
| 492 |
+
v = v_uncond + cfg_scale * (v_cond - v_uncond)
|
| 493 |
+
|
| 494 |
+
# Euler step: x_{t+dt} = x_t + v * dt
|
| 495 |
+
x = x + v * dt
|
| 496 |
+
|
| 497 |
+
# Decode with VAE
|
| 498 |
+
x = x.reshape(1, 64, 64, 16).permute(0, 3, 1, 2) # [B, C, H, W]
|
| 499 |
+
x = x / vae.config.scaling_factor
|
| 500 |
+
image = vae.decode(x).sample
|
| 501 |
+
|
| 502 |
+
# Convert to PIL
|
| 503 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 504 |
+
image = image[0].permute(1, 2, 0).cpu().float().numpy()
|
| 505 |
+
image = (image * 255).astype("uint8")
|
| 506 |
+
|
| 507 |
+
from PIL import Image
|
| 508 |
+
return Image.fromarray(image)
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
# Generate
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
image = generate_image("a photograph of a tiger in natural habitat", seed=42)
|
| 513 |
image.save("tiger.png")
|
| 514 |
```
|