File size: 15,509 Bytes
fb67af8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
"""Transformer block with pre-normalization architecture and memory optimizations.

Critical implementation details:
1. Pre-normalization: RMSNorm BEFORE attention and FFN
2. Residual connections after each sub-layer
3. Modern component stack: RoPE + RMSNorm + SwiGLU
4. Gradient checkpointing support for memory efficiency
"""

import torch
import torch.nn as nn
from typing import Optional, Tuple, Dict, Any
from torch.utils.checkpoint import checkpoint

from .rmsnorm import RMSNorm
from .attention import MultiHeadAttention
from .swiglu import SwiGLU


class TransformerBlock(nn.Module):
    """Single transformer block with pre-normalization.
    
    This follows the modern architecture used in LLaMA, Mistral, etc:
    - Pre-normalization with RMSNorm
    - Multi-head attention with RoPE
    - SwiGLU activation in FFN
    - Residual connections
    - Gradient checkpointing support
    """
    
    def __init__(
        self,
        d_model: int = 768,
        n_heads: int = 12,
        d_ffn: Optional[int] = None,
        dropout: float = 0.1,
        max_seq_len: int = 2048,
        rope_base: int = 10000,
        rope_percentage: float = 0.5,
        rms_norm_eps: float = 1e-6,
        use_flash_attention: bool = True,
        use_gradient_checkpointing: bool = False,
    ):
        super().__init__()
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.use_gradient_checkpointing = use_gradient_checkpointing
        
        # Pre-normalization layers
        self.attn_norm = RMSNorm(d_model, eps=rms_norm_eps)
        self.ffn_norm = RMSNorm(d_model, eps=rms_norm_eps)
        
        # Multi-head attention with RoPE
        self.attention = MultiHeadAttention(
            d_model=d_model,
            n_heads=n_heads,
            dropout=dropout,
            max_seq_len=max_seq_len,
            rope_base=rope_base,
            rope_percentage=rope_percentage,
            use_flash_attention=use_flash_attention,
        )
        
        # SwiGLU FFN
        # Default hidden dimension: 8/3 * d_model for parameter parity
        if d_ffn is None:
            d_ffn = int(8 * d_model / 3)
            # Round to multiple of 256 for hardware efficiency
            d_ffn = 256 * ((d_ffn + 255) // 256)
        
        self.ffn = SwiGLU(
            input_dim=d_model,
            hidden_dim=d_ffn,
            output_dim=d_model,
            bias=False,
        )
    
    def _attention_block(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        use_cache: bool = False,
        past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        """Attention sub-block with pre-norm."""
        # Pre-normalization
        x_norm = self.attn_norm(x)
        
        # Multi-head attention
        attn_output, kv_cache = self.attention(
            x_norm,
            attention_mask=attention_mask,
            position_ids=position_ids,
            use_cache=use_cache,
            past_kv=past_kv,
        )
        
        # Residual connection
        return attn_output, kv_cache
    
    def _ffn_block(self, x: torch.Tensor) -> torch.Tensor:
        """Feed-forward sub-block with pre-norm."""
        # Pre-normalization
        x_norm = self.ffn_norm(x)
        
        # Feed-forward
        ffn_output = self.ffn(x_norm)
        
        return ffn_output
    
    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        use_cache: bool = False,
        past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        """Forward pass of transformer block.
        
        Args:
            x: Input tensor [batch, seq_len, d_model]
            attention_mask: Optional attention mask
            position_ids: Optional position IDs for RoPE
            use_cache: Whether to return KV cache
            past_kv: Past key-value cache
        
        Returns:
            Output tensor and optional KV cache
        """
        # Attention block with residual
        if self.use_gradient_checkpointing and self.training:
            # Use gradient checkpointing to save memory during training
            def attention_fn(x_in):
                attn_out, _ = self._attention_block(
                    x_in,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    use_cache=False,  # Can't use cache with checkpointing
                    past_kv=None,
                )
                return attn_out
            
            attn_output = checkpoint(attention_fn, x, use_reentrant=False)
            kv_cache = None
        else:
            attn_output, kv_cache = self._attention_block(
                x,
                attention_mask=attention_mask,
                position_ids=position_ids,
                use_cache=use_cache,
                past_kv=past_kv,
            )
        
        # Add residual for attention
        x = x + attn_output
        
        # FFN block with residual
        if self.use_gradient_checkpointing and self.training:
            # Use gradient checkpointing for FFN as well
            ffn_output = checkpoint(self._ffn_block, x, use_reentrant=False)
        else:
            ffn_output = self._ffn_block(x)
        
        # Add residual for FFN
        x = x + ffn_output
        
        return x, kv_cache


class WikiMiniModel(nn.Module):
    """Complete WikiMini 95M language model.
    
    Architecture:
    - Token embeddings with weight tying
    - Stack of transformer blocks
    - Final RMSNorm
    - LM head (tied with embeddings)
    """
    
    def __init__(self, config: Dict[str, Any]):
        super().__init__()
        
        # Extract config values with defaults
        self.vocab_size = config.get('vocab_size', 32000)
        self.d_model = config.get('d_model', 768)
        self.n_layers = config.get('n_layers', 12)
        self.n_heads = config.get('n_heads', 12)
        self.d_ffn = config.get('d_ffn', None)
        self.max_seq_len = config.get('max_seq_len', 2048)
        self.dropout = config.get('dropout', 0.1)
        self.rope_percentage = config.get('rope_percentage', 0.5)
        self.rope_base = config.get('rope_base', 10000)
        self.rms_norm_eps = config.get('rms_norm_eps', 1e-6)
        self.tie_embeddings = config.get('tie_embeddings', True)
        self.use_flash_attention = config.get('use_flash_attention', True)
        self.use_gradient_checkpointing = config.get('gradient_checkpointing', False)
        
        # Token embeddings
        self.token_embedding = nn.Embedding(self.vocab_size, self.d_model)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(
                d_model=self.d_model,
                n_heads=self.n_heads,
                d_ffn=self.d_ffn,
                dropout=self.dropout,
                max_seq_len=self.max_seq_len,
                rope_base=self.rope_base,
                rope_percentage=self.rope_percentage,
                rms_norm_eps=self.rms_norm_eps,
                use_flash_attention=self.use_flash_attention,
                use_gradient_checkpointing=self.use_gradient_checkpointing,
            )
            for _ in range(self.n_layers)
        ])
        
        # Final normalization
        self.final_norm = RMSNorm(self.d_model, eps=self.rms_norm_eps)
        
        # Language modeling head
        self.lm_head = nn.Linear(self.d_model, self.vocab_size, bias=False)
        
        # Weight tying
        if self.tie_embeddings:
            self.lm_head.weight = self.token_embedding.weight
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        """Initialize weights with scaled normal distribution."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def enable_gradient_checkpointing(self):
        """Enable gradient checkpointing for all transformer blocks."""
        self.use_gradient_checkpointing = True
        for block in self.blocks:
            block.use_gradient_checkpointing = True
    
    def disable_gradient_checkpointing(self):
        """Disable gradient checkpointing for all transformer blocks."""
        self.use_gradient_checkpointing = False
        for block in self.blocks:
            block.use_gradient_checkpointing = False

    def count_parameters(self) -> dict:
        """Count model parameters by component.

        Returns:
            Dictionary with parameter counts for each component
        """
        # Count by component type
        embedding_params = sum(p.numel() for p in self.token_embedding.parameters())

        attention_params = 0
        ffn_params = 0
        norm_params = 0

        for block in self.blocks:
            # Attention parameters
            attention_params += sum(p.numel() for p in block.attention.parameters())
            # FFN parameters
            ffn_params += sum(p.numel() for p in block.ffn.parameters())
            # Norm parameters (attention + ffn norms)
            norm_params += sum(p.numel() for p in block.attn_norm.parameters())
            norm_params += sum(p.numel() for p in block.ffn_norm.parameters())

        # Final norm
        norm_params += sum(p.numel() for p in self.final_norm.parameters())

        # LM head (only if not tied)
        if not self.tie_embeddings:
            lm_head_params = sum(p.numel() for p in self.lm_head.parameters())
        else:
            lm_head_params = 0  # Shared with embeddings

        total_params = sum(p.numel() for p in self.parameters())

        return {
            'total': total_params,
            'total_millions': total_params / 1e6,
            'embedding': embedding_params,
            'attention': attention_params,
            'ffn': ffn_params,
            'norm': norm_params,
            'lm_head': lm_head_params,
        }
    
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        use_cache: bool = False,
        past_key_values: Optional[list] = None,
    ) -> Dict[str, torch.Tensor]:
        """Forward pass of the model.

        Args:
            input_ids: Token IDs [batch, seq_len]
            attention_mask: Optional attention mask
            position_ids: Optional position IDs
            labels: Optional labels for language modeling loss
            use_cache: Whether to return KV cache
            past_key_values: Past KV cache for inference

        Returns:
            Dictionary with 'logits' and optionally 'loss' and 'past_key_values'
        """
        batch_size, seq_len = input_ids.shape

        # Token embeddings
        x = self.token_embedding(input_ids)
        
        # Apply dropout to embeddings
        x = nn.functional.dropout(x, p=self.dropout, training=self.training)

        # Process through transformer blocks
        past_key_values_out = [] if use_cache else None

        for i, block in enumerate(self.blocks):
            # Get past KV for this layer if available
            past_kv = past_key_values[i] if past_key_values is not None else None

            # Process through block
            x, kv_cache = block(
                x,
                attention_mask=attention_mask,
                position_ids=position_ids,
                use_cache=use_cache,
                past_kv=past_kv,
            )

            # Store KV cache if needed
            if use_cache:
                past_key_values_out.append(kv_cache)

        # Final normalization
        x = self.final_norm(x)

        # Language modeling head
        logits = self.lm_head(x)
        
        # Prepare output
        output = {'logits': logits}

        # Calculate loss if labels provided
        if labels is not None:
            # Shift for next-token prediction
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            # Flatten for cross-entropy
            shift_logits = shift_logits.view(-1, self.vocab_size)
            shift_labels = shift_labels.view(-1)

            # Calculate cross-entropy loss
            loss = nn.functional.cross_entropy(
                shift_logits,
                shift_labels,
                ignore_index=-100,  # Standard ignore index
            )

            output['loss'] = loss

        # Add KV cache to output if requested
        if use_cache:
            output['past_key_values'] = past_key_values_out

        return output


def create_model(config: Dict[str, Any]) -> WikiMiniModel:
    """Create a WikiMini model from configuration.
    
    Args:
        config: Model configuration dictionary
    
    Returns:
        WikiMiniModel instance
    """
    return WikiMiniModel(config)


# Test the complete model
if __name__ == "__main__":
    # Test configuration for ~95M parameters
    config = {
        'vocab_size': 32000,
        'd_model': 768,
        'n_layers': 12,
        'n_heads': 12,
        'd_ffn': 2048,  # Adjusted for SwiGLU
        'max_seq_len': 2048,
        'dropout': 0.1,
        'rope_percentage': 0.5,
        'rope_base': 10000,
        'rms_norm_eps': 1e-6,
        'tie_embeddings': True,
        'use_flash_attention': True,
        'gradient_checkpointing': True,  # Enable for memory efficiency
    }
    
    # Create model
    model = WikiMiniModel(config)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"WikiMini Model:")
    print(f"  Total parameters: {total_params:,} ({total_params/1e6:.2f}M)")
    print(f"  Trainable parameters: {trainable_params:,} ({trainable_params/1e6:.2f}M)")
    print(f"  Layers: {model.n_layers}")
    print(f"  Hidden size: {model.d_model}")
    print(f"  Attention heads: {model.n_heads}")
    
    # Test forward pass
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()
    
    # Small test batch
    batch_size = 2
    seq_len = 128
    input_ids = torch.randint(0, config['vocab_size'], (batch_size, seq_len), device=device)
    
    # Enable gradient checkpointing
    model.enable_gradient_checkpointing()
    
    # Forward pass
    with torch.no_grad():
        outputs = model(input_ids=input_ids)
    
    print(f"\nTest forward pass:")
    print(f"  Input shape: {input_ids.shape}")
    print(f"  Output logits shape: {outputs['logits'].shape}")
    print(f"  Device: {device}")
    
    if torch.cuda.is_available():
        print(f"  Memory allocated: {torch.cuda.memory_allocated(device) / 1024**3:.2f} GB")
    
    print("\n✓ Model test passed!")