import torch import torch.nn.functional as F from tqdm import tqdm import math import numpy as np def add_gumbel_noise(logits, temperature=1.0, generator=None): """Add Gumbel noise to logits for sampling""" if temperature == 0: return logits if generator is not None: uniform_noise = torch.rand(logits.shape, dtype=logits.dtype, device=logits.device, generator=generator) else: uniform_noise = torch.rand_like(logits) gumbel_noise = -torch.log(-torch.log(uniform_noise + 1e-10) + 1e-10) return logits + temperature * gumbel_noise def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None): """ Mask tokens by random top-k selection based on confidence probs: [batch, L] confidence scores (higher = more confident) mask_len: tensor shape [batch, 1] or scalar, number of tokens to keep masked (lowest-confidence) returns: boolean mask [batch, L] True where token should REMAIN masked """ if generator is not None: noise = torch.randn(probs.shape, dtype=probs.dtype, device=probs.device, generator=generator) else: noise = torch.randn_like(probs) # Add small noise to jitter confidences according to temperature confidence = torch.log(probs + 1e-10) + temperature * noise # higher = more confident # We want to mask lowest-confidence tokens -> find cutoff sorted_confidence, sorted_indices = torch.sort(confidence, dim=-1, descending=False) # ascending # mask_len may be float or tensor; ensure integer per-batch if isinstance(mask_len, torch.Tensor): mask_len_clamped = torch.clamp(mask_len, 0, probs.shape[-1] - 1) mask_len_clamped = mask_len_clamped.long().squeeze(-1) # shape [batch] else: mask_len_clamped = int(mask_len) # Build boolean mask: True for tokens to KEEP masked (lowest confidence) if isinstance(mask_len_clamped, torch.Tensor): batch = probs.shape[0] masking = torch.zeros_like(probs, dtype=torch.bool, device=probs.device) for b in range(batch): k = mask_len_clamped[b].item() if k <= 0: continue low_idx = sorted_indices[b, :k] # indices of lowest k confidences masking[b, low_idx] = True else: # scalar k k = mask_len_clamped if k <= 0: masking = torch.zeros_like(probs, dtype=torch.bool, device=probs.device) else: low_idx = sorted_indices[:, :k] masking = torch.zeros_like(probs, dtype=torch.bool, device=probs.device) batch = probs.shape[0] for b in range(batch): masking[b, low_idx[b]] = True return masking def cosine_schedule(t): """Cosine noise schedule""" return torch.cos(t * math.pi / 2) def get_num_transfer_tokens(text_masked_indices, text_steps): """ Calculate number of tokens to unmask at each step Returns: [batch_size, text_steps] """ batch_size = text_masked_indices.shape[0] initial_masks = text_masked_indices.sum(dim=1) # [batch_size] num_transfer = torch.zeros(batch_size, text_steps, dtype=torch.long, device=text_masked_indices.device) for b in range(batch_size): total_masks = initial_masks[b].item() remaining = total_masks for step in range(text_steps): ratio = (step + 1) / text_steps target_remaining = int(total_masks * (1 - ratio)) tokens_to_unmask = max(0, remaining - target_remaining) num_transfer[b, step] = tokens_to_unmask remaining -= tokens_to_unmask return num_transfer def generate_ti2ti( model, input_ids, text_start, text_end, image_start, seq_len, newline_every, text_steps=100, text_gen_length=256, text_block_length=64, timesteps=100, temperature=1.0, text_temperature=0.7, cfg_scale=0.0, cfg_img=4.0, uncon_text=None, uncon_image=None, tokenizer=None, remasking='low_confidence', noise_schedule=cosine_schedule, generator=None, text_vocab_size=126356, codebook_size=8192, ): """ Generate text and image jointly with interleaved generation. Text generation uses cond logits only (text_cfg assumed 0). Image generation (at scheduled steps) uses two CFGs: - uncond_text (if provided) : guidance that relates to text part - uncond_image (if provided): guidance that relates to image part """ device = input_ids.device MASK_TOKEN = 126336 NEW_LINE = 126084 # Clone input for modification combined_input_ids = input_ids.clone() # Calculate total image region length (including newlines) num_vq_tokens = seq_len total_image_len = seq_len + seq_len // newline_every image_end = image_start + total_image_len print(f"Interleaved generation: {text_steps} total steps") print(f" - Text generation range: [{text_start}, {text_end})") print(f" - Image generation range: [{image_start}, {image_end}) (total {total_image_len} including newlines)") print(f" - VQ tokens: {num_vq_tokens}") # Calculate number of tokens to unmask at each step for text text_masked_indices = combined_input_ids[:, text_start:text_end] == MASK_TOKEN num_transfer_tokens = get_num_transfer_tokens(text_masked_indices, text_steps) # Schedule: when to perform image generation steps image_generation_step_indices = torch.linspace( text_steps // 4, text_steps - 1, timesteps ).round().int().tolist() print(f" - Image generation at steps: {image_generation_step_indices[:5]}...{image_generation_step_indices[-5:]}") # Build position mapping for image (excluding newlines) image_position_mapping = [] for i in range(image_start, image_end): if combined_input_ids[0, i] != NEW_LINE: image_position_mapping.append(i) assert len(image_position_mapping) == num_vq_tokens, f"Expected {num_vq_tokens} VQ tokens, got {len(image_position_mapping)}" batch_size = combined_input_ids.shape[0] # ========== Interleaved Generation Loop ========== for step in tqdm(range(text_steps), desc="Interleaved generation"): # ===== Forward pass: compute conditional logits once per step ===== with torch.no_grad(): cond_logits = model(combined_input_ids, infer=True, use_cache=False).logits # [B, L, V] # ===== Text Generation Step (no CFG for text; use cond_logits directly) ===== text_masked_indices = combined_input_ids[:, text_start:text_end] == MASK_TOKEN if text_masked_indices.sum() > 0: # Extract text logits from cond (no guidance) text_logits = cond_logits[:, text_start:text_end, :] # Apply temperature & gumbel logits_with_noise = add_gumbel_noise(text_logits, temperature=text_temperature, generator=generator) x0 = torch.argmax(logits_with_noise, dim=-1) # [B, text_len] # Compute confidence for remasking if remasking == 'low_confidence': p = F.softmax(text_logits.to(torch.float64), dim=-1) x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # [B, text_len] elif remasking == 'random': if generator is not None: x0_p = torch.rand(x0.shape, dtype=x0.dtype, device=x0.device, generator=generator) else: x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) else: raise NotImplementedError(remasking) # keep already-unmasked tokens x0 = torch.where(text_masked_indices, x0, combined_input_ids[:, text_start:text_end]) confidence = torch.where(text_masked_indices, x0_p, -np.inf) # Select tokens to unmask based on confidence (top-k per batch element) transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) for j in range(confidence.shape[0]): k = num_transfer_tokens[j, step].item() if k > 0: _, select_index = torch.topk(confidence[j], k=k) transfer_index[j, select_index] = True # Unmask selected tokens into combined_input_ids # Note: transfer_index is [B, text_len] boolean; place into full combined_input_ids combined_input_ids[:, text_start:text_end][transfer_index] = x0[transfer_index] # ===== Image Generation Step (scheduled) ===== if step in image_generation_step_indices: # Build vq token list from current combined_input_ids (placeholder -1 for masked) vq_tokens_list = [] for pos in image_position_mapping: token = combined_input_ids[0, pos].item() if token == MASK_TOKEN: vq_tokens_list.append(-1) else: vq_token = token - text_vocab_size vq_token = max(0, min(vq_token, codebook_size - 1)) vq_tokens_list.append(vq_token) vq_tokens_tensor = torch.tensor(vq_tokens_list, device=device).unsqueeze(0) # [1, num_vq_tokens] unknown_map = vq_tokens_tensor == -1 # True where masked # Extract cond_vq_logits from cond_logits (for VQ positions and vocab offset) cond_image_logits_list = [] for pos in image_position_mapping: cond_image_logits_list.append(cond_logits[:, pos:pos+1, text_vocab_size:text_vocab_size+codebook_size]) cond_vq_logits = torch.cat(cond_image_logits_list, dim=1) # [B, num_vq_tokens, codebook_size] # Prepare uncond logits only when needed (for image CFG) # Create combined_uncond_text and combined_uncond_img by replacing prefix with uncon_text/uncon_image if (cfg_scale > 0.0 and uncon_text is not None) or (cfg_img > 0.0 and uncon_image is not None): # clone base input # IMPORTANT: uncon_text/uncon_image expected to be on the same device or will be moved # If uncon_text / uncon_image is None, create copies to avoid errors if uncon_text is None: combined_uncond_text = combined_input_ids.clone() else: combined_uncond_text = combined_input_ids.clone() prefix_len = uncon_text.shape[1] combined_uncond_text[:, :prefix_len] = uncon_text.to(device) if uncon_image is None: combined_uncond_img = combined_input_ids.clone() else: combined_uncond_img = combined_input_ids.clone() prefix_len_img = uncon_image.shape[1] combined_uncond_img[:, :prefix_len_img] = uncon_image.to(device) # Forward for unconds with torch.no_grad(): uncond_text_logits_full = model(combined_uncond_text, infer=True, use_cache=False).logits uncond_img_logits_full = model(combined_uncond_img, infer=True, use_cache=False).logits # Extract VQ ranges for each image position uncond_text_vq_list = [] uncond_img_vq_list = [] for pos in image_position_mapping: uncond_text_vq_list.append(uncond_text_logits_full[:, pos:pos+1, text_vocab_size:text_vocab_size+codebook_size]) uncond_img_vq_list.append(uncond_img_logits_full[:, pos:pos+1, text_vocab_size:text_vocab_size+codebook_size]) uncond_text_vq_logits = torch.cat(uncond_text_vq_list, dim=1) # [B, num_vq_tokens, codebook_size] uncond_img_vq_logits = torch.cat(uncond_img_vq_list, dim=1) # [B, num_vq_tokens, codebook_size] else: # no unconds provided or scales are zero -> set uncond logits to zeros so (cond - 0) works if used uncond_text_vq_logits = torch.zeros_like(cond_vq_logits) uncond_img_vq_logits = torch.zeros_like(cond_vq_logits) # Compose guided image logits: # image_logits = cond_vq + cfg_scale * (cond_vq - uncond_text_vq) + cfg_img * (cond_vq - uncond_img_vq) if cfg_scale == 0.0 and cfg_img == 0.0: image_logits = cond_vq_logits else: image_logits = cond_vq_logits if cfg_scale != 0.0: image_logits = image_logits + cfg_scale * (cond_vq_logits - uncond_text_vq_logits) if cfg_img != 0.0: image_logits = image_logits + cfg_img * (cond_vq_logits - uncond_img_vq_logits) # Sample from image_logits probs = F.softmax(image_logits, dim=-1) # [B, num_vq, codebook] if temperature == 0: sampled_ids = probs.argmax(dim=-1) else: # flatten batch*num_vq x vocab for multinomial sampled = probs.reshape(-1, image_logits.size(-1)) if generator is not None: sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*image_logits.shape[:-1]) else: sampled_ids = torch.multinomial(sampled, 1)[:, 0].view(*image_logits.shape[:-1]) # Keep already-unmasked tokens unchanged sampled_ids = torch.where(unknown_map, sampled_ids, vq_tokens_tensor) # Clamp safety sampled_ids = torch.clamp(sampled_ids, 0, codebook_size - 1) # Confidence for sampled tokens selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]).squeeze(-1) # [B, num_vq] # If token was previously unmasked, give it very high confidence so we don't remask it high_val = torch.finfo(selected_probs.dtype).max selected_probs = torch.where(unknown_map, selected_probs, high_val) # Masking ratio and mask_len calculation ratio = 1.0 * (step + 1) / text_steps mask_ratio = noise_schedule(torch.tensor(ratio, device=device)) # compute how many tokens to keep masked (lowest confidences) unknown_counts = unknown_map.sum(dim=-1, keepdim=True) # [B,1] mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(device) # shape [1,] maybe # clamp mask_len to [1, unknown_counts-1] mask_len = torch.max(torch.tensor([1], device=device), torch.min(unknown_counts - 1, mask_len.to(device).long())) # ensure shape [B,1] if mask_len.ndim == 1: mask_len = mask_len.unsqueeze(1) # temperature decay for image sampling (optional) img_temp = temperature * (1.0 - ratio) # masking boolean: True where should remain masked masking = mask_by_random_topk(mask_len, selected_probs, img_temp, generator=generator) # final_vq_tokens: -1 means remain masked, else sampled id final_vq_tokens = torch.where(masking, torch.tensor(-1, device=device), sampled_ids) # Write back into combined_input_ids (convert vq id -> full vocab id by adding offset) for idx, pos in enumerate(image_position_mapping): v = final_vq_tokens[0, idx].item() if v == -1: combined_input_ids[0, pos] = MASK_TOKEN else: combined_input_ids[0, pos] = int(v + text_vocab_size) # ===== Extract final results ===== # Extract text tokens text_tokens = combined_input_ids[0, text_start:text_end].cpu().tolist() text_tokens = [t for t in text_tokens if t != MASK_TOKEN] generated_text = tokenizer.decode(text_tokens, skip_special_tokens=True) if tokenizer is not None else text_tokens # Extract image VQ tokens image_tokens = [] for pos in image_position_mapping: token = combined_input_ids[0, pos].item() if token != MASK_TOKEN: vq_token = token - text_vocab_size vq_token = max(0, min(vq_token, codebook_size - 1)) image_tokens.append(vq_token) else: # still masked -> sample randomly image_tokens.append(int(torch.randint(0, codebook_size, (1,)).item())) print(f"Interleaved generation complete.") print(f" - Generated text: {len(text_tokens)} tokens") print(f" - Generated image: {len(image_tokens)} VQ tokens (range [0, {codebook_size}))") return image_tokens, generated_text