luxopes commited on
Commit
8449341
·
verified ·
1 Parent(s): 41dd3d9

Upload 7 files

Browse files
Files changed (7) hide show
  1. data_prepare.py +177 -0
  2. eval-loss.py +245 -0
  3. test-checkpoints.py +86 -0
  4. tokenizer.model +3 -0
  5. tokenizer.vocab +0 -0
  6. train.py +519 -0
  7. valid.bin +3 -0
data_prepare.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -- coding: utf-8 --
2
+ import os
3
+ from datasets import load_dataset
4
+ from tqdm import tqdm
5
+ import sentencepiece as spm
6
+ import numpy as np
7
+
8
+ # ===========================================================
9
+ # KONFIGURACE
10
+ # ===========================================================
11
+ TARGET_TOKENS = 1_000_000_000 # 100M pro test, může být 1_000_000_000 a víc
12
+ VOCAB_SIZE = 32_000
13
+ RAW_TEXT_PATH = "dataset.txt"
14
+
15
+ TOKENIZER_MODEL_PATH = "tokenizer.model"
16
+
17
+ BIN_TRAIN_PATH = "dataset.bin"
18
+ BIN_VALID_PATH = "valid.bin"
19
+
20
+ TRAIN_RATIO = 0.98 # 98% trénink, 2% valid
21
+
22
+ SPECIAL_TOKENS = {
23
+ "unk_id": 0,
24
+ "bos_id": 1,
25
+ "eos_id": 2,
26
+ "pad_id": 3,
27
+ }
28
+
29
+ # ===========================================================
30
+ # 1) STREAMOVANÉ STAŽENÍ FINEWEB -> dataset.txt
31
+ # ===========================================================
32
+ def download_fineweb_streaming():
33
+ if os.path.exists(RAW_TEXT_PATH):
34
+ print("✔ dataset.txt už existuje, přeskočeno.")
35
+ return
36
+
37
+ print("📥 Stahuji FineWeb-Edu streamovacím způsobem...")
38
+
39
+ dataset = load_dataset(
40
+ "HuggingFaceFW/fineweb-edu",
41
+ name="sample-10BT",
42
+ split="train",
43
+ streaming=True
44
+ )
45
+
46
+ tokens_so_far = 0
47
+
48
+ with open(RAW_TEXT_PATH, "w", encoding="utf-8") as f:
49
+ for example in tqdm(dataset, desc="Stahuji dataset"):
50
+ text = example["text"].strip() + "\n\n"
51
+ approx = len(text) // 4 # odhad tokenů
52
+
53
+ if tokens_so_far + approx > TARGET_TOKENS:
54
+ remaining = TARGET_TOKENS - tokens_so_far
55
+ chars = remaining * 4
56
+ f.write(text[:chars])
57
+ print("✔ dataset.txt hotovo.")
58
+ return
59
+
60
+ f.write(text)
61
+ tokens_so_far += approx
62
+
63
+ if tokens_so_far >= TARGET_TOKENS:
64
+ print("✔ dataset.txt hotovo.")
65
+ return
66
+
67
+ # ===========================================================
68
+ # 2) TRÉNINK SENTENCEPIECE TOKENIZERU
69
+ # ===========================================================
70
+ def train_tokenizer():
71
+ if os.path.exists(TOKENIZER_MODEL_PATH):
72
+ print("✔ Tokenizer už existuje, přeskakuji.")
73
+ return
74
+
75
+ print("🔧 Trénuji SentencePiece tokenizer...")
76
+
77
+ prefix = TOKENIZER_MODEL_PATH.replace(".model", "")
78
+
79
+ spm.SentencePieceTrainer.train(
80
+ input=RAW_TEXT_PATH,
81
+ model_prefix=prefix,
82
+ vocab_size=VOCAB_SIZE,
83
+ model_type="unigram",
84
+ character_coverage=1.0,
85
+ byte_fallback=True,
86
+
87
+ unk_id=SPECIAL_TOKENS["unk_id"],
88
+ bos_id=SPECIAL_TOKENS["bos_id"],
89
+ eos_id=SPECIAL_TOKENS["eos_id"],
90
+ pad_id=SPECIAL_TOKENS["pad_id"],
91
+
92
+ train_extremely_large_corpus=True,
93
+ )
94
+
95
+ print("✔ Tokenizer natrénován.")
96
+
97
+ # ===========================================================
98
+ # 3) STREAMOVÁ TOKENIZACE → BIN FILE (INT32)
99
+ # ===========================================================
100
+ def tokenize_to_bin_streaming():
101
+ """
102
+ Streamovací tokenizace velkého datasetu do binárních souborů (int32),
103
+ bez držení celého datasetu v paměti.
104
+ """
105
+ if os.path.exists(BIN_TRAIN_PATH) and os.path.exists(BIN_VALID_PATH):
106
+ print("✔ dataset.bin + valid.bin už existují.")
107
+ return
108
+
109
+ print("🔠 Streamuji text → tokeny (int32) → dataset.bin...")
110
+
111
+ sp = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH)
112
+ EOS = sp.eos_id()
113
+
114
+ # ===========================================================
115
+ # 1️⃣ ZJIŠTĚNÍ CELKOVÉHO POČTU TOKENŮ
116
+ # ===========================================================
117
+ print("🔎 Počítám tokeny...")
118
+ total_tokens = 0
119
+ with open(RAW_TEXT_PATH, "r", encoding="utf-8") as f:
120
+ for line in tqdm(f, desc="Počítám tokeny"):
121
+ line = line.strip()
122
+ if not line:
123
+ continue
124
+ total_tokens += len(sp.encode(line)) + 1 # +1 pro EOS
125
+
126
+ train_tokens = int(total_tokens * TRAIN_RATIO)
127
+ valid_tokens = total_tokens - train_tokens
128
+
129
+ print(f"Celkem tokenů: {total_tokens:,}")
130
+ print(f"Train: {train_tokens:,}")
131
+ print(f"Valid: {valid_tokens:,}")
132
+
133
+ # ===========================================================
134
+ # 2️⃣ VYTVOŘENÍ MEMMAP SOUBORŮ
135
+ # ===========================================================
136
+ train_mm = np.memmap(BIN_TRAIN_PATH, dtype=np.int32, mode="w+", shape=(train_tokens,))
137
+ valid_mm = np.memmap(BIN_VALID_PATH, dtype=np.int32, mode="w+", shape=(valid_tokens,))
138
+
139
+ # ===========================================================
140
+ # 3️⃣ STREAMOVÁ TOKENIZACE A ZÁPIS
141
+ # ===========================================================
142
+ print("✍ Tokenizuji a zapisují do memmap...")
143
+ ti, vi = 0, 0 # indexy do train/valid memmap
144
+
145
+ with open(RAW_TEXT_PATH, "r", encoding="utf-8") as f:
146
+ for line in tqdm(f, desc="Tokenizuji dataset"):
147
+ line = line.strip()
148
+ if not line:
149
+ continue
150
+
151
+ ids = sp.encode(line) + [EOS]
152
+
153
+ for tok in ids:
154
+ if ti < train_tokens:
155
+ train_mm[ti] = tok
156
+ ti += 1
157
+ else:
158
+ valid_mm[vi] = tok
159
+ vi += 1
160
+
161
+ # ===========================================================
162
+ # 4️⃣ FLUSH MEMMAP
163
+ # ===========================================================
164
+ train_mm.flush()
165
+ valid_mm.flush()
166
+
167
+ print("✔ Hotovo — dataset.bin + valid.bin připravené pro trénink!")
168
+
169
+ # ===========================================================
170
+ # MAIN
171
+ # ===========================================================
172
+ if __name__ == "__main__":
173
+ download_fineweb_streaming()
174
+ train_tokenizer()
175
+ tokenize_to_bin_streaming()
176
+ print("\n🎉 HOTOVO — dataset.bin + valid.bin připravené pro trénink!")
177
+
eval-loss.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -- coding: utf-8 --
2
+ # Compare validation loss of multiple GPT checkpoints
3
+ # Works with old and new checkpoint formats
4
+ # Compatible with Antonín Tomeček Transformer code
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import sentencepiece as spm
11
+ import numpy as np
12
+ from torch.utils.data import Dataset, DataLoader
13
+ from tqdm import tqdm
14
+
15
+ # =========================
16
+ # CONFIG
17
+ # =========================
18
+ CHECKPOINTS = {
19
+ "pretrain_900k": "checkpoints/step_900000.pt",
20
+ "continual_100k": "checkpoints/step_100000.pt",
21
+ "continual_200k": "checkpoints/step_200000.pt",
22
+ "continual_300k": "checkpoints/step_300000.pt",
23
+ "continual_400k": "checkpoints/step_400000.pt",
24
+ "continual_500k": "checkpoints/step_500000.pt",
25
+ }
26
+
27
+ TOKENIZER_MODEL_PATH = "tokenizer.model"
28
+ VALID_BIN = "valid.bin"
29
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
+ BATCH_SIZE = 1 # můžeš zvýšit podle VRAM
31
+
32
+ # =========================
33
+ # ModelArgs
34
+ # =========================
35
+ from dataclasses import dataclass
36
+
37
+ @dataclass
38
+ class ModelArgs:
39
+ dim: int = 768
40
+ n_layers: int = 12
41
+ n_heads: int = 12
42
+ n_kv_heads: int = 4
43
+ vocab_size: int = 32000
44
+ multiple_of: int = 256
45
+ ffn_dim_multiplier: float = 3.0
46
+ norm_eps: float = 1e-5
47
+ max_seq_len: int = 1024
48
+
49
+ # =========================
50
+ # Dataset
51
+ # =========================
52
+ class MemmapDataset(Dataset):
53
+ def __init__(self, path: str, max_seq_len: int, stride=None):
54
+ self.tokens = np.memmap(path, dtype=np.int32, mode="r")
55
+ self.max_seq_len = max_seq_len
56
+ self.stride = stride or max_seq_len // 2
57
+
58
+ max_start = len(self.tokens) - (max_seq_len + 1)
59
+ if max_start <= 0:
60
+ raise ValueError("Dataset too small")
61
+
62
+ self.starts = list(range(0, max_start, self.stride))
63
+ if self.starts[-1] != max_start:
64
+ self.starts.append(max_start)
65
+
66
+ def __len__(self):
67
+ return len(self.starts)
68
+
69
+ def __getitem__(self, idx):
70
+ i = self.starts[idx]
71
+ seq = torch.from_numpy(
72
+ self.tokens[i:i + self.max_seq_len + 1].copy()
73
+ ).long()
74
+ return seq[:-1], seq[1:]
75
+
76
+ # =========================
77
+ # Transformer model
78
+ # =========================
79
+ class RMSNorm(nn.Module):
80
+ def __init__(self, dim, eps=1e-6):
81
+ super().__init__()
82
+ self.eps = eps
83
+ self.weight = nn.Parameter(torch.ones(dim))
84
+
85
+ def forward(self, x):
86
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
87
+
88
+ def precompute_freqs_cis(dim, seq_len, theta=10000.0):
89
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) / dim))
90
+ t = torch.arange(seq_len)
91
+ freqs = torch.outer(t, freqs)
92
+ return freqs.cos(), freqs.sin()
93
+
94
+ def apply_rotary_emb(x, cos, sin):
95
+ x1, x2 = x[..., 0::2], x[..., 1::2]
96
+ cos = cos.unsqueeze(0).unsqueeze(2)
97
+ sin = sin.unsqueeze(0).unsqueeze(2)
98
+ out = torch.empty_like(x)
99
+ out[..., 0::2] = x1 * cos - x2 * sin
100
+ out[..., 1::2] = x1 * sin + x2 * cos
101
+ return out
102
+
103
+ class Attention(nn.Module):
104
+ def __init__(self, args):
105
+ super().__init__()
106
+ self.n_heads = args.n_heads
107
+ self.head_dim = args.dim // args.n_heads
108
+ self.n_kv_heads = args.n_kv_heads
109
+ self.repeat_kv = args.n_heads // args.n_kv_heads
110
+
111
+ self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
112
+ self.wk = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
113
+ self.wv = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
114
+ self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
115
+
116
+ def forward(self, x, cos, sin):
117
+ B, T, _ = x.shape
118
+ q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
119
+ k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
120
+ v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)
121
+ k = k.repeat_interleave(self.repeat_kv, dim=2)
122
+ v = v.repeat_interleave(self.repeat_kv, dim=2)
123
+ q = apply_rotary_emb(q, cos, sin)
124
+ k = apply_rotary_emb(k, cos, sin)
125
+ q = q.transpose(1,2)
126
+ k = k.transpose(1,2)
127
+ v = v.transpose(1,2)
128
+ out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
129
+ out = out.transpose(1,2).contiguous().view(B, T, -1)
130
+ return self.wo(out)
131
+
132
+ class FeedForward(nn.Module):
133
+ def __init__(self, dim, multiple_of, mult):
134
+ super().__init__()
135
+ hidden = multiple_of * ((int(dim * mult) + multiple_of -1)//multiple_of)
136
+ self.w1 = nn.Linear(dim, hidden, bias=False)
137
+ self.w2 = nn.Linear(hidden, dim, bias=False)
138
+ self.w3 = nn.Linear(dim, hidden, bias=False)
139
+ def forward(self,x):
140
+ return self.w2(F.silu(self.w1(x))*self.w3(x))
141
+
142
+ class TransformerBlock(nn.Module):
143
+ def __init__(self, args):
144
+ super().__init__()
145
+ self.attn = Attention(args)
146
+ self.ffn = FeedForward(args.dim, args.multiple_of, args.ffn_dim_multiplier)
147
+ self.attn_norm = RMSNorm(args.dim, args.norm_eps)
148
+ self.ffn_norm = RMSNorm(args.dim, args.norm_eps)
149
+ def forward(self, x, cos, sin):
150
+ x = x + self.attn(self.attn_norm(x), cos, sin)
151
+ x = x + self.ffn(self.ffn_norm(x))
152
+ return x
153
+
154
+ class Transformer(nn.Module):
155
+ def __init__(self, args):
156
+ super().__init__()
157
+ self.tok_emb = nn.Embedding(args.vocab_size, args.dim)
158
+ self.layers = nn.ModuleList([TransformerBlock(args) for _ in range(args.n_layers)])
159
+ self.norm = RMSNorm(args.dim, args.norm_eps)
160
+ self.out = nn.Linear(args.dim, args.vocab_size, bias=False)
161
+ cos, sin = precompute_freqs_cis(args.dim//args.n_heads, args.max_seq_len*2)
162
+ self.register_buffer("cos_cached", cos, persistent=False)
163
+ self.register_buffer("sin_cached", sin, persistent=False)
164
+ def forward(self, tokens):
165
+ B, T = tokens.shape
166
+ h = self.tok_emb(tokens)
167
+ cos = self.cos_cached[:T]
168
+ sin = self.sin_cached[:T]
169
+ for layer in self.layers:
170
+ h = layer(h, cos, sin)
171
+ h = self.norm(h)
172
+ return self.out(h)
173
+
174
+ # =========================
175
+ # Eval function
176
+ # =========================
177
+ def evaluate_checkpoint(path, valid_loader, tokenizer, args):
178
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
179
+
180
+ # Podpora starého i nového formátu checkpointu
181
+ if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
182
+ state_dict = ckpt["model_state_dict"]
183
+ else:
184
+ state_dict = ckpt
185
+
186
+ model = Transformer(args)
187
+ model.load_state_dict(state_dict)
188
+ model.to(DEVICE)
189
+ model.eval()
190
+
191
+ total_loss = 0.0
192
+ total_tokens = 0
193
+
194
+ with torch.no_grad():
195
+ for x, y in valid_loader:
196
+ x = x.to(DEVICE)
197
+ y = y.to(DEVICE)
198
+
199
+ logits = model(x)
200
+ loss = F.cross_entropy(
201
+ logits.view(-1, logits.size(-1)),
202
+ y.view(-1),
203
+ ignore_index=tokenizer.pad_id(),
204
+ reduction="sum",
205
+ )
206
+
207
+ total_loss += loss.item()
208
+ total_tokens += (y != tokenizer.pad_id()).sum().item()
209
+
210
+ return total_loss / total_tokens
211
+
212
+ # =========================
213
+ # MAIN
214
+ # =========================
215
+ def main():
216
+ # pevné ModelArgs
217
+ args = ModelArgs()
218
+ tokenizer = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH)
219
+ args.vocab_size = tokenizer.vocab_size()
220
+
221
+ # dataset
222
+ valid_ds = MemmapDataset(VALID_BIN, args.max_seq_len)
223
+ valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
224
+
225
+ print("="*70)
226
+ print("Checkpoint comparison (validation)")
227
+ print("="*70)
228
+
229
+ results = {}
230
+ for name, path in CHECKPOINTS.items():
231
+ print(f"[Eval] {name}")
232
+ loss = evaluate_checkpoint(path, valid_loader, tokenizer, args)
233
+ ppl = math.exp(loss)
234
+ results[name] = (loss, ppl)
235
+ print(f" Val loss: {loss:.6f}")
236
+ print(f" Perplexity: {ppl:.2f}")
237
+ print("-"*50)
238
+
239
+ print("\nSummary:")
240
+ for name, (loss, ppl) in results.items():
241
+ print(f"{name:20s} | loss {loss:.6f} | ppl {ppl:.2f}")
242
+ print("="*70)
243
+
244
+ if __name__ == "__main__":
245
+ main()
test-checkpoints.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -- coding: utf-8 --
2
+ # Author: Antonín Tomeček
3
+ # Date: 10 Jan 2026
4
+ # Description: Standalone text generation from GPT-style checkpoint 500k
5
+
6
+ import os
7
+ import torch
8
+ import sentencepiece as spm
9
+
10
+ # importuj model a třídy z tvého tréninkového souboru
11
+ from train import Transformer, ModelArgs, generate_text # uprav podle názvu souboru
12
+
13
+ # =========================
14
+ # CONFIG
15
+ # =========================
16
+ CHECKPOINT_PATH = "checkpoints/step_500000.pt"
17
+ TOKENIZER_MODEL_PATH = "tokenizer.model"
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ MAX_NEW_TOKENS = 200
21
+ TEMPERATURE = 0.8
22
+ TOP_P = 0.95
23
+ EOS_ID = 1 # podle tokenizeru, většinou 1 je </s>
24
+
25
+ # =========================
26
+ # Povolit ModelArgs při odpickle
27
+ # =========================
28
+ torch.serialization.add_safe_globals([ModelArgs])
29
+
30
+ # =========================
31
+ # LOAD TOKENIZER
32
+ # =========================
33
+ tokenizer = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH)
34
+ vocab_size = tokenizer.vocab_size()
35
+
36
+ # =========================
37
+ # LOAD CHECKPOINT
38
+ # =========================
39
+ if not os.path.exists(CHECKPOINT_PATH):
40
+ raise FileNotFoundError(f"Checkpoint {CHECKPOINT_PATH} not found")
41
+
42
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
43
+
44
+ # načteme model podle uložených args
45
+ model_args = checkpoint.get("model_args", ModelArgs())
46
+ model_args.vocab_size = vocab_size
47
+ model = Transformer(model_args).to(DEVICE)
48
+
49
+ # načteme váhy
50
+ model.load_state_dict(checkpoint["model_state_dict"])
51
+ model.eval()
52
+
53
+ print(f"[Info] Loaded checkpoint from step {checkpoint.get('step', 'unknown')}")
54
+ print(f"[Info] Model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} params")
55
+
56
+ # =========================
57
+ # PROMPTS
58
+ # =========================
59
+ prompts = [
60
+ "Once upon a time",
61
+ "In a distant future",
62
+ "Artificial intelligence will",
63
+ "First step to build a rocket",
64
+ "Capital city of France"
65
+ ]
66
+
67
+ # =========================
68
+ # GENERATE TEXT
69
+ # =========================
70
+ results = generate_text(
71
+ model,
72
+ tokenizer,
73
+ prompts,
74
+ max_new_tokens=MAX_NEW_TOKENS,
75
+ temperature=TEMPERATURE,
76
+ top_p=TOP_P,
77
+ eos_id=EOS_ID
78
+ )
79
+
80
+ # =========================
81
+ # PRINT RESULTS
82
+ # =========================
83
+ for prompt, text in results.items():
84
+ print("="*50)
85
+ print(f"Prompt: {prompt}")
86
+ print(f"Generated: {text}")
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba603eec2affef5ce7b3826463b2839bfbdc19ebade48fecd7551f847c17f9da
3
+ size 725097
tokenizer.vocab ADDED
The diff for this file is too large to render. See raw diff
 
train.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -- coding: utf-8 --
2
+ # Author: Antonín Tomeček
3
+ # Date: 3 Jan. 2026
4
+ # Description: GPT-style Transformer with Flash Attention 2, Memmap dataset,
5
+ # correct gradient accumulation, and clean English logging.
6
+
7
+ import os
8
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
9
+
10
+ import math
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Optional
18
+ from torch.utils.data import Dataset, DataLoader
19
+ from accelerate import Accelerator
20
+ from tqdm import tqdm
21
+ import sentencepiece as spm
22
+
23
+ torch.backends.cuda.matmul.allow_tf32 = True
24
+ torch.backends.cudnn.allow_tf32 = True
25
+
26
+ # =========================
27
+ # FLASH ATTENTION 2
28
+ # =========================
29
+ try:
30
+ print(f"[Info] Torch version: {torch.__version__}")
31
+ print(f"[Info] CUDA available: {torch.cuda.is_available()}")
32
+ if torch.cuda.is_available():
33
+ print(f"[Info] CUDA version: {torch.version.cuda}")
34
+
35
+ from flash_attn import flash_attn_func
36
+ FLASH_ATTENTION_2 = True
37
+ print("[OK] Flash Attention 2 enabled")
38
+ except Exception:
39
+ FLASH_ATTENTION_2 = False
40
+ print("[WARN] Flash Attention 2 not available – using PyTorch SDPA")
41
+
42
+ # =========================
43
+ # CONFIG
44
+ # =========================
45
+ @dataclass
46
+ class ModelArgs:
47
+ dim: int = 768
48
+ n_layers: int = 12
49
+ n_heads: int = 12
50
+ n_kv_heads: int = 4
51
+ vocab_size: int = 32000
52
+ multiple_of: int = 256
53
+ ffn_dim_multiplier: float = 3.0
54
+ norm_eps: float = 1e-5
55
+ max_seq_len: int = 1024
56
+
57
+
58
+ SAVE_EVERY_STEPS = 100_000
59
+ TOKENIZER_MODEL_PATH = "tokenizer.model"
60
+ TRAIN_BIN = "dataset.bin"
61
+ VALID_BIN = "valid.bin"
62
+ CHECKPOINT_DIR = "checkpoints"
63
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
64
+
65
+ # =========================
66
+ # MODEL
67
+ # =========================
68
+ class RMSNorm(nn.Module):
69
+ def __init__(self, dim, eps=1e-6):
70
+ super().__init__()
71
+ self.eps = eps
72
+ self.weight = nn.Parameter(torch.ones(dim))
73
+
74
+ def forward(self, x):
75
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
76
+
77
+
78
+ def precompute_freqs_cis(dim, seq_len, theta=10000.0):
79
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2) / dim))
80
+ t = torch.arange(seq_len)
81
+ freqs = torch.outer(t, freqs)
82
+ return freqs.cos(), freqs.sin()
83
+
84
+
85
+ def apply_rotary_emb(x, cos, sin):
86
+ x1, x2 = x[..., 0::2], x[..., 1::2]
87
+ cos = cos.unsqueeze(0).unsqueeze(2)
88
+ sin = sin.unsqueeze(0).unsqueeze(2)
89
+ out = torch.empty_like(x)
90
+ out[..., 0::2] = x1 * cos - x2 * sin
91
+ out[..., 1::2] = x1 * sin + x2 * cos
92
+ return out
93
+
94
+
95
+ class Attention(nn.Module):
96
+ def __init__(self, args):
97
+ super().__init__()
98
+ self.n_heads = args.n_heads
99
+ self.head_dim = args.dim // args.n_heads
100
+ self.n_kv_heads = args.n_kv_heads
101
+ self.repeat_kv = args.n_heads // args.n_kv_heads
102
+
103
+ self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
104
+ self.wk = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
105
+ self.wv = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
106
+ self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
107
+
108
+ def forward(self, x, cos, sin):
109
+ B, T, _ = x.shape
110
+
111
+ q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
112
+ k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
113
+ v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)
114
+
115
+ k = k.repeat_interleave(self.repeat_kv, dim=2)
116
+ v = v.repeat_interleave(self.repeat_kv, dim=2)
117
+
118
+ q = apply_rotary_emb(q, cos, sin)
119
+ k = apply_rotary_emb(k, cos, sin)
120
+
121
+ q = q.transpose(1, 2)
122
+ k = k.transpose(1, 2)
123
+ v = v.transpose(1, 2)
124
+
125
+ if FLASH_ATTENTION_2:
126
+ out = flash_attn_func(q, k, v, causal=True)
127
+ else:
128
+ out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
129
+
130
+ out = out.transpose(1, 2).contiguous().view(B, T, -1)
131
+ return self.wo(out)
132
+
133
+
134
+ class FeedForward(nn.Module):
135
+ def __init__(self, dim, multiple_of, mult):
136
+ super().__init__()
137
+ hidden = multiple_of * ((int(dim * mult) + multiple_of - 1) // multiple_of)
138
+ self.w1 = nn.Linear(dim, hidden, bias=False)
139
+ self.w2 = nn.Linear(hidden, dim, bias=False)
140
+ self.w3 = nn.Linear(dim, hidden, bias=False)
141
+
142
+ def forward(self, x):
143
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
144
+
145
+
146
+ class TransformerBlock(nn.Module):
147
+ def __init__(self, args):
148
+ super().__init__()
149
+ self.attn = Attention(args)
150
+ self.ffn = FeedForward(args.dim, args.multiple_of, args.ffn_dim_multiplier)
151
+ self.attn_norm = RMSNorm(args.dim, args.norm_eps)
152
+ self.ffn_norm = RMSNorm(args.dim, args.norm_eps)
153
+ self.gradient_checkpointing = False
154
+
155
+ def forward(self, x, cos, sin):
156
+ x = x + self.attn(self.attn_norm(x), cos, sin)
157
+
158
+ if self.training and self.gradient_checkpointing:
159
+ x = x + torch.utils.checkpoint.checkpoint(
160
+ self._ffn, x, use_reentrant=False
161
+ )
162
+ else:
163
+ x = x + self.ffn(self.ffn_norm(x))
164
+ return x
165
+
166
+ def _ffn(self, x):
167
+ return self.ffn(self.ffn_norm(x))
168
+
169
+
170
+ class Transformer(nn.Module):
171
+ def __init__(self, args):
172
+ super().__init__()
173
+ self.tok_emb = nn.Embedding(args.vocab_size, args.dim)
174
+ self.layers = nn.ModuleList([TransformerBlock(args) for _ in range(args.n_layers)])
175
+ self.norm = RMSNorm(args.dim, args.norm_eps)
176
+ self.out = nn.Linear(args.dim, args.vocab_size, bias=False)
177
+
178
+ cos, sin = precompute_freqs_cis(args.dim // args.n_heads, args.max_seq_len * 2)
179
+ self.register_buffer("cos_cached", cos, persistent=False)
180
+ self.register_buffer("sin_cached", sin, persistent=False)
181
+
182
+ self.apply(self._init)
183
+
184
+ def gradient_checkpointing_enable(self):
185
+ for layer in self.layers:
186
+ layer.gradient_checkpointing = True
187
+ print("[OK] Gradient checkpointing enabled")
188
+
189
+ def _init(self, m):
190
+ if isinstance(m, (nn.Linear, nn.Embedding)):
191
+ nn.init.normal_(m.weight, std=0.02)
192
+
193
+ def forward(self, tokens):
194
+ B, T = tokens.shape
195
+ h = self.tok_emb(tokens)
196
+ cos = self.cos_cached[:T]
197
+ sin = self.sin_cached[:T]
198
+
199
+ for layer in self.layers:
200
+ h = layer(h, cos, sin)
201
+
202
+ h = self.norm(h)
203
+ return self.out(h)
204
+
205
+ def get_num_params(self):
206
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
207
+
208
+ # =========================
209
+ # MEMMAP DATASET (FIXED)
210
+ # =========================
211
+ class MemmapDataset(Dataset):
212
+ def __init__(self, path: str, max_seq_len: int, stride: Optional[int] = None):
213
+ self.tokens = np.memmap(path, dtype=np.int32, mode="r")
214
+ self.max_seq_len = max_seq_len
215
+ self.stride = stride or max_seq_len // 2
216
+
217
+ max_start = len(self.tokens) - (max_seq_len + 1)
218
+ if max_start <= 0:
219
+ raise ValueError("Dataset too small for the given max_seq_len")
220
+
221
+ self.starts = list(range(0, max_start, self.stride))
222
+ if self.starts[-1] != max_start:
223
+ self.starts.append(max_start)
224
+
225
+ def __len__(self):
226
+ return len(self.starts)
227
+
228
+ def __getitem__(self, idx):
229
+ i = self.starts[idx]
230
+ seq = torch.from_numpy(
231
+ self.tokens[i:i + self.max_seq_len + 1].copy()
232
+ ).long()
233
+ return seq[:-1], seq[1:]
234
+
235
+ # =========================
236
+ # TEXT GENERATION
237
+ # =========================
238
+ @torch.no_grad()
239
+ def generate_text(model, tokenizer, prompts,
240
+ max_new_tokens=128, temperature=0.8, top_p=0.95, eos_id=1):
241
+ model.eval()
242
+ device = next(model.parameters()).device
243
+ results = {}
244
+
245
+ for prompt in prompts:
246
+ ids = tokenizer.encode(prompt)
247
+ x = torch.tensor([ids], device=device)
248
+
249
+ for _ in range(max_new_tokens):
250
+ logits = model(x)[0, -1] / temperature
251
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True)
252
+ probs = torch.softmax(sorted_logits, dim=0)
253
+
254
+ cum_probs = probs.cumsum(dim=0)
255
+ mask = cum_probs > top_p
256
+ mask[1:] = mask[:-1].clone()
257
+ mask[0] = False
258
+
259
+ logits[sorted_idx[mask]] = -float("inf")
260
+ probs = torch.softmax(logits, dim=0)
261
+
262
+ next_tok = torch.multinomial(probs, 1)
263
+ x = torch.cat([x, next_tok.unsqueeze(0)], dim=1)
264
+
265
+ if next_tok.item() == eos_id:
266
+ break
267
+
268
+ results[prompt] = tokenizer.decode(x[0].tolist())
269
+
270
+ return results
271
+
272
+ # =========================
273
+ # TRAINING
274
+ # =========================
275
+ def train(
276
+ model,
277
+ train_ds,
278
+ valid_ds,
279
+ tokenizer,
280
+ args,
281
+ batch_size=1,
282
+ grad_accum=8,
283
+ epochs=1,
284
+ lr=1e-5,
285
+ warmup_steps=0,
286
+ ):
287
+ accelerator = Accelerator(
288
+ mixed_precision="bf16" if torch.cuda.is_bf16_supported() else "fp16",
289
+ gradient_accumulation_steps=grad_accum,
290
+ )
291
+
292
+ model.gradient_checkpointing_enable()
293
+
294
+ train_loader = DataLoader(
295
+ train_ds,
296
+ batch_size=batch_size,
297
+ shuffle=True,
298
+ num_workers=2,
299
+ pin_memory=True,
300
+ )
301
+
302
+ valid_loader = DataLoader(
303
+ valid_ds,
304
+ batch_size=batch_size,
305
+ shuffle=False,
306
+ num_workers=2,
307
+ pin_memory=True,
308
+ )
309
+
310
+ optimizer = torch.optim.AdamW(
311
+ model.parameters(),
312
+ lr=lr,
313
+ betas=(0.9, 0.95),
314
+ weight_decay=0.01,
315
+ )
316
+
317
+ total_steps = math.ceil(len(train_loader) / grad_accum) * epochs
318
+
319
+ def lr_lambda(step):
320
+ if step < warmup_steps:
321
+ return step / max(1, warmup_steps)
322
+ progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
323
+ return 0.5 * (1.0 + math.cos(math.pi * progress))
324
+
325
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
326
+
327
+ model, optimizer, train_loader, valid_loader, scheduler = accelerator.prepare(
328
+ model, optimizer, train_loader, valid_loader, scheduler
329
+ )
330
+
331
+ if accelerator.is_main_process:
332
+ eff_bs = batch_size * grad_accum * accelerator.num_processes
333
+ print(f"Model params: {model.get_num_params():,}")
334
+ print(f"Effective batch size: {eff_bs}")
335
+ print(f"Total optimizer steps: {total_steps}")
336
+ print(f"Flash Attention: {FLASH_ATTENTION_2}")
337
+ print("-" * 60)
338
+
339
+ global_step = 0
340
+ best_val = float("inf")
341
+
342
+ for epoch in range(epochs):
343
+ model.train()
344
+ running_loss = 0.0
345
+
346
+ pbar = tqdm(
347
+ train_loader,
348
+ disable=not accelerator.is_local_main_process,
349
+ desc=f"Epoch {epoch+1}/{epochs}",
350
+ )
351
+
352
+ for step, (x, y) in enumerate(pbar):
353
+ with accelerator.accumulate(model):
354
+ logits = model(x)
355
+ loss = F.cross_entropy(
356
+ logits.view(-1, logits.size(-1)),
357
+ y.view(-1),
358
+ ignore_index=tokenizer.pad_id(),
359
+ )
360
+ accelerator.backward(loss)
361
+
362
+ if accelerator.sync_gradients:
363
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
364
+ optimizer.step()
365
+ scheduler.step()
366
+ optimizer.zero_grad()
367
+
368
+ # ======== global_step podle training steps (batchů) ========
369
+ global_step += 1
370
+
371
+ # ==========================================
372
+ # PERIODIC CHECKPOINT + TEXT GENERATION
373
+ # ==========================================
374
+ if accelerator.is_main_process and global_step % SAVE_EVERY_STEPS == 0:
375
+ ckpt_path = f"{CHECKPOINT_DIR}/step_{global_step}.pt"
376
+ checkpoint = {
377
+ "step": global_step,
378
+ "model_state_dict": accelerator.unwrap_model(model).state_dict(),
379
+ "optimizer_state_dict": optimizer.state_dict(),
380
+ "scheduler_state_dict": scheduler.state_dict(),
381
+ "model_args": args,
382
+ }
383
+ torch.save(checkpoint, ckpt_path)
384
+ print(f"[Checkpoint] Saved complete checkpoint at step {global_step}")
385
+
386
+ prompts = [
387
+ "Once upon a time",
388
+ "In a distant future",
389
+ "First step to build a rocket",
390
+ "Capital city of France",
391
+ "Artificial intelligence will",
392
+ ]
393
+
394
+ samples = generate_text(
395
+ accelerator.unwrap_model(model),
396
+ tokenizer,
397
+ prompts,
398
+ max_new_tokens=100,
399
+ temperature=0.8,
400
+ top_p=0.95,
401
+ )
402
+
403
+ print(f"[Sample generation @ step {global_step}]")
404
+ for prompt, text in samples.items():
405
+ print(f"Prompt: {prompt}")
406
+ print(f"Generated: {text}")
407
+ print("-" * 50)
408
+
409
+ running_loss += loss.item()
410
+ pbar.set_postfix(
411
+ loss=f"{running_loss/(step+1):.4f}",
412
+ lr=f"{scheduler.get_last_lr()[0]:.2e}",
413
+ )
414
+
415
+ # =========================
416
+ # VALIDATION
417
+ # =========================
418
+ model.eval()
419
+ val_loss = 0.0
420
+ with torch.no_grad():
421
+ for x, y in valid_loader:
422
+ logits = model(x)
423
+ loss = F.cross_entropy(
424
+ logits.view(-1, logits.size(-1)),
425
+ y.view(-1),
426
+ ignore_index=tokenizer.pad_id(),
427
+ )
428
+ val_loss += loss.item()
429
+
430
+ val_loss /= len(valid_loader)
431
+
432
+ accelerator.print(
433
+ f"[Epoch {epoch+1}] Train Loss: {running_loss/len(train_loader):.6f} | "
434
+ f"Val Loss: {val_loss:.6f}"
435
+ )
436
+
437
+ # =========================
438
+ # END-OF-EPOCH GENERATION
439
+ # =========================
440
+ if accelerator.is_main_process:
441
+ prompts = [
442
+ "Once upon a time",
443
+ "In a distant future",
444
+ "First step to build a rocket",
445
+ "Capital city of France",
446
+ "Artificial intelligence will",
447
+ ]
448
+
449
+ samples = generate_text(
450
+ accelerator.unwrap_model(model),
451
+ tokenizer,
452
+ prompts,
453
+ max_new_tokens=100,
454
+ temperature=0.8,
455
+ top_p=0.95,
456
+ )
457
+
458
+ print("[Sample generation]")
459
+ for prompt, text in samples.items():
460
+ print(f"Prompt: {prompt}")
461
+ print(f"Generated: {text}")
462
+ print("-" * 50)
463
+
464
+ # =========================
465
+ # FINAL SAVE
466
+ # =========================
467
+ if accelerator.is_main_process:
468
+ checkpoint = {
469
+ "step": global_step,
470
+ "model_state_dict": accelerator.unwrap_model(model).state_dict(),
471
+ "optimizer_state_dict": optimizer.state_dict(),
472
+ "scheduler_state_dict": scheduler.state_dict(),
473
+ "model_args": args,
474
+ }
475
+ torch.save(checkpoint, f"{CHECKPOINT_DIR}/final_model.pt")
476
+ print("Training complete.")
477
+
478
+
479
+
480
+ # =========================
481
+ # MAIN
482
+ # =========================
483
+ if __name__ == "__main__":
484
+ args = ModelArgs()
485
+
486
+ tokenizer = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH)
487
+ args.vocab_size = tokenizer.vocab_size()
488
+
489
+ train_ds = MemmapDataset(TRAIN_BIN, args.max_seq_len)
490
+ valid_ds = MemmapDataset(VALID_BIN, args.max_seq_len)
491
+
492
+ model = Transformer(args)
493
+
494
+ RESUME_FROM = "checkpoints/step_200000.pt"
495
+
496
+ if os.path.exists(RESUME_FROM):
497
+ print(f"[Resume] Loading checkpoint from {RESUME_FROM}")
498
+ checkpoint = torch.load(RESUME_FROM, map_location="cpu")
499
+
500
+ # Support both old format (direct state_dict) and new format (checkpoint dict)
501
+ if "model_state_dict" in checkpoint:
502
+ model.load_state_dict(checkpoint["model_state_dict"])
503
+ print(f"[Resume] Loaded model from step {checkpoint.get('step', 'unknown')}")
504
+ else:
505
+ # Old format: checkpoint is directly the model state_dict
506
+ model.load_state_dict(checkpoint)
507
+ print(f"[Resume] Loaded model (old format)")
508
+
509
+ train(
510
+ model,
511
+ train_ds,
512
+ valid_ds,
513
+ tokenizer,
514
+ args,
515
+ batch_size=1,
516
+ grad_accum=8,
517
+ epochs=1,
518
+ lr=1e-5,
519
+ )
valid.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f593d53b5d225ba26ba5e8c48277b7eb0d3737d2a1fc3544be43871a58c963b
3
+ size 4000000