umm-dev commited on
Commit
0cad2e8
·
verified ·
1 Parent(s): 8e85857

Create gclm_train_example.py

Browse files
Files changed (1) hide show
  1. gclm_train_example.py +287 -0
gclm_train_example.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print("Starting...")
2
+
3
+ ###############################################
4
+ # CONFIGURATION — CUSTOMIZE EVERYTHING HERE
5
+ ###############################################
6
+
7
+ # ---- data / vocab ----
8
+ TXT_PATH = "data.txt"
9
+ TOKENIZER_NAME = "gpt2"
10
+ REDUCE_VOCAB = True
11
+ VOCAB_SAVE_PATH = "vocab_map.pt"
12
+
13
+ # ---- training ----
14
+ EPOCHS = 25
15
+ MICRO_BATCH_SIZE = 1
16
+ GRAD_ACCUM_STEPS = 8
17
+ LEARNING_RATE = 3e-4
18
+
19
+ # ---- model ----
20
+ D_MODEL = 256
21
+ N_LAYERS = 4
22
+ MAX_SEQ_LEN = 8192
23
+
24
+ LOCAL_KERNEL_SIZE = 5
25
+ GLOBAL_KERNEL_SIZE = 256
26
+ USE_GLOBAL_EVERY_N_LAYERS = 2
27
+
28
+ # ---- FFT conv ----
29
+ FFT_SIZE = 1024 # must be power of 2 and > GLOBAL_KERNEL_SIZE
30
+
31
+ # ---- checkpointing ----
32
+ SAVE_PATH = "model.pt"
33
+ SAVE_N_EPOCHS = 1
34
+
35
+ # ---- device ----
36
+ USE_DEVICE = "cuda"
37
+ USE_AMP = True
38
+ USE_ACTIVATION_CHECKPOINTING = False
39
+
40
+ # ---- torch.compile ----
41
+ COMPILE = False
42
+ COMPILE_MODE = "reduce-overhead"
43
+ COMPILE_BACKEND = "eager"
44
+
45
+ ###############################################
46
+ # END CONFIG
47
+ ###############################################
48
+
49
+ import os
50
+
51
+ # Windows cannot use expandable_segments — only enable on Linux.
52
+ if os.name != "nt":
53
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
54
+
55
+ import torch
56
+ import torch.nn as nn
57
+ import torch.nn.functional as F
58
+ from torch.utils.data import Dataset, DataLoader
59
+ from tqdm import tqdm
60
+ import tiktoken
61
+
62
+ # performance settings
63
+ torch.set_float32_matmul_precision("high")
64
+ torch.backends.cuda.matmul.allow_tf32 = True
65
+ torch.backends.cudnn.allow_tf32 = True
66
+
67
+ ###############################################################
68
+ # SPECIAL TOKENS
69
+ ###############################################################
70
+
71
+ PAD_ID = 0
72
+ SEP_ID = 1
73
+ EOS_ID = 2
74
+ OFFSET = 3
75
+
76
+ ###############################################################
77
+ # VOCAB
78
+ ###############################################################
79
+
80
+ def build_dataset_vocab(txt_path, tokenizer, save_path):
81
+ text = open(txt_path, "r", encoding="utf-8").read()
82
+ token_ids = tokenizer.encode(text)
83
+ used = sorted(set(token_ids))
84
+
85
+ id2new = {tok: i + OFFSET for i, tok in enumerate(used)}
86
+
87
+ torch.save({
88
+ "used_tokens": used,
89
+ "id2new": id2new,
90
+ "PAD_ID": PAD_ID,
91
+ "SEP_ID": SEP_ID,
92
+ "EOS_ID": EOS_ID,
93
+ }, save_path)
94
+
95
+ print(f"[OK] Vocab size: {len(used) + OFFSET}")
96
+ return used, id2new
97
+
98
+
99
+ ###############################################################
100
+ # DATASET
101
+ ###############################################################
102
+
103
+ class RemappedTextDataset(Dataset):
104
+ def __init__(self, path, tokenizer, id2new, max_len):
105
+ text = open(path, "r", encoding="utf-8").read()
106
+ raw = tokenizer.encode(text)
107
+ self.ids = [id2new.get(i, PAD_ID) for i in raw] + [EOS_ID]
108
+ self.max_len = max_len
109
+
110
+ def __len__(self):
111
+ return len(self.ids) - self.max_len - 1
112
+
113
+ def __getitem__(self, i):
114
+ x = self.ids[i:i+self.max_len]
115
+ y = self.ids[i+1:i+self.max_len+1]
116
+ return torch.tensor(x), torch.tensor(y)
117
+
118
+
119
+ ###############################################################
120
+ # GLOBAL + LOCAL CONVOLUTION
121
+ ###############################################################
122
+
123
+ class GlobalConv1D(nn.Module):
124
+ def __init__(self, d_model, kernel_size, fft_size):
125
+ super().__init__()
126
+ self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01)
127
+ self.kernel_size = kernel_size
128
+ self.fft_size = fft_size
129
+
130
+ def forward(self, x):
131
+ B, C, T = x.shape
132
+ K = min(self.kernel_size, T)
133
+
134
+ overlap = K - 1
135
+ block = self.fft_size - overlap
136
+
137
+ x = F.pad(x, (overlap, 0))
138
+ k = self.kernel[:, :K]
139
+ k = F.pad(k, (0, self.fft_size - K))
140
+ k_f = torch.fft.rfft(k, n=self.fft_size)
141
+
142
+ outs = []
143
+ pos = 0
144
+ while pos < T:
145
+ seg = x[..., pos:pos+self.fft_size]
146
+ if seg.shape[-1] < self.fft_size:
147
+ seg = F.pad(seg, (0, self.fft_size - seg.shape[-1]))
148
+
149
+ y = torch.fft.irfft(
150
+ torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0),
151
+ n=self.fft_size
152
+ )
153
+ outs.append(y[..., overlap:overlap+block])
154
+ pos += block
155
+
156
+ return torch.cat(outs, dim=-1)[..., :T]
157
+
158
+
159
+ class LocalConv1D(nn.Module):
160
+ def __init__(self, d_model, k):
161
+ super().__init__()
162
+ self.k = k
163
+ self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model)
164
+ self.pw = nn.Conv1d(d_model, d_model, 1)
165
+
166
+ def forward(self, x):
167
+ x = F.pad(x, (self.k - 1, 0))
168
+ return self.pw(F.relu(self.dw(x)))
169
+
170
+
171
+ class Block(nn.Module):
172
+ def __init__(self, d_model, use_global):
173
+ super().__init__()
174
+ self.use_global = use_global
175
+
176
+ self.ln1 = nn.LayerNorm(d_model)
177
+ self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE)
178
+
179
+ if use_global:
180
+ self.ln2 = nn.LayerNorm(d_model)
181
+ self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE)
182
+
183
+ self.ln3 = nn.LayerNorm(d_model)
184
+ self.ff = nn.Sequential(
185
+ nn.Linear(d_model, d_model*4),
186
+ nn.GELU(),
187
+ nn.Linear(d_model*4, d_model)
188
+ )
189
+
190
+ def forward(self, x):
191
+ x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2)
192
+ if self.use_global:
193
+ x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2)
194
+ return x + self.ff(self.ln3(x))
195
+
196
+
197
+ class GCLM(nn.Module):
198
+ def __init__(self, vocab):
199
+ super().__init__()
200
+ self.emb = nn.Embedding(vocab, D_MODEL)
201
+ self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL)
202
+
203
+ self.layers = nn.ModuleList([
204
+ Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0)
205
+ for i in range(N_LAYERS)
206
+ ])
207
+
208
+ self.ln = nn.LayerNorm(D_MODEL)
209
+ self.head = nn.Linear(D_MODEL, vocab)
210
+
211
+ def forward(self, x):
212
+ T = x.size(1)
213
+ h = self.emb(x) + self.pos(torch.arange(T, device=x.device))
214
+ for layer in self.layers:
215
+ h = layer(h)
216
+ return self.head(self.ln(h))
217
+
218
+
219
+ ###############################################################
220
+ # TRAINING LOOP
221
+ ###############################################################
222
+
223
+ def train():
224
+ device = USE_DEVICE if torch.cuda.is_available() else "cpu"
225
+ print("[INFO] Device:", device)
226
+
227
+ tok = tiktoken.get_encoding(TOKENIZER_NAME)
228
+ used, id2new = build_dataset_vocab(TXT_PATH, tok, VOCAB_SAVE_PATH)
229
+ vocab = len(used) + OFFSET
230
+
231
+ ds = RemappedTextDataset(TXT_PATH, tok, id2new, MAX_SEQ_LEN)
232
+ dl = DataLoader(ds, batch_size=MICRO_BATCH_SIZE, shuffle=True)
233
+
234
+ model = GCLM(vocab).to(device)
235
+
236
+ # 🔁 RESUME IF CHECKPOINT EXISTS
237
+ if os.path.exists(SAVE_PATH):
238
+ model.load_state_dict(torch.load(SAVE_PATH, map_location=device))
239
+ print(f"[RESUME] Loaded existing checkpoint from {SAVE_PATH}")
240
+
241
+ if device == "cuda" and COMPILE:
242
+ print("[INFO] Compiling model with torch.compile...")
243
+ model = torch.compile(
244
+ model,
245
+ mode=COMPILE_MODE,
246
+ fullgraph=False,
247
+ backend=COMPILE_BACKEND
248
+ )
249
+
250
+ opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
251
+ loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)
252
+
253
+ scaler = torch.amp.GradScaler("cuda", enabled=(device=="cuda" and USE_AMP))
254
+
255
+ for ep in range(EPOCHS):
256
+ print(f"\nEpoch {ep+1}/{EPOCHS}")
257
+ opt.zero_grad(set_to_none=True)
258
+
259
+ for i, (x, y) in enumerate(tqdm(dl)):
260
+ x, y = x.to(device), y.to(device)
261
+
262
+ with torch.amp.autocast("cuda", enabled=(device=="cuda" and USE_AMP)):
263
+ logits = model(x)
264
+ loss = loss_fn(logits.reshape(-1, vocab), y.reshape(-1))
265
+ loss = loss / GRAD_ACCUM_STEPS
266
+
267
+ scaler.scale(loss).backward()
268
+
269
+ if (i+1) % GRAD_ACCUM_STEPS == 0:
270
+ scaler.step(opt)
271
+ scaler.update()
272
+ opt.zero_grad(set_to_none=True)
273
+
274
+ if SAVE_N_EPOCHS and (ep+1) % SAVE_N_EPOCHS == 0:
275
+ torch.save(model.state_dict(), SAVE_PATH)
276
+ print("[OK] Saved checkpoint.")
277
+
278
+ torch.save(model.state_dict(), SAVE_PATH)
279
+ print("[DONE] Training complete.")
280
+
281
+
282
+ ###############################################################
283
+ # ENTRY POINT
284
+ ###############################################################
285
+
286
+ if __name__ == "__main__":
287
+ train()