|
|
""" |
|
|
Token and Position Embeddings for FrawdLLM. |
|
|
|
|
|
This is the first layer of the model - converts token IDs into vectors |
|
|
that the transformer can process. |
|
|
|
|
|
Two lookup tables: |
|
|
1. Token embeddings: WHAT the token is (vocab_size x n_embd) |
|
|
2. Position embeddings: WHERE the token is (context_length x n_embd) |
|
|
|
|
|
Final output = token_emb + pos_emb (just addition!) |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from .config import ModelConfig |
|
|
|
|
|
|
|
|
class Embeddings(nn.Module): |
|
|
""" |
|
|
Combined token + position embeddings. |
|
|
|
|
|
Input: token_ids [batch_size, seq_len] - integers from tokenizer |
|
|
Output: vectors [batch_size, seq_len, n_embd] - dense representations |
|
|
""" |
|
|
|
|
|
def __init__(self, config: ModelConfig): |
|
|
super().__init__() |
|
|
|
|
|
self.config = config |
|
|
self.use_rope = config.use_rope |
|
|
|
|
|
|
|
|
|
|
|
self.token_emb = nn.Embedding(config.vocab_size, config.n_embd) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not self.use_rope: |
|
|
self.pos_emb = nn.Embedding(config.context_length, config.n_embd) |
|
|
else: |
|
|
self.pos_emb = None |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
|
|
def forward(self, token_ids: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Convert token IDs to embeddings. |
|
|
|
|
|
Args: |
|
|
token_ids: [batch_size, seq_len] tensor of token IDs |
|
|
|
|
|
Returns: |
|
|
[batch_size, seq_len, n_embd] tensor of embeddings |
|
|
""" |
|
|
batch_size, seq_len = token_ids.shape |
|
|
|
|
|
|
|
|
max_len = self.config.context_length * 4 if self.use_rope else self.config.context_length |
|
|
if seq_len > max_len: |
|
|
raise ValueError( |
|
|
f"Sequence length {seq_len} exceeds maximum length {max_len}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
embeddings = self.token_emb(token_ids) |
|
|
|
|
|
|
|
|
|
|
|
if not self.use_rope: |
|
|
positions = torch.arange(seq_len, device=token_ids.device) |
|
|
pos_emb = self.pos_emb(positions) |
|
|
embeddings = embeddings + pos_emb |
|
|
|
|
|
|
|
|
embeddings = self.dropout(embeddings) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
from .config import get_config |
|
|
|
|
|
print("Testing Embeddings...") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
config = get_config("tiny") |
|
|
print(f"Config: vocab={config.vocab_size}, n_embd={config.n_embd}, " |
|
|
f"context={config.context_length}") |
|
|
|
|
|
|
|
|
emb = Embeddings(config) |
|
|
|
|
|
|
|
|
num_params = sum(p.numel() for p in emb.parameters()) |
|
|
print(f"Embedding parameters: {num_params:,}") |
|
|
|
|
|
|
|
|
|
|
|
token_ids = torch.tensor([ |
|
|
[2, 531, 892, 12], |
|
|
[2, 100, 200, 3], |
|
|
]) |
|
|
|
|
|
print(f"\nInput shape: {token_ids.shape}") |
|
|
print(f"Input tokens: {token_ids.tolist()}") |
|
|
|
|
|
|
|
|
output = emb(token_ids) |
|
|
|
|
|
print(f"\nOutput shape: {output.shape}") |
|
|
print(f"Each token is now a {output.shape[-1]}-dimensional vector") |
|
|
|
|
|
|
|
|
print(f"\nFirst token's vector (first 10 dims):") |
|
|
print(f" {output[0, 0, :10].tolist()}") |
|
|
|
|
|
print("\nEmbeddings working!") |
|
|
|