Spaces:
Sleeping
Sleeping
trixyL commited on
Commit ·
3617dac
1
Parent(s): 886ef62
dump: initial dump
Browse files- README.md +4 -4
- app.py +47 -0
- model.py +776 -0
- model/model.safetensors +3 -0
- requirements.txt +7 -0
README.md
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: red
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.5.1
|
| 8 |
python_version: '3.12'
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
-
short_description:
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: MNIST Diffusion (TransformerLM)
|
| 3 |
+
emoji: 🧪
|
| 4 |
colorFrom: red
|
| 5 |
+
colorTo: yellow
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.5.1
|
| 8 |
python_version: '3.12'
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
+
short_description: Discrete diffusion MNIST digit generation
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import spaces
|
| 4 |
+
|
| 5 |
+
from model import generate_images, load_model
|
| 6 |
+
|
| 7 |
+
MODEL_READY = False
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def ensure_model_loaded():
|
| 11 |
+
global MODEL_READY
|
| 12 |
+
if not MODEL_READY:
|
| 13 |
+
load_model()
|
| 14 |
+
MODEL_READY = True
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@spaces.GPU
|
| 18 |
+
@torch.inference_mode()
|
| 19 |
+
def predict(label: int, steps: int, num_samples: int):
|
| 20 |
+
ensure_model_loaded()
|
| 21 |
+
return generate_images(label=label, steps=steps, num_samples=num_samples)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
with gr.Blocks(title="MNIST Diffusion") as demo:
|
| 25 |
+
gr.Markdown("# MNIST Diffusion")
|
| 26 |
+
gr.Markdown(
|
| 27 |
+
"Discrete diffusion model for MNIST digits. "
|
| 28 |
+
"Sampling uses fixed CFG=2.0, temperature=0.6, top_p=0.99."
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
gallery = gr.Gallery(label="Samples", show_label=True, columns=6, rows=3, height=360)
|
| 32 |
+
|
| 33 |
+
with gr.Row():
|
| 34 |
+
label = gr.Dropdown([str(i) for i in range(10)], value="6", label="Label")
|
| 35 |
+
steps = gr.Slider(1, 784, value=784, step=1, label="Steps")
|
| 36 |
+
num_samples = gr.Slider(1, 36, value=16, step=1, label="Samples")
|
| 37 |
+
|
| 38 |
+
generate_btn = gr.Button("Generate")
|
| 39 |
+
|
| 40 |
+
generate_btn.click(
|
| 41 |
+
fn=predict,
|
| 42 |
+
inputs=[label, steps, num_samples],
|
| 43 |
+
outputs=gallery,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
demo.launch()
|
model.py
ADDED
|
@@ -0,0 +1,776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import Tuple, List
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from safetensors.torch import load_file
|
| 10 |
+
from einops import einsum, rearrange
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 14 |
+
CHECKPOINT_PATH = os.path.join(BASE_DIR, "model", "model.safetensors")
|
| 15 |
+
|
| 16 |
+
MODEL_CONFIG = {
|
| 17 |
+
"model_type": "image",
|
| 18 |
+
"label_vocab_size": 11,
|
| 19 |
+
"vocab_size": 33,
|
| 20 |
+
"pixel_bins": 32,
|
| 21 |
+
"context_length": 784,
|
| 22 |
+
"d_model": 256,
|
| 23 |
+
"num_layers": 8,
|
| 24 |
+
"num_heads": 16,
|
| 25 |
+
"d_ff": 1024,
|
| 26 |
+
"rope_theta": 10000.0,
|
| 27 |
+
"attention_backend": "torch_sdpa",
|
| 28 |
+
"attention_sdp_backend": "auto",
|
| 29 |
+
"device": "cuda",
|
| 30 |
+
"dtype": "float16",
|
| 31 |
+
"mask_token_id": 32,
|
| 32 |
+
"null_label_id": 10,
|
| 33 |
+
"image_height": 28,
|
| 34 |
+
"image_width": 28,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
INFER_CONFIG = {
|
| 38 |
+
"block_length": 784,
|
| 39 |
+
"temperature": 0.6,
|
| 40 |
+
"top_p": 0.99,
|
| 41 |
+
"cfg_scale": 2.0,
|
| 42 |
+
"remasking": "random",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
DTYPES = {
|
| 46 |
+
"float16": torch.float16,
|
| 47 |
+
"float32": torch.float32,
|
| 48 |
+
"bfloat16": torch.bfloat16,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _resolve_device_dtype(device: str, dtype_name: str) -> Tuple[str, torch.dtype]:
|
| 53 |
+
resolved_device = device
|
| 54 |
+
if device == "cuda" and not torch.cuda.is_available():
|
| 55 |
+
resolved_device = "cpu"
|
| 56 |
+
|
| 57 |
+
resolved_dtype = DTYPES[dtype_name]
|
| 58 |
+
if resolved_device == "cpu" and resolved_dtype == torch.float16:
|
| 59 |
+
resolved_dtype = torch.float32
|
| 60 |
+
|
| 61 |
+
return resolved_device, resolved_dtype
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def set_sdp_backend(backend: str) -> None:
|
| 65 |
+
backend = backend.lower()
|
| 66 |
+
allowed = {"auto", "flash", "mem_efficient", "math"}
|
| 67 |
+
if backend not in allowed:
|
| 68 |
+
raise ValueError(f"attention_sdp_backend must be one of {sorted(allowed)}")
|
| 69 |
+
if not torch.cuda.is_available():
|
| 70 |
+
return
|
| 71 |
+
if backend == "auto":
|
| 72 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
| 73 |
+
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 74 |
+
torch.backends.cuda.enable_math_sdp(True)
|
| 75 |
+
return
|
| 76 |
+
torch.backends.cuda.enable_flash_sdp(backend == "flash")
|
| 77 |
+
torch.backends.cuda.enable_mem_efficient_sdp(backend == "mem_efficient")
|
| 78 |
+
torch.backends.cuda.enable_math_sdp(backend == "math")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class Linear(torch.nn.Module):
|
| 82 |
+
def __init__(self, in_features, out_features, device=None, dtype=None):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, device=device, dtype=dtype))
|
| 85 |
+
mean = 0.0
|
| 86 |
+
std = 2 / (in_features + out_features)
|
| 87 |
+
a = mean - 3 * std
|
| 88 |
+
b = mean + 3 * std
|
| 89 |
+
torch.nn.init.trunc_normal_(self.weight, mean=mean, std=std, a=a, b=b)
|
| 90 |
+
|
| 91 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 92 |
+
y = einsum(self.weight, x, "out_features in_features, ... in_features -> ... out_features")
|
| 93 |
+
return y
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class Embedding(torch.nn.Module):
|
| 97 |
+
def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.num_embeddings = num_embeddings
|
| 100 |
+
self.embedding_dim = embedding_dim
|
| 101 |
+
self.weight = torch.nn.Parameter(torch.empty(num_embeddings, embedding_dim, device=device, dtype=dtype))
|
| 102 |
+
torch.nn.init.trunc_normal_(self.weight, mean=0, std=1, a=-3, b=3)
|
| 103 |
+
|
| 104 |
+
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 105 |
+
embeds = self.weight[token_ids]
|
| 106 |
+
return embeds
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class RMSNorm(torch.nn.Module):
|
| 110 |
+
def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.eps = eps
|
| 113 |
+
self.d_model = d_model
|
| 114 |
+
self.weight = torch.nn.Parameter(torch.empty(d_model, device=device, dtype=dtype))
|
| 115 |
+
torch.nn.init.ones_(self.weight)
|
| 116 |
+
|
| 117 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 118 |
+
in_dtype = x.dtype
|
| 119 |
+
x = x.to(torch.float32)
|
| 120 |
+
rms = torch.sqrt(torch.mean(x ** 2, dim=-1) + self.eps).unsqueeze(-1)
|
| 121 |
+
x = (1 / rms) * (x * self.weight)
|
| 122 |
+
return x.to(in_dtype)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class SwiGLU(torch.nn.Module):
|
| 126 |
+
def __init__(self, d_model: int, d_ff: int, device=None, dtype=None):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.w1 = Linear(d_model, d_ff, device=device, dtype=dtype)
|
| 129 |
+
self.w2 = Linear(d_ff, d_model, device=device, dtype=dtype)
|
| 130 |
+
self.w3 = Linear(d_model, d_ff, device=device, dtype=dtype)
|
| 131 |
+
|
| 132 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 133 |
+
w1x = self.w1(x)
|
| 134 |
+
w3x = self.w3(x)
|
| 135 |
+
silu = w1x * torch.sigmoid(w1x)
|
| 136 |
+
glu = silu * w3x
|
| 137 |
+
w2x = self.w2(glu)
|
| 138 |
+
return w2x
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def softmax(x: torch.Tensor, dim: int):
|
| 142 |
+
x_max = x.max(dim=dim, keepdim=True).values
|
| 143 |
+
x_stable = x - x_max
|
| 144 |
+
exp_x = torch.exp(x_stable)
|
| 145 |
+
sum_exp_x = exp_x.sum(dim=dim, keepdim=True)
|
| 146 |
+
return exp_x / sum_exp_x
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def top_p_filter(probs: torch.Tensor, p: float) -> torch.Tensor:
|
| 150 |
+
if probs.dim() < 2:
|
| 151 |
+
raise ValueError("probs must have at least 2 dimensions")
|
| 152 |
+
orig_shape = probs.shape
|
| 153 |
+
vocab = orig_shape[-1]
|
| 154 |
+
probs = probs.reshape(-1, vocab)
|
| 155 |
+
if p <= 0:
|
| 156 |
+
argmax = probs.argmax(dim=-1)
|
| 157 |
+
out = torch.zeros_like(probs)
|
| 158 |
+
out.scatter_(-1, argmax.unsqueeze(-1), 1.0)
|
| 159 |
+
return out.reshape(orig_shape)
|
| 160 |
+
if p >= 1:
|
| 161 |
+
return (probs / probs.sum(dim=-1, keepdim=True)).reshape(orig_shape)
|
| 162 |
+
|
| 163 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
| 164 |
+
cumulative = torch.cumsum(sorted_probs, dim=-1)
|
| 165 |
+
|
| 166 |
+
keep = cumulative <= p
|
| 167 |
+
keep[..., 0] = True
|
| 168 |
+
first_ge = (cumulative >= p).float().argmax(dim=-1)
|
| 169 |
+
rows = torch.arange(keep.shape[0], device=keep.device)
|
| 170 |
+
keep[rows, first_ge] = True
|
| 171 |
+
|
| 172 |
+
filtered_sorted = torch.where(keep, sorted_probs, torch.zeros_like(sorted_probs))
|
| 173 |
+
norm = filtered_sorted.sum(dim=-1, keepdim=True).clamp_min(1e-12)
|
| 174 |
+
filtered_sorted = filtered_sorted / norm
|
| 175 |
+
|
| 176 |
+
filtered = torch.zeros_like(probs)
|
| 177 |
+
filtered.scatter_(dim=-1, index=sorted_indices, src=filtered_sorted)
|
| 178 |
+
return filtered.reshape(orig_shape)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def add_gumbel_noise(logits: torch.Tensor, temperature: float, *, generator: torch.Generator | None = None) -> torch.Tensor:
|
| 182 |
+
if temperature <= 0:
|
| 183 |
+
return logits
|
| 184 |
+
|
| 185 |
+
noise = torch.rand(logits.shape, device=logits.device, dtype=torch.float64, generator=generator)
|
| 186 |
+
gumbel_noise = (-torch.log(noise)) ** temperature
|
| 187 |
+
logits64 = logits.to(torch.float64)
|
| 188 |
+
perturbed = logits64.exp() / gumbel_noise
|
| 189 |
+
return perturbed.to(logits.dtype)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def compute_transfer_schedule(mask: torch.Tensor, steps: int) -> torch.Tensor:
|
| 193 |
+
if steps <= 0:
|
| 194 |
+
raise ValueError("steps must be > 0")
|
| 195 |
+
if mask.dim() != 2:
|
| 196 |
+
raise ValueError("mask must be 2D (batch, block_length)")
|
| 197 |
+
|
| 198 |
+
counts = mask.sum(dim=1, keepdim=True).to(torch.int64)
|
| 199 |
+
base = counts // steps
|
| 200 |
+
remainder = counts % steps
|
| 201 |
+
|
| 202 |
+
schedule = base.expand(-1, steps).clone()
|
| 203 |
+
for idx in range(schedule.size(0)):
|
| 204 |
+
r = remainder[idx, 0].item()
|
| 205 |
+
if r > 0:
|
| 206 |
+
schedule[idx, :r] += 1
|
| 207 |
+
return schedule
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def _prepare_attention_mask(attention_mask: torch.Tensor, ref_tensor: torch.Tensor) -> torch.Tensor:
|
| 211 |
+
mask = attention_mask.to(device=ref_tensor.device, dtype=torch.bool)
|
| 212 |
+
if mask.dim() == 2:
|
| 213 |
+
mask = mask[:, None, None, :]
|
| 214 |
+
elif mask.dim() == 3:
|
| 215 |
+
mask = mask[:, None, :, :]
|
| 216 |
+
elif mask.dim() != 4:
|
| 217 |
+
raise ValueError("attention_mask must be 2D, 3D, or 4D")
|
| 218 |
+
return mask
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def scaled_dot_product_attention(
|
| 222 |
+
Q: torch.Tensor,
|
| 223 |
+
K: torch.Tensor,
|
| 224 |
+
V: torch.Tensor,
|
| 225 |
+
attention_mask: torch.Tensor | None = None,
|
| 226 |
+
):
|
| 227 |
+
scale = torch.tensor(Q.shape[-1], device=Q.device, dtype=Q.dtype).sqrt()
|
| 228 |
+
qk_score = einsum(Q, K, "batch_size ... n d_k, batch_size ... m d_k -> batch_size ... n m") / scale
|
| 229 |
+
if attention_mask is not None:
|
| 230 |
+
mask = _prepare_attention_mask(attention_mask, qk_score)
|
| 231 |
+
qk_score = qk_score.masked_fill(~mask, float("-inf"))
|
| 232 |
+
softmax_qk_score = softmax(qk_score, dim=-1)
|
| 233 |
+
attn = einsum(softmax_qk_score, V, "batch_size ... n m, batch_size ... m d_k -> batch_size ... n d_k")
|
| 234 |
+
return attn
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def torch_scaled_dot_product_attention(
|
| 238 |
+
Q: torch.Tensor,
|
| 239 |
+
K: torch.Tensor,
|
| 240 |
+
V: torch.Tensor,
|
| 241 |
+
attention_mask: torch.Tensor | None = None,
|
| 242 |
+
):
|
| 243 |
+
Q = Q.contiguous()
|
| 244 |
+
K = K.contiguous()
|
| 245 |
+
V = V.contiguous()
|
| 246 |
+
mask = None
|
| 247 |
+
if attention_mask is not None:
|
| 248 |
+
mask = _prepare_attention_mask(attention_mask, Q)
|
| 249 |
+
return torch.nn.functional.scaled_dot_product_attention(Q, K, V, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class RotaryPositionalEmbedding(torch.nn.Module):
|
| 253 |
+
def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None):
|
| 254 |
+
super().__init__()
|
| 255 |
+
self.device = device
|
| 256 |
+
|
| 257 |
+
theta_i = theta ** (torch.arange(0, d_k, 2).float() / d_k)
|
| 258 |
+
position = torch.arange(max_seq_len)
|
| 259 |
+
|
| 260 |
+
phases = position.unsqueeze(1) / theta_i.unsqueeze(0)
|
| 261 |
+
phases_cos = torch.cos(phases)
|
| 262 |
+
phases_sin = torch.sin(phases)
|
| 263 |
+
phases_combined = torch.stack([phases_cos, phases_sin], dim=-1).to(device=device)
|
| 264 |
+
|
| 265 |
+
self.register_buffer("phases", phases_combined, persistent=False)
|
| 266 |
+
|
| 267 |
+
def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
|
| 268 |
+
x = rearrange(x, "... (d_k p) -> ... d_k p", p=2)
|
| 269 |
+
x1 = x[..., 0]
|
| 270 |
+
x2 = x[..., 1]
|
| 271 |
+
|
| 272 |
+
phases_cos = self.phases[..., 0][token_positions].to(dtype=x.dtype)
|
| 273 |
+
phases_sin = self.phases[..., 1][token_positions].to(dtype=x.dtype)
|
| 274 |
+
|
| 275 |
+
x_rotated = torch.stack([
|
| 276 |
+
x1 * phases_cos - x2 * phases_sin,
|
| 277 |
+
x1 * phases_sin + x2 * phases_cos,
|
| 278 |
+
], dim=-1)
|
| 279 |
+
|
| 280 |
+
return x_rotated.flatten(-2)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class MultiheadSelfAttentionRoPE(torch.nn.Module):
|
| 284 |
+
def __init__(
|
| 285 |
+
self,
|
| 286 |
+
d_model: int,
|
| 287 |
+
num_heads: int,
|
| 288 |
+
max_seq_len: int,
|
| 289 |
+
theta: float,
|
| 290 |
+
attention_backend: str = "custom",
|
| 291 |
+
device=None,
|
| 292 |
+
dtype=None,
|
| 293 |
+
):
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.d_model = d_model
|
| 296 |
+
self.num_heads = num_heads
|
| 297 |
+
self.d_k = self.d_model // self.num_heads
|
| 298 |
+
self.d_v = self.d_k
|
| 299 |
+
self.max_seq_len = max_seq_len
|
| 300 |
+
self.theta = theta
|
| 301 |
+
if attention_backend not in {"custom", "torch_sdpa"}:
|
| 302 |
+
raise ValueError("attention_backend must be one of ['custom', 'torch_sdpa']")
|
| 303 |
+
self.attention_backend = attention_backend
|
| 304 |
+
|
| 305 |
+
self.q_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
|
| 306 |
+
self.k_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
|
| 307 |
+
self.v_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
|
| 308 |
+
self.output_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
|
| 309 |
+
|
| 310 |
+
self.rope = RotaryPositionalEmbedding(self.theta, self.d_k, self.max_seq_len, device)
|
| 311 |
+
|
| 312 |
+
def forward(
|
| 313 |
+
self,
|
| 314 |
+
x: torch.Tensor,
|
| 315 |
+
token_positions: torch.Tensor,
|
| 316 |
+
attention_mask: torch.Tensor | None = None,
|
| 317 |
+
) -> torch.Tensor:
|
| 318 |
+
wqx = self.q_proj(x)
|
| 319 |
+
wqx_rearr = rearrange(wqx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k)
|
| 320 |
+
wqx_rearr_rope = self.rope(wqx_rearr, token_positions)
|
| 321 |
+
|
| 322 |
+
wkx = self.k_proj(x)
|
| 323 |
+
wkx_rearr = rearrange(wkx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k)
|
| 324 |
+
wkx_rearr_rope = self.rope(wkx_rearr, token_positions)
|
| 325 |
+
|
| 326 |
+
wvx = self.v_proj(x)
|
| 327 |
+
wvx_rearr = rearrange(wvx, "... seq_len (num_heads d_v) -> ... num_heads seq_len d_v", num_heads=self.num_heads, d_v=self.d_v)
|
| 328 |
+
|
| 329 |
+
if self.attention_backend == "torch_sdpa":
|
| 330 |
+
attn = torch_scaled_dot_product_attention(
|
| 331 |
+
wqx_rearr_rope,
|
| 332 |
+
wkx_rearr_rope,
|
| 333 |
+
wvx_rearr,
|
| 334 |
+
attention_mask=attention_mask,
|
| 335 |
+
)
|
| 336 |
+
else:
|
| 337 |
+
attn = scaled_dot_product_attention(
|
| 338 |
+
wqx_rearr_rope,
|
| 339 |
+
wkx_rearr_rope,
|
| 340 |
+
wvx_rearr,
|
| 341 |
+
attention_mask=attention_mask,
|
| 342 |
+
)
|
| 343 |
+
attn_rearr = rearrange(attn, "... num_heads seq_len d_v -> ... seq_len (num_heads d_v)", num_heads=self.num_heads, d_v=self.d_v)
|
| 344 |
+
attn_rearr_proj = self.output_proj(attn_rearr)
|
| 345 |
+
return attn_rearr_proj
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class MultiheadCrossAttentionRoPE(torch.nn.Module):
|
| 349 |
+
def __init__(
|
| 350 |
+
self,
|
| 351 |
+
d_model: int,
|
| 352 |
+
num_heads: int,
|
| 353 |
+
max_seq_len: int,
|
| 354 |
+
theta: float,
|
| 355 |
+
attention_backend: str = "custom",
|
| 356 |
+
device=None,
|
| 357 |
+
dtype=None,
|
| 358 |
+
):
|
| 359 |
+
super().__init__()
|
| 360 |
+
self.d_model = d_model
|
| 361 |
+
self.num_heads = num_heads
|
| 362 |
+
self.d_k = self.d_model // self.num_heads
|
| 363 |
+
self.d_v = self.d_k
|
| 364 |
+
self.max_seq_len = max_seq_len
|
| 365 |
+
self.theta = theta
|
| 366 |
+
if attention_backend not in {"custom", "torch_sdpa"}:
|
| 367 |
+
raise ValueError("attention_backend must be one of ['custom', 'torch_sdpa']")
|
| 368 |
+
self.attention_backend = attention_backend
|
| 369 |
+
|
| 370 |
+
self.q_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
|
| 371 |
+
self.k_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
|
| 372 |
+
self.v_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
|
| 373 |
+
self.output_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
|
| 374 |
+
|
| 375 |
+
self.rope = RotaryPositionalEmbedding(self.theta, self.d_k, self.max_seq_len, device)
|
| 376 |
+
|
| 377 |
+
def forward(
|
| 378 |
+
self,
|
| 379 |
+
x: torch.Tensor,
|
| 380 |
+
context: torch.Tensor,
|
| 381 |
+
token_positions: torch.Tensor,
|
| 382 |
+
context_token_positions: torch.Tensor,
|
| 383 |
+
attention_mask: torch.Tensor | None = None,
|
| 384 |
+
) -> torch.Tensor:
|
| 385 |
+
wqx = self.q_proj(x)
|
| 386 |
+
wqx_rearr = rearrange(wqx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k)
|
| 387 |
+
wqx_rearr_rope = self.rope(wqx_rearr, token_positions)
|
| 388 |
+
|
| 389 |
+
wkx = self.k_proj(context)
|
| 390 |
+
wkx_rearr = rearrange(wkx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k)
|
| 391 |
+
wkx_rearr_rope = self.rope(wkx_rearr, context_token_positions)
|
| 392 |
+
|
| 393 |
+
wvx = self.v_proj(context)
|
| 394 |
+
wvx_rearr = rearrange(wvx, "... seq_len (num_heads d_v) -> ... num_heads seq_len d_v", num_heads=self.num_heads, d_v=self.d_v)
|
| 395 |
+
|
| 396 |
+
if self.attention_backend == "torch_sdpa":
|
| 397 |
+
attn = torch_scaled_dot_product_attention(
|
| 398 |
+
wqx_rearr_rope,
|
| 399 |
+
wkx_rearr_rope,
|
| 400 |
+
wvx_rearr,
|
| 401 |
+
attention_mask=attention_mask,
|
| 402 |
+
)
|
| 403 |
+
else:
|
| 404 |
+
attn = scaled_dot_product_attention(
|
| 405 |
+
wqx_rearr_rope,
|
| 406 |
+
wkx_rearr_rope,
|
| 407 |
+
wvx_rearr,
|
| 408 |
+
attention_mask=attention_mask,
|
| 409 |
+
)
|
| 410 |
+
attn_rearr = rearrange(attn, "... num_heads seq_len d_v -> ... seq_len (num_heads d_v)", num_heads=self.num_heads, d_v=self.d_v)
|
| 411 |
+
attn_rearr_proj = self.output_proj(attn_rearr)
|
| 412 |
+
return attn_rearr_proj
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
class TransformerImageBlock(torch.nn.Module):
|
| 416 |
+
def __init__(
|
| 417 |
+
self,
|
| 418 |
+
d_model: int,
|
| 419 |
+
num_heads: int,
|
| 420 |
+
max_seq_len: int,
|
| 421 |
+
theta: float,
|
| 422 |
+
d_ff: int,
|
| 423 |
+
attention_backend: str = "custom",
|
| 424 |
+
device=None,
|
| 425 |
+
dtype=None,
|
| 426 |
+
):
|
| 427 |
+
super().__init__()
|
| 428 |
+
self.ffn = SwiGLU(d_model, d_ff, device, dtype)
|
| 429 |
+
self.self_attn = MultiheadSelfAttentionRoPE(
|
| 430 |
+
d_model,
|
| 431 |
+
num_heads,
|
| 432 |
+
max_seq_len,
|
| 433 |
+
theta,
|
| 434 |
+
attention_backend=attention_backend,
|
| 435 |
+
device=device,
|
| 436 |
+
dtype=dtype,
|
| 437 |
+
)
|
| 438 |
+
self.cross_attn = MultiheadCrossAttentionRoPE(
|
| 439 |
+
d_model,
|
| 440 |
+
num_heads,
|
| 441 |
+
max_seq_len,
|
| 442 |
+
theta,
|
| 443 |
+
attention_backend=attention_backend,
|
| 444 |
+
device=device,
|
| 445 |
+
dtype=dtype,
|
| 446 |
+
)
|
| 447 |
+
self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)
|
| 448 |
+
self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)
|
| 449 |
+
self.ln3 = RMSNorm(d_model, device=device, dtype=dtype)
|
| 450 |
+
|
| 451 |
+
def forward(
|
| 452 |
+
self,
|
| 453 |
+
x: torch.Tensor,
|
| 454 |
+
token_positions: torch.Tensor,
|
| 455 |
+
context: torch.Tensor,
|
| 456 |
+
context_token_positions: torch.Tensor,
|
| 457 |
+
attention_mask: torch.Tensor | None = None,
|
| 458 |
+
) -> torch.Tensor:
|
| 459 |
+
ln1x = self.ln1(x)
|
| 460 |
+
x = x + self.self_attn(ln1x, token_positions, attention_mask=attention_mask)
|
| 461 |
+
ln2x = self.ln2(x)
|
| 462 |
+
x = x + self.cross_attn(
|
| 463 |
+
ln2x,
|
| 464 |
+
context,
|
| 465 |
+
token_positions,
|
| 466 |
+
context_token_positions,
|
| 467 |
+
attention_mask=None,
|
| 468 |
+
)
|
| 469 |
+
ln3x = self.ln3(x)
|
| 470 |
+
x = x + self.ffn(ln3x)
|
| 471 |
+
return x
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
class TransformerImage(torch.nn.Module):
|
| 475 |
+
def __init__(
|
| 476 |
+
self,
|
| 477 |
+
vocab_size: int,
|
| 478 |
+
context_length: int,
|
| 479 |
+
d_model: int,
|
| 480 |
+
num_layers: int,
|
| 481 |
+
num_heads: int,
|
| 482 |
+
d_ff: int,
|
| 483 |
+
rope_theta: float,
|
| 484 |
+
label_vocab_size: int,
|
| 485 |
+
attention_backend: str = "custom",
|
| 486 |
+
device=None,
|
| 487 |
+
dtype=None,
|
| 488 |
+
):
|
| 489 |
+
super().__init__()
|
| 490 |
+
self.context_length = context_length
|
| 491 |
+
self.token_embeddings = Embedding(vocab_size, d_model, device, dtype)
|
| 492 |
+
self.label_embeddings = Embedding(label_vocab_size, d_model, device, dtype)
|
| 493 |
+
self.layers = torch.nn.ModuleList(
|
| 494 |
+
[
|
| 495 |
+
TransformerImageBlock(
|
| 496 |
+
d_model,
|
| 497 |
+
num_heads,
|
| 498 |
+
context_length,
|
| 499 |
+
rope_theta,
|
| 500 |
+
d_ff,
|
| 501 |
+
attention_backend=attention_backend,
|
| 502 |
+
device=device,
|
| 503 |
+
dtype=dtype,
|
| 504 |
+
)
|
| 505 |
+
for _ in range(num_layers)
|
| 506 |
+
]
|
| 507 |
+
)
|
| 508 |
+
self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
|
| 509 |
+
self.lm_head = Linear(d_model, vocab_size, device, dtype)
|
| 510 |
+
|
| 511 |
+
def forward(
|
| 512 |
+
self,
|
| 513 |
+
in_indices: torch.Tensor,
|
| 514 |
+
attention_mask: torch.Tensor | None = None,
|
| 515 |
+
context: torch.Tensor | None = None,
|
| 516 |
+
) -> torch.Tensor:
|
| 517 |
+
if context is None:
|
| 518 |
+
raise ValueError("context must be provided for TransformerImage")
|
| 519 |
+
output_seq = self.token_embeddings(in_indices)
|
| 520 |
+
context_emb = self.label_embeddings(context).unsqueeze(-2)
|
| 521 |
+
token_positions = torch.arange(output_seq.shape[-2], device=output_seq.device, dtype=torch.long)
|
| 522 |
+
context_token_positions = torch.arange(context_emb.shape[-2], device=output_seq.device, dtype=torch.long)
|
| 523 |
+
for layer in self.layers:
|
| 524 |
+
output_seq = layer(
|
| 525 |
+
output_seq,
|
| 526 |
+
token_positions,
|
| 527 |
+
context_emb,
|
| 528 |
+
context_token_positions,
|
| 529 |
+
attention_mask=attention_mask,
|
| 530 |
+
)
|
| 531 |
+
normed_output_seq = self.ln_final(output_seq)
|
| 532 |
+
logits = self.lm_head(normed_output_seq)
|
| 533 |
+
return logits
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
@torch.no_grad()
|
| 537 |
+
def image_diffusion_generate(
|
| 538 |
+
model,
|
| 539 |
+
prompt_indices: torch.Tensor,
|
| 540 |
+
*,
|
| 541 |
+
context: torch.Tensor,
|
| 542 |
+
mask_id: int,
|
| 543 |
+
eos_token_id: int | None = None,
|
| 544 |
+
steps: int,
|
| 545 |
+
gen_length: int,
|
| 546 |
+
block_length: int,
|
| 547 |
+
temperature: float = 0.0,
|
| 548 |
+
top_p: float | None = None,
|
| 549 |
+
cfg_scale: float = 0.0,
|
| 550 |
+
uncond_context: torch.Tensor | None = None,
|
| 551 |
+
remasking: str = "random",
|
| 552 |
+
logits_eos_inf: bool = False,
|
| 553 |
+
confidence_eos_eot_inf: bool = False,
|
| 554 |
+
generator: torch.Generator | None = None,
|
| 555 |
+
) -> torch.Tensor:
|
| 556 |
+
if prompt_indices.dim() != 2:
|
| 557 |
+
raise ValueError("prompt_indices must be 2D (batch, seq)")
|
| 558 |
+
if context.dim() != 1:
|
| 559 |
+
raise ValueError("context must be 1D (batch,)")
|
| 560 |
+
if prompt_indices.shape[0] != context.shape[0]:
|
| 561 |
+
raise ValueError("context batch size must match prompt batch size")
|
| 562 |
+
if block_length <= 0:
|
| 563 |
+
raise ValueError("block_length must be > 0")
|
| 564 |
+
if steps <= 0:
|
| 565 |
+
raise ValueError("steps must be > 0")
|
| 566 |
+
|
| 567 |
+
if gen_length <= 0:
|
| 568 |
+
return prompt_indices
|
| 569 |
+
|
| 570 |
+
blocks = max(1, int(np.ceil(gen_length / block_length)))
|
| 571 |
+
if steps < blocks:
|
| 572 |
+
raise ValueError("steps must be >= number of blocks")
|
| 573 |
+
base_steps = steps // blocks
|
| 574 |
+
extra_steps = steps % blocks
|
| 575 |
+
|
| 576 |
+
device = prompt_indices.device
|
| 577 |
+
batch_size, prompt_len = prompt_indices.shape
|
| 578 |
+
total_len = prompt_len + gen_length
|
| 579 |
+
|
| 580 |
+
context_limit = getattr(model, "context_length", None)
|
| 581 |
+
if context_limit is not None and total_len > int(context_limit):
|
| 582 |
+
raise ValueError("prompt length + gen_length exceeds model context_length")
|
| 583 |
+
|
| 584 |
+
x = torch.full(
|
| 585 |
+
(batch_size, total_len),
|
| 586 |
+
fill_value=mask_id,
|
| 587 |
+
device=device,
|
| 588 |
+
dtype=prompt_indices.dtype,
|
| 589 |
+
)
|
| 590 |
+
x[:, :prompt_len] = prompt_indices
|
| 591 |
+
|
| 592 |
+
if uncond_context is not None:
|
| 593 |
+
if uncond_context.dim() != 1:
|
| 594 |
+
raise ValueError("uncond_context must be 1D (batch,)")
|
| 595 |
+
if uncond_context.shape[0] != batch_size:
|
| 596 |
+
raise ValueError("uncond_context batch size must match prompt batch size")
|
| 597 |
+
uncond_context = uncond_context.to(device=context.device, dtype=context.dtype)
|
| 598 |
+
|
| 599 |
+
for block_idx in range(blocks):
|
| 600 |
+
block_start = prompt_len + block_idx * block_length
|
| 601 |
+
block_end = min(block_start + block_length, total_len)
|
| 602 |
+
block_steps = base_steps + (1 if block_idx < extra_steps else 0)
|
| 603 |
+
if block_steps <= 0:
|
| 604 |
+
block_steps = 1
|
| 605 |
+
block_mask = (x[:, block_start:block_end] == mask_id)
|
| 606 |
+
transfer_counts = compute_transfer_schedule(block_mask, block_steps)
|
| 607 |
+
|
| 608 |
+
for step_idx in range(block_steps):
|
| 609 |
+
mask_index = (x == mask_id)
|
| 610 |
+
if cfg_scale > 0.0:
|
| 611 |
+
if uncond_context is None:
|
| 612 |
+
raise ValueError("uncond_context must be set when cfg_scale > 0 for image_diffusion_generate")
|
| 613 |
+
cond_logits = model(x, context=context)
|
| 614 |
+
uncond_logits = model(x, context=uncond_context)
|
| 615 |
+
logits = uncond_logits + (cfg_scale + 1.0) * (cond_logits - uncond_logits)
|
| 616 |
+
else:
|
| 617 |
+
logits = model(x, context=context)
|
| 618 |
+
|
| 619 |
+
if logits_eos_inf and eos_token_id is not None:
|
| 620 |
+
logits[:, :, eos_token_id] = float("-inf")
|
| 621 |
+
|
| 622 |
+
if top_p is not None:
|
| 623 |
+
probs = softmax(logits, dim=-1)
|
| 624 |
+
probs = top_p_filter(probs, float(top_p))
|
| 625 |
+
logits = torch.where(
|
| 626 |
+
probs > 0,
|
| 627 |
+
logits,
|
| 628 |
+
torch.full_like(logits, float("-inf")),
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
logits_with_noise = add_gumbel_noise(logits, temperature, generator=generator)
|
| 632 |
+
predictions = torch.argmax(logits_with_noise, dim=-1)
|
| 633 |
+
predictions = torch.where(mask_index, predictions, x)
|
| 634 |
+
|
| 635 |
+
if remasking == "low_confidence":
|
| 636 |
+
probs = softmax(logits, dim=-1)
|
| 637 |
+
confidence = torch.squeeze(
|
| 638 |
+
torch.gather(probs, dim=-1, index=torch.unsqueeze(predictions, -1)),
|
| 639 |
+
-1,
|
| 640 |
+
)
|
| 641 |
+
elif remasking == "random":
|
| 642 |
+
confidence = torch.rand(
|
| 643 |
+
(batch_size, total_len),
|
| 644 |
+
device=device,
|
| 645 |
+
dtype=torch.float32,
|
| 646 |
+
generator=generator,
|
| 647 |
+
)
|
| 648 |
+
else:
|
| 649 |
+
raise ValueError(f"Unsupported remasking strategy: {remasking}")
|
| 650 |
+
|
| 651 |
+
if confidence_eos_eot_inf and eos_token_id is not None:
|
| 652 |
+
confidence = torch.where(
|
| 653 |
+
predictions == eos_token_id,
|
| 654 |
+
torch.full_like(confidence, float("-inf")),
|
| 655 |
+
confidence,
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
confidence[:, block_end:] = float("-inf")
|
| 659 |
+
confidence = torch.where(mask_index, confidence, torch.full_like(confidence, float("-inf")))
|
| 660 |
+
|
| 661 |
+
transfer_mask = torch.zeros_like(mask_index)
|
| 662 |
+
for b in range(batch_size):
|
| 663 |
+
k = int(transfer_counts[b, step_idx].item())
|
| 664 |
+
if k <= 0:
|
| 665 |
+
continue
|
| 666 |
+
available = confidence[b] > float("-inf")
|
| 667 |
+
available_count = int(available.sum().item())
|
| 668 |
+
if available_count == 0:
|
| 669 |
+
continue
|
| 670 |
+
if available_count < k:
|
| 671 |
+
k = available_count
|
| 672 |
+
topk_indices = torch.topk(confidence[b], k=k, dim=-1).indices
|
| 673 |
+
transfer_mask[b, topk_indices] = True
|
| 674 |
+
|
| 675 |
+
x = torch.where(transfer_mask, predictions, x)
|
| 676 |
+
|
| 677 |
+
return x
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
def dequantize_tokens_to_uint8(tokens: np.ndarray, *, pixel_bins: int) -> np.ndarray:
|
| 681 |
+
if pixel_bins == 256:
|
| 682 |
+
return tokens.astype(np.uint8)
|
| 683 |
+
vals = np.clip(tokens.astype(np.int32), 0, int(pixel_bins) - 1)
|
| 684 |
+
scale = 256.0 / float(pixel_bins)
|
| 685 |
+
restored = np.round((vals + 0.5) * scale - 0.5)
|
| 686 |
+
return np.clip(restored, 0, 255).astype(np.uint8)
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
MODEL = None
|
| 690 |
+
DEVICE = None
|
| 691 |
+
DTYPE = None
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
def load_model():
|
| 695 |
+
global MODEL, DEVICE, DTYPE
|
| 696 |
+
if MODEL is not None:
|
| 697 |
+
return MODEL, DEVICE, DTYPE
|
| 698 |
+
|
| 699 |
+
if not os.path.exists(CHECKPOINT_PATH):
|
| 700 |
+
raise FileNotFoundError(f"Missing checkpoint at {CHECKPOINT_PATH}")
|
| 701 |
+
|
| 702 |
+
device, dtype = _resolve_device_dtype(MODEL_CONFIG["device"], MODEL_CONFIG["dtype"])
|
| 703 |
+
set_sdp_backend(MODEL_CONFIG["attention_sdp_backend"])
|
| 704 |
+
|
| 705 |
+
model = TransformerImage(
|
| 706 |
+
vocab_size=MODEL_CONFIG["vocab_size"],
|
| 707 |
+
context_length=MODEL_CONFIG["context_length"],
|
| 708 |
+
d_model=MODEL_CONFIG["d_model"],
|
| 709 |
+
num_layers=MODEL_CONFIG["num_layers"],
|
| 710 |
+
num_heads=MODEL_CONFIG["num_heads"],
|
| 711 |
+
d_ff=MODEL_CONFIG["d_ff"],
|
| 712 |
+
rope_theta=MODEL_CONFIG["rope_theta"],
|
| 713 |
+
label_vocab_size=MODEL_CONFIG["label_vocab_size"],
|
| 714 |
+
attention_backend=MODEL_CONFIG["attention_backend"],
|
| 715 |
+
device=device,
|
| 716 |
+
dtype=dtype,
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
model_state = load_file(CHECKPOINT_PATH)
|
| 720 |
+
model.load_state_dict(model_state)
|
| 721 |
+
model.eval().to(device)
|
| 722 |
+
|
| 723 |
+
MODEL = model
|
| 724 |
+
DEVICE = device
|
| 725 |
+
DTYPE = dtype
|
| 726 |
+
return MODEL, DEVICE, DTYPE
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
@torch.inference_mode()
|
| 730 |
+
def generate_images(label: int, steps: int, num_samples: int) -> List[Image.Image]:
|
| 731 |
+
model, device, _ = load_model()
|
| 732 |
+
|
| 733 |
+
num_samples = int(num_samples)
|
| 734 |
+
label = int(label)
|
| 735 |
+
steps = int(steps)
|
| 736 |
+
|
| 737 |
+
context = torch.full((num_samples,), label, device=device, dtype=torch.long)
|
| 738 |
+
prompt = torch.empty((num_samples, 0), device=device, dtype=torch.long)
|
| 739 |
+
|
| 740 |
+
cfg_scale = float(INFER_CONFIG["cfg_scale"])
|
| 741 |
+
uncond_context = None
|
| 742 |
+
if cfg_scale > 0.0:
|
| 743 |
+
null_label_id = int(MODEL_CONFIG["null_label_id"])
|
| 744 |
+
uncond_context = torch.full((num_samples,), null_label_id, device=device, dtype=torch.long)
|
| 745 |
+
|
| 746 |
+
out_indices = image_diffusion_generate(
|
| 747 |
+
model,
|
| 748 |
+
prompt,
|
| 749 |
+
context=context,
|
| 750 |
+
mask_id=int(MODEL_CONFIG["mask_token_id"]),
|
| 751 |
+
eos_token_id=None,
|
| 752 |
+
steps=steps,
|
| 753 |
+
gen_length=int(MODEL_CONFIG["context_length"]),
|
| 754 |
+
block_length=int(INFER_CONFIG["block_length"]),
|
| 755 |
+
temperature=float(INFER_CONFIG["temperature"]),
|
| 756 |
+
top_p=float(INFER_CONFIG["top_p"]),
|
| 757 |
+
cfg_scale=cfg_scale,
|
| 758 |
+
uncond_context=uncond_context,
|
| 759 |
+
remasking=str(INFER_CONFIG["remasking"]),
|
| 760 |
+
logits_eos_inf=False,
|
| 761 |
+
confidence_eos_eot_inf=False,
|
| 762 |
+
generator=None,
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
h = int(MODEL_CONFIG["image_height"])
|
| 766 |
+
w = int(MODEL_CONFIG["image_width"])
|
| 767 |
+
pixel_bins = int(MODEL_CONFIG["pixel_bins"])
|
| 768 |
+
|
| 769 |
+
images: List[Image.Image] = []
|
| 770 |
+
for i in range(num_samples):
|
| 771 |
+
tokens = out_indices[i].detach().cpu().to(torch.int32).numpy().reshape(h, w)
|
| 772 |
+
arr = dequantize_tokens_to_uint8(tokens, pixel_bins=pixel_bins)
|
| 773 |
+
img = Image.fromarray(arr, mode="L")
|
| 774 |
+
images.append(img)
|
| 775 |
+
|
| 776 |
+
return images
|
model/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f863ca7bfd2fc11fc6cf4f3df57567655a43bf4cf9ccaa66f254ed6ed248c9e0
|
| 3 |
+
size 42058920
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
spaces
|
| 3 |
+
torch
|
| 4 |
+
einops
|
| 5 |
+
safetensors
|
| 6 |
+
numpy
|
| 7 |
+
pillow
|