Spaces:
Sleeping
Sleeping
trixyL commited on
Commit ·
776256b
1
Parent(s): d4ac762
refactor: single sample, intermediate steps
Browse files
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import torch
|
|
| 2 |
import gradio as gr
|
| 3 |
import spaces
|
| 4 |
|
| 5 |
-
from model import
|
| 6 |
|
| 7 |
MODEL_READY = False
|
| 8 |
|
|
@@ -16,31 +16,35 @@ def ensure_model_loaded():
|
|
| 16 |
|
| 17 |
@spaces.GPU
|
| 18 |
@torch.inference_mode()
|
| 19 |
-
def predict(label: int, steps: int
|
| 20 |
ensure_model_loaded()
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
with gr.Blocks(title="MNIST Diffusion") as demo:
|
| 25 |
gr.Markdown("# MNIST Diffusion")
|
| 26 |
gr.Markdown(
|
| 27 |
"Discrete diffusion model for MNIST digits. "
|
| 28 |
-
"
|
|
|
|
| 29 |
)
|
| 30 |
|
| 31 |
-
grid = gr.Image(label="
|
|
|
|
| 32 |
|
| 33 |
with gr.Row():
|
| 34 |
label = gr.Dropdown([str(i) for i in range(10)], value="4", label="Label")
|
| 35 |
-
steps = gr.Slider(
|
| 36 |
-
num_samples = gr.Slider(1, 36, value=16, step=1, label="Samples")
|
| 37 |
|
| 38 |
generate_btn = gr.Button("Generate")
|
| 39 |
|
| 40 |
generate_btn.click(
|
| 41 |
fn=predict,
|
| 42 |
-
inputs=[label, steps
|
| 43 |
-
outputs=grid,
|
| 44 |
scroll_to_output=True,
|
| 45 |
)
|
| 46 |
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import spaces
|
| 4 |
|
| 5 |
+
from model import iter_trajectory_frames, load_model
|
| 6 |
|
| 7 |
MODEL_READY = False
|
| 8 |
|
|
|
|
| 16 |
|
| 17 |
@spaces.GPU
|
| 18 |
@torch.inference_mode()
|
| 19 |
+
def predict(label: int, steps: int):
|
| 20 |
ensure_model_loaded()
|
| 21 |
+
for idx, (image, step_idx, total_steps, total) in enumerate(
|
| 22 |
+
iter_trajectory_frames(label=label, steps=steps), start=1
|
| 23 |
+
):
|
| 24 |
+
yield image, f"trajectory checkpoint {idx}/{total} | denoising step {step_idx}/{total_steps}"
|
| 25 |
|
| 26 |
|
| 27 |
with gr.Blocks(title="MNIST Diffusion") as demo:
|
| 28 |
gr.Markdown("# MNIST Diffusion")
|
| 29 |
gr.Markdown(
|
| 30 |
"Discrete diffusion model for MNIST digits. "
|
| 31 |
+
"The demo streams one sample as masked tokens are resolved with fixed CFG=2.0, "
|
| 32 |
+
"temperature=0.6, and top_p=0.99."
|
| 33 |
)
|
| 34 |
|
| 35 |
+
grid = gr.Image(label="Trajectory", show_label=True)
|
| 36 |
+
status = gr.Textbox(label="Status")
|
| 37 |
|
| 38 |
with gr.Row():
|
| 39 |
label = gr.Dropdown([str(i) for i in range(10)], value="4", label="Label")
|
| 40 |
+
steps = gr.Slider(32, 784, value=784, step=1, label="Steps")
|
|
|
|
| 41 |
|
| 42 |
generate_btn = gr.Button("Generate")
|
| 43 |
|
| 44 |
generate_btn.click(
|
| 45 |
fn=predict,
|
| 46 |
+
inputs=[label, steps],
|
| 47 |
+
outputs=[grid, status],
|
| 48 |
scroll_to_output=True,
|
| 49 |
)
|
| 50 |
|
model.py
CHANGED
|
@@ -40,6 +40,7 @@ INFER_CONFIG = {
|
|
| 40 |
"top_p": 0.99,
|
| 41 |
"cfg_scale": 2.0,
|
| 42 |
"remasking": "random",
|
|
|
|
| 43 |
}
|
| 44 |
|
| 45 |
DTYPES = {
|
|
@@ -779,6 +780,18 @@ def generate_images(label: int, steps: int, num_samples: int) -> List[Image.Imag
|
|
| 779 |
return images
|
| 780 |
|
| 781 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 782 |
def _grid_dims(num_samples: int) -> Tup[int, int]:
|
| 783 |
cols = int(np.ceil(np.sqrt(num_samples)))
|
| 784 |
rows = int(np.ceil(num_samples / cols))
|
|
@@ -799,3 +812,83 @@ def generate_grid_image(label: int, steps: int, num_samples: int) -> Image.Image
|
|
| 799 |
c = idx % cols
|
| 800 |
grid.paste(img, (c * w, r * h))
|
| 801 |
return grid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
"top_p": 0.99,
|
| 41 |
"cfg_scale": 2.0,
|
| 42 |
"remasking": "random",
|
| 43 |
+
"trajectory_checkpoints": 32,
|
| 44 |
}
|
| 45 |
|
| 46 |
DTYPES = {
|
|
|
|
| 780 |
return images
|
| 781 |
|
| 782 |
|
| 783 |
+
def _to_image(tokens: torch.Tensor) -> Image.Image:
|
| 784 |
+
h = int(MODEL_CONFIG["image_height"])
|
| 785 |
+
w = int(MODEL_CONFIG["image_width"])
|
| 786 |
+
pixel_bins = int(MODEL_CONFIG["pixel_bins"])
|
| 787 |
+
scale = 10
|
| 788 |
+
arr = tokens.detach().cpu().to(torch.int32).numpy().reshape(h, w)
|
| 789 |
+
img = Image.fromarray(dequantize_tokens_to_uint8(arr, pixel_bins=pixel_bins), mode="L")
|
| 790 |
+
if scale > 1:
|
| 791 |
+
img = img.resize((w * scale, h * scale), resample=Image.NEAREST)
|
| 792 |
+
return img
|
| 793 |
+
|
| 794 |
+
|
| 795 |
def _grid_dims(num_samples: int) -> Tup[int, int]:
|
| 796 |
cols = int(np.ceil(np.sqrt(num_samples)))
|
| 797 |
rows = int(np.ceil(num_samples / cols))
|
|
|
|
| 812 |
c = idx % cols
|
| 813 |
grid.paste(img, (c * w, r * h))
|
| 814 |
return grid
|
| 815 |
+
|
| 816 |
+
|
| 817 |
+
@torch.inference_mode()
|
| 818 |
+
def iter_trajectory_frames(label: int, steps: int):
|
| 819 |
+
model, device, _ = load_model()
|
| 820 |
+
label = int(label)
|
| 821 |
+
steps = max(32, int(steps))
|
| 822 |
+
batch_size = 1
|
| 823 |
+
context = torch.full((batch_size,), label, device=device, dtype=torch.long)
|
| 824 |
+
prompt_len = 0
|
| 825 |
+
gen_length = int(MODEL_CONFIG["context_length"])
|
| 826 |
+
block_length = int(INFER_CONFIG["block_length"])
|
| 827 |
+
total_len = prompt_len + gen_length
|
| 828 |
+
blocks = max(1, int(np.ceil(gen_length / block_length)))
|
| 829 |
+
if steps < blocks:
|
| 830 |
+
steps = blocks
|
| 831 |
+
base_steps = steps // blocks
|
| 832 |
+
extra_steps = steps % blocks
|
| 833 |
+
|
| 834 |
+
x = torch.full((batch_size, total_len), fill_value=int(MODEL_CONFIG["mask_token_id"]), device=device, dtype=torch.long)
|
| 835 |
+
uncond_context = torch.full((batch_size,), int(MODEL_CONFIG["null_label_id"]), device=device, dtype=torch.long)
|
| 836 |
+
|
| 837 |
+
checkpoint_count = max(32, int(INFER_CONFIG["trajectory_checkpoints"]))
|
| 838 |
+
checkpoint_indices = np.linspace(1, steps, num=checkpoint_count, dtype=int).tolist()
|
| 839 |
+
checkpoint_indices = sorted(set(max(1, min(steps, idx)) for idx in checkpoint_indices))
|
| 840 |
+
checkpoint_set = set(checkpoint_indices)
|
| 841 |
+
captured = 0
|
| 842 |
+
global_step = 0
|
| 843 |
+
|
| 844 |
+
for block_idx in range(blocks):
|
| 845 |
+
block_start = prompt_len + block_idx * block_length
|
| 846 |
+
block_end = min(block_start + block_length, total_len)
|
| 847 |
+
block_steps = base_steps + (1 if block_idx < extra_steps else 0)
|
| 848 |
+
if block_steps <= 0:
|
| 849 |
+
block_steps = 1
|
| 850 |
+
block_mask = x[:, block_start:block_end] == int(MODEL_CONFIG["mask_token_id"])
|
| 851 |
+
transfer_counts = compute_transfer_schedule(block_mask, block_steps)
|
| 852 |
+
|
| 853 |
+
for step_idx in range(block_steps):
|
| 854 |
+
global_step += 1
|
| 855 |
+
mask_index = x == int(MODEL_CONFIG["mask_token_id"])
|
| 856 |
+
cfg_scale = float(INFER_CONFIG["cfg_scale"])
|
| 857 |
+
if cfg_scale > 0.0:
|
| 858 |
+
cond_logits = model(x, context=context)
|
| 859 |
+
uncond_logits = model(x, context=uncond_context)
|
| 860 |
+
logits = uncond_logits + (cfg_scale + 1.0) * (cond_logits - uncond_logits)
|
| 861 |
+
else:
|
| 862 |
+
logits = model(x, context=context)
|
| 863 |
+
|
| 864 |
+
probs = softmax(logits, dim=-1)
|
| 865 |
+
probs = top_p_filter(probs, float(INFER_CONFIG["top_p"]))
|
| 866 |
+
logits = torch.where(probs > 0, logits, torch.full_like(logits, float("-inf")))
|
| 867 |
+
|
| 868 |
+
logits_with_noise = add_gumbel_noise(logits, float(INFER_CONFIG["temperature"]), generator=None)
|
| 869 |
+
predictions = torch.argmax(logits_with_noise, dim=-1)
|
| 870 |
+
predictions = torch.where(mask_index, predictions, x)
|
| 871 |
+
|
| 872 |
+
confidence = torch.rand((batch_size, total_len), device=device, dtype=torch.float32)
|
| 873 |
+
confidence[:, block_end:] = float("-inf")
|
| 874 |
+
confidence = torch.where(mask_index, confidence, torch.full_like(confidence, float("-inf")))
|
| 875 |
+
|
| 876 |
+
transfer_mask = torch.zeros_like(mask_index)
|
| 877 |
+
for b in range(batch_size):
|
| 878 |
+
k = int(transfer_counts[b, step_idx].item())
|
| 879 |
+
if k <= 0:
|
| 880 |
+
continue
|
| 881 |
+
available = confidence[b] > float("-inf")
|
| 882 |
+
available_count = int(available.sum().item())
|
| 883 |
+
if available_count == 0:
|
| 884 |
+
continue
|
| 885 |
+
if available_count < k:
|
| 886 |
+
k = available_count
|
| 887 |
+
topk_indices = torch.topk(confidence[b], k=k, dim=-1).indices
|
| 888 |
+
transfer_mask[b, topk_indices] = True
|
| 889 |
+
|
| 890 |
+
x = torch.where(transfer_mask, predictions, x)
|
| 891 |
+
|
| 892 |
+
if global_step in checkpoint_set:
|
| 893 |
+
captured += 1
|
| 894 |
+
yield _to_image(x[0]), global_step, steps, len(checkpoint_indices)
|