AbstractPhil commited on
Commit
f6fc133
·
verified ·
1 Parent(s): b9c4369

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +92 -63
README.md CHANGED
@@ -94,26 +94,40 @@ def flux_shift(t, s=3.0):
94
  """Flux-style timestep shifting - biases toward data end."""
95
  return s * t / (1 + (s - 1) * t)
96
 
97
- def generate(model, t5_emb, clip_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
98
- """Euler sampling with classifier-free guidance."""
 
 
 
 
 
 
 
 
 
99
  device = next(model.parameters()).device
100
  dtype = next(model.parameters()).dtype
101
 
102
- # Start from noise
103
  x = torch.randn(1, 64*64, 16, device=device, dtype=dtype)
104
  img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, device)
105
 
106
- # Timesteps with Flux shift
107
- timesteps = flux_shift(torch.linspace(1, 0, num_steps + 1, device=device))
 
 
 
 
 
108
 
109
  for i in range(num_steps):
110
  t_curr = timesteps[i]
111
  t_next = timesteps[i + 1]
112
- dt = t_next - t_curr
113
 
114
- t_batch = t_curr.expand(1)
115
 
116
- # Conditional prediction
117
  v_cond = model(
118
  hidden_states=x,
119
  encoder_hidden_states=t5_emb,
@@ -122,11 +136,10 @@ def generate(model, t5_emb, clip_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
122
  img_ids=img_ids,
123
  )
124
 
125
- # Unconditional prediction (for CFG)
126
  v_uncond = model(
127
  hidden_states=x,
128
- encoder_hidden_states=torch.zeros_like(t5_emb),
129
- pooled_projections=torch.zeros_like(clip_pooled),
130
  timestep=t_batch,
131
  img_ids=img_ids,
132
  )
@@ -134,7 +147,7 @@ def generate(model, t5_emb, clip_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
134
  # Classifier-free guidance
135
  v = v_uncond + cfg_scale * (v_cond - v_uncond)
136
 
137
- # Euler step
138
  x = x + v * dt
139
 
140
  return x # [1, 4096, 16] - decode with VAE
@@ -405,9 +418,12 @@ vae = AutoencoderKL.from_pretrained(
405
  model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py")
406
  exec(open(model_py).read())
407
 
408
- config = TinyFluxConfig()
 
 
 
409
  model = TinyFluxDeep(config).to("cuda", torch.bfloat16)
410
- weights = load_file(hf_hub_download("AbstractPhil/tiny-flux-deep", "model.safetensors"))
411
  model.load_state_dict(weights, strict=False)
412
  model.eval()
413
 
@@ -428,58 +444,71 @@ def encode_prompt(prompt):
428
 
429
  return t5_emb, clip_pooled
430
 
 
 
 
 
 
 
 
431
  def generate_image(prompt, num_steps=25, cfg_scale=4.0, seed=None):
432
- """Generate image from text prompt."""
 
 
 
 
 
 
 
 
 
 
433
  if seed is not None:
434
  torch.manual_seed(seed)
435
-
436
- t5_emb, clip_pooled = encode_prompt(prompt)
437
-
438
- # Null embeddings for CFG
439
- t5_null, clip_null = encode_prompt("")
440
-
441
- # Start from noise
442
- x = torch.randn(1, 64*64, 16, device="cuda", dtype=torch.bfloat16)
443
- img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, "cuda")
444
-
445
- # Flux-shifted timesteps
446
- def flux_shift(t, s=3.0):
447
- return s * t / (1 + (s - 1) * t)
448
-
449
- timesteps = flux_shift(torch.linspace(1, 0, num_steps + 1, device="cuda"))
450
-
451
- with torch.no_grad():
452
- for i in range(num_steps):
453
- t = timesteps[i].expand(1)
454
- dt = timesteps[i + 1] - timesteps[i]
455
-
456
- # Conditional
457
- v_cond = model(x, t5_emb, clip_pooled, t, img_ids)
458
-
459
- # Unconditional
460
- v_uncond = model(x, t5_null, clip_null, t, img_ids)
461
-
462
- # CFG
463
- v = v_uncond + cfg_scale * (v_cond - v_uncond)
464
-
465
- # Euler step
466
- x = x + v * dt
467
-
468
- # Decode with VAE
469
- x = x.reshape(1, 64, 64, 16).permute(0, 3, 1, 2) # [B, C, H, W]
470
- x = x / vae.config.scaling_factor
471
- with torch.no_grad():
472
- image = vae.decode(x).sample
473
-
474
- # Convert to PIL
475
- image = (image / 2 + 0.5).clamp(0, 1)
476
- image = image[0].permute(1, 2, 0).cpu().float().numpy()
477
- image = (image * 255).astype("uint8")
478
-
479
- from PIL import Image
480
- return Image.fromarray(image)
481
-
482
- # Generate!
483
  image = generate_image("a photograph of a tiger in natural habitat", seed=42)
484
  image.save("tiger.png")
485
  ```
 
94
  """Flux-style timestep shifting - biases toward data end."""
95
  return s * t / (1 + (s - 1) * t)
96
 
97
+ def generate(model, t5_emb, clip_pooled, num_steps=25, cfg_scale=4.0):
98
+ """
99
+ Euler sampling for rectified flow.
100
+
101
+ Flow matching: x_t = (1-t)*noise + t*data
102
+ - t=0: pure noise
103
+ - t=1: pure data
104
+ - v = data - noise
105
+
106
+ Integrate from t=0 → t=1
107
+ """
108
  device = next(model.parameters()).device
109
  dtype = next(model.parameters()).dtype
110
 
111
+ # Start from pure noise (t=0)
112
  x = torch.randn(1, 64*64, 16, device=device, dtype=dtype)
113
  img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, device)
114
 
115
+ # Timesteps: 0 → 1 with Flux shift
116
+ t_linear = torch.linspace(0, 1, num_steps + 1, device=device)
117
+ timesteps = flux_shift(t_linear, s=3.0)
118
+
119
+ # Null embeddings for CFG
120
+ t5_null = torch.zeros_like(t5_emb)
121
+ clip_null = torch.zeros_like(clip_pooled)
122
 
123
  for i in range(num_steps):
124
  t_curr = timesteps[i]
125
  t_next = timesteps[i + 1]
126
+ dt = t_next - t_curr # Positive
127
 
128
+ t_batch = t_curr.unsqueeze(0)
129
 
130
+ # Predict velocity
131
  v_cond = model(
132
  hidden_states=x,
133
  encoder_hidden_states=t5_emb,
 
136
  img_ids=img_ids,
137
  )
138
 
 
139
  v_uncond = model(
140
  hidden_states=x,
141
+ encoder_hidden_states=t5_null,
142
+ pooled_projections=clip_null,
143
  timestep=t_batch,
144
  img_ids=img_ids,
145
  )
 
147
  # Classifier-free guidance
148
  v = v_uncond + cfg_scale * (v_cond - v_uncond)
149
 
150
+ # Euler step: x_{t+dt} = x_t + v * dt
151
  x = x + v * dt
152
 
153
  return x # [1, 4096, 16] - decode with VAE
 
418
  model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py")
419
  exec(open(model_py).read())
420
 
421
+ config = TinyFluxConfig(
422
+ use_sol_prior=True, # Disabled until trained
423
+ use_t5_vec=True, # Disabled until trained
424
+ )
425
  model = TinyFluxDeep(config).to("cuda", torch.bfloat16)
426
+ weights = load_file(hf_hub_download("AbstractPhil/tiny-flux-deep", "checkpoint_runs/v4_init/lailah_401434_v4_init.safetensors"))
427
  model.load_state_dict(weights, strict=False)
428
  model.eval()
429
 
 
444
 
445
  return t5_emb, clip_pooled
446
 
447
+
448
+ def flux_shift(t, s=3.0):
449
+ """Flux-style timestep shift."""
450
+ return s * t / (1 + (s - 1) * t)
451
+
452
+
453
+ @torch.inference_mode()
454
  def generate_image(prompt, num_steps=25, cfg_scale=4.0, seed=None):
455
+ """
456
+ Euler sampling for rectified flow.
457
+
458
+ Flow matching formulation:
459
+ x_t = (1 - t) * noise + t * data
460
+ At t=0: pure noise
461
+ At t=1: pure data
462
+ Velocity v = data - noise (constant)
463
+
464
+ Sampling: Integrate from t=0 (noise) → t=1 (data)
465
+ """
466
  if seed is not None:
467
  torch.manual_seed(seed)
468
+ with torch.autocast("cuda", dtype=torch.bfloat16):
469
+ t5_emb, clip_pooled = encode_prompt(prompt)
470
+ t5_null, clip_null = encode_prompt("")
471
+
472
+ # Start from pure noise (t=0)
473
+ x = torch.randn(1, 64*64, 16, device="cuda", dtype=torch.bfloat16)
474
+ img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, "cuda")
475
+
476
+ # Timesteps: 0 → 1 with Flux shift
477
+ t_linear = torch.linspace(0, 1, num_steps + 1, device="cuda", dtype=torch.float32)
478
+ timesteps = flux_shift(t_linear, s=3.0)
479
+
480
+ for i in range(num_steps):
481
+ t_curr = timesteps[i]
482
+ t_next = timesteps[i + 1]
483
+ dt = t_next - t_curr # Positive, moving toward data
484
+
485
+ t_batch = t_curr.unsqueeze(0)
486
+
487
+ # Predict velocity
488
+ v_cond = model(x, t5_emb, clip_pooled, t_batch, img_ids)
489
+ v_uncond = model(x, t5_null, clip_null, t_batch, img_ids)
490
+
491
+ # Classifier-free guidance
492
+ v = v_uncond + cfg_scale * (v_cond - v_uncond)
493
+
494
+ # Euler step: x_{t+dt} = x_t + v * dt
495
+ x = x + v * dt
496
+
497
+ # Decode with VAE
498
+ x = x.reshape(1, 64, 64, 16).permute(0, 3, 1, 2) # [B, C, H, W]
499
+ x = x / vae.config.scaling_factor
500
+ image = vae.decode(x).sample
501
+
502
+ # Convert to PIL
503
+ image = (image / 2 + 0.5).clamp(0, 1)
504
+ image = image[0].permute(1, 2, 0).cpu().float().numpy()
505
+ image = (image * 255).astype("uint8")
506
+
507
+ from PIL import Image
508
+ return Image.fromarray(image)
509
+
510
+
511
+ # Generate
 
 
 
 
512
  image = generate_image("a photograph of a tiger in natural habitat", seed=42)
513
  image.save("tiger.png")
514
  ```