MMaDA-Parallel-A / generators /image_generation_generator.py
akhaliq's picture
akhaliq HF Staff
Upload 22 files
9b58924 verified
# -*- coding: utf-8 -*-
"""
Image generation generator (with optional debug prints/saving)
"""
import torch
import math
import os
import numpy as np
from typing import Callable, Optional
from utils.generation_utils import cosine_schedule, gumbel_max_sample, mask_by_random_topk
from model import LLaDAForMultiModalGeneration
@torch.no_grad()
def generate_image(
model,
prompt: torch.LongTensor,
*,
seq_len: int = 1024,
newline_every: int = 16,
timesteps: int = 18,
mask_token_id: int = 126336,
newline_id: int = 126084,
temperature: float = 1.0,
cfg_scale: float = 0.0,
uncon_ids: torch.LongTensor = None,
code_start: Optional[int] = None,
codebook_size: int = 8192,
noise_schedule: Callable[[torch.Tensor], torch.Tensor] = cosine_schedule,
text_vocab_size: Optional[int] = None,
generator: Optional[torch.Generator] = None,
use_cache=False,
cache_ratio=0.9,
refresh_interval=5,
warmup_ratio=0.3,
debug: bool = True,
debug_log_dir: Optional[str] = None,
max_print_tokens: int = 100
) -> torch.LongTensor:
"""
MaskGit parallel decoding to generate VQ tokens
Added debug=True to print shapes and token samples per step. Optional debug_log_dir to save numpy dumps.
Args:
debug: when True, print detailed info each step.
debug_log_dir: directory to save per-step npy dumps (x, vq_mask, logits, sampled_full)
max_print_tokens: maximum number of tokens/logits to print for arrays (prevents terminal spam)
"""
if debug and debug_log_dir:
os.makedirs(debug_log_dir, exist_ok=True)
device = next(model.parameters()).device
prompt = prompt.to(device)
B, P = prompt.shape
assert B == 1, "batch>1 not supported – wrap in loop if needed"
x = prompt.clone()
vq_mask = x == mask_token_id
unknown_cnt = vq_mask.sum(dim=1, keepdim=True)
vq_len = unknown_cnt
if isinstance(model, LLaDAForMultiModalGeneration):
model.caching(use_cache)
else: # DDP
model.module.caching(use_cache)
warmup_step = int(timesteps * warmup_ratio)
refresh_steps = torch.zeros(timesteps, dtype=torch.bool)
for step in range(timesteps):
if not use_cache or step <= warmup_step or (step-warmup_step) % refresh_interval == 0:
refresh_steps[step] = True
compute_ratio = 1 - cache_ratio
# Infer text vocabulary size
if text_vocab_size is None:
# call with a minimal input to get logits size
vocab_total = model(torch.zeros(1, 1, dtype=torch.long, device=device), infer=True).logits.size(-1)
text_vocab_size = vocab_total - codebook_size
vocab_offset = text_vocab_size
if debug:
print("=== generate_image debug start ===")
print(f"device={device}, seq_len={seq_len}, code_start={code_start}, codebook_size={codebook_size}")
print(f"text_vocab_size={text_vocab_size}, vocab_offset={vocab_offset}")
print(f"Initial x.shape={x.shape}, initial unknown_cnt={int(unknown_cnt.item())}")
print("==================================")
for step in range(timesteps):
if unknown_cnt.item() == 0:
if debug:
print(f"[step {step}] All tokens filled, breaking early.")
break
# Calculate number of tokens to keep (continue masking) this round
if step < timesteps - 1:
frac = noise_schedule(torch.tensor([(step + 1) / timesteps], device=device))
keep_n = (vq_len.float() * frac).floor().clamp_min(1).long()
else:
keep_n = torch.zeros_like(unknown_cnt)
if use_cache and step and refresh_steps[step]:
if isinstance(model, LLaDAForMultiModalGeneration):
model.empty_cache()
else: # DDP
model.module.empty_cache()
if debug:
print(f"\n--- step {step} ---")
print(f"unknown_cnt={int(unknown_cnt.item())}, keep_n={int(keep_n.item())}, refresh_step={bool(refresh_steps[step])}")
print(f"x.shape={x.shape}, vq_mask.sum()={int(vq_mask.sum().item())}")
# print a slice of tokens around code_start for visibility if code_start is set
if code_start is not None:
cs = code_start
sample_slice = x[0, cs:cs+min(50, x.shape[1]-cs)].detach().cpu().numpy().tolist()
print(f"x tokens at code_start (first 50): {sample_slice[:min(len(sample_slice), max_print_tokens)]}")
# Forward pass (with/without CFG)
if cfg_scale > 0:
# build uncond sequence
uncond = torch.cat((uncon_ids.to(x.device), x[:, code_start-2:]), axis=1)
uncond_vq_mask = torch.cat((torch.zeros((1, uncon_ids.size()[1]), dtype=torch.bool).to(x.device), vq_mask[:, code_start-2:]), axis=1)
# conditional logits
cond_out = model(x, infer=True, use_cache=use_cache)
cond_logits = cond_out.logits[..., vocab_offset : vocab_offset + codebook_size]
if debug:
print(f"cond_logits shape: {cond_logits.shape}")
cond_mask_logits = cond_logits[vq_mask].view(B, -1, codebook_size)
"""
if debug:
print(f"cond_mask_logits shape (after vq_mask): {tuple(cond_mask_logits.shape)}")
# print few values
tmp = cond_mask_logits.detach().cpu().numpy()
flat_tmp = tmp.reshape(-1, tmp.shape[-1])
if flat_tmp.shape[0] > 0:
print("cond_mask_logits[first_row, first_10]:", flat_tmp[0, :min(10, flat_tmp.shape[1])].tolist())
"""
# unconditional logits
uncond_out = model(uncond, infer=True, use_cache=use_cache)
uncond_logits = uncond_out.logits[..., vocab_offset : vocab_offset + codebook_size]
if debug:
print(f"uncond_logits shape: {uncond_logits.shape}")
uncond_mask_logits = uncond_logits[uncond_vq_mask].view(B, -1, codebook_size)
"""
if debug:
print(f"uncond_mask_logits shape: {tuple(uncond_mask_logits.shape)}")
tmpu = uncond_mask_logits.detach().cpu().numpy()
if tmpu.size:
print("uncond_mask_logits[first_row, first_10]:", tmpu.reshape(-1, tmpu.shape[-1])[0, :min(10, tmpu.shape[-1])].tolist())
"""
logits = (1 + cfg_scale) * cond_mask_logits - cfg_scale * uncond_mask_logits
if debug:
print(f"combined logits shape: {logits.shape}")
else:
out = model(x, infer=True)
# logits for masked positions: (B, num_masked, codebook_size)
# here we index directly by boolean mask along sequence dim
logits = out.logits[:, vq_mask[0], vocab_offset : vocab_offset + codebook_size]
if debug:
print(f"logits shape (no-cfg): {logits.shape}")
ltmp = logits.detach().cpu().numpy()
if ltmp.size:
print("logits[first_pos, first_10]:", ltmp[0, :min(10, ltmp.shape[1])].tolist() if ltmp.ndim == 2 else ltmp.reshape(-1, ltmp.shape[-1])[0, :min(10, ltmp.shape[-1])].tolist())
# sample
sampled = gumbel_max_sample(logits, temperature, generator=generator)
sampled_full = sampled + vocab_offset # bring to full token space
probs = torch.softmax(logits, dim=-1)
conf = probs.gather(-1, sampled.unsqueeze(-1)).squeeze(-1)
if debug:
print(f"sampled.shape={sampled.shape}, sampled_full.shape={sampled_full.shape}, conf.shape={conf.shape}")
# print some sampled tokens
sf_np = sampled_full.detach().cpu().numpy().reshape(-1).tolist()
print(f"sampled_full(first {min(len(sf_np), max_print_tokens)}): {sf_np[:min(len(sf_np), max_print_tokens)]}")
# write sampled tokens into x at masked positions
flat_idx = vq_mask.nonzero(as_tuple=False)[:, 1]
if debug:
print(f"flat_idx (masked positions indices) length={flat_idx.shape[0]}")
if flat_idx.numel() > 0:
print(f"flat_idx first 30: {flat_idx[:min(30, flat_idx.shape[0])].detach().cpu().numpy().tolist()}")
x.view(-1)[flat_idx] = sampled_full.view(-1)
# confidence map (for display / selection)
conf_map = torch.full_like(x, -math.inf, dtype=probs.dtype)
conf_map.view(-1)[flat_idx] = conf.view(-1)
if debug:
# show some stats of conf_map in code region
try:
conf_np = conf.detach().cpu().numpy().reshape(-1)
print(f"conf stats (min/mean/max): {float(conf_np.min()):.6f}/{float(conf_np.mean()):.6f}/{float(conf_np.max()):.6f}")
except Exception:
pass
# mask selection -> re-mask some tokens for next step
mask_sel = mask_by_random_topk(keep_n.squeeze(1), conf, temperature=temperature, generator=generator)
if debug:
print(f"mask_sel.shape={mask_sel.shape}, mask_sel.sum()={int(mask_sel.sum().item())}")
x.view(-1)[flat_idx[mask_sel.view(-1)]] = mask_token_id
vq_mask = x == mask_token_id
unknown_cnt = vq_mask.sum(dim=1, keepdim=True)
if debug:
print(f"after masking, vq_mask.sum()={int(vq_mask.sum().item())}, unknown_cnt={int(unknown_cnt.item())}")
# Save debug artifacts if requested
if debug and debug_log_dir:
step_base = os.path.join(debug_log_dir, f"step_{step}")
try:
np.save(step_base + "_x.npy", x.detach().cpu().numpy())
np.save(step_base + "_vq_mask.npy", vq_mask.detach().cpu().numpy())
# logits may be large; save as float32
np.save(step_base + "_logits.npy", logits.detach().cpu().numpy().astype(np.float32))
np.save(step_base + "_sampled_full.npy", sampled_full.detach().cpu().numpy())
except Exception as e:
print(f"[debug] failed to save debug npy at step {step}: {e}")
# Update cond/uncond compute masks for caching only if cfg_scale>0
if use_cache and step < timesteps - 1 and not refresh_steps[step+1] and cfg_scale > 0:
cond_conf = cond_logits.max(dim=-1)[0]
cond_conf_threshold = torch.quantile(cond_conf.to(torch.float), compute_ratio, dim=-1, keepdim=True)
cond_to_compute_mask = cond_conf <= cond_conf_threshold
uncond_conf = uncond_logits.max(dim=-1)[0]
uncond_conf_threshold = torch.quantile(uncond_conf.to(torch.float), compute_ratio, dim=-1, keepdim=True)
uncond_to_compute_mask = uncond_conf <= uncond_conf_threshold
if debug:
print(f"cond_conf shape: {cond_conf.shape}, threshold: {cond_conf_threshold.detach().cpu().numpy().tolist()}")
print(f"uncond_conf shape: {uncond_conf.shape}, threshold: {uncond_conf_threshold.detach().cpu().numpy().tolist()}")
# Remove newline tokens and shape properly
vq_ids = x[0, code_start:-2]
vq_ids = vq_ids[vq_ids != newline_id].view(1, seq_len)
if debug:
print("=== generate_image debug end ===")
print(f"final vq_ids.shape={vq_ids.shape}")
try:
print("final vq_ids first 100:", vq_ids.detach().cpu().numpy().reshape(-1)[:min(max_print_tokens, vq_ids.numel())].tolist())
except Exception:
pass
return vq_ids