Files changed (4) hide show
  1. chatgclm_base_2.9M.pt +3 -0
  2. sample.py +281 -0
  3. train_gclm_base.py +378 -0
  4. vocab_map.pt +3 -0
chatgclm_base_2.9M.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce848df8f00a516cfd77fd92f91f683adfde65a448635a52f126a81928ef43db
3
+ size 11769469
sample.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import tiktoken
7
+
8
+ # ----------------- CONFIG -----------------
9
+ MODEL_PATH = "chatgclm_base_2.9M.pt"
10
+ VOCAB_PATH = "vocab_map.pt"
11
+ TOKENIZER_NAME = "gpt2"
12
+
13
+ # Defined in training script
14
+ D_MODEL = 256
15
+ N_LAYERS = 4
16
+ MAX_SEQ_LEN = 1024
17
+ LOCAL_KERNEL_SIZE = 5
18
+ GLOBAL_KERNEL_SIZE = 256
19
+ USE_GLOBAL_EVERY_N_LAYERS = 2
20
+ FFT_SIZE = 1024
21
+
22
+ PAD_ID = 0
23
+ SEP_ID = 1
24
+ EOS_ID = 2
25
+ OFFSET = 3
26
+ # ------------------------------------------
27
+
28
+ # ----------------- MODEL DEF -----------------
29
+ class GlobalConv1D(nn.Module):
30
+ def __init__(self, d_model, kernel_size, fft_size):
31
+ super().__init__()
32
+ self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01)
33
+ self.kernel_size = kernel_size
34
+ self.fft_size = fft_size
35
+
36
+ def forward(self, x):
37
+ B, C, T = x.shape
38
+ K = min(self.kernel_size, T)
39
+
40
+ overlap = K - 1
41
+ block = self.fft_size - overlap
42
+
43
+ x = F.pad(x, (overlap, 0))
44
+ k = self.kernel[:, :K]
45
+ k = F.pad(k, (0, self.fft_size - K))
46
+ k_f = torch.fft.rfft(k, n=self.fft_size)
47
+
48
+ outs = []
49
+ pos = 0
50
+ while pos < T:
51
+ seg = x[..., pos:pos+self.fft_size]
52
+ if seg.shape[-1] < self.fft_size:
53
+ seg = F.pad(seg, (0, self.fft_size - seg.shape[-1]))
54
+
55
+ y = torch.fft.irfft(
56
+ torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0),
57
+ n=self.fft_size
58
+ )
59
+ outs.append(y[..., overlap:overlap+block])
60
+ pos += block
61
+
62
+ return torch.cat(outs, dim=-1)[..., :T]
63
+
64
+
65
+ class LocalConv1D(nn.Module):
66
+ def __init__(self, d_model, k):
67
+ super().__init__()
68
+ self.k = k
69
+ self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model)
70
+ self.pw = nn.Conv1d(d_model, d_model, 1)
71
+
72
+ def forward(self, x):
73
+ x = F.pad(x, (self.k - 1, 0))
74
+ return self.pw(F.relu(self.dw(x)))
75
+
76
+
77
+ class Block(nn.Module):
78
+ def __init__(self, d_model, use_global):
79
+ super().__init__()
80
+ self.use_global = use_global
81
+
82
+ self.ln1 = nn.LayerNorm(d_model)
83
+ self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE)
84
+
85
+ if use_global:
86
+ self.ln2 = nn.LayerNorm(d_model)
87
+ self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE)
88
+
89
+ self.ln3 = nn.LayerNorm(d_model)
90
+ self.ff = nn.Sequential(
91
+ nn.Linear(d_model, d_model*4),
92
+ nn.GELU(),
93
+ nn.Linear(d_model*4, d_model)
94
+ )
95
+
96
+ def forward(self, x):
97
+ x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2)
98
+ if self.use_global:
99
+ x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2)
100
+ return x + self.ff(self.ln3(x))
101
+
102
+
103
+ class GCLM(nn.Module):
104
+ def __init__(self, vocab):
105
+ super().__init__()
106
+ self.emb = nn.Embedding(vocab, D_MODEL)
107
+ self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL)
108
+
109
+ self.layers = nn.ModuleList([
110
+ Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0)
111
+ for i in range(N_LAYERS)
112
+ ])
113
+
114
+ self.ln = nn.LayerNorm(D_MODEL)
115
+ self.head = nn.Linear(D_MODEL, vocab)
116
+
117
+ # Weight tying
118
+ self.head.weight = self.emb.weight
119
+
120
+ def forward(self, x):
121
+ T = x.size(1)
122
+ h = self.emb(x) + self.pos(torch.arange(T, device=x.device))
123
+ for layer in self.layers:
124
+ h = layer(h)
125
+ return self.head(self.ln(h))
126
+
127
+
128
+ # ----------------- UTILS -----------------
129
+ def load_model_and_vocab(device):
130
+ if not os.path.exists(VOCAB_PATH):
131
+ print(f"[ERROR] Vocab file not found: {VOCAB_PATH}")
132
+ return None, None, None
133
+
134
+ vocab_data = torch.load(VOCAB_PATH, map_location="cpu")
135
+ used_tokens = vocab_data["used_tokens"]
136
+ id2new = vocab_data["id2new"]
137
+ vocab_size = len(used_tokens) + OFFSET
138
+
139
+ print(f"[INFO] Vocab loaded. Size: {vocab_size}")
140
+
141
+ model = GCLM(vocab_size).to(device)
142
+
143
+ if os.path.exists(MODEL_PATH):
144
+ print(f"[INFO] Loading model from {MODEL_PATH}...")
145
+ state_dict = torch.load(MODEL_PATH, map_location=device)
146
+ model.load_state_dict(state_dict)
147
+ model.eval()
148
+ else:
149
+ print(f"[ERROR] Model file not found: {MODEL_PATH}")
150
+ return None, None, None
151
+
152
+ return model, used_tokens, id2new
153
+
154
+ @torch.no_grad()
155
+ def generate(model, prompt, tokenizer, id2new, used_tokens, device, max_new_tokens=200, temperature=0.8, top_k=50):
156
+ model.eval()
157
+
158
+ # Encode prompt
159
+ raw_ids = tokenizer.encode(prompt)
160
+ input_ids = []
161
+
162
+ # Map to model IDs
163
+ for rid in raw_ids:
164
+ if rid in id2new:
165
+ input_ids.append(id2new[rid])
166
+ else:
167
+ # Skip unknown tokens
168
+ continue
169
+
170
+ if not input_ids:
171
+ print("[WARN] No known tokens in prompt.")
172
+ input_ids = [PAD_ID] # Should not happen ideally
173
+
174
+ x = torch.tensor([input_ids], dtype=torch.long, device=device)
175
+
176
+ generated = []
177
+
178
+ for _ in range(max_new_tokens):
179
+ # Crop to max seq len
180
+ if x.size(1) > MAX_SEQ_LEN:
181
+ ctx = x[:, -MAX_SEQ_LEN:]
182
+ else:
183
+ ctx = x
184
+
185
+ logits = model(ctx)
186
+ next_token_logits = logits[:, -1, :] / temperature
187
+
188
+ # Optional: Top-k sampling
189
+ if top_k is not None:
190
+ v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
191
+ next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
192
+
193
+ probs = F.softmax(next_token_logits, dim=-1)
194
+ next_token = torch.multinomial(probs, num_samples=1)
195
+
196
+ idx = next_token.item()
197
+
198
+ if idx == EOS_ID:
199
+ break
200
+
201
+ x = torch.cat((x, next_token), dim=1)
202
+ generated.append(idx)
203
+
204
+ # Decode result
205
+ decoded_text = decoder(generated, used_tokens, tokenizer)
206
+ return decoded_text
207
+
208
+ def decoder(ids, used_tokens, tokenizer):
209
+ raw_ids = []
210
+ for i in ids:
211
+ if i >= OFFSET:
212
+ raw_ids.append(used_tokens[i - OFFSET])
213
+ return tokenizer.decode(raw_ids)
214
+
215
+
216
+
217
+ # ----------------- MAIN -----------------
218
+ if __name__ == "__main__":
219
+ if torch.cuda.is_available():
220
+ device = "cuda"
221
+ elif torch.backends.mps.is_available():
222
+ device = "mps"
223
+ else:
224
+ device = "cpu"
225
+
226
+ print(f"Using device: {device}")
227
+
228
+ model, used_tokens, id2new = load_model_and_vocab(device)
229
+ enc = tiktoken.get_encoding(TOKENIZER_NAME)
230
+
231
+ if model:
232
+ # Find a good starting token ID (e.g., newline or space)
233
+ newline_id = id2new.get(enc.encode("\n")[0], OFFSET)
234
+
235
+ while True:
236
+ print(f"\n--- Generating Sample (Temp=0.8, TopK=50) ---")
237
+ print("-" * 20)
238
+
239
+ x = torch.tensor([[newline_id]], dtype=torch.long, device=device)
240
+ generated = []
241
+
242
+ with torch.no_grad():
243
+ for _ in range(500):
244
+ if x.size(1) > MAX_SEQ_LEN:
245
+ ctx = x[:, -MAX_SEQ_LEN:]
246
+ else:
247
+ ctx = x
248
+
249
+ logits = model(ctx)
250
+ logits = logits[:, -1, :] / 0.8 # Temperature
251
+
252
+ # Top-k
253
+ v, _ = torch.topk(logits, min(50, logits.size(-1)))
254
+ logits[logits < v[:, [-1]]] = -float('Inf')
255
+
256
+ probs = F.softmax(logits, dim=-1)
257
+ next_token = torch.multinomial(probs, num_samples=1)
258
+
259
+ idx = next_token.item()
260
+ x = torch.cat((x, next_token), dim=1)
261
+ generated.append(idx)
262
+
263
+ if idx == EOS_ID:
264
+ print("[EOS]", end="", flush=True)
265
+ break
266
+
267
+ if idx >= OFFSET:
268
+ raw_id = used_tokens[idx - OFFSET]
269
+ token_text = enc.decode([raw_id])
270
+ print(token_text, end="", flush=True)
271
+ elif idx == PAD_ID:
272
+ print("[PAD]", end="", flush=True)
273
+ elif idx == SEP_ID:
274
+ print("[SEP]", end="", flush=True)
275
+
276
+ print("\n" + "-"*20)
277
+ cont = input("\nPress [Enter] to generate again, or type 'exit': ")
278
+ if cont.lower() == 'exit':
279
+ break
280
+
281
+
train_gclm_base.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print("Starting...")
2
+
3
+ ###############################################
4
+ # CONFIGURATION — CUSTOMIZE EVERYTHING HERE
5
+ ###############################################
6
+
7
+ # ---- data / vocab ----
8
+ TXT_PATH = "data.txt"
9
+ DATA_PCT = 0.001 # this is small for testing purposes
10
+ TOKENIZER_NAME = "gpt2"
11
+ REDUCE_VOCAB = True
12
+ VOCAB_SAVE_PATH = "vocab_map.pt"
13
+
14
+ # ---- training ----
15
+ EPOCHS = 25
16
+ MICRO_BATCH_SIZE = 1
17
+ GRAD_ACCUM_STEPS = 8
18
+ LEARNING_RATE = 3e-4
19
+
20
+ # ---- model ----
21
+ D_MODEL = 256
22
+ N_LAYERS = 4
23
+ MAX_SEQ_LEN = 1024
24
+
25
+ LOCAL_KERNEL_SIZE = 5
26
+ GLOBAL_KERNEL_SIZE = 256
27
+ USE_GLOBAL_EVERY_N_LAYERS = 2
28
+
29
+ # ---- FFT conv ----
30
+ FFT_SIZE = 1024 # must be power of 2 and > GLOBAL_KERNEL_SIZE
31
+
32
+ # ---- checkpointing ----
33
+ SAVE_PATH = "model.pt"
34
+ SAVE_N_EPOCHS = 1
35
+
36
+ # ---- device ----
37
+ USE_DEVICE = "cuda"
38
+ USE_AMP = True
39
+ USE_ACTIVATION_CHECKPOINTING = False
40
+
41
+ # ---- torch.compile ----
42
+ COMPILE = False
43
+ COMPILE_MODE = "reduce-overhead"
44
+ COMPILE_BACKEND = "eager"
45
+
46
+ ###############################################
47
+ # END CONFIG
48
+ ###############################################
49
+
50
+ import os
51
+
52
+ # Windows cannot use expandable_segments — only enable on Linux.
53
+ if os.name != "nt":
54
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
55
+
56
+ import torch
57
+ import torch.nn as nn
58
+ import torch.nn.functional as F
59
+ from torch.utils.data import Dataset, DataLoader
60
+ from tqdm import tqdm
61
+ import tiktoken
62
+
63
+ # performance settings
64
+ if torch.cuda.is_available():
65
+ torch.set_float32_matmul_precision("high")
66
+ torch.backends.cuda.matmul.allow_tf32 = True
67
+ torch.backends.cudnn.allow_tf32 = True
68
+
69
+ ###############################################################
70
+ # SPECIAL TOKENS
71
+ ###############################################################
72
+
73
+ PAD_ID = 0
74
+ SEP_ID = 1
75
+ EOS_ID = 2
76
+ OFFSET = 3
77
+
78
+ ###############################################################
79
+ # VOCAB
80
+ ###############################################################
81
+
82
+ def build_dataset_vocab(txt_path, tokenizer, save_path):
83
+ text = open(txt_path, "r", encoding="utf-8").read()
84
+ if DATA_PCT < 1.0:
85
+ text = text[:int(len(text) * DATA_PCT)]
86
+ token_ids = tokenizer.encode(text)
87
+ used = sorted(set(token_ids))
88
+
89
+ id2new = {tok: i + OFFSET for i, tok in enumerate(used)}
90
+
91
+ torch.save({
92
+ "used_tokens": used,
93
+ "id2new": id2new,
94
+ "PAD_ID": PAD_ID,
95
+ "SEP_ID": SEP_ID,
96
+ "EOS_ID": EOS_ID,
97
+ }, save_path)
98
+
99
+ print(f"[OK] Vocab size: {len(used) + OFFSET}")
100
+ return used, id2new
101
+
102
+
103
+ ###############################################################
104
+ # DATASET
105
+ ###############################################################
106
+
107
+ class RemappedTextDataset(Dataset):
108
+ def __init__(self, ids, max_len):
109
+ self.ids = ids
110
+ self.max_len = max_len
111
+
112
+ def __len__(self):
113
+ return max(0, len(self.ids) - self.max_len - 1)
114
+
115
+ def __getitem__(self, i):
116
+ x = self.ids[i:i+self.max_len]
117
+ y = self.ids[i+1:i+self.max_len+1]
118
+ return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
119
+
120
+
121
+ ###############################################################
122
+ # GLOBAL + LOCAL CONVOLUTION
123
+ ###############################################################
124
+
125
+ class GlobalConv1D(nn.Module):
126
+ def __init__(self, d_model, kernel_size, fft_size):
127
+ super().__init__()
128
+ self.kernel = nn.Parameter(torch.randn(d_model, kernel_size) * 0.01)
129
+ self.kernel_size = kernel_size
130
+ self.fft_size = fft_size
131
+
132
+ def forward(self, x):
133
+ B, C, T = x.shape
134
+ K = min(self.kernel_size, T)
135
+
136
+ overlap = K - 1
137
+ block = self.fft_size - overlap
138
+
139
+ x = F.pad(x, (overlap, 0))
140
+ k = self.kernel[:, :K]
141
+ k = F.pad(k, (0, self.fft_size - K))
142
+ k_f = torch.fft.rfft(k, n=self.fft_size)
143
+
144
+ outs = []
145
+ pos = 0
146
+ while pos < T:
147
+ seg = x[..., pos:pos+self.fft_size]
148
+ if seg.shape[-1] < self.fft_size:
149
+ seg = F.pad(seg, (0, self.fft_size - seg.shape[-1]))
150
+
151
+ y = torch.fft.irfft(
152
+ torch.fft.rfft(seg, n=self.fft_size) * k_f.unsqueeze(0),
153
+ n=self.fft_size
154
+ )
155
+ outs.append(y[..., overlap:overlap+block])
156
+ pos += block
157
+
158
+ return torch.cat(outs, dim=-1)[..., :T]
159
+
160
+
161
+ class LocalConv1D(nn.Module):
162
+ def __init__(self, d_model, k):
163
+ super().__init__()
164
+ self.k = k
165
+ self.dw = nn.Conv1d(d_model, d_model, k, groups=d_model)
166
+ self.pw = nn.Conv1d(d_model, d_model, 1)
167
+
168
+ def forward(self, x):
169
+ x = F.pad(x, (self.k - 1, 0))
170
+ return self.pw(F.relu(self.dw(x)))
171
+
172
+
173
+ class Block(nn.Module):
174
+ def __init__(self, d_model, use_global):
175
+ super().__init__()
176
+ self.use_global = use_global
177
+
178
+ self.ln1 = nn.LayerNorm(d_model)
179
+ self.local = LocalConv1D(d_model, LOCAL_KERNEL_SIZE)
180
+
181
+ if use_global:
182
+ self.ln2 = nn.LayerNorm(d_model)
183
+ self.global_conv = GlobalConv1D(d_model, GLOBAL_KERNEL_SIZE, FFT_SIZE)
184
+
185
+ self.ln3 = nn.LayerNorm(d_model)
186
+ self.ff = nn.Sequential(
187
+ nn.Linear(d_model, d_model*4),
188
+ nn.GELU(),
189
+ nn.Linear(d_model*4, d_model)
190
+ )
191
+
192
+ def forward(self, x):
193
+ x = x + self.local(self.ln1(x).transpose(1,2)).transpose(1,2)
194
+ if self.use_global:
195
+ x = x + self.global_conv(self.ln2(x).transpose(1,2)).transpose(1,2)
196
+ return x + self.ff(self.ln3(x))
197
+
198
+
199
+ class GCLM(nn.Module):
200
+ def __init__(self, vocab):
201
+ super().__init__()
202
+ self.emb = nn.Embedding(vocab, D_MODEL)
203
+ self.pos = nn.Embedding(MAX_SEQ_LEN, D_MODEL)
204
+
205
+ self.layers = nn.ModuleList([
206
+ Block(D_MODEL, i % USE_GLOBAL_EVERY_N_LAYERS == 0)
207
+ for i in range(N_LAYERS)
208
+ ])
209
+
210
+ self.ln = nn.LayerNorm(D_MODEL)
211
+ self.head = nn.Linear(D_MODEL, vocab)
212
+
213
+ # Weight tying: SIGNIFICANTLY reduces parameter count
214
+ self.head.weight = self.emb.weight
215
+
216
+ def forward(self, x):
217
+ T = x.size(1)
218
+ h = self.emb(x) + self.pos(torch.arange(T, device=x.device))
219
+ for layer in self.layers:
220
+ h = layer(h)
221
+ return self.head(self.ln(h))
222
+
223
+
224
+ ###############################################################
225
+ # TRAINING LOOP
226
+ ###############################################################
227
+
228
+ def format_params(num):
229
+ if num >= 1_000_000_000:
230
+ return f"{num/1_000_000_000:.1f}B"
231
+ elif num >= 1_000_000:
232
+ return f"{num/1_000_000:.1f}M"
233
+ else:
234
+ return f"{num/1_000:.1f}K"
235
+
236
+ @torch.no_grad()
237
+ def estimate_loss(model, dl, device, ctx):
238
+ model.eval()
239
+ losses = []
240
+ # Check up to 50 batches for validation to save time
241
+ limit = 50
242
+ for i, (x, y) in enumerate(dl):
243
+ if i >= limit: break
244
+ x, y = x.to(device), y.to(device)
245
+ with ctx:
246
+ logits = model(x)
247
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=PAD_ID)
248
+ losses.append(loss.item())
249
+ model.train()
250
+ return sum(losses) / len(losses) if losses else 0.0
251
+
252
+ def train():
253
+ if torch.cuda.is_available():
254
+ device = "cuda"
255
+ elif torch.backends.mps.is_available():
256
+ device = "mps"
257
+ else:
258
+ device = "cpu"
259
+ print("[INFO] Device:", device)
260
+
261
+ # 1. Prepare Vocab & Data
262
+ tok = tiktoken.get_encoding(TOKENIZER_NAME)
263
+
264
+ # We call this to generate/load the vocab map
265
+ used, id2new = build_dataset_vocab(TXT_PATH, tok, VOCAB_SAVE_PATH)
266
+ vocab = len(used) + OFFSET
267
+
268
+ # Load and process full text
269
+ print("[INFO] Loading and tokenizing text...")
270
+ text = open(TXT_PATH, "r", encoding="utf-8").read()
271
+ if DATA_PCT < 1.0:
272
+ text = text[:int(len(text) * DATA_PCT)]
273
+
274
+ raw_ids = tok.encode(text)
275
+ # Map to new IDs
276
+ ids = [id2new.get(i, PAD_ID) for i in raw_ids] + [EOS_ID]
277
+
278
+ # Split Train/Val (90/10)
279
+ n = len(ids)
280
+ split_idx = int(n * 0.9)
281
+ train_ids = ids[:split_idx]
282
+ val_ids = ids[split_idx:]
283
+
284
+ print(f"[INFO] Tokens: {n} | Train: {len(train_ids)} | Val: {len(val_ids)}")
285
+
286
+ train_ds = RemappedTextDataset(train_ids, MAX_SEQ_LEN)
287
+ val_ds = RemappedTextDataset(val_ids, MAX_SEQ_LEN)
288
+
289
+ train_dl = DataLoader(train_ds, batch_size=MICRO_BATCH_SIZE, shuffle=True)
290
+ val_dl = DataLoader(val_ds, batch_size=MICRO_BATCH_SIZE, shuffle=False)
291
+
292
+ model = GCLM(vocab).to(device)
293
+
294
+ # Calculate params
295
+ num_params = sum(p.numel() for p in model.parameters())
296
+ param_str = format_params(num_params)
297
+ save_path = f"chatgclm_base_{param_str}.pt"
298
+ print(f"[INFO] Model parameters: {num_params:,} ({param_str})")
299
+ print(f"[INFO] Save path: {save_path}")
300
+
301
+ # 🔁 RESUME IF CHECKPOINT EXISTS
302
+ if os.path.exists(save_path):
303
+ model.load_state_dict(torch.load(save_path, map_location=device))
304
+ print(f"[RESUME] Loaded existing checkpoint from {save_path}")
305
+
306
+ if device == "cuda" and COMPILE:
307
+ print("[INFO] Compiling model with torch.compile...")
308
+ model = torch.compile(
309
+ model,
310
+ mode=COMPILE_MODE,
311
+ fullgraph=False,
312
+ backend=COMPILE_BACKEND
313
+ )
314
+
315
+ opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
316
+ loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)
317
+
318
+ # AMP Context
319
+ if device == "cuda" and USE_AMP:
320
+ ctx = torch.amp.autocast(device)
321
+ scaler = torch.amp.GradScaler(device)
322
+ else:
323
+ # Dummy context for cpu/mps
324
+ import contextlib
325
+ ctx = contextlib.nullcontext()
326
+ scaler = None
327
+
328
+ for ep in range(EPOCHS):
329
+ print(f"\nEpoch {ep+1}/{EPOCHS}")
330
+ opt.zero_grad(set_to_none=True)
331
+
332
+ pbar = tqdm(train_dl, desc="Training")
333
+ running_loss = 0.0
334
+
335
+ for i, (x, y) in enumerate(pbar):
336
+ x, y = x.to(device), y.to(device)
337
+
338
+ with ctx:
339
+ logits = model(x)
340
+ loss = loss_fn(logits.reshape(-1, vocab), y.reshape(-1))
341
+ loss_val = loss.item()
342
+ loss = loss / GRAD_ACCUM_STEPS
343
+
344
+ if scaler:
345
+ scaler.scale(loss).backward()
346
+ else:
347
+ loss.backward()
348
+
349
+ if (i+1) % GRAD_ACCUM_STEPS == 0:
350
+ if scaler:
351
+ scaler.step(opt)
352
+ scaler.update()
353
+ else:
354
+ opt.step()
355
+ opt.zero_grad(set_to_none=True)
356
+
357
+ # Update progress bar
358
+ running_loss = 0.9 * running_loss + 0.1 * loss_val if running_loss > 0 else loss_val
359
+ pbar.set_postfix(loss=f"{running_loss:.4f}")
360
+
361
+ # Validate at end of epoch
362
+ val_loss = estimate_loss(model, val_dl, device, ctx)
363
+ print(f"Epoch {ep+1} finished. Train Loss: {running_loss:.4f} | Val Loss: {val_loss:.4f}")
364
+
365
+ if SAVE_N_EPOCHS and (ep+1) % SAVE_N_EPOCHS == 0:
366
+ torch.save(model.state_dict(), save_path)
367
+ print(f"[OK] Saved checkpoint to {save_path}")
368
+
369
+ torch.save(model.state_dict(), save_path)
370
+ print("[DONE] Training complete.")
371
+
372
+
373
+ ###############################################################
374
+ # ENTRY POINT
375
+ ###############################################################
376
+
377
+ if __name__ == "__main__":
378
+ train()
vocab_map.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cde6c659b6f26ead43cd4a301ea1c6da6eb3975e47ffeb00ef740722b1a3cf39
3
+ size 6841