vigneshwar234 commited on
Commit
93d8b89
·
verified ·
1 Parent(s): a25d929

Add source: tmt/experiments/01_baseline.ipynb

Browse files
Files changed (1) hide show
  1. tmt/experiments/01_baseline.ipynb +110 -0
tmt/experiments/01_baseline.ipynb ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": ["# Experiment 01 — Vanilla Transformer Baseline\n", "Train a standard transformer on the same data as TMT for fair comparison."]
7
+ },
8
+ {
9
+ "cell_type": "code",
10
+ "execution_count": null,
11
+ "metadata": {},
12
+ "outputs": [],
13
+ "source": [
14
+ "import torch\n",
15
+ "import torch.nn as nn\n",
16
+ "from torch.optim import AdamW\n",
17
+ "from tmt.data.dataset import load_text_dataset\n",
18
+ "from tmt.training.scheduler import cosine_warmup_scheduler\n",
19
+ "import math\n",
20
+ "\n",
21
+ "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
22
+ "SEQ_LEN = 256\n",
23
+ "BATCH_SIZE = 16\n",
24
+ "\n",
25
+ "loaders = load_text_dataset('wikitext-2', seq_len=SEQ_LEN, batch_size=BATCH_SIZE)\n",
26
+ "print('Train batches:', len(loaders['train']))"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "# Standard transformer — same param budget as TMT\n",
36
+ "baseline = nn.Transformer(\n",
37
+ " d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6,\n",
38
+ " dim_feedforward=2048, dropout=0.1, batch_first=True\n",
39
+ ").to(DEVICE)\n",
40
+ "print(f'Baseline params: {sum(p.numel() for p in baseline.parameters())/1e6:.2f}M')"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "# Simple GPT-style decoder-only baseline using nn.TransformerDecoder\n",
50
+ "vocab_size = 50258 # gpt2 tokenizer size\n",
51
+ "\n",
52
+ "class BaselineGPT(nn.Module):\n",
53
+ " def __init__(self, vocab=vocab_size, d=512, heads=8, layers=6, seq=256):\n",
54
+ " super().__init__()\n",
55
+ " self.embed = nn.Embedding(vocab, d)\n",
56
+ " self.pos = nn.Embedding(seq, d)\n",
57
+ " layer = nn.TransformerEncoderLayer(d, heads, dim_feedforward=2048, batch_first=True)\n",
58
+ " self.transformer = nn.TransformerEncoder(layer, num_layers=layers)\n",
59
+ " self.proj = nn.Linear(d, vocab)\n",
60
+ " self.proj.weight = self.embed.weight\n",
61
+ "\n",
62
+ " def forward(self, x):\n",
63
+ " B, S = x.shape\n",
64
+ " pos = torch.arange(S, device=x.device).unsqueeze(0)\n",
65
+ " h = self.embed(x) + self.pos(pos)\n",
66
+ " mask = nn.Transformer.generate_square_subsequent_mask(S, device=x.device)\n",
67
+ " h = self.transformer(h, mask=mask, is_causal=True)\n",
68
+ " return self.proj(h)\n",
69
+ "\n",
70
+ "baseline = BaselineGPT().to(DEVICE)\n",
71
+ "print(f'BaselineGPT params: {sum(p.numel() for p in baseline.parameters())/1e6:.2f}M')"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "opt = AdamW(baseline.parameters(), lr=3e-4, weight_decay=0.1)\n",
81
+ "sched = cosine_warmup_scheduler(opt, warmup_steps=200, total_steps=2000)\n",
82
+ "baseline.train()\n",
83
+ "\n",
84
+ "losses = []\n",
85
+ "for step, batch in enumerate(loaders['train']):\n",
86
+ " if step >= 2000:\n",
87
+ " break\n",
88
+ " ids = batch['input_ids'].to(DEVICE)\n",
89
+ " x, y = ids[:, :-1], ids[:, 1:]\n",
90
+ " logits = baseline(x)\n",
91
+ " loss = nn.functional.cross_entropy(logits.reshape(-1, vocab_size), y.reshape(-1))\n",
92
+ " opt.zero_grad(); loss.backward()\n",
93
+ " nn.utils.clip_grad_norm_(baseline.parameters(), 1.0)\n",
94
+ " opt.step(); sched.step()\n",
95
+ " losses.append(loss.item())\n",
96
+ " if step % 100 == 0:\n",
97
+ " print(f'step={step:4d} loss={loss.item():.4f}')\n",
98
+ "\n",
99
+ "baseline_ppl = math.exp(sum(losses[-200:]) / 200)\n",
100
+ "print(f'\\nBaseline final perplexity: {baseline_ppl:.2f}')"
101
+ ]
102
+ }
103
+ ],
104
+ "metadata": {
105
+ "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
106
+ "language_info": {"name": "python", "version": "3.10.0"}
107
+ },
108
+ "nbformat": 4,
109
+ "nbformat_minor": 4
110
+ }