File size: 11,714 Bytes
fb67af8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
"""Multi-Head Attention with RoPE integration and memory optimizations.

Critical implementation details:
1. Apply RoPE only to Q and K, never to V
2. Use SDPA for Flash Attention 2 support
3. Pre-normalization architecture
4. Memory-efficient implementation
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple
from .rope import RotaryPositionEmbeddings


class MultiHeadAttention(nn.Module):
    """Multi-Head Attention with RoPE and Flash Attention support.
    
    This implementation:
    - Uses Rotary Position Embeddings (RoPE) on Q and K only
    - Supports Flash Attention 2 via torch.nn.functional.scaled_dot_product_attention
    - Uses no bias terms (modern approach)
    - Includes proper causal masking
    - Memory-efficient implementation
    """
    
    def __init__(
        self,
        d_model: int = 768,
        n_heads: int = 12,
        dropout: float = 0.1,
        max_seq_len: int = 2048,
        rope_base: int = 10000,
        rope_percentage: float = 0.5,
        use_flash_attention: bool = True,
    ):
        super().__init__()
        
        assert d_model % n_heads == 0, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"

        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        # Windows Flash Attention: Test with PyTorch 2.10+ nightly
        # Older versions had freezing issues, but newer versions may work
        import sys
        import logging
        logger = logging.getLogger(__name__)

        if sys.platform == 'win32' and use_flash_attention:
            # Allow Flash Attention on Windows with PyTorch 2.10+
            # If freezing occurs, set use_flash_attention: false in config
            self.use_flash_attention = use_flash_attention
            logger.info("[Windows] Attempting Flash Attention with PyTorch 2.10+ - if freezing occurs, disable in config")
        elif sys.platform == 'win32':
            self.use_flash_attention = False
            logger.info("[Windows] Flash Attention disabled - using manual attention")
        else:
            self.use_flash_attention = use_flash_attention

        self.dropout = dropout
        self.scale = 1.0 / math.sqrt(self.head_dim)
        
        # Q, K, V projections (no bias)
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.o_proj = nn.Linear(d_model, d_model, bias=False)
        
        # RoPE for positional encoding
        # Apply to only part of head dimensions (typically 50%)
        rope_dim = int(self.head_dim * rope_percentage)
        self.rope_dim = rope_dim
        self.rope = RotaryPositionEmbeddings(
            head_dim=rope_dim,
            max_seq_len=max_seq_len,
            base=rope_base
        )
        
        # Dropout
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        
        # Pre-allocate causal mask more efficiently
        # We'll create it on-demand based on sequence length
        self.register_buffer('cached_mask', None, persistent=False)
        self.register_buffer('cached_mask_size', torch.tensor(0), persistent=False)
    
    def _get_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
        """Get or create causal mask for the given sequence length.

        CRITICAL: Always returns mask on the specified device to prevent CPU OOM errors.
        """
        if self.cached_mask is None or self.cached_mask_size < seq_len:
            # Create a new mask directly on the target device
            mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
            mask = mask.masked_fill(mask == 1, float('-inf'))
            self.cached_mask = mask
            self.cached_mask_size = torch.tensor(seq_len)

        # CRITICAL: Ensure the returned mask is on the correct device
        # This prevents CPU OOM when broadcasting during attn_scores + causal_mask
        return self.cached_mask[:seq_len, :seq_len].to(device)
    
    def _apply_rope(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        position_ids: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply RoPE to partial dimensions of Q and K.
        
        Args:
            q: Query tensor [batch, seq_len, n_heads, head_dim]
            k: Key tensor [batch, seq_len, n_heads, head_dim]
            position_ids: Optional custom position IDs
        
        Returns:
            Rotated Q and K tensors
        """
        # Split into RoPE and pass-through dimensions
        if self.rope_dim > 0:
            q_rope, q_pass = q[..., :self.rope_dim], q[..., self.rope_dim:]
            k_rope, k_pass = k[..., :self.rope_dim], k[..., self.rope_dim:]
            
            # Apply RoPE to the first part
            q_rope, k_rope = self.rope(q_rope, k_rope, position_ids)
            
            # Concatenate back
            q = torch.cat([q_rope, q_pass], dim=-1)
            k = torch.cat([k_rope, k_pass], dim=-1)
        
        return q, k
    
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        use_cache: bool = False,
        past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        """Forward pass of multi-head attention.

        Args:
            x: Input tensor [batch, seq_len, d_model]
            attention_mask: Optional attention mask
            position_ids: Optional position IDs for RoPE
            use_cache: Whether to return KV cache for inference
            past_kv: Past key-value cache for inference

        Returns:
            Output tensor and optional KV cache
        """
        batch_size, seq_len, _ = x.size()

        # Project to Q, K, V
        q = self.q_proj(x)  # [batch, seq_len, d_model]
        k = self.k_proj(x)  # [batch, seq_len, d_model]
        v = self.v_proj(x)  # [batch, seq_len, d_model]

        # Reshape for multi-head attention
        # [batch, seq_len, d_model] -> [batch, seq_len, n_heads, head_dim]
        q = q.view(batch_size, seq_len, self.n_heads, self.head_dim)
        k = k.view(batch_size, seq_len, self.n_heads, self.head_dim)
        v = v.view(batch_size, seq_len, self.n_heads, self.head_dim)

        # Apply RoPE to Q and K only (not V!)
        q, k = self._apply_rope(q, k, position_ids)
        
        # Handle KV cache for inference
        if use_cache and past_kv is not None:
            past_k, past_v = past_kv
            k = torch.cat([past_k, k], dim=1)
            v = torch.cat([past_v, v], dim=1)
        
        kv_cache = (k, v) if use_cache else None
        
        # Transpose for attention computation
        # [batch, seq_len, n_heads, head_dim] -> [batch, n_heads, seq_len, head_dim]
        q = q.transpose(1, 2).contiguous()
        k = k.transpose(1, 2).contiguous()
        v = v.transpose(1, 2).contiguous()

        # Use Flash Attention 2 via SDPA when available
        # This is MUCH more memory efficient than manual attention
        if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'):
            # Flash Attention 2 is automatically used when available
            # It handles the causal mask internally when is_causal=True
            # NOTE: Windows compatibility - skip context manager to avoid freezing
            import sys
            if sys.platform == 'win32':
                # On Windows, use SDPA without explicit kernel selection
                attn_output = F.scaled_dot_product_attention(
                    q, k, v,
                    attn_mask=attention_mask,
                    dropout_p=self.dropout if self.training else 0.0,
                    is_causal=True if attention_mask is None else False,
                    scale=self.scale,
                )
            else:
                # On Linux, use explicit kernel selection for best performance
                with torch.backends.cuda.sdp_kernel(
                    enable_flash=True,  # Use Flash Attention when possible
                    enable_math=True,   # Fallback to math implementation
                    enable_mem_efficient=True  # Use memory-efficient attention
                ):
                    attn_output = F.scaled_dot_product_attention(
                        q, k, v,
                        attn_mask=attention_mask,
                        dropout_p=self.dropout if self.training else 0.0,
                        is_causal=True if attention_mask is None else False,
                        scale=self.scale,
                    )
        else:
            # Manual attention computation (fallback)
            # This is memory-intensive and should only be used for small sequences
            attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

            # Apply causal mask
            if attention_mask is None:
                causal_mask = self._get_causal_mask(seq_len, x.device)
                # Expand mask for batch and heads
                causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
                attn_scores = attn_scores + causal_mask
            else:
                attn_scores = attn_scores + attention_mask

            # Apply softmax
            attn_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
            attn_weights = self.attn_dropout(attn_weights)

            # Compute output
            attn_output = torch.matmul(attn_weights, v)
        
        # Reshape back
        # [batch, n_heads, seq_len, head_dim] -> [batch, seq_len, d_model]
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)

        # Output projection
        output = self.o_proj(attn_output)
        output = self.resid_dropout(output)

        return output, kv_cache


# Test the attention implementation
def test_attention():
    """Test multi-head attention with various configurations."""
    print("Testing Multi-Head Attention...")
    
    # Test configuration
    batch_size = 2
    seq_len = 128
    d_model = 768
    n_heads = 12
    
    # Create attention module
    attention = MultiHeadAttention(
        d_model=d_model,
        n_heads=n_heads,
        dropout=0.1,
        max_seq_len=2048,
        rope_percentage=0.5,
        use_flash_attention=True,  # Enable Flash Attention
    )
    
    # Move to GPU if available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    attention = attention.to(device)
    attention.eval()  # Set to eval mode for testing
    
    # Create dummy input
    x = torch.randn(batch_size, seq_len, d_model, device=device, dtype=torch.bfloat16)
    
    # Forward pass
    with torch.no_grad():
        output, _ = attention(x)
    
    # Check output shape
    assert output.shape == (batch_size, seq_len, d_model), \
        f"Expected shape {(batch_size, seq_len, d_model)}, got {output.shape}"
    
    # Check for NaN
    assert not torch.isnan(output).any(), "Output contains NaN values!"
    
    print("✓ Multi-Head Attention test passed!")
    print(f"  Input shape: {x.shape}")
    print(f"  Output shape: {output.shape}")
    print(f"  Device: {device}")
    print(f"  Memory allocated: {torch.cuda.memory_allocated(device) / 1024**3:.2f} GB")
    
    return True


if __name__ == "__main__":
    test_attention()