File size: 11,778 Bytes
9b58924
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
# -*- coding: utf-8 -*-
"""

Image generation generator (with optional debug prints/saving)

"""
import torch
import math
import os
import numpy as np
from typing import Callable, Optional
from utils.generation_utils import cosine_schedule, gumbel_max_sample, mask_by_random_topk
from model import LLaDAForMultiModalGeneration


@torch.no_grad()
def generate_image(

    model,

    prompt: torch.LongTensor,

    *,

    seq_len: int = 1024,

    newline_every: int = 16,

    timesteps: int = 18,

    mask_token_id: int = 126336,

    newline_id: int = 126084,

    temperature: float = 1.0,

    cfg_scale: float = 0.0,

    uncon_ids: torch.LongTensor = None,

    code_start: Optional[int] = None,

    codebook_size: int = 8192,

    noise_schedule: Callable[[torch.Tensor], torch.Tensor] = cosine_schedule,

    text_vocab_size: Optional[int] = None,

    generator: Optional[torch.Generator] = None,

    use_cache=False,

    cache_ratio=0.9,

    refresh_interval=5,

    warmup_ratio=0.3,

    debug: bool = True,

    debug_log_dir: Optional[str] = None,

    max_print_tokens: int = 100

) -> torch.LongTensor:
    """

    MaskGit parallel decoding to generate VQ tokens



    Added debug=True to print shapes and token samples per step. Optional debug_log_dir to save numpy dumps.



    Args:

        debug: when True, print detailed info each step.

        debug_log_dir: directory to save per-step npy dumps (x, vq_mask, logits, sampled_full)

        max_print_tokens: maximum number of tokens/logits to print for arrays (prevents terminal spam)

    """

    if debug and debug_log_dir:
        os.makedirs(debug_log_dir, exist_ok=True)

    device = next(model.parameters()).device
    prompt = prompt.to(device)
    B, P = prompt.shape
    assert B == 1, "batch>1 not supported – wrap in loop if needed"

    x = prompt.clone()

    vq_mask = x == mask_token_id
    unknown_cnt = vq_mask.sum(dim=1, keepdim=True)
    vq_len = unknown_cnt

    if isinstance(model, LLaDAForMultiModalGeneration):
        model.caching(use_cache)
    else:  # DDP
        model.module.caching(use_cache)

    warmup_step = int(timesteps * warmup_ratio)
    refresh_steps = torch.zeros(timesteps, dtype=torch.bool)
    for step in range(timesteps):
        if not use_cache or step <= warmup_step or (step-warmup_step) % refresh_interval == 0:
            refresh_steps[step] = True
    compute_ratio = 1 - cache_ratio

    # Infer text vocabulary size
    if text_vocab_size is None:
        # call with a minimal input to get logits size
        vocab_total = model(torch.zeros(1, 1, dtype=torch.long, device=device), infer=True).logits.size(-1)
        text_vocab_size = vocab_total - codebook_size
    vocab_offset = text_vocab_size

    if debug:
        print("=== generate_image debug start ===")
        print(f"device={device}, seq_len={seq_len}, code_start={code_start}, codebook_size={codebook_size}")
        print(f"text_vocab_size={text_vocab_size}, vocab_offset={vocab_offset}")
        print(f"Initial x.shape={x.shape}, initial unknown_cnt={int(unknown_cnt.item())}")
        print("==================================")

    for step in range(timesteps):
        if unknown_cnt.item() == 0:
            if debug:
                print(f"[step {step}] All tokens filled, breaking early.")
            break

        # Calculate number of tokens to keep (continue masking) this round
        if step < timesteps - 1:
            frac = noise_schedule(torch.tensor([(step + 1) / timesteps], device=device))
            keep_n = (vq_len.float() * frac).floor().clamp_min(1).long()
        else:
            keep_n = torch.zeros_like(unknown_cnt)

        if use_cache and step and refresh_steps[step]:
            if isinstance(model, LLaDAForMultiModalGeneration):
                model.empty_cache()
            else:  # DDP
                model.module.empty_cache()

        if debug:
            print(f"\n--- step {step} ---")
            print(f"unknown_cnt={int(unknown_cnt.item())}, keep_n={int(keep_n.item())}, refresh_step={bool(refresh_steps[step])}")
            print(f"x.shape={x.shape}, vq_mask.sum()={int(vq_mask.sum().item())}")
            # print a slice of tokens around code_start for visibility if code_start is set
            if code_start is not None:
                cs = code_start
                sample_slice = x[0, cs:cs+min(50, x.shape[1]-cs)].detach().cpu().numpy().tolist()
                print(f"x tokens at code_start (first 50): {sample_slice[:min(len(sample_slice), max_print_tokens)]}")
        
        # Forward pass (with/without CFG)
        if cfg_scale > 0:
            # build uncond sequence
            uncond = torch.cat((uncon_ids.to(x.device), x[:, code_start-2:]), axis=1)
            uncond_vq_mask = torch.cat((torch.zeros((1, uncon_ids.size()[1]), dtype=torch.bool).to(x.device), vq_mask[:, code_start-2:]), axis=1)

            # conditional logits
            cond_out = model(x, infer=True, use_cache=use_cache)
            cond_logits = cond_out.logits[..., vocab_offset : vocab_offset + codebook_size]
            if debug:
                print(f"cond_logits shape: {cond_logits.shape}")
            cond_mask_logits = cond_logits[vq_mask].view(B, -1, codebook_size)
            """

            if debug:

                print(f"cond_mask_logits shape (after vq_mask): {tuple(cond_mask_logits.shape)}")

                # print few values

                tmp = cond_mask_logits.detach().cpu().numpy()

                flat_tmp = tmp.reshape(-1, tmp.shape[-1])

                if flat_tmp.shape[0] > 0:

                    print("cond_mask_logits[first_row, first_10]:", flat_tmp[0, :min(10, flat_tmp.shape[1])].tolist())

"""
            # unconditional logits
            uncond_out = model(uncond, infer=True, use_cache=use_cache)
            uncond_logits = uncond_out.logits[..., vocab_offset : vocab_offset + codebook_size]
            if debug:
                print(f"uncond_logits shape: {uncond_logits.shape}")
            uncond_mask_logits = uncond_logits[uncond_vq_mask].view(B, -1, codebook_size)
            """

            if debug:

                print(f"uncond_mask_logits shape: {tuple(uncond_mask_logits.shape)}")

                tmpu = uncond_mask_logits.detach().cpu().numpy()

                if tmpu.size:

                    print("uncond_mask_logits[first_row, first_10]:", tmpu.reshape(-1, tmpu.shape[-1])[0, :min(10, tmpu.shape[-1])].tolist())

"""
            logits = (1 + cfg_scale) * cond_mask_logits - cfg_scale * uncond_mask_logits
            if debug:
                print(f"combined logits shape: {logits.shape}")

        else:
            out = model(x, infer=True)
            # logits for masked positions: (B, num_masked, codebook_size)
            # here we index directly by boolean mask along sequence dim
            logits = out.logits[:, vq_mask[0], vocab_offset : vocab_offset + codebook_size]
            if debug:
                print(f"logits shape (no-cfg): {logits.shape}")
                ltmp = logits.detach().cpu().numpy()
                if ltmp.size:
                    print("logits[first_pos, first_10]:", ltmp[0, :min(10, ltmp.shape[1])].tolist() if ltmp.ndim == 2 else ltmp.reshape(-1, ltmp.shape[-1])[0, :min(10, ltmp.shape[-1])].tolist())

        # sample
        sampled = gumbel_max_sample(logits, temperature, generator=generator)
        sampled_full = sampled + vocab_offset  # bring to full token space
        probs = torch.softmax(logits, dim=-1)
        conf = probs.gather(-1, sampled.unsqueeze(-1)).squeeze(-1)

        if debug:
            print(f"sampled.shape={sampled.shape}, sampled_full.shape={sampled_full.shape}, conf.shape={conf.shape}")
            # print some sampled tokens
            sf_np = sampled_full.detach().cpu().numpy().reshape(-1).tolist()
            print(f"sampled_full(first {min(len(sf_np), max_print_tokens)}): {sf_np[:min(len(sf_np), max_print_tokens)]}")

        # write sampled tokens into x at masked positions
        flat_idx = vq_mask.nonzero(as_tuple=False)[:, 1]
        if debug:
            print(f"flat_idx (masked positions indices) length={flat_idx.shape[0]}")
            if flat_idx.numel() > 0:
                print(f"flat_idx first 30: {flat_idx[:min(30, flat_idx.shape[0])].detach().cpu().numpy().tolist()}")

        x.view(-1)[flat_idx] = sampled_full.view(-1)

        # confidence map (for display / selection)
        conf_map = torch.full_like(x, -math.inf, dtype=probs.dtype)
        conf_map.view(-1)[flat_idx] = conf.view(-1)

        if debug:
            # show some stats of conf_map in code region
            try:
                conf_np = conf.detach().cpu().numpy().reshape(-1)
                print(f"conf stats (min/mean/max): {float(conf_np.min()):.6f}/{float(conf_np.mean()):.6f}/{float(conf_np.max()):.6f}")
            except Exception:
                pass

        # mask selection -> re-mask some tokens for next step
        mask_sel = mask_by_random_topk(keep_n.squeeze(1), conf, temperature=temperature, generator=generator)
        if debug:
            print(f"mask_sel.shape={mask_sel.shape}, mask_sel.sum()={int(mask_sel.sum().item())}")
        x.view(-1)[flat_idx[mask_sel.view(-1)]] = mask_token_id
        vq_mask = x == mask_token_id
        unknown_cnt = vq_mask.sum(dim=1, keepdim=True)

        if debug:
            print(f"after masking, vq_mask.sum()={int(vq_mask.sum().item())}, unknown_cnt={int(unknown_cnt.item())}")

        # Save debug artifacts if requested
        if debug and debug_log_dir:
            step_base = os.path.join(debug_log_dir, f"step_{step}")
            try:
                np.save(step_base + "_x.npy", x.detach().cpu().numpy())
                np.save(step_base + "_vq_mask.npy", vq_mask.detach().cpu().numpy())
                # logits may be large; save as float32
                np.save(step_base + "_logits.npy", logits.detach().cpu().numpy().astype(np.float32))
                np.save(step_base + "_sampled_full.npy", sampled_full.detach().cpu().numpy())
            except Exception as e:
                print(f"[debug] failed to save debug npy at step {step}: {e}")

        # Update cond/uncond compute masks for caching only if cfg_scale>0
        if use_cache and step < timesteps - 1 and not refresh_steps[step+1] and cfg_scale > 0:
            cond_conf = cond_logits.max(dim=-1)[0]
            cond_conf_threshold = torch.quantile(cond_conf.to(torch.float), compute_ratio, dim=-1, keepdim=True)
            cond_to_compute_mask = cond_conf <= cond_conf_threshold

            uncond_conf = uncond_logits.max(dim=-1)[0]
            uncond_conf_threshold = torch.quantile(uncond_conf.to(torch.float), compute_ratio, dim=-1, keepdim=True)
            uncond_to_compute_mask = uncond_conf <= uncond_conf_threshold

            if debug:
                print(f"cond_conf shape: {cond_conf.shape}, threshold: {cond_conf_threshold.detach().cpu().numpy().tolist()}")
                print(f"uncond_conf shape: {uncond_conf.shape}, threshold: {uncond_conf_threshold.detach().cpu().numpy().tolist()}")

    # Remove newline tokens and shape properly
    vq_ids = x[0, code_start:-2]
    vq_ids = vq_ids[vq_ids != newline_id].view(1, seq_len)

    if debug:
        print("=== generate_image debug end ===")
        print(f"final vq_ids.shape={vq_ids.shape}")
        try:
            print("final vq_ids first 100:", vq_ids.detach().cpu().numpy().reshape(-1)[:min(max_print_tokens, vq_ids.numel())].tolist())
        except Exception:
            pass

    return vq_ids