| """ |
| Utility functions for dynamic block training with variable-length blocks. |
| |
| This module provides functions to extract block boundaries from <EOB> tokens |
| and generate attention masks for variable-length blocks. |
| """ |
|
|
| import torch |
| from typing import List, Tuple |
|
|
|
|
| def calculate_block_nums_from_eob( |
| input_ids: torch.Tensor, |
| num_tokens_list: List[List[int]], |
| eob_token_id: int |
| ) -> List[List[torch.Tensor]]: |
| """ |
| Extract variable block lengths from <EOB> token positions, respecting packed sample boundaries. |
| |
| Args: |
| input_ids: Token IDs tensor, shape (batch_size, seq_len) |
| num_tokens_list: List of lists, where each inner list contains sequence lengths for a batch item. |
| (Output from calculate_token_nums) |
| eob_token_id: Token ID for <EOB> |
| |
| Returns: |
| List of lists of tensors. Outer list is batch. Inner list is samples. |
| Each tensor contains block lengths for that sample. |
| """ |
| batch_size, seq_len = input_ids.shape |
| all_batch_block_lengths = [] |
| |
| for i in range(batch_size): |
| current_ids = input_ids[i] |
| sample_lengths = num_tokens_list[i] |
| |
| current_sample_block_lengths = [] |
| start_idx = 0 |
| |
| for length in sample_lengths: |
| |
| if isinstance(length, torch.Tensor): |
| length = length.item() |
| |
| |
| end_idx = start_idx + length |
| |
| |
| end_idx = min(end_idx, seq_len) |
| |
| if start_idx >= seq_len: |
| break |
| |
| sample_ids = current_ids[start_idx:end_idx] |
| |
| |
| eob_positions = torch.nonzero(sample_ids == eob_token_id).flatten() |
| |
| |
| if len(eob_positions) == 0: |
| |
| block_lengths = torch.tensor([length], device=input_ids.device) |
| else: |
| |
| |
| boundaries = torch.cat([ |
| torch.tensor([0], device=input_ids.device), |
| eob_positions + 1, |
| torch.tensor([length], device=input_ids.device) |
| ]) |
| block_lengths = torch.diff(boundaries) |
| |
| block_lengths = block_lengths[block_lengths > 0] |
| |
| current_sample_block_lengths.append(block_lengths) |
| start_idx = end_idx |
| |
| all_batch_block_lengths.append(current_sample_block_lengths) |
| |
| return all_batch_block_lengths |
|
|
|
|
| def block_diff_mask_dynamic(b, h, q_idx, kv_idx, block_boundaries=None, n=None): |
| """ |
| Dynamic block diffusion mask using precomputed block boundaries. |
| |
| This replaces the fixed block_size arithmetic with torch.searchsorted |
| to support variable-length blocks. |
| |
| Args: |
| b: Batch index (unused in mask logic) |
| h: Head index (unused in mask logic) |
| q_idx: Query indices tensor |
| kv_idx: Key-value indices tensor |
| block_boundaries: Cumulative sum of block lengths, e.g., [0, 4, 12, 16] |
| This maps: tokens 0-3 → block 0, 4-11 → block 1, 12-15 → block 2 |
| n: Number of denoised (clean) tokens |
| |
| Returns: |
| Boolean attention mask (True = can attend) |
| |
| The mask combines three types: |
| - M_BD (Block Diagonal): Self-attention within noised blocks |
| - M_OBC (Offset Block Causal): Cross-attention from noised to conditional context |
| - M_BC (Block Causal): Attention to denoised blocks |
| """ |
| |
| |
| q_mod = q_idx % n |
| kv_mod = kv_idx % n |
| |
| |
| |
| |
| q_block_id = torch.searchsorted(block_boundaries, q_mod, right=True) - 1 |
| kv_block_id = torch.searchsorted(block_boundaries, kv_mod, right=True) - 1 |
| |
| |
| q_block_id = torch.clamp(q_block_id, 0, len(block_boundaries) - 2) |
| kv_block_id = torch.clamp(kv_block_id, 0, len(block_boundaries) - 2) |
| |
| |
| |
| |
| is_clean_q = q_idx >= n |
| is_clean_kv = kv_idx >= n |
| |
| |
| |
| M_BD = (q_block_id == kv_block_id) & (is_clean_q == is_clean_kv) |
|
|
| |
| |
| |
| M_OBC = (q_block_id > kv_block_id) & (is_clean_kv) & (~is_clean_q) |
|
|
| |
| |
| M_BC = (q_block_id >= kv_block_id) & (is_clean_kv) & (is_clean_q) |
|
|
| |
| return M_BD | M_OBC | M_BC |
|
|
|
|
| def block_attn_mask_dynamic( |
| nested_block_lengths_list: List[List[torch.Tensor]], |
| device: torch.device |
| ) -> torch.Tensor: |
| """ |
| Construct attention masks for variable-length blocks, handling packed sequences. |
| |
| Args: |
| nested_block_lengths_list: List (batch) of Lists (samples) of Tensors (block lengths). |
| device: Device to create tensors on |
| |
| Returns: |
| Attention mask tensor, shape (batch_size, total_seq_len*2, total_seq_len*2) |
| """ |
| masks = [] |
| |
| for sample_block_lengths_list in nested_block_lengths_list: |
| sample_masks = [] |
| |
| for block_lengths in sample_block_lengths_list: |
| |
| total_len = block_lengths.sum().item() |
| if total_len == 0: |
| continue |
| |
| n = total_len |
| |
| |
| block_boundaries = torch.cat([ |
| torch.tensor([0], device=device), |
| torch.cumsum(block_lengths, dim=0) |
| ]) |
| |
| |
| seq_len_doubled = total_len * 2 |
| q_idx = torch.arange(seq_len_doubled, device=device)[:, None] |
| kv_idx = torch.arange(seq_len_doubled, device=device)[None, :] |
| |
| |
| mask = block_diff_mask_dynamic( |
| b=None, |
| h=None, |
| q_idx=q_idx, |
| kv_idx=kv_idx, |
| block_boundaries=block_boundaries, |
| n=n |
| ) |
| |
| sample_masks.append(mask) |
| |
| |
| if sample_masks: |
| row_mask = torch.block_diag(*sample_masks) |
| else: |
| |
| row_mask = torch.zeros((0, 0), device=device, dtype=torch.bool) |
| |
| masks.append(row_mask) |
| |
| |
| |
| |
| |
| |
| |
| sizes = [m.shape[0] for m in masks] |
| max_size = max(sizes) |
| |
| padded_masks = [] |
| for m in masks: |
| if m.shape[0] < max_size: |
| |
| pad_size = max_size - m.shape[0] |
| m = torch.nn.functional.pad(m, (0, pad_size, 0, pad_size), value=False) |
| padded_masks.append(m) |
| |
| masks = torch.stack(padded_masks, dim=0) |
| return masks |
|
|
|
|