File size: 9,266 Bytes
4fb7fbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
"""
Utility functions for dynamic block training with variable-length blocks.

This module provides functions to extract block boundaries from <EOB> tokens
and generate attention masks for variable-length blocks.
"""

import torch
from typing import List, Tuple


def calculate_block_nums_from_eob(
    input_ids: torch.Tensor,
    num_tokens_list: List[List[int]],
    eob_token_id: int,
    lines_per_block: int = 1
) -> List[List[torch.Tensor]]:
    """
    Extract variable block lengths from <EOB> token positions, respecting packed sample boundaries.
    
    Args:
        input_ids: Token IDs tensor, shape (batch_size, seq_len)
        num_tokens_list: List of lists, where each inner list contains sequence lengths for a batch item.
                        (Output from calculate_token_nums)
        eob_token_id: Token ID for <EOB>
        lines_per_block: Number of lines (EOBs) to group into a single block.
                        Default 1 means each line is its own block.
                        Set to 2 or 3 to have multiple EOBs within a block,
                        allowing the model to learn EOB positions within blocks.
        
    Returns:
        List of lists of tensors. Outer list is batch. Inner list is samples. 
        Each tensor contains block lengths for that sample.
    """
    batch_size, seq_len = input_ids.shape
    all_batch_block_lengths = []
    
    for i in range(batch_size):
        current_ids = input_ids[i]
        sample_lengths = num_tokens_list[i] # List of integers
        
        current_sample_block_lengths = []
        start_idx = 0
        
        for length in sample_lengths:
            # Handle tensor or int
            if isinstance(length, torch.Tensor):
                length = length.item()
                
            # Extract sample tokens
            end_idx = start_idx + length
            # Ensure we don't go out of bounds (e.g. if sum(lengths) != seq_len due to padding logic differences)
            # But typically sum(lengths) == seq_len for packed data + padding
            end_idx = min(end_idx, seq_len)
            
            if start_idx >= seq_len:
                break
                
            sample_ids = current_ids[start_idx:end_idx]
            
            # Find positions of <EOB> tokens in this sample
            eob_positions = torch.nonzero(sample_ids == eob_token_id).flatten()
            
            # Calculate block lengths for this sample
            if len(eob_positions) == 0:
                # No EOB tokens, treat entire sample as one block
                block_lengths = torch.tensor([length], device=input_ids.device)
            else:
                # Group N consecutive lines (EOBs) into one block
                if lines_per_block > 1 and len(eob_positions) >= lines_per_block:
                    # Only use every Nth EOB as a block boundary
                    # Example: lines_per_block=2, EOBs at [5, 12, 18, 25]
                    # Use EOBs at indices [1, 3] -> positions [12, 25]
                    eob_positions = eob_positions[lines_per_block-1::lines_per_block]
                
                # Add start and end positions
                # EOB is included in its block (boundary marker)
                boundaries = torch.cat([
                    torch.tensor([0], device=input_ids.device),
                    eob_positions + 1,  # +1 to include EOB token in block
                    torch.tensor([length], device=input_ids.device)
                ])
                block_lengths = torch.diff(boundaries)
            # Filter out 0-length blocks (happens when EOB is at the end of the sample)
            block_lengths = block_lengths[block_lengths > 0]
            
            current_sample_block_lengths.append(block_lengths)
            start_idx = end_idx
            
        all_batch_block_lengths.append(current_sample_block_lengths)
    
    return all_batch_block_lengths


def block_diff_mask_dynamic(b, h, q_idx, kv_idx, block_boundaries=None, n=None):
    """
    Dynamic block diffusion mask using precomputed block boundaries.
    
    This replaces the fixed block_size arithmetic with torch.searchsorted
    to support variable-length blocks.
    
    Args:
        b: Batch index (unused in mask logic)
        h: Head index (unused in mask logic)
        q_idx: Query indices tensor
        kv_idx: Key-value indices tensor
        block_boundaries: Cumulative sum of block lengths, e.g., [0, 4, 12, 16]
                         This maps: tokens 0-3 → block 0, 4-11 → block 1, 12-15 → block 2
        n: Number of denoised (clean) tokens
    
    Returns:
        Boolean attention mask (True = can attend)
        
    The mask combines three types:
        - M_BD (Block Diagonal): Self-attention within noised blocks
        - M_OBC (Offset Block Causal): Cross-attention from noised to conditional context
        - M_BC (Block Causal): Attention to denoised blocks
    """
    # Map indices to block IDs (handling both Noisy 0..n-1 and Clean n..2n-1)
    # We use modulo n to map Clean tokens back to their relative position
    q_mod = q_idx % n
    kv_mod = kv_idx % n
    
    # Use searchsorted to find which block each index belongs to
    # right=True ensures that [0, 4] maps 0,1,2,3 to the first interval
    # We subtract 1 to get 0-based block indices
    q_block_id = torch.searchsorted(block_boundaries, q_mod, right=True) - 1
    kv_block_id = torch.searchsorted(block_boundaries, kv_mod, right=True) - 1
    
    # Clamp to handle edge cases
    q_block_id = torch.clamp(q_block_id, 0, len(block_boundaries) - 2)
    kv_block_id = torch.clamp(kv_block_id, 0, len(block_boundaries) - 2)
    
    # Identify Noisy vs Clean
    # Noisy: < n (x0_flag = False)
    # Clean: >= n (x0_flag = True)
    is_clean_q = q_idx >= n
    is_clean_kv = kv_idx >= n
    
    # **1. Block Diagonal Mask (M_BD) **
    # Self-attention within blocks (Noisy->Noisy, Clean->Clean)
    M_BD = (q_block_id == kv_block_id) & (is_clean_q == is_clean_kv)

    # **2. Offset Block-Causal Mask (M_OBC) **
    # Noisy i attends to Clean j < i
    # (Original code: block_q > block_kv & clean_kv & noisy_q)
    M_OBC = (q_block_id > kv_block_id) & (is_clean_kv) & (~is_clean_q)

    # **3. Block-Causal Mask (M_BC) **
    # Clean i attends to Clean j <= i
    M_BC = (q_block_id >= kv_block_id) & (is_clean_kv) & (is_clean_q)

    # **4. Combine Masks **
    return M_BD | M_OBC | M_BC


def block_attn_mask_dynamic(
    nested_block_lengths_list: List[List[torch.Tensor]],
    device: torch.device
) -> torch.Tensor:
    """
    Construct attention masks for variable-length blocks, handling packed sequences.
    
    Args:
        nested_block_lengths_list: List (batch) of Lists (samples) of Tensors (block lengths).
        device: Device to create tensors on
        
    Returns:
        Attention mask tensor, shape (batch_size, total_seq_len*2, total_seq_len*2)
    """
    masks = []
    
    for sample_block_lengths_list in nested_block_lengths_list:
        sample_masks = []
        
        for block_lengths in sample_block_lengths_list:
            # Calculate total sequence length for this sample
            total_len = block_lengths.sum().item()
            if total_len == 0:
                continue
                
            n = total_len  # Number of clean tokens
            
            # Create block boundaries (cumulative sum)
            block_boundaries = torch.cat([
                torch.tensor([0], device=device),
                torch.cumsum(block_lengths, dim=0)
            ])
            
            # Create index tensors for the full 2n x 2n mask
            seq_len_doubled = total_len * 2
            q_idx = torch.arange(seq_len_doubled, device=device)[:, None]
            kv_idx = torch.arange(seq_len_doubled, device=device)[None, :]
            
            # Generate mask using dynamic block boundaries
            mask = block_diff_mask_dynamic(
                b=None,
                h=None,
                q_idx=q_idx,
                kv_idx=kv_idx,
                block_boundaries=block_boundaries,
                n=n
            )
            
            sample_masks.append(mask)
        
        # Combine sample masks into a single block-diagonal mask for the batch item
        if sample_masks:
            row_mask = torch.block_diag(*sample_masks)
        else:
            # Should not happen if input is valid
            row_mask = torch.zeros((0, 0), device=device, dtype=torch.bool)
            
        masks.append(row_mask)
    
    # Stack into batch
    # We assume all row_masks have the same size (2 * seq_len)
    # If not (due to padding differences?), we might need to pad them.
    # But calculate_token_nums usually covers the whole seq_len including padding.
    
    # Check sizes
    sizes = [m.shape[0] for m in masks]
    max_size = max(sizes)
    
    padded_masks = []
    for m in masks:
        if m.shape[0] < max_size:
            # Pad with False (no attention)
            pad_size = max_size - m.shape[0]
            m = torch.nn.functional.pad(m, (0, pad_size, 0, pad_size), value=False)
        padded_masks.append(m)
        
    masks = torch.stack(padded_masks, dim=0)
    return masks