Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| import os | |
| from typing import Tuple, List, Tuple as Tup | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from safetensors.torch import load_file | |
| from einops import einsum, rearrange | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| CHECKPOINT_PATH = os.path.join(BASE_DIR, "model", "model.safetensors") | |
| MODEL_CONFIG = { | |
| "model_type": "image", | |
| "label_vocab_size": 11, | |
| "vocab_size": 33, | |
| "pixel_bins": 32, | |
| "context_length": 784, | |
| "d_model": 256, | |
| "num_layers": 8, | |
| "num_heads": 16, | |
| "d_ff": 1024, | |
| "rope_theta": 10000.0, | |
| "attention_backend": "torch_sdpa", | |
| "attention_sdp_backend": "auto", | |
| "device": "cuda", | |
| "dtype": "float16", | |
| "mask_token_id": 32, | |
| "null_label_id": 10, | |
| "image_height": 28, | |
| "image_width": 28, | |
| } | |
| INFER_CONFIG = { | |
| "block_length": 784, | |
| "temperature": 0.6, | |
| "top_p": 0.99, | |
| "cfg_scale": 2.0, | |
| "remasking": "random", | |
| } | |
| DTYPES = { | |
| "float16": torch.float16, | |
| "float32": torch.float32, | |
| "bfloat16": torch.bfloat16, | |
| } | |
| def _resolve_device_dtype(device: str, dtype_name: str) -> Tuple[str, torch.dtype]: | |
| resolved_device = device | |
| if device == "cuda" and not torch.cuda.is_available(): | |
| resolved_device = "cpu" | |
| resolved_dtype = DTYPES[dtype_name] | |
| if resolved_device == "cpu" and resolved_dtype == torch.float16: | |
| resolved_dtype = torch.float32 | |
| return resolved_device, resolved_dtype | |
| def set_sdp_backend(backend: str) -> None: | |
| backend = backend.lower() | |
| allowed = {"auto", "flash", "mem_efficient", "math"} | |
| if backend not in allowed: | |
| raise ValueError(f"attention_sdp_backend must be one of {sorted(allowed)}") | |
| if not torch.cuda.is_available(): | |
| return | |
| if backend == "auto": | |
| torch.backends.cuda.enable_flash_sdp(True) | |
| torch.backends.cuda.enable_mem_efficient_sdp(True) | |
| torch.backends.cuda.enable_math_sdp(True) | |
| return | |
| torch.backends.cuda.enable_flash_sdp(backend == "flash") | |
| torch.backends.cuda.enable_mem_efficient_sdp(backend == "mem_efficient") | |
| torch.backends.cuda.enable_math_sdp(backend == "math") | |
| class Linear(torch.nn.Module): | |
| def __init__(self, in_features, out_features, device=None, dtype=None): | |
| super().__init__() | |
| self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, device=device, dtype=dtype)) | |
| mean = 0.0 | |
| std = 2 / (in_features + out_features) | |
| a = mean - 3 * std | |
| b = mean + 3 * std | |
| torch.nn.init.trunc_normal_(self.weight, mean=mean, std=std, a=a, b=b) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| y = einsum(self.weight, x, "out_features in_features, ... in_features -> ... out_features") | |
| return y | |
| class Embedding(torch.nn.Module): | |
| def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None): | |
| super().__init__() | |
| self.num_embeddings = num_embeddings | |
| self.embedding_dim = embedding_dim | |
| self.weight = torch.nn.Parameter(torch.empty(num_embeddings, embedding_dim, device=device, dtype=dtype)) | |
| torch.nn.init.trunc_normal_(self.weight, mean=0, std=1, a=-3, b=3) | |
| def forward(self, token_ids: torch.Tensor) -> torch.Tensor: | |
| embeds = self.weight[token_ids] | |
| return embeds | |
| class RMSNorm(torch.nn.Module): | |
| def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None): | |
| super().__init__() | |
| self.eps = eps | |
| self.d_model = d_model | |
| self.weight = torch.nn.Parameter(torch.empty(d_model, device=device, dtype=dtype)) | |
| torch.nn.init.ones_(self.weight) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| in_dtype = x.dtype | |
| x = x.to(torch.float32) | |
| rms = torch.sqrt(torch.mean(x ** 2, dim=-1) + self.eps).unsqueeze(-1) | |
| x = (1 / rms) * (x * self.weight) | |
| return x.to(in_dtype) | |
| class SwiGLU(torch.nn.Module): | |
| def __init__(self, d_model: int, d_ff: int, device=None, dtype=None): | |
| super().__init__() | |
| self.w1 = Linear(d_model, d_ff, device=device, dtype=dtype) | |
| self.w2 = Linear(d_ff, d_model, device=device, dtype=dtype) | |
| self.w3 = Linear(d_model, d_ff, device=device, dtype=dtype) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| w1x = self.w1(x) | |
| w3x = self.w3(x) | |
| silu = w1x * torch.sigmoid(w1x) | |
| glu = silu * w3x | |
| w2x = self.w2(glu) | |
| return w2x | |
| def softmax(x: torch.Tensor, dim: int): | |
| x_max = x.max(dim=dim, keepdim=True).values | |
| x_stable = x - x_max | |
| exp_x = torch.exp(x_stable) | |
| sum_exp_x = exp_x.sum(dim=dim, keepdim=True) | |
| return exp_x / sum_exp_x | |
| def top_p_filter(probs: torch.Tensor, p: float) -> torch.Tensor: | |
| if probs.dim() < 2: | |
| raise ValueError("probs must have at least 2 dimensions") | |
| orig_shape = probs.shape | |
| vocab = orig_shape[-1] | |
| probs = probs.reshape(-1, vocab) | |
| if p <= 0: | |
| argmax = probs.argmax(dim=-1) | |
| out = torch.zeros_like(probs) | |
| out.scatter_(-1, argmax.unsqueeze(-1), 1.0) | |
| return out.reshape(orig_shape) | |
| if p >= 1: | |
| return (probs / probs.sum(dim=-1, keepdim=True)).reshape(orig_shape) | |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True) | |
| cumulative = torch.cumsum(sorted_probs, dim=-1) | |
| keep = cumulative <= p | |
| keep[..., 0] = True | |
| first_ge = (cumulative >= p).float().argmax(dim=-1) | |
| rows = torch.arange(keep.shape[0], device=keep.device) | |
| keep[rows, first_ge] = True | |
| filtered_sorted = torch.where(keep, sorted_probs, torch.zeros_like(sorted_probs)) | |
| norm = filtered_sorted.sum(dim=-1, keepdim=True).clamp_min(1e-12) | |
| filtered_sorted = filtered_sorted / norm | |
| filtered = torch.zeros_like(probs) | |
| filtered.scatter_(dim=-1, index=sorted_indices, src=filtered_sorted) | |
| return filtered.reshape(orig_shape) | |
| def add_gumbel_noise(logits: torch.Tensor, temperature: float, *, generator: torch.Generator | None = None) -> torch.Tensor: | |
| if temperature <= 0: | |
| return logits | |
| noise = torch.rand(logits.shape, device=logits.device, dtype=torch.float64, generator=generator) | |
| gumbel_noise = (-torch.log(noise)) ** temperature | |
| logits64 = logits.to(torch.float64) | |
| perturbed = logits64.exp() / gumbel_noise | |
| return perturbed.to(logits.dtype) | |
| def compute_transfer_schedule(mask: torch.Tensor, steps: int) -> torch.Tensor: | |
| if steps <= 0: | |
| raise ValueError("steps must be > 0") | |
| if mask.dim() != 2: | |
| raise ValueError("mask must be 2D (batch, block_length)") | |
| counts = mask.sum(dim=1, keepdim=True).to(torch.int64) | |
| base = counts // steps | |
| remainder = counts % steps | |
| schedule = base.expand(-1, steps).clone() | |
| for idx in range(schedule.size(0)): | |
| r = remainder[idx, 0].item() | |
| if r > 0: | |
| schedule[idx, :r] += 1 | |
| return schedule | |
| def _prepare_attention_mask(attention_mask: torch.Tensor, ref_tensor: torch.Tensor) -> torch.Tensor: | |
| mask = attention_mask.to(device=ref_tensor.device, dtype=torch.bool) | |
| if mask.dim() == 2: | |
| mask = mask[:, None, None, :] | |
| elif mask.dim() == 3: | |
| mask = mask[:, None, :, :] | |
| elif mask.dim() != 4: | |
| raise ValueError("attention_mask must be 2D, 3D, or 4D") | |
| return mask | |
| def scaled_dot_product_attention( | |
| Q: torch.Tensor, | |
| K: torch.Tensor, | |
| V: torch.Tensor, | |
| attention_mask: torch.Tensor | None = None, | |
| ): | |
| scale = torch.tensor(Q.shape[-1], device=Q.device, dtype=Q.dtype).sqrt() | |
| qk_score = einsum(Q, K, "batch_size ... n d_k, batch_size ... m d_k -> batch_size ... n m") / scale | |
| if attention_mask is not None: | |
| mask = _prepare_attention_mask(attention_mask, qk_score) | |
| qk_score = qk_score.masked_fill(~mask, float("-inf")) | |
| softmax_qk_score = softmax(qk_score, dim=-1) | |
| attn = einsum(softmax_qk_score, V, "batch_size ... n m, batch_size ... m d_k -> batch_size ... n d_k") | |
| return attn | |
| def torch_scaled_dot_product_attention( | |
| Q: torch.Tensor, | |
| K: torch.Tensor, | |
| V: torch.Tensor, | |
| attention_mask: torch.Tensor | None = None, | |
| ): | |
| Q = Q.contiguous() | |
| K = K.contiguous() | |
| V = V.contiguous() | |
| mask = None | |
| if attention_mask is not None: | |
| mask = _prepare_attention_mask(attention_mask, Q) | |
| return torch.nn.functional.scaled_dot_product_attention(Q, K, V, attn_mask=mask, dropout_p=0.0, is_causal=False) | |
| class RotaryPositionalEmbedding(torch.nn.Module): | |
| def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None): | |
| super().__init__() | |
| self.device = device | |
| theta_i = theta ** (torch.arange(0, d_k, 2).float() / d_k) | |
| position = torch.arange(max_seq_len) | |
| phases = position.unsqueeze(1) / theta_i.unsqueeze(0) | |
| phases_cos = torch.cos(phases) | |
| phases_sin = torch.sin(phases) | |
| phases_combined = torch.stack([phases_cos, phases_sin], dim=-1).to(device=device) | |
| self.register_buffer("phases", phases_combined, persistent=False) | |
| def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor: | |
| x = rearrange(x, "... (d_k p) -> ... d_k p", p=2) | |
| x1 = x[..., 0] | |
| x2 = x[..., 1] | |
| phases_cos = self.phases[..., 0][token_positions].to(dtype=x.dtype) | |
| phases_sin = self.phases[..., 1][token_positions].to(dtype=x.dtype) | |
| x_rotated = torch.stack([ | |
| x1 * phases_cos - x2 * phases_sin, | |
| x1 * phases_sin + x2 * phases_cos, | |
| ], dim=-1) | |
| return x_rotated.flatten(-2) | |
| class MultiheadSelfAttentionRoPE(torch.nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| num_heads: int, | |
| max_seq_len: int, | |
| theta: float, | |
| attention_backend: str = "custom", | |
| device=None, | |
| dtype=None, | |
| ): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.num_heads = num_heads | |
| self.d_k = self.d_model // self.num_heads | |
| self.d_v = self.d_k | |
| self.max_seq_len = max_seq_len | |
| self.theta = theta | |
| if attention_backend not in {"custom", "torch_sdpa"}: | |
| raise ValueError("attention_backend must be one of ['custom', 'torch_sdpa']") | |
| self.attention_backend = attention_backend | |
| self.q_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype) | |
| self.k_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype) | |
| self.v_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype) | |
| self.output_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype) | |
| self.rope = RotaryPositionalEmbedding(self.theta, self.d_k, self.max_seq_len, device) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| token_positions: torch.Tensor, | |
| attention_mask: torch.Tensor | None = None, | |
| ) -> torch.Tensor: | |
| wqx = self.q_proj(x) | |
| wqx_rearr = rearrange(wqx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k) | |
| wqx_rearr_rope = self.rope(wqx_rearr, token_positions) | |
| wkx = self.k_proj(x) | |
| wkx_rearr = rearrange(wkx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k) | |
| wkx_rearr_rope = self.rope(wkx_rearr, token_positions) | |
| wvx = self.v_proj(x) | |
| wvx_rearr = rearrange(wvx, "... seq_len (num_heads d_v) -> ... num_heads seq_len d_v", num_heads=self.num_heads, d_v=self.d_v) | |
| if self.attention_backend == "torch_sdpa": | |
| attn = torch_scaled_dot_product_attention( | |
| wqx_rearr_rope, | |
| wkx_rearr_rope, | |
| wvx_rearr, | |
| attention_mask=attention_mask, | |
| ) | |
| else: | |
| attn = scaled_dot_product_attention( | |
| wqx_rearr_rope, | |
| wkx_rearr_rope, | |
| wvx_rearr, | |
| attention_mask=attention_mask, | |
| ) | |
| attn_rearr = rearrange(attn, "... num_heads seq_len d_v -> ... seq_len (num_heads d_v)", num_heads=self.num_heads, d_v=self.d_v) | |
| attn_rearr_proj = self.output_proj(attn_rearr) | |
| return attn_rearr_proj | |
| class MultiheadCrossAttentionRoPE(torch.nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| num_heads: int, | |
| max_seq_len: int, | |
| theta: float, | |
| attention_backend: str = "custom", | |
| device=None, | |
| dtype=None, | |
| ): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.num_heads = num_heads | |
| self.d_k = self.d_model // self.num_heads | |
| self.d_v = self.d_k | |
| self.max_seq_len = max_seq_len | |
| self.theta = theta | |
| if attention_backend not in {"custom", "torch_sdpa"}: | |
| raise ValueError("attention_backend must be one of ['custom', 'torch_sdpa']") | |
| self.attention_backend = attention_backend | |
| self.q_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype) | |
| self.k_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype) | |
| self.v_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype) | |
| self.output_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype) | |
| self.rope = RotaryPositionalEmbedding(self.theta, self.d_k, self.max_seq_len, device) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| context: torch.Tensor, | |
| token_positions: torch.Tensor, | |
| context_token_positions: torch.Tensor, | |
| attention_mask: torch.Tensor | None = None, | |
| ) -> torch.Tensor: | |
| wqx = self.q_proj(x) | |
| wqx_rearr = rearrange(wqx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k) | |
| wqx_rearr_rope = self.rope(wqx_rearr, token_positions) | |
| wkx = self.k_proj(context) | |
| wkx_rearr = rearrange(wkx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k) | |
| wkx_rearr_rope = self.rope(wkx_rearr, context_token_positions) | |
| wvx = self.v_proj(context) | |
| wvx_rearr = rearrange(wvx, "... seq_len (num_heads d_v) -> ... num_heads seq_len d_v", num_heads=self.num_heads, d_v=self.d_v) | |
| if self.attention_backend == "torch_sdpa": | |
| attn = torch_scaled_dot_product_attention( | |
| wqx_rearr_rope, | |
| wkx_rearr_rope, | |
| wvx_rearr, | |
| attention_mask=attention_mask, | |
| ) | |
| else: | |
| attn = scaled_dot_product_attention( | |
| wqx_rearr_rope, | |
| wkx_rearr_rope, | |
| wvx_rearr, | |
| attention_mask=attention_mask, | |
| ) | |
| attn_rearr = rearrange(attn, "... num_heads seq_len d_v -> ... seq_len (num_heads d_v)", num_heads=self.num_heads, d_v=self.d_v) | |
| attn_rearr_proj = self.output_proj(attn_rearr) | |
| return attn_rearr_proj | |
| class TransformerImageBlock(torch.nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| num_heads: int, | |
| max_seq_len: int, | |
| theta: float, | |
| d_ff: int, | |
| attention_backend: str = "custom", | |
| device=None, | |
| dtype=None, | |
| ): | |
| super().__init__() | |
| self.ffn = SwiGLU(d_model, d_ff, device, dtype) | |
| self.self_attn = MultiheadSelfAttentionRoPE( | |
| d_model, | |
| num_heads, | |
| max_seq_len, | |
| theta, | |
| attention_backend=attention_backend, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| self.cross_attn = MultiheadCrossAttentionRoPE( | |
| d_model, | |
| num_heads, | |
| max_seq_len, | |
| theta, | |
| attention_backend=attention_backend, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| self.ln1 = RMSNorm(d_model, device=device, dtype=dtype) | |
| self.ln2 = RMSNorm(d_model, device=device, dtype=dtype) | |
| self.ln3 = RMSNorm(d_model, device=device, dtype=dtype) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| token_positions: torch.Tensor, | |
| context: torch.Tensor, | |
| context_token_positions: torch.Tensor, | |
| attention_mask: torch.Tensor | None = None, | |
| ) -> torch.Tensor: | |
| ln1x = self.ln1(x) | |
| x = x + self.self_attn(ln1x, token_positions, attention_mask=attention_mask) | |
| ln2x = self.ln2(x) | |
| x = x + self.cross_attn( | |
| ln2x, | |
| context, | |
| token_positions, | |
| context_token_positions, | |
| attention_mask=None, | |
| ) | |
| ln3x = self.ln3(x) | |
| x = x + self.ffn(ln3x) | |
| return x | |
| class TransformerImage(torch.nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| context_length: int, | |
| d_model: int, | |
| num_layers: int, | |
| num_heads: int, | |
| d_ff: int, | |
| rope_theta: float, | |
| label_vocab_size: int, | |
| attention_backend: str = "custom", | |
| device=None, | |
| dtype=None, | |
| ): | |
| super().__init__() | |
| self.context_length = context_length | |
| self.token_embeddings = Embedding(vocab_size, d_model, device, dtype) | |
| self.label_embeddings = Embedding(label_vocab_size, d_model, device, dtype) | |
| self.layers = torch.nn.ModuleList( | |
| [ | |
| TransformerImageBlock( | |
| d_model, | |
| num_heads, | |
| context_length, | |
| rope_theta, | |
| d_ff, | |
| attention_backend=attention_backend, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| self.ln_final = RMSNorm(d_model, device=device, dtype=dtype) | |
| self.lm_head = Linear(d_model, vocab_size, device, dtype) | |
| def forward( | |
| self, | |
| in_indices: torch.Tensor, | |
| attention_mask: torch.Tensor | None = None, | |
| context: torch.Tensor | None = None, | |
| ) -> torch.Tensor: | |
| if context is None: | |
| raise ValueError("context must be provided for TransformerImage") | |
| output_seq = self.token_embeddings(in_indices) | |
| context_emb = self.label_embeddings(context).unsqueeze(-2) | |
| token_positions = torch.arange(output_seq.shape[-2], device=output_seq.device, dtype=torch.long) | |
| context_token_positions = torch.arange(context_emb.shape[-2], device=output_seq.device, dtype=torch.long) | |
| for layer in self.layers: | |
| output_seq = layer( | |
| output_seq, | |
| token_positions, | |
| context_emb, | |
| context_token_positions, | |
| attention_mask=attention_mask, | |
| ) | |
| normed_output_seq = self.ln_final(output_seq) | |
| logits = self.lm_head(normed_output_seq) | |
| return logits | |
| def image_diffusion_generate( | |
| model, | |
| prompt_indices: torch.Tensor, | |
| *, | |
| context: torch.Tensor, | |
| mask_id: int, | |
| eos_token_id: int | None = None, | |
| steps: int, | |
| gen_length: int, | |
| block_length: int, | |
| temperature: float = 0.0, | |
| top_p: float | None = None, | |
| cfg_scale: float = 0.0, | |
| uncond_context: torch.Tensor | None = None, | |
| remasking: str = "random", | |
| logits_eos_inf: bool = False, | |
| confidence_eos_eot_inf: bool = False, | |
| generator: torch.Generator | None = None, | |
| ) -> torch.Tensor: | |
| if prompt_indices.dim() != 2: | |
| raise ValueError("prompt_indices must be 2D (batch, seq)") | |
| if context.dim() != 1: | |
| raise ValueError("context must be 1D (batch,)") | |
| if prompt_indices.shape[0] != context.shape[0]: | |
| raise ValueError("context batch size must match prompt batch size") | |
| if block_length <= 0: | |
| raise ValueError("block_length must be > 0") | |
| if steps <= 0: | |
| raise ValueError("steps must be > 0") | |
| if gen_length <= 0: | |
| return prompt_indices | |
| blocks = max(1, int(np.ceil(gen_length / block_length))) | |
| if steps < blocks: | |
| raise ValueError("steps must be >= number of blocks") | |
| base_steps = steps // blocks | |
| extra_steps = steps % blocks | |
| device = prompt_indices.device | |
| batch_size, prompt_len = prompt_indices.shape | |
| total_len = prompt_len + gen_length | |
| context_limit = getattr(model, "context_length", None) | |
| if context_limit is not None and total_len > int(context_limit): | |
| raise ValueError("prompt length + gen_length exceeds model context_length") | |
| x = torch.full( | |
| (batch_size, total_len), | |
| fill_value=mask_id, | |
| device=device, | |
| dtype=prompt_indices.dtype, | |
| ) | |
| x[:, :prompt_len] = prompt_indices | |
| if uncond_context is not None: | |
| if uncond_context.dim() != 1: | |
| raise ValueError("uncond_context must be 1D (batch,)") | |
| if uncond_context.shape[0] != batch_size: | |
| raise ValueError("uncond_context batch size must match prompt batch size") | |
| uncond_context = uncond_context.to(device=context.device, dtype=context.dtype) | |
| for block_idx in range(blocks): | |
| block_start = prompt_len + block_idx * block_length | |
| block_end = min(block_start + block_length, total_len) | |
| block_steps = base_steps + (1 if block_idx < extra_steps else 0) | |
| if block_steps <= 0: | |
| block_steps = 1 | |
| block_mask = (x[:, block_start:block_end] == mask_id) | |
| transfer_counts = compute_transfer_schedule(block_mask, block_steps) | |
| for step_idx in range(block_steps): | |
| mask_index = (x == mask_id) | |
| if cfg_scale > 0.0: | |
| if uncond_context is None: | |
| raise ValueError("uncond_context must be set when cfg_scale > 0 for image_diffusion_generate") | |
| cond_logits = model(x, context=context) | |
| uncond_logits = model(x, context=uncond_context) | |
| logits = uncond_logits + (cfg_scale + 1.0) * (cond_logits - uncond_logits) | |
| else: | |
| logits = model(x, context=context) | |
| if logits_eos_inf and eos_token_id is not None: | |
| logits[:, :, eos_token_id] = float("-inf") | |
| if top_p is not None: | |
| probs = softmax(logits, dim=-1) | |
| probs = top_p_filter(probs, float(top_p)) | |
| logits = torch.where( | |
| probs > 0, | |
| logits, | |
| torch.full_like(logits, float("-inf")), | |
| ) | |
| logits_with_noise = add_gumbel_noise(logits, temperature, generator=generator) | |
| predictions = torch.argmax(logits_with_noise, dim=-1) | |
| predictions = torch.where(mask_index, predictions, x) | |
| if remasking == "low_confidence": | |
| probs = softmax(logits, dim=-1) | |
| confidence = torch.squeeze( | |
| torch.gather(probs, dim=-1, index=torch.unsqueeze(predictions, -1)), | |
| -1, | |
| ) | |
| elif remasking == "random": | |
| confidence = torch.rand( | |
| (batch_size, total_len), | |
| device=device, | |
| dtype=torch.float32, | |
| generator=generator, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported remasking strategy: {remasking}") | |
| if confidence_eos_eot_inf and eos_token_id is not None: | |
| confidence = torch.where( | |
| predictions == eos_token_id, | |
| torch.full_like(confidence, float("-inf")), | |
| confidence, | |
| ) | |
| confidence[:, block_end:] = float("-inf") | |
| confidence = torch.where(mask_index, confidence, torch.full_like(confidence, float("-inf"))) | |
| transfer_mask = torch.zeros_like(mask_index) | |
| for b in range(batch_size): | |
| k = int(transfer_counts[b, step_idx].item()) | |
| if k <= 0: | |
| continue | |
| available = confidence[b] > float("-inf") | |
| available_count = int(available.sum().item()) | |
| if available_count == 0: | |
| continue | |
| if available_count < k: | |
| k = available_count | |
| topk_indices = torch.topk(confidence[b], k=k, dim=-1).indices | |
| transfer_mask[b, topk_indices] = True | |
| x = torch.where(transfer_mask, predictions, x) | |
| return x | |
| def dequantize_tokens_to_uint8(tokens: np.ndarray, *, pixel_bins: int) -> np.ndarray: | |
| if pixel_bins == 256: | |
| return tokens.astype(np.uint8) | |
| vals = np.clip(tokens.astype(np.int32), 0, int(pixel_bins) - 1) | |
| scale = 256.0 / float(pixel_bins) | |
| restored = np.round((vals + 0.5) * scale - 0.5) | |
| return np.clip(restored, 0, 255).astype(np.uint8) | |
| MODEL = None | |
| DEVICE = None | |
| DTYPE = None | |
| def load_model(): | |
| global MODEL, DEVICE, DTYPE | |
| if MODEL is not None: | |
| return MODEL, DEVICE, DTYPE | |
| if not os.path.exists(CHECKPOINT_PATH): | |
| raise FileNotFoundError(f"Missing checkpoint at {CHECKPOINT_PATH}") | |
| device, dtype = _resolve_device_dtype(MODEL_CONFIG["device"], MODEL_CONFIG["dtype"]) | |
| set_sdp_backend(MODEL_CONFIG["attention_sdp_backend"]) | |
| model = TransformerImage( | |
| vocab_size=MODEL_CONFIG["vocab_size"], | |
| context_length=MODEL_CONFIG["context_length"], | |
| d_model=MODEL_CONFIG["d_model"], | |
| num_layers=MODEL_CONFIG["num_layers"], | |
| num_heads=MODEL_CONFIG["num_heads"], | |
| d_ff=MODEL_CONFIG["d_ff"], | |
| rope_theta=MODEL_CONFIG["rope_theta"], | |
| label_vocab_size=MODEL_CONFIG["label_vocab_size"], | |
| attention_backend=MODEL_CONFIG["attention_backend"], | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| model_state = load_file(CHECKPOINT_PATH) | |
| model.load_state_dict(model_state) | |
| model.eval().to(device) | |
| MODEL = model | |
| DEVICE = device | |
| DTYPE = dtype | |
| return MODEL, DEVICE, DTYPE | |
| def generate_images(label: int, steps: int, num_samples: int) -> List[Image.Image]: | |
| model, device, _ = load_model() | |
| num_samples = int(num_samples) | |
| label = int(label) | |
| steps = int(steps) | |
| context = torch.full((num_samples,), label, device=device, dtype=torch.long) | |
| prompt = torch.empty((num_samples, 0), device=device, dtype=torch.long) | |
| cfg_scale = float(INFER_CONFIG["cfg_scale"]) | |
| uncond_context = None | |
| if cfg_scale > 0.0: | |
| null_label_id = int(MODEL_CONFIG["null_label_id"]) | |
| uncond_context = torch.full((num_samples,), null_label_id, device=device, dtype=torch.long) | |
| out_indices = image_diffusion_generate( | |
| model, | |
| prompt, | |
| context=context, | |
| mask_id=int(MODEL_CONFIG["mask_token_id"]), | |
| eos_token_id=None, | |
| steps=steps, | |
| gen_length=int(MODEL_CONFIG["context_length"]), | |
| block_length=int(INFER_CONFIG["block_length"]), | |
| temperature=float(INFER_CONFIG["temperature"]), | |
| top_p=float(INFER_CONFIG["top_p"]), | |
| cfg_scale=cfg_scale, | |
| uncond_context=uncond_context, | |
| remasking=str(INFER_CONFIG["remasking"]), | |
| logits_eos_inf=False, | |
| confidence_eos_eot_inf=False, | |
| generator=None, | |
| ) | |
| h = int(MODEL_CONFIG["image_height"]) | |
| w = int(MODEL_CONFIG["image_width"]) | |
| pixel_bins = int(MODEL_CONFIG["pixel_bins"]) | |
| images: List[Image.Image] = [] | |
| scale = 10 | |
| for i in range(num_samples): | |
| tokens = out_indices[i].detach().cpu().to(torch.int32).numpy().reshape(h, w) | |
| arr = dequantize_tokens_to_uint8(tokens, pixel_bins=pixel_bins) | |
| img = Image.fromarray(arr, mode="L") | |
| if scale > 1: | |
| img = img.resize((w * scale, h * scale), resample=Image.NEAREST) | |
| images.append(img) | |
| return images | |
| def _grid_dims(num_samples: int) -> Tup[int, int]: | |
| cols = int(np.ceil(np.sqrt(num_samples))) | |
| rows = int(np.ceil(num_samples / cols)) | |
| return rows, cols | |
| def generate_grid_image(label: int, steps: int, num_samples: int) -> Image.Image: | |
| images = generate_images(label=label, steps=steps, num_samples=num_samples) | |
| if not images: | |
| return Image.new("L", (1, 1), color=0) | |
| rows, cols = _grid_dims(len(images)) | |
| w, h = images[0].size | |
| grid = Image.new("L", (cols * w, rows * h)) | |
| for idx, img in enumerate(images): | |
| r = idx // cols | |
| c = idx % cols | |
| grid.paste(img, (c * w, r * h)) | |
| return grid | |