Architecture code included.
Browse files- .gitattributes +1 -0
- architecture/README.md +82 -0
- architecture/__init__.py +2 -0
- architecture/architecture.png +3 -0
- architecture/gemma3.py +130 -0
- architecture/model_config.py +16 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
architecture/architecture.png filter=lfs diff=lfs merge=lfs -text
|
architecture/README.md
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Architecture Module
|
| 2 |
+
|
| 3 |
+
This module contains the main Gemma3 model implementation and configuration management.
|
| 4 |
+
|
| 5 |
+
## Files
|
| 6 |
+
|
| 7 |
+
### `gemma3.py`
|
| 8 |
+
The core Gemma3Model class implementation featuring:
|
| 9 |
+
|
| 10 |
+
- **Token Embeddings**: Scaled embedding layer with vocabulary size of 50,257
|
| 11 |
+
- **Transformer Blocks**: 18 layers with mixed attention patterns (sliding window and full attention)
|
| 12 |
+
- **Dual RoPE**: Two sets of rotary position embeddings for local and global context
|
| 13 |
+
- **Attention Masks**: Dynamic generation of causal and sliding window masks
|
| 14 |
+
- **Output Head**: Linear projection to vocabulary size for next-token prediction
|
| 15 |
+
- **Generation Method**: Temperature-controlled sampling with top-k filtering
|
| 16 |
+
|
| 17 |
+
Key components:
|
| 18 |
+
- `__init__`: Initializes model layers, embeddings, and precomputes RoPE parameters
|
| 19 |
+
- `_create_masks`: Generates causal and sliding window attention masks
|
| 20 |
+
- `forward`: Main forward pass with optional loss computation
|
| 21 |
+
- `generate`: Autoregressive text generation with temperature and top-k sampling
|
| 22 |
+
|
| 23 |
+
### `model_config.py`
|
| 24 |
+
Configuration loader that reads model hyperparameters from `config/model_config.json`.
|
| 25 |
+
|
| 26 |
+
### `__init__.py`
|
| 27 |
+
Module initialization that exports:
|
| 28 |
+
- `model_config`: Dictionary containing all model hyperparameters
|
| 29 |
+
- `Gemma3Model`: The main model class
|
| 30 |
+
|
| 31 |
+
## Model Architecture Details
|
| 32 |
+
|
| 33 |
+
### Layer Configuration
|
| 34 |
+
The model uses a strategic mix of attention types across 18 layers:
|
| 35 |
+
- **Layers 1-5**: Sliding window attention (512 token window)
|
| 36 |
+
- **Layer 6**: Full attention (checkpoint layer)
|
| 37 |
+
- **Layers 7-11**: Sliding window attention
|
| 38 |
+
- **Layer 12**: Full attention (checkpoint layer)
|
| 39 |
+
- **Layers 13-17**: Sliding window attention
|
| 40 |
+
- **Layer 18**: Full attention (final layer)
|
| 41 |
+
|
| 42 |
+
This pattern allows the model to:
|
| 43 |
+
- Efficiently process local context with sliding windows
|
| 44 |
+
- Capture long-range dependencies at strategic checkpoints
|
| 45 |
+
- Balance computational efficiency with modeling capability
|
| 46 |
+
|
| 47 |
+
### Embedding and Normalization
|
| 48 |
+
- **Embedding Scaling**: Input embeddings are scaled by √(embedding_dim) for training stability
|
| 49 |
+
- **Final Normalization**: RMS normalization before the output projection
|
| 50 |
+
- **Weight Tying**: Output projection weights are separate from input embeddings
|
| 51 |
+
|
| 52 |
+
### Position Encoding
|
| 53 |
+
The model uses dual RoPE (Rotary Position Embeddings):
|
| 54 |
+
- **Local RoPE**: θ_base = 10,000 for sliding window attention
|
| 55 |
+
- **Global RoPE**: θ_base = 1,000,000 for full attention layers
|
| 56 |
+
|
| 57 |
+
This dual approach allows different attention patterns to use position encodings optimized for their respective context ranges.
|
| 58 |
+
|
| 59 |
+
## Usage Example
|
| 60 |
+
|
| 61 |
+
```python
|
| 62 |
+
from architecture import Gemma3Model, model_config
|
| 63 |
+
import torch
|
| 64 |
+
|
| 65 |
+
# Initialize model
|
| 66 |
+
model = Gemma3Model(model_config)
|
| 67 |
+
|
| 68 |
+
# Forward pass
|
| 69 |
+
input_ids = torch.randint(0, 50257, (2, 128)) # batch_size=2, seq_len=128
|
| 70 |
+
logits, loss = model(input_ids, targets=None)
|
| 71 |
+
|
| 72 |
+
# Generation
|
| 73 |
+
prompt = torch.randint(0, 50257, (1, 10)) # Single prompt
|
| 74 |
+
generated = model.generate(prompt, max_new_tokens=50, temperature=0.8, top_k=40)
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## Design Decisions
|
| 78 |
+
|
| 79 |
+
1. **Mixed Attention**: Combines efficiency of sliding windows with the modeling power of full attention
|
| 80 |
+
2. **Separate RoPE Bases**: Optimizes position encoding for different attention ranges
|
| 81 |
+
3. **Grouped Query Attention**: Reduces KV cache memory while maintaining performance
|
| 82 |
+
4. **Gemma3-style Normalization**: Uses (1 + weight) scaling for better training dynamics
|
architecture/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .gemma3 import Gemma3Model
|
| 2 |
+
from .model_config import model_config
|
architecture/architecture.png
ADDED
|
Git LFS Details
|
architecture/gemma3.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
from os.path import dirname as up
|
| 3 |
+
|
| 4 |
+
sys.path.append(os.path.abspath(os.path.join(up(__file__), os.pardir)))
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from block.transformer import TransformerBlock
|
| 11 |
+
from block.rms_norm import RMSNorm
|
| 12 |
+
from block.rope import compute_rope_params
|
| 13 |
+
|
| 14 |
+
class Gemma3Model(nn.Module):
|
| 15 |
+
def __init__(self, cfg):
|
| 16 |
+
super().__init__()
|
| 17 |
+
assert cfg["layer_types"] is not None and len(cfg["layer_types"]) == cfg["n_layers"]
|
| 18 |
+
|
| 19 |
+
# Main model parameters
|
| 20 |
+
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
|
| 21 |
+
|
| 22 |
+
self.blocks = nn.ModuleList([
|
| 23 |
+
TransformerBlock(cfg, attn_type)for attn_type in cfg["layer_types"]
|
| 24 |
+
])
|
| 25 |
+
|
| 26 |
+
self.final_norm = RMSNorm(cfg["emb_dim"], eps=1e-6)
|
| 27 |
+
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
| 28 |
+
self.cfg = cfg
|
| 29 |
+
|
| 30 |
+
# Reusuable utilities
|
| 31 |
+
cos_local, sin_local = compute_rope_params(
|
| 32 |
+
head_dim=cfg["head_dim"],
|
| 33 |
+
theta_base=cfg["rope_local_base"],
|
| 34 |
+
context_length=cfg["context_length"],
|
| 35 |
+
dtype=torch.float32,
|
| 36 |
+
)
|
| 37 |
+
cos_global, sin_global = compute_rope_params(
|
| 38 |
+
head_dim=cfg["head_dim"],
|
| 39 |
+
theta_base=cfg["rope_base"],
|
| 40 |
+
context_length=cfg["context_length"],
|
| 41 |
+
dtype=torch.float32,
|
| 42 |
+
)
|
| 43 |
+
self.register_buffer("cos_local", cos_local, persistent=False)
|
| 44 |
+
self.register_buffer("sin_local", sin_local, persistent=False)
|
| 45 |
+
self.register_buffer("cos_global", cos_global, persistent=False)
|
| 46 |
+
self.register_buffer("sin_global", sin_global, persistent=False)
|
| 47 |
+
|
| 48 |
+
def _create_masks(self, seq_len, device):
|
| 49 |
+
ones = torch.ones((seq_len, seq_len), dtype=torch.bool, device=device)
|
| 50 |
+
|
| 51 |
+
# mask_global (future is masked: j > i)
|
| 52 |
+
# j: 0 1 2 3 4 5 6 7
|
| 53 |
+
# i
|
| 54 |
+
# 0: 0 1 1 1 1 1 1 1
|
| 55 |
+
# 1: 0 0 1 1 1 1 1 1
|
| 56 |
+
# 2: 0 0 0 1 1 1 1 1
|
| 57 |
+
# 3: 0 0 0 0 1 1 1 1
|
| 58 |
+
# 4: 0 0 0 0 0 1 1 1
|
| 59 |
+
# 5: 0 0 0 0 0 0 1 1
|
| 60 |
+
# 6: 0 0 0 0 0 0 0 1
|
| 61 |
+
# 7: 0 0 0 0 0 0 0 0
|
| 62 |
+
mask_global = torch.triu(ones, diagonal=1)
|
| 63 |
+
|
| 64 |
+
# far_past (too far back is masked: i - j >= sliding_window)
|
| 65 |
+
# where sliding_window = 4
|
| 66 |
+
# j: 0 1 2 3 4 5 6 7
|
| 67 |
+
# i
|
| 68 |
+
# 0: 0 0 0 0 0 0 0 0
|
| 69 |
+
# 1: 0 0 0 0 0 0 0 0
|
| 70 |
+
# 2: 0 0 0 0 0 0 0 0
|
| 71 |
+
# 3: 0 0 0 0 0 0 0 0
|
| 72 |
+
# 4: 1 0 0 0 0 0 0 0
|
| 73 |
+
# 5: 1 1 0 0 0 0 0 0
|
| 74 |
+
# 6: 1 1 1 0 0 0 0 0
|
| 75 |
+
# 7: 1 1 1 1 0 0 0 0
|
| 76 |
+
far_past = torch.triu(ones, diagonal=self.cfg["sliding_window"]).T
|
| 77 |
+
|
| 78 |
+
# Local (sliding_window) = future OR far-past
|
| 79 |
+
# mask_local
|
| 80 |
+
# j: 0 1 2 3 4 5 6 7
|
| 81 |
+
# i
|
| 82 |
+
# 0: 0 1 1 1 1 1 1 1
|
| 83 |
+
# 1: 0 0 1 1 1 1 1 1
|
| 84 |
+
# 2: 0 0 0 1 1 1 1 1
|
| 85 |
+
# 3: 0 0 0 0 1 1 1 1
|
| 86 |
+
# 4: 1 0 0 0 0 1 1 1
|
| 87 |
+
# 5: 1 1 0 0 0 0 1 1
|
| 88 |
+
# 6: 1 1 1 0 0 0 0 1
|
| 89 |
+
# 7: 1 1 1 1 0 0 0 0
|
| 90 |
+
mask_local = mask_global | far_past
|
| 91 |
+
return mask_global, mask_local
|
| 92 |
+
|
| 93 |
+
def forward(self, input_ids, targets=None):
|
| 94 |
+
b, seq_len = input_ids.shape
|
| 95 |
+
x = self.tok_emb(input_ids) * (self.cfg["emb_dim"] ** 0.5)
|
| 96 |
+
mask_global, mask_local = self._create_masks(seq_len, x.device)
|
| 97 |
+
|
| 98 |
+
for block in self.blocks:
|
| 99 |
+
x = block(
|
| 100 |
+
x,
|
| 101 |
+
mask_global=mask_global,
|
| 102 |
+
mask_local=mask_local,
|
| 103 |
+
cos_global=self.cos_global,
|
| 104 |
+
sin_global=self.sin_global,
|
| 105 |
+
cos_local=self.cos_local,
|
| 106 |
+
sin_local=self.sin_local,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
x = self.final_norm(x)
|
| 110 |
+
logits = self.out_head(x.to(self.cfg["dtype"]))
|
| 111 |
+
loss = None
|
| 112 |
+
if targets is not None:
|
| 113 |
+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
|
| 114 |
+
return logits, loss
|
| 115 |
+
|
| 116 |
+
@torch.no_grad()
|
| 117 |
+
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
| 118 |
+
for _ in range(max_new_tokens):
|
| 119 |
+
ctx_len = self.cfg["context_length"]
|
| 120 |
+
idx_cond = idx if idx.size(1) <= ctx_len else idx[:, -ctx_len:]
|
| 121 |
+
logits, _ = self(idx_cond) # targets=None by default
|
| 122 |
+
logits = logits[:, -1, :] / temperature
|
| 123 |
+
if top_k is not None:
|
| 124 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 125 |
+
logits[logits < v[:, [-1]]] = float("-inf")
|
| 126 |
+
probs = F.softmax(logits, dim=-1)
|
| 127 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
| 128 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
| 129 |
+
return idx
|
| 130 |
+
|
architecture/model_config.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
from os.path import dirname as up
|
| 3 |
+
|
| 4 |
+
sys.path.append(os.path.abspath(os.path.join(up(__file__), os.pardir)))
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
MODEL_CONFIG_PATH = 'config/model_config.json'
|
| 9 |
+
|
| 10 |
+
with open(MODEL_CONFIG_PATH, 'r') as f:
|
| 11 |
+
model_config = json.load(f)
|
| 12 |
+
|
| 13 |
+
# print(model_config)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|