mnist-diff-demo / model.py
trixyL
exp: single grid image
7c3029a
from __future__ import annotations
import os
from typing import Tuple, List, Tuple as Tup
import numpy as np
import torch
from PIL import Image
from safetensors.torch import load_file
from einops import einsum, rearrange
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
CHECKPOINT_PATH = os.path.join(BASE_DIR, "model", "model.safetensors")
MODEL_CONFIG = {
"model_type": "image",
"label_vocab_size": 11,
"vocab_size": 33,
"pixel_bins": 32,
"context_length": 784,
"d_model": 256,
"num_layers": 8,
"num_heads": 16,
"d_ff": 1024,
"rope_theta": 10000.0,
"attention_backend": "torch_sdpa",
"attention_sdp_backend": "auto",
"device": "cuda",
"dtype": "float16",
"mask_token_id": 32,
"null_label_id": 10,
"image_height": 28,
"image_width": 28,
}
INFER_CONFIG = {
"block_length": 784,
"temperature": 0.6,
"top_p": 0.99,
"cfg_scale": 2.0,
"remasking": "random",
}
DTYPES = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
def _resolve_device_dtype(device: str, dtype_name: str) -> Tuple[str, torch.dtype]:
resolved_device = device
if device == "cuda" and not torch.cuda.is_available():
resolved_device = "cpu"
resolved_dtype = DTYPES[dtype_name]
if resolved_device == "cpu" and resolved_dtype == torch.float16:
resolved_dtype = torch.float32
return resolved_device, resolved_dtype
def set_sdp_backend(backend: str) -> None:
backend = backend.lower()
allowed = {"auto", "flash", "mem_efficient", "math"}
if backend not in allowed:
raise ValueError(f"attention_sdp_backend must be one of {sorted(allowed)}")
if not torch.cuda.is_available():
return
if backend == "auto":
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
return
torch.backends.cuda.enable_flash_sdp(backend == "flash")
torch.backends.cuda.enable_mem_efficient_sdp(backend == "mem_efficient")
torch.backends.cuda.enable_math_sdp(backend == "math")
class Linear(torch.nn.Module):
def __init__(self, in_features, out_features, device=None, dtype=None):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, device=device, dtype=dtype))
mean = 0.0
std = 2 / (in_features + out_features)
a = mean - 3 * std
b = mean + 3 * std
torch.nn.init.trunc_normal_(self.weight, mean=mean, std=std, a=a, b=b)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = einsum(self.weight, x, "out_features in_features, ... in_features -> ... out_features")
return y
class Embedding(torch.nn.Module):
def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.weight = torch.nn.Parameter(torch.empty(num_embeddings, embedding_dim, device=device, dtype=dtype))
torch.nn.init.trunc_normal_(self.weight, mean=0, std=1, a=-3, b=3)
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
embeds = self.weight[token_ids]
return embeds
class RMSNorm(torch.nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
super().__init__()
self.eps = eps
self.d_model = d_model
self.weight = torch.nn.Parameter(torch.empty(d_model, device=device, dtype=dtype))
torch.nn.init.ones_(self.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
in_dtype = x.dtype
x = x.to(torch.float32)
rms = torch.sqrt(torch.mean(x ** 2, dim=-1) + self.eps).unsqueeze(-1)
x = (1 / rms) * (x * self.weight)
return x.to(in_dtype)
class SwiGLU(torch.nn.Module):
def __init__(self, d_model: int, d_ff: int, device=None, dtype=None):
super().__init__()
self.w1 = Linear(d_model, d_ff, device=device, dtype=dtype)
self.w2 = Linear(d_ff, d_model, device=device, dtype=dtype)
self.w3 = Linear(d_model, d_ff, device=device, dtype=dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
w1x = self.w1(x)
w3x = self.w3(x)
silu = w1x * torch.sigmoid(w1x)
glu = silu * w3x
w2x = self.w2(glu)
return w2x
def softmax(x: torch.Tensor, dim: int):
x_max = x.max(dim=dim, keepdim=True).values
x_stable = x - x_max
exp_x = torch.exp(x_stable)
sum_exp_x = exp_x.sum(dim=dim, keepdim=True)
return exp_x / sum_exp_x
def top_p_filter(probs: torch.Tensor, p: float) -> torch.Tensor:
if probs.dim() < 2:
raise ValueError("probs must have at least 2 dimensions")
orig_shape = probs.shape
vocab = orig_shape[-1]
probs = probs.reshape(-1, vocab)
if p <= 0:
argmax = probs.argmax(dim=-1)
out = torch.zeros_like(probs)
out.scatter_(-1, argmax.unsqueeze(-1), 1.0)
return out.reshape(orig_shape)
if p >= 1:
return (probs / probs.sum(dim=-1, keepdim=True)).reshape(orig_shape)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative = torch.cumsum(sorted_probs, dim=-1)
keep = cumulative <= p
keep[..., 0] = True
first_ge = (cumulative >= p).float().argmax(dim=-1)
rows = torch.arange(keep.shape[0], device=keep.device)
keep[rows, first_ge] = True
filtered_sorted = torch.where(keep, sorted_probs, torch.zeros_like(sorted_probs))
norm = filtered_sorted.sum(dim=-1, keepdim=True).clamp_min(1e-12)
filtered_sorted = filtered_sorted / norm
filtered = torch.zeros_like(probs)
filtered.scatter_(dim=-1, index=sorted_indices, src=filtered_sorted)
return filtered.reshape(orig_shape)
def add_gumbel_noise(logits: torch.Tensor, temperature: float, *, generator: torch.Generator | None = None) -> torch.Tensor:
if temperature <= 0:
return logits
noise = torch.rand(logits.shape, device=logits.device, dtype=torch.float64, generator=generator)
gumbel_noise = (-torch.log(noise)) ** temperature
logits64 = logits.to(torch.float64)
perturbed = logits64.exp() / gumbel_noise
return perturbed.to(logits.dtype)
def compute_transfer_schedule(mask: torch.Tensor, steps: int) -> torch.Tensor:
if steps <= 0:
raise ValueError("steps must be > 0")
if mask.dim() != 2:
raise ValueError("mask must be 2D (batch, block_length)")
counts = mask.sum(dim=1, keepdim=True).to(torch.int64)
base = counts // steps
remainder = counts % steps
schedule = base.expand(-1, steps).clone()
for idx in range(schedule.size(0)):
r = remainder[idx, 0].item()
if r > 0:
schedule[idx, :r] += 1
return schedule
def _prepare_attention_mask(attention_mask: torch.Tensor, ref_tensor: torch.Tensor) -> torch.Tensor:
mask = attention_mask.to(device=ref_tensor.device, dtype=torch.bool)
if mask.dim() == 2:
mask = mask[:, None, None, :]
elif mask.dim() == 3:
mask = mask[:, None, :, :]
elif mask.dim() != 4:
raise ValueError("attention_mask must be 2D, 3D, or 4D")
return mask
def scaled_dot_product_attention(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
attention_mask: torch.Tensor | None = None,
):
scale = torch.tensor(Q.shape[-1], device=Q.device, dtype=Q.dtype).sqrt()
qk_score = einsum(Q, K, "batch_size ... n d_k, batch_size ... m d_k -> batch_size ... n m") / scale
if attention_mask is not None:
mask = _prepare_attention_mask(attention_mask, qk_score)
qk_score = qk_score.masked_fill(~mask, float("-inf"))
softmax_qk_score = softmax(qk_score, dim=-1)
attn = einsum(softmax_qk_score, V, "batch_size ... n m, batch_size ... m d_k -> batch_size ... n d_k")
return attn
def torch_scaled_dot_product_attention(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
attention_mask: torch.Tensor | None = None,
):
Q = Q.contiguous()
K = K.contiguous()
V = V.contiguous()
mask = None
if attention_mask is not None:
mask = _prepare_attention_mask(attention_mask, Q)
return torch.nn.functional.scaled_dot_product_attention(Q, K, V, attn_mask=mask, dropout_p=0.0, is_causal=False)
class RotaryPositionalEmbedding(torch.nn.Module):
def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None):
super().__init__()
self.device = device
theta_i = theta ** (torch.arange(0, d_k, 2).float() / d_k)
position = torch.arange(max_seq_len)
phases = position.unsqueeze(1) / theta_i.unsqueeze(0)
phases_cos = torch.cos(phases)
phases_sin = torch.sin(phases)
phases_combined = torch.stack([phases_cos, phases_sin], dim=-1).to(device=device)
self.register_buffer("phases", phases_combined, persistent=False)
def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
x = rearrange(x, "... (d_k p) -> ... d_k p", p=2)
x1 = x[..., 0]
x2 = x[..., 1]
phases_cos = self.phases[..., 0][token_positions].to(dtype=x.dtype)
phases_sin = self.phases[..., 1][token_positions].to(dtype=x.dtype)
x_rotated = torch.stack([
x1 * phases_cos - x2 * phases_sin,
x1 * phases_sin + x2 * phases_cos,
], dim=-1)
return x_rotated.flatten(-2)
class MultiheadSelfAttentionRoPE(torch.nn.Module):
def __init__(
self,
d_model: int,
num_heads: int,
max_seq_len: int,
theta: float,
attention_backend: str = "custom",
device=None,
dtype=None,
):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = self.d_model // self.num_heads
self.d_v = self.d_k
self.max_seq_len = max_seq_len
self.theta = theta
if attention_backend not in {"custom", "torch_sdpa"}:
raise ValueError("attention_backend must be one of ['custom', 'torch_sdpa']")
self.attention_backend = attention_backend
self.q_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
self.k_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
self.v_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
self.output_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
self.rope = RotaryPositionalEmbedding(self.theta, self.d_k, self.max_seq_len, device)
def forward(
self,
x: torch.Tensor,
token_positions: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
wqx = self.q_proj(x)
wqx_rearr = rearrange(wqx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k)
wqx_rearr_rope = self.rope(wqx_rearr, token_positions)
wkx = self.k_proj(x)
wkx_rearr = rearrange(wkx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k)
wkx_rearr_rope = self.rope(wkx_rearr, token_positions)
wvx = self.v_proj(x)
wvx_rearr = rearrange(wvx, "... seq_len (num_heads d_v) -> ... num_heads seq_len d_v", num_heads=self.num_heads, d_v=self.d_v)
if self.attention_backend == "torch_sdpa":
attn = torch_scaled_dot_product_attention(
wqx_rearr_rope,
wkx_rearr_rope,
wvx_rearr,
attention_mask=attention_mask,
)
else:
attn = scaled_dot_product_attention(
wqx_rearr_rope,
wkx_rearr_rope,
wvx_rearr,
attention_mask=attention_mask,
)
attn_rearr = rearrange(attn, "... num_heads seq_len d_v -> ... seq_len (num_heads d_v)", num_heads=self.num_heads, d_v=self.d_v)
attn_rearr_proj = self.output_proj(attn_rearr)
return attn_rearr_proj
class MultiheadCrossAttentionRoPE(torch.nn.Module):
def __init__(
self,
d_model: int,
num_heads: int,
max_seq_len: int,
theta: float,
attention_backend: str = "custom",
device=None,
dtype=None,
):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = self.d_model // self.num_heads
self.d_v = self.d_k
self.max_seq_len = max_seq_len
self.theta = theta
if attention_backend not in {"custom", "torch_sdpa"}:
raise ValueError("attention_backend must be one of ['custom', 'torch_sdpa']")
self.attention_backend = attention_backend
self.q_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
self.k_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
self.v_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
self.output_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
self.rope = RotaryPositionalEmbedding(self.theta, self.d_k, self.max_seq_len, device)
def forward(
self,
x: torch.Tensor,
context: torch.Tensor,
token_positions: torch.Tensor,
context_token_positions: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
wqx = self.q_proj(x)
wqx_rearr = rearrange(wqx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k)
wqx_rearr_rope = self.rope(wqx_rearr, token_positions)
wkx = self.k_proj(context)
wkx_rearr = rearrange(wkx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k)
wkx_rearr_rope = self.rope(wkx_rearr, context_token_positions)
wvx = self.v_proj(context)
wvx_rearr = rearrange(wvx, "... seq_len (num_heads d_v) -> ... num_heads seq_len d_v", num_heads=self.num_heads, d_v=self.d_v)
if self.attention_backend == "torch_sdpa":
attn = torch_scaled_dot_product_attention(
wqx_rearr_rope,
wkx_rearr_rope,
wvx_rearr,
attention_mask=attention_mask,
)
else:
attn = scaled_dot_product_attention(
wqx_rearr_rope,
wkx_rearr_rope,
wvx_rearr,
attention_mask=attention_mask,
)
attn_rearr = rearrange(attn, "... num_heads seq_len d_v -> ... seq_len (num_heads d_v)", num_heads=self.num_heads, d_v=self.d_v)
attn_rearr_proj = self.output_proj(attn_rearr)
return attn_rearr_proj
class TransformerImageBlock(torch.nn.Module):
def __init__(
self,
d_model: int,
num_heads: int,
max_seq_len: int,
theta: float,
d_ff: int,
attention_backend: str = "custom",
device=None,
dtype=None,
):
super().__init__()
self.ffn = SwiGLU(d_model, d_ff, device, dtype)
self.self_attn = MultiheadSelfAttentionRoPE(
d_model,
num_heads,
max_seq_len,
theta,
attention_backend=attention_backend,
device=device,
dtype=dtype,
)
self.cross_attn = MultiheadCrossAttentionRoPE(
d_model,
num_heads,
max_seq_len,
theta,
attention_backend=attention_backend,
device=device,
dtype=dtype,
)
self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)
self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)
self.ln3 = RMSNorm(d_model, device=device, dtype=dtype)
def forward(
self,
x: torch.Tensor,
token_positions: torch.Tensor,
context: torch.Tensor,
context_token_positions: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
ln1x = self.ln1(x)
x = x + self.self_attn(ln1x, token_positions, attention_mask=attention_mask)
ln2x = self.ln2(x)
x = x + self.cross_attn(
ln2x,
context,
token_positions,
context_token_positions,
attention_mask=None,
)
ln3x = self.ln3(x)
x = x + self.ffn(ln3x)
return x
class TransformerImage(torch.nn.Module):
def __init__(
self,
vocab_size: int,
context_length: int,
d_model: int,
num_layers: int,
num_heads: int,
d_ff: int,
rope_theta: float,
label_vocab_size: int,
attention_backend: str = "custom",
device=None,
dtype=None,
):
super().__init__()
self.context_length = context_length
self.token_embeddings = Embedding(vocab_size, d_model, device, dtype)
self.label_embeddings = Embedding(label_vocab_size, d_model, device, dtype)
self.layers = torch.nn.ModuleList(
[
TransformerImageBlock(
d_model,
num_heads,
context_length,
rope_theta,
d_ff,
attention_backend=attention_backend,
device=device,
dtype=dtype,
)
for _ in range(num_layers)
]
)
self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
self.lm_head = Linear(d_model, vocab_size, device, dtype)
def forward(
self,
in_indices: torch.Tensor,
attention_mask: torch.Tensor | None = None,
context: torch.Tensor | None = None,
) -> torch.Tensor:
if context is None:
raise ValueError("context must be provided for TransformerImage")
output_seq = self.token_embeddings(in_indices)
context_emb = self.label_embeddings(context).unsqueeze(-2)
token_positions = torch.arange(output_seq.shape[-2], device=output_seq.device, dtype=torch.long)
context_token_positions = torch.arange(context_emb.shape[-2], device=output_seq.device, dtype=torch.long)
for layer in self.layers:
output_seq = layer(
output_seq,
token_positions,
context_emb,
context_token_positions,
attention_mask=attention_mask,
)
normed_output_seq = self.ln_final(output_seq)
logits = self.lm_head(normed_output_seq)
return logits
@torch.no_grad()
def image_diffusion_generate(
model,
prompt_indices: torch.Tensor,
*,
context: torch.Tensor,
mask_id: int,
eos_token_id: int | None = None,
steps: int,
gen_length: int,
block_length: int,
temperature: float = 0.0,
top_p: float | None = None,
cfg_scale: float = 0.0,
uncond_context: torch.Tensor | None = None,
remasking: str = "random",
logits_eos_inf: bool = False,
confidence_eos_eot_inf: bool = False,
generator: torch.Generator | None = None,
) -> torch.Tensor:
if prompt_indices.dim() != 2:
raise ValueError("prompt_indices must be 2D (batch, seq)")
if context.dim() != 1:
raise ValueError("context must be 1D (batch,)")
if prompt_indices.shape[0] != context.shape[0]:
raise ValueError("context batch size must match prompt batch size")
if block_length <= 0:
raise ValueError("block_length must be > 0")
if steps <= 0:
raise ValueError("steps must be > 0")
if gen_length <= 0:
return prompt_indices
blocks = max(1, int(np.ceil(gen_length / block_length)))
if steps < blocks:
raise ValueError("steps must be >= number of blocks")
base_steps = steps // blocks
extra_steps = steps % blocks
device = prompt_indices.device
batch_size, prompt_len = prompt_indices.shape
total_len = prompt_len + gen_length
context_limit = getattr(model, "context_length", None)
if context_limit is not None and total_len > int(context_limit):
raise ValueError("prompt length + gen_length exceeds model context_length")
x = torch.full(
(batch_size, total_len),
fill_value=mask_id,
device=device,
dtype=prompt_indices.dtype,
)
x[:, :prompt_len] = prompt_indices
if uncond_context is not None:
if uncond_context.dim() != 1:
raise ValueError("uncond_context must be 1D (batch,)")
if uncond_context.shape[0] != batch_size:
raise ValueError("uncond_context batch size must match prompt batch size")
uncond_context = uncond_context.to(device=context.device, dtype=context.dtype)
for block_idx in range(blocks):
block_start = prompt_len + block_idx * block_length
block_end = min(block_start + block_length, total_len)
block_steps = base_steps + (1 if block_idx < extra_steps else 0)
if block_steps <= 0:
block_steps = 1
block_mask = (x[:, block_start:block_end] == mask_id)
transfer_counts = compute_transfer_schedule(block_mask, block_steps)
for step_idx in range(block_steps):
mask_index = (x == mask_id)
if cfg_scale > 0.0:
if uncond_context is None:
raise ValueError("uncond_context must be set when cfg_scale > 0 for image_diffusion_generate")
cond_logits = model(x, context=context)
uncond_logits = model(x, context=uncond_context)
logits = uncond_logits + (cfg_scale + 1.0) * (cond_logits - uncond_logits)
else:
logits = model(x, context=context)
if logits_eos_inf and eos_token_id is not None:
logits[:, :, eos_token_id] = float("-inf")
if top_p is not None:
probs = softmax(logits, dim=-1)
probs = top_p_filter(probs, float(top_p))
logits = torch.where(
probs > 0,
logits,
torch.full_like(logits, float("-inf")),
)
logits_with_noise = add_gumbel_noise(logits, temperature, generator=generator)
predictions = torch.argmax(logits_with_noise, dim=-1)
predictions = torch.where(mask_index, predictions, x)
if remasking == "low_confidence":
probs = softmax(logits, dim=-1)
confidence = torch.squeeze(
torch.gather(probs, dim=-1, index=torch.unsqueeze(predictions, -1)),
-1,
)
elif remasking == "random":
confidence = torch.rand(
(batch_size, total_len),
device=device,
dtype=torch.float32,
generator=generator,
)
else:
raise ValueError(f"Unsupported remasking strategy: {remasking}")
if confidence_eos_eot_inf and eos_token_id is not None:
confidence = torch.where(
predictions == eos_token_id,
torch.full_like(confidence, float("-inf")),
confidence,
)
confidence[:, block_end:] = float("-inf")
confidence = torch.where(mask_index, confidence, torch.full_like(confidence, float("-inf")))
transfer_mask = torch.zeros_like(mask_index)
for b in range(batch_size):
k = int(transfer_counts[b, step_idx].item())
if k <= 0:
continue
available = confidence[b] > float("-inf")
available_count = int(available.sum().item())
if available_count == 0:
continue
if available_count < k:
k = available_count
topk_indices = torch.topk(confidence[b], k=k, dim=-1).indices
transfer_mask[b, topk_indices] = True
x = torch.where(transfer_mask, predictions, x)
return x
def dequantize_tokens_to_uint8(tokens: np.ndarray, *, pixel_bins: int) -> np.ndarray:
if pixel_bins == 256:
return tokens.astype(np.uint8)
vals = np.clip(tokens.astype(np.int32), 0, int(pixel_bins) - 1)
scale = 256.0 / float(pixel_bins)
restored = np.round((vals + 0.5) * scale - 0.5)
return np.clip(restored, 0, 255).astype(np.uint8)
MODEL = None
DEVICE = None
DTYPE = None
def load_model():
global MODEL, DEVICE, DTYPE
if MODEL is not None:
return MODEL, DEVICE, DTYPE
if not os.path.exists(CHECKPOINT_PATH):
raise FileNotFoundError(f"Missing checkpoint at {CHECKPOINT_PATH}")
device, dtype = _resolve_device_dtype(MODEL_CONFIG["device"], MODEL_CONFIG["dtype"])
set_sdp_backend(MODEL_CONFIG["attention_sdp_backend"])
model = TransformerImage(
vocab_size=MODEL_CONFIG["vocab_size"],
context_length=MODEL_CONFIG["context_length"],
d_model=MODEL_CONFIG["d_model"],
num_layers=MODEL_CONFIG["num_layers"],
num_heads=MODEL_CONFIG["num_heads"],
d_ff=MODEL_CONFIG["d_ff"],
rope_theta=MODEL_CONFIG["rope_theta"],
label_vocab_size=MODEL_CONFIG["label_vocab_size"],
attention_backend=MODEL_CONFIG["attention_backend"],
device=device,
dtype=dtype,
)
model_state = load_file(CHECKPOINT_PATH)
model.load_state_dict(model_state)
model.eval().to(device)
MODEL = model
DEVICE = device
DTYPE = dtype
return MODEL, DEVICE, DTYPE
@torch.inference_mode()
def generate_images(label: int, steps: int, num_samples: int) -> List[Image.Image]:
model, device, _ = load_model()
num_samples = int(num_samples)
label = int(label)
steps = int(steps)
context = torch.full((num_samples,), label, device=device, dtype=torch.long)
prompt = torch.empty((num_samples, 0), device=device, dtype=torch.long)
cfg_scale = float(INFER_CONFIG["cfg_scale"])
uncond_context = None
if cfg_scale > 0.0:
null_label_id = int(MODEL_CONFIG["null_label_id"])
uncond_context = torch.full((num_samples,), null_label_id, device=device, dtype=torch.long)
out_indices = image_diffusion_generate(
model,
prompt,
context=context,
mask_id=int(MODEL_CONFIG["mask_token_id"]),
eos_token_id=None,
steps=steps,
gen_length=int(MODEL_CONFIG["context_length"]),
block_length=int(INFER_CONFIG["block_length"]),
temperature=float(INFER_CONFIG["temperature"]),
top_p=float(INFER_CONFIG["top_p"]),
cfg_scale=cfg_scale,
uncond_context=uncond_context,
remasking=str(INFER_CONFIG["remasking"]),
logits_eos_inf=False,
confidence_eos_eot_inf=False,
generator=None,
)
h = int(MODEL_CONFIG["image_height"])
w = int(MODEL_CONFIG["image_width"])
pixel_bins = int(MODEL_CONFIG["pixel_bins"])
images: List[Image.Image] = []
scale = 10
for i in range(num_samples):
tokens = out_indices[i].detach().cpu().to(torch.int32).numpy().reshape(h, w)
arr = dequantize_tokens_to_uint8(tokens, pixel_bins=pixel_bins)
img = Image.fromarray(arr, mode="L")
if scale > 1:
img = img.resize((w * scale, h * scale), resample=Image.NEAREST)
images.append(img)
return images
def _grid_dims(num_samples: int) -> Tup[int, int]:
cols = int(np.ceil(np.sqrt(num_samples)))
rows = int(np.ceil(num_samples / cols))
return rows, cols
@torch.inference_mode()
def generate_grid_image(label: int, steps: int, num_samples: int) -> Image.Image:
images = generate_images(label=label, steps=steps, num_samples=num_samples)
if not images:
return Image.new("L", (1, 1), color=0)
rows, cols = _grid_dims(len(images))
w, h = images[0].size
grid = Image.new("L", (cols * w, rows * h))
for idx, img in enumerate(images):
r = idx // cols
c = idx % cols
grid.paste(img, (c * w, r * h))
return grid