endurasolution commited on
Commit
3b97420
·
verified ·
1 Parent(s): e0e9d48

Upload Ron-110M: pretrain + summarizer + tokenizer + code

Browse files
README.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: mit
5
+ tags:
6
+ - gpt
7
+ - text-generation
8
+ - summarization
9
+ - from-scratch
10
+ - pytorch
11
+ library_name: pytorch
12
+ ---
13
+
14
+ # Ron-110M
15
+
16
+ A 110M-parameter GPT-style language model trained from scratch on a single
17
+ RTX 3090. Pretrained on WikiText-103, then fine-tuned on CNN/DailyMail for
18
+ extractive news summarization.
19
+
20
+ This is a learning / research model. It is small, the tokenizer is a custom
21
+ byte-level BPE, and it does not use the Hugging Face `transformers` model
22
+ classes. The repo includes the original PyTorch code so you can run, fine-tune,
23
+ or continue pretraining from these weights.
24
+
25
+ ## Files
26
+
27
+ - `pretrain.pt` - base language model checkpoint (after WikiText-103 pretraining)
28
+ - `summarizer.pt` - SFT checkpoint for news summarization (start from this for inference)
29
+ - `tokenizer.json` - byte-level BPE tokenizer (32k vocab, specials: `<pad> <bos> <eos> <unk>`)
30
+ - `meta.json` - dataset metadata (vocab size, dtype, token counts)
31
+ - `code/model.py` - GPT model definition
32
+ - `code/tokenizer.py` - tokenizer wrapper with ByteLevel decoder fix
33
+ - `code/ask.py` - inference script with repetition penalty, top-p, no-repeat-ngram
34
+ - `code/train.py` - pretraining script
35
+ - `code/finetune_sft.py` - supervised fine-tuning script
36
+ - `code/make_cnndm_sft.py` - CNN/DailyMail SFT data builder
37
+ - `code/prepare_wikitext.py` - WikiText-103 tokenization + tokenizer training
38
+
39
+ ## Architecture
40
+
41
+ ```
42
+ n_layer = 12
43
+ n_head = 12
44
+ n_embd = 768
45
+ block_size = 512
46
+ vocab_size = 32000
47
+ parameters = 109.92M
48
+ ```
49
+
50
+ ## Training results
51
+
52
+ | Stage | Dataset | Steps | Final val loss |
53
+ |--------------------|---------------|--------|----------------|
54
+ | Pretrain | WikiText-103 | 12,000 | 3.15 |
55
+ | SFT (summarizer) | CNN/DailyMail | 6,000 | 2.97 |
56
+
57
+ ## Quick start
58
+
59
+ ```bash
60
+ # Clone this repo
61
+ git lfs install
62
+ git clone https://huggingface.co/endurasolution/RON-110M
63
+ cd RON-110M
64
+
65
+ # Install minimal deps
66
+ pip install torch numpy tokenizers rich
67
+
68
+ # Run inference
69
+ python code/ask.py \
70
+ --checkpoint summarizer.pt \
71
+ --tokenizer tokenizer.json \
72
+ --text "A man has been arrested in Manchester after a series of break-ins at local shops. Police said the suspect was found with stolen goods. He is due to appear in court on Monday." \
73
+ --max_new_tokens 80 \
74
+ --temperature 0.4 \
75
+ --top_p 0.9 \
76
+ --repetition_penalty 1.1 \
77
+ --no_repeat_ngram_size 3
78
+ ```
79
+
80
+ Expected output (paraphrased): a short news-style summary that preserves the key
81
+ facts from the input.
82
+
83
+ ## Continue training
84
+
85
+ To resume pretraining from `pretrain.pt`:
86
+
87
+ ```bash
88
+ python code/train.py \
89
+ --resume pretrain.pt \
90
+ --reset_step --reset_optimizer \
91
+ --data_dir data/wikitext103 \
92
+ --out_dir runs/wikitext-gpt \
93
+ --preset rtx3090_8h \
94
+ --batch_size 16 --grad_accum 8 \
95
+ --max_steps 12000 \
96
+ --learning_rate 2e-4 --min_lr 2e-5 \
97
+ --warmup_steps 200 \
98
+ --no_gradient_checkpointing \
99
+ --save_optimizer
100
+ ```
101
+
102
+ To fine-tune for a new task, prepare a JSONL file with `prompt` and `answer`
103
+ keys, then:
104
+
105
+ ```bash
106
+ python code/finetune_sft.py \
107
+ --base_checkpoint pretrain.pt \
108
+ --tokenizer tokenizer.json \
109
+ --sft_file your_data.jsonl \
110
+ --out_dir runs/my-finetune \
111
+ --max_steps 6000 \
112
+ --batch_size 8 --grad_accum 8 \
113
+ --learning_rate 5e-5 --min_lr 5e-6 \
114
+ --warmup_steps 200
115
+ ```
116
+
117
+ ## Limitations
118
+
119
+ - Small (110M parameters) - knowledge is limited, hallucinations possible on
120
+ out-of-domain inputs.
121
+ - Tokenizer is custom byte-level BPE - **must** be loaded with the included
122
+ `tokenizer.json`. Do not substitute a GPT-2 tokenizer.
123
+ - Not compatible with `transformers.AutoModel`. Use the included `code/`.
124
+ - SFT data was CNN/DailyMail news. The model is most reliable on news-style
125
+ English; expect weaker output on code, math, or conversational input.
126
+
127
+ ## License
128
+
129
+ MIT.
code/__init__.py ADDED
File without changes
code/ask.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from searshorai.model import GPT, GPTConfig
10
+ from searshorai.tokenizer import TextTokenizer
11
+
12
+
13
+ # Must match the prompts used in make_xsum_sft.py / make_paragraph_sft.py.
14
+ # Using the first template is the canonical choice at inference time.
15
+ DEFAULT_PROMPT_TEMPLATE = (
16
+ "Read the article and write a one-sentence summary.\n\n"
17
+ "Article:\n{passage}\n\nSummary:\n"
18
+ )
19
+
20
+
21
+ def strip_compile_prefix(state_dict):
22
+ cleaned = {}
23
+ for key, value in state_dict.items():
24
+ if key.startswith("_orig_mod."):
25
+ key = key[len("_orig_mod.") :]
26
+ cleaned[key] = value
27
+ return cleaned
28
+
29
+
30
+ def parse_args() -> argparse.Namespace:
31
+ parser = argparse.ArgumentParser(description="Ask the paragraph-explainer model.")
32
+ parser.add_argument("--checkpoint", type=Path, required=True)
33
+ parser.add_argument("--tokenizer", type=Path, default=Path("data/wikitext103/tokenizer.json"))
34
+ parser.add_argument("--text", type=str, required=True, help="The passage to explain.")
35
+ parser.add_argument("--prompt_template", type=str, default=DEFAULT_PROMPT_TEMPLATE)
36
+ parser.add_argument("--max_new_tokens", type=int, default=120)
37
+ parser.add_argument("--temperature", type=float, default=0.7)
38
+ parser.add_argument("--top_k", type=int, default=40)
39
+ parser.add_argument("--top_p", type=float, default=0.9,
40
+ help="Nucleus sampling cutoff. Set 1.0 to disable.")
41
+ parser.add_argument("--repetition_penalty", type=float, default=1.3,
42
+ help="Penalty for re-emitting tokens already in the context. 1.0 = off.")
43
+ parser.add_argument("--no_repeat_ngram_size", type=int, default=3,
44
+ help="Block any n-gram of this size from appearing twice. 0 = off.")
45
+ parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"])
46
+ parser.add_argument("--seed", type=int, default=0)
47
+ return parser.parse_args()
48
+
49
+
50
+ def banned_tokens_from_ngrams(generated: list[int], n: int) -> set[int]:
51
+ """
52
+ For no-repeat-ngram blocking: given the tokens generated so far, return
53
+ the set of token ids that would close a previously-seen n-gram if emitted
54
+ next.
55
+ """
56
+ if n <= 0 or len(generated) < n - 1:
57
+ return set()
58
+ prefix = tuple(generated[-(n - 1):])
59
+ banned: set[int] = set()
60
+ for i in range(len(generated) - n + 1):
61
+ ngram = tuple(generated[i : i + n - 1])
62
+ if ngram == prefix:
63
+ banned.add(generated[i + n - 1])
64
+ return banned
65
+
66
+
67
+ def generate(
68
+ model: GPT,
69
+ prompt_ids: list[int],
70
+ max_new_tokens: int,
71
+ temperature: float,
72
+ top_k: int,
73
+ top_p: float,
74
+ repetition_penalty: float,
75
+ no_repeat_ngram_size: int,
76
+ eos_id: int | None,
77
+ device: str,
78
+ ) -> list[int]:
79
+ """
80
+ Custom sampling loop with repetition penalty, top-k, top-p (nucleus),
81
+ and no-repeat-ngram blocking. Returns the list of newly generated token
82
+ ids (does not include the prompt).
83
+ """
84
+ block_size = model.config.block_size
85
+ context = list(prompt_ids)
86
+ generated: list[int] = []
87
+
88
+ for _ in range(max_new_tokens):
89
+ idx_cond = context if len(context) <= block_size else context[-block_size:]
90
+ x = torch.tensor([idx_cond], dtype=torch.long, device=device)
91
+
92
+ with torch.no_grad():
93
+ logits, _ = model(x)
94
+ logits = logits[:, -1, :].squeeze(0).float()
95
+
96
+ if repetition_penalty != 1.0 and len(context) > 0:
97
+ seen = torch.tensor(list(set(context)), dtype=torch.long, device=device)
98
+ scores = logits[seen]
99
+ scores = torch.where(scores > 0, scores / repetition_penalty, scores * repetition_penalty)
100
+ logits[seen] = scores
101
+
102
+ if no_repeat_ngram_size > 0 and len(generated) >= no_repeat_ngram_size - 1:
103
+ banned = banned_tokens_from_ngrams(generated, no_repeat_ngram_size)
104
+ for tok_id in banned:
105
+ logits[tok_id] = -float("inf")
106
+
107
+ logits = logits / max(temperature, 1e-5)
108
+
109
+ if top_k is not None and top_k > 0:
110
+ k = min(top_k, logits.size(-1))
111
+ top_vals, _ = torch.topk(logits, k)
112
+ cutoff = top_vals[-1]
113
+ logits = torch.where(logits < cutoff, torch.full_like(logits, -float("inf")), logits)
114
+
115
+ if top_p < 1.0:
116
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True)
117
+ probs_sorted = F.softmax(sorted_logits, dim=-1)
118
+ cumulative = torch.cumsum(probs_sorted, dim=-1)
119
+ mask = cumulative > top_p
120
+ mask[..., 1:] = mask[..., :-1].clone()
121
+ mask[..., 0] = False
122
+ sorted_logits = sorted_logits.masked_fill(mask, -float("inf"))
123
+ logits = torch.full_like(logits, -float("inf"))
124
+ logits.scatter_(0, sorted_idx, sorted_logits)
125
+
126
+ probs = F.softmax(logits, dim=-1)
127
+ if not torch.isfinite(probs).all() or probs.sum() <= 0:
128
+ next_tok = int(logits.argmax().item())
129
+ else:
130
+ next_tok = int(torch.multinomial(probs, num_samples=1).item())
131
+
132
+ if eos_id is not None and next_tok == eos_id:
133
+ break
134
+
135
+ context.append(next_tok)
136
+ generated.append(next_tok)
137
+
138
+ return generated
139
+
140
+
141
+ def main() -> None:
142
+ args = parse_args()
143
+
144
+ if args.seed:
145
+ torch.manual_seed(args.seed)
146
+ if torch.cuda.is_available():
147
+ torch.cuda.manual_seed_all(args.seed)
148
+
149
+ if args.device == "auto":
150
+ device = "cuda" if torch.cuda.is_available() else "cpu"
151
+ else:
152
+ device = args.device
153
+
154
+ tok = TextTokenizer(args.tokenizer)
155
+
156
+ ckpt = torch.load(args.checkpoint, map_location=device, weights_only=False)
157
+ config = GPTConfig(**ckpt["config"])
158
+ config.dropout = 0.0
159
+ config.gradient_checkpointing = False
160
+
161
+ model = GPT(config)
162
+ state_dict = strip_compile_prefix(ckpt["model"])
163
+ model.load_state_dict(state_dict, strict=True)
164
+ model.to(device)
165
+ model.eval()
166
+
167
+ if tok.vocab_size != model.config.vocab_size:
168
+ raise RuntimeError(
169
+ f"Tokenizer vocab_size {tok.vocab_size} != model vocab_size {model.config.vocab_size}. "
170
+ "Use the same tokenizer.json that was used for pretrain/SFT."
171
+ )
172
+
173
+ prompt = args.prompt_template.format(passage=args.text.strip())
174
+ prompt_ids = tok.encode(prompt, add_bos=True, add_eos=False)
175
+
176
+ max_prompt_len = model.config.block_size - args.max_new_tokens - 1
177
+ if max_prompt_len < 16:
178
+ raise RuntimeError(
179
+ f"max_new_tokens={args.max_new_tokens} is too large for block_size={model.config.block_size}."
180
+ )
181
+ if len(prompt_ids) > max_prompt_len:
182
+ bos = [prompt_ids[0]] if prompt_ids and prompt_ids[0] == tok.bos_id else []
183
+ tail = prompt_ids[-(max_prompt_len - len(bos)) :]
184
+ prompt_ids = bos + tail
185
+
186
+ new_ids = generate(
187
+ model=model,
188
+ prompt_ids=prompt_ids,
189
+ max_new_tokens=args.max_new_tokens,
190
+ temperature=args.temperature,
191
+ top_k=args.top_k,
192
+ top_p=args.top_p,
193
+ repetition_penalty=args.repetition_penalty,
194
+ no_repeat_ngram_size=args.no_repeat_ngram_size,
195
+ eos_id=tok.eos_id,
196
+ device=device,
197
+ )
198
+
199
+ answer = tok.decode(new_ids, skip_special_tokens=True).strip()
200
+ print(answer)
201
+
202
+
203
+ if __name__ == "__main__":
204
+ main()
code/finetune_sft.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import math
6
+ import random
7
+ import time
8
+ import unicodedata
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ import torch
14
+ from rich.console import Console
15
+ from torch.nn.utils.rnn import pad_sequence
16
+
17
+ from searshorai.model import GPT, GPTConfig
18
+ from searshorai.tokenizer import TextTokenizer
19
+
20
+
21
+ console = Console()
22
+
23
+
24
+ @dataclass
25
+ class Example:
26
+ input_ids: list[int]
27
+ labels: list[int]
28
+
29
+
30
+ def parse_args() -> argparse.Namespace:
31
+ parser = argparse.ArgumentParser(description="Stable supervised fine-tune for paragraph explanation.")
32
+
33
+ parser.add_argument("--base_checkpoint", type=Path, default=Path("runs/wikitext-gpt/best.pt"))
34
+ parser.add_argument("--tokenizer", type=Path, default=Path("data/wikitext103/tokenizer.json"))
35
+ parser.add_argument("--sft_file", type=Path, default=Path("data/wikitext103/paragraph_sft.jsonl"))
36
+ parser.add_argument("--out_dir", type=Path, default=Path("runs/paragraph-explainer"))
37
+
38
+ parser.add_argument("--max_steps", type=int, default=8000)
39
+ parser.add_argument("--batch_size", type=int, default=8)
40
+ parser.add_argument("--grad_accum", type=int, default=8)
41
+ parser.add_argument("--learning_rate", type=float, default=2e-5)
42
+ parser.add_argument("--min_lr", type=float, default=2e-6)
43
+ parser.add_argument("--warmup_steps", type=int, default=300)
44
+ parser.add_argument("--weight_decay", type=float, default=0.01)
45
+ parser.add_argument("--grad_clip", type=float, default=1.0)
46
+
47
+ parser.add_argument("--max_answer_tokens", type=int, default=220)
48
+ parser.add_argument("--min_answer_tokens", type=int, default=8)
49
+ parser.add_argument("--val_ratio", type=float, default=0.02)
50
+
51
+ parser.add_argument("--eval_interval", type=int, default=250)
52
+ parser.add_argument("--eval_batches", type=int, default=40)
53
+ parser.add_argument("--save_interval", type=int, default=500)
54
+ parser.add_argument("--log_interval", type=int, default=20)
55
+ parser.add_argument("--seed", type=int, default=1337)
56
+ parser.add_argument("--compile", action="store_true")
57
+ parser.add_argument("--resume", type=Path, default=None)
58
+
59
+ return parser.parse_args()
60
+
61
+
62
+ def clean_text(text: Any) -> str:
63
+ if text is None:
64
+ return ""
65
+ text = str(text)
66
+ text = text.replace("\ufffd", " ")
67
+ text = unicodedata.normalize("NFKC", text)
68
+ text = "".join(ch if (ch in ("\n", "\t") or ord(ch) >= 32) else " " for ch in text)
69
+ text = "\n".join(" ".join(line.split()) for line in text.splitlines())
70
+ return text.strip()
71
+
72
+
73
+ def get_special_id(tok: TextTokenizer, name: str) -> int | None:
74
+ value = getattr(tok, name, None)
75
+ return int(value) if isinstance(value, int) else None
76
+
77
+
78
+ def ensure_eos(ids: list[int], eos_id: int | None) -> list[int]:
79
+ if eos_id is None:
80
+ return ids
81
+ if not ids or ids[-1] != eos_id:
82
+ return ids + [eos_id]
83
+ return ids
84
+
85
+
86
+ def get_lr(step: int, args: argparse.Namespace) -> float:
87
+ if step < args.warmup_steps:
88
+ return args.learning_rate * (step + 1) / max(1, args.warmup_steps)
89
+ ratio = (step - args.warmup_steps) / max(1, args.max_steps - args.warmup_steps)
90
+ coeff = 0.5 * (1.0 + math.cos(math.pi * min(1.0, max(0.0, ratio))))
91
+ return args.min_lr + coeff * (args.learning_rate - args.min_lr)
92
+
93
+
94
+ def read_prompt_answer(row: dict[str, Any]) -> tuple[str, str]:
95
+ """
96
+ Supports these JSONL styles:
97
+ {"prompt": "...", "answer": "..."}
98
+ {"input": "...", "output": "..."}
99
+ {"paragraph": "...", "explanation": "..."}
100
+ {"text": "...", "answer": "..."}
101
+ """
102
+ if "prompt" in row:
103
+ prompt = row.get("prompt", "")
104
+ elif "paragraph" in row:
105
+ prompt = f"Explain this paragraph in simple words:\n\n{row.get('paragraph', '')}\n\nExplanation:\n"
106
+ elif "text" in row:
107
+ prompt = f"Explain this paragraph in simple words:\n\n{row.get('text', '')}\n\nExplanation:\n"
108
+ else:
109
+ prompt = row.get("input", "")
110
+
111
+ answer = (
112
+ row.get("answer")
113
+ if row.get("answer") is not None
114
+ else row.get("output")
115
+ if row.get("output") is not None
116
+ else row.get("explanation", "")
117
+ )
118
+
119
+ return clean_text(prompt), clean_text(answer)
120
+
121
+
122
+ def load_examples(path: Path, tok: TextTokenizer, block_size: int, args: argparse.Namespace) -> list[Example]:
123
+ if not path.exists():
124
+ raise FileNotFoundError(f"SFT file not found: {path}")
125
+
126
+ eos_id = get_special_id(tok, "eos_id")
127
+ examples: list[Example] = []
128
+
129
+ skipped_empty = 0
130
+ skipped_too_short = 0
131
+ truncated_answers = 0
132
+ bad_json = 0
133
+
134
+ with path.open("r", encoding="utf-8", errors="replace") as f:
135
+ for line in f:
136
+ line = line.strip()
137
+ if not line:
138
+ continue
139
+ try:
140
+ row = json.loads(line)
141
+ except json.JSONDecodeError:
142
+ bad_json += 1
143
+ continue
144
+
145
+ prompt, answer = read_prompt_answer(row)
146
+ if not prompt or not answer:
147
+ skipped_empty += 1
148
+ continue
149
+
150
+ prompt_ids = tok.encode(prompt, add_bos=True, add_eos=False)
151
+
152
+ # Encode answer without EOS, then add EOS after any truncation.
153
+ answer_ids = tok.encode(answer, add_bos=False, add_eos=False)
154
+ if len(answer_ids) < args.min_answer_tokens:
155
+ skipped_too_short += 1
156
+ continue
157
+ if len(answer_ids) > args.max_answer_tokens:
158
+ answer_ids = answer_ids[: args.max_answer_tokens]
159
+ truncated_answers += 1
160
+ answer_ids = ensure_eos(answer_ids, eos_id)
161
+
162
+ # full_ids must fit in block_size + 1 (we shift to get input/target).
163
+ room_for_prompt = (block_size + 1) - len(answer_ids)
164
+ if room_for_prompt < 16:
165
+ # Answer is huge - cut it further but keep EOS at the end.
166
+ keep = max(16, block_size - 32)
167
+ answer_ids = answer_ids[: keep - 1]
168
+ answer_ids = ensure_eos(answer_ids, eos_id)
169
+ room_for_prompt = (block_size + 1) - len(answer_ids)
170
+
171
+ # Keep the tail of the prompt if it is too long.
172
+ if len(prompt_ids) > room_for_prompt:
173
+ # Preserve BOS at position 0 by keeping BOS + tail of body.
174
+ bos = [prompt_ids[0]] if prompt_ids and prompt_ids[0] == tok.bos_id else []
175
+ tail = prompt_ids[-(room_for_prompt - len(bos)) :] if room_for_prompt - len(bos) > 0 else []
176
+ prompt_ids = bos + tail
177
+
178
+ full_ids = prompt_ids + answer_ids
179
+
180
+ if len(full_ids) > block_size + 1:
181
+ # Final hard cap. If we have to cut, keep EOS as the last target token.
182
+ full_ids = full_ids[: block_size + 1]
183
+ if eos_id is not None and full_ids[-1] != eos_id:
184
+ full_ids[-1] = eos_id
185
+
186
+ if len(full_ids) < 16:
187
+ skipped_too_short += 1
188
+ continue
189
+
190
+ input_ids = full_ids[:-1]
191
+ next_ids = full_ids[1:]
192
+
193
+ # Loss only on answer tokens (including the final EOS target).
194
+ prompt_len = len(prompt_ids)
195
+ labels = [
196
+ token_id if (position + 1) >= prompt_len else -100
197
+ for position, token_id in enumerate(next_ids)
198
+ ]
199
+
200
+ if any(x != -100 for x in labels):
201
+ examples.append(Example(input_ids=input_ids, labels=labels))
202
+
203
+ console.print(
204
+ f"Loaded {len(examples):,} examples | "
205
+ f"empty={skipped_empty:,}, short={skipped_too_short:,}, "
206
+ f"truncated_answers={truncated_answers:,}, bad_json={bad_json:,}"
207
+ )
208
+ if len(examples) < 10:
209
+ raise RuntimeError("Too few valid SFT examples. Check your JSONL keys and tokenizer.")
210
+ return examples
211
+
212
+
213
+ def make_batch(
214
+ examples: list[Example],
215
+ batch_size: int,
216
+ pad_id: int,
217
+ device: str,
218
+ block_size: int,
219
+ ):
220
+ if len(examples) >= batch_size:
221
+ batch = random.sample(examples, batch_size)
222
+ else:
223
+ batch = random.choices(examples, k=batch_size)
224
+
225
+ xs = []
226
+ ys = []
227
+ for ex in batch:
228
+ ix = ex.input_ids[:block_size]
229
+ ly = ex.labels[:block_size]
230
+ xs.append(torch.tensor(ix, dtype=torch.long))
231
+ ys.append(torch.tensor(ly, dtype=torch.long))
232
+
233
+ x = pad_sequence(xs, batch_first=True, padding_value=pad_id)
234
+ y = pad_sequence(ys, batch_first=True, padding_value=-100)
235
+
236
+ if device == "cuda":
237
+ x = x.pin_memory().to(device, non_blocking=True)
238
+ y = y.pin_memory().to(device, non_blocking=True)
239
+ else:
240
+ x = x.to(device)
241
+ y = y.to(device)
242
+ return x, y
243
+
244
+
245
+ @torch.no_grad()
246
+ def evaluate(model, examples, args, pad_id, device, autocast_ctx, block_size) -> float:
247
+ model.eval()
248
+ losses: list[float] = []
249
+ for _ in range(args.eval_batches):
250
+ x, y = make_batch(examples, args.batch_size, pad_id, device, block_size)
251
+ with autocast_ctx:
252
+ _, loss = model(x, y)
253
+ if torch.isfinite(loss):
254
+ losses.append(float(loss.item()))
255
+ model.train()
256
+ return sum(losses) / max(1, len(losses))
257
+
258
+
259
+ def strip_compile_prefix(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
260
+ cleaned = {}
261
+ for key, value in state_dict.items():
262
+ if key.startswith("_orig_mod."):
263
+ key = key[len("_orig_mod.") :]
264
+ cleaned[key] = value
265
+ return cleaned
266
+
267
+
268
+ def save_checkpoint(
269
+ path: Path,
270
+ model,
271
+ optimizer,
272
+ args: argparse.Namespace,
273
+ step: int,
274
+ best_val_loss: float,
275
+ meta: dict[str, Any],
276
+ ) -> None:
277
+ raw_model = model._orig_mod if hasattr(model, "_orig_mod") else model
278
+ meta = dict(meta or {})
279
+ meta.update(
280
+ {
281
+ "task": "paragraph_explainer_sft",
282
+ "tokenizer": str(args.tokenizer),
283
+ "sft_file": str(args.sft_file),
284
+ "important": "Prompt tokens are masked; answer is EOS-safe truncated.",
285
+ }
286
+ )
287
+ torch.save(
288
+ {
289
+ "model": raw_model.state_dict(),
290
+ "optimizer": optimizer.state_dict(),
291
+ "args": {k: (str(v) if isinstance(v, Path) else v) for k, v in vars(args).items()},
292
+ "config": vars(raw_model.config),
293
+ "step": step,
294
+ "best_val_loss": best_val_loss,
295
+ "meta": meta,
296
+ },
297
+ path,
298
+ )
299
+
300
+
301
+ def main() -> None:
302
+ args = parse_args()
303
+ args.out_dir.mkdir(parents=True, exist_ok=True)
304
+
305
+ random.seed(args.seed)
306
+ torch.manual_seed(args.seed)
307
+ if torch.cuda.is_available():
308
+ torch.cuda.manual_seed_all(args.seed)
309
+ torch.backends.cuda.matmul.allow_tf32 = True
310
+ torch.backends.cudnn.allow_tf32 = True
311
+
312
+ device = "cuda" if torch.cuda.is_available() else "cpu"
313
+ device_type = "cuda" if device == "cuda" else "cpu"
314
+
315
+ if device == "cuda" and torch.cuda.is_bf16_supported():
316
+ amp_dtype = torch.bfloat16
317
+ console.print("AMP dtype: bfloat16")
318
+ elif device == "cuda":
319
+ amp_dtype = torch.float16
320
+ console.print("AMP dtype: float16")
321
+ else:
322
+ amp_dtype = torch.float32
323
+ console.print("AMP disabled on CPU")
324
+
325
+ autocast_ctx = torch.amp.autocast(
326
+ device_type=device_type,
327
+ dtype=amp_dtype,
328
+ enabled=(device == "cuda"),
329
+ )
330
+
331
+ tok = TextTokenizer(args.tokenizer)
332
+ pad_id = int(getattr(tok, "pad_id", 0))
333
+
334
+ if args.resume is not None:
335
+ ckpt_path = args.resume
336
+ console.print(f"Resuming SFT checkpoint: {ckpt_path}")
337
+ else:
338
+ ckpt_path = args.base_checkpoint
339
+ console.print(f"Starting from base checkpoint: {ckpt_path}")
340
+
341
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
342
+
343
+ config = GPTConfig(**ckpt["config"])
344
+ # Force disable dropout for stable SFT (already 0.0 in pretrain).
345
+ config.dropout = 0.0
346
+ model = GPT(config)
347
+ state_dict = strip_compile_prefix(ckpt["model"])
348
+ model.load_state_dict(state_dict, strict=True)
349
+ model.to(device)
350
+
351
+ # Sanity check: tokenizer and model vocab must match.
352
+ if tok.vocab_size != model.config.vocab_size:
353
+ raise RuntimeError(
354
+ f"Tokenizer vocab_size {tok.vocab_size} != model vocab_size {model.config.vocab_size}. "
355
+ "This is the most common cause of garbled output. Use the same tokenizer that produced the pretrain data."
356
+ )
357
+
358
+ optimizer = model.configure_optimizers(
359
+ args.weight_decay,
360
+ args.learning_rate,
361
+ (0.9, 0.95),
362
+ device_type,
363
+ )
364
+
365
+ start_step = 0
366
+ best_val_loss = float("inf")
367
+ if args.resume is not None and "optimizer" in ckpt:
368
+ try:
369
+ optimizer.load_state_dict(ckpt["optimizer"])
370
+ start_step = int(ckpt.get("step", 0)) + 1
371
+ best_val_loss = float(ckpt.get("best_val_loss", float("inf")))
372
+ console.print(f"Resume from step {start_step}, previous best val {best_val_loss:.4f}")
373
+ except Exception as exc:
374
+ console.print(f"[yellow]Could not load optimizer state, starting fresh: {exc}[/yellow]")
375
+
376
+ try:
377
+ scaler = torch.amp.GradScaler("cuda", enabled=(device == "cuda" and amp_dtype == torch.float16))
378
+ except TypeError:
379
+ scaler = torch.cuda.amp.GradScaler(enabled=(device == "cuda" and amp_dtype == torch.float16))
380
+
381
+ examples = load_examples(args.sft_file, tok, model.config.block_size, args)
382
+ random.shuffle(examples)
383
+
384
+ val_size = max(1, int(len(examples) * args.val_ratio))
385
+ val_examples = examples[:val_size]
386
+ train_examples = examples[val_size:]
387
+ if not train_examples:
388
+ raise RuntimeError("No training examples after split.")
389
+
390
+ console.print(
391
+ f"Train={len(train_examples):,} | Val={len(val_examples):,} | "
392
+ f"Block size={model.config.block_size} | Device={device}"
393
+ )
394
+
395
+ if args.compile:
396
+ console.print("Compiling model with torch.compile...")
397
+ model = torch.compile(model)
398
+
399
+ model.train()
400
+ block_size = model.config.block_size if not hasattr(model, "_orig_mod") else model._orig_mod.config.block_size
401
+ last_time = time.time()
402
+ last_step = start_step
403
+
404
+ for step in range(start_step, args.max_steps + 1):
405
+ lr = get_lr(step, args)
406
+ for group in optimizer.param_groups:
407
+ group["lr"] = lr
408
+
409
+ optimizer.zero_grad(set_to_none=True)
410
+ loss_accum = 0.0
411
+ ok_micro_steps = 0
412
+
413
+ for _ in range(args.grad_accum):
414
+ x, y = make_batch(train_examples, args.batch_size, pad_id, device, block_size)
415
+ with autocast_ctx:
416
+ _, loss = model(x, y)
417
+ loss = loss / args.grad_accum
418
+ if not torch.isfinite(loss):
419
+ console.print(f"[yellow]Skipping non-finite loss at step {step}[/yellow]")
420
+ continue
421
+ scaler.scale(loss).backward()
422
+ loss_accum += float(loss.item())
423
+ ok_micro_steps += 1
424
+
425
+ if ok_micro_steps == 0:
426
+ scaler.update()
427
+ continue
428
+
429
+ scaler.unscale_(optimizer)
430
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
431
+ scaler.step(optimizer)
432
+ scaler.update()
433
+
434
+ if step % args.log_interval == 0:
435
+ now = time.time()
436
+ steps_done = max(1, step - last_step)
437
+ console.print(
438
+ f"step {step:6d} | loss {loss_accum:.4f} | "
439
+ f"lr {lr:.2e} | {(now - last_time) / steps_done:.2f}s/step"
440
+ )
441
+ last_time = now
442
+ last_step = step
443
+
444
+ if step > 0 and (step % args.eval_interval == 0 or step == args.max_steps):
445
+ val_loss = evaluate(model, val_examples, args, pad_id, device, autocast_ctx, block_size)
446
+ console.print(f"eval step {step}: val {val_loss:.4f}")
447
+ if val_loss < best_val_loss:
448
+ best_val_loss = val_loss
449
+ save_checkpoint(
450
+ args.out_dir / "best.pt",
451
+ model,
452
+ optimizer,
453
+ args,
454
+ step,
455
+ best_val_loss,
456
+ ckpt.get("meta", {}),
457
+ )
458
+ console.print(f"[green]saved best checkpoint: {best_val_loss:.4f}[/green]")
459
+
460
+ if step > 0 and step % args.save_interval == 0:
461
+ save_checkpoint(
462
+ args.out_dir / "latest.pt",
463
+ model,
464
+ optimizer,
465
+ args,
466
+ step,
467
+ best_val_loss,
468
+ ckpt.get("meta", {}),
469
+ )
470
+
471
+ save_checkpoint(
472
+ args.out_dir / "latest.pt",
473
+ model,
474
+ optimizer,
475
+ args,
476
+ args.max_steps,
477
+ best_val_loss,
478
+ ckpt.get("meta", {}),
479
+ )
480
+ console.print("Fine-tuning complete.")
481
+
482
+
483
+ if __name__ == "__main__":
484
+ main()
code/make_cnndm_sft.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import random
6
+ import re
7
+ import unicodedata
8
+ from pathlib import Path
9
+
10
+ from datasets import load_dataset
11
+ from tqdm import tqdm
12
+
13
+
14
+ _WS = re.compile(r"\s+")
15
+ _BAD_CHARS = re.compile(r"[\u0000-\u001f]")
16
+ _REFS = re.compile(r"\[\s*\d+\s*\]")
17
+ # CNN/DailyMail articles often start with "(CNN) -- " or "By . SOMEBODY . PUBLISHED:"
18
+ _CNN_PREFIX = re.compile(r"^\s*\(CNN\)\s*--\s*", re.IGNORECASE)
19
+ _BYLINE = re.compile(r"^\s*By\s+\.\s+.*?PUBLISHED:.*?\s*\.\s*", re.IGNORECASE | re.DOTALL)
20
+
21
+
22
+ PROMPT_TEMPLATES = [
23
+ "Read the article and write a short summary.\n\nArticle:\n{passage}\n\nSummary:\n",
24
+ "Summarize the following article in a few sentences.\n\nArticle:\n{passage}\n\nShort summary:\n",
25
+ "Below is a news article. Give a concise summary using key facts from the article.\n\nArticle:\n{passage}\n\nSummary:\n",
26
+ "Provide a short summary of the article below.\n\nArticle:\n{passage}\n\nSummary:\n",
27
+ ]
28
+
29
+
30
+ def normalize(text: str) -> str:
31
+ if text is None:
32
+ return ""
33
+ text = str(text)
34
+ text = text.replace("\ufffd", " ")
35
+ text = unicodedata.normalize("NFKC", text)
36
+ text = _BAD_CHARS.sub(" ", text)
37
+ text = _REFS.sub("", text)
38
+ text = _CNN_PREFIX.sub("", text)
39
+ text = _BYLINE.sub("", text)
40
+ text = _WS.sub(" ", text).strip()
41
+ return text
42
+
43
+
44
+ def join_highlights(highlights: str) -> str:
45
+ """
46
+ CNN/DailyMail highlights come as several short lines joined by newlines.
47
+ We join them into a single multi-sentence string with periods.
48
+ """
49
+ if not highlights:
50
+ return ""
51
+ pieces = [p.strip() for p in highlights.split("\n") if p.strip()]
52
+ # Make sure each piece ends with terminal punctuation.
53
+ fixed = []
54
+ for p in pieces:
55
+ if p[-1] not in ".!?":
56
+ p = p + "."
57
+ fixed.append(p)
58
+ return " ".join(fixed)
59
+
60
+
61
+ def is_good_pair(article: str, summary: str, min_article_chars: int, max_article_chars: int,
62
+ min_summary_chars: int, max_summary_chars: int) -> bool:
63
+ if not article or not summary:
64
+ return False
65
+ if not (min_article_chars <= len(article) <= max_article_chars):
66
+ return False
67
+ if not (min_summary_chars <= len(summary) <= max_summary_chars):
68
+ return False
69
+ # Reject if the summary is basically the whole article (rare here but safe).
70
+ if len(summary) >= 0.8 * len(article):
71
+ return False
72
+ # Must be mostly letters.
73
+ letters = sum(ch.isalpha() for ch in article)
74
+ if letters / max(1, len(article)) < 0.6:
75
+ return False
76
+ return True
77
+
78
+
79
+ def main() -> None:
80
+ parser = argparse.ArgumentParser(
81
+ description="Build an SFT set from CNN/DailyMail (near-extractive summaries)."
82
+ )
83
+ parser.add_argument("--out_file", type=Path, default=Path("data/wikitext103/paragraph_sft.jsonl"))
84
+ parser.add_argument("--dataset", type=str, default="abisee/cnn_dailymail")
85
+ parser.add_argument("--config", type=str, default="3.0.0")
86
+ parser.add_argument("--max_examples", type=int, default=100000)
87
+ parser.add_argument("--min_article_chars", type=int, default=400)
88
+ parser.add_argument("--max_article_chars", type=int, default=2200)
89
+ parser.add_argument("--min_summary_chars", type=int, default=80)
90
+ parser.add_argument("--max_summary_chars", type=int, default=400)
91
+ parser.add_argument("--seed", type=int, default=1337)
92
+ args = parser.parse_args()
93
+
94
+ args.out_file.parent.mkdir(parents=True, exist_ok=True)
95
+ rng = random.Random(args.seed)
96
+
97
+ print(f"Loading {args.dataset} ({args.config})...")
98
+ dataset = load_dataset(args.dataset, args.config, split="train")
99
+
100
+ count = 0
101
+ skipped = 0
102
+
103
+ with args.out_file.open("w", encoding="utf-8") as f:
104
+ for row in tqdm(dataset, desc="building SFT"):
105
+ article = normalize(row.get("article", ""))
106
+ summary = join_highlights(normalize(row.get("highlights", "")))
107
+
108
+ if not is_good_pair(
109
+ article, summary,
110
+ args.min_article_chars, args.max_article_chars,
111
+ args.min_summary_chars, args.max_summary_chars,
112
+ ):
113
+ skipped += 1
114
+ continue
115
+
116
+ if len(article) > args.max_article_chars:
117
+ article = article[: args.max_article_chars].rsplit(" ", 1)[0]
118
+
119
+ template = rng.choice(PROMPT_TEMPLATES)
120
+ prompt = template.format(passage=article)
121
+ f.write(json.dumps({"prompt": prompt, "answer": summary}, ensure_ascii=False) + "\n")
122
+ count += 1
123
+ if count >= args.max_examples:
124
+ break
125
+
126
+ print(f"Wrote {count:,} examples to {args.out_file} (skipped={skipped:,})")
127
+
128
+
129
+ if __name__ == "__main__":
130
+ main()
code/model.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import math
5
+ from dataclasses import dataclass
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.utils.checkpoint
10
+ from torch.nn import functional as F
11
+
12
+
13
+ @dataclass
14
+ class GPTConfig:
15
+ vocab_size: int
16
+ block_size: int = 512
17
+ n_layer: int = 12
18
+ n_head: int = 12
19
+ n_embd: int = 768
20
+ dropout: float = 0.0
21
+ bias: bool = False
22
+ gradient_checkpointing: bool = False
23
+
24
+
25
+ class LayerNorm(nn.Module):
26
+ def __init__(self, ndim: int, bias: bool):
27
+ super().__init__()
28
+ self.weight = nn.Parameter(torch.ones(ndim))
29
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
30
+
31
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
32
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
33
+
34
+
35
+ class CausalSelfAttention(nn.Module):
36
+ def __init__(self, config: GPTConfig):
37
+ super().__init__()
38
+ assert config.n_embd % config.n_head == 0
39
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
40
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
41
+ self.attn_dropout = config.dropout
42
+ self.resid_dropout = nn.Dropout(config.dropout)
43
+ self.n_head = config.n_head
44
+ self.n_embd = config.n_embd
45
+ self.dropout = config.dropout
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ batch, seq_len, channels = x.size()
49
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
50
+ head_dim = channels // self.n_head
51
+ q = q.view(batch, seq_len, self.n_head, head_dim).transpose(1, 2)
52
+ k = k.view(batch, seq_len, self.n_head, head_dim).transpose(1, 2)
53
+ v = v.view(batch, seq_len, self.n_head, head_dim).transpose(1, 2)
54
+
55
+ y = F.scaled_dot_product_attention(
56
+ q,
57
+ k,
58
+ v,
59
+ attn_mask=None,
60
+ dropout_p=self.attn_dropout if self.training else 0.0,
61
+ is_causal=True,
62
+ )
63
+ y = y.transpose(1, 2).contiguous().view(batch, seq_len, channels)
64
+ return self.resid_dropout(self.c_proj(y))
65
+
66
+
67
+ class MLP(nn.Module):
68
+ def __init__(self, config: GPTConfig):
69
+ super().__init__()
70
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
71
+ self.gelu = nn.GELU(approximate="tanh")
72
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
73
+ self.dropout = nn.Dropout(config.dropout)
74
+
75
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
76
+ return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
77
+
78
+
79
+ class Block(nn.Module):
80
+ def __init__(self, config: GPTConfig):
81
+ super().__init__()
82
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
83
+ self.attn = CausalSelfAttention(config)
84
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
85
+ self.mlp = MLP(config)
86
+
87
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
88
+ x = x + self.attn(self.ln_1(x))
89
+ x = x + self.mlp(self.ln_2(x))
90
+ return x
91
+
92
+
93
+ class GPT(nn.Module):
94
+ def __init__(self, config: GPTConfig):
95
+ super().__init__()
96
+ self.config = config
97
+ self.transformer = nn.ModuleDict(
98
+ {
99
+ "wte": nn.Embedding(config.vocab_size, config.n_embd),
100
+ "wpe": nn.Embedding(config.block_size, config.n_embd),
101
+ "drop": nn.Dropout(config.dropout),
102
+ "h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
103
+ "ln_f": LayerNorm(config.n_embd, bias=config.bias),
104
+ }
105
+ )
106
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
107
+ self.transformer.wte.weight = self.lm_head.weight
108
+ self.apply(self._init_weights)
109
+ for name, param in self.named_parameters():
110
+ if name.endswith("c_proj.weight"):
111
+ torch.nn.init.normal_(param, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
112
+
113
+ def _init_weights(self, module: nn.Module) -> None:
114
+ if isinstance(module, nn.Linear):
115
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
116
+ if module.bias is not None:
117
+ torch.nn.init.zeros_(module.bias)
118
+ elif isinstance(module, nn.Embedding):
119
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
120
+
121
+ def forward(
122
+ self, idx: torch.Tensor, targets: torch.Tensor | None = None
123
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
124
+ batch, seq_len = idx.size()
125
+ if seq_len > self.config.block_size:
126
+ raise ValueError(f"Sequence length {seq_len} exceeds block size {self.config.block_size}")
127
+
128
+ pos = torch.arange(0, seq_len, dtype=torch.long, device=idx.device)
129
+ x = self.transformer.drop(self.transformer.wte(idx) + self.transformer.wpe(pos))
130
+ for block in self.transformer.h:
131
+ if self.config.gradient_checkpointing and self.training:
132
+ x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
133
+ else:
134
+ x = block(x)
135
+ x = self.transformer.ln_f(x)
136
+
137
+ if targets is None:
138
+ logits = self.lm_head(x[:, [-1], :])
139
+ loss = None
140
+ else:
141
+ logits = self.lm_head(x)
142
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
143
+ return logits, loss
144
+
145
+ @torch.no_grad()
146
+ def generate(
147
+ self,
148
+ idx: torch.Tensor,
149
+ max_new_tokens: int,
150
+ temperature: float = 0.8,
151
+ top_k: int | None = 50,
152
+ eos_id: int | None = None,
153
+ ) -> torch.Tensor:
154
+ for _ in range(max_new_tokens):
155
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size :]
156
+ logits, _ = self(idx_cond)
157
+ logits = logits[:, -1, :] / max(temperature, 1e-5)
158
+ if top_k is not None:
159
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
160
+ logits[logits < v[:, [-1]]] = -float("Inf")
161
+ probs = F.softmax(logits, dim=-1)
162
+ idx_next = torch.multinomial(probs, num_samples=1)
163
+ idx = torch.cat((idx, idx_next), dim=1)
164
+ if eos_id is not None and idx_next.item() == eos_id:
165
+ break
166
+ return idx
167
+
168
+ def crop_block_size(self, block_size: int) -> None:
169
+ assert block_size <= self.config.block_size
170
+ self.config.block_size = block_size
171
+ self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
172
+
173
+ def configure_optimizers(
174
+ self, weight_decay: float, learning_rate: float, betas: tuple[float, float], device_type: str
175
+ ) -> torch.optim.Optimizer:
176
+ param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
177
+ decay_params = [p for _, p in param_dict.items() if p.dim() >= 2]
178
+ nodecay_params = [p for _, p in param_dict.items() if p.dim() < 2]
179
+ optim_groups = [
180
+ {"params": decay_params, "weight_decay": weight_decay},
181
+ {"params": nodecay_params, "weight_decay": 0.0},
182
+ ]
183
+ fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
184
+ use_fused = fused_available and device_type == "cuda"
185
+ extra_args = {"fused": True} if use_fused else {}
186
+ return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
187
+
188
+ def num_parameters(self) -> int:
189
+ return sum(p.numel() for p in self.parameters())
code/prepare_wikitext.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ from datasets import load_dataset
9
+ from tokenizers import ByteLevelBPETokenizer, Tokenizer
10
+ from tokenizers import decoders as _decoders
11
+ from tqdm import tqdm
12
+
13
+
14
+ SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>", "<unk>"]
15
+
16
+
17
+ def clean_lines(dataset):
18
+ for row in dataset:
19
+ text = row["text"].strip()
20
+ if text:
21
+ yield text
22
+
23
+
24
+ class _TokenizerAdapter:
25
+ """
26
+ Small adapter so the rest of the script can call .encode(text).ids and
27
+ .get_vocab() / .get_vocab_size() regardless of whether the tokenizer was
28
+ freshly trained (ByteLevelBPETokenizer) or reloaded from JSON (Tokenizer).
29
+ """
30
+
31
+ def __init__(self, tokenizer):
32
+ self._t = tokenizer
33
+
34
+ def encode(self, text: str):
35
+ return self._t.encode(text)
36
+
37
+ def get_vocab(self):
38
+ return self._t.get_vocab()
39
+
40
+ def get_vocab_size(self):
41
+ return self._t.get_vocab_size()
42
+
43
+
44
+ def load_or_train_tokenizer(tokenizer_path: Path, train_dataset, vocab_size: int, min_frequency: int):
45
+ if tokenizer_path.exists():
46
+ print(f"Using existing tokenizer at {tokenizer_path}")
47
+ # Reload via the generic Tokenizer class. ByteLevelBPETokenizer does NOT
48
+ # accept tokenizer_file= in current tokenizers releases.
49
+ t = Tokenizer.from_file(str(tokenizer_path))
50
+ # Make sure a ByteLevel decoder is attached so downstream decoding works.
51
+ try:
52
+ current_decoder = t.decoder
53
+ except Exception:
54
+ current_decoder = None
55
+ if current_decoder is None:
56
+ t.decoder = _decoders.ByteLevel()
57
+ return _TokenizerAdapter(t)
58
+
59
+ print("Training byte-level BPE tokenizer...")
60
+ t = ByteLevelBPETokenizer()
61
+ t.train_from_iterator(
62
+ clean_lines(train_dataset),
63
+ vocab_size=vocab_size,
64
+ min_frequency=min_frequency,
65
+ special_tokens=SPECIAL_TOKENS,
66
+ )
67
+ t.save(str(tokenizer_path))
68
+ # Reopen via generic Tokenizer so we attach a decoder consistently.
69
+ reopened = Tokenizer.from_file(str(tokenizer_path))
70
+ try:
71
+ current_decoder = reopened.decoder
72
+ except Exception:
73
+ current_decoder = None
74
+ if current_decoder is None:
75
+ reopened.decoder = _decoders.ByteLevel()
76
+ return _TokenizerAdapter(reopened)
77
+
78
+
79
+ def write_split(tokenizer, dataset, out_file: Path, dtype, bos_id: int, eos_id: int) -> int:
80
+ token_count = 0
81
+ with out_file.open("wb") as f:
82
+ for text in tqdm(clean_lines(dataset), desc=f"tokenizing {out_file.name}"):
83
+ ids = [bos_id] + tokenizer.encode(text).ids + [eos_id]
84
+ arr = np.asarray(ids, dtype=dtype)
85
+ arr.tofile(f)
86
+ token_count += len(ids)
87
+ return token_count
88
+
89
+
90
+ def main() -> None:
91
+ parser = argparse.ArgumentParser(
92
+ description="Download WikiText-103, train a tokenizer, and make binary token files."
93
+ )
94
+ parser.add_argument("--data_dir", type=Path, default=Path("data/wikitext103"))
95
+ parser.add_argument("--dataset", type=str, default="Salesforce/wikitext")
96
+ parser.add_argument("--config", type=str, default="wikitext-103-raw-v1")
97
+ parser.add_argument("--vocab_size", type=int, default=32000)
98
+ parser.add_argument("--min_frequency", type=int, default=2)
99
+ args = parser.parse_args()
100
+
101
+ args.data_dir.mkdir(parents=True, exist_ok=True)
102
+ tokenizer_path = args.data_dir / "tokenizer.json"
103
+
104
+ print("Loading WikiText-103...")
105
+ train = load_dataset(args.dataset, args.config, split="train")
106
+ val = load_dataset(args.dataset, args.config, split="validation")
107
+ test = load_dataset(args.dataset, args.config, split="test")
108
+
109
+ tokenizer = load_or_train_tokenizer(
110
+ tokenizer_path=tokenizer_path,
111
+ train_dataset=train,
112
+ vocab_size=args.vocab_size,
113
+ min_frequency=args.min_frequency,
114
+ )
115
+
116
+ vocab = tokenizer.get_vocab()
117
+ if "<bos>" not in vocab or "<eos>" not in vocab or "<pad>" not in vocab:
118
+ raise RuntimeError(
119
+ "Tokenizer is missing required special tokens (<pad>, <bos>, <eos>). "
120
+ "Delete data/wikitext103/tokenizer.json and re-run to retrain."
121
+ )
122
+
123
+ bos_id = vocab["<bos>"]
124
+ eos_id = vocab["<eos>"]
125
+ pad_id = vocab["<pad>"]
126
+
127
+ vocab_size = tokenizer.get_vocab_size()
128
+ dtype = np.uint16 if vocab_size <= np.iinfo(np.uint16).max else np.uint32
129
+
130
+ train_tokens = write_split(tokenizer, train, args.data_dir / "train.bin", dtype, bos_id, eos_id)
131
+ val_tokens = write_split(tokenizer, val, args.data_dir / "val.bin", dtype, bos_id, eos_id)
132
+ test_tokens = write_split(tokenizer, test, args.data_dir / "test.bin", dtype, bos_id, eos_id)
133
+
134
+ meta = {
135
+ "dataset": args.dataset,
136
+ "config": args.config,
137
+ "vocab_size": vocab_size,
138
+ "dtype": "uint16" if dtype == np.uint16 else "uint32",
139
+ "bos_id": bos_id,
140
+ "eos_id": eos_id,
141
+ "pad_id": pad_id,
142
+ "train_tokens": train_tokens,
143
+ "val_tokens": val_tokens,
144
+ "test_tokens": test_tokens,
145
+ }
146
+ (args.data_dir / "meta.json").write_text(json.dumps(meta, indent=2), encoding="utf-8")
147
+ print(f"Done. Wrote tokenizer and token files to {args.data_dir}")
148
+
149
+
150
+ if __name__ == "__main__":
151
+ main()
code/tokenizer.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ from tokenizers import Tokenizer
6
+ from tokenizers import decoders as _decoders
7
+
8
+
9
+ class TextTokenizer:
10
+ """
11
+ Wrapper around tokenizers.Tokenizer that guarantees a ByteLevel decoder
12
+ is attached. ByteLevelBPETokenizer saves a JSON without a decoder block,
13
+ so reloading via Tokenizer.from_file() yields a tokenizer whose .decode()
14
+ returns raw byte-level tokens (Ġ, ä) and replacement chars (�, �)
15
+ instead of proper UTF-8 text. We attach the decoder here so decode is
16
+ always correct.
17
+ """
18
+
19
+ def __init__(self, path: str | Path):
20
+ self.path = Path(path)
21
+ self.tokenizer = Tokenizer.from_file(str(self.path))
22
+
23
+ # Force a ByteLevel decoder if one is not attached.
24
+ try:
25
+ current_decoder = self.tokenizer.decoder
26
+ except Exception:
27
+ current_decoder = None
28
+ if current_decoder is None:
29
+ self.tokenizer.decoder = _decoders.ByteLevel()
30
+
31
+ vocab = self.tokenizer.get_vocab()
32
+ self.pad_id = vocab.get("<pad>", 0)
33
+ self.bos_id = vocab.get("<bos>", 1)
34
+ self.eos_id = vocab.get("<eos>", 2)
35
+ self.unk_id = vocab.get("<unk>", 3)
36
+ self.vocab_size = self.tokenizer.get_vocab_size()
37
+
38
+ def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> list[int]:
39
+ ids = self.tokenizer.encode(text).ids
40
+ if add_bos:
41
+ ids = [self.bos_id] + ids
42
+ if add_eos:
43
+ ids = ids + [self.eos_id]
44
+ return ids
45
+
46
+ def decode(self, ids: list[int], skip_special_tokens: bool = True) -> str:
47
+ if skip_special_tokens:
48
+ specials = {self.pad_id, self.bos_id, self.eos_id, self.unk_id}
49
+ ids = [int(i) for i in ids if int(i) not in specials]
50
+ else:
51
+ ids = [int(i) for i in ids]
52
+ return self.tokenizer.decode(ids, skip_special_tokens=skip_special_tokens)
code/train.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ import time
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ import numpy as np
13
+ import torch
14
+ from rich.console import Console
15
+
16
+ from searshorai.model import GPT, GPTConfig
17
+
18
+
19
+ console = Console()
20
+
21
+
22
+ PRESETS = {
23
+ "quick_test": dict(
24
+ n_layer=6,
25
+ n_head=6,
26
+ n_embd=384,
27
+ block_size=256,
28
+ batch_size=8,
29
+ grad_accum=8,
30
+ max_steps=1000,
31
+ ),
32
+ "gpu_16gb": dict(
33
+ n_layer=10,
34
+ n_head=10,
35
+ n_embd=640,
36
+ block_size=512,
37
+ batch_size=4,
38
+ grad_accum=16,
39
+ max_steps=20000,
40
+ ),
41
+ "rtx3090_8h": dict(
42
+ n_layer=12,
43
+ n_head=12,
44
+ n_embd=768,
45
+ block_size=512,
46
+ batch_size=8,
47
+ grad_accum=16,
48
+ max_steps=20000,
49
+ ),
50
+ "rtx3090_quality": dict(
51
+ n_layer=16,
52
+ n_head=16,
53
+ n_embd=1024,
54
+ block_size=512,
55
+ batch_size=4,
56
+ grad_accum=24,
57
+ max_steps=30000,
58
+ ),
59
+ "gpu_40gb_quality": dict(
60
+ n_layer=20,
61
+ n_head=16,
62
+ n_embd=1024,
63
+ block_size=768,
64
+ batch_size=4,
65
+ grad_accum=32,
66
+ max_steps=40000,
67
+ ),
68
+ }
69
+
70
+
71
+ def parse_args() -> argparse.Namespace:
72
+ parser = argparse.ArgumentParser(description="Train a GPT-style language model from scratch.")
73
+
74
+ parser.add_argument("--data_dir", type=Path, default=Path("data/wikitext103"))
75
+ parser.add_argument("--out_dir", type=Path, default=Path("runs/wikitext-gpt"))
76
+
77
+ parser.add_argument("--preset", choices=PRESETS.keys(), default="gpu_16gb")
78
+
79
+ parser.add_argument("--resume", type=Path, default=None)
80
+ parser.add_argument("--reset_optimizer", action="store_true")
81
+ parser.add_argument("--reset_step", action="store_true",
82
+ help="When resuming, restart step counter at 0 (useful when restarting a fresh schedule).")
83
+
84
+ parser.add_argument("--n_layer", type=int, default=None)
85
+ parser.add_argument("--n_head", type=int, default=None)
86
+ parser.add_argument("--n_embd", type=int, default=None)
87
+ parser.add_argument("--block_size", type=int, default=None)
88
+
89
+ parser.add_argument("--batch_size", type=int, default=None, help="Micro-batch size.")
90
+ parser.add_argument("--grad_accum", type=int, default=None)
91
+ parser.add_argument("--max_steps", type=int, default=None)
92
+
93
+ parser.add_argument("--learning_rate", type=float, default=2.5e-4)
94
+ parser.add_argument("--min_lr", type=float, default=2.5e-5)
95
+ parser.add_argument("--warmup_steps", type=int, default=1000)
96
+ parser.add_argument("--weight_decay", type=float, default=0.1)
97
+ parser.add_argument("--dropout", type=float, default=0.0)
98
+ parser.add_argument("--grad_clip", type=float, default=1.0)
99
+
100
+ parser.add_argument("--eval_interval", type=int, default=500)
101
+ parser.add_argument("--eval_iters", type=int, default=100)
102
+ parser.add_argument("--save_interval", type=int, default=1000)
103
+ parser.add_argument("--log_interval", type=int, default=20)
104
+
105
+ parser.add_argument("--seed", type=int, default=1337)
106
+
107
+ parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"])
108
+ parser.add_argument("--dtype", type=str, default="auto", choices=["auto", "float32", "float16", "bfloat16"])
109
+
110
+ parser.add_argument("--compile", action="store_true")
111
+
112
+ parser.add_argument("--gradient_checkpointing", action="store_true")
113
+ parser.add_argument(
114
+ "--no_gradient_checkpointing",
115
+ "--no-gradient-checkpointing",
116
+ action="store_true",
117
+ help="Disable checkpointing when resuming from a checkpoint that was trained with it.",
118
+ )
119
+
120
+ parser.add_argument("--eval_only", action="store_true")
121
+ parser.add_argument("--always_save_checkpoint", action="store_true")
122
+ parser.add_argument("--save_optimizer", action="store_true")
123
+
124
+ return parser.parse_args()
125
+
126
+
127
+ def apply_preset(args: argparse.Namespace) -> argparse.Namespace:
128
+ preset = PRESETS[args.preset]
129
+ for key, value in preset.items():
130
+ if getattr(args, key) is None:
131
+ setattr(args, key, value)
132
+ return args
133
+
134
+
135
+ def setup_reproducibility(seed: int) -> None:
136
+ random.seed(seed)
137
+ np.random.seed(seed)
138
+ torch.manual_seed(seed)
139
+ if torch.cuda.is_available():
140
+ torch.cuda.manual_seed_all(seed)
141
+ torch.backends.cuda.matmul.allow_tf32 = True
142
+ torch.backends.cudnn.allow_tf32 = True
143
+ torch.backends.cudnn.benchmark = True
144
+
145
+
146
+ def choose_device(args: argparse.Namespace) -> str:
147
+ if args.device == "auto":
148
+ return "cuda" if torch.cuda.is_available() else "cpu"
149
+ if args.device == "cuda" and not torch.cuda.is_available():
150
+ raise RuntimeError("CUDA was requested, but torch.cuda.is_available() is False.")
151
+ return args.device
152
+
153
+
154
+ def choose_dtype(args: argparse.Namespace, device: str) -> torch.dtype:
155
+ if device == "cpu":
156
+ return torch.float32
157
+ if args.dtype == "float32":
158
+ return torch.float32
159
+ if args.dtype == "float16":
160
+ return torch.float16
161
+ if args.dtype == "bfloat16":
162
+ if torch.cuda.is_bf16_supported():
163
+ return torch.bfloat16
164
+ console.print("[yellow]bfloat16 requested but not supported. Falling back to float16.[/yellow]")
165
+ return torch.float16
166
+ if torch.cuda.is_bf16_supported():
167
+ return torch.bfloat16
168
+ return torch.float16
169
+
170
+
171
+ def make_autocast_context(device: str, dtype: torch.dtype):
172
+ enabled = device == "cuda" and dtype in (torch.float16, torch.bfloat16)
173
+ return torch.amp.autocast(device_type=device, dtype=dtype, enabled=enabled)
174
+
175
+
176
+ def make_grad_scaler(device: str, dtype: torch.dtype):
177
+ enabled = device == "cuda" and dtype == torch.float16
178
+ try:
179
+ return torch.amp.GradScaler("cuda", enabled=enabled)
180
+ except TypeError:
181
+ return torch.cuda.amp.GradScaler(enabled=enabled)
182
+
183
+
184
+ def get_lr(step: int, args: argparse.Namespace) -> float:
185
+ if step < args.warmup_steps:
186
+ return args.learning_rate * step / max(1, args.warmup_steps)
187
+ if step > args.max_steps:
188
+ return args.min_lr
189
+ decay_ratio = (step - args.warmup_steps) / max(1, args.max_steps - args.warmup_steps)
190
+ decay_ratio = min(1.0, max(0.0, decay_ratio))
191
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
192
+ return args.min_lr + coeff * (args.learning_rate - args.min_lr)
193
+
194
+
195
+ def load_json(path: Path) -> dict[str, Any]:
196
+ if not path.exists():
197
+ raise FileNotFoundError(f"Missing required file: {path}")
198
+ return json.loads(path.read_text(encoding="utf-8"))
199
+
200
+
201
+ def validate_meta(meta: dict[str, Any]) -> None:
202
+ required_keys = ["vocab_size", "dtype"]
203
+ for key in required_keys:
204
+ if key not in meta:
205
+ raise KeyError(f"meta.json is missing required key: {key}")
206
+ if meta["dtype"] not in ("uint16", "uint32"):
207
+ raise ValueError(f"Unsupported meta dtype: {meta['dtype']}. Expected uint16 or uint32.")
208
+ if int(meta["vocab_size"]) <= 0:
209
+ raise ValueError("meta.json vocab_size must be greater than zero.")
210
+ if meta["dtype"] == "uint16" and int(meta["vocab_size"]) > 65535:
211
+ raise ValueError("meta dtype is uint16 but vocab_size is greater than 65535. Use uint32 data files.")
212
+
213
+
214
+ def load_memmap(path: Path, dtype: str) -> np.memmap:
215
+ if not path.exists():
216
+ raise FileNotFoundError(f"Missing required file: {path}")
217
+ np_dtype = np.uint16 if dtype == "uint16" else np.uint32
218
+ return np.memmap(path, dtype=np_dtype, mode="r")
219
+
220
+
221
+ def validate_dataset(train_data: np.memmap, val_data: np.memmap, block_size: int, vocab_size: int) -> None:
222
+ min_required = block_size + 2
223
+ if len(train_data) < min_required:
224
+ raise ValueError(
225
+ f"train.bin is too small. Need at least {min_required} tokens for block_size={block_size}, "
226
+ f"but got {len(train_data)}."
227
+ )
228
+ if len(val_data) < min_required:
229
+ raise ValueError(
230
+ f"val.bin is too small. Need at least {min_required} tokens for block_size={block_size}, "
231
+ f"but got {len(val_data)}."
232
+ )
233
+
234
+ sample_count = min(10000, len(train_data))
235
+ sample_positions = np.linspace(0, len(train_data) - 1, sample_count, dtype=np.int64)
236
+ sample = np.asarray(train_data[sample_positions], dtype=np.int64)
237
+ max_token = int(sample.max())
238
+ min_token = int(sample.min())
239
+ if min_token < 0:
240
+ raise ValueError(f"Dataset contains negative token id: {min_token}")
241
+ if max_token >= vocab_size:
242
+ raise ValueError(
243
+ f"Dataset token id {max_token} is >= vocab_size {vocab_size}. "
244
+ "This usually means tokenizer/meta/train.bin mismatch."
245
+ )
246
+
247
+
248
+ def get_batch(
249
+ data: np.memmap,
250
+ batch_size: int,
251
+ block_size: int,
252
+ device: str,
253
+ ) -> tuple[torch.Tensor, torch.Tensor]:
254
+ """
255
+ Fast batch loader: one vectorized gather, then a single host->device transfer.
256
+ The old code did batch_size python-level numpy slices per call, which was a
257
+ major bottleneck.
258
+ """
259
+ max_start = len(data) - block_size - 1
260
+ if max_start <= 0:
261
+ raise ValueError("Dataset is too small for the configured block_size.")
262
+
263
+ # Random start positions.
264
+ ix = np.random.randint(0, max_start, size=(batch_size,), dtype=np.int64)
265
+
266
+ # Allocate contiguous int64 arrays. memmap reads are cheap for sequential blocks.
267
+ x_np = np.empty((batch_size, block_size), dtype=np.int64)
268
+ y_np = np.empty((batch_size, block_size), dtype=np.int64)
269
+ for row, start in enumerate(ix):
270
+ x_np[row] = data[start : start + block_size]
271
+ y_np[row] = data[start + 1 : start + 1 + block_size]
272
+
273
+ x = torch.from_numpy(x_np)
274
+ y = torch.from_numpy(y_np)
275
+
276
+ if device == "cuda":
277
+ x = x.pin_memory().to(device, non_blocking=True)
278
+ y = y.pin_memory().to(device, non_blocking=True)
279
+ else:
280
+ x = x.to(device)
281
+ y = y.to(device)
282
+ return x, y
283
+
284
+
285
+ @torch.no_grad()
286
+ def estimate_loss(
287
+ model: GPT,
288
+ train_data: np.memmap,
289
+ val_data: np.memmap,
290
+ args: argparse.Namespace,
291
+ device: str,
292
+ autocast_ctx,
293
+ ) -> dict[str, float]:
294
+ out: dict[str, float] = {}
295
+ model.eval()
296
+ for split, data in [("train", train_data), ("val", val_data)]:
297
+ losses = []
298
+ for _ in range(args.eval_iters):
299
+ x, y = get_batch(data, args.batch_size, args.block_size, device)
300
+ with autocast_ctx:
301
+ _, loss = model(x, y)
302
+ if torch.isfinite(loss):
303
+ losses.append(float(loss.item()))
304
+ out[split] = float(sum(losses) / max(1, len(losses)))
305
+ model.train()
306
+ return out
307
+
308
+
309
+ def unwrap_model(model: GPT) -> GPT:
310
+ if hasattr(model, "_orig_mod"):
311
+ return model._orig_mod
312
+ return model
313
+
314
+
315
+ def strip_compile_prefix(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
316
+ cleaned = {}
317
+ for key, value in state_dict.items():
318
+ if key.startswith("_orig_mod."):
319
+ key = key[len("_orig_mod.") :]
320
+ cleaned[key] = value
321
+ return cleaned
322
+
323
+
324
+ def optimizer_to_device(optimizer: torch.optim.Optimizer, device: str) -> None:
325
+ for state in optimizer.state.values():
326
+ for key, value in state.items():
327
+ if isinstance(value, torch.Tensor):
328
+ state[key] = value.to(device)
329
+
330
+
331
+ def save_checkpoint(
332
+ path: Path,
333
+ model: GPT,
334
+ optimizer: torch.optim.Optimizer | None,
335
+ args: argparse.Namespace,
336
+ step: int,
337
+ best_val_loss: float,
338
+ meta: dict[str, Any],
339
+ ) -> None:
340
+ raw_model = unwrap_model(model)
341
+ checkpoint: dict[str, Any] = {
342
+ "model": raw_model.state_dict(),
343
+ "args": vars(args),
344
+ "config": vars(raw_model.config),
345
+ "step": step,
346
+ "best_val_loss": best_val_loss,
347
+ "meta": meta,
348
+ }
349
+ if args.save_optimizer and optimizer is not None:
350
+ checkpoint["optimizer"] = optimizer.state_dict()
351
+ torch.save(checkpoint, path)
352
+
353
+
354
+ def write_run_config(args: argparse.Namespace, meta: dict[str, Any], device: str, dtype: torch.dtype) -> None:
355
+ config_path = args.out_dir / "run_config.json"
356
+ payload = {
357
+ "args": {k: (str(v) if isinstance(v, Path) else v) for k, v in vars(args).items()},
358
+ "meta": meta,
359
+ "device": device,
360
+ "dtype": str(dtype),
361
+ "torch_version": torch.__version__,
362
+ "cuda_available": torch.cuda.is_available(),
363
+ "cuda_device_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
364
+ }
365
+ config_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
366
+
367
+
368
+ def build_model_from_checkpoint(
369
+ ckpt_path: Path,
370
+ device: str,
371
+ args: argparse.Namespace,
372
+ ) -> tuple[GPT, int, float, dict[str, Any]]:
373
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
374
+ config = GPTConfig(**ckpt["config"])
375
+ if hasattr(config, "gradient_checkpointing"):
376
+ if args.no_gradient_checkpointing:
377
+ config.gradient_checkpointing = False
378
+ elif args.gradient_checkpointing:
379
+ config.gradient_checkpointing = True
380
+ model = GPT(config)
381
+ state_dict = strip_compile_prefix(ckpt["model"])
382
+ model.load_state_dict(state_dict, strict=True)
383
+ start_step = int(ckpt.get("step", 0))
384
+ best_val_loss = float(ckpt.get("best_val_loss", float("inf")))
385
+ checkpoint_meta = ckpt.get("meta", {})
386
+ return model, start_step, best_val_loss, checkpoint_meta
387
+
388
+
389
+ def build_new_model(meta: dict[str, Any], args: argparse.Namespace) -> tuple[GPT, int, float]:
390
+ config = GPTConfig(
391
+ vocab_size=int(meta["vocab_size"]),
392
+ block_size=int(args.block_size),
393
+ n_layer=int(args.n_layer),
394
+ n_head=int(args.n_head),
395
+ n_embd=int(args.n_embd),
396
+ dropout=float(args.dropout),
397
+ gradient_checkpointing=bool(args.gradient_checkpointing),
398
+ )
399
+ model = GPT(config)
400
+ return model, 0, float("inf")
401
+
402
+
403
+ def print_startup_info(
404
+ model: GPT,
405
+ args: argparse.Namespace,
406
+ device: str,
407
+ dtype: torch.dtype,
408
+ train_data: np.memmap,
409
+ val_data: np.memmap,
410
+ start_step: int,
411
+ ) -> None:
412
+ raw_model = unwrap_model(model)
413
+ tokens_per_step = args.batch_size * args.grad_accum * args.block_size
414
+ if hasattr(raw_model, "num_parameters"):
415
+ num_params = raw_model.num_parameters()
416
+ else:
417
+ num_params = sum(p.numel() for p in raw_model.parameters())
418
+
419
+ console.print("")
420
+ console.print("[bold green]Training configuration[/bold green]")
421
+ console.print(f"Device: {device}")
422
+ console.print(f"Dtype: {dtype}")
423
+ console.print(f"Preset: {args.preset}")
424
+ console.print(f"Parameters: {num_params / 1e6:.2f}M")
425
+ console.print(f"Layers: {args.n_layer}")
426
+ console.print(f"Heads: {args.n_head}")
427
+ console.print(f"Embedding size: {args.n_embd}")
428
+ console.print(f"Block size: {args.block_size}")
429
+ console.print(f"Batch size: {args.batch_size}")
430
+ console.print(f"Grad accumulation: {args.grad_accum}")
431
+ console.print(f"Tokens per step: {tokens_per_step:,}")
432
+ console.print(f"Train tokens: {len(train_data):,}")
433
+ console.print(f"Val tokens: {len(val_data):,}")
434
+ console.print(f"Start step: {start_step:,}")
435
+ console.print(f"Max steps: {args.max_steps:,}")
436
+ console.print(f"Learning rate: {args.learning_rate:.2e}")
437
+ console.print(f"Min LR: {args.min_lr:.2e}")
438
+ console.print(f"Warmup steps: {args.warmup_steps:,}")
439
+ console.print(f"Grad clip: {args.grad_clip}")
440
+ console.print("")
441
+
442
+
443
+ def main() -> None:
444
+ args = apply_preset(parse_args())
445
+ args.out_dir.mkdir(parents=True, exist_ok=True)
446
+ setup_reproducibility(args.seed)
447
+
448
+ device = choose_device(args)
449
+ dtype = choose_dtype(args, device)
450
+ autocast_ctx = make_autocast_context(device, dtype)
451
+ scaler = make_grad_scaler(device, dtype)
452
+
453
+ meta_path = args.data_dir / "meta.json"
454
+ meta = load_json(meta_path)
455
+ validate_meta(meta)
456
+
457
+ train_data = load_memmap(args.data_dir / "train.bin", meta["dtype"])
458
+ val_data = load_memmap(args.data_dir / "val.bin", meta["dtype"])
459
+ validate_dataset(
460
+ train_data=train_data,
461
+ val_data=val_data,
462
+ block_size=int(args.block_size),
463
+ vocab_size=int(meta["vocab_size"]),
464
+ )
465
+
466
+ if args.resume is not None:
467
+ console.print(f"[yellow]Resuming from checkpoint:[/yellow] {args.resume}")
468
+ model, start_step, best_val_loss, checkpoint_meta = build_model_from_checkpoint(args.resume, device, args)
469
+ if checkpoint_meta:
470
+ meta = checkpoint_meta
471
+ else:
472
+ model, start_step, best_val_loss = build_new_model(meta, args)
473
+
474
+ if args.reset_step:
475
+ start_step = 0
476
+ best_val_loss = float("inf")
477
+ console.print("[yellow]reset_step set: step counter restarted at 0.[/yellow]")
478
+
479
+ model.to(device)
480
+
481
+ optimizer = model.configure_optimizers(
482
+ args.weight_decay,
483
+ args.learning_rate,
484
+ (0.9, 0.95),
485
+ "cuda" if device == "cuda" else "cpu",
486
+ )
487
+
488
+ if args.resume is not None and not args.reset_optimizer:
489
+ ckpt = torch.load(args.resume, map_location=device, weights_only=False)
490
+ if "optimizer" in ckpt:
491
+ try:
492
+ optimizer.load_state_dict(ckpt["optimizer"])
493
+ optimizer_to_device(optimizer, device)
494
+ console.print("[green]Loaded optimizer state from checkpoint.[/green]")
495
+ except Exception as exc:
496
+ console.print(f"[yellow]Could not load optimizer state. Continuing with fresh optimizer. Error: {exc}[/yellow]")
497
+ else:
498
+ console.print("[yellow]Checkpoint has no optimizer state. Continuing with fresh optimizer.[/yellow]")
499
+ elif args.resume is not None and args.reset_optimizer:
500
+ console.print("[yellow]reset_optimizer set: starting with fresh Adam moments.[/yellow]")
501
+
502
+ if args.compile:
503
+ console.print("[cyan]Compiling model...[/cyan]")
504
+ model = torch.compile(model)
505
+
506
+ write_run_config(args, meta, device, dtype)
507
+ print_startup_info(model, args, device, dtype, train_data, val_data, start_step)
508
+
509
+ if args.eval_only:
510
+ losses = estimate_loss(model, train_data, val_data, args, device, autocast_ctx)
511
+ console.print(f"eval only: train {losses['train']:.4f}, val {losses['val']:.4f}")
512
+ return
513
+
514
+ model.train()
515
+ tokens_per_step = args.batch_size * args.grad_accum * args.block_size
516
+
517
+ start_time = time.time()
518
+ last_log_time = start_time
519
+ last_log_step = start_step
520
+
521
+ for completed_step in range(start_step, args.max_steps):
522
+ step = completed_step + 1
523
+
524
+ lr = get_lr(step, args)
525
+ for param_group in optimizer.param_groups:
526
+ param_group["lr"] = lr
527
+
528
+ optimizer.zero_grad(set_to_none=True)
529
+ loss_accum = 0.0
530
+ skipped_micro = 0
531
+
532
+ for _ in range(args.grad_accum):
533
+ x, y = get_batch(train_data, args.batch_size, args.block_size, device)
534
+ with autocast_ctx:
535
+ _, loss = model(x, y)
536
+ loss = loss / args.grad_accum
537
+ if not torch.isfinite(loss):
538
+ console.print(f"[yellow]Non-finite loss at step {step}, skipping micro-batch.[/yellow]")
539
+ skipped_micro += 1
540
+ continue
541
+ scaler.scale(loss).backward()
542
+ loss_accum += float(loss.item())
543
+
544
+ if skipped_micro == args.grad_accum:
545
+ # Whole step was bad. Skip the optimizer update.
546
+ scaler.update()
547
+ continue
548
+
549
+ scaler.unscale_(optimizer)
550
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
551
+ scaler.step(optimizer)
552
+ scaler.update()
553
+
554
+ if step % args.log_interval == 0 or step == start_step + 1:
555
+ now = time.time()
556
+ elapsed = max(now - last_log_time, 1e-9)
557
+ steps_done = max(1, step - last_log_step)
558
+ toks_per_sec = (tokens_per_step * steps_done) / elapsed
559
+ last_log_time = now
560
+ last_log_step = step
561
+ console.print(
562
+ f"step {step:7d} | "
563
+ f"loss {loss_accum:.4f} | "
564
+ f"lr {lr:.2e} | "
565
+ f"grad {float(grad_norm):.2f} | "
566
+ f"{toks_per_sec:,.0f} tok/s"
567
+ )
568
+
569
+ should_eval = step % args.eval_interval == 0 or step == args.max_steps
570
+ if should_eval:
571
+ losses = estimate_loss(model, train_data, val_data, args, device, autocast_ctx)
572
+ console.print(
573
+ f"[bold]eval step {step}:[/bold] "
574
+ f"train {losses['train']:.4f}, val {losses['val']:.4f}"
575
+ )
576
+ if losses["val"] < best_val_loss:
577
+ best_val_loss = losses["val"]
578
+ save_checkpoint(
579
+ args.out_dir / "best.pt",
580
+ model,
581
+ optimizer,
582
+ args,
583
+ step,
584
+ best_val_loss,
585
+ meta,
586
+ )
587
+ console.print(f"[green]saved best checkpoint: val {best_val_loss:.4f}[/green]")
588
+ if args.always_save_checkpoint:
589
+ save_checkpoint(
590
+ args.out_dir / f"step_{step}.pt",
591
+ model,
592
+ optimizer,
593
+ args,
594
+ step,
595
+ best_val_loss,
596
+ meta,
597
+ )
598
+
599
+ if step % args.save_interval == 0:
600
+ save_checkpoint(
601
+ args.out_dir / "latest.pt",
602
+ model,
603
+ optimizer,
604
+ args,
605
+ step,
606
+ best_val_loss,
607
+ meta,
608
+ )
609
+ console.print(f"[cyan]saved latest checkpoint at step {step}[/cyan]")
610
+
611
+ save_checkpoint(
612
+ args.out_dir / "latest.pt",
613
+ model,
614
+ optimizer,
615
+ args,
616
+ args.max_steps,
617
+ best_val_loss,
618
+ meta,
619
+ )
620
+
621
+ elapsed_hours = (time.time() - start_time) / 3600.0
622
+ console.print("")
623
+ console.print(f"[bold green]Finished in {elapsed_hours:.2f} hours.[/bold green]")
624
+ console.print(f"[bold green]Best validation loss: {best_val_loss:.4f}[/bold green]")
625
+ console.print(f"[bold green]Best checkpoint: {args.out_dir / 'best.pt'}[/bold green]")
626
+ console.print(f"[bold green]Latest checkpoint: {args.out_dir / 'latest.pt'}[/bold green]")
627
+
628
+
629
+ if __name__ == "__main__":
630
+ main()
config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 32000,
3
+ "block_size": 512,
4
+ "n_layer": 12,
5
+ "n_head": 12,
6
+ "n_embd": 768,
7
+ "dropout": 0.0,
8
+ "bias": false,
9
+ "gradient_checkpointing": false,
10
+ "model_type": "ron-gpt",
11
+ "architectures": [
12
+ "GPT"
13
+ ],
14
+ "torch_dtype": "float32"
15
+ }
meta.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset": "Salesforce/wikitext",
3
+ "config": "wikitext-103-raw-v1",
4
+ "vocab_size": 32000,
5
+ "dtype": "uint16",
6
+ "bos_id": 1,
7
+ "eos_id": 2,
8
+ "pad_id": 0,
9
+ "train_tokens": 115671965,
10
+ "val_tokens": 242485,
11
+ "test_tokens": 276246
12
+ }
pretrain.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2b0ce466438490a11b38092ed30993111987b58f6d8e08da64c262db1e0f476
3
+ size 1319159633
summarizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b665451d9bdb04205cafd88b0ef46a777204584cef3a037c3bd47f0598631e8
3
+ size 1319159633
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff