trixyL commited on
Commit
776256b
·
1 Parent(s): d4ac762

refactor: single sample, intermediate steps

Browse files
Files changed (2) hide show
  1. app.py +13 -9
  2. model.py +93 -0
app.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import gradio as gr
3
  import spaces
4
 
5
- from model import generate_grid_image, load_model
6
 
7
  MODEL_READY = False
8
 
@@ -16,31 +16,35 @@ def ensure_model_loaded():
16
 
17
  @spaces.GPU
18
  @torch.inference_mode()
19
- def predict(label: int, steps: int, num_samples: int):
20
  ensure_model_loaded()
21
- return generate_grid_image(label=label, steps=steps, num_samples=num_samples)
 
 
 
22
 
23
 
24
  with gr.Blocks(title="MNIST Diffusion") as demo:
25
  gr.Markdown("# MNIST Diffusion")
26
  gr.Markdown(
27
  "Discrete diffusion model for MNIST digits. "
28
- "Sampling uses fixed CFG=2.0, temperature=0.6, top_p=0.99."
 
29
  )
30
 
31
- grid = gr.Image(label="Samples", show_label=True)
 
32
 
33
  with gr.Row():
34
  label = gr.Dropdown([str(i) for i in range(10)], value="4", label="Label")
35
- steps = gr.Slider(1, 784, value=784, step=1, label="Steps")
36
- num_samples = gr.Slider(1, 36, value=16, step=1, label="Samples")
37
 
38
  generate_btn = gr.Button("Generate")
39
 
40
  generate_btn.click(
41
  fn=predict,
42
- inputs=[label, steps, num_samples],
43
- outputs=grid,
44
  scroll_to_output=True,
45
  )
46
 
 
2
  import gradio as gr
3
  import spaces
4
 
5
+ from model import iter_trajectory_frames, load_model
6
 
7
  MODEL_READY = False
8
 
 
16
 
17
  @spaces.GPU
18
  @torch.inference_mode()
19
+ def predict(label: int, steps: int):
20
  ensure_model_loaded()
21
+ for idx, (image, step_idx, total_steps, total) in enumerate(
22
+ iter_trajectory_frames(label=label, steps=steps), start=1
23
+ ):
24
+ yield image, f"trajectory checkpoint {idx}/{total} | denoising step {step_idx}/{total_steps}"
25
 
26
 
27
  with gr.Blocks(title="MNIST Diffusion") as demo:
28
  gr.Markdown("# MNIST Diffusion")
29
  gr.Markdown(
30
  "Discrete diffusion model for MNIST digits. "
31
+ "The demo streams one sample as masked tokens are resolved with fixed CFG=2.0, "
32
+ "temperature=0.6, and top_p=0.99."
33
  )
34
 
35
+ grid = gr.Image(label="Trajectory", show_label=True)
36
+ status = gr.Textbox(label="Status")
37
 
38
  with gr.Row():
39
  label = gr.Dropdown([str(i) for i in range(10)], value="4", label="Label")
40
+ steps = gr.Slider(32, 784, value=784, step=1, label="Steps")
 
41
 
42
  generate_btn = gr.Button("Generate")
43
 
44
  generate_btn.click(
45
  fn=predict,
46
+ inputs=[label, steps],
47
+ outputs=[grid, status],
48
  scroll_to_output=True,
49
  )
50
 
model.py CHANGED
@@ -40,6 +40,7 @@ INFER_CONFIG = {
40
  "top_p": 0.99,
41
  "cfg_scale": 2.0,
42
  "remasking": "random",
 
43
  }
44
 
45
  DTYPES = {
@@ -779,6 +780,18 @@ def generate_images(label: int, steps: int, num_samples: int) -> List[Image.Imag
779
  return images
780
 
781
 
 
 
 
 
 
 
 
 
 
 
 
 
782
  def _grid_dims(num_samples: int) -> Tup[int, int]:
783
  cols = int(np.ceil(np.sqrt(num_samples)))
784
  rows = int(np.ceil(num_samples / cols))
@@ -799,3 +812,83 @@ def generate_grid_image(label: int, steps: int, num_samples: int) -> Image.Image
799
  c = idx % cols
800
  grid.paste(img, (c * w, r * h))
801
  return grid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  "top_p": 0.99,
41
  "cfg_scale": 2.0,
42
  "remasking": "random",
43
+ "trajectory_checkpoints": 32,
44
  }
45
 
46
  DTYPES = {
 
780
  return images
781
 
782
 
783
+ def _to_image(tokens: torch.Tensor) -> Image.Image:
784
+ h = int(MODEL_CONFIG["image_height"])
785
+ w = int(MODEL_CONFIG["image_width"])
786
+ pixel_bins = int(MODEL_CONFIG["pixel_bins"])
787
+ scale = 10
788
+ arr = tokens.detach().cpu().to(torch.int32).numpy().reshape(h, w)
789
+ img = Image.fromarray(dequantize_tokens_to_uint8(arr, pixel_bins=pixel_bins), mode="L")
790
+ if scale > 1:
791
+ img = img.resize((w * scale, h * scale), resample=Image.NEAREST)
792
+ return img
793
+
794
+
795
  def _grid_dims(num_samples: int) -> Tup[int, int]:
796
  cols = int(np.ceil(np.sqrt(num_samples)))
797
  rows = int(np.ceil(num_samples / cols))
 
812
  c = idx % cols
813
  grid.paste(img, (c * w, r * h))
814
  return grid
815
+
816
+
817
+ @torch.inference_mode()
818
+ def iter_trajectory_frames(label: int, steps: int):
819
+ model, device, _ = load_model()
820
+ label = int(label)
821
+ steps = max(32, int(steps))
822
+ batch_size = 1
823
+ context = torch.full((batch_size,), label, device=device, dtype=torch.long)
824
+ prompt_len = 0
825
+ gen_length = int(MODEL_CONFIG["context_length"])
826
+ block_length = int(INFER_CONFIG["block_length"])
827
+ total_len = prompt_len + gen_length
828
+ blocks = max(1, int(np.ceil(gen_length / block_length)))
829
+ if steps < blocks:
830
+ steps = blocks
831
+ base_steps = steps // blocks
832
+ extra_steps = steps % blocks
833
+
834
+ x = torch.full((batch_size, total_len), fill_value=int(MODEL_CONFIG["mask_token_id"]), device=device, dtype=torch.long)
835
+ uncond_context = torch.full((batch_size,), int(MODEL_CONFIG["null_label_id"]), device=device, dtype=torch.long)
836
+
837
+ checkpoint_count = max(32, int(INFER_CONFIG["trajectory_checkpoints"]))
838
+ checkpoint_indices = np.linspace(1, steps, num=checkpoint_count, dtype=int).tolist()
839
+ checkpoint_indices = sorted(set(max(1, min(steps, idx)) for idx in checkpoint_indices))
840
+ checkpoint_set = set(checkpoint_indices)
841
+ captured = 0
842
+ global_step = 0
843
+
844
+ for block_idx in range(blocks):
845
+ block_start = prompt_len + block_idx * block_length
846
+ block_end = min(block_start + block_length, total_len)
847
+ block_steps = base_steps + (1 if block_idx < extra_steps else 0)
848
+ if block_steps <= 0:
849
+ block_steps = 1
850
+ block_mask = x[:, block_start:block_end] == int(MODEL_CONFIG["mask_token_id"])
851
+ transfer_counts = compute_transfer_schedule(block_mask, block_steps)
852
+
853
+ for step_idx in range(block_steps):
854
+ global_step += 1
855
+ mask_index = x == int(MODEL_CONFIG["mask_token_id"])
856
+ cfg_scale = float(INFER_CONFIG["cfg_scale"])
857
+ if cfg_scale > 0.0:
858
+ cond_logits = model(x, context=context)
859
+ uncond_logits = model(x, context=uncond_context)
860
+ logits = uncond_logits + (cfg_scale + 1.0) * (cond_logits - uncond_logits)
861
+ else:
862
+ logits = model(x, context=context)
863
+
864
+ probs = softmax(logits, dim=-1)
865
+ probs = top_p_filter(probs, float(INFER_CONFIG["top_p"]))
866
+ logits = torch.where(probs > 0, logits, torch.full_like(logits, float("-inf")))
867
+
868
+ logits_with_noise = add_gumbel_noise(logits, float(INFER_CONFIG["temperature"]), generator=None)
869
+ predictions = torch.argmax(logits_with_noise, dim=-1)
870
+ predictions = torch.where(mask_index, predictions, x)
871
+
872
+ confidence = torch.rand((batch_size, total_len), device=device, dtype=torch.float32)
873
+ confidence[:, block_end:] = float("-inf")
874
+ confidence = torch.where(mask_index, confidence, torch.full_like(confidence, float("-inf")))
875
+
876
+ transfer_mask = torch.zeros_like(mask_index)
877
+ for b in range(batch_size):
878
+ k = int(transfer_counts[b, step_idx].item())
879
+ if k <= 0:
880
+ continue
881
+ available = confidence[b] > float("-inf")
882
+ available_count = int(available.sum().item())
883
+ if available_count == 0:
884
+ continue
885
+ if available_count < k:
886
+ k = available_count
887
+ topk_indices = torch.topk(confidence[b], k=k, dim=-1).indices
888
+ transfer_mask[b, topk_indices] = True
889
+
890
+ x = torch.where(transfer_mask, predictions, x)
891
+
892
+ if global_step in checkpoint_set:
893
+ captured += 1
894
+ yield _to_image(x[0]), global_step, steps, len(checkpoint_indices)