""" Utility functions for dynamic block training with variable-length blocks. This module provides functions to extract block boundaries from tokens and generate attention masks for variable-length blocks. """ import torch from typing import List, Tuple def calculate_block_nums_from_eob( input_ids: torch.Tensor, num_tokens_list: List[List[int]], eob_token_id: int, lines_per_block: int = 1 ) -> List[List[torch.Tensor]]: """ Extract variable block lengths from token positions, respecting packed sample boundaries. Args: input_ids: Token IDs tensor, shape (batch_size, seq_len) num_tokens_list: List of lists, where each inner list contains sequence lengths for a batch item. (Output from calculate_token_nums) eob_token_id: Token ID for lines_per_block: Number of lines (EOBs) to group into a single block. Default 1 means each line is its own block. Set to 2 or 3 to have multiple EOBs within a block, allowing the model to learn EOB positions within blocks. Returns: List of lists of tensors. Outer list is batch. Inner list is samples. Each tensor contains block lengths for that sample. """ batch_size, seq_len = input_ids.shape all_batch_block_lengths = [] for i in range(batch_size): current_ids = input_ids[i] sample_lengths = num_tokens_list[i] # List of integers current_sample_block_lengths = [] start_idx = 0 for length in sample_lengths: # Handle tensor or int if isinstance(length, torch.Tensor): length = length.item() # Extract sample tokens end_idx = start_idx + length # Ensure we don't go out of bounds (e.g. if sum(lengths) != seq_len due to padding logic differences) # But typically sum(lengths) == seq_len for packed data + padding end_idx = min(end_idx, seq_len) if start_idx >= seq_len: break sample_ids = current_ids[start_idx:end_idx] # Find positions of tokens in this sample eob_positions = torch.nonzero(sample_ids == eob_token_id).flatten() # Calculate block lengths for this sample if len(eob_positions) == 0: # No EOB tokens, treat entire sample as one block block_lengths = torch.tensor([length], device=input_ids.device) else: # Group N consecutive lines (EOBs) into one block if lines_per_block > 1 and len(eob_positions) >= lines_per_block: # Only use every Nth EOB as a block boundary # Example: lines_per_block=2, EOBs at [5, 12, 18, 25] # Use EOBs at indices [1, 3] -> positions [12, 25] eob_positions = eob_positions[lines_per_block-1::lines_per_block] # Add start and end positions # EOB is included in its block (boundary marker) boundaries = torch.cat([ torch.tensor([0], device=input_ids.device), eob_positions + 1, # +1 to include EOB token in block torch.tensor([length], device=input_ids.device) ]) block_lengths = torch.diff(boundaries) # Filter out 0-length blocks (happens when EOB is at the end of the sample) block_lengths = block_lengths[block_lengths > 0] current_sample_block_lengths.append(block_lengths) start_idx = end_idx all_batch_block_lengths.append(current_sample_block_lengths) return all_batch_block_lengths def block_diff_mask_dynamic(b, h, q_idx, kv_idx, block_boundaries=None, n=None): """ Dynamic block diffusion mask using precomputed block boundaries. This replaces the fixed block_size arithmetic with torch.searchsorted to support variable-length blocks. Args: b: Batch index (unused in mask logic) h: Head index (unused in mask logic) q_idx: Query indices tensor kv_idx: Key-value indices tensor block_boundaries: Cumulative sum of block lengths, e.g., [0, 4, 12, 16] This maps: tokens 0-3 → block 0, 4-11 → block 1, 12-15 → block 2 n: Number of denoised (clean) tokens Returns: Boolean attention mask (True = can attend) The mask combines three types: - M_BD (Block Diagonal): Self-attention within noised blocks - M_OBC (Offset Block Causal): Cross-attention from noised to conditional context - M_BC (Block Causal): Attention to denoised blocks """ # Map indices to block IDs (handling both Noisy 0..n-1 and Clean n..2n-1) # We use modulo n to map Clean tokens back to their relative position q_mod = q_idx % n kv_mod = kv_idx % n # Use searchsorted to find which block each index belongs to # right=True ensures that [0, 4] maps 0,1,2,3 to the first interval # We subtract 1 to get 0-based block indices q_block_id = torch.searchsorted(block_boundaries, q_mod, right=True) - 1 kv_block_id = torch.searchsorted(block_boundaries, kv_mod, right=True) - 1 # Clamp to handle edge cases q_block_id = torch.clamp(q_block_id, 0, len(block_boundaries) - 2) kv_block_id = torch.clamp(kv_block_id, 0, len(block_boundaries) - 2) # Identify Noisy vs Clean # Noisy: < n (x0_flag = False) # Clean: >= n (x0_flag = True) is_clean_q = q_idx >= n is_clean_kv = kv_idx >= n # **1. Block Diagonal Mask (M_BD) ** # Self-attention within blocks (Noisy->Noisy, Clean->Clean) M_BD = (q_block_id == kv_block_id) & (is_clean_q == is_clean_kv) # **2. Offset Block-Causal Mask (M_OBC) ** # Noisy i attends to Clean j < i # (Original code: block_q > block_kv & clean_kv & noisy_q) M_OBC = (q_block_id > kv_block_id) & (is_clean_kv) & (~is_clean_q) # **3. Block-Causal Mask (M_BC) ** # Clean i attends to Clean j <= i M_BC = (q_block_id >= kv_block_id) & (is_clean_kv) & (is_clean_q) # **4. Combine Masks ** return M_BD | M_OBC | M_BC def block_attn_mask_dynamic( nested_block_lengths_list: List[List[torch.Tensor]], device: torch.device ) -> torch.Tensor: """ Construct attention masks for variable-length blocks, handling packed sequences. Args: nested_block_lengths_list: List (batch) of Lists (samples) of Tensors (block lengths). device: Device to create tensors on Returns: Attention mask tensor, shape (batch_size, total_seq_len*2, total_seq_len*2) """ masks = [] for sample_block_lengths_list in nested_block_lengths_list: sample_masks = [] for block_lengths in sample_block_lengths_list: # Calculate total sequence length for this sample total_len = block_lengths.sum().item() if total_len == 0: continue n = total_len # Number of clean tokens # Create block boundaries (cumulative sum) block_boundaries = torch.cat([ torch.tensor([0], device=device), torch.cumsum(block_lengths, dim=0) ]) # Create index tensors for the full 2n x 2n mask seq_len_doubled = total_len * 2 q_idx = torch.arange(seq_len_doubled, device=device)[:, None] kv_idx = torch.arange(seq_len_doubled, device=device)[None, :] # Generate mask using dynamic block boundaries mask = block_diff_mask_dynamic( b=None, h=None, q_idx=q_idx, kv_idx=kv_idx, block_boundaries=block_boundaries, n=n ) sample_masks.append(mask) # Combine sample masks into a single block-diagonal mask for the batch item if sample_masks: row_mask = torch.block_diag(*sample_masks) else: # Should not happen if input is valid row_mask = torch.zeros((0, 0), device=device, dtype=torch.bool) masks.append(row_mask) # Stack into batch # We assume all row_masks have the same size (2 * seq_len) # If not (due to padding differences?), we might need to pad them. # But calculate_token_nums usually covers the whole seq_len including padding. # Check sizes sizes = [m.shape[0] for m in masks] max_size = max(sizes) padded_masks = [] for m in masks: if m.shape[0] < max_size: # Pad with False (no attention) pad_size = max_size - m.shape[0] m = torch.nn.functional.pad(m, (0, pad_size, 0, pad_size), value=False) padded_masks.append(m) masks = torch.stack(padded_masks, dim=0) return masks