SimpleTransformer / transformer.py
eshangj's picture
Upload folder using huggingface_hub
8066cf3 verified
"""
Author: Eshan Jayasundara
Last Updated: 2nd of March 2025
Created: 28th of February 2025
___
About:
└── Single head transformer (Transformer with self-attention training with teacher-forcing)
___
Training:
└── Teacher Forcing (Baseline)
β”œβ”€β”€ During training, the actual ground-truth tokens (from the dataset) are fed as input to the decoder instead of using the model’s own predictions.
β”œβ”€β”€ This makes training faster and ensures the model learns accurate token-to-token mappings.
└── Drawback: At inference time, the model doesn't see ground-truth inputs, so errors can accumulate (called exposure bias).
___
vocabulary dataset (from huggingface):
└── "yukiarimo/english-vocabulary"
___
Architecture:
Encoder
β”œβ”€β”€ Input text
β”‚ └── Eg: "Hello, how are you?"
β”œβ”€β”€ Remove punctuation from input text
β”œβ”€β”€ Input tokenization
β”œβ”€β”€ Embedding lookup with torch.nn.Embedding
β”œβ”€β”€ Positional encoding (sin, cosine)
β”œβ”€β”€ Self-attention
β”‚ β”œβ”€β”€ single-head
β”‚ β”œβ”€β”€ Q = Wq @ Embedding
β”‚ β”œβ”€β”€ K = Wk @ Embedding
β”‚ └── V = Wv @ Embedding
β”œβ”€β”€ Add and norm
β”œβ”€β”€ Feed forward layer
β”‚ β”œβ”€β”€ 2 hidden layers
β”‚ β”œβ”€β”€ ReLU as the activation in hidden layer
β”‚ β”œβ”€β”€ No activation at the output layer
β”‚ └── nn.Linear(in_features=embedding_dim, out_features=d_ff), nn.ReLU(), nn.Linear(in_features=d_ff, out_features=embedding_dim)
β”œβ”€β”€ Add and norm (again)
└── Save encoder out to be used in cross attention
Decoder
β”œβ”€β”€ Decoder teacher text (same as the target text but shifted right)
β”‚ β”œβ”€β”€ Eg: Decoder teacher text - "<SOS> hello, I'm fine."
β”‚ └── Eg: target text - "hello, I'm fine. <EOS>"
β”œβ”€β”€ Remove punctuation from input text
β”œβ”€β”€ Input tokenization
β”œβ”€β”€ Embedding lookup with torch.nn.Embedding
β”œβ”€β”€ Positional encoding (sin, cosine)
β”œβ”€β”€ Masked-self-attention (single-head, new class signature for masked self attention introduced)
β”‚ β”œβ”€β”€ single-head
β”‚ β”œβ”€β”€ causal mask with triangular matrix
β”‚ β”œβ”€β”€ Q = Wq @ Embedding
β”‚ β”œβ”€β”€ K = Wk @ Embedding
β”‚ └── V = Wv @ Embedding
β”œβ”€β”€ Add and norm
β”œβ”€β”€ Cross attention (same class signature used in the encoder self-attention can be used)
β”‚ β”œβ”€β”€ single-head
β”‚ β”œβ”€β”€ Q = Wq @ Add and normalized output from masked-self-attention
β”‚ β”œβ”€β”€ K = Wk @ Encoder output
β”‚ └── V = Wv @ Encoder output
β”œβ”€β”€ Add and norm
β”œβ”€β”€ Feed forward layer
β”‚ β”œβ”€β”€ 2 hidden layers
β”‚ β”œβ”€β”€ ReLU as the activation in hidden layer
β”‚ β”œβ”€β”€ No activation at the output layer
β”‚ └── nn.Linear(in_features=embedding_dim, out_features=d_ff), nn.ReLU(), nn.Linear(in_features=d_ff, out_features=embedding_dim)
β”œβ”€β”€ Add and norm (again)
└── Linear layer (No activation or softmax as in 'Attention is all you need' is used here)
Optimization
β”œβ”€β”€ Initialize the Adam optimizer with the model’s parameters and a specified learning rate.
β”‚ └── self.optimizer = torch.optim.Adam(params=self.parameters, lr=learning_rate)
β”œβ”€β”€ Before computing gradients for the current batch, we reset any existing gradients from the previous iteration.
β”‚ └── self.optimizer.zero_grad()
β”œβ”€β”€ The model takes in `input_tokens` and `decoder_teacher_tokens` and performs a forward pass to compute `logits`
β”‚ └── logits = self.forward(input_tokens, decoder_teacher_tokens)
β”œβ”€β”€ The cross-entropy loss
β”‚ β”œβ”€β”€ Measures the difference between the predicted token distribution (logits) and the actual target tokens (decoder_target_tokens).
β”‚ β”œβ”€β”€ It expects logits to have raw scores (not probabilities), and it applies softmax internally.
β”‚ └── loss = F.cross_entropy(logits, decoder_target_tokens)
β”œβ”€β”€ Compute the gradients of the loss with respect to all trainable parameters in the model using automatic differentiation (backpropagation).
β”‚ └── loss.backward()
└── Optimizer updates the model's weights using the computed gradients.
└── self.optimizer.step()
After training, to calculate the output tokens -> text, 'Autoregressive text generation' is used (one word at a time)
β”œβ”€β”€ Start with <SOS>. (Initial input to the decoder) but input to the encoder is the `prompt`.
β”œβ”€β”€ Model predicts the next token.
β”œβ”€β”€ Append the predicted token to the sequence.
β”œβ”€β”€ Repeat until an <EOS> token or max length is reached.
└── For illustration let's use words instead of tokens(numerical representation)
<SOS>
<SOS> hello
<SOS> hello I'm
<SOS> hello I'm good
<SOS> hello I'm good <EOS>
___
Feauter Improvements:
β”œβ”€β”€ Multi-head attention instead of single-head attention.
β”œβ”€β”€ Layer normalization instead of simple mean-variance normalization.
└── Dropout layers for better generalization.
"""
from datasets import load_dataset
import torch
import torch.nn as nn
import string
import torch.nn.functional as F
# SELECT DEVICE
if torch.cuda.is_available():
device = torch.device('cuda:1')
print(f"Using Device: {device} | Name: {torch.cuda.get_device_name(0)}")
else:
device = torch.device('cpu')
print(f"Using Device: {device}")
# SINGLE HEAD ATTENTION
class SingleHeadAttention(torch.nn.Module):
def __init__(self, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
self.query_layer = torch.nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
self.key_layer = torch.nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
self.value_layer = torch.nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
def forward(self, q_embedding, k_embedding, v_embedding, attention_mask):
Q = self.query_layer.forward(q_embedding)
K = self.key_layer.forward(k_embedding)
V = self.value_layer.forward(v_embedding)
# softmax over last dimension
attention_scores = (torch.matmul(Q, K.transpose(-2, -1)) / self.embedding_dim ** 0.5).float()
# Apply attention mask (if provided)
if attention_mask is not None:
attention_scores = attention_scores.masked_fill(attention_mask == 0, torch.finfo(attention_scores.dtype).min)
# Compute attention weights using softmax
attention_weights = F.softmax(attention_scores, dim=-1) # (batch_size, seq_len, seq_len)
# Compute attention output
attention_output = torch.matmul(attention_weights, V) # (batch_size, seq_len, embedding_dim)
return attention_output, attention_weights
# FEED FORWARD NN
class FeedForwardLayer(torch.nn.Module):
def __init__(self, embedding_dim=64, d_ff=256):
super().__init__()
self.fc1 = torch.nn.Linear(in_features=embedding_dim, out_features=d_ff)
self.fc2 = torch.nn.Linear(in_features=d_ff, out_features=embedding_dim)
self.activation = torch.nn.ReLU()
def forward(self, x):
return self.fc2.forward(
self.activation(
self.fc1.forward(x)
)
)
# MASKED ATTENTION
class DecoderMaskedAttention(nn.Module):
def __init__(self, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
self.query_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
self.key_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
self.value_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
def forward(self, q_embedding, k_embedding, v_embedding, attention_mask=None):
# Linear transformations
Q = self.query_layer(q_embedding) # (seq_len, embedding_dim)
K = self.key_layer(k_embedding) # (seq_len, embedding_dim)
V = self.value_layer(v_embedding) # (seq_len, embedding_dim)
# Scaled dot-product attention scores
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embedding_dim ** 0.5) # (batch_size, seq_len, seq_len)
# Create causal mask
seq_len = q_embedding.shape[0]
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool() # Upper triangular matrix
# Apply causal mask to attention scores
attention_scores = attention_scores.masked_fill(causal_mask, torch.finfo(attention_scores.dtype).min)
# Apply additional attention mask (if provided)
if attention_mask is not None:
attention_scores = attention_scores.masked_fill(attention_mask == 0, torch.finfo(attention_scores.dtype).min)
# Compute attention weights using softmax
attention_weights = F.softmax(attention_scores, dim=-1) # (seq_len, seq_len)
# Compute attention output
attention_output = torch.matmul(attention_weights, V) # (seq_len, embedding_dim)
return attention_output, attention_weights
class Transformer(torch.nn.Module):
def __init__(self, embedding_dim, learning_rate=1e-3, vocab_dataset="yukiarimo/english-vocabulary", split="train"):
super().__init__()
# SETUP VOCABULARY
self.vocab_df = load_dataset(vocab_dataset, split=split).to_pandas()
remove_indices = self.vocab_df[(self.vocab_df["text"]=='PAD') | (self.vocab_df["text"]=='SOS') | (self.vocab_df["text"]=='EOS')].index
self.vocab_df = self.vocab_df.drop(remove_indices, axis=0)
self.vocab_df.loc[0, "text"] = '<PAD>'
self.vocab_df.loc[1, "text"] = '<UNK>'
self.vocab_df.loc[2, "text"] = '<SOS>'
self.vocab_df.loc[3, "text"] = '<EOS>'
self.vocab_size = self.vocab_df.shape[0]
self.vocab_df['idx'] = range(0, self.vocab_size)
self.vocab_df = self.vocab_df.set_index("text")
self.vocab = self.vocab_df["idx"].to_dict()
# INITIALIZE ALL TRAINABLE MODELS
self.embedding_fn = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=embedding_dim)
self.encoder_self_attention = SingleHeadAttention(embedding_dim=embedding_dim)
self.encoder_ff = FeedForwardLayer(embedding_dim=embedding_dim, d_ff=embedding_dim * 4)
self.cross_attention = SingleHeadAttention(embedding_dim=embedding_dim)
self.decoder_masked_attention = DecoderMaskedAttention(embedding_dim=embedding_dim)
self.decoder_ff = FeedForwardLayer(embedding_dim=embedding_dim, d_ff=embedding_dim * 4)
self.linear = nn.Linear(in_features=embedding_dim, out_features=self.vocab_size)
# PARAMETERS OF LEARNABLE MODELS
self.parameters = list(self.embedding_fn.parameters()) + \
list(self.encoder_self_attention.parameters()) + \
list(self.encoder_ff.parameters()) + \
list(self.cross_attention.parameters()) + \
list(self.decoder_masked_attention.parameters()) + \
list(self.decoder_ff.parameters()) + \
list(self.linear.parameters())
# OPTIMIZER
self.optimizer = torch.optim.Adam(params=self.parameters, lr=learning_rate)
# INPUT TEXT HANDLING
def remove_punctuation(self, text):
return text.translate(str.maketrans("", "", string.punctuation))
def tokenize(self, text, unk_token="<UNK>"):
tokens = text.strip().split()
return torch.tensor([self.vocab.get(token, self.vocab.get(unk_token)) for token in tokens], device=device)
def positional_encoding(self, embedding, max_len, embedding_dim=64):
pe = torch.zeros(max_len, embedding_dim, device=device)
# Create a tensor of positions (0, 1, 2, ..., max_len - 1)
position = torch.arange(0, max_len, dtype=torch.float, device=device).unsqueeze(1)
# Compute the division term for the frequency
div_term = torch.exp(torch.arange(0, embedding_dim, 2, device=device).float() * (torch.log(torch.tensor(10000.0, device=device))) / embedding_dim)
# Apply sine to even indices and cosine to odd indices
pe[:, 0::2] = torch.sin(position / div_term) # Even dimensions
pe[:, 1::2] = torch.cos(position / div_term) # Odd dimensions
return embedding + pe
# ADD AND NORM
def add_norm(self, old_tensor, new_tensor):
addition = old_tensor + new_tensor
norm = (addition - addition.mean(dim=-1, keepdim=True)) / addition.std(dim=-1, keepdim=True)
return norm
# ENCODER
def encoder(self, encoder_input_tokens):
encoder_input_embeddings = self.embedding_fn(encoder_input_tokens).to(device=device)
encoder_input_pos_embeddings = self.positional_encoding(encoder_input_embeddings, max_len=encoder_input_embeddings.shape[0], embedding_dim=64).to(device=device)
encoder_self_attention_out, _ = self.encoder_self_attention.forward(
q_embedding=encoder_input_pos_embeddings,
k_embedding=encoder_input_pos_embeddings,
v_embedding=encoder_input_pos_embeddings,
attention_mask=None
)
add_norm_encoder_self_attention_out = self.add_norm(old_tensor=encoder_input_pos_embeddings, new_tensor=encoder_self_attention_out.to(device=device)).to(device=device)
encoder_ff_out = self.encoder_ff.forward(add_norm_encoder_self_attention_out).to(device=device)
add_norm_encoder_ff_out = self.add_norm(old_tensor=add_norm_encoder_self_attention_out, new_tensor=encoder_ff_out).to(device=device)
return add_norm_encoder_ff_out
# DECODER
def decoder(self, decoder_teacher_tokens, encoder_out):
decoder_teacher_embeddings = self.embedding_fn(decoder_teacher_tokens).to(device=device)
decoder_teacher_pos_embeddings = self.positional_encoding(decoder_teacher_embeddings, max_len=decoder_teacher_embeddings.shape[0], embedding_dim=64).to(device=device)
decoder_masked_attention_out, _ = self.decoder_masked_attention.forward(
q_embedding=decoder_teacher_pos_embeddings,
k_embedding=decoder_teacher_pos_embeddings,
v_embedding=decoder_teacher_pos_embeddings,
attention_mask=None
)
add_norm_decoder_masked_attention_out = self.add_norm(old_tensor=decoder_teacher_pos_embeddings, new_tensor=decoder_masked_attention_out.to(device=device)).to(device=device)
cross_attention_out, _ = self.cross_attention.forward(
q_embedding=add_norm_decoder_masked_attention_out,
k_embedding=encoder_out,
v_embedding=encoder_out,
attention_mask=None
)
add_norm_cross_attention_out = self.add_norm(old_tensor=add_norm_decoder_masked_attention_out, new_tensor=cross_attention_out.to(device=device)).to(device=device)
decoder_ff_out = self.decoder_ff.forward(add_norm_cross_attention_out).to(device=device)
add_norm_decoder_ff_out = self.add_norm(old_tensor=add_norm_cross_attention_out, new_tensor=decoder_ff_out).to(device=device)
logits = self.linear.forward(add_norm_decoder_ff_out).to(device=device)
return logits
# FORWARD PASS THROUGH ENCODER and DECODER
def forward(self, encoder_input_tokens, decoder_teacher_tokens):
encoder_out = self.encoder(encoder_input_tokens)
decoder_out = self.decoder(decoder_teacher_tokens, encoder_out=encoder_out)
return decoder_out
# TRAIN the TRANSFORMER
def train(self, dataset, epochs=100):
for epoch in range(epochs):
total_loss = 0
for input_text, output_text in dataset:
encoder_input_text = self.remove_punctuation(input_text)
target_text = self.remove_punctuation(output_text)
decoder_teacher_text = "<SOS> " + target_text
decoder_target_text = target_text + " <EOS>"
encoder_input_tokens = self.tokenize(encoder_input_text)
decoder_teacher_tokens = self.tokenize(decoder_teacher_text)
decoder_target_tokens = self.tokenize(decoder_target_text)
self.optimizer.zero_grad()
logits = self.forward(encoder_input_tokens=encoder_input_tokens, decoder_teacher_tokens=decoder_teacher_tokens).to(device=device)
loss = F.cross_entropy(logits, decoder_target_tokens)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
if (epoch+1) % 10 == 0:
print(f"Epoch {epoch+1:04d} - Loss: {total_loss:.4f}")
print("*** END ***\n")
# GET PREDICTED TOKENS
def predict_tokens(self, encoder_input_tokens, max_output_len=20):
encoder_out = self.encoder(encoder_input_tokens).to(device=device)
decoder_input = [self.vocab["<SOS>"]]
for _ in range(max_output_len):
current_decoder_tokens = torch.tensor(decoder_input).to(device=device)
pred_index = torch.argmax(self.decoder(current_decoder_tokens, encoder_out).to(device=device)[-1, :]).item()
decoder_input.append(pred_index)
if pred_index == self.vocab["<EOS>"]:
break
return decoder_input
# GET PREDICTED TEXT
def predict_text(self, encoder_input_tokens):
return ' '.join(
[self.vocab_df[self.vocab_df['idx'] == token].index.values[0] \
for token in self.predict_tokens(encoder_input_tokens=encoder_input_tokens)]
)