| | import torch |
| | import copy |
| |
|
| | def find_prefix_seq_length_by_pe( |
| | pe: torch.Tensor |
| | ) -> torch.Tensor: |
| | """ |
| | Find the sequence length where position encoding drops (indicating prefix boundary). |
| | Args: |
| | pe: Position encoding tensor of shape [Batch size, Sequence length ] |
| | Contains position indices for each token in the sequence. |
| | Returns: |
| | torch.Tensor: A tensor of shape [B] containing: |
| | - The index where position encoding drops for each sequence |
| | - -1 if no drop occurs in the sequence |
| | """ |
| | batch_size, seq_len = pe.shape |
| | prev = pe[:, :-1] |
| | curr = pe[:, 1:] |
| | drop_mask = curr < prev |
| |
|
| | seq_len = torch.full((batch_size,), -1, dtype=torch.long) |
| | |
| | for b in range(batch_size): |
| | drop_pos = torch.nonzero(drop_mask[b], as_tuple=False) |
| | if drop_pos.numel() > 0: |
| | i = drop_pos[0].item() + 1 |
| | seq_len[b] = i |
| |
|
| | return seq_len |
| |
|
| |
|
| |
|
| | def update_causal_mask_with_pad_non_visible_2d( |
| | input_ids: torch.Tensor, |
| | attn_mask_2d: torch.Tensor, |
| | text_mask_token_id: int = 151666, |
| | block_size: int = 4, |
| | causal_attn: bool = False |
| | ) -> torch.Tensor: |
| | """ |
| | Updates a 2D attention mask for hole sequence through input_ids and text_mask_token_id |
| | |
| | Args: |
| | input_ids: Input token IDs (unused in current implementation) |
| | attn_mask_2d: 2D attention mask matrix of shape [seq_len, seq_len] where: |
| | - 0.0 indicates allowed attention |
| | - -inf indicates masked attention |
| | text_mask_token_id: ID representing masked tokens |
| | block_size: Size of the diffusion window |
| | causal_attn: If True, maintains strict causal masking throughout |
| | |
| | Returns: |
| | Modified attention mask with updated visibility patterns |
| | """ |
| | seq_len = input_ids.shape[0] |
| | device = input_ids.device |
| |
|
| | |
| | input_mask = input_ids.eq(text_mask_token_id) |
| | input_before_mask = torch.zeros_like(input_mask) |
| | input_before_mask[:-1] = input_mask[1:] |
| | mask_cols = (input_mask | input_before_mask) |
| | non_mask = ~mask_cols |
| |
|
| | rows = torch.arange(seq_len, device=device)[:, None] |
| | cols = torch.arange(seq_len, device=device) |
| |
|
| |
|
| | indices = torch.arange(seq_len, device=device) |
| | prev_non_mask = (indices * non_mask).cummax(dim=0).values |
| |
|
| | max_value = torch.iinfo(indices.dtype).max |
| | mask_indices = torch.where(non_mask, indices, torch.full_like(indices, max_value)) |
| | reversed_mask_indices = torch.flip(mask_indices, dims=[0]) |
| | reversed_cummin = reversed_mask_indices.cummin(dim=0).values |
| | next_non_mask = torch.flip(reversed_cummin, dims=[0]) |
| |
|
| | |
| | infra_mask = ( |
| | (cols > prev_non_mask) & |
| | (rows >= next_non_mask[None, :]) & |
| | mask_cols[None, :] |
| | ) |
| | attn_mask_2d.masked_fill_(infra_mask, -float('inf')) |
| |
|
| | |
| | if not causal_attn: |
| | visible_mask = ( |
| | (rows > prev_non_mask[None, :]) & |
| | (rows < cols) & |
| | mask_cols[None, :] |
| | ) |
| | attn_mask_2d.masked_fill_(visible_mask, 0.0) |
| |
|
| | return attn_mask_2d |
| |
|
| |
|
| | def update_causal_mask_for_one_gen_window_2d( |
| | input_ids: torch.Tensor, |
| | attn_mask_2d: torch.Tensor, |
| | block_size: int = 4, |
| | use_cache: bool = True, |
| | causal_attn: bool = False |
| | ) -> torch.Tensor: |
| | """ |
| | Updates a 2D attention mask for a diffusion window in transformer inference. |
| | |
| | Args: |
| | input_ids: Input token IDs (unused in current implementation) |
| | attn_mask_2d: 2D attention mask matrix of shape [seq_len, seq_len] where: |
| | - 0.0 indicates allowed attention |
| | - -inf indicates masked attention |
| | block_size: Size of the diffusion window |
| | use_cache: Whether key-value cache is being used |
| | causal_attn: If True, maintains strict causal masking throughout |
| | |
| | Returns: |
| | Modified attention mask with updated visibility patterns |
| | """ |
| |
|
| | if not causal_attn: |
| | |
| | |
| | attn_mask_2d[-block_size:, -block_size:] = 0.0 |
| | if use_cache: |
| | |
| | attn_mask_2d[-block_size:, -block_size-1] = -float('inf') |
| |
|
| | return attn_mask_2d |
| |
|
| |
|
| | def create_block_diff_mask_by_pe_1d( |
| | b: int, |
| | h: int, |
| | q_idx: torch.Tensor, |
| | kv_idx: torch.Tensor, |
| | block_size: int, |
| | x0_len_list: torch.Tensor, |
| | position_ids_list: torch.Tensor, |
| | causal_attn: bool = False, |
| | ) -> torch.Tensor: |
| | """Computes attention mask for a single query-key position in Flex Attention. |
| | |
| | Args: |
| | b (int): Batch index (0 <= b < batch_size). |
| | h (int): Head index (unused in current implementation, reserved for future multi-head support). |
| | q_idx (torch.Tensor): Query position index (scalar or 0D tensor). |
| | kv_idx (torch.Tensor): Key/Value position index (scalar or 0D tensor). |
| | block_size (int): Size of processing blocks for non-`x0` tokens. |
| | x0_len_list (torch.Tensor): Tensor of shape [batch_size] with `x0` segment lengths. |
| | position_ids_list (torch.Tensor): Tensor of shape [batch_size, seq_len] with position IDs. |
| | causal_attn (bool, optional): Enforces causal masking in mutual blocks if True. Defaults to False. |
| | |
| | Returns: |
| | torch.Tensor: Boolean indicating whether attention is allowed (True = allowed). |
| | """ |
| | x0_len = x0_len_list[b] |
| | position_ids = position_ids_list[b] |
| |
|
| | x0_flag_q = (q_idx < x0_len) |
| | x0_flag_kv = (kv_idx < x0_len) |
| |
|
| | |
| | block_causal = ( |
| | x0_flag_q & \ |
| | x0_flag_kv & \ |
| | (q_idx >= kv_idx) |
| | ) |
| |
|
| | q_ith_block = (q_idx - x0_len) // block_size |
| | kv_ith_block = (kv_idx - x0_len) // block_size |
| |
|
| | |
| | block_mutual = ( |
| | (~x0_flag_q & ~x0_flag_kv) & \ |
| | (q_ith_block == kv_ith_block) & \ |
| | (q_idx >= kv_idx if causal_attn else 1) |
| | ) |
| |
|
| | |
| | prefix_len = position_ids[x0_len + q_ith_block * block_size] |
| | block_prefix = ( |
| | (~x0_flag_q & x0_flag_kv) & \ |
| | (kv_idx < prefix_len) |
| | ) |
| | |
| | mask_val = (block_causal | block_mutual | block_prefix) |
| | return mask_val.to(torch.bool) |
| |
|
| |
|
| | def create_block_diff_mask_by_pe_4d( |
| | block_size: int, |
| | x0_len_list: torch.Tensor, |
| | position_ids: torch.Tensor, |
| | causal_attn: bool = False |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """Generates a 4D attention mask for block-difference attention patterns. |
| | |
| | The mask consists of three regions: |
| | 1. Causal block (top-left): Standard causal attention for `x0` tokens. |
| | 2. Mutual block (bottom-right): Non-causal attention within the same block for non-`x0` tokens. |
| | 3. Prefix block (bottom-left): Non-`x0` tokens can attend to a prefix of `x0` tokens. |
| | |
| | Args: |
| | block_size (int): Size of processing blocks for non-`x0` tokens. |
| | x0_len_list (torch.Tensor): Tensor of shape [B] containing lengths of `x0` segments per batch. |
| | position_ids (torch.Tensor): Tensor of shape [B, seq_len] containing position IDs. |
| | causal_attn (bool, optional): If True, enforces causal masking in mutual blocks. Defaults to False. |
| | |
| | Returns: |
| | tuple[torch.Tensor, torch.Tensor]: |
| | - A float mask of shape [batch_size, 1, seq_len, seq_len] with `-inf` for masked positions (non visiable). |
| | - A boolean mask of shape [batch_size, 1, seq_len, seq_len] indicating allowed attention positions. |
| | """ |
| | batch_size, seq_len = position_ids.shape |
| | device = position_ids.device |
| | |
| | |
| | q_idx = torch.arange(seq_len, device=device).view(1, seq_len, 1) |
| | kv_idx = torch.arange(seq_len, device=device).view(1, 1, seq_len) |
| | |
| | |
| | x0_len = x0_len_list.view(batch_size, 1, 1) |
| | x0_flag_q = q_idx < x0_len |
| | x0_flag_kv = kv_idx < x0_len |
| | |
| | |
| | q_block_idx = (q_idx - x0_len) // block_size |
| | kv_block_idx = (kv_idx - x0_len) // block_size |
| | |
| | |
| | block_causal = x0_flag_q & x0_flag_kv & (q_idx >= kv_idx) |
| |
|
| | |
| | mutual_condition = (q_idx >= kv_idx) if causal_attn else torch.ones_like(q_idx, dtype=torch.bool) |
| | block_mutual = (~x0_flag_q & ~x0_flag_kv & |
| | (q_block_idx == kv_block_idx) & |
| | mutual_condition) |
| |
|
| | |
| | q_blk = torch.div(q_idx - x0_len, block_size, rounding_mode='floor') |
| | q_blk_start = (x0_len_list.view(batch_size, 1) + q_blk[:, :, 0] * block_size).clamp(min=0, max=seq_len-1) |
| | prefix_len = position_ids.gather(1, q_blk_start) |
| | prefix_len = prefix_len.unsqueeze(2) |
| | block_prefix = (~x0_flag_q & x0_flag_kv) & (kv_idx < prefix_len) |
| |
|
| | |
| | |
| | |
| | |
| | final_mask = (block_causal | block_mutual | block_prefix) |
| | |
| | customized_mask = torch.full_like(final_mask, float('-inf'), dtype=torch.bfloat16) |
| | customized_mask.masked_fill_(final_mask, 0.0) |
| | |
| | |
| | return customized_mask.unsqueeze(1).to(device=device), final_mask.unsqueeze(1).to(device=device) |
| |
|
| |
|
| | def find_pred_pos_from_input_ids( |
| | input_ids: torch.LongTensor = None, |
| | text_mask_token_id: int = 151666, |
| | ) -> torch.Tensor: |
| | """Compute the relative prediction positions for masked tokens in a sequence. |
| | |
| | For non-masked positions, the output is 0. For masked positions, the value increments |
| | by 1 for each consecutive mask token, indicating how many steps ahead the prediction is. |
| | |
| | Args: |
| | input_ids (torch.LongTensor): Input token IDs of shape [batch_size, seq_len]. |
| | text_mask_token_id (int, optional): Token ID representing masked positions. Defaults to 151666. |
| | |
| | Returns: |
| | torch.Tensor: A tensor of shape [batch_size, seq_len] where: |
| | - 0 indicates a non-masked token. |
| | - n > 0 indicates the nth consecutive masked token (e.g., 1 = first mask, 2 = second mask, etc.). |
| | """ |
| | batch_size, seq_len = input_ids.shape |
| | device = input_ids.device |
| |
|
| | is_mask = (input_ids == text_mask_token_id) |
| |
|
| | base_mask = torch.zeros((batch_size, seq_len), dtype=torch.int8, device=device) |
| |
|
| | for b in range(batch_size): |
| | for ix in range(1, seq_len): |
| | if is_mask[b][ix] == True: |
| | |
| | base_mask[b][ix] = base_mask[b][ix-1] + 1 |
| |
|
| | return base_mask |
| |
|