File size: 9,266 Bytes
4fb7fbd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 | """
Utility functions for dynamic block training with variable-length blocks.
This module provides functions to extract block boundaries from <EOB> tokens
and generate attention masks for variable-length blocks.
"""
import torch
from typing import List, Tuple
def calculate_block_nums_from_eob(
input_ids: torch.Tensor,
num_tokens_list: List[List[int]],
eob_token_id: int,
lines_per_block: int = 1
) -> List[List[torch.Tensor]]:
"""
Extract variable block lengths from <EOB> token positions, respecting packed sample boundaries.
Args:
input_ids: Token IDs tensor, shape (batch_size, seq_len)
num_tokens_list: List of lists, where each inner list contains sequence lengths for a batch item.
(Output from calculate_token_nums)
eob_token_id: Token ID for <EOB>
lines_per_block: Number of lines (EOBs) to group into a single block.
Default 1 means each line is its own block.
Set to 2 or 3 to have multiple EOBs within a block,
allowing the model to learn EOB positions within blocks.
Returns:
List of lists of tensors. Outer list is batch. Inner list is samples.
Each tensor contains block lengths for that sample.
"""
batch_size, seq_len = input_ids.shape
all_batch_block_lengths = []
for i in range(batch_size):
current_ids = input_ids[i]
sample_lengths = num_tokens_list[i] # List of integers
current_sample_block_lengths = []
start_idx = 0
for length in sample_lengths:
# Handle tensor or int
if isinstance(length, torch.Tensor):
length = length.item()
# Extract sample tokens
end_idx = start_idx + length
# Ensure we don't go out of bounds (e.g. if sum(lengths) != seq_len due to padding logic differences)
# But typically sum(lengths) == seq_len for packed data + padding
end_idx = min(end_idx, seq_len)
if start_idx >= seq_len:
break
sample_ids = current_ids[start_idx:end_idx]
# Find positions of <EOB> tokens in this sample
eob_positions = torch.nonzero(sample_ids == eob_token_id).flatten()
# Calculate block lengths for this sample
if len(eob_positions) == 0:
# No EOB tokens, treat entire sample as one block
block_lengths = torch.tensor([length], device=input_ids.device)
else:
# Group N consecutive lines (EOBs) into one block
if lines_per_block > 1 and len(eob_positions) >= lines_per_block:
# Only use every Nth EOB as a block boundary
# Example: lines_per_block=2, EOBs at [5, 12, 18, 25]
# Use EOBs at indices [1, 3] -> positions [12, 25]
eob_positions = eob_positions[lines_per_block-1::lines_per_block]
# Add start and end positions
# EOB is included in its block (boundary marker)
boundaries = torch.cat([
torch.tensor([0], device=input_ids.device),
eob_positions + 1, # +1 to include EOB token in block
torch.tensor([length], device=input_ids.device)
])
block_lengths = torch.diff(boundaries)
# Filter out 0-length blocks (happens when EOB is at the end of the sample)
block_lengths = block_lengths[block_lengths > 0]
current_sample_block_lengths.append(block_lengths)
start_idx = end_idx
all_batch_block_lengths.append(current_sample_block_lengths)
return all_batch_block_lengths
def block_diff_mask_dynamic(b, h, q_idx, kv_idx, block_boundaries=None, n=None):
"""
Dynamic block diffusion mask using precomputed block boundaries.
This replaces the fixed block_size arithmetic with torch.searchsorted
to support variable-length blocks.
Args:
b: Batch index (unused in mask logic)
h: Head index (unused in mask logic)
q_idx: Query indices tensor
kv_idx: Key-value indices tensor
block_boundaries: Cumulative sum of block lengths, e.g., [0, 4, 12, 16]
This maps: tokens 0-3 → block 0, 4-11 → block 1, 12-15 → block 2
n: Number of denoised (clean) tokens
Returns:
Boolean attention mask (True = can attend)
The mask combines three types:
- M_BD (Block Diagonal): Self-attention within noised blocks
- M_OBC (Offset Block Causal): Cross-attention from noised to conditional context
- M_BC (Block Causal): Attention to denoised blocks
"""
# Map indices to block IDs (handling both Noisy 0..n-1 and Clean n..2n-1)
# We use modulo n to map Clean tokens back to their relative position
q_mod = q_idx % n
kv_mod = kv_idx % n
# Use searchsorted to find which block each index belongs to
# right=True ensures that [0, 4] maps 0,1,2,3 to the first interval
# We subtract 1 to get 0-based block indices
q_block_id = torch.searchsorted(block_boundaries, q_mod, right=True) - 1
kv_block_id = torch.searchsorted(block_boundaries, kv_mod, right=True) - 1
# Clamp to handle edge cases
q_block_id = torch.clamp(q_block_id, 0, len(block_boundaries) - 2)
kv_block_id = torch.clamp(kv_block_id, 0, len(block_boundaries) - 2)
# Identify Noisy vs Clean
# Noisy: < n (x0_flag = False)
# Clean: >= n (x0_flag = True)
is_clean_q = q_idx >= n
is_clean_kv = kv_idx >= n
# **1. Block Diagonal Mask (M_BD) **
# Self-attention within blocks (Noisy->Noisy, Clean->Clean)
M_BD = (q_block_id == kv_block_id) & (is_clean_q == is_clean_kv)
# **2. Offset Block-Causal Mask (M_OBC) **
# Noisy i attends to Clean j < i
# (Original code: block_q > block_kv & clean_kv & noisy_q)
M_OBC = (q_block_id > kv_block_id) & (is_clean_kv) & (~is_clean_q)
# **3. Block-Causal Mask (M_BC) **
# Clean i attends to Clean j <= i
M_BC = (q_block_id >= kv_block_id) & (is_clean_kv) & (is_clean_q)
# **4. Combine Masks **
return M_BD | M_OBC | M_BC
def block_attn_mask_dynamic(
nested_block_lengths_list: List[List[torch.Tensor]],
device: torch.device
) -> torch.Tensor:
"""
Construct attention masks for variable-length blocks, handling packed sequences.
Args:
nested_block_lengths_list: List (batch) of Lists (samples) of Tensors (block lengths).
device: Device to create tensors on
Returns:
Attention mask tensor, shape (batch_size, total_seq_len*2, total_seq_len*2)
"""
masks = []
for sample_block_lengths_list in nested_block_lengths_list:
sample_masks = []
for block_lengths in sample_block_lengths_list:
# Calculate total sequence length for this sample
total_len = block_lengths.sum().item()
if total_len == 0:
continue
n = total_len # Number of clean tokens
# Create block boundaries (cumulative sum)
block_boundaries = torch.cat([
torch.tensor([0], device=device),
torch.cumsum(block_lengths, dim=0)
])
# Create index tensors for the full 2n x 2n mask
seq_len_doubled = total_len * 2
q_idx = torch.arange(seq_len_doubled, device=device)[:, None]
kv_idx = torch.arange(seq_len_doubled, device=device)[None, :]
# Generate mask using dynamic block boundaries
mask = block_diff_mask_dynamic(
b=None,
h=None,
q_idx=q_idx,
kv_idx=kv_idx,
block_boundaries=block_boundaries,
n=n
)
sample_masks.append(mask)
# Combine sample masks into a single block-diagonal mask for the batch item
if sample_masks:
row_mask = torch.block_diag(*sample_masks)
else:
# Should not happen if input is valid
row_mask = torch.zeros((0, 0), device=device, dtype=torch.bool)
masks.append(row_mask)
# Stack into batch
# We assume all row_masks have the same size (2 * seq_len)
# If not (due to padding differences?), we might need to pad them.
# But calculate_token_nums usually covers the whole seq_len including padding.
# Check sizes
sizes = [m.shape[0] for m in masks]
max_size = max(sizes)
padded_masks = []
for m in masks:
if m.shape[0] < max_size:
# Pad with False (no attention)
pad_size = max_size - m.shape[0]
m = torch.nn.functional.pad(m, (0, pad_size, 0, pad_size), value=False)
padded_masks.append(m)
masks = torch.stack(padded_masks, dim=0)
return masks
|