| """ |
| Training Demo: Teach the Transformer to copy sequences. |
| |
| The "copy task" is the classic smoke test for sequence-to-sequence models: |
| Input: [BOS, 5, 3, 8, 2, 7, EOS, PAD, PAD] |
| Output: [BOS, 5, 3, 8, 2, 7, EOS, PAD, PAD] |
| |
| If the model learns to copy, it proves that: |
| - Encoder correctly represents the source sequence |
| - Cross-attention correctly attends to source positions |
| - Decoder correctly generates autoregressively |
| - Masking (causal + padding) works correctly |
| - The training loop (optimizer, LR schedule, label smoothing) works |
| |
| We use a small model (d_model=64, 2 layers) so it trains in seconds on CPU. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from transformer import Transformer, TransformerLRScheduler, greedy_decode |
|
|
| |
| PAD_IDX = 0 |
| BOS_IDX = 1 |
| EOS_IDX = 2 |
| VOCAB_SIZE = 15 |
|
|
|
|
| def generate_copy_batch(batch_size: int, seq_len: int, device: torch.device): |
| """ |
| Generate a batch for the copy task. |
| |
| Source: [random tokens, EOS, PAD...] |
| Target: [BOS, same random tokens, EOS, PAD...] |
| """ |
| lengths = torch.randint(3, seq_len - 1, (batch_size,)) |
| |
| src = torch.full((batch_size, seq_len), PAD_IDX, dtype=torch.long, device=device) |
| tgt = torch.full((batch_size, seq_len), PAD_IDX, dtype=torch.long, device=device) |
| |
| for i in range(batch_size): |
| l = lengths[i].item() |
| tokens = torch.randint(3, VOCAB_SIZE, (l,)) |
| |
| src[i, :l] = tokens |
| src[i, l] = EOS_IDX |
| |
| tgt[i, 0] = BOS_IDX |
| tgt[i, 1:l+1] = tokens |
| tgt[i, l+1] = EOS_IDX |
| |
| return src, tgt |
|
|
|
|
| def train(): |
| device = torch.device('cpu') |
| |
| config = { |
| 'src_vocab_size': VOCAB_SIZE, |
| 'tgt_vocab_size': VOCAB_SIZE, |
| 'd_model': 64, |
| 'n_heads': 4, |
| 'n_layers': 2, |
| 'd_ff': 256, |
| 'dropout': 0.0, |
| 'max_len': 100, |
| 'pad_idx': PAD_IDX, |
| 'tie_weights': True, |
| } |
| |
| model = Transformer(**config).to(device) |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"Model parameters: {total_params:,}") |
| |
| optimizer = torch.optim.Adam( |
| model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9 |
| ) |
| criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX) |
| |
| batch_size = 32 |
| seq_len = 10 |
| n_steps = 3000 |
| |
| print(f"\nTraining copy task for {n_steps} steps...") |
| print(f"Batch size: {batch_size}, Seq length: {seq_len}") |
| print("-" * 50) |
| |
| model.train() |
| for step in range(1, n_steps + 1): |
| src, tgt = generate_copy_batch(batch_size, seq_len, device) |
| |
| tgt_input = tgt[:, :-1] |
| tgt_label = tgt[:, 1:] |
| |
| logits = model(src, tgt_input) |
| loss = criterion( |
| logits.reshape(-1, logits.size(-1)), |
| tgt_label.reshape(-1), |
| ) |
| |
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| |
| if step % 300 == 0 or step == 1: |
| preds = logits.argmax(dim=-1) |
| mask = tgt_label != PAD_IDX |
| correct = ((preds == tgt_label) & mask).sum().item() |
| total = mask.sum().item() |
| acc = correct / total * 100 |
| lr = optimizer.param_groups[0]['lr'] |
| print(f"Step {step:>5d} | Loss: {loss.item():.4f} | Acc: {acc:.1f}% | LR: {lr:.6f}") |
| |
| print("\n" + "=" * 50) |
| print("EVALUATION: Greedy Decode Examples") |
| print("=" * 50) |
| |
| model.eval() |
| n_correct = 0 |
| n_total = 10 |
| |
| for i in range(n_total): |
| src, tgt = generate_copy_batch(1, seq_len, device) |
| |
| src_tokens = [] |
| for t in src[0]: |
| if t.item() in (PAD_IDX, EOS_IDX): |
| break |
| src_tokens.append(t.item()) |
| |
| decoded = greedy_decode(model, src, max_len=seq_len + 2, bos_idx=BOS_IDX, eos_idx=EOS_IDX) |
| decoded_tokens = [] |
| for t in decoded[1:]: |
| if t.item() == EOS_IDX: |
| break |
| decoded_tokens.append(t.item()) |
| |
| match = src_tokens == decoded_tokens |
| n_correct += match |
| |
| status = "✅" if match else "❌" |
| print(f" {status} Source: {src_tokens}") |
| print(f" Output: {decoded_tokens}") |
| if not match: |
| print(f" Expected: {src_tokens}") |
| print() |
| |
| print(f"Copy accuracy: {n_correct}/{n_total} ({n_correct/n_total*100:.0f}%)") |
| |
| if n_correct >= 8: |
| print("\n🎉 The Transformer learned the copy task! Model is working correctly.") |
| else: |
| print("\n⚠️ Model didn't fully converge. Try more steps or adjust hyperparameters.") |
|
|
|
|
| if __name__ == '__main__': |
| train() |
|
|