transformer-from-scratch / train_copy_task.py
syedmohaiminulhoque's picture
Add copy task training demo
10835d2 verified
"""
Training Demo: Teach the Transformer to copy sequences.
The "copy task" is the classic smoke test for sequence-to-sequence models:
Input: [BOS, 5, 3, 8, 2, 7, EOS, PAD, PAD]
Output: [BOS, 5, 3, 8, 2, 7, EOS, PAD, PAD]
If the model learns to copy, it proves that:
- Encoder correctly represents the source sequence
- Cross-attention correctly attends to source positions
- Decoder correctly generates autoregressively
- Masking (causal + padding) works correctly
- The training loop (optimizer, LR schedule, label smoothing) works
We use a small model (d_model=64, 2 layers) so it trains in seconds on CPU.
"""
import torch
import torch.nn as nn
from transformer import Transformer, TransformerLRScheduler, greedy_decode
# Special tokens
PAD_IDX = 0
BOS_IDX = 1
EOS_IDX = 2
VOCAB_SIZE = 15 # Small vocabulary: tokens 3..14 are "real" tokens
def generate_copy_batch(batch_size: int, seq_len: int, device: torch.device):
"""
Generate a batch for the copy task.
Source: [random tokens, EOS, PAD...]
Target: [BOS, same random tokens, EOS, PAD...]
"""
lengths = torch.randint(3, seq_len - 1, (batch_size,))
src = torch.full((batch_size, seq_len), PAD_IDX, dtype=torch.long, device=device)
tgt = torch.full((batch_size, seq_len), PAD_IDX, dtype=torch.long, device=device)
for i in range(batch_size):
l = lengths[i].item()
tokens = torch.randint(3, VOCAB_SIZE, (l,))
src[i, :l] = tokens
src[i, l] = EOS_IDX
tgt[i, 0] = BOS_IDX
tgt[i, 1:l+1] = tokens
tgt[i, l+1] = EOS_IDX
return src, tgt
def train():
device = torch.device('cpu')
config = {
'src_vocab_size': VOCAB_SIZE,
'tgt_vocab_size': VOCAB_SIZE,
'd_model': 64,
'n_heads': 4,
'n_layers': 2,
'd_ff': 256,
'dropout': 0.0,
'max_len': 100,
'pad_idx': PAD_IDX,
'tie_weights': True,
}
model = Transformer(**config).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")
optimizer = torch.optim.Adam(
model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9
)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
batch_size = 32
seq_len = 10
n_steps = 3000
print(f"\nTraining copy task for {n_steps} steps...")
print(f"Batch size: {batch_size}, Seq length: {seq_len}")
print("-" * 50)
model.train()
for step in range(1, n_steps + 1):
src, tgt = generate_copy_batch(batch_size, seq_len, device)
tgt_input = tgt[:, :-1]
tgt_label = tgt[:, 1:]
logits = model(src, tgt_input)
loss = criterion(
logits.reshape(-1, logits.size(-1)),
tgt_label.reshape(-1),
)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
if step % 300 == 0 or step == 1:
preds = logits.argmax(dim=-1)
mask = tgt_label != PAD_IDX
correct = ((preds == tgt_label) & mask).sum().item()
total = mask.sum().item()
acc = correct / total * 100
lr = optimizer.param_groups[0]['lr']
print(f"Step {step:>5d} | Loss: {loss.item():.4f} | Acc: {acc:.1f}% | LR: {lr:.6f}")
print("\n" + "=" * 50)
print("EVALUATION: Greedy Decode Examples")
print("=" * 50)
model.eval()
n_correct = 0
n_total = 10
for i in range(n_total):
src, tgt = generate_copy_batch(1, seq_len, device)
src_tokens = []
for t in src[0]:
if t.item() in (PAD_IDX, EOS_IDX):
break
src_tokens.append(t.item())
decoded = greedy_decode(model, src, max_len=seq_len + 2, bos_idx=BOS_IDX, eos_idx=EOS_IDX)
decoded_tokens = []
for t in decoded[1:]:
if t.item() == EOS_IDX:
break
decoded_tokens.append(t.item())
match = src_tokens == decoded_tokens
n_correct += match
status = "✅" if match else "❌"
print(f" {status} Source: {src_tokens}")
print(f" Output: {decoded_tokens}")
if not match:
print(f" Expected: {src_tokens}")
print()
print(f"Copy accuracy: {n_correct}/{n_total} ({n_correct/n_total*100:.0f}%)")
if n_correct >= 8:
print("\n🎉 The Transformer learned the copy task! Model is working correctly.")
else:
print("\n⚠️ Model didn't fully converge. Try more steps or adjust hyperparameters.")
if __name__ == '__main__':
train()