Upload FrawdLLMForCausalLM
Browse files- README.md +199 -0
- attention.py +162 -0
- block.py +118 -0
- config.json +24 -0
- config.py +209 -0
- embeddings.py +124 -0
- generation_config.json +7 -0
- gpt.py +223 -0
- hf_wrapper.py +258 -0
- mlp.py +105 -0
- model.safetensors +3 -0
- rope.py +153 -0
README.md
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags: []
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
attention.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-Head Self-Attention for FrawdLLM.
|
| 3 |
+
|
| 4 |
+
This is the core mechanism that lets tokens "look at" each other.
|
| 5 |
+
Each token creates:
|
| 6 |
+
- Query (Q): "What am I looking for?"
|
| 7 |
+
- Key (K): "What do I contain?"
|
| 8 |
+
- Value (V): "What information do I give?"
|
| 9 |
+
|
| 10 |
+
Attention score = how well Q matches K
|
| 11 |
+
Output = weighted sum of V based on attention scores
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import math
|
| 18 |
+
|
| 19 |
+
from .config import ModelConfig
|
| 20 |
+
from .rope import RotaryEmbedding
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CausalSelfAttention(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
Multi-head causal (masked) self-attention.
|
| 26 |
+
|
| 27 |
+
"Causal" means tokens can only attend to past tokens, not future.
|
| 28 |
+
This is required for language models (can't peek at what we're predicting!)
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, config: ModelConfig):
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
self.config = config
|
| 35 |
+
self.n_head = config.n_head
|
| 36 |
+
self.n_embd = config.n_embd
|
| 37 |
+
self.head_dim = config.n_embd // config.n_head # e.g., 768/12 = 64
|
| 38 |
+
self.use_rope = config.use_rope
|
| 39 |
+
|
| 40 |
+
# Linear projections to create Q, K, V
|
| 41 |
+
# Each transforms [batch, seq, n_embd] -> [batch, seq, n_embd]
|
| 42 |
+
# We do all three in one big matrix for efficiency, then split
|
| 43 |
+
self.qkv_proj = nn.Linear(config.n_embd, 3 * config.n_embd)
|
| 44 |
+
|
| 45 |
+
# Output projection: combines all heads back together
|
| 46 |
+
self.out_proj = nn.Linear(config.n_embd, config.n_embd)
|
| 47 |
+
|
| 48 |
+
# Dropout for regularization
|
| 49 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 50 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 51 |
+
|
| 52 |
+
# RoPE for position encoding (if enabled)
|
| 53 |
+
if self.use_rope:
|
| 54 |
+
self.rope = RotaryEmbedding(
|
| 55 |
+
dim=self.head_dim,
|
| 56 |
+
max_seq_len=config.context_length * 4, # Allow extrapolation
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Causal mask: lower triangular matrix
|
| 60 |
+
# This prevents attending to future tokens
|
| 61 |
+
# We register it as a buffer (saved with model, but not a parameter)
|
| 62 |
+
max_len = config.context_length * 4 if self.use_rope else config.context_length
|
| 63 |
+
mask = torch.tril(torch.ones(max_len, max_len))
|
| 64 |
+
self.register_buffer("mask", mask.view(1, 1, max_len, max_len))
|
| 65 |
+
|
| 66 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 67 |
+
"""
|
| 68 |
+
Apply multi-head causal self-attention.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
x: [batch_size, seq_len, n_embd] - input embeddings
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
[batch_size, seq_len, n_embd] - attended embeddings
|
| 75 |
+
"""
|
| 76 |
+
batch_size, seq_len, n_embd = x.shape
|
| 77 |
+
|
| 78 |
+
# Step 1: Project to Q, K, V (all at once for efficiency)
|
| 79 |
+
# [batch, seq, n_embd] -> [batch, seq, 3 * n_embd]
|
| 80 |
+
qkv = self.qkv_proj(x)
|
| 81 |
+
|
| 82 |
+
# Step 2: Split into Q, K, V
|
| 83 |
+
# [batch, seq, 3 * n_embd] -> 3 x [batch, seq, n_embd]
|
| 84 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
| 85 |
+
|
| 86 |
+
# Step 3: Reshape for multi-head attention
|
| 87 |
+
# [batch, seq, n_embd] -> [batch, n_head, seq, head_dim]
|
| 88 |
+
# Example: [32, 512, 768] -> [32, 12, 512, 64]
|
| 89 |
+
q = q.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
|
| 90 |
+
k = k.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
|
| 91 |
+
v = v.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
|
| 92 |
+
|
| 93 |
+
# Step 3.5: Apply RoPE (if enabled)
|
| 94 |
+
# This rotates Q and K based on position - encodes position info
|
| 95 |
+
if self.use_rope:
|
| 96 |
+
q = self.rope(q)
|
| 97 |
+
k = self.rope(k)
|
| 98 |
+
# Note: V is not rotated - only Q and K need position info
|
| 99 |
+
|
| 100 |
+
# Step 4: Compute attention scores
|
| 101 |
+
# Q @ K^T: [batch, n_head, seq, head_dim] @ [batch, n_head, head_dim, seq]
|
| 102 |
+
# = [batch, n_head, seq, seq]
|
| 103 |
+
# Each (i,j) entry = "how much should position i attend to position j?"
|
| 104 |
+
attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 105 |
+
|
| 106 |
+
# Step 5: Apply causal mask (prevent attending to future)
|
| 107 |
+
# Mask is 1 for allowed positions, 0 for disallowed
|
| 108 |
+
# We set disallowed positions to -inf so softmax gives 0
|
| 109 |
+
attn_scores = attn_scores.masked_fill(
|
| 110 |
+
self.mask[:, :, :seq_len, :seq_len] == 0,
|
| 111 |
+
float('-inf')
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Step 6: Softmax to get attention weights (probabilities)
|
| 115 |
+
# [batch, n_head, seq, seq] - each row sums to 1
|
| 116 |
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
| 117 |
+
attn_weights = self.attn_dropout(attn_weights)
|
| 118 |
+
|
| 119 |
+
# Step 7: Apply attention to values
|
| 120 |
+
# [batch, n_head, seq, seq] @ [batch, n_head, seq, head_dim]
|
| 121 |
+
# = [batch, n_head, seq, head_dim]
|
| 122 |
+
out = attn_weights @ v
|
| 123 |
+
|
| 124 |
+
# Step 8: Reshape back: combine all heads
|
| 125 |
+
# [batch, n_head, seq, head_dim] -> [batch, seq, n_head, head_dim] -> [batch, seq, n_embd]
|
| 126 |
+
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, n_embd)
|
| 127 |
+
|
| 128 |
+
# Step 9: Final output projection
|
| 129 |
+
out = self.out_proj(out)
|
| 130 |
+
out = self.resid_dropout(out)
|
| 131 |
+
|
| 132 |
+
return out
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
# Test the attention module
|
| 137 |
+
from .config import get_config
|
| 138 |
+
|
| 139 |
+
print("Testing CausalSelfAttention...")
|
| 140 |
+
print("=" * 50)
|
| 141 |
+
|
| 142 |
+
config = get_config("tiny")
|
| 143 |
+
print(f"Config: n_embd={config.n_embd}, n_head={config.n_head}, "
|
| 144 |
+
f"head_dim={config.head_dim}")
|
| 145 |
+
|
| 146 |
+
attn = CausalSelfAttention(config)
|
| 147 |
+
|
| 148 |
+
# Count parameters
|
| 149 |
+
num_params = sum(p.numel() for p in attn.parameters())
|
| 150 |
+
print(f"Attention parameters: {num_params:,}")
|
| 151 |
+
|
| 152 |
+
# Test input: [batch=2, seq=8, n_embd=256]
|
| 153 |
+
x = torch.randn(2, 8, config.n_embd)
|
| 154 |
+
print(f"\nInput shape: {x.shape}")
|
| 155 |
+
|
| 156 |
+
# Forward pass
|
| 157 |
+
out = attn(x)
|
| 158 |
+
print(f"Output shape: {out.shape}")
|
| 159 |
+
|
| 160 |
+
# Verify shapes match
|
| 161 |
+
assert x.shape == out.shape, "Input and output shapes should match!"
|
| 162 |
+
print("\nAttention working!")
|
block.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Transformer Block for FrawdLLM.
|
| 3 |
+
|
| 4 |
+
A transformer block combines:
|
| 5 |
+
1. Multi-head self-attention (tokens gather info from each other)
|
| 6 |
+
2. MLP (each token processes info independently)
|
| 7 |
+
|
| 8 |
+
With two important additions:
|
| 9 |
+
- LayerNorm: Keeps values stable during training
|
| 10 |
+
- Residual connections: Add input to output ("don't lose what you had")
|
| 11 |
+
|
| 12 |
+
Structure (Pre-LN, which is more stable):
|
| 13 |
+
|
| 14 |
+
Input
|
| 15 |
+
↓
|
| 16 |
+
┌─────────────┐
|
| 17 |
+
│ LayerNorm │
|
| 18 |
+
└─────────────┘
|
| 19 |
+
↓
|
| 20 |
+
┌─────────────┐
|
| 21 |
+
│ Attention │───────┐
|
| 22 |
+
└─────────────┘ │ (residual)
|
| 23 |
+
↓ │
|
| 24 |
+
+ ←─────────────────┘
|
| 25 |
+
↓
|
| 26 |
+
┌─────────────┐
|
| 27 |
+
│ LayerNorm │
|
| 28 |
+
└─────────────┘
|
| 29 |
+
↓
|
| 30 |
+
┌─────────────┐
|
| 31 |
+
│ MLP │───────┐
|
| 32 |
+
└─────────────┘ │ (residual)
|
| 33 |
+
↓ │
|
| 34 |
+
+ ←─────────────────┘
|
| 35 |
+
↓
|
| 36 |
+
Output
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
import torch
|
| 40 |
+
import torch.nn as nn
|
| 41 |
+
|
| 42 |
+
from .config import ModelConfig
|
| 43 |
+
from .attention import CausalSelfAttention
|
| 44 |
+
from .mlp import MLP
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class TransformerBlock(nn.Module):
|
| 48 |
+
"""
|
| 49 |
+
One transformer block = Attention + MLP with norms and residuals.
|
| 50 |
+
|
| 51 |
+
Input: [batch_size, seq_len, n_embd]
|
| 52 |
+
Output: [batch_size, seq_len, n_embd]
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, config: ModelConfig):
|
| 56 |
+
super().__init__()
|
| 57 |
+
|
| 58 |
+
self.config = config
|
| 59 |
+
|
| 60 |
+
# Layer norms (one before attention, one before MLP)
|
| 61 |
+
self.ln1 = nn.LayerNorm(config.n_embd)
|
| 62 |
+
self.ln2 = nn.LayerNorm(config.n_embd)
|
| 63 |
+
|
| 64 |
+
# Attention and MLP
|
| 65 |
+
self.attn = CausalSelfAttention(config)
|
| 66 |
+
self.mlp = MLP(config)
|
| 67 |
+
|
| 68 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
"""
|
| 70 |
+
Apply transformer block.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
x: [batch_size, seq_len, n_embd]
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
[batch_size, seq_len, n_embd]
|
| 77 |
+
"""
|
| 78 |
+
# Attention with residual connection
|
| 79 |
+
# x + attention(norm(x))
|
| 80 |
+
# "Keep x, add attention's contribution"
|
| 81 |
+
x = x + self.attn(self.ln1(x))
|
| 82 |
+
|
| 83 |
+
# MLP with residual connection
|
| 84 |
+
# x + mlp(norm(x))
|
| 85 |
+
# "Keep x, add MLP's contribution"
|
| 86 |
+
x = x + self.mlp(self.ln2(x))
|
| 87 |
+
|
| 88 |
+
return x
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
# Test the transformer block
|
| 93 |
+
from .config import get_config
|
| 94 |
+
|
| 95 |
+
print("Testing TransformerBlock...")
|
| 96 |
+
print("=" * 50)
|
| 97 |
+
|
| 98 |
+
config = get_config("tiny")
|
| 99 |
+
print(f"Config: n_embd={config.n_embd}, n_head={config.n_head}, "
|
| 100 |
+
f"n_layer={config.n_layer}")
|
| 101 |
+
|
| 102 |
+
block = TransformerBlock(config)
|
| 103 |
+
|
| 104 |
+
# Count parameters
|
| 105 |
+
num_params = sum(p.numel() for p in block.parameters())
|
| 106 |
+
print(f"Block parameters: {num_params:,}")
|
| 107 |
+
|
| 108 |
+
# Test input: [batch=2, seq=8, n_embd=256]
|
| 109 |
+
x = torch.randn(2, 8, config.n_embd)
|
| 110 |
+
print(f"\nInput shape: {x.shape}")
|
| 111 |
+
|
| 112 |
+
# Forward pass
|
| 113 |
+
out = block(x)
|
| 114 |
+
print(f"Output shape: {out.shape}")
|
| 115 |
+
|
| 116 |
+
# Verify shapes match
|
| 117 |
+
assert x.shape == out.shape, "Input and output shapes should match!"
|
| 118 |
+
print("\nTransformerBlock working!")
|
config.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"FrawdLLMForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "hf_wrapper.FrawdLLMConfig",
|
| 7 |
+
"AutoModelForCausalLM": "hf_wrapper.FrawdLLMForCausalLM"
|
| 8 |
+
},
|
| 9 |
+
"bos_token_id": 2,
|
| 10 |
+
"context_length": 1024,
|
| 11 |
+
"dropout": 0.1,
|
| 12 |
+
"dtype": "float32",
|
| 13 |
+
"eos_token_id": 3,
|
| 14 |
+
"model_type": "frawdllm",
|
| 15 |
+
"n_embd": 768,
|
| 16 |
+
"n_head": 12,
|
| 17 |
+
"n_layer": 12,
|
| 18 |
+
"pad_token_id": 0,
|
| 19 |
+
"transformers_version": "4.57.3",
|
| 20 |
+
"use_rmsnorm": false,
|
| 21 |
+
"use_rope": true,
|
| 22 |
+
"use_swiglu": false,
|
| 23 |
+
"vocab_size": 32000
|
| 24 |
+
}
|
config.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model configuration for FrawdLLM.
|
| 3 |
+
|
| 4 |
+
This module defines the hyperparameters that control model architecture.
|
| 5 |
+
We'll define multiple sizes to experiment with.
|
| 6 |
+
|
| 7 |
+
Learning Notes:
|
| 8 |
+
--------------
|
| 9 |
+
Key hyperparameters and their effects:
|
| 10 |
+
|
| 11 |
+
1. vocab_size: Size of tokenizer vocabulary
|
| 12 |
+
- Must match your trained tokenizer
|
| 13 |
+
- Larger = more memory for embedding table
|
| 14 |
+
|
| 15 |
+
2. n_embd (embedding dimension): Size of hidden representations
|
| 16 |
+
- Larger = more expressive, but slower and more memory
|
| 17 |
+
- GPT-2 small: 768, GPT-2 large: 1280, GPT-3: 12288
|
| 18 |
+
|
| 19 |
+
3. n_layer: Number of transformer blocks
|
| 20 |
+
- More layers = deeper reasoning, but harder to train
|
| 21 |
+
- GPT-2 small: 12, GPT-2 large: 36
|
| 22 |
+
|
| 23 |
+
4. n_head: Number of attention heads
|
| 24 |
+
- Usually n_embd / n_head = 64 (head dimension)
|
| 25 |
+
- More heads = more parallel attention patterns
|
| 26 |
+
|
| 27 |
+
5. context_length: Maximum sequence length
|
| 28 |
+
- Longer = can process more text, but O(n²) memory for attention
|
| 29 |
+
- GPT-2: 1024, GPT-3: 2048, modern models: 4096-128K
|
| 30 |
+
|
| 31 |
+
6. dropout: Regularization to prevent overfitting
|
| 32 |
+
- 0.0 for small datasets (we need all the learning we can get)
|
| 33 |
+
- 0.1-0.2 for larger datasets
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
from dataclasses import dataclass
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class ModelConfig:
|
| 41 |
+
"""Configuration for FrawdLLM model."""
|
| 42 |
+
|
| 43 |
+
# Vocabulary (must match tokenizer)
|
| 44 |
+
vocab_size: int = 8192
|
| 45 |
+
|
| 46 |
+
# Model dimensions
|
| 47 |
+
n_embd: int = 768 # Embedding dimension
|
| 48 |
+
n_layer: int = 12 # Number of transformer blocks
|
| 49 |
+
n_head: int = 12 # Number of attention heads
|
| 50 |
+
|
| 51 |
+
# Sequence length
|
| 52 |
+
context_length: int = 512 # Maximum sequence length
|
| 53 |
+
|
| 54 |
+
# Regularization
|
| 55 |
+
dropout: float = 0.0 # Dropout probability (0 for small data)
|
| 56 |
+
|
| 57 |
+
# Architecture choices (we'll implement both!)
|
| 58 |
+
use_rope: bool = False # Use Rotary Position Embeddings (Llama-style)
|
| 59 |
+
use_rmsnorm: bool = False # Use RMSNorm instead of LayerNorm (Llama-style)
|
| 60 |
+
use_swiglu: bool = False # Use SwiGLU activation (Llama-style)
|
| 61 |
+
|
| 62 |
+
# Special token IDs (must match tokenizer)
|
| 63 |
+
pad_token_id: int = 0
|
| 64 |
+
bos_token_id: int = 2
|
| 65 |
+
eos_token_id: int = 3
|
| 66 |
+
|
| 67 |
+
def __post_init__(self):
|
| 68 |
+
"""Validate configuration."""
|
| 69 |
+
assert self.n_embd % self.n_head == 0, \
|
| 70 |
+
f"n_embd ({self.n_embd}) must be divisible by n_head ({self.n_head})"
|
| 71 |
+
|
| 72 |
+
self.head_dim = self.n_embd // self.n_head
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def num_parameters(self) -> int:
|
| 76 |
+
"""Estimate total number of parameters."""
|
| 77 |
+
# Token embeddings: vocab_size * n_embd
|
| 78 |
+
token_emb = self.vocab_size * self.n_embd
|
| 79 |
+
|
| 80 |
+
# Position embeddings (if not using RoPE): context_length * n_embd
|
| 81 |
+
pos_emb = 0 if self.use_rope else self.context_length * self.n_embd
|
| 82 |
+
|
| 83 |
+
# Per transformer block:
|
| 84 |
+
# - Attention: 4 * n_embd^2 (Q, K, V, O projections)
|
| 85 |
+
# - MLP: 8 * n_embd^2 (up, down) or 12 * n_embd^2 (SwiGLU has gate)
|
| 86 |
+
# - LayerNorms: 2 * n_embd (or 4 * n_embd with biases)
|
| 87 |
+
mlp_factor = 12 if self.use_swiglu else 8
|
| 88 |
+
per_block = 4 * self.n_embd**2 + mlp_factor * self.n_embd**2 + 4 * self.n_embd
|
| 89 |
+
total_blocks = self.n_layer * per_block
|
| 90 |
+
|
| 91 |
+
# Output projection (tied with token embeddings usually, so not counted)
|
| 92 |
+
# Final layer norm: n_embd
|
| 93 |
+
final_ln = self.n_embd
|
| 94 |
+
|
| 95 |
+
return token_emb + pos_emb + total_blocks + final_ln
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# Predefined configurations for different sizes
|
| 99 |
+
# These are designed to be trainable on different hardware
|
| 100 |
+
|
| 101 |
+
# ~10M parameters - For quick debugging on CPU/M3
|
| 102 |
+
# Can train in minutes on a laptop
|
| 103 |
+
FRAWDLLM_TINY = ModelConfig(
|
| 104 |
+
vocab_size=8192,
|
| 105 |
+
n_embd=256,
|
| 106 |
+
n_layer=6,
|
| 107 |
+
n_head=8,
|
| 108 |
+
context_length=256,
|
| 109 |
+
dropout=0.0,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# ~50M parameters - Good for learning on M3/single GPU
|
| 113 |
+
# Can train in hours on M3, generates reasonable text
|
| 114 |
+
FRAWDLLM_SMALL = ModelConfig(
|
| 115 |
+
vocab_size=8192,
|
| 116 |
+
n_embd=512,
|
| 117 |
+
n_layer=8,
|
| 118 |
+
n_head=8,
|
| 119 |
+
context_length=512,
|
| 120 |
+
dropout=0.0,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# ~125M parameters - Similar to GPT-2 small
|
| 124 |
+
# Needs GPU (AWS), generates good quality text
|
| 125 |
+
FRAWDLLM_BASE = ModelConfig(
|
| 126 |
+
vocab_size=8192,
|
| 127 |
+
n_embd=768,
|
| 128 |
+
n_layer=12,
|
| 129 |
+
n_head=12,
|
| 130 |
+
context_length=1024,
|
| 131 |
+
dropout=0.1,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# Llama-style variants (modern architecture)
|
| 136 |
+
FRAWDLLM_TINY_LLAMA = ModelConfig(
|
| 137 |
+
vocab_size=8192,
|
| 138 |
+
n_embd=256,
|
| 139 |
+
n_layer=6,
|
| 140 |
+
n_head=8,
|
| 141 |
+
context_length=256,
|
| 142 |
+
dropout=0.0,
|
| 143 |
+
use_rope=True,
|
| 144 |
+
use_rmsnorm=True,
|
| 145 |
+
use_swiglu=True,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
FRAWDLLM_SMALL_LLAMA = ModelConfig(
|
| 149 |
+
vocab_size=8192,
|
| 150 |
+
n_embd=512,
|
| 151 |
+
n_layer=8,
|
| 152 |
+
n_head=8,
|
| 153 |
+
context_length=512,
|
| 154 |
+
dropout=0.0,
|
| 155 |
+
use_rope=True,
|
| 156 |
+
use_rmsnorm=True,
|
| 157 |
+
use_swiglu=True,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# ~100M parameters - Similar to GPT-2 Small but with modern architecture
|
| 161 |
+
# Uses RoPE for position encoding, allowing longer context at inference
|
| 162 |
+
FRAWDLLM_100M = ModelConfig(
|
| 163 |
+
vocab_size=32000, # Larger vocab for diverse data
|
| 164 |
+
n_embd=768,
|
| 165 |
+
n_layer=12,
|
| 166 |
+
n_head=12,
|
| 167 |
+
context_length=1024, # Train on 1024, can extrapolate to 2048+
|
| 168 |
+
dropout=0.1,
|
| 169 |
+
use_rope=True, # Rotary position embeddings
|
| 170 |
+
use_rmsnorm=False, # Keep LayerNorm for now
|
| 171 |
+
use_swiglu=False, # Keep GELU for now
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def get_config(name: str) -> ModelConfig:
|
| 176 |
+
"""Get a predefined configuration by name."""
|
| 177 |
+
configs = {
|
| 178 |
+
"tiny": FRAWDLLM_TINY,
|
| 179 |
+
"small": FRAWDLLM_SMALL,
|
| 180 |
+
"base": FRAWDLLM_BASE,
|
| 181 |
+
"tiny-llama": FRAWDLLM_TINY_LLAMA,
|
| 182 |
+
"small-llama": FRAWDLLM_SMALL_LLAMA,
|
| 183 |
+
"100m": FRAWDLLM_100M,
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
if name not in configs:
|
| 187 |
+
raise ValueError(f"Unknown config: {name}. Available: {list(configs.keys())}")
|
| 188 |
+
|
| 189 |
+
return configs[name]
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
if __name__ == "__main__":
|
| 193 |
+
# Print parameter counts for each config
|
| 194 |
+
print("FrawdLLM Model Configurations")
|
| 195 |
+
print("=" * 50)
|
| 196 |
+
|
| 197 |
+
for name in ["tiny", "small", "base", "tiny-llama", "small-llama"]:
|
| 198 |
+
config = get_config(name)
|
| 199 |
+
params = config.num_parameters
|
| 200 |
+
print(f"\n{name}:")
|
| 201 |
+
print(f" Parameters: {params:,} ({params/1e6:.1f}M)")
|
| 202 |
+
print(f" Embedding dim: {config.n_embd}")
|
| 203 |
+
print(f" Layers: {config.n_layer}")
|
| 204 |
+
print(f" Heads: {config.n_head}")
|
| 205 |
+
print(f" Context: {config.context_length}")
|
| 206 |
+
if config.use_rope:
|
| 207 |
+
print(f" Style: Llama (RoPE, RMSNorm, SwiGLU)")
|
| 208 |
+
else:
|
| 209 |
+
print(f" Style: GPT-2 (learned pos, LayerNorm, GELU)")
|
embeddings.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Token and Position Embeddings for FrawdLLM.
|
| 3 |
+
|
| 4 |
+
This is the first layer of the model - converts token IDs into vectors
|
| 5 |
+
that the transformer can process.
|
| 6 |
+
|
| 7 |
+
Two lookup tables:
|
| 8 |
+
1. Token embeddings: WHAT the token is (vocab_size x n_embd)
|
| 9 |
+
2. Position embeddings: WHERE the token is (context_length x n_embd)
|
| 10 |
+
|
| 11 |
+
Final output = token_emb + pos_emb (just addition!)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
|
| 17 |
+
from .config import ModelConfig
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Embeddings(nn.Module):
|
| 21 |
+
"""
|
| 22 |
+
Combined token + position embeddings.
|
| 23 |
+
|
| 24 |
+
Input: token_ids [batch_size, seq_len] - integers from tokenizer
|
| 25 |
+
Output: vectors [batch_size, seq_len, n_embd] - dense representations
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, config: ModelConfig):
|
| 29 |
+
super().__init__() # Initialize nn.Module tracking
|
| 30 |
+
|
| 31 |
+
self.config = config
|
| 32 |
+
self.use_rope = config.use_rope
|
| 33 |
+
|
| 34 |
+
# Token embedding table: one vector per vocabulary word
|
| 35 |
+
# Shape: [vocab_size, n_embd] = [8192, 768]
|
| 36 |
+
self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)
|
| 37 |
+
|
| 38 |
+
# Position embedding table: one vector per position (only if NOT using RoPE)
|
| 39 |
+
# Shape: [context_length, n_embd] = [512, 768]
|
| 40 |
+
# With RoPE, position is encoded in attention via rotation instead
|
| 41 |
+
if not self.use_rope:
|
| 42 |
+
self.pos_emb = nn.Embedding(config.context_length, config.n_embd)
|
| 43 |
+
else:
|
| 44 |
+
self.pos_emb = None
|
| 45 |
+
|
| 46 |
+
# Dropout for regularization (usually 0 for small datasets)
|
| 47 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 48 |
+
|
| 49 |
+
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 50 |
+
"""
|
| 51 |
+
Convert token IDs to embeddings.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
token_ids: [batch_size, seq_len] tensor of token IDs
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
[batch_size, seq_len, n_embd] tensor of embeddings
|
| 58 |
+
"""
|
| 59 |
+
batch_size, seq_len = token_ids.shape
|
| 60 |
+
|
| 61 |
+
# Safety check: don't exceed context window (relaxed for RoPE)
|
| 62 |
+
max_len = self.config.context_length * 4 if self.use_rope else self.config.context_length
|
| 63 |
+
if seq_len > max_len:
|
| 64 |
+
raise ValueError(
|
| 65 |
+
f"Sequence length {seq_len} exceeds maximum length {max_len}"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# Step 1: Look up token embeddings
|
| 69 |
+
# [batch_size, seq_len] -> [batch_size, seq_len, n_embd]
|
| 70 |
+
embeddings = self.token_emb(token_ids)
|
| 71 |
+
|
| 72 |
+
# Step 2: Add position embeddings (only if NOT using RoPE)
|
| 73 |
+
# With RoPE, position is encoded via rotation in attention instead
|
| 74 |
+
if not self.use_rope:
|
| 75 |
+
positions = torch.arange(seq_len, device=token_ids.device)
|
| 76 |
+
pos_emb = self.pos_emb(positions)
|
| 77 |
+
embeddings = embeddings + pos_emb
|
| 78 |
+
|
| 79 |
+
# Step 3: Apply dropout (if any)
|
| 80 |
+
embeddings = self.dropout(embeddings)
|
| 81 |
+
|
| 82 |
+
return embeddings
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
# Quick test to verify it works
|
| 87 |
+
from .config import get_config
|
| 88 |
+
|
| 89 |
+
print("Testing Embeddings...")
|
| 90 |
+
print("=" * 50)
|
| 91 |
+
|
| 92 |
+
# Use tiny config for testing
|
| 93 |
+
config = get_config("tiny")
|
| 94 |
+
print(f"Config: vocab={config.vocab_size}, n_embd={config.n_embd}, "
|
| 95 |
+
f"context={config.context_length}")
|
| 96 |
+
|
| 97 |
+
# Create embedding layer
|
| 98 |
+
emb = Embeddings(config)
|
| 99 |
+
|
| 100 |
+
# Count parameters
|
| 101 |
+
num_params = sum(p.numel() for p in emb.parameters())
|
| 102 |
+
print(f"Embedding parameters: {num_params:,}")
|
| 103 |
+
|
| 104 |
+
# Test forward pass
|
| 105 |
+
# Fake batch: 2 sequences of 4 tokens each
|
| 106 |
+
token_ids = torch.tensor([
|
| 107 |
+
[2, 531, 892, 12], # Sequence 1
|
| 108 |
+
[2, 100, 200, 3], # Sequence 2
|
| 109 |
+
])
|
| 110 |
+
|
| 111 |
+
print(f"\nInput shape: {token_ids.shape}")
|
| 112 |
+
print(f"Input tokens: {token_ids.tolist()}")
|
| 113 |
+
|
| 114 |
+
# Forward pass
|
| 115 |
+
output = emb(token_ids)
|
| 116 |
+
|
| 117 |
+
print(f"\nOutput shape: {output.shape}")
|
| 118 |
+
print(f"Each token is now a {output.shape[-1]}-dimensional vector")
|
| 119 |
+
|
| 120 |
+
# Show a snippet of the output
|
| 121 |
+
print(f"\nFirst token's vector (first 10 dims):")
|
| 122 |
+
print(f" {output[0, 0, :10].tolist()}")
|
| 123 |
+
|
| 124 |
+
print("\nEmbeddings working!")
|
generation_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 2,
|
| 4 |
+
"eos_token_id": 3,
|
| 5 |
+
"pad_token_id": 0,
|
| 6 |
+
"transformers_version": "4.57.3"
|
| 7 |
+
}
|
gpt.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Full GPT Model for FrawdLLM.
|
| 3 |
+
|
| 4 |
+
This is the complete model that:
|
| 5 |
+
1. Takes token IDs as input
|
| 6 |
+
2. Converts to embeddings (token + position)
|
| 7 |
+
3. Passes through N transformer blocks
|
| 8 |
+
4. Predicts the next token
|
| 9 |
+
|
| 10 |
+
Architecture:
|
| 11 |
+
Token IDs [batch, seq]
|
| 12 |
+
↓
|
| 13 |
+
Embeddings [batch, seq, n_embd]
|
| 14 |
+
↓
|
| 15 |
+
Transformer Block × N
|
| 16 |
+
↓
|
| 17 |
+
Final LayerNorm
|
| 18 |
+
↓
|
| 19 |
+
Output Head → [batch, seq, vocab_size]
|
| 20 |
+
↓
|
| 21 |
+
Logits (unnormalized probabilities for each vocab word)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
|
| 28 |
+
from .config import ModelConfig
|
| 29 |
+
from .embeddings import Embeddings
|
| 30 |
+
from .block import TransformerBlock
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class FrawdLLM(nn.Module):
|
| 34 |
+
"""
|
| 35 |
+
The complete FrawdLLM model.
|
| 36 |
+
|
| 37 |
+
Input: token_ids [batch_size, seq_len]
|
| 38 |
+
Output: logits [batch_size, seq_len, vocab_size]
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, config: ModelConfig):
|
| 42 |
+
super().__init__()
|
| 43 |
+
|
| 44 |
+
self.config = config
|
| 45 |
+
|
| 46 |
+
# Token + position embeddings
|
| 47 |
+
self.embeddings = Embeddings(config)
|
| 48 |
+
|
| 49 |
+
# Stack of transformer blocks
|
| 50 |
+
self.blocks = nn.ModuleList([
|
| 51 |
+
TransformerBlock(config) for _ in range(config.n_layer)
|
| 52 |
+
])
|
| 53 |
+
|
| 54 |
+
# Final layer norm (before output projection)
|
| 55 |
+
self.ln_f = nn.LayerNorm(config.n_embd)
|
| 56 |
+
|
| 57 |
+
# Output head: project from n_embd to vocab_size
|
| 58 |
+
# This gives us a score for each possible next token
|
| 59 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 60 |
+
|
| 61 |
+
# Weight tying: share weights between token embeddings and output head
|
| 62 |
+
# This is a common trick that:
|
| 63 |
+
# 1. Reduces parameters
|
| 64 |
+
# 2. Makes sense: similar tokens should have similar embeddings AND predictions
|
| 65 |
+
self.lm_head.weight = self.embeddings.token_emb.weight
|
| 66 |
+
|
| 67 |
+
# Initialize weights
|
| 68 |
+
self.apply(self._init_weights)
|
| 69 |
+
|
| 70 |
+
def _init_weights(self, module):
|
| 71 |
+
"""Initialize weights for better training."""
|
| 72 |
+
if isinstance(module, nn.Linear):
|
| 73 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 74 |
+
if module.bias is not None:
|
| 75 |
+
torch.nn.init.zeros_(module.bias)
|
| 76 |
+
elif isinstance(module, nn.Embedding):
|
| 77 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 78 |
+
|
| 79 |
+
def forward(
|
| 80 |
+
self,
|
| 81 |
+
token_ids: torch.Tensor,
|
| 82 |
+
targets: torch.Tensor | None = None,
|
| 83 |
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
| 84 |
+
"""
|
| 85 |
+
Forward pass through the model.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
token_ids: [batch_size, seq_len] - input token IDs
|
| 89 |
+
targets: [batch_size, seq_len] - target token IDs (for computing loss)
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
logits: [batch_size, seq_len, vocab_size] - prediction scores
|
| 93 |
+
loss: scalar tensor if targets provided, else None
|
| 94 |
+
"""
|
| 95 |
+
# Step 1: Convert token IDs to embeddings
|
| 96 |
+
# [batch, seq] → [batch, seq, n_embd]
|
| 97 |
+
x = self.embeddings(token_ids)
|
| 98 |
+
|
| 99 |
+
# Step 2: Pass through all transformer blocks
|
| 100 |
+
for block in self.blocks:
|
| 101 |
+
x = block(x)
|
| 102 |
+
|
| 103 |
+
# Step 3: Final layer norm
|
| 104 |
+
x = self.ln_f(x)
|
| 105 |
+
|
| 106 |
+
# Step 4: Project to vocabulary size
|
| 107 |
+
# [batch, seq, n_embd] → [batch, seq, vocab_size]
|
| 108 |
+
logits = self.lm_head(x)
|
| 109 |
+
|
| 110 |
+
# Step 5: Compute loss if targets provided
|
| 111 |
+
loss = None
|
| 112 |
+
if targets is not None:
|
| 113 |
+
# Flatten for cross-entropy
|
| 114 |
+
# logits: [batch * seq, vocab_size]
|
| 115 |
+
# targets: [batch * seq]
|
| 116 |
+
loss = F.cross_entropy(
|
| 117 |
+
logits.view(-1, logits.size(-1)),
|
| 118 |
+
targets.view(-1),
|
| 119 |
+
ignore_index=self.config.pad_token_id, # Don't compute loss on padding
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
return logits, loss
|
| 123 |
+
|
| 124 |
+
@torch.no_grad()
|
| 125 |
+
def generate(
|
| 126 |
+
self,
|
| 127 |
+
token_ids: torch.Tensor,
|
| 128 |
+
max_new_tokens: int = 100,
|
| 129 |
+
temperature: float = 1.0,
|
| 130 |
+
top_k: int | None = None,
|
| 131 |
+
) -> torch.Tensor:
|
| 132 |
+
"""
|
| 133 |
+
Generate new tokens autoregressively.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
token_ids: [batch_size, seq_len] - starting tokens (prompt)
|
| 137 |
+
max_new_tokens: How many new tokens to generate
|
| 138 |
+
temperature: Higher = more random, lower = more deterministic
|
| 139 |
+
top_k: If set, only sample from top k most likely tokens
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
[batch_size, seq_len + max_new_tokens] - original + generated tokens
|
| 143 |
+
"""
|
| 144 |
+
for _ in range(max_new_tokens):
|
| 145 |
+
# Crop to context length if needed
|
| 146 |
+
context = token_ids[:, -self.config.context_length:]
|
| 147 |
+
|
| 148 |
+
# Get predictions
|
| 149 |
+
logits, _ = self.forward(context)
|
| 150 |
+
|
| 151 |
+
# Take logits for the last position only
|
| 152 |
+
# [batch, vocab_size]
|
| 153 |
+
logits = logits[:, -1, :]
|
| 154 |
+
|
| 155 |
+
# Apply temperature
|
| 156 |
+
logits = logits / temperature
|
| 157 |
+
|
| 158 |
+
# Optionally apply top-k filtering
|
| 159 |
+
if top_k is not None:
|
| 160 |
+
# Keep only top k values, set rest to -inf
|
| 161 |
+
top_values, _ = torch.topk(logits, top_k, dim=-1)
|
| 162 |
+
min_top_value = top_values[:, -1].unsqueeze(-1)
|
| 163 |
+
logits = torch.where(
|
| 164 |
+
logits < min_top_value,
|
| 165 |
+
torch.full_like(logits, float('-inf')),
|
| 166 |
+
logits,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Convert to probabilities
|
| 170 |
+
probs = F.softmax(logits, dim=-1)
|
| 171 |
+
|
| 172 |
+
# Sample next token
|
| 173 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 174 |
+
|
| 175 |
+
# Append to sequence
|
| 176 |
+
token_ids = torch.cat([token_ids, next_token], dim=1)
|
| 177 |
+
|
| 178 |
+
# Stop if we generated EOS token
|
| 179 |
+
if (next_token == self.config.eos_token_id).all():
|
| 180 |
+
break
|
| 181 |
+
|
| 182 |
+
return token_ids
|
| 183 |
+
|
| 184 |
+
def count_parameters(self) -> int:
|
| 185 |
+
"""Count total trainable parameters."""
|
| 186 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
if __name__ == "__main__":
|
| 190 |
+
from .config import get_config
|
| 191 |
+
|
| 192 |
+
print("Testing FrawdLLM...")
|
| 193 |
+
print("=" * 50)
|
| 194 |
+
|
| 195 |
+
config = get_config("tiny")
|
| 196 |
+
print(f"Config: vocab={config.vocab_size}, n_embd={config.n_embd}, "
|
| 197 |
+
f"n_layer={config.n_layer}, n_head={config.n_head}")
|
| 198 |
+
|
| 199 |
+
model = FrawdLLM(config)
|
| 200 |
+
|
| 201 |
+
# Count parameters
|
| 202 |
+
num_params = model.count_parameters()
|
| 203 |
+
print(f"Total parameters: {num_params:,} ({num_params/1e6:.1f}M)")
|
| 204 |
+
|
| 205 |
+
# Test forward pass
|
| 206 |
+
batch_size, seq_len = 2, 16
|
| 207 |
+
token_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
| 208 |
+
targets = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
| 209 |
+
|
| 210 |
+
print(f"\nInput shape: {token_ids.shape}")
|
| 211 |
+
|
| 212 |
+
logits, loss = model(token_ids, targets)
|
| 213 |
+
|
| 214 |
+
print(f"Output logits shape: {logits.shape}")
|
| 215 |
+
print(f"Loss: {loss.item():.4f}")
|
| 216 |
+
|
| 217 |
+
# Test generation
|
| 218 |
+
prompt = torch.tensor([[config.bos_token_id]]) # Start with BOS
|
| 219 |
+
generated = model.generate(prompt, max_new_tokens=10)
|
| 220 |
+
print(f"\nGenerated shape: {generated.shape}")
|
| 221 |
+
print(f"Generated tokens: {generated[0].tolist()}")
|
| 222 |
+
|
| 223 |
+
print("\nFrawdLLM working!")
|
hf_wrapper.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
"""
|
| 3 |
+
HuggingFace wrapper for FrawdLLM.
|
| 4 |
+
|
| 5 |
+
This allows the model to be loaded with:
|
| 6 |
+
from transformers import AutoModelForCausalLM
|
| 7 |
+
model = AutoModelForCausalLM.from_pretrained("tsingla1998/frawdllm-100m", trust_remote_code=True)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin
|
| 16 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 17 |
+
|
| 18 |
+
from .config import ModelConfig
|
| 19 |
+
from .gpt import FrawdLLM
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class FrawdLLMConfig(PretrainedConfig):
|
| 23 |
+
"""HuggingFace-compatible configuration for FrawdLLM."""
|
| 24 |
+
|
| 25 |
+
model_type = "frawdllm"
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
vocab_size: int = 32000,
|
| 30 |
+
n_embd: int = 768,
|
| 31 |
+
n_layer: int = 12,
|
| 32 |
+
n_head: int = 12,
|
| 33 |
+
context_length: int = 1024,
|
| 34 |
+
dropout: float = 0.1,
|
| 35 |
+
use_rope: bool = True,
|
| 36 |
+
use_rmsnorm: bool = False,
|
| 37 |
+
use_swiglu: bool = False,
|
| 38 |
+
pad_token_id: int = 0,
|
| 39 |
+
bos_token_id: int = 2,
|
| 40 |
+
eos_token_id: int = 3,
|
| 41 |
+
**kwargs,
|
| 42 |
+
):
|
| 43 |
+
self.vocab_size = vocab_size
|
| 44 |
+
self.n_embd = n_embd
|
| 45 |
+
self.n_layer = n_layer
|
| 46 |
+
self.n_head = n_head
|
| 47 |
+
self.context_length = context_length
|
| 48 |
+
self.dropout = dropout
|
| 49 |
+
self.use_rope = use_rope
|
| 50 |
+
self.use_rmsnorm = use_rmsnorm
|
| 51 |
+
self.use_swiglu = use_swiglu
|
| 52 |
+
|
| 53 |
+
super().__init__(
|
| 54 |
+
pad_token_id=pad_token_id,
|
| 55 |
+
bos_token_id=bos_token_id,
|
| 56 |
+
eos_token_id=eos_token_id,
|
| 57 |
+
**kwargs,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def to_model_config(self) -> ModelConfig:
|
| 61 |
+
"""Convert to internal ModelConfig for the model."""
|
| 62 |
+
return ModelConfig(
|
| 63 |
+
vocab_size=self.vocab_size,
|
| 64 |
+
n_embd=self.n_embd,
|
| 65 |
+
n_layer=self.n_layer,
|
| 66 |
+
n_head=self.n_head,
|
| 67 |
+
context_length=self.context_length,
|
| 68 |
+
dropout=self.dropout,
|
| 69 |
+
use_rope=self.use_rope,
|
| 70 |
+
use_rmsnorm=self.use_rmsnorm,
|
| 71 |
+
use_swiglu=self.use_swiglu,
|
| 72 |
+
pad_token_id=self.pad_token_id,
|
| 73 |
+
bos_token_id=self.bos_token_id,
|
| 74 |
+
eos_token_id=self.eos_token_id,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
@classmethod
|
| 78 |
+
def from_model_config(cls, config: ModelConfig) -> "FrawdLLMConfig":
|
| 79 |
+
"""Create from internal ModelConfig."""
|
| 80 |
+
return cls(
|
| 81 |
+
vocab_size=config.vocab_size,
|
| 82 |
+
n_embd=config.n_embd,
|
| 83 |
+
n_layer=config.n_layer,
|
| 84 |
+
n_head=config.n_head,
|
| 85 |
+
context_length=config.context_length,
|
| 86 |
+
dropout=config.dropout,
|
| 87 |
+
use_rope=config.use_rope,
|
| 88 |
+
use_rmsnorm=config.use_rmsnorm,
|
| 89 |
+
use_swiglu=config.use_swiglu,
|
| 90 |
+
pad_token_id=config.pad_token_id,
|
| 91 |
+
bos_token_id=config.bos_token_id,
|
| 92 |
+
eos_token_id=config.eos_token_id,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class FrawdLLMForCausalLM(PreTrainedModel, GenerationMixin):
|
| 97 |
+
"""HuggingFace-compatible wrapper for FrawdLLM."""
|
| 98 |
+
|
| 99 |
+
config_class = FrawdLLMConfig
|
| 100 |
+
base_model_prefix = "model"
|
| 101 |
+
supports_gradient_checkpointing = False
|
| 102 |
+
_no_split_modules = ["TransformerBlock"]
|
| 103 |
+
_tied_weights_keys = ["model.lm_head.weight"]
|
| 104 |
+
|
| 105 |
+
def __init__(self, config: FrawdLLMConfig):
|
| 106 |
+
super().__init__(config)
|
| 107 |
+
|
| 108 |
+
# Convert HF config to internal config
|
| 109 |
+
model_config = config.to_model_config()
|
| 110 |
+
|
| 111 |
+
# Create the actual model
|
| 112 |
+
self.model = FrawdLLM(model_config)
|
| 113 |
+
|
| 114 |
+
# For generation
|
| 115 |
+
self.main_input_name = "input_ids"
|
| 116 |
+
|
| 117 |
+
def get_input_embeddings(self):
|
| 118 |
+
return self.model.embeddings.token_emb
|
| 119 |
+
|
| 120 |
+
def set_input_embeddings(self, value):
|
| 121 |
+
self.model.embeddings.token_emb = value
|
| 122 |
+
|
| 123 |
+
def get_output_embeddings(self):
|
| 124 |
+
return self.model.lm_head
|
| 125 |
+
|
| 126 |
+
def set_output_embeddings(self, new_embeddings):
|
| 127 |
+
self.model.lm_head = new_embeddings
|
| 128 |
+
|
| 129 |
+
def tie_weights(self):
|
| 130 |
+
"""Tie input and output embeddings."""
|
| 131 |
+
self.model.lm_head.weight = self.model.embeddings.token_emb.weight
|
| 132 |
+
|
| 133 |
+
def forward(
|
| 134 |
+
self,
|
| 135 |
+
input_ids: torch.LongTensor,
|
| 136 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 137 |
+
labels: Optional[torch.LongTensor] = None,
|
| 138 |
+
past_key_values: Optional[Tuple] = None,
|
| 139 |
+
use_cache: Optional[bool] = None,
|
| 140 |
+
output_attentions: Optional[bool] = None,
|
| 141 |
+
output_hidden_states: Optional[bool] = None,
|
| 142 |
+
return_dict: Optional[bool] = None,
|
| 143 |
+
**kwargs,
|
| 144 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 145 |
+
"""
|
| 146 |
+
Forward pass compatible with HuggingFace API.
|
| 147 |
+
|
| 148 |
+
Note: attention_mask, past_key_values, use_cache are accepted but
|
| 149 |
+
not fully implemented (our model doesn't use KV caching yet).
|
| 150 |
+
"""
|
| 151 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 152 |
+
|
| 153 |
+
# Get logits from our model
|
| 154 |
+
logits, _ = self.model(input_ids, None)
|
| 155 |
+
|
| 156 |
+
# Compute loss if labels provided
|
| 157 |
+
loss = None
|
| 158 |
+
if labels is not None:
|
| 159 |
+
# Shift for causal LM loss
|
| 160 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 161 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 162 |
+
|
| 163 |
+
loss = F.cross_entropy(
|
| 164 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 165 |
+
shift_labels.view(-1),
|
| 166 |
+
ignore_index=-100,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
if not return_dict:
|
| 170 |
+
output = (logits,)
|
| 171 |
+
return (loss,) + output if loss is not None else output
|
| 172 |
+
|
| 173 |
+
return CausalLMOutputWithPast(
|
| 174 |
+
loss=loss,
|
| 175 |
+
logits=logits,
|
| 176 |
+
past_key_values=None,
|
| 177 |
+
hidden_states=None,
|
| 178 |
+
attentions=None,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def prepare_inputs_for_generation(
|
| 182 |
+
self,
|
| 183 |
+
input_ids: torch.LongTensor,
|
| 184 |
+
past_key_values: Optional[Tuple] = None,
|
| 185 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 186 |
+
**kwargs,
|
| 187 |
+
):
|
| 188 |
+
"""Prepare inputs for generation (called by HF generate())."""
|
| 189 |
+
# Our model doesn't use KV cache yet, so just return input_ids
|
| 190 |
+
return {
|
| 191 |
+
"input_ids": input_ids,
|
| 192 |
+
"attention_mask": attention_mask,
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
@classmethod
|
| 196 |
+
def from_frawdllm_checkpoint(
|
| 197 |
+
cls,
|
| 198 |
+
checkpoint_path: str,
|
| 199 |
+
device: str = "cpu",
|
| 200 |
+
) -> "FrawdLLMForCausalLM":
|
| 201 |
+
"""
|
| 202 |
+
Load from a FrawdLLM .pt checkpoint.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
checkpoint_path: Path to the .pt checkpoint file
|
| 206 |
+
device: Device to load the model on
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
FrawdLLMForCausalLM instance
|
| 210 |
+
"""
|
| 211 |
+
# Load the checkpoint
|
| 212 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 213 |
+
|
| 214 |
+
# Get the internal config
|
| 215 |
+
model_config = checkpoint["config"]
|
| 216 |
+
|
| 217 |
+
# Create HF config
|
| 218 |
+
hf_config = FrawdLLMConfig.from_model_config(model_config)
|
| 219 |
+
|
| 220 |
+
# Create the wrapper model
|
| 221 |
+
model = cls(hf_config)
|
| 222 |
+
|
| 223 |
+
# Load the weights
|
| 224 |
+
model.model.load_state_dict(checkpoint["model_state_dict"])
|
| 225 |
+
|
| 226 |
+
return model
|
| 227 |
+
|
| 228 |
+
def save_pretrained_simple(self, save_directory: str):
|
| 229 |
+
"""
|
| 230 |
+
Save in HuggingFace format.
|
| 231 |
+
|
| 232 |
+
This saves:
|
| 233 |
+
- config.json
|
| 234 |
+
- model.safetensors (or pytorch_model.bin)
|
| 235 |
+
"""
|
| 236 |
+
import os
|
| 237 |
+
from safetensors.torch import save_file
|
| 238 |
+
|
| 239 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 240 |
+
|
| 241 |
+
# Save config
|
| 242 |
+
self.config.save_pretrained(save_directory)
|
| 243 |
+
|
| 244 |
+
# Save model weights
|
| 245 |
+
# Note: We have weight tying (token_emb.weight == lm_head.weight)
|
| 246 |
+
# Remove the duplicate to avoid safetensors error
|
| 247 |
+
state_dict = self.state_dict()
|
| 248 |
+
if "model.lm_head.weight" in state_dict:
|
| 249 |
+
del state_dict["model.lm_head.weight"]
|
| 250 |
+
|
| 251 |
+
save_file(state_dict, os.path.join(save_directory, "model.safetensors"))
|
| 252 |
+
|
| 253 |
+
print(f"Saved model to {save_directory}")
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# Register for AutoClass - this adds auto_map to config when saving
|
| 257 |
+
FrawdLLMConfig.register_for_auto_class()
|
| 258 |
+
FrawdLLMForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
mlp.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MLP (Multi-Layer Perceptron) for FrawdLLM.
|
| 3 |
+
|
| 4 |
+
This is the "feed-forward" part of the transformer block.
|
| 5 |
+
After attention lets tokens gather information from each other,
|
| 6 |
+
MLP lets each token process that information independently.
|
| 7 |
+
|
| 8 |
+
Structure:
|
| 9 |
+
Input (768) → Expand (3072) → GELU → Shrink (768) → Output
|
| 10 |
+
|
| 11 |
+
The 4x expansion gives the model more "thinking room" before
|
| 12 |
+
compressing back to the original size.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
from .config import ModelConfig
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MLP(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
Simple feed-forward network with GELU activation.
|
| 24 |
+
|
| 25 |
+
Input: [batch_size, seq_len, n_embd]
|
| 26 |
+
Output: [batch_size, seq_len, n_embd]
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, config: ModelConfig):
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
self.config = config
|
| 33 |
+
|
| 34 |
+
# Hidden dimension is 4x the embedding dimension
|
| 35 |
+
# This is a common ratio used in most transformers
|
| 36 |
+
hidden_dim = 4 * config.n_embd
|
| 37 |
+
|
| 38 |
+
# Expand: 768 → 3072
|
| 39 |
+
self.fc1 = nn.Linear(config.n_embd, hidden_dim)
|
| 40 |
+
|
| 41 |
+
# Activation function
|
| 42 |
+
self.act = nn.GELU()
|
| 43 |
+
|
| 44 |
+
# Shrink: 3072 → 768
|
| 45 |
+
self.fc2 = nn.Linear(hidden_dim, config.n_embd)
|
| 46 |
+
|
| 47 |
+
# Dropout for regularization
|
| 48 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 49 |
+
|
| 50 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 51 |
+
"""
|
| 52 |
+
Apply MLP to each token independently.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
x: [batch_size, seq_len, n_embd]
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
[batch_size, seq_len, n_embd]
|
| 59 |
+
"""
|
| 60 |
+
# Step 1: Expand
|
| 61 |
+
# [batch, seq, 768] → [batch, seq, 3072]
|
| 62 |
+
x = self.fc1(x)
|
| 63 |
+
|
| 64 |
+
# Step 2: Non-linearity
|
| 65 |
+
# [batch, seq, 3072] → [batch, seq, 3072] (same shape, different values)
|
| 66 |
+
x = self.act(x)
|
| 67 |
+
|
| 68 |
+
# Step 3: Shrink back
|
| 69 |
+
# [batch, seq, 3072] → [batch, seq, 768]
|
| 70 |
+
x = self.fc2(x)
|
| 71 |
+
|
| 72 |
+
# Step 4: Dropout
|
| 73 |
+
x = self.dropout(x)
|
| 74 |
+
|
| 75 |
+
return x
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
# Test the MLP module
|
| 80 |
+
from .config import get_config
|
| 81 |
+
|
| 82 |
+
print("Testing MLP...")
|
| 83 |
+
print("=" * 50)
|
| 84 |
+
|
| 85 |
+
config = get_config("tiny")
|
| 86 |
+
hidden_dim = 4 * config.n_embd
|
| 87 |
+
print(f"Config: n_embd={config.n_embd}, hidden_dim={hidden_dim}")
|
| 88 |
+
|
| 89 |
+
mlp = MLP(config)
|
| 90 |
+
|
| 91 |
+
# Count parameters
|
| 92 |
+
num_params = sum(p.numel() for p in mlp.parameters())
|
| 93 |
+
print(f"MLP parameters: {num_params:,}")
|
| 94 |
+
|
| 95 |
+
# Test input: [batch=2, seq=8, n_embd=256]
|
| 96 |
+
x = torch.randn(2, 8, config.n_embd)
|
| 97 |
+
print(f"\nInput shape: {x.shape}")
|
| 98 |
+
|
| 99 |
+
# Forward pass
|
| 100 |
+
out = mlp(x)
|
| 101 |
+
print(f"Output shape: {out.shape}")
|
| 102 |
+
|
| 103 |
+
# Verify shapes match
|
| 104 |
+
assert x.shape == out.shape, "Input and output shapes should match!"
|
| 105 |
+
print("\nMLP working!")
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a70aec25201815a785d3731a9b204a149cae2f0e7788a24c1d853f1375ad5cd8
|
| 3 |
+
size 1243850448
|
rope.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Rotary Position Embedding (RoPE) for FrawdLLM.
|
| 3 |
+
|
| 4 |
+
RoPE encodes position by rotating the Q and K vectors. This has several advantages:
|
| 5 |
+
1. No learned position embeddings (saves parameters)
|
| 6 |
+
2. Better length generalization (can extrapolate beyond training length)
|
| 7 |
+
3. Relative position encoding (attention depends on distance, not absolute position)
|
| 8 |
+
|
| 9 |
+
How it works:
|
| 10 |
+
- Each position gets a rotation angle based on its index
|
| 11 |
+
- Q and K are rotated by their position's angle
|
| 12 |
+
- The dot product Q·K then naturally encodes relative distance
|
| 13 |
+
|
| 14 |
+
Reference: https://arxiv.org/abs/2104.09864
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import math
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def precompute_freqs(dim: int, max_seq_len: int, theta: float = 10000.0) -> torch.Tensor:
|
| 23 |
+
"""
|
| 24 |
+
Precompute the frequency tensor for RoPE.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
dim: Dimension of each head (must be even)
|
| 28 |
+
max_seq_len: Maximum sequence length
|
| 29 |
+
theta: Base for frequency computation (10000 is standard)
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Complex tensor of shape [max_seq_len, dim//2] containing rotation frequencies
|
| 33 |
+
"""
|
| 34 |
+
# Frequency for each dimension pair: theta^(-2i/dim) for i = 0, 1, ..., dim/2-1
|
| 35 |
+
# Lower dimensions rotate slowly, higher dimensions rotate quickly
|
| 36 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| 37 |
+
|
| 38 |
+
# Position indices
|
| 39 |
+
positions = torch.arange(max_seq_len)
|
| 40 |
+
|
| 41 |
+
# Outer product: [max_seq_len, dim//2]
|
| 42 |
+
# Each position gets a different rotation angle for each frequency
|
| 43 |
+
angles = torch.outer(positions, freqs)
|
| 44 |
+
|
| 45 |
+
# Convert to complex numbers for easy rotation
|
| 46 |
+
# e^(i*angle) = cos(angle) + i*sin(angle)
|
| 47 |
+
freqs_complex = torch.polar(torch.ones_like(angles), angles)
|
| 48 |
+
|
| 49 |
+
return freqs_complex
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def apply_rope(
|
| 53 |
+
x: torch.Tensor,
|
| 54 |
+
freqs: torch.Tensor,
|
| 55 |
+
start_pos: int = 0,
|
| 56 |
+
) -> torch.Tensor:
|
| 57 |
+
"""
|
| 58 |
+
Apply rotary position embedding to Q or K tensor.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
x: [batch, n_head, seq_len, head_dim] - Q or K tensor
|
| 62 |
+
freqs: [max_seq_len, head_dim//2] - precomputed frequencies
|
| 63 |
+
start_pos: Starting position (for KV cache during generation)
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Rotated tensor with same shape as input
|
| 67 |
+
"""
|
| 68 |
+
batch, n_head, seq_len, head_dim = x.shape
|
| 69 |
+
|
| 70 |
+
# Get frequencies for this sequence
|
| 71 |
+
# [seq_len, head_dim//2]
|
| 72 |
+
seq_freqs = freqs[start_pos:start_pos + seq_len]
|
| 73 |
+
|
| 74 |
+
# Reshape x to pairs: [batch, n_head, seq_len, head_dim//2, 2]
|
| 75 |
+
# We rotate adjacent pairs of dimensions together
|
| 76 |
+
x_pairs = x.float().reshape(batch, n_head, seq_len, -1, 2)
|
| 77 |
+
|
| 78 |
+
# Convert to complex: [batch, n_head, seq_len, head_dim//2]
|
| 79 |
+
x_complex = torch.view_as_complex(x_pairs)
|
| 80 |
+
|
| 81 |
+
# Reshape freqs for broadcasting: [1, 1, seq_len, head_dim//2]
|
| 82 |
+
seq_freqs = seq_freqs.unsqueeze(0).unsqueeze(0)
|
| 83 |
+
|
| 84 |
+
# Rotate by multiplying complex numbers
|
| 85 |
+
x_rotated = x_complex * seq_freqs
|
| 86 |
+
|
| 87 |
+
# Convert back to real: [batch, n_head, seq_len, head_dim//2, 2]
|
| 88 |
+
x_out = torch.view_as_real(x_rotated)
|
| 89 |
+
|
| 90 |
+
# Flatten back: [batch, n_head, seq_len, head_dim]
|
| 91 |
+
x_out = x_out.reshape(batch, n_head, seq_len, head_dim)
|
| 92 |
+
|
| 93 |
+
return x_out.type_as(x)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class RotaryEmbedding(nn.Module):
|
| 97 |
+
"""
|
| 98 |
+
Module wrapper for rotary embeddings.
|
| 99 |
+
|
| 100 |
+
Precomputes and caches the frequency tensor.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(self, dim: int, max_seq_len: int = 4096, theta: float = 10000.0):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.dim = dim
|
| 106 |
+
self.max_seq_len = max_seq_len
|
| 107 |
+
self.theta = theta
|
| 108 |
+
|
| 109 |
+
# Precompute and register as buffer (saved with model but not trained)
|
| 110 |
+
freqs = precompute_freqs(dim, max_seq_len, theta)
|
| 111 |
+
self.register_buffer("freqs", freqs, persistent=False)
|
| 112 |
+
|
| 113 |
+
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
|
| 114 |
+
"""Apply RoPE to input tensor."""
|
| 115 |
+
return apply_rope(x, self.freqs, start_pos)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
print("Testing RoPE...")
|
| 120 |
+
print("=" * 50)
|
| 121 |
+
|
| 122 |
+
# Test parameters
|
| 123 |
+
batch, n_head, seq_len, head_dim = 2, 4, 16, 64
|
| 124 |
+
|
| 125 |
+
# Create rotary embedding
|
| 126 |
+
rope = RotaryEmbedding(dim=head_dim, max_seq_len=512)
|
| 127 |
+
|
| 128 |
+
# Create random Q and K
|
| 129 |
+
q = torch.randn(batch, n_head, seq_len, head_dim)
|
| 130 |
+
k = torch.randn(batch, n_head, seq_len, head_dim)
|
| 131 |
+
|
| 132 |
+
print(f"Input shape: {q.shape}")
|
| 133 |
+
|
| 134 |
+
# Apply RoPE
|
| 135 |
+
q_rotated = rope(q)
|
| 136 |
+
k_rotated = rope(k)
|
| 137 |
+
|
| 138 |
+
print(f"Output shape: {q_rotated.shape}")
|
| 139 |
+
|
| 140 |
+
# Verify relative position property
|
| 141 |
+
# Attention at (i, j) should only depend on (i - j), not absolute positions
|
| 142 |
+
print("\nVerifying relative position property...")
|
| 143 |
+
|
| 144 |
+
# Compute attention for two positions
|
| 145 |
+
attn_0_1 = (q_rotated[:, :, 0:1, :] @ k_rotated[:, :, 1:2, :].transpose(-2, -1))
|
| 146 |
+
attn_5_6 = (q_rotated[:, :, 5:6, :] @ k_rotated[:, :, 6:7, :].transpose(-2, -1))
|
| 147 |
+
|
| 148 |
+
# These should be very similar (same relative distance of 1)
|
| 149 |
+
diff = (attn_0_1 - attn_5_6).abs().mean().item()
|
| 150 |
+
print(f" Attention (0,1) vs (5,6) difference: {diff:.6f}")
|
| 151 |
+
print(f" (Should be very small - same relative distance)")
|
| 152 |
+
|
| 153 |
+
print("\nRoPE working!")
|