armanakbari4's picture
Upload folder using huggingface_hub
4fb7fbd verified
"""
Utility functions for dynamic block training with variable-length blocks.
This module provides functions to extract block boundaries from <EOB> tokens
and generate attention masks for variable-length blocks.
"""
import torch
from typing import List, Tuple
def calculate_block_nums_from_eob(
input_ids: torch.Tensor,
num_tokens_list: List[List[int]],
eob_token_id: int,
lines_per_block: int = 1
) -> List[List[torch.Tensor]]:
"""
Extract variable block lengths from <EOB> token positions, respecting packed sample boundaries.
Args:
input_ids: Token IDs tensor, shape (batch_size, seq_len)
num_tokens_list: List of lists, where each inner list contains sequence lengths for a batch item.
(Output from calculate_token_nums)
eob_token_id: Token ID for <EOB>
lines_per_block: Number of lines (EOBs) to group into a single block.
Default 1 means each line is its own block.
Set to 2 or 3 to have multiple EOBs within a block,
allowing the model to learn EOB positions within blocks.
Returns:
List of lists of tensors. Outer list is batch. Inner list is samples.
Each tensor contains block lengths for that sample.
"""
batch_size, seq_len = input_ids.shape
all_batch_block_lengths = []
for i in range(batch_size):
current_ids = input_ids[i]
sample_lengths = num_tokens_list[i] # List of integers
current_sample_block_lengths = []
start_idx = 0
for length in sample_lengths:
# Handle tensor or int
if isinstance(length, torch.Tensor):
length = length.item()
# Extract sample tokens
end_idx = start_idx + length
# Ensure we don't go out of bounds (e.g. if sum(lengths) != seq_len due to padding logic differences)
# But typically sum(lengths) == seq_len for packed data + padding
end_idx = min(end_idx, seq_len)
if start_idx >= seq_len:
break
sample_ids = current_ids[start_idx:end_idx]
# Find positions of <EOB> tokens in this sample
eob_positions = torch.nonzero(sample_ids == eob_token_id).flatten()
# Calculate block lengths for this sample
if len(eob_positions) == 0:
# No EOB tokens, treat entire sample as one block
block_lengths = torch.tensor([length], device=input_ids.device)
else:
# Group N consecutive lines (EOBs) into one block
if lines_per_block > 1 and len(eob_positions) >= lines_per_block:
# Only use every Nth EOB as a block boundary
# Example: lines_per_block=2, EOBs at [5, 12, 18, 25]
# Use EOBs at indices [1, 3] -> positions [12, 25]
eob_positions = eob_positions[lines_per_block-1::lines_per_block]
# Add start and end positions
# EOB is included in its block (boundary marker)
boundaries = torch.cat([
torch.tensor([0], device=input_ids.device),
eob_positions + 1, # +1 to include EOB token in block
torch.tensor([length], device=input_ids.device)
])
block_lengths = torch.diff(boundaries)
# Filter out 0-length blocks (happens when EOB is at the end of the sample)
block_lengths = block_lengths[block_lengths > 0]
current_sample_block_lengths.append(block_lengths)
start_idx = end_idx
all_batch_block_lengths.append(current_sample_block_lengths)
return all_batch_block_lengths
def block_diff_mask_dynamic(b, h, q_idx, kv_idx, block_boundaries=None, n=None):
"""
Dynamic block diffusion mask using precomputed block boundaries.
This replaces the fixed block_size arithmetic with torch.searchsorted
to support variable-length blocks.
Args:
b: Batch index (unused in mask logic)
h: Head index (unused in mask logic)
q_idx: Query indices tensor
kv_idx: Key-value indices tensor
block_boundaries: Cumulative sum of block lengths, e.g., [0, 4, 12, 16]
This maps: tokens 0-3 → block 0, 4-11 → block 1, 12-15 → block 2
n: Number of denoised (clean) tokens
Returns:
Boolean attention mask (True = can attend)
The mask combines three types:
- M_BD (Block Diagonal): Self-attention within noised blocks
- M_OBC (Offset Block Causal): Cross-attention from noised to conditional context
- M_BC (Block Causal): Attention to denoised blocks
"""
# Map indices to block IDs (handling both Noisy 0..n-1 and Clean n..2n-1)
# We use modulo n to map Clean tokens back to their relative position
q_mod = q_idx % n
kv_mod = kv_idx % n
# Use searchsorted to find which block each index belongs to
# right=True ensures that [0, 4] maps 0,1,2,3 to the first interval
# We subtract 1 to get 0-based block indices
q_block_id = torch.searchsorted(block_boundaries, q_mod, right=True) - 1
kv_block_id = torch.searchsorted(block_boundaries, kv_mod, right=True) - 1
# Clamp to handle edge cases
q_block_id = torch.clamp(q_block_id, 0, len(block_boundaries) - 2)
kv_block_id = torch.clamp(kv_block_id, 0, len(block_boundaries) - 2)
# Identify Noisy vs Clean
# Noisy: < n (x0_flag = False)
# Clean: >= n (x0_flag = True)
is_clean_q = q_idx >= n
is_clean_kv = kv_idx >= n
# **1. Block Diagonal Mask (M_BD) **
# Self-attention within blocks (Noisy->Noisy, Clean->Clean)
M_BD = (q_block_id == kv_block_id) & (is_clean_q == is_clean_kv)
# **2. Offset Block-Causal Mask (M_OBC) **
# Noisy i attends to Clean j < i
# (Original code: block_q > block_kv & clean_kv & noisy_q)
M_OBC = (q_block_id > kv_block_id) & (is_clean_kv) & (~is_clean_q)
# **3. Block-Causal Mask (M_BC) **
# Clean i attends to Clean j <= i
M_BC = (q_block_id >= kv_block_id) & (is_clean_kv) & (is_clean_q)
# **4. Combine Masks **
return M_BD | M_OBC | M_BC
def block_attn_mask_dynamic(
nested_block_lengths_list: List[List[torch.Tensor]],
device: torch.device
) -> torch.Tensor:
"""
Construct attention masks for variable-length blocks, handling packed sequences.
Args:
nested_block_lengths_list: List (batch) of Lists (samples) of Tensors (block lengths).
device: Device to create tensors on
Returns:
Attention mask tensor, shape (batch_size, total_seq_len*2, total_seq_len*2)
"""
masks = []
for sample_block_lengths_list in nested_block_lengths_list:
sample_masks = []
for block_lengths in sample_block_lengths_list:
# Calculate total sequence length for this sample
total_len = block_lengths.sum().item()
if total_len == 0:
continue
n = total_len # Number of clean tokens
# Create block boundaries (cumulative sum)
block_boundaries = torch.cat([
torch.tensor([0], device=device),
torch.cumsum(block_lengths, dim=0)
])
# Create index tensors for the full 2n x 2n mask
seq_len_doubled = total_len * 2
q_idx = torch.arange(seq_len_doubled, device=device)[:, None]
kv_idx = torch.arange(seq_len_doubled, device=device)[None, :]
# Generate mask using dynamic block boundaries
mask = block_diff_mask_dynamic(
b=None,
h=None,
q_idx=q_idx,
kv_idx=kv_idx,
block_boundaries=block_boundaries,
n=n
)
sample_masks.append(mask)
# Combine sample masks into a single block-diagonal mask for the batch item
if sample_masks:
row_mask = torch.block_diag(*sample_masks)
else:
# Should not happen if input is valid
row_mask = torch.zeros((0, 0), device=device, dtype=torch.bool)
masks.append(row_mask)
# Stack into batch
# We assume all row_masks have the same size (2 * seq_len)
# If not (due to padding differences?), we might need to pad them.
# But calculate_token_nums usually covers the whole seq_len including padding.
# Check sizes
sizes = [m.shape[0] for m in masks]
max_size = max(sizes)
padded_masks = []
for m in masks:
if m.shape[0] < max_size:
# Pad with False (no attention)
pad_size = max_size - m.shape[0]
m = torch.nn.functional.pad(m, (0, pad_size, 0, pad_size), value=False)
padded_masks.append(m)
masks = torch.stack(padded_masks, dim=0)
return masks