frawdllm-100m / embeddings.py
tsingla98's picture
Upload FrawdLLMForCausalLM
2eca14b verified
"""
Token and Position Embeddings for FrawdLLM.
This is the first layer of the model - converts token IDs into vectors
that the transformer can process.
Two lookup tables:
1. Token embeddings: WHAT the token is (vocab_size x n_embd)
2. Position embeddings: WHERE the token is (context_length x n_embd)
Final output = token_emb + pos_emb (just addition!)
"""
import torch
import torch.nn as nn
from .config import ModelConfig
class Embeddings(nn.Module):
"""
Combined token + position embeddings.
Input: token_ids [batch_size, seq_len] - integers from tokenizer
Output: vectors [batch_size, seq_len, n_embd] - dense representations
"""
def __init__(self, config: ModelConfig):
super().__init__() # Initialize nn.Module tracking
self.config = config
self.use_rope = config.use_rope
# Token embedding table: one vector per vocabulary word
# Shape: [vocab_size, n_embd] = [8192, 768]
self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)
# Position embedding table: one vector per position (only if NOT using RoPE)
# Shape: [context_length, n_embd] = [512, 768]
# With RoPE, position is encoded in attention via rotation instead
if not self.use_rope:
self.pos_emb = nn.Embedding(config.context_length, config.n_embd)
else:
self.pos_emb = None
# Dropout for regularization (usually 0 for small datasets)
self.dropout = nn.Dropout(config.dropout)
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
"""
Convert token IDs to embeddings.
Args:
token_ids: [batch_size, seq_len] tensor of token IDs
Returns:
[batch_size, seq_len, n_embd] tensor of embeddings
"""
batch_size, seq_len = token_ids.shape
# Safety check: don't exceed context window (relaxed for RoPE)
max_len = self.config.context_length * 4 if self.use_rope else self.config.context_length
if seq_len > max_len:
raise ValueError(
f"Sequence length {seq_len} exceeds maximum length {max_len}"
)
# Step 1: Look up token embeddings
# [batch_size, seq_len] -> [batch_size, seq_len, n_embd]
embeddings = self.token_emb(token_ids)
# Step 2: Add position embeddings (only if NOT using RoPE)
# With RoPE, position is encoded via rotation in attention instead
if not self.use_rope:
positions = torch.arange(seq_len, device=token_ids.device)
pos_emb = self.pos_emb(positions)
embeddings = embeddings + pos_emb
# Step 3: Apply dropout (if any)
embeddings = self.dropout(embeddings)
return embeddings
if __name__ == "__main__":
# Quick test to verify it works
from .config import get_config
print("Testing Embeddings...")
print("=" * 50)
# Use tiny config for testing
config = get_config("tiny")
print(f"Config: vocab={config.vocab_size}, n_embd={config.n_embd}, "
f"context={config.context_length}")
# Create embedding layer
emb = Embeddings(config)
# Count parameters
num_params = sum(p.numel() for p in emb.parameters())
print(f"Embedding parameters: {num_params:,}")
# Test forward pass
# Fake batch: 2 sequences of 4 tokens each
token_ids = torch.tensor([
[2, 531, 892, 12], # Sequence 1
[2, 100, 200, 3], # Sequence 2
])
print(f"\nInput shape: {token_ids.shape}")
print(f"Input tokens: {token_ids.tolist()}")
# Forward pass
output = emb(token_ids)
print(f"\nOutput shape: {output.shape}")
print(f"Each token is now a {output.shape[-1]}-dimensional vector")
# Show a snippet of the output
print(f"\nFirst token's vector (first 10 dims):")
print(f" {output[0, 0, :10].tolist()}")
print("\nEmbeddings working!")