# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import torch def find_prefix_seq_length_by_pe( pe: torch.Tensor ) -> torch.Tensor: """ Find the sequence length where position encoding drops (indicating prefix boundary). Args: pe: Position encoding tensor of shape [Batch size, Sequence length ] Contains position indices for each token in the sequence. Returns: torch.Tensor: A tensor of shape [B] containing: - The index where position encoding drops for each sequence - -1 if no drop occurs in the sequence """ batch_size, seq_len = pe.shape prev = pe[:, :-1] curr = pe[:, 1:] drop_mask = curr < prev # [batch_size, seq_len-1] seq_len = torch.full((batch_size,), -1, dtype=torch.long) for b in range(batch_size): drop_pos = torch.nonzero(drop_mask[b], as_tuple=False) if drop_pos.numel() > 0: i = drop_pos[0].item() + 1 # Take first drop position (+1 because we compared shifted sequences) seq_len[b] = i return seq_len def update_causal_mask_with_pad_non_visible_2d( input_ids: torch.Tensor, attn_mask_2d: torch.Tensor, text_mask_token_id: int, block_size: int = 4, causal_attn: bool = False ) -> torch.Tensor: """ Updates a 2D attention mask for hole sequence through input_ids and text_mask_token_id Args: input_ids: Input token IDs (unused in current implementation) attn_mask_2d: 2D attention mask matrix of shape [seq_len, seq_len] where: - 0.0 indicates allowed attention - -inf indicates masked attention text_mask_token_id: ID representing masked tokens block_size: Size of the diffusion window causal_attn: If True, maintains strict causal masking throughout Returns: Modified attention mask with updated visibility patterns """ seq_len = input_ids.shape[0] device = input_ids.device # Identify masked tokens and their preceding positions input_mask = input_ids.eq(text_mask_token_id) input_before_mask = torch.zeros_like(input_mask) input_before_mask[:-1] = input_mask[1:] mask_cols = (input_mask | input_before_mask) non_mask = ~mask_cols rows = torch.arange(seq_len, device=device)[:, None] cols = torch.arange(seq_len, device=device) indices = torch.arange(seq_len, device=device) prev_non_mask = (indices * non_mask).cummax(dim=0).values max_value = torch.iinfo(indices.dtype).max mask_indices = torch.where(non_mask, indices, torch.full_like(indices, max_value)) reversed_mask_indices = torch.flip(mask_indices, dims=[0]) reversed_cummin = reversed_mask_indices.cummin(dim=0).values next_non_mask = torch.flip(reversed_cummin, dims=[0]) infra_mask = ( (cols > prev_non_mask) & (rows >= next_non_mask[None, :]) & mask_cols[None, :] ) attn_mask_2d.masked_fill_(infra_mask, -float('inf')) if not causal_attn: visible_mask = ( (rows > prev_non_mask[None, :]) & (rows < cols) & mask_cols[None, :] ) attn_mask_2d.masked_fill_(visible_mask, 0.0) return attn_mask_2d def update_causal_mask_for_one_gen_window_2d( input_ids: torch.Tensor, attn_mask_2d: torch.Tensor, block_size: int = 4, use_cache: bool = True, causal_attn: bool = False ) -> torch.Tensor: """ Updates a 2D attention mask for a diffusion window in transformer inference. Args: input_ids: Input token IDs (unused in current implementation) attn_mask_2d: 2D attention mask matrix of shape [seq_len, seq_len] where: - 0.0 indicates allowed attention - -inf indicates masked attention block_size: Size of the diffusion window use_cache: Whether key-value cache is being used causal_attn: If True, maintains strict causal masking throughout Returns: Modified attention mask with updated visibility patterns """ if not causal_attn: # Make the diffusion window (last block_size tokens) fully visible to itself # This allows bidirectional attention within the diffusion window attn_mask_2d[-block_size:, -block_size:] = 0.0 if use_cache: # Mask the last token from previous round to prevent recomputation and maintain generation consistency. attn_mask_2d[-block_size:, -block_size-1] = -float('inf') return attn_mask_2d def create_block_diff_mask_by_pe_4d( block_size: int, x0_len_list: torch.Tensor, position_ids: torch.Tensor, causal_attn: bool = False ) -> tuple[torch.Tensor, torch.Tensor]: """Generates a 4D attention mask for block-difference attention patterns. The mask consists of three regions: 1. Causal block (top-left): Standard causal attention for `x0` tokens. 2. Mutual block (bottom-right): Non-causal attention within the same block for non-`x0` tokens. 3. Prefix block (bottom-left): Non-`x0` tokens can attend to a prefix of `x0` tokens. Args: block_size (int): Size of processing blocks for non-`x0` tokens. x0_len_list (torch.Tensor): Tensor of shape [B] containing lengths of `x0` segments per batch. position_ids (torch.Tensor): Tensor of shape [B, seq_len] containing position IDs. causal_attn (bool, optional): If True, enforces causal masking in mutual blocks. Defaults to False. Returns: tuple[torch.Tensor, torch.Tensor]: - A float mask of shape [batch_size, 1, seq_len, seq_len] with `-inf` for masked positions (non visiable). - A boolean mask of shape [batch_size, 1, seq_len, seq_len] indicating allowed attention positions. """ batch_size, seq_len = position_ids.shape device = position_ids.device # Create position indices [batch_size, seq_len, seq_len] q_idx = torch.arange(seq_len, device=device).view(1, seq_len, 1) # [1, seq_len, 1] kv_idx = torch.arange(seq_len, device=device).view(1, 1, seq_len) # [1, 1, seq_len] # Broadcast to [B, seq_len, seq_len] x0_len = x0_len_list.view(batch_size, 1, 1) # [batch_size, 1, 1] x0_flag_q = q_idx < x0_len # [batch_size, seq_len, seq_len] x0_flag_kv = kv_idx < x0_len # Block indices calculation [batch_size, seq_len, seq_len] q_block_idx = (q_idx - x0_len) // block_size kv_block_idx = (kv_idx - x0_len) // block_size # causal block (top-left) block_causal = x0_flag_q & x0_flag_kv & (q_idx >= kv_idx) mutual_condition = (q_idx >= kv_idx) if causal_attn else torch.ones_like(q_idx, dtype=torch.bool) block_mutual = ( ~x0_flag_q & ~x0_flag_kv & (q_block_idx == kv_block_idx) & mutual_condition ) q_blk = torch.div(q_idx - x0_len, block_size, rounding_mode='floor') q_blk_start = (x0_len_list.view(batch_size, 1) + q_blk[:, :, 0] * block_size).clamp(min=0, max=seq_len - 1) prefix_len = position_ids.gather(1, q_blk_start) prefix_len = prefix_len.unsqueeze(2) block_prefix = (~x0_flag_q & x0_flag_kv) & (kv_idx < prefix_len) final_mask = (block_causal | block_mutual | block_prefix) customized_mask = torch.full_like(final_mask, float('-inf'), dtype=torch.bfloat16) customized_mask.masked_fill_(final_mask, 0.0) return customized_mask.unsqueeze(1).to(device=device), final_mask.unsqueeze(1).to(device=device) def find_pred_pos_from_input_ids( input_ids: torch.LongTensor = None, text_mask_token_id: int = None, ) -> torch.Tensor: """Compute the relative prediction positions for masked tokens in a sequence. For non-masked positions, the output is 0. For masked positions, the value increments by 1 for each consecutive mask token, indicating how many steps ahead the prediction is. Args: input_ids (torch.LongTensor): Input token IDs of shape [batch_size, seq_len]. text_mask_token_id (int, optional): Token ID representing masked positions. Defaults to 151666. Returns: torch.Tensor: A tensor of shape [batch_size, seq_len] where: - 0 indicates a non-masked token. - n > 0 indicates the nth consecutive masked token (e.g., 1 = first mask, 2 = second mask, etc.). """ batch_size, seq_len = input_ids.shape device = input_ids.device is_mask = (input_ids == text_mask_token_id) base_mask = torch.zeros((batch_size, seq_len), dtype=torch.int8, device=device) for b in range(batch_size): for ix in range(1, seq_len): if is_mask[b][ix] == True: # Increment counter if current token is masked base_mask[b][ix] = base_mask[b][ix-1] + 1 return base_mask