AbstractPhil commited on
Commit
3beb7ef
·
verified ·
1 Parent(s): bf792e2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +63 -87
README.md CHANGED
@@ -94,17 +94,8 @@ def flux_shift(t, s=3.0):
94
  """Flux-style timestep shifting - biases toward data end."""
95
  return s * t / (1 + (s - 1) * t)
96
 
97
- def generate(model, t5_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
98
- """
99
- Euler sampling for rectified flow.
100
-
101
- Flow matching: x_t = (1-t)*noise + t*data
102
- - t=0: pure noise
103
- - t=1: pure data
104
- - v = data - noise
105
-
106
- Integrate from t=0 → t=1
107
- """
108
  device = next(model.parameters()).device
109
  dtype = next(model.parameters()).dtype
110
 
@@ -112,22 +103,17 @@ def generate(model, t5_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
112
  x = torch.randn(1, 64*64, 16, device=device, dtype=dtype)
113
  img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, device)
114
 
115
- # Timesteps: 0 1 with Flux shift
116
- t_linear = torch.linspace(0, 1, num_steps + 1, device=device)
117
- timesteps = flux_shift(t_linear, s=3.0)
118
-
119
- # Null embeddings for CFG
120
- t5_null = torch.zeros_like(t5_emb)
121
- clip_null = torch.zeros_like(clip_pooled)
122
 
123
  for i in range(num_steps):
124
  t_curr = timesteps[i]
125
  t_next = timesteps[i + 1]
126
- dt = t_next - t_curr # Positive
127
 
128
- t_batch = t_curr.unsqueeze(0)
129
 
130
- # Predict velocity
131
  v_cond = model(
132
  hidden_states=x,
133
  encoder_hidden_states=t5_emb,
@@ -136,10 +122,11 @@ def generate(model, t5_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
136
  img_ids=img_ids,
137
  )
138
 
 
139
  v_uncond = model(
140
  hidden_states=x,
141
- encoder_hidden_states=t5_null,
142
- pooled_projections=clip_null,
143
  timestep=t_batch,
144
  img_ids=img_ids,
145
  )
@@ -147,7 +134,7 @@ def generate(model, t5_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
147
  # Classifier-free guidance
148
  v = v_uncond + cfg_scale * (v_cond - v_uncond)
149
 
150
- # Euler step: x_{t+dt} = x_t + v * dt
151
  x = x + v * dt
152
 
153
  return x # [1, 4096, 16] - decode with VAE
@@ -418,12 +405,9 @@ vae = AutoencoderKL.from_pretrained(
418
  model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py")
419
  exec(open(model_py).read())
420
 
421
- config = TinyFluxConfig(
422
- use_sol_prior=True, # Disabled until trained
423
- use_t5_vec=True, # Disabled until trained
424
- )
425
  model = TinyFluxDeep(config).to("cuda", torch.bfloat16)
426
- weights = load_file(hf_hub_download("AbstractPhil/tiny-flux-deep", "checkpoint_runs/v4_init/lailah_401434_v4_init.safetensors"))
427
  model.load_state_dict(weights, strict=False)
428
  model.eval()
429
 
@@ -444,71 +428,63 @@ def encode_prompt(prompt):
444
 
445
  return t5_emb, clip_pooled
446
 
447
-
448
- def flux_shift(t, s=3.0):
449
- """Flux-style timestep shift."""
450
- return s * t / (1 + (s - 1) * t)
451
-
452
-
453
- @torch.inference_mode()
454
  def generate_image(prompt, num_steps=25, cfg_scale=4.0, seed=None):
455
  """
456
  Euler sampling for rectified flow.
457
 
458
- Flow matching formulation:
459
- x_t = (1 - t) * noise + t * data
460
- At t=0: pure noise
461
- At t=1: pure data
462
- Velocity v = data - noise (constant)
463
-
464
- Sampling: Integrate from t=0 (noise) → t=1 (data)
465
  """
466
  if seed is not None:
467
  torch.manual_seed(seed)
468
- with torch.autocast("cuda", dtype=torch.bfloat16):
469
- t5_emb, clip_pooled = encode_prompt(prompt)
470
- t5_null, clip_null = encode_prompt("")
471
-
472
- # Start from pure noise (t=0)
473
- x = torch.randn(1, 64*64, 16, device="cuda", dtype=torch.bfloat16)
474
- img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, "cuda")
475
-
476
- # Timesteps: 0 → 1 with Flux shift
477
- t_linear = torch.linspace(0, 1, num_steps + 1, device="cuda", dtype=torch.float32)
478
- timesteps = flux_shift(t_linear, s=3.0)
479
-
480
- for i in range(num_steps):
481
- t_curr = timesteps[i]
482
- t_next = timesteps[i + 1]
483
- dt = t_next - t_curr # Positive, moving toward data
484
-
485
- t_batch = t_curr.unsqueeze(0)
486
-
487
- # Predict velocity
488
- v_cond = model(x, t5_emb, clip_pooled, t_batch, img_ids)
489
- v_uncond = model(x, t5_null, clip_null, t_batch, img_ids)
490
-
491
- # Classifier-free guidance
492
- v = v_uncond + cfg_scale * (v_cond - v_uncond)
493
-
494
- # Euler step: x_{t+dt} = x_t + v * dt
495
- x = x + v * dt
496
-
497
- # Decode with VAE
498
- x = x.reshape(1, 64, 64, 16).permute(0, 3, 1, 2) # [B, C, H, W]
499
- x = x / vae.config.scaling_factor
500
- image = vae.decode(x).sample
501
-
502
- # Convert to PIL
503
- image = (image / 2 + 0.5).clamp(0, 1)
504
- image = image[0].permute(1, 2, 0).cpu().float().numpy()
505
- image = (image * 255).astype("uint8")
506
-
507
- from PIL import Image
508
- return Image.fromarray(image)
509
-
510
-
511
- # Generate
 
 
 
 
512
  image = generate_image("a photograph of a tiger in natural habitat", seed=42)
513
  image.save("tiger.png")
514
  ```
 
94
  """Flux-style timestep shifting - biases toward data end."""
95
  return s * t / (1 + (s - 1) * t)
96
 
97
+ def generate(model, t5_emb, clip_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
98
+ """Euler sampling with classifier-free guidance."""
 
 
 
 
 
 
 
 
 
99
  device = next(model.parameters()).device
100
  dtype = next(model.parameters()).dtype
101
 
 
103
  x = torch.randn(1, 64*64, 16, device=device, dtype=dtype)
104
  img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, device)
105
 
106
+ # Rectified flow: integrate from t=0 (noise) to t=1 (data)
107
+ timesteps = flux_shift(torch.linspace(0, 1, num_steps + 1, device=device))
 
 
 
 
 
108
 
109
  for i in range(num_steps):
110
  t_curr = timesteps[i]
111
  t_next = timesteps[i + 1]
112
+ dt = t_next - t_curr
113
 
114
+ t_batch = t_curr.expand(1)
115
 
116
+ # Conditional prediction
117
  v_cond = model(
118
  hidden_states=x,
119
  encoder_hidden_states=t5_emb,
 
122
  img_ids=img_ids,
123
  )
124
 
125
+ # Unconditional prediction (for CFG)
126
  v_uncond = model(
127
  hidden_states=x,
128
+ encoder_hidden_states=torch.zeros_like(t5_emb),
129
+ pooled_projections=torch.zeros_like(clip_pooled),
130
  timestep=t_batch,
131
  img_ids=img_ids,
132
  )
 
134
  # Classifier-free guidance
135
  v = v_uncond + cfg_scale * (v_cond - v_uncond)
136
 
137
+ # Euler step
138
  x = x + v * dt
139
 
140
  return x # [1, 4096, 16] - decode with VAE
 
405
  model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py")
406
  exec(open(model_py).read())
407
 
408
+ config = TinyFluxConfig()
 
 
 
409
  model = TinyFluxDeep(config).to("cuda", torch.bfloat16)
410
+ weights = load_file(hf_hub_download("AbstractPhil/tiny-flux-deep", "model.safetensors"))
411
  model.load_state_dict(weights, strict=False)
412
  model.eval()
413
 
 
428
 
429
  return t5_emb, clip_pooled
430
 
 
 
 
 
 
 
 
431
  def generate_image(prompt, num_steps=25, cfg_scale=4.0, seed=None):
432
  """
433
  Euler sampling for rectified flow.
434
 
435
+ Flow: x_t = (1-t)*noise + t*data
436
+ Integrate from t=0 (noise) to t=1 (data)
 
 
 
 
 
437
  """
438
  if seed is not None:
439
  torch.manual_seed(seed)
440
+
441
+ t5_emb, clip_pooled = encode_prompt(prompt)
442
+
443
+ # Null embeddings for CFG
444
+ t5_null, clip_null = encode_prompt("")
445
+
446
+ # Start from pure noise (t=0)
447
+ x = torch.randn(1, 64*64, 16, device="cuda", dtype=torch.bfloat16)
448
+ img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, "cuda")
449
+
450
+ # Rectified flow: 0 → 1 with Flux shift
451
+ def flux_shift(t, s=3.0):
452
+ return s * t / (1 + (s - 1) * t)
453
+
454
+ timesteps = flux_shift(torch.linspace(0, 1, num_steps + 1, device="cuda"))
455
+
456
+ with torch.no_grad():
457
+ for i in range(num_steps):
458
+ t = timesteps[i].expand(1)
459
+ dt = timesteps[i + 1] - timesteps[i] # Positive
460
+
461
+ # Conditional
462
+ v_cond = model(x, t5_emb, clip_pooled, t, img_ids)
463
+
464
+ # Unconditional
465
+ v_uncond = model(x, t5_null, clip_null, t, img_ids)
466
+
467
+ # CFG
468
+ v = v_uncond + cfg_scale * (v_cond - v_uncond)
469
+
470
+ # Euler step
471
+ x = x + v * dt
472
+
473
+ # Decode with VAE
474
+ x = x.reshape(1, 64, 64, 16).permute(0, 3, 1, 2) # [B, C, H, W]
475
+ x = x / vae.config.scaling_factor
476
+ with torch.no_grad():
477
+ image = vae.decode(x).sample
478
+
479
+ # Convert to PIL
480
+ image = (image / 2 + 0.5).clamp(0, 1)
481
+ image = image[0].permute(1, 2, 0).cpu().float().numpy()
482
+ image = (image * 255).astype("uint8")
483
+
484
+ from PIL import Image
485
+ return Image.fromarray(image)
486
+
487
+ # Generate!
488
  image = generate_image("a photograph of a tiger in natural habitat", seed=42)
489
  image.save("tiger.png")
490
  ```