| | """ |
| | Utility functions for dynamic block training with variable-length blocks. |
| | |
| | This module provides functions to extract block boundaries from <EOB> tokens |
| | and generate attention masks for variable-length blocks. |
| | """ |
| |
|
| | import torch |
| | from typing import List, Tuple |
| |
|
| |
|
| | def calculate_block_nums_from_eob( |
| | input_ids: torch.Tensor, |
| | num_tokens_list: List[List[int]], |
| | eob_token_id: int, |
| | lines_per_block: int = 1 |
| | ) -> List[List[torch.Tensor]]: |
| | """ |
| | Extract variable block lengths from <EOB> token positions, respecting packed sample boundaries. |
| | |
| | Args: |
| | input_ids: Token IDs tensor, shape (batch_size, seq_len) |
| | num_tokens_list: List of lists, where each inner list contains sequence lengths for a batch item. |
| | (Output from calculate_token_nums) |
| | eob_token_id: Token ID for <EOB> |
| | lines_per_block: Number of lines (EOBs) to group into a single block. |
| | Default 1 means each line is its own block. |
| | Set to 2 or 3 to have multiple EOBs within a block, |
| | allowing the model to learn EOB positions within blocks. |
| | |
| | Returns: |
| | List of lists of tensors. Outer list is batch. Inner list is samples. |
| | Each tensor contains block lengths for that sample. |
| | """ |
| | batch_size, seq_len = input_ids.shape |
| | all_batch_block_lengths = [] |
| | |
| | for i in range(batch_size): |
| | current_ids = input_ids[i] |
| | sample_lengths = num_tokens_list[i] |
| | |
| | current_sample_block_lengths = [] |
| | start_idx = 0 |
| | |
| | for length in sample_lengths: |
| | |
| | if isinstance(length, torch.Tensor): |
| | length = length.item() |
| | |
| | |
| | end_idx = start_idx + length |
| | |
| | |
| | end_idx = min(end_idx, seq_len) |
| | |
| | if start_idx >= seq_len: |
| | break |
| | |
| | sample_ids = current_ids[start_idx:end_idx] |
| | |
| | |
| | eob_positions = torch.nonzero(sample_ids == eob_token_id).flatten() |
| | |
| | |
| | if len(eob_positions) == 0: |
| | |
| | block_lengths = torch.tensor([length], device=input_ids.device) |
| | else: |
| | |
| | if lines_per_block > 1 and len(eob_positions) >= lines_per_block: |
| | |
| | |
| | |
| | eob_positions = eob_positions[lines_per_block-1::lines_per_block] |
| | |
| | |
| | |
| | boundaries = torch.cat([ |
| | torch.tensor([0], device=input_ids.device), |
| | eob_positions + 1, |
| | torch.tensor([length], device=input_ids.device) |
| | ]) |
| | block_lengths = torch.diff(boundaries) |
| | |
| | block_lengths = block_lengths[block_lengths > 0] |
| | |
| | current_sample_block_lengths.append(block_lengths) |
| | start_idx = end_idx |
| | |
| | all_batch_block_lengths.append(current_sample_block_lengths) |
| | |
| | return all_batch_block_lengths |
| |
|
| |
|
| | def block_diff_mask_dynamic(b, h, q_idx, kv_idx, block_boundaries=None, n=None): |
| | """ |
| | Dynamic block diffusion mask using precomputed block boundaries. |
| | |
| | This replaces the fixed block_size arithmetic with torch.searchsorted |
| | to support variable-length blocks. |
| | |
| | Args: |
| | b: Batch index (unused in mask logic) |
| | h: Head index (unused in mask logic) |
| | q_idx: Query indices tensor |
| | kv_idx: Key-value indices tensor |
| | block_boundaries: Cumulative sum of block lengths, e.g., [0, 4, 12, 16] |
| | This maps: tokens 0-3 → block 0, 4-11 → block 1, 12-15 → block 2 |
| | n: Number of denoised (clean) tokens |
| | |
| | Returns: |
| | Boolean attention mask (True = can attend) |
| | |
| | The mask combines three types: |
| | - M_BD (Block Diagonal): Self-attention within noised blocks |
| | - M_OBC (Offset Block Causal): Cross-attention from noised to conditional context |
| | - M_BC (Block Causal): Attention to denoised blocks |
| | """ |
| | |
| | |
| | q_mod = q_idx % n |
| | kv_mod = kv_idx % n |
| | |
| | |
| | |
| | |
| | q_block_id = torch.searchsorted(block_boundaries, q_mod, right=True) - 1 |
| | kv_block_id = torch.searchsorted(block_boundaries, kv_mod, right=True) - 1 |
| | |
| | |
| | q_block_id = torch.clamp(q_block_id, 0, len(block_boundaries) - 2) |
| | kv_block_id = torch.clamp(kv_block_id, 0, len(block_boundaries) - 2) |
| | |
| | |
| | |
| | |
| | is_clean_q = q_idx >= n |
| | is_clean_kv = kv_idx >= n |
| | |
| | |
| | |
| | M_BD = (q_block_id == kv_block_id) & (is_clean_q == is_clean_kv) |
| |
|
| | |
| | |
| | |
| | M_OBC = (q_block_id > kv_block_id) & (is_clean_kv) & (~is_clean_q) |
| |
|
| | |
| | |
| | M_BC = (q_block_id >= kv_block_id) & (is_clean_kv) & (is_clean_q) |
| |
|
| | |
| | return M_BD | M_OBC | M_BC |
| |
|
| |
|
| | def block_attn_mask_dynamic( |
| | nested_block_lengths_list: List[List[torch.Tensor]], |
| | device: torch.device |
| | ) -> torch.Tensor: |
| | """ |
| | Construct attention masks for variable-length blocks, handling packed sequences. |
| | |
| | Args: |
| | nested_block_lengths_list: List (batch) of Lists (samples) of Tensors (block lengths). |
| | device: Device to create tensors on |
| | |
| | Returns: |
| | Attention mask tensor, shape (batch_size, total_seq_len*2, total_seq_len*2) |
| | """ |
| | masks = [] |
| | |
| | for sample_block_lengths_list in nested_block_lengths_list: |
| | sample_masks = [] |
| | |
| | for block_lengths in sample_block_lengths_list: |
| | |
| | total_len = block_lengths.sum().item() |
| | if total_len == 0: |
| | continue |
| | |
| | n = total_len |
| | |
| | |
| | block_boundaries = torch.cat([ |
| | torch.tensor([0], device=device), |
| | torch.cumsum(block_lengths, dim=0) |
| | ]) |
| | |
| | |
| | seq_len_doubled = total_len * 2 |
| | q_idx = torch.arange(seq_len_doubled, device=device)[:, None] |
| | kv_idx = torch.arange(seq_len_doubled, device=device)[None, :] |
| | |
| | |
| | mask = block_diff_mask_dynamic( |
| | b=None, |
| | h=None, |
| | q_idx=q_idx, |
| | kv_idx=kv_idx, |
| | block_boundaries=block_boundaries, |
| | n=n |
| | ) |
| | |
| | sample_masks.append(mask) |
| | |
| | |
| | if sample_masks: |
| | row_mask = torch.block_diag(*sample_masks) |
| | else: |
| | |
| | row_mask = torch.zeros((0, 0), device=device, dtype=torch.bool) |
| | |
| | masks.append(row_mask) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | sizes = [m.shape[0] for m in masks] |
| | max_size = max(sizes) |
| | |
| | padded_masks = [] |
| | for m in masks: |
| | if m.shape[0] < max_size: |
| | |
| | pad_size = max_size - m.shape[0] |
| | m = torch.nn.functional.pad(m, (0, pad_size, 0, pad_size), value=False) |
| | padded_masks.append(m) |
| | |
| | masks = torch.stack(padded_masks, dim=0) |
| | return masks |
| |
|
| |
|