LH-Tech-AI commited on
Commit
f6660e6
·
verified ·
1 Parent(s): 8ab44df

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +361 -0
train.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ %%writefile train.py
2
+ # ============================================================
3
+ # Mini Math Model - T5 Seq2Seq
4
+ # ============================================================
5
+ # pip install transformers torch datasets accelerate
6
+
7
+ import random
8
+ import torch
9
+ import numpy as np
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from transformers import T5Config, T5ForConditionalGeneration
12
+ from torch.optim import AdamW
13
+ from torch.optim.lr_scheduler import CosineAnnealingLR
14
+ import time
15
+
16
+ # ============================================================
17
+ # 1. CONFIG
18
+ # ============================================================
19
+
20
+ TRAIN_SAMPLES = 2_000_000
21
+ VAL_SAMPLES = 10_000
22
+ MAX_DIGITS = 3
23
+ BATCH_SIZE = 512
24
+ EPOCHS = 10
25
+ LR = 3e-4
26
+ MAX_INPUT_LEN = 20
27
+ MAX_TARGET_LEN= 12
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+ SAVE_PATH = "model.pt"
30
+
31
+ print(f"Device: {DEVICE}")
32
+ print(f"GPU: {torch.cuda.get_device_name(0) if DEVICE == 'cuda' else 'None'}")
33
+
34
+ # ============================================================
35
+ # 2. TOKENIZER (Character-Level)
36
+ # ============================================================
37
+
38
+ CHARS = list("0123456789+-*/=") + ["<pad>", "<bos>", "<eos>"]
39
+ char2id = {c: i for i, c in enumerate(CHARS)}
40
+ id2char = {i: c for c, i in char2id.items()}
41
+
42
+ PAD_ID = char2id["<pad>"]
43
+ BOS_ID = char2id["<bos>"]
44
+ EOS_ID = char2id["<eos>"]
45
+ VOCAB_SIZE = len(CHARS)
46
+
47
+ def encode(text, max_len, add_bos=False, add_eos=True):
48
+ tokens = []
49
+ if add_bos:
50
+ tokens.append(BOS_ID)
51
+ for c in text:
52
+ tokens.append(char2id.get(c, PAD_ID))
53
+ if add_eos:
54
+ tokens.append(EOS_ID)
55
+ # Padding
56
+ tokens = tokens[:max_len]
57
+ tokens += [PAD_ID] * (max_len - len(tokens))
58
+ return tokens
59
+
60
+ def decode(token_ids):
61
+ result = []
62
+ for tid in token_ids:
63
+ if tid == EOS_ID:
64
+ break
65
+ if tid in (PAD_ID, BOS_ID):
66
+ continue
67
+ result.append(id2char.get(tid, "?"))
68
+ return "".join(result)
69
+
70
+ # ============================================================
71
+ # 3. DATA GENERATION
72
+ # ============================================================
73
+
74
+ def generate_sample(max_digits=3):
75
+ op = random.choice(["+", "-", "*", "/"])
76
+
77
+ if op == "+":
78
+ a = random.randint(0, 10**max_digits - 1)
79
+ b = random.randint(0, 10**max_digits - 1)
80
+ result = a + b
81
+ elif op == "-":
82
+ a = random.randint(0, 10**max_digits - 1)
83
+ b = random.randint(0, 10**max_digits - 1)
84
+ result = a - b
85
+ elif op == "*":
86
+ a = random.randint(0, 10**(max_digits-1) - 1)
87
+ b = random.randint(0, 10**(max_digits-1) - 1)
88
+ result = a * b
89
+ elif op == "/":
90
+ b = random.randint(1, 10**(max_digits-1) - 1)
91
+ result = random.randint(0, 10**(max_digits-1) - 1)
92
+ a = b * result
93
+
94
+ input_str = f"{a}{op}{b}="
95
+ target_str = str(result)
96
+ return input_str, target_str
97
+
98
+ def generate_dataset(n_samples, max_digits=3):
99
+ inputs, targets = [], []
100
+ for _ in range(n_samples):
101
+ inp, tgt = generate_sample(max_digits)
102
+ inputs.append(inp)
103
+ targets.append(tgt)
104
+ return inputs, targets
105
+
106
+ print("Generating training data...")
107
+ t0 = time.time()
108
+ train_inputs, train_targets = generate_dataset(TRAIN_SAMPLES, MAX_DIGITS)
109
+ val_inputs, val_targets = generate_dataset(VAL_SAMPLES, MAX_DIGITS)
110
+ print(f"Done in {time.time()-t0:.1f}s")
111
+ print(f"Sample: '{train_inputs[0]}' → '{train_targets[0]}'")
112
+
113
+ # ============================================================
114
+ # 4. DATASET
115
+ # ============================================================
116
+
117
+ class MathDataset(Dataset):
118
+ def __init__(self, inputs, targets):
119
+ self.inputs = inputs
120
+ self.targets = targets
121
+
122
+ def __len__(self):
123
+ return len(self.inputs)
124
+
125
+ def __getitem__(self, idx):
126
+ inp = self.inputs[idx]
127
+ tgt = self.targets[idx]
128
+
129
+ input_ids = encode(inp, MAX_INPUT_LEN, add_bos=False, add_eos=True)
130
+ attention_mask = [1 if t != PAD_ID else 0 for t in input_ids]
131
+
132
+ labels = encode(tgt, MAX_TARGET_LEN, add_bos=False, add_eos=True)
133
+ labels = [t if t != PAD_ID else -100 for t in labels]
134
+
135
+ decoder_input = [BOS_ID] + encode(tgt, MAX_TARGET_LEN-1, add_bos=False, add_eos=False)
136
+ decoder_input = decoder_input[:MAX_TARGET_LEN]
137
+ decoder_input += [PAD_ID] * (MAX_TARGET_LEN - len(decoder_input))
138
+
139
+ return {
140
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
141
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
142
+ "decoder_input_ids": torch.tensor(decoder_input, dtype=torch.long),
143
+ "labels": torch.tensor(labels, dtype=torch.long),
144
+ }
145
+
146
+ train_dataset = MathDataset(train_inputs, train_targets)
147
+ val_dataset = MathDataset(val_inputs, val_targets)
148
+
149
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
150
+ num_workers=2, pin_memory=True)
151
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
152
+ num_workers=2, pin_memory=True)
153
+
154
+ # ============================================================
155
+ # 5. MODEL (~1M parameters)
156
+ # ============================================================
157
+
158
+ config = T5Config(
159
+ vocab_size=VOCAB_SIZE,
160
+ d_model=128,
161
+ d_ff=256,
162
+ num_heads=4,
163
+ num_layers=3, # Encoder layers
164
+ num_decoder_layers=3, # Decoder layers
165
+ d_kv=32,
166
+ dropout_rate=0.1,
167
+ feed_forward_proj="relu",
168
+ is_encoder_decoder=True,
169
+ pad_token_id=PAD_ID,
170
+ eos_token_id=EOS_ID,
171
+ decoder_start_token_id=BOS_ID,
172
+ )
173
+
174
+ model = T5ForConditionalGeneration(config).to(DEVICE)
175
+
176
+ scaler = torch.cuda.amp.GradScaler()
177
+
178
+ total_params = sum(p.numel() for p in model.parameters())
179
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
180
+ print(f"\nTotal parameters: {total_params/1e6:.2f}M")
181
+ print(f"Trainable: {trainable/1e6:.2f}M")
182
+
183
+ # ============================================================
184
+ # 6. OPTIMIZER & SCHEDULER
185
+ # ============================================================
186
+
187
+ optimizer = AdamW(model.parameters(), lr=LR, weight_decay=0.01)
188
+ total_steps = len(train_loader) * EPOCHS
189
+ scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=LR/10)
190
+
191
+ # ============================================================
192
+ # 7. EVALUATION
193
+ # ============================================================
194
+
195
+ def evaluate(model, loader, n_examples=5):
196
+ model.eval()
197
+ correct = 0
198
+ total = 0
199
+ examples = []
200
+
201
+ with torch.no_grad():
202
+ for batch in loader:
203
+ input_ids = batch["input_ids"].to(DEVICE)
204
+ attention_mask = batch["attention_mask"].to(DEVICE)
205
+
206
+ # Greedy generation
207
+ generated = model.generate(
208
+ input_ids=input_ids,
209
+ attention_mask=attention_mask,
210
+ max_new_tokens=MAX_TARGET_LEN,
211
+ eos_token_id=EOS_ID,
212
+ pad_token_id=PAD_ID,
213
+ )
214
+
215
+ labels = batch["labels"]
216
+
217
+ for i in range(len(input_ids)):
218
+ pred_ids = generated[i].cpu().tolist()
219
+ pred_str = decode(pred_ids)
220
+
221
+ lbl = labels[i].tolist()
222
+ lbl = [t for t in lbl if t != -100]
223
+ true_str = decode(lbl)
224
+
225
+ is_correct = (pred_str == true_str)
226
+ correct += int(is_correct)
227
+ total += 1
228
+
229
+ if len(examples) < n_examples:
230
+ inp_str = decode(input_ids[i].cpu().tolist())
231
+ examples.append((inp_str, true_str, pred_str, is_correct))
232
+
233
+ accuracy = correct / total * 100
234
+ return accuracy, examples
235
+
236
+ # ============================================================
237
+ # 8. TRAINING LOOP
238
+ # ============================================================
239
+
240
+ print("\n" + "="*60)
241
+ print("TRAINING START")
242
+ print("="*60)
243
+
244
+ best_accuracy = 0.0
245
+
246
+ for epoch in range(1, EPOCHS + 1):
247
+ model.train()
248
+ total_loss = 0.0
249
+ steps = 0
250
+ t_start = time.time()
251
+
252
+ for batch in train_loader:
253
+ input_ids = batch["input_ids"].to(DEVICE)
254
+ attention_mask = batch["attention_mask"].to(DEVICE)
255
+ decoder_input_ids = batch["decoder_input_ids"].to(DEVICE)
256
+ labels = batch["labels"].to(DEVICE)
257
+
258
+ optimizer.zero_grad()
259
+
260
+ # Mixed Precision
261
+ with torch.cuda.amp.autocast(dtype=torch.float16):
262
+ outputs = model(
263
+ input_ids=input_ids,
264
+ attention_mask=attention_mask,
265
+ decoder_input_ids=decoder_input_ids,
266
+ labels=labels,
267
+ )
268
+ loss = outputs.loss
269
+
270
+ scaler.scale(loss).backward()
271
+ scaler.unscale_(optimizer)
272
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
273
+ scaler.step(optimizer)
274
+ scaler.update()
275
+ scheduler.step()
276
+
277
+ total_loss += loss.item()
278
+ steps += 1
279
+
280
+ if steps % 500 == 0:
281
+ avg_loss = total_loss / steps
282
+ elapsed = time.time() - t_start
283
+ print(f" Epoch {epoch} | Step {steps}/{len(train_loader)} "
284
+ f"| Loss: {avg_loss:.4f} | {elapsed:.0f}s")
285
+
286
+ avg_loss = total_loss / steps
287
+
288
+ # Validation
289
+ print(f"\nEpoch {epoch} done. Evaluating...")
290
+ accuracy, examples = evaluate(model, val_loader)
291
+
292
+ print(f"\n{'='*60}")
293
+ print(f"Epoch {epoch}/{EPOCHS}")
294
+ print(f" Train loss: {avg_loss:.4f}")
295
+ print(f" Val accuracy: {accuracy:.2f}%")
296
+ print(f"\n Samples:")
297
+ for inp, true, pred, ok in examples:
298
+ status = "✅" if ok else "❌"
299
+ print(f" {status} '{inp}' → expected: '{true}', got: '{pred}'")
300
+ print("="*60)
301
+
302
+ # Bestes Modell speichern
303
+ if accuracy > best_accuracy:
304
+ best_accuracy = accuracy
305
+ torch.save({
306
+ "model_state_dict": model.state_dict(),
307
+ "config": config,
308
+ "char2id": char2id,
309
+ "id2char": id2char,
310
+ "epoch": epoch,
311
+ "accuracy": accuracy,
312
+ }, SAVE_PATH)
313
+ print(f" 💾 New best model saved! ({accuracy:.2f}%)")
314
+
315
+ print(f"\nTraining done! Best accuracy: {best_accuracy:.2f}%")
316
+
317
+ # ============================================================
318
+ # 9. INFERENCE - TEST
319
+ # ============================================================
320
+
321
+ def predict(model, expression):
322
+ model.eval()
323
+ inp = expression + "="
324
+ input_ids = torch.tensor(
325
+ [encode(inp, MAX_INPUT_LEN, add_bos=False, add_eos=True)],
326
+ dtype=torch.long
327
+ ).to(DEVICE)
328
+ attention_mask = (input_ids != PAD_ID).long()
329
+
330
+ with torch.no_grad():
331
+ generated = model.generate(
332
+ input_ids=input_ids,
333
+ attention_mask=attention_mask,
334
+ max_new_tokens=MAX_TARGET_LEN,
335
+ eos_token_id=EOS_ID,
336
+ pad_token_id=PAD_ID,
337
+ )
338
+
339
+ return decode(generated[0].cpu().tolist())
340
+
341
+ print("\n" + "="*60)
342
+ print("INFERENCE TEST")
343
+ print("="*60)
344
+
345
+ test_cases = [
346
+ "123+456",
347
+ "999-123",
348
+ "12*34",
349
+ "100/5",
350
+ "500+500",
351
+ "77*8",
352
+ ]
353
+
354
+ for expr in test_cases:
355
+ pred = predict(model, expr)
356
+ try:
357
+ true = str(eval(expr.replace("/", "//")))
358
+ except:
359
+ true = "?"
360
+ status = "✅" if pred == true else "❌"
361
+ print(f" {status} {expr} = {pred} (correct: {true})")