from __future__ import annotations import os from typing import Tuple, List, Tuple as Tup import numpy as np import torch from PIL import Image from safetensors.torch import load_file from einops import einsum, rearrange BASE_DIR = os.path.dirname(os.path.abspath(__file__)) CHECKPOINT_PATH = os.path.join(BASE_DIR, "model", "model.safetensors") MODEL_CONFIG = { "model_type": "image", "label_vocab_size": 11, "vocab_size": 33, "pixel_bins": 32, "context_length": 784, "d_model": 256, "num_layers": 8, "num_heads": 16, "d_ff": 1024, "rope_theta": 10000.0, "attention_backend": "torch_sdpa", "attention_sdp_backend": "auto", "device": "cuda", "dtype": "float16", "mask_token_id": 32, "null_label_id": 10, "image_height": 28, "image_width": 28, } INFER_CONFIG = { "block_length": 784, "temperature": 0.6, "top_p": 0.99, "cfg_scale": 2.0, "remasking": "random", } DTYPES = { "float16": torch.float16, "float32": torch.float32, "bfloat16": torch.bfloat16, } def _resolve_device_dtype(device: str, dtype_name: str) -> Tuple[str, torch.dtype]: resolved_device = device if device == "cuda" and not torch.cuda.is_available(): resolved_device = "cpu" resolved_dtype = DTYPES[dtype_name] if resolved_device == "cpu" and resolved_dtype == torch.float16: resolved_dtype = torch.float32 return resolved_device, resolved_dtype def set_sdp_backend(backend: str) -> None: backend = backend.lower() allowed = {"auto", "flash", "mem_efficient", "math"} if backend not in allowed: raise ValueError(f"attention_sdp_backend must be one of {sorted(allowed)}") if not torch.cuda.is_available(): return if backend == "auto": torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_math_sdp(True) return torch.backends.cuda.enable_flash_sdp(backend == "flash") torch.backends.cuda.enable_mem_efficient_sdp(backend == "mem_efficient") torch.backends.cuda.enable_math_sdp(backend == "math") class Linear(torch.nn.Module): def __init__(self, in_features, out_features, device=None, dtype=None): super().__init__() self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, device=device, dtype=dtype)) mean = 0.0 std = 2 / (in_features + out_features) a = mean - 3 * std b = mean + 3 * std torch.nn.init.trunc_normal_(self.weight, mean=mean, std=std, a=a, b=b) def forward(self, x: torch.Tensor) -> torch.Tensor: y = einsum(self.weight, x, "out_features in_features, ... in_features -> ... out_features") return y class Embedding(torch.nn.Module): def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None): super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.weight = torch.nn.Parameter(torch.empty(num_embeddings, embedding_dim, device=device, dtype=dtype)) torch.nn.init.trunc_normal_(self.weight, mean=0, std=1, a=-3, b=3) def forward(self, token_ids: torch.Tensor) -> torch.Tensor: embeds = self.weight[token_ids] return embeds class RMSNorm(torch.nn.Module): def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None): super().__init__() self.eps = eps self.d_model = d_model self.weight = torch.nn.Parameter(torch.empty(d_model, device=device, dtype=dtype)) torch.nn.init.ones_(self.weight) def forward(self, x: torch.Tensor) -> torch.Tensor: in_dtype = x.dtype x = x.to(torch.float32) rms = torch.sqrt(torch.mean(x ** 2, dim=-1) + self.eps).unsqueeze(-1) x = (1 / rms) * (x * self.weight) return x.to(in_dtype) class SwiGLU(torch.nn.Module): def __init__(self, d_model: int, d_ff: int, device=None, dtype=None): super().__init__() self.w1 = Linear(d_model, d_ff, device=device, dtype=dtype) self.w2 = Linear(d_ff, d_model, device=device, dtype=dtype) self.w3 = Linear(d_model, d_ff, device=device, dtype=dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: w1x = self.w1(x) w3x = self.w3(x) silu = w1x * torch.sigmoid(w1x) glu = silu * w3x w2x = self.w2(glu) return w2x def softmax(x: torch.Tensor, dim: int): x_max = x.max(dim=dim, keepdim=True).values x_stable = x - x_max exp_x = torch.exp(x_stable) sum_exp_x = exp_x.sum(dim=dim, keepdim=True) return exp_x / sum_exp_x def top_p_filter(probs: torch.Tensor, p: float) -> torch.Tensor: if probs.dim() < 2: raise ValueError("probs must have at least 2 dimensions") orig_shape = probs.shape vocab = orig_shape[-1] probs = probs.reshape(-1, vocab) if p <= 0: argmax = probs.argmax(dim=-1) out = torch.zeros_like(probs) out.scatter_(-1, argmax.unsqueeze(-1), 1.0) return out.reshape(orig_shape) if p >= 1: return (probs / probs.sum(dim=-1, keepdim=True)).reshape(orig_shape) sorted_probs, sorted_indices = torch.sort(probs, descending=True) cumulative = torch.cumsum(sorted_probs, dim=-1) keep = cumulative <= p keep[..., 0] = True first_ge = (cumulative >= p).float().argmax(dim=-1) rows = torch.arange(keep.shape[0], device=keep.device) keep[rows, first_ge] = True filtered_sorted = torch.where(keep, sorted_probs, torch.zeros_like(sorted_probs)) norm = filtered_sorted.sum(dim=-1, keepdim=True).clamp_min(1e-12) filtered_sorted = filtered_sorted / norm filtered = torch.zeros_like(probs) filtered.scatter_(dim=-1, index=sorted_indices, src=filtered_sorted) return filtered.reshape(orig_shape) def add_gumbel_noise(logits: torch.Tensor, temperature: float, *, generator: torch.Generator | None = None) -> torch.Tensor: if temperature <= 0: return logits noise = torch.rand(logits.shape, device=logits.device, dtype=torch.float64, generator=generator) gumbel_noise = (-torch.log(noise)) ** temperature logits64 = logits.to(torch.float64) perturbed = logits64.exp() / gumbel_noise return perturbed.to(logits.dtype) def compute_transfer_schedule(mask: torch.Tensor, steps: int) -> torch.Tensor: if steps <= 0: raise ValueError("steps must be > 0") if mask.dim() != 2: raise ValueError("mask must be 2D (batch, block_length)") counts = mask.sum(dim=1, keepdim=True).to(torch.int64) base = counts // steps remainder = counts % steps schedule = base.expand(-1, steps).clone() for idx in range(schedule.size(0)): r = remainder[idx, 0].item() if r > 0: schedule[idx, :r] += 1 return schedule def _prepare_attention_mask(attention_mask: torch.Tensor, ref_tensor: torch.Tensor) -> torch.Tensor: mask = attention_mask.to(device=ref_tensor.device, dtype=torch.bool) if mask.dim() == 2: mask = mask[:, None, None, :] elif mask.dim() == 3: mask = mask[:, None, :, :] elif mask.dim() != 4: raise ValueError("attention_mask must be 2D, 3D, or 4D") return mask def scaled_dot_product_attention( Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, attention_mask: torch.Tensor | None = None, ): scale = torch.tensor(Q.shape[-1], device=Q.device, dtype=Q.dtype).sqrt() qk_score = einsum(Q, K, "batch_size ... n d_k, batch_size ... m d_k -> batch_size ... n m") / scale if attention_mask is not None: mask = _prepare_attention_mask(attention_mask, qk_score) qk_score = qk_score.masked_fill(~mask, float("-inf")) softmax_qk_score = softmax(qk_score, dim=-1) attn = einsum(softmax_qk_score, V, "batch_size ... n m, batch_size ... m d_k -> batch_size ... n d_k") return attn def torch_scaled_dot_product_attention( Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, attention_mask: torch.Tensor | None = None, ): Q = Q.contiguous() K = K.contiguous() V = V.contiguous() mask = None if attention_mask is not None: mask = _prepare_attention_mask(attention_mask, Q) return torch.nn.functional.scaled_dot_product_attention(Q, K, V, attn_mask=mask, dropout_p=0.0, is_causal=False) class RotaryPositionalEmbedding(torch.nn.Module): def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None): super().__init__() self.device = device theta_i = theta ** (torch.arange(0, d_k, 2).float() / d_k) position = torch.arange(max_seq_len) phases = position.unsqueeze(1) / theta_i.unsqueeze(0) phases_cos = torch.cos(phases) phases_sin = torch.sin(phases) phases_combined = torch.stack([phases_cos, phases_sin], dim=-1).to(device=device) self.register_buffer("phases", phases_combined, persistent=False) def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor: x = rearrange(x, "... (d_k p) -> ... d_k p", p=2) x1 = x[..., 0] x2 = x[..., 1] phases_cos = self.phases[..., 0][token_positions].to(dtype=x.dtype) phases_sin = self.phases[..., 1][token_positions].to(dtype=x.dtype) x_rotated = torch.stack([ x1 * phases_cos - x2 * phases_sin, x1 * phases_sin + x2 * phases_cos, ], dim=-1) return x_rotated.flatten(-2) class MultiheadSelfAttentionRoPE(torch.nn.Module): def __init__( self, d_model: int, num_heads: int, max_seq_len: int, theta: float, attention_backend: str = "custom", device=None, dtype=None, ): super().__init__() self.d_model = d_model self.num_heads = num_heads self.d_k = self.d_model // self.num_heads self.d_v = self.d_k self.max_seq_len = max_seq_len self.theta = theta if attention_backend not in {"custom", "torch_sdpa"}: raise ValueError("attention_backend must be one of ['custom', 'torch_sdpa']") self.attention_backend = attention_backend self.q_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype) self.k_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype) self.v_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype) self.output_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype) self.rope = RotaryPositionalEmbedding(self.theta, self.d_k, self.max_seq_len, device) def forward( self, x: torch.Tensor, token_positions: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: wqx = self.q_proj(x) wqx_rearr = rearrange(wqx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k) wqx_rearr_rope = self.rope(wqx_rearr, token_positions) wkx = self.k_proj(x) wkx_rearr = rearrange(wkx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k) wkx_rearr_rope = self.rope(wkx_rearr, token_positions) wvx = self.v_proj(x) wvx_rearr = rearrange(wvx, "... seq_len (num_heads d_v) -> ... num_heads seq_len d_v", num_heads=self.num_heads, d_v=self.d_v) if self.attention_backend == "torch_sdpa": attn = torch_scaled_dot_product_attention( wqx_rearr_rope, wkx_rearr_rope, wvx_rearr, attention_mask=attention_mask, ) else: attn = scaled_dot_product_attention( wqx_rearr_rope, wkx_rearr_rope, wvx_rearr, attention_mask=attention_mask, ) attn_rearr = rearrange(attn, "... num_heads seq_len d_v -> ... seq_len (num_heads d_v)", num_heads=self.num_heads, d_v=self.d_v) attn_rearr_proj = self.output_proj(attn_rearr) return attn_rearr_proj class MultiheadCrossAttentionRoPE(torch.nn.Module): def __init__( self, d_model: int, num_heads: int, max_seq_len: int, theta: float, attention_backend: str = "custom", device=None, dtype=None, ): super().__init__() self.d_model = d_model self.num_heads = num_heads self.d_k = self.d_model // self.num_heads self.d_v = self.d_k self.max_seq_len = max_seq_len self.theta = theta if attention_backend not in {"custom", "torch_sdpa"}: raise ValueError("attention_backend must be one of ['custom', 'torch_sdpa']") self.attention_backend = attention_backend self.q_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype) self.k_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype) self.v_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype) self.output_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype) self.rope = RotaryPositionalEmbedding(self.theta, self.d_k, self.max_seq_len, device) def forward( self, x: torch.Tensor, context: torch.Tensor, token_positions: torch.Tensor, context_token_positions: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: wqx = self.q_proj(x) wqx_rearr = rearrange(wqx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k) wqx_rearr_rope = self.rope(wqx_rearr, token_positions) wkx = self.k_proj(context) wkx_rearr = rearrange(wkx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k) wkx_rearr_rope = self.rope(wkx_rearr, context_token_positions) wvx = self.v_proj(context) wvx_rearr = rearrange(wvx, "... seq_len (num_heads d_v) -> ... num_heads seq_len d_v", num_heads=self.num_heads, d_v=self.d_v) if self.attention_backend == "torch_sdpa": attn = torch_scaled_dot_product_attention( wqx_rearr_rope, wkx_rearr_rope, wvx_rearr, attention_mask=attention_mask, ) else: attn = scaled_dot_product_attention( wqx_rearr_rope, wkx_rearr_rope, wvx_rearr, attention_mask=attention_mask, ) attn_rearr = rearrange(attn, "... num_heads seq_len d_v -> ... seq_len (num_heads d_v)", num_heads=self.num_heads, d_v=self.d_v) attn_rearr_proj = self.output_proj(attn_rearr) return attn_rearr_proj class TransformerImageBlock(torch.nn.Module): def __init__( self, d_model: int, num_heads: int, max_seq_len: int, theta: float, d_ff: int, attention_backend: str = "custom", device=None, dtype=None, ): super().__init__() self.ffn = SwiGLU(d_model, d_ff, device, dtype) self.self_attn = MultiheadSelfAttentionRoPE( d_model, num_heads, max_seq_len, theta, attention_backend=attention_backend, device=device, dtype=dtype, ) self.cross_attn = MultiheadCrossAttentionRoPE( d_model, num_heads, max_seq_len, theta, attention_backend=attention_backend, device=device, dtype=dtype, ) self.ln1 = RMSNorm(d_model, device=device, dtype=dtype) self.ln2 = RMSNorm(d_model, device=device, dtype=dtype) self.ln3 = RMSNorm(d_model, device=device, dtype=dtype) def forward( self, x: torch.Tensor, token_positions: torch.Tensor, context: torch.Tensor, context_token_positions: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: ln1x = self.ln1(x) x = x + self.self_attn(ln1x, token_positions, attention_mask=attention_mask) ln2x = self.ln2(x) x = x + self.cross_attn( ln2x, context, token_positions, context_token_positions, attention_mask=None, ) ln3x = self.ln3(x) x = x + self.ffn(ln3x) return x class TransformerImage(torch.nn.Module): def __init__( self, vocab_size: int, context_length: int, d_model: int, num_layers: int, num_heads: int, d_ff: int, rope_theta: float, label_vocab_size: int, attention_backend: str = "custom", device=None, dtype=None, ): super().__init__() self.context_length = context_length self.token_embeddings = Embedding(vocab_size, d_model, device, dtype) self.label_embeddings = Embedding(label_vocab_size, d_model, device, dtype) self.layers = torch.nn.ModuleList( [ TransformerImageBlock( d_model, num_heads, context_length, rope_theta, d_ff, attention_backend=attention_backend, device=device, dtype=dtype, ) for _ in range(num_layers) ] ) self.ln_final = RMSNorm(d_model, device=device, dtype=dtype) self.lm_head = Linear(d_model, vocab_size, device, dtype) def forward( self, in_indices: torch.Tensor, attention_mask: torch.Tensor | None = None, context: torch.Tensor | None = None, ) -> torch.Tensor: if context is None: raise ValueError("context must be provided for TransformerImage") output_seq = self.token_embeddings(in_indices) context_emb = self.label_embeddings(context).unsqueeze(-2) token_positions = torch.arange(output_seq.shape[-2], device=output_seq.device, dtype=torch.long) context_token_positions = torch.arange(context_emb.shape[-2], device=output_seq.device, dtype=torch.long) for layer in self.layers: output_seq = layer( output_seq, token_positions, context_emb, context_token_positions, attention_mask=attention_mask, ) normed_output_seq = self.ln_final(output_seq) logits = self.lm_head(normed_output_seq) return logits @torch.no_grad() def image_diffusion_generate( model, prompt_indices: torch.Tensor, *, context: torch.Tensor, mask_id: int, eos_token_id: int | None = None, steps: int, gen_length: int, block_length: int, temperature: float = 0.0, top_p: float | None = None, cfg_scale: float = 0.0, uncond_context: torch.Tensor | None = None, remasking: str = "random", logits_eos_inf: bool = False, confidence_eos_eot_inf: bool = False, generator: torch.Generator | None = None, ) -> torch.Tensor: if prompt_indices.dim() != 2: raise ValueError("prompt_indices must be 2D (batch, seq)") if context.dim() != 1: raise ValueError("context must be 1D (batch,)") if prompt_indices.shape[0] != context.shape[0]: raise ValueError("context batch size must match prompt batch size") if block_length <= 0: raise ValueError("block_length must be > 0") if steps <= 0: raise ValueError("steps must be > 0") if gen_length <= 0: return prompt_indices blocks = max(1, int(np.ceil(gen_length / block_length))) if steps < blocks: raise ValueError("steps must be >= number of blocks") base_steps = steps // blocks extra_steps = steps % blocks device = prompt_indices.device batch_size, prompt_len = prompt_indices.shape total_len = prompt_len + gen_length context_limit = getattr(model, "context_length", None) if context_limit is not None and total_len > int(context_limit): raise ValueError("prompt length + gen_length exceeds model context_length") x = torch.full( (batch_size, total_len), fill_value=mask_id, device=device, dtype=prompt_indices.dtype, ) x[:, :prompt_len] = prompt_indices if uncond_context is not None: if uncond_context.dim() != 1: raise ValueError("uncond_context must be 1D (batch,)") if uncond_context.shape[0] != batch_size: raise ValueError("uncond_context batch size must match prompt batch size") uncond_context = uncond_context.to(device=context.device, dtype=context.dtype) for block_idx in range(blocks): block_start = prompt_len + block_idx * block_length block_end = min(block_start + block_length, total_len) block_steps = base_steps + (1 if block_idx < extra_steps else 0) if block_steps <= 0: block_steps = 1 block_mask = (x[:, block_start:block_end] == mask_id) transfer_counts = compute_transfer_schedule(block_mask, block_steps) for step_idx in range(block_steps): mask_index = (x == mask_id) if cfg_scale > 0.0: if uncond_context is None: raise ValueError("uncond_context must be set when cfg_scale > 0 for image_diffusion_generate") cond_logits = model(x, context=context) uncond_logits = model(x, context=uncond_context) logits = uncond_logits + (cfg_scale + 1.0) * (cond_logits - uncond_logits) else: logits = model(x, context=context) if logits_eos_inf and eos_token_id is not None: logits[:, :, eos_token_id] = float("-inf") if top_p is not None: probs = softmax(logits, dim=-1) probs = top_p_filter(probs, float(top_p)) logits = torch.where( probs > 0, logits, torch.full_like(logits, float("-inf")), ) logits_with_noise = add_gumbel_noise(logits, temperature, generator=generator) predictions = torch.argmax(logits_with_noise, dim=-1) predictions = torch.where(mask_index, predictions, x) if remasking == "low_confidence": probs = softmax(logits, dim=-1) confidence = torch.squeeze( torch.gather(probs, dim=-1, index=torch.unsqueeze(predictions, -1)), -1, ) elif remasking == "random": confidence = torch.rand( (batch_size, total_len), device=device, dtype=torch.float32, generator=generator, ) else: raise ValueError(f"Unsupported remasking strategy: {remasking}") if confidence_eos_eot_inf and eos_token_id is not None: confidence = torch.where( predictions == eos_token_id, torch.full_like(confidence, float("-inf")), confidence, ) confidence[:, block_end:] = float("-inf") confidence = torch.where(mask_index, confidence, torch.full_like(confidence, float("-inf"))) transfer_mask = torch.zeros_like(mask_index) for b in range(batch_size): k = int(transfer_counts[b, step_idx].item()) if k <= 0: continue available = confidence[b] > float("-inf") available_count = int(available.sum().item()) if available_count == 0: continue if available_count < k: k = available_count topk_indices = torch.topk(confidence[b], k=k, dim=-1).indices transfer_mask[b, topk_indices] = True x = torch.where(transfer_mask, predictions, x) return x def dequantize_tokens_to_uint8(tokens: np.ndarray, *, pixel_bins: int) -> np.ndarray: if pixel_bins == 256: return tokens.astype(np.uint8) vals = np.clip(tokens.astype(np.int32), 0, int(pixel_bins) - 1) scale = 256.0 / float(pixel_bins) restored = np.round((vals + 0.5) * scale - 0.5) return np.clip(restored, 0, 255).astype(np.uint8) MODEL = None DEVICE = None DTYPE = None def load_model(): global MODEL, DEVICE, DTYPE if MODEL is not None: return MODEL, DEVICE, DTYPE if not os.path.exists(CHECKPOINT_PATH): raise FileNotFoundError(f"Missing checkpoint at {CHECKPOINT_PATH}") device, dtype = _resolve_device_dtype(MODEL_CONFIG["device"], MODEL_CONFIG["dtype"]) set_sdp_backend(MODEL_CONFIG["attention_sdp_backend"]) model = TransformerImage( vocab_size=MODEL_CONFIG["vocab_size"], context_length=MODEL_CONFIG["context_length"], d_model=MODEL_CONFIG["d_model"], num_layers=MODEL_CONFIG["num_layers"], num_heads=MODEL_CONFIG["num_heads"], d_ff=MODEL_CONFIG["d_ff"], rope_theta=MODEL_CONFIG["rope_theta"], label_vocab_size=MODEL_CONFIG["label_vocab_size"], attention_backend=MODEL_CONFIG["attention_backend"], device=device, dtype=dtype, ) model_state = load_file(CHECKPOINT_PATH) model.load_state_dict(model_state) model.eval().to(device) MODEL = model DEVICE = device DTYPE = dtype return MODEL, DEVICE, DTYPE @torch.inference_mode() def generate_images(label: int, steps: int, num_samples: int) -> List[Image.Image]: model, device, _ = load_model() num_samples = int(num_samples) label = int(label) steps = int(steps) context = torch.full((num_samples,), label, device=device, dtype=torch.long) prompt = torch.empty((num_samples, 0), device=device, dtype=torch.long) cfg_scale = float(INFER_CONFIG["cfg_scale"]) uncond_context = None if cfg_scale > 0.0: null_label_id = int(MODEL_CONFIG["null_label_id"]) uncond_context = torch.full((num_samples,), null_label_id, device=device, dtype=torch.long) out_indices = image_diffusion_generate( model, prompt, context=context, mask_id=int(MODEL_CONFIG["mask_token_id"]), eos_token_id=None, steps=steps, gen_length=int(MODEL_CONFIG["context_length"]), block_length=int(INFER_CONFIG["block_length"]), temperature=float(INFER_CONFIG["temperature"]), top_p=float(INFER_CONFIG["top_p"]), cfg_scale=cfg_scale, uncond_context=uncond_context, remasking=str(INFER_CONFIG["remasking"]), logits_eos_inf=False, confidence_eos_eot_inf=False, generator=None, ) h = int(MODEL_CONFIG["image_height"]) w = int(MODEL_CONFIG["image_width"]) pixel_bins = int(MODEL_CONFIG["pixel_bins"]) images: List[Image.Image] = [] scale = 10 for i in range(num_samples): tokens = out_indices[i].detach().cpu().to(torch.int32).numpy().reshape(h, w) arr = dequantize_tokens_to_uint8(tokens, pixel_bins=pixel_bins) img = Image.fromarray(arr, mode="L") if scale > 1: img = img.resize((w * scale, h * scale), resample=Image.NEAREST) images.append(img) return images def _grid_dims(num_samples: int) -> Tup[int, int]: cols = int(np.ceil(np.sqrt(num_samples))) rows = int(np.ceil(num_samples / cols)) return rows, cols @torch.inference_mode() def generate_grid_image(label: int, steps: int, num_samples: int) -> Image.Image: images = generate_images(label=label, steps=steps, num_samples=num_samples) if not images: return Image.new("L", (1, 1), color=0) rows, cols = _grid_dims(len(images)) w, h = images[0].size grid = Image.new("L", (cols * w, rows * h)) for idx, img in enumerate(images): r = idx // cols c = idx % cols grid.paste(img, (c * w, r * h)) return grid