Upload Ron-110M: pretrain + summarizer + tokenizer + code
Browse files- README.md +129 -0
- code/__init__.py +0 -0
- code/ask.py +204 -0
- code/finetune_sft.py +484 -0
- code/make_cnndm_sft.py +130 -0
- code/model.py +189 -0
- code/prepare_wikitext.py +151 -0
- code/tokenizer.py +52 -0
- code/train.py +630 -0
- config.json +15 -0
- meta.json +12 -0
- pretrain.pt +3 -0
- summarizer.pt +3 -0
- tokenizer.json +0 -0
README.md
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
license: mit
|
| 5 |
+
tags:
|
| 6 |
+
- gpt
|
| 7 |
+
- text-generation
|
| 8 |
+
- summarization
|
| 9 |
+
- from-scratch
|
| 10 |
+
- pytorch
|
| 11 |
+
library_name: pytorch
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# Ron-110M
|
| 15 |
+
|
| 16 |
+
A 110M-parameter GPT-style language model trained from scratch on a single
|
| 17 |
+
RTX 3090. Pretrained on WikiText-103, then fine-tuned on CNN/DailyMail for
|
| 18 |
+
extractive news summarization.
|
| 19 |
+
|
| 20 |
+
This is a learning / research model. It is small, the tokenizer is a custom
|
| 21 |
+
byte-level BPE, and it does not use the Hugging Face `transformers` model
|
| 22 |
+
classes. The repo includes the original PyTorch code so you can run, fine-tune,
|
| 23 |
+
or continue pretraining from these weights.
|
| 24 |
+
|
| 25 |
+
## Files
|
| 26 |
+
|
| 27 |
+
- `pretrain.pt` - base language model checkpoint (after WikiText-103 pretraining)
|
| 28 |
+
- `summarizer.pt` - SFT checkpoint for news summarization (start from this for inference)
|
| 29 |
+
- `tokenizer.json` - byte-level BPE tokenizer (32k vocab, specials: `<pad> <bos> <eos> <unk>`)
|
| 30 |
+
- `meta.json` - dataset metadata (vocab size, dtype, token counts)
|
| 31 |
+
- `code/model.py` - GPT model definition
|
| 32 |
+
- `code/tokenizer.py` - tokenizer wrapper with ByteLevel decoder fix
|
| 33 |
+
- `code/ask.py` - inference script with repetition penalty, top-p, no-repeat-ngram
|
| 34 |
+
- `code/train.py` - pretraining script
|
| 35 |
+
- `code/finetune_sft.py` - supervised fine-tuning script
|
| 36 |
+
- `code/make_cnndm_sft.py` - CNN/DailyMail SFT data builder
|
| 37 |
+
- `code/prepare_wikitext.py` - WikiText-103 tokenization + tokenizer training
|
| 38 |
+
|
| 39 |
+
## Architecture
|
| 40 |
+
|
| 41 |
+
```
|
| 42 |
+
n_layer = 12
|
| 43 |
+
n_head = 12
|
| 44 |
+
n_embd = 768
|
| 45 |
+
block_size = 512
|
| 46 |
+
vocab_size = 32000
|
| 47 |
+
parameters = 109.92M
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## Training results
|
| 51 |
+
|
| 52 |
+
| Stage | Dataset | Steps | Final val loss |
|
| 53 |
+
|--------------------|---------------|--------|----------------|
|
| 54 |
+
| Pretrain | WikiText-103 | 12,000 | 3.15 |
|
| 55 |
+
| SFT (summarizer) | CNN/DailyMail | 6,000 | 2.97 |
|
| 56 |
+
|
| 57 |
+
## Quick start
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
# Clone this repo
|
| 61 |
+
git lfs install
|
| 62 |
+
git clone https://huggingface.co/endurasolution/RON-110M
|
| 63 |
+
cd RON-110M
|
| 64 |
+
|
| 65 |
+
# Install minimal deps
|
| 66 |
+
pip install torch numpy tokenizers rich
|
| 67 |
+
|
| 68 |
+
# Run inference
|
| 69 |
+
python code/ask.py \
|
| 70 |
+
--checkpoint summarizer.pt \
|
| 71 |
+
--tokenizer tokenizer.json \
|
| 72 |
+
--text "A man has been arrested in Manchester after a series of break-ins at local shops. Police said the suspect was found with stolen goods. He is due to appear in court on Monday." \
|
| 73 |
+
--max_new_tokens 80 \
|
| 74 |
+
--temperature 0.4 \
|
| 75 |
+
--top_p 0.9 \
|
| 76 |
+
--repetition_penalty 1.1 \
|
| 77 |
+
--no_repeat_ngram_size 3
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
Expected output (paraphrased): a short news-style summary that preserves the key
|
| 81 |
+
facts from the input.
|
| 82 |
+
|
| 83 |
+
## Continue training
|
| 84 |
+
|
| 85 |
+
To resume pretraining from `pretrain.pt`:
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
python code/train.py \
|
| 89 |
+
--resume pretrain.pt \
|
| 90 |
+
--reset_step --reset_optimizer \
|
| 91 |
+
--data_dir data/wikitext103 \
|
| 92 |
+
--out_dir runs/wikitext-gpt \
|
| 93 |
+
--preset rtx3090_8h \
|
| 94 |
+
--batch_size 16 --grad_accum 8 \
|
| 95 |
+
--max_steps 12000 \
|
| 96 |
+
--learning_rate 2e-4 --min_lr 2e-5 \
|
| 97 |
+
--warmup_steps 200 \
|
| 98 |
+
--no_gradient_checkpointing \
|
| 99 |
+
--save_optimizer
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
To fine-tune for a new task, prepare a JSONL file with `prompt` and `answer`
|
| 103 |
+
keys, then:
|
| 104 |
+
|
| 105 |
+
```bash
|
| 106 |
+
python code/finetune_sft.py \
|
| 107 |
+
--base_checkpoint pretrain.pt \
|
| 108 |
+
--tokenizer tokenizer.json \
|
| 109 |
+
--sft_file your_data.jsonl \
|
| 110 |
+
--out_dir runs/my-finetune \
|
| 111 |
+
--max_steps 6000 \
|
| 112 |
+
--batch_size 8 --grad_accum 8 \
|
| 113 |
+
--learning_rate 5e-5 --min_lr 5e-6 \
|
| 114 |
+
--warmup_steps 200
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
## Limitations
|
| 118 |
+
|
| 119 |
+
- Small (110M parameters) - knowledge is limited, hallucinations possible on
|
| 120 |
+
out-of-domain inputs.
|
| 121 |
+
- Tokenizer is custom byte-level BPE - **must** be loaded with the included
|
| 122 |
+
`tokenizer.json`. Do not substitute a GPT-2 tokenizer.
|
| 123 |
+
- Not compatible with `transformers.AutoModel`. Use the included `code/`.
|
| 124 |
+
- SFT data was CNN/DailyMail news. The model is most reliable on news-style
|
| 125 |
+
English; expect weaker output on code, math, or conversational input.
|
| 126 |
+
|
| 127 |
+
## License
|
| 128 |
+
|
| 129 |
+
MIT.
|
code/__init__.py
ADDED
|
File without changes
|
code/ask.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from searshorai.model import GPT, GPTConfig
|
| 10 |
+
from searshorai.tokenizer import TextTokenizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Must match the prompts used in make_xsum_sft.py / make_paragraph_sft.py.
|
| 14 |
+
# Using the first template is the canonical choice at inference time.
|
| 15 |
+
DEFAULT_PROMPT_TEMPLATE = (
|
| 16 |
+
"Read the article and write a one-sentence summary.\n\n"
|
| 17 |
+
"Article:\n{passage}\n\nSummary:\n"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def strip_compile_prefix(state_dict):
|
| 22 |
+
cleaned = {}
|
| 23 |
+
for key, value in state_dict.items():
|
| 24 |
+
if key.startswith("_orig_mod."):
|
| 25 |
+
key = key[len("_orig_mod.") :]
|
| 26 |
+
cleaned[key] = value
|
| 27 |
+
return cleaned
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def parse_args() -> argparse.Namespace:
|
| 31 |
+
parser = argparse.ArgumentParser(description="Ask the paragraph-explainer model.")
|
| 32 |
+
parser.add_argument("--checkpoint", type=Path, required=True)
|
| 33 |
+
parser.add_argument("--tokenizer", type=Path, default=Path("data/wikitext103/tokenizer.json"))
|
| 34 |
+
parser.add_argument("--text", type=str, required=True, help="The passage to explain.")
|
| 35 |
+
parser.add_argument("--prompt_template", type=str, default=DEFAULT_PROMPT_TEMPLATE)
|
| 36 |
+
parser.add_argument("--max_new_tokens", type=int, default=120)
|
| 37 |
+
parser.add_argument("--temperature", type=float, default=0.7)
|
| 38 |
+
parser.add_argument("--top_k", type=int, default=40)
|
| 39 |
+
parser.add_argument("--top_p", type=float, default=0.9,
|
| 40 |
+
help="Nucleus sampling cutoff. Set 1.0 to disable.")
|
| 41 |
+
parser.add_argument("--repetition_penalty", type=float, default=1.3,
|
| 42 |
+
help="Penalty for re-emitting tokens already in the context. 1.0 = off.")
|
| 43 |
+
parser.add_argument("--no_repeat_ngram_size", type=int, default=3,
|
| 44 |
+
help="Block any n-gram of this size from appearing twice. 0 = off.")
|
| 45 |
+
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"])
|
| 46 |
+
parser.add_argument("--seed", type=int, default=0)
|
| 47 |
+
return parser.parse_args()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def banned_tokens_from_ngrams(generated: list[int], n: int) -> set[int]:
|
| 51 |
+
"""
|
| 52 |
+
For no-repeat-ngram blocking: given the tokens generated so far, return
|
| 53 |
+
the set of token ids that would close a previously-seen n-gram if emitted
|
| 54 |
+
next.
|
| 55 |
+
"""
|
| 56 |
+
if n <= 0 or len(generated) < n - 1:
|
| 57 |
+
return set()
|
| 58 |
+
prefix = tuple(generated[-(n - 1):])
|
| 59 |
+
banned: set[int] = set()
|
| 60 |
+
for i in range(len(generated) - n + 1):
|
| 61 |
+
ngram = tuple(generated[i : i + n - 1])
|
| 62 |
+
if ngram == prefix:
|
| 63 |
+
banned.add(generated[i + n - 1])
|
| 64 |
+
return banned
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def generate(
|
| 68 |
+
model: GPT,
|
| 69 |
+
prompt_ids: list[int],
|
| 70 |
+
max_new_tokens: int,
|
| 71 |
+
temperature: float,
|
| 72 |
+
top_k: int,
|
| 73 |
+
top_p: float,
|
| 74 |
+
repetition_penalty: float,
|
| 75 |
+
no_repeat_ngram_size: int,
|
| 76 |
+
eos_id: int | None,
|
| 77 |
+
device: str,
|
| 78 |
+
) -> list[int]:
|
| 79 |
+
"""
|
| 80 |
+
Custom sampling loop with repetition penalty, top-k, top-p (nucleus),
|
| 81 |
+
and no-repeat-ngram blocking. Returns the list of newly generated token
|
| 82 |
+
ids (does not include the prompt).
|
| 83 |
+
"""
|
| 84 |
+
block_size = model.config.block_size
|
| 85 |
+
context = list(prompt_ids)
|
| 86 |
+
generated: list[int] = []
|
| 87 |
+
|
| 88 |
+
for _ in range(max_new_tokens):
|
| 89 |
+
idx_cond = context if len(context) <= block_size else context[-block_size:]
|
| 90 |
+
x = torch.tensor([idx_cond], dtype=torch.long, device=device)
|
| 91 |
+
|
| 92 |
+
with torch.no_grad():
|
| 93 |
+
logits, _ = model(x)
|
| 94 |
+
logits = logits[:, -1, :].squeeze(0).float()
|
| 95 |
+
|
| 96 |
+
if repetition_penalty != 1.0 and len(context) > 0:
|
| 97 |
+
seen = torch.tensor(list(set(context)), dtype=torch.long, device=device)
|
| 98 |
+
scores = logits[seen]
|
| 99 |
+
scores = torch.where(scores > 0, scores / repetition_penalty, scores * repetition_penalty)
|
| 100 |
+
logits[seen] = scores
|
| 101 |
+
|
| 102 |
+
if no_repeat_ngram_size > 0 and len(generated) >= no_repeat_ngram_size - 1:
|
| 103 |
+
banned = banned_tokens_from_ngrams(generated, no_repeat_ngram_size)
|
| 104 |
+
for tok_id in banned:
|
| 105 |
+
logits[tok_id] = -float("inf")
|
| 106 |
+
|
| 107 |
+
logits = logits / max(temperature, 1e-5)
|
| 108 |
+
|
| 109 |
+
if top_k is not None and top_k > 0:
|
| 110 |
+
k = min(top_k, logits.size(-1))
|
| 111 |
+
top_vals, _ = torch.topk(logits, k)
|
| 112 |
+
cutoff = top_vals[-1]
|
| 113 |
+
logits = torch.where(logits < cutoff, torch.full_like(logits, -float("inf")), logits)
|
| 114 |
+
|
| 115 |
+
if top_p < 1.0:
|
| 116 |
+
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
|
| 117 |
+
probs_sorted = F.softmax(sorted_logits, dim=-1)
|
| 118 |
+
cumulative = torch.cumsum(probs_sorted, dim=-1)
|
| 119 |
+
mask = cumulative > top_p
|
| 120 |
+
mask[..., 1:] = mask[..., :-1].clone()
|
| 121 |
+
mask[..., 0] = False
|
| 122 |
+
sorted_logits = sorted_logits.masked_fill(mask, -float("inf"))
|
| 123 |
+
logits = torch.full_like(logits, -float("inf"))
|
| 124 |
+
logits.scatter_(0, sorted_idx, sorted_logits)
|
| 125 |
+
|
| 126 |
+
probs = F.softmax(logits, dim=-1)
|
| 127 |
+
if not torch.isfinite(probs).all() or probs.sum() <= 0:
|
| 128 |
+
next_tok = int(logits.argmax().item())
|
| 129 |
+
else:
|
| 130 |
+
next_tok = int(torch.multinomial(probs, num_samples=1).item())
|
| 131 |
+
|
| 132 |
+
if eos_id is not None and next_tok == eos_id:
|
| 133 |
+
break
|
| 134 |
+
|
| 135 |
+
context.append(next_tok)
|
| 136 |
+
generated.append(next_tok)
|
| 137 |
+
|
| 138 |
+
return generated
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def main() -> None:
|
| 142 |
+
args = parse_args()
|
| 143 |
+
|
| 144 |
+
if args.seed:
|
| 145 |
+
torch.manual_seed(args.seed)
|
| 146 |
+
if torch.cuda.is_available():
|
| 147 |
+
torch.cuda.manual_seed_all(args.seed)
|
| 148 |
+
|
| 149 |
+
if args.device == "auto":
|
| 150 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 151 |
+
else:
|
| 152 |
+
device = args.device
|
| 153 |
+
|
| 154 |
+
tok = TextTokenizer(args.tokenizer)
|
| 155 |
+
|
| 156 |
+
ckpt = torch.load(args.checkpoint, map_location=device, weights_only=False)
|
| 157 |
+
config = GPTConfig(**ckpt["config"])
|
| 158 |
+
config.dropout = 0.0
|
| 159 |
+
config.gradient_checkpointing = False
|
| 160 |
+
|
| 161 |
+
model = GPT(config)
|
| 162 |
+
state_dict = strip_compile_prefix(ckpt["model"])
|
| 163 |
+
model.load_state_dict(state_dict, strict=True)
|
| 164 |
+
model.to(device)
|
| 165 |
+
model.eval()
|
| 166 |
+
|
| 167 |
+
if tok.vocab_size != model.config.vocab_size:
|
| 168 |
+
raise RuntimeError(
|
| 169 |
+
f"Tokenizer vocab_size {tok.vocab_size} != model vocab_size {model.config.vocab_size}. "
|
| 170 |
+
"Use the same tokenizer.json that was used for pretrain/SFT."
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
prompt = args.prompt_template.format(passage=args.text.strip())
|
| 174 |
+
prompt_ids = tok.encode(prompt, add_bos=True, add_eos=False)
|
| 175 |
+
|
| 176 |
+
max_prompt_len = model.config.block_size - args.max_new_tokens - 1
|
| 177 |
+
if max_prompt_len < 16:
|
| 178 |
+
raise RuntimeError(
|
| 179 |
+
f"max_new_tokens={args.max_new_tokens} is too large for block_size={model.config.block_size}."
|
| 180 |
+
)
|
| 181 |
+
if len(prompt_ids) > max_prompt_len:
|
| 182 |
+
bos = [prompt_ids[0]] if prompt_ids and prompt_ids[0] == tok.bos_id else []
|
| 183 |
+
tail = prompt_ids[-(max_prompt_len - len(bos)) :]
|
| 184 |
+
prompt_ids = bos + tail
|
| 185 |
+
|
| 186 |
+
new_ids = generate(
|
| 187 |
+
model=model,
|
| 188 |
+
prompt_ids=prompt_ids,
|
| 189 |
+
max_new_tokens=args.max_new_tokens,
|
| 190 |
+
temperature=args.temperature,
|
| 191 |
+
top_k=args.top_k,
|
| 192 |
+
top_p=args.top_p,
|
| 193 |
+
repetition_penalty=args.repetition_penalty,
|
| 194 |
+
no_repeat_ngram_size=args.no_repeat_ngram_size,
|
| 195 |
+
eos_id=tok.eos_id,
|
| 196 |
+
device=device,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
answer = tok.decode(new_ids, skip_special_tokens=True).strip()
|
| 200 |
+
print(answer)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
if __name__ == "__main__":
|
| 204 |
+
main()
|
code/finetune_sft.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import math
|
| 6 |
+
import random
|
| 7 |
+
import time
|
| 8 |
+
import unicodedata
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from rich.console import Console
|
| 15 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 16 |
+
|
| 17 |
+
from searshorai.model import GPT, GPTConfig
|
| 18 |
+
from searshorai.tokenizer import TextTokenizer
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
console = Console()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class Example:
|
| 26 |
+
input_ids: list[int]
|
| 27 |
+
labels: list[int]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def parse_args() -> argparse.Namespace:
|
| 31 |
+
parser = argparse.ArgumentParser(description="Stable supervised fine-tune for paragraph explanation.")
|
| 32 |
+
|
| 33 |
+
parser.add_argument("--base_checkpoint", type=Path, default=Path("runs/wikitext-gpt/best.pt"))
|
| 34 |
+
parser.add_argument("--tokenizer", type=Path, default=Path("data/wikitext103/tokenizer.json"))
|
| 35 |
+
parser.add_argument("--sft_file", type=Path, default=Path("data/wikitext103/paragraph_sft.jsonl"))
|
| 36 |
+
parser.add_argument("--out_dir", type=Path, default=Path("runs/paragraph-explainer"))
|
| 37 |
+
|
| 38 |
+
parser.add_argument("--max_steps", type=int, default=8000)
|
| 39 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
| 40 |
+
parser.add_argument("--grad_accum", type=int, default=8)
|
| 41 |
+
parser.add_argument("--learning_rate", type=float, default=2e-5)
|
| 42 |
+
parser.add_argument("--min_lr", type=float, default=2e-6)
|
| 43 |
+
parser.add_argument("--warmup_steps", type=int, default=300)
|
| 44 |
+
parser.add_argument("--weight_decay", type=float, default=0.01)
|
| 45 |
+
parser.add_argument("--grad_clip", type=float, default=1.0)
|
| 46 |
+
|
| 47 |
+
parser.add_argument("--max_answer_tokens", type=int, default=220)
|
| 48 |
+
parser.add_argument("--min_answer_tokens", type=int, default=8)
|
| 49 |
+
parser.add_argument("--val_ratio", type=float, default=0.02)
|
| 50 |
+
|
| 51 |
+
parser.add_argument("--eval_interval", type=int, default=250)
|
| 52 |
+
parser.add_argument("--eval_batches", type=int, default=40)
|
| 53 |
+
parser.add_argument("--save_interval", type=int, default=500)
|
| 54 |
+
parser.add_argument("--log_interval", type=int, default=20)
|
| 55 |
+
parser.add_argument("--seed", type=int, default=1337)
|
| 56 |
+
parser.add_argument("--compile", action="store_true")
|
| 57 |
+
parser.add_argument("--resume", type=Path, default=None)
|
| 58 |
+
|
| 59 |
+
return parser.parse_args()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def clean_text(text: Any) -> str:
|
| 63 |
+
if text is None:
|
| 64 |
+
return ""
|
| 65 |
+
text = str(text)
|
| 66 |
+
text = text.replace("\ufffd", " ")
|
| 67 |
+
text = unicodedata.normalize("NFKC", text)
|
| 68 |
+
text = "".join(ch if (ch in ("\n", "\t") or ord(ch) >= 32) else " " for ch in text)
|
| 69 |
+
text = "\n".join(" ".join(line.split()) for line in text.splitlines())
|
| 70 |
+
return text.strip()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_special_id(tok: TextTokenizer, name: str) -> int | None:
|
| 74 |
+
value = getattr(tok, name, None)
|
| 75 |
+
return int(value) if isinstance(value, int) else None
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def ensure_eos(ids: list[int], eos_id: int | None) -> list[int]:
|
| 79 |
+
if eos_id is None:
|
| 80 |
+
return ids
|
| 81 |
+
if not ids or ids[-1] != eos_id:
|
| 82 |
+
return ids + [eos_id]
|
| 83 |
+
return ids
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_lr(step: int, args: argparse.Namespace) -> float:
|
| 87 |
+
if step < args.warmup_steps:
|
| 88 |
+
return args.learning_rate * (step + 1) / max(1, args.warmup_steps)
|
| 89 |
+
ratio = (step - args.warmup_steps) / max(1, args.max_steps - args.warmup_steps)
|
| 90 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * min(1.0, max(0.0, ratio))))
|
| 91 |
+
return args.min_lr + coeff * (args.learning_rate - args.min_lr)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def read_prompt_answer(row: dict[str, Any]) -> tuple[str, str]:
|
| 95 |
+
"""
|
| 96 |
+
Supports these JSONL styles:
|
| 97 |
+
{"prompt": "...", "answer": "..."}
|
| 98 |
+
{"input": "...", "output": "..."}
|
| 99 |
+
{"paragraph": "...", "explanation": "..."}
|
| 100 |
+
{"text": "...", "answer": "..."}
|
| 101 |
+
"""
|
| 102 |
+
if "prompt" in row:
|
| 103 |
+
prompt = row.get("prompt", "")
|
| 104 |
+
elif "paragraph" in row:
|
| 105 |
+
prompt = f"Explain this paragraph in simple words:\n\n{row.get('paragraph', '')}\n\nExplanation:\n"
|
| 106 |
+
elif "text" in row:
|
| 107 |
+
prompt = f"Explain this paragraph in simple words:\n\n{row.get('text', '')}\n\nExplanation:\n"
|
| 108 |
+
else:
|
| 109 |
+
prompt = row.get("input", "")
|
| 110 |
+
|
| 111 |
+
answer = (
|
| 112 |
+
row.get("answer")
|
| 113 |
+
if row.get("answer") is not None
|
| 114 |
+
else row.get("output")
|
| 115 |
+
if row.get("output") is not None
|
| 116 |
+
else row.get("explanation", "")
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return clean_text(prompt), clean_text(answer)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def load_examples(path: Path, tok: TextTokenizer, block_size: int, args: argparse.Namespace) -> list[Example]:
|
| 123 |
+
if not path.exists():
|
| 124 |
+
raise FileNotFoundError(f"SFT file not found: {path}")
|
| 125 |
+
|
| 126 |
+
eos_id = get_special_id(tok, "eos_id")
|
| 127 |
+
examples: list[Example] = []
|
| 128 |
+
|
| 129 |
+
skipped_empty = 0
|
| 130 |
+
skipped_too_short = 0
|
| 131 |
+
truncated_answers = 0
|
| 132 |
+
bad_json = 0
|
| 133 |
+
|
| 134 |
+
with path.open("r", encoding="utf-8", errors="replace") as f:
|
| 135 |
+
for line in f:
|
| 136 |
+
line = line.strip()
|
| 137 |
+
if not line:
|
| 138 |
+
continue
|
| 139 |
+
try:
|
| 140 |
+
row = json.loads(line)
|
| 141 |
+
except json.JSONDecodeError:
|
| 142 |
+
bad_json += 1
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
prompt, answer = read_prompt_answer(row)
|
| 146 |
+
if not prompt or not answer:
|
| 147 |
+
skipped_empty += 1
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
prompt_ids = tok.encode(prompt, add_bos=True, add_eos=False)
|
| 151 |
+
|
| 152 |
+
# Encode answer without EOS, then add EOS after any truncation.
|
| 153 |
+
answer_ids = tok.encode(answer, add_bos=False, add_eos=False)
|
| 154 |
+
if len(answer_ids) < args.min_answer_tokens:
|
| 155 |
+
skipped_too_short += 1
|
| 156 |
+
continue
|
| 157 |
+
if len(answer_ids) > args.max_answer_tokens:
|
| 158 |
+
answer_ids = answer_ids[: args.max_answer_tokens]
|
| 159 |
+
truncated_answers += 1
|
| 160 |
+
answer_ids = ensure_eos(answer_ids, eos_id)
|
| 161 |
+
|
| 162 |
+
# full_ids must fit in block_size + 1 (we shift to get input/target).
|
| 163 |
+
room_for_prompt = (block_size + 1) - len(answer_ids)
|
| 164 |
+
if room_for_prompt < 16:
|
| 165 |
+
# Answer is huge - cut it further but keep EOS at the end.
|
| 166 |
+
keep = max(16, block_size - 32)
|
| 167 |
+
answer_ids = answer_ids[: keep - 1]
|
| 168 |
+
answer_ids = ensure_eos(answer_ids, eos_id)
|
| 169 |
+
room_for_prompt = (block_size + 1) - len(answer_ids)
|
| 170 |
+
|
| 171 |
+
# Keep the tail of the prompt if it is too long.
|
| 172 |
+
if len(prompt_ids) > room_for_prompt:
|
| 173 |
+
# Preserve BOS at position 0 by keeping BOS + tail of body.
|
| 174 |
+
bos = [prompt_ids[0]] if prompt_ids and prompt_ids[0] == tok.bos_id else []
|
| 175 |
+
tail = prompt_ids[-(room_for_prompt - len(bos)) :] if room_for_prompt - len(bos) > 0 else []
|
| 176 |
+
prompt_ids = bos + tail
|
| 177 |
+
|
| 178 |
+
full_ids = prompt_ids + answer_ids
|
| 179 |
+
|
| 180 |
+
if len(full_ids) > block_size + 1:
|
| 181 |
+
# Final hard cap. If we have to cut, keep EOS as the last target token.
|
| 182 |
+
full_ids = full_ids[: block_size + 1]
|
| 183 |
+
if eos_id is not None and full_ids[-1] != eos_id:
|
| 184 |
+
full_ids[-1] = eos_id
|
| 185 |
+
|
| 186 |
+
if len(full_ids) < 16:
|
| 187 |
+
skipped_too_short += 1
|
| 188 |
+
continue
|
| 189 |
+
|
| 190 |
+
input_ids = full_ids[:-1]
|
| 191 |
+
next_ids = full_ids[1:]
|
| 192 |
+
|
| 193 |
+
# Loss only on answer tokens (including the final EOS target).
|
| 194 |
+
prompt_len = len(prompt_ids)
|
| 195 |
+
labels = [
|
| 196 |
+
token_id if (position + 1) >= prompt_len else -100
|
| 197 |
+
for position, token_id in enumerate(next_ids)
|
| 198 |
+
]
|
| 199 |
+
|
| 200 |
+
if any(x != -100 for x in labels):
|
| 201 |
+
examples.append(Example(input_ids=input_ids, labels=labels))
|
| 202 |
+
|
| 203 |
+
console.print(
|
| 204 |
+
f"Loaded {len(examples):,} examples | "
|
| 205 |
+
f"empty={skipped_empty:,}, short={skipped_too_short:,}, "
|
| 206 |
+
f"truncated_answers={truncated_answers:,}, bad_json={bad_json:,}"
|
| 207 |
+
)
|
| 208 |
+
if len(examples) < 10:
|
| 209 |
+
raise RuntimeError("Too few valid SFT examples. Check your JSONL keys and tokenizer.")
|
| 210 |
+
return examples
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def make_batch(
|
| 214 |
+
examples: list[Example],
|
| 215 |
+
batch_size: int,
|
| 216 |
+
pad_id: int,
|
| 217 |
+
device: str,
|
| 218 |
+
block_size: int,
|
| 219 |
+
):
|
| 220 |
+
if len(examples) >= batch_size:
|
| 221 |
+
batch = random.sample(examples, batch_size)
|
| 222 |
+
else:
|
| 223 |
+
batch = random.choices(examples, k=batch_size)
|
| 224 |
+
|
| 225 |
+
xs = []
|
| 226 |
+
ys = []
|
| 227 |
+
for ex in batch:
|
| 228 |
+
ix = ex.input_ids[:block_size]
|
| 229 |
+
ly = ex.labels[:block_size]
|
| 230 |
+
xs.append(torch.tensor(ix, dtype=torch.long))
|
| 231 |
+
ys.append(torch.tensor(ly, dtype=torch.long))
|
| 232 |
+
|
| 233 |
+
x = pad_sequence(xs, batch_first=True, padding_value=pad_id)
|
| 234 |
+
y = pad_sequence(ys, batch_first=True, padding_value=-100)
|
| 235 |
+
|
| 236 |
+
if device == "cuda":
|
| 237 |
+
x = x.pin_memory().to(device, non_blocking=True)
|
| 238 |
+
y = y.pin_memory().to(device, non_blocking=True)
|
| 239 |
+
else:
|
| 240 |
+
x = x.to(device)
|
| 241 |
+
y = y.to(device)
|
| 242 |
+
return x, y
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
@torch.no_grad()
|
| 246 |
+
def evaluate(model, examples, args, pad_id, device, autocast_ctx, block_size) -> float:
|
| 247 |
+
model.eval()
|
| 248 |
+
losses: list[float] = []
|
| 249 |
+
for _ in range(args.eval_batches):
|
| 250 |
+
x, y = make_batch(examples, args.batch_size, pad_id, device, block_size)
|
| 251 |
+
with autocast_ctx:
|
| 252 |
+
_, loss = model(x, y)
|
| 253 |
+
if torch.isfinite(loss):
|
| 254 |
+
losses.append(float(loss.item()))
|
| 255 |
+
model.train()
|
| 256 |
+
return sum(losses) / max(1, len(losses))
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def strip_compile_prefix(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
| 260 |
+
cleaned = {}
|
| 261 |
+
for key, value in state_dict.items():
|
| 262 |
+
if key.startswith("_orig_mod."):
|
| 263 |
+
key = key[len("_orig_mod.") :]
|
| 264 |
+
cleaned[key] = value
|
| 265 |
+
return cleaned
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def save_checkpoint(
|
| 269 |
+
path: Path,
|
| 270 |
+
model,
|
| 271 |
+
optimizer,
|
| 272 |
+
args: argparse.Namespace,
|
| 273 |
+
step: int,
|
| 274 |
+
best_val_loss: float,
|
| 275 |
+
meta: dict[str, Any],
|
| 276 |
+
) -> None:
|
| 277 |
+
raw_model = model._orig_mod if hasattr(model, "_orig_mod") else model
|
| 278 |
+
meta = dict(meta or {})
|
| 279 |
+
meta.update(
|
| 280 |
+
{
|
| 281 |
+
"task": "paragraph_explainer_sft",
|
| 282 |
+
"tokenizer": str(args.tokenizer),
|
| 283 |
+
"sft_file": str(args.sft_file),
|
| 284 |
+
"important": "Prompt tokens are masked; answer is EOS-safe truncated.",
|
| 285 |
+
}
|
| 286 |
+
)
|
| 287 |
+
torch.save(
|
| 288 |
+
{
|
| 289 |
+
"model": raw_model.state_dict(),
|
| 290 |
+
"optimizer": optimizer.state_dict(),
|
| 291 |
+
"args": {k: (str(v) if isinstance(v, Path) else v) for k, v in vars(args).items()},
|
| 292 |
+
"config": vars(raw_model.config),
|
| 293 |
+
"step": step,
|
| 294 |
+
"best_val_loss": best_val_loss,
|
| 295 |
+
"meta": meta,
|
| 296 |
+
},
|
| 297 |
+
path,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def main() -> None:
|
| 302 |
+
args = parse_args()
|
| 303 |
+
args.out_dir.mkdir(parents=True, exist_ok=True)
|
| 304 |
+
|
| 305 |
+
random.seed(args.seed)
|
| 306 |
+
torch.manual_seed(args.seed)
|
| 307 |
+
if torch.cuda.is_available():
|
| 308 |
+
torch.cuda.manual_seed_all(args.seed)
|
| 309 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 310 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 311 |
+
|
| 312 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 313 |
+
device_type = "cuda" if device == "cuda" else "cpu"
|
| 314 |
+
|
| 315 |
+
if device == "cuda" and torch.cuda.is_bf16_supported():
|
| 316 |
+
amp_dtype = torch.bfloat16
|
| 317 |
+
console.print("AMP dtype: bfloat16")
|
| 318 |
+
elif device == "cuda":
|
| 319 |
+
amp_dtype = torch.float16
|
| 320 |
+
console.print("AMP dtype: float16")
|
| 321 |
+
else:
|
| 322 |
+
amp_dtype = torch.float32
|
| 323 |
+
console.print("AMP disabled on CPU")
|
| 324 |
+
|
| 325 |
+
autocast_ctx = torch.amp.autocast(
|
| 326 |
+
device_type=device_type,
|
| 327 |
+
dtype=amp_dtype,
|
| 328 |
+
enabled=(device == "cuda"),
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
tok = TextTokenizer(args.tokenizer)
|
| 332 |
+
pad_id = int(getattr(tok, "pad_id", 0))
|
| 333 |
+
|
| 334 |
+
if args.resume is not None:
|
| 335 |
+
ckpt_path = args.resume
|
| 336 |
+
console.print(f"Resuming SFT checkpoint: {ckpt_path}")
|
| 337 |
+
else:
|
| 338 |
+
ckpt_path = args.base_checkpoint
|
| 339 |
+
console.print(f"Starting from base checkpoint: {ckpt_path}")
|
| 340 |
+
|
| 341 |
+
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 342 |
+
|
| 343 |
+
config = GPTConfig(**ckpt["config"])
|
| 344 |
+
# Force disable dropout for stable SFT (already 0.0 in pretrain).
|
| 345 |
+
config.dropout = 0.0
|
| 346 |
+
model = GPT(config)
|
| 347 |
+
state_dict = strip_compile_prefix(ckpt["model"])
|
| 348 |
+
model.load_state_dict(state_dict, strict=True)
|
| 349 |
+
model.to(device)
|
| 350 |
+
|
| 351 |
+
# Sanity check: tokenizer and model vocab must match.
|
| 352 |
+
if tok.vocab_size != model.config.vocab_size:
|
| 353 |
+
raise RuntimeError(
|
| 354 |
+
f"Tokenizer vocab_size {tok.vocab_size} != model vocab_size {model.config.vocab_size}. "
|
| 355 |
+
"This is the most common cause of garbled output. Use the same tokenizer that produced the pretrain data."
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
optimizer = model.configure_optimizers(
|
| 359 |
+
args.weight_decay,
|
| 360 |
+
args.learning_rate,
|
| 361 |
+
(0.9, 0.95),
|
| 362 |
+
device_type,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
start_step = 0
|
| 366 |
+
best_val_loss = float("inf")
|
| 367 |
+
if args.resume is not None and "optimizer" in ckpt:
|
| 368 |
+
try:
|
| 369 |
+
optimizer.load_state_dict(ckpt["optimizer"])
|
| 370 |
+
start_step = int(ckpt.get("step", 0)) + 1
|
| 371 |
+
best_val_loss = float(ckpt.get("best_val_loss", float("inf")))
|
| 372 |
+
console.print(f"Resume from step {start_step}, previous best val {best_val_loss:.4f}")
|
| 373 |
+
except Exception as exc:
|
| 374 |
+
console.print(f"[yellow]Could not load optimizer state, starting fresh: {exc}[/yellow]")
|
| 375 |
+
|
| 376 |
+
try:
|
| 377 |
+
scaler = torch.amp.GradScaler("cuda", enabled=(device == "cuda" and amp_dtype == torch.float16))
|
| 378 |
+
except TypeError:
|
| 379 |
+
scaler = torch.cuda.amp.GradScaler(enabled=(device == "cuda" and amp_dtype == torch.float16))
|
| 380 |
+
|
| 381 |
+
examples = load_examples(args.sft_file, tok, model.config.block_size, args)
|
| 382 |
+
random.shuffle(examples)
|
| 383 |
+
|
| 384 |
+
val_size = max(1, int(len(examples) * args.val_ratio))
|
| 385 |
+
val_examples = examples[:val_size]
|
| 386 |
+
train_examples = examples[val_size:]
|
| 387 |
+
if not train_examples:
|
| 388 |
+
raise RuntimeError("No training examples after split.")
|
| 389 |
+
|
| 390 |
+
console.print(
|
| 391 |
+
f"Train={len(train_examples):,} | Val={len(val_examples):,} | "
|
| 392 |
+
f"Block size={model.config.block_size} | Device={device}"
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
if args.compile:
|
| 396 |
+
console.print("Compiling model with torch.compile...")
|
| 397 |
+
model = torch.compile(model)
|
| 398 |
+
|
| 399 |
+
model.train()
|
| 400 |
+
block_size = model.config.block_size if not hasattr(model, "_orig_mod") else model._orig_mod.config.block_size
|
| 401 |
+
last_time = time.time()
|
| 402 |
+
last_step = start_step
|
| 403 |
+
|
| 404 |
+
for step in range(start_step, args.max_steps + 1):
|
| 405 |
+
lr = get_lr(step, args)
|
| 406 |
+
for group in optimizer.param_groups:
|
| 407 |
+
group["lr"] = lr
|
| 408 |
+
|
| 409 |
+
optimizer.zero_grad(set_to_none=True)
|
| 410 |
+
loss_accum = 0.0
|
| 411 |
+
ok_micro_steps = 0
|
| 412 |
+
|
| 413 |
+
for _ in range(args.grad_accum):
|
| 414 |
+
x, y = make_batch(train_examples, args.batch_size, pad_id, device, block_size)
|
| 415 |
+
with autocast_ctx:
|
| 416 |
+
_, loss = model(x, y)
|
| 417 |
+
loss = loss / args.grad_accum
|
| 418 |
+
if not torch.isfinite(loss):
|
| 419 |
+
console.print(f"[yellow]Skipping non-finite loss at step {step}[/yellow]")
|
| 420 |
+
continue
|
| 421 |
+
scaler.scale(loss).backward()
|
| 422 |
+
loss_accum += float(loss.item())
|
| 423 |
+
ok_micro_steps += 1
|
| 424 |
+
|
| 425 |
+
if ok_micro_steps == 0:
|
| 426 |
+
scaler.update()
|
| 427 |
+
continue
|
| 428 |
+
|
| 429 |
+
scaler.unscale_(optimizer)
|
| 430 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
| 431 |
+
scaler.step(optimizer)
|
| 432 |
+
scaler.update()
|
| 433 |
+
|
| 434 |
+
if step % args.log_interval == 0:
|
| 435 |
+
now = time.time()
|
| 436 |
+
steps_done = max(1, step - last_step)
|
| 437 |
+
console.print(
|
| 438 |
+
f"step {step:6d} | loss {loss_accum:.4f} | "
|
| 439 |
+
f"lr {lr:.2e} | {(now - last_time) / steps_done:.2f}s/step"
|
| 440 |
+
)
|
| 441 |
+
last_time = now
|
| 442 |
+
last_step = step
|
| 443 |
+
|
| 444 |
+
if step > 0 and (step % args.eval_interval == 0 or step == args.max_steps):
|
| 445 |
+
val_loss = evaluate(model, val_examples, args, pad_id, device, autocast_ctx, block_size)
|
| 446 |
+
console.print(f"eval step {step}: val {val_loss:.4f}")
|
| 447 |
+
if val_loss < best_val_loss:
|
| 448 |
+
best_val_loss = val_loss
|
| 449 |
+
save_checkpoint(
|
| 450 |
+
args.out_dir / "best.pt",
|
| 451 |
+
model,
|
| 452 |
+
optimizer,
|
| 453 |
+
args,
|
| 454 |
+
step,
|
| 455 |
+
best_val_loss,
|
| 456 |
+
ckpt.get("meta", {}),
|
| 457 |
+
)
|
| 458 |
+
console.print(f"[green]saved best checkpoint: {best_val_loss:.4f}[/green]")
|
| 459 |
+
|
| 460 |
+
if step > 0 and step % args.save_interval == 0:
|
| 461 |
+
save_checkpoint(
|
| 462 |
+
args.out_dir / "latest.pt",
|
| 463 |
+
model,
|
| 464 |
+
optimizer,
|
| 465 |
+
args,
|
| 466 |
+
step,
|
| 467 |
+
best_val_loss,
|
| 468 |
+
ckpt.get("meta", {}),
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
save_checkpoint(
|
| 472 |
+
args.out_dir / "latest.pt",
|
| 473 |
+
model,
|
| 474 |
+
optimizer,
|
| 475 |
+
args,
|
| 476 |
+
args.max_steps,
|
| 477 |
+
best_val_loss,
|
| 478 |
+
ckpt.get("meta", {}),
|
| 479 |
+
)
|
| 480 |
+
console.print("Fine-tuning complete.")
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
if __name__ == "__main__":
|
| 484 |
+
main()
|
code/make_cnndm_sft.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import random
|
| 6 |
+
import re
|
| 7 |
+
import unicodedata
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_WS = re.compile(r"\s+")
|
| 15 |
+
_BAD_CHARS = re.compile(r"[\u0000-\u001f]")
|
| 16 |
+
_REFS = re.compile(r"\[\s*\d+\s*\]")
|
| 17 |
+
# CNN/DailyMail articles often start with "(CNN) -- " or "By . SOMEBODY . PUBLISHED:"
|
| 18 |
+
_CNN_PREFIX = re.compile(r"^\s*\(CNN\)\s*--\s*", re.IGNORECASE)
|
| 19 |
+
_BYLINE = re.compile(r"^\s*By\s+\.\s+.*?PUBLISHED:.*?\s*\.\s*", re.IGNORECASE | re.DOTALL)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
PROMPT_TEMPLATES = [
|
| 23 |
+
"Read the article and write a short summary.\n\nArticle:\n{passage}\n\nSummary:\n",
|
| 24 |
+
"Summarize the following article in a few sentences.\n\nArticle:\n{passage}\n\nShort summary:\n",
|
| 25 |
+
"Below is a news article. Give a concise summary using key facts from the article.\n\nArticle:\n{passage}\n\nSummary:\n",
|
| 26 |
+
"Provide a short summary of the article below.\n\nArticle:\n{passage}\n\nSummary:\n",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def normalize(text: str) -> str:
|
| 31 |
+
if text is None:
|
| 32 |
+
return ""
|
| 33 |
+
text = str(text)
|
| 34 |
+
text = text.replace("\ufffd", " ")
|
| 35 |
+
text = unicodedata.normalize("NFKC", text)
|
| 36 |
+
text = _BAD_CHARS.sub(" ", text)
|
| 37 |
+
text = _REFS.sub("", text)
|
| 38 |
+
text = _CNN_PREFIX.sub("", text)
|
| 39 |
+
text = _BYLINE.sub("", text)
|
| 40 |
+
text = _WS.sub(" ", text).strip()
|
| 41 |
+
return text
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def join_highlights(highlights: str) -> str:
|
| 45 |
+
"""
|
| 46 |
+
CNN/DailyMail highlights come as several short lines joined by newlines.
|
| 47 |
+
We join them into a single multi-sentence string with periods.
|
| 48 |
+
"""
|
| 49 |
+
if not highlights:
|
| 50 |
+
return ""
|
| 51 |
+
pieces = [p.strip() for p in highlights.split("\n") if p.strip()]
|
| 52 |
+
# Make sure each piece ends with terminal punctuation.
|
| 53 |
+
fixed = []
|
| 54 |
+
for p in pieces:
|
| 55 |
+
if p[-1] not in ".!?":
|
| 56 |
+
p = p + "."
|
| 57 |
+
fixed.append(p)
|
| 58 |
+
return " ".join(fixed)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def is_good_pair(article: str, summary: str, min_article_chars: int, max_article_chars: int,
|
| 62 |
+
min_summary_chars: int, max_summary_chars: int) -> bool:
|
| 63 |
+
if not article or not summary:
|
| 64 |
+
return False
|
| 65 |
+
if not (min_article_chars <= len(article) <= max_article_chars):
|
| 66 |
+
return False
|
| 67 |
+
if not (min_summary_chars <= len(summary) <= max_summary_chars):
|
| 68 |
+
return False
|
| 69 |
+
# Reject if the summary is basically the whole article (rare here but safe).
|
| 70 |
+
if len(summary) >= 0.8 * len(article):
|
| 71 |
+
return False
|
| 72 |
+
# Must be mostly letters.
|
| 73 |
+
letters = sum(ch.isalpha() for ch in article)
|
| 74 |
+
if letters / max(1, len(article)) < 0.6:
|
| 75 |
+
return False
|
| 76 |
+
return True
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def main() -> None:
|
| 80 |
+
parser = argparse.ArgumentParser(
|
| 81 |
+
description="Build an SFT set from CNN/DailyMail (near-extractive summaries)."
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument("--out_file", type=Path, default=Path("data/wikitext103/paragraph_sft.jsonl"))
|
| 84 |
+
parser.add_argument("--dataset", type=str, default="abisee/cnn_dailymail")
|
| 85 |
+
parser.add_argument("--config", type=str, default="3.0.0")
|
| 86 |
+
parser.add_argument("--max_examples", type=int, default=100000)
|
| 87 |
+
parser.add_argument("--min_article_chars", type=int, default=400)
|
| 88 |
+
parser.add_argument("--max_article_chars", type=int, default=2200)
|
| 89 |
+
parser.add_argument("--min_summary_chars", type=int, default=80)
|
| 90 |
+
parser.add_argument("--max_summary_chars", type=int, default=400)
|
| 91 |
+
parser.add_argument("--seed", type=int, default=1337)
|
| 92 |
+
args = parser.parse_args()
|
| 93 |
+
|
| 94 |
+
args.out_file.parent.mkdir(parents=True, exist_ok=True)
|
| 95 |
+
rng = random.Random(args.seed)
|
| 96 |
+
|
| 97 |
+
print(f"Loading {args.dataset} ({args.config})...")
|
| 98 |
+
dataset = load_dataset(args.dataset, args.config, split="train")
|
| 99 |
+
|
| 100 |
+
count = 0
|
| 101 |
+
skipped = 0
|
| 102 |
+
|
| 103 |
+
with args.out_file.open("w", encoding="utf-8") as f:
|
| 104 |
+
for row in tqdm(dataset, desc="building SFT"):
|
| 105 |
+
article = normalize(row.get("article", ""))
|
| 106 |
+
summary = join_highlights(normalize(row.get("highlights", "")))
|
| 107 |
+
|
| 108 |
+
if not is_good_pair(
|
| 109 |
+
article, summary,
|
| 110 |
+
args.min_article_chars, args.max_article_chars,
|
| 111 |
+
args.min_summary_chars, args.max_summary_chars,
|
| 112 |
+
):
|
| 113 |
+
skipped += 1
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
if len(article) > args.max_article_chars:
|
| 117 |
+
article = article[: args.max_article_chars].rsplit(" ", 1)[0]
|
| 118 |
+
|
| 119 |
+
template = rng.choice(PROMPT_TEMPLATES)
|
| 120 |
+
prompt = template.format(passage=article)
|
| 121 |
+
f.write(json.dumps({"prompt": prompt, "answer": summary}, ensure_ascii=False) + "\n")
|
| 122 |
+
count += 1
|
| 123 |
+
if count >= args.max_examples:
|
| 124 |
+
break
|
| 125 |
+
|
| 126 |
+
print(f"Wrote {count:,} examples to {args.out_file} (skipped={skipped:,})")
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__":
|
| 130 |
+
main()
|
code/model.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import inspect
|
| 4 |
+
import math
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.utils.checkpoint
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class GPTConfig:
|
| 15 |
+
vocab_size: int
|
| 16 |
+
block_size: int = 512
|
| 17 |
+
n_layer: int = 12
|
| 18 |
+
n_head: int = 12
|
| 19 |
+
n_embd: int = 768
|
| 20 |
+
dropout: float = 0.0
|
| 21 |
+
bias: bool = False
|
| 22 |
+
gradient_checkpointing: bool = False
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class LayerNorm(nn.Module):
|
| 26 |
+
def __init__(self, ndim: int, bias: bool):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.weight = nn.Parameter(torch.ones(ndim))
|
| 29 |
+
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
| 30 |
+
|
| 31 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class CausalSelfAttention(nn.Module):
|
| 36 |
+
def __init__(self, config: GPTConfig):
|
| 37 |
+
super().__init__()
|
| 38 |
+
assert config.n_embd % config.n_head == 0
|
| 39 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
| 40 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
| 41 |
+
self.attn_dropout = config.dropout
|
| 42 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 43 |
+
self.n_head = config.n_head
|
| 44 |
+
self.n_embd = config.n_embd
|
| 45 |
+
self.dropout = config.dropout
|
| 46 |
+
|
| 47 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
batch, seq_len, channels = x.size()
|
| 49 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
| 50 |
+
head_dim = channels // self.n_head
|
| 51 |
+
q = q.view(batch, seq_len, self.n_head, head_dim).transpose(1, 2)
|
| 52 |
+
k = k.view(batch, seq_len, self.n_head, head_dim).transpose(1, 2)
|
| 53 |
+
v = v.view(batch, seq_len, self.n_head, head_dim).transpose(1, 2)
|
| 54 |
+
|
| 55 |
+
y = F.scaled_dot_product_attention(
|
| 56 |
+
q,
|
| 57 |
+
k,
|
| 58 |
+
v,
|
| 59 |
+
attn_mask=None,
|
| 60 |
+
dropout_p=self.attn_dropout if self.training else 0.0,
|
| 61 |
+
is_causal=True,
|
| 62 |
+
)
|
| 63 |
+
y = y.transpose(1, 2).contiguous().view(batch, seq_len, channels)
|
| 64 |
+
return self.resid_dropout(self.c_proj(y))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class MLP(nn.Module):
|
| 68 |
+
def __init__(self, config: GPTConfig):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
| 71 |
+
self.gelu = nn.GELU(approximate="tanh")
|
| 72 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
| 73 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 74 |
+
|
| 75 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 76 |
+
return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class Block(nn.Module):
|
| 80 |
+
def __init__(self, config: GPTConfig):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
|
| 83 |
+
self.attn = CausalSelfAttention(config)
|
| 84 |
+
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
|
| 85 |
+
self.mlp = MLP(config)
|
| 86 |
+
|
| 87 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
x = x + self.attn(self.ln_1(x))
|
| 89 |
+
x = x + self.mlp(self.ln_2(x))
|
| 90 |
+
return x
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class GPT(nn.Module):
|
| 94 |
+
def __init__(self, config: GPTConfig):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.config = config
|
| 97 |
+
self.transformer = nn.ModuleDict(
|
| 98 |
+
{
|
| 99 |
+
"wte": nn.Embedding(config.vocab_size, config.n_embd),
|
| 100 |
+
"wpe": nn.Embedding(config.block_size, config.n_embd),
|
| 101 |
+
"drop": nn.Dropout(config.dropout),
|
| 102 |
+
"h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
| 103 |
+
"ln_f": LayerNorm(config.n_embd, bias=config.bias),
|
| 104 |
+
}
|
| 105 |
+
)
|
| 106 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 107 |
+
self.transformer.wte.weight = self.lm_head.weight
|
| 108 |
+
self.apply(self._init_weights)
|
| 109 |
+
for name, param in self.named_parameters():
|
| 110 |
+
if name.endswith("c_proj.weight"):
|
| 111 |
+
torch.nn.init.normal_(param, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
|
| 112 |
+
|
| 113 |
+
def _init_weights(self, module: nn.Module) -> None:
|
| 114 |
+
if isinstance(module, nn.Linear):
|
| 115 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 116 |
+
if module.bias is not None:
|
| 117 |
+
torch.nn.init.zeros_(module.bias)
|
| 118 |
+
elif isinstance(module, nn.Embedding):
|
| 119 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 120 |
+
|
| 121 |
+
def forward(
|
| 122 |
+
self, idx: torch.Tensor, targets: torch.Tensor | None = None
|
| 123 |
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
| 124 |
+
batch, seq_len = idx.size()
|
| 125 |
+
if seq_len > self.config.block_size:
|
| 126 |
+
raise ValueError(f"Sequence length {seq_len} exceeds block size {self.config.block_size}")
|
| 127 |
+
|
| 128 |
+
pos = torch.arange(0, seq_len, dtype=torch.long, device=idx.device)
|
| 129 |
+
x = self.transformer.drop(self.transformer.wte(idx) + self.transformer.wpe(pos))
|
| 130 |
+
for block in self.transformer.h:
|
| 131 |
+
if self.config.gradient_checkpointing and self.training:
|
| 132 |
+
x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
|
| 133 |
+
else:
|
| 134 |
+
x = block(x)
|
| 135 |
+
x = self.transformer.ln_f(x)
|
| 136 |
+
|
| 137 |
+
if targets is None:
|
| 138 |
+
logits = self.lm_head(x[:, [-1], :])
|
| 139 |
+
loss = None
|
| 140 |
+
else:
|
| 141 |
+
logits = self.lm_head(x)
|
| 142 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100)
|
| 143 |
+
return logits, loss
|
| 144 |
+
|
| 145 |
+
@torch.no_grad()
|
| 146 |
+
def generate(
|
| 147 |
+
self,
|
| 148 |
+
idx: torch.Tensor,
|
| 149 |
+
max_new_tokens: int,
|
| 150 |
+
temperature: float = 0.8,
|
| 151 |
+
top_k: int | None = 50,
|
| 152 |
+
eos_id: int | None = None,
|
| 153 |
+
) -> torch.Tensor:
|
| 154 |
+
for _ in range(max_new_tokens):
|
| 155 |
+
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size :]
|
| 156 |
+
logits, _ = self(idx_cond)
|
| 157 |
+
logits = logits[:, -1, :] / max(temperature, 1e-5)
|
| 158 |
+
if top_k is not None:
|
| 159 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 160 |
+
logits[logits < v[:, [-1]]] = -float("Inf")
|
| 161 |
+
probs = F.softmax(logits, dim=-1)
|
| 162 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
| 163 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
| 164 |
+
if eos_id is not None and idx_next.item() == eos_id:
|
| 165 |
+
break
|
| 166 |
+
return idx
|
| 167 |
+
|
| 168 |
+
def crop_block_size(self, block_size: int) -> None:
|
| 169 |
+
assert block_size <= self.config.block_size
|
| 170 |
+
self.config.block_size = block_size
|
| 171 |
+
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
|
| 172 |
+
|
| 173 |
+
def configure_optimizers(
|
| 174 |
+
self, weight_decay: float, learning_rate: float, betas: tuple[float, float], device_type: str
|
| 175 |
+
) -> torch.optim.Optimizer:
|
| 176 |
+
param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
|
| 177 |
+
decay_params = [p for _, p in param_dict.items() if p.dim() >= 2]
|
| 178 |
+
nodecay_params = [p for _, p in param_dict.items() if p.dim() < 2]
|
| 179 |
+
optim_groups = [
|
| 180 |
+
{"params": decay_params, "weight_decay": weight_decay},
|
| 181 |
+
{"params": nodecay_params, "weight_decay": 0.0},
|
| 182 |
+
]
|
| 183 |
+
fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
|
| 184 |
+
use_fused = fused_available and device_type == "cuda"
|
| 185 |
+
extra_args = {"fused": True} if use_fused else {}
|
| 186 |
+
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
| 187 |
+
|
| 188 |
+
def num_parameters(self) -> int:
|
| 189 |
+
return sum(p.numel() for p in self.parameters())
|
code/prepare_wikitext.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
from tokenizers import ByteLevelBPETokenizer, Tokenizer
|
| 10 |
+
from tokenizers import decoders as _decoders
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>", "<unk>"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def clean_lines(dataset):
|
| 18 |
+
for row in dataset:
|
| 19 |
+
text = row["text"].strip()
|
| 20 |
+
if text:
|
| 21 |
+
yield text
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class _TokenizerAdapter:
|
| 25 |
+
"""
|
| 26 |
+
Small adapter so the rest of the script can call .encode(text).ids and
|
| 27 |
+
.get_vocab() / .get_vocab_size() regardless of whether the tokenizer was
|
| 28 |
+
freshly trained (ByteLevelBPETokenizer) or reloaded from JSON (Tokenizer).
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, tokenizer):
|
| 32 |
+
self._t = tokenizer
|
| 33 |
+
|
| 34 |
+
def encode(self, text: str):
|
| 35 |
+
return self._t.encode(text)
|
| 36 |
+
|
| 37 |
+
def get_vocab(self):
|
| 38 |
+
return self._t.get_vocab()
|
| 39 |
+
|
| 40 |
+
def get_vocab_size(self):
|
| 41 |
+
return self._t.get_vocab_size()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_or_train_tokenizer(tokenizer_path: Path, train_dataset, vocab_size: int, min_frequency: int):
|
| 45 |
+
if tokenizer_path.exists():
|
| 46 |
+
print(f"Using existing tokenizer at {tokenizer_path}")
|
| 47 |
+
# Reload via the generic Tokenizer class. ByteLevelBPETokenizer does NOT
|
| 48 |
+
# accept tokenizer_file= in current tokenizers releases.
|
| 49 |
+
t = Tokenizer.from_file(str(tokenizer_path))
|
| 50 |
+
# Make sure a ByteLevel decoder is attached so downstream decoding works.
|
| 51 |
+
try:
|
| 52 |
+
current_decoder = t.decoder
|
| 53 |
+
except Exception:
|
| 54 |
+
current_decoder = None
|
| 55 |
+
if current_decoder is None:
|
| 56 |
+
t.decoder = _decoders.ByteLevel()
|
| 57 |
+
return _TokenizerAdapter(t)
|
| 58 |
+
|
| 59 |
+
print("Training byte-level BPE tokenizer...")
|
| 60 |
+
t = ByteLevelBPETokenizer()
|
| 61 |
+
t.train_from_iterator(
|
| 62 |
+
clean_lines(train_dataset),
|
| 63 |
+
vocab_size=vocab_size,
|
| 64 |
+
min_frequency=min_frequency,
|
| 65 |
+
special_tokens=SPECIAL_TOKENS,
|
| 66 |
+
)
|
| 67 |
+
t.save(str(tokenizer_path))
|
| 68 |
+
# Reopen via generic Tokenizer so we attach a decoder consistently.
|
| 69 |
+
reopened = Tokenizer.from_file(str(tokenizer_path))
|
| 70 |
+
try:
|
| 71 |
+
current_decoder = reopened.decoder
|
| 72 |
+
except Exception:
|
| 73 |
+
current_decoder = None
|
| 74 |
+
if current_decoder is None:
|
| 75 |
+
reopened.decoder = _decoders.ByteLevel()
|
| 76 |
+
return _TokenizerAdapter(reopened)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def write_split(tokenizer, dataset, out_file: Path, dtype, bos_id: int, eos_id: int) -> int:
|
| 80 |
+
token_count = 0
|
| 81 |
+
with out_file.open("wb") as f:
|
| 82 |
+
for text in tqdm(clean_lines(dataset), desc=f"tokenizing {out_file.name}"):
|
| 83 |
+
ids = [bos_id] + tokenizer.encode(text).ids + [eos_id]
|
| 84 |
+
arr = np.asarray(ids, dtype=dtype)
|
| 85 |
+
arr.tofile(f)
|
| 86 |
+
token_count += len(ids)
|
| 87 |
+
return token_count
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def main() -> None:
|
| 91 |
+
parser = argparse.ArgumentParser(
|
| 92 |
+
description="Download WikiText-103, train a tokenizer, and make binary token files."
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument("--data_dir", type=Path, default=Path("data/wikitext103"))
|
| 95 |
+
parser.add_argument("--dataset", type=str, default="Salesforce/wikitext")
|
| 96 |
+
parser.add_argument("--config", type=str, default="wikitext-103-raw-v1")
|
| 97 |
+
parser.add_argument("--vocab_size", type=int, default=32000)
|
| 98 |
+
parser.add_argument("--min_frequency", type=int, default=2)
|
| 99 |
+
args = parser.parse_args()
|
| 100 |
+
|
| 101 |
+
args.data_dir.mkdir(parents=True, exist_ok=True)
|
| 102 |
+
tokenizer_path = args.data_dir / "tokenizer.json"
|
| 103 |
+
|
| 104 |
+
print("Loading WikiText-103...")
|
| 105 |
+
train = load_dataset(args.dataset, args.config, split="train")
|
| 106 |
+
val = load_dataset(args.dataset, args.config, split="validation")
|
| 107 |
+
test = load_dataset(args.dataset, args.config, split="test")
|
| 108 |
+
|
| 109 |
+
tokenizer = load_or_train_tokenizer(
|
| 110 |
+
tokenizer_path=tokenizer_path,
|
| 111 |
+
train_dataset=train,
|
| 112 |
+
vocab_size=args.vocab_size,
|
| 113 |
+
min_frequency=args.min_frequency,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
vocab = tokenizer.get_vocab()
|
| 117 |
+
if "<bos>" not in vocab or "<eos>" not in vocab or "<pad>" not in vocab:
|
| 118 |
+
raise RuntimeError(
|
| 119 |
+
"Tokenizer is missing required special tokens (<pad>, <bos>, <eos>). "
|
| 120 |
+
"Delete data/wikitext103/tokenizer.json and re-run to retrain."
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
bos_id = vocab["<bos>"]
|
| 124 |
+
eos_id = vocab["<eos>"]
|
| 125 |
+
pad_id = vocab["<pad>"]
|
| 126 |
+
|
| 127 |
+
vocab_size = tokenizer.get_vocab_size()
|
| 128 |
+
dtype = np.uint16 if vocab_size <= np.iinfo(np.uint16).max else np.uint32
|
| 129 |
+
|
| 130 |
+
train_tokens = write_split(tokenizer, train, args.data_dir / "train.bin", dtype, bos_id, eos_id)
|
| 131 |
+
val_tokens = write_split(tokenizer, val, args.data_dir / "val.bin", dtype, bos_id, eos_id)
|
| 132 |
+
test_tokens = write_split(tokenizer, test, args.data_dir / "test.bin", dtype, bos_id, eos_id)
|
| 133 |
+
|
| 134 |
+
meta = {
|
| 135 |
+
"dataset": args.dataset,
|
| 136 |
+
"config": args.config,
|
| 137 |
+
"vocab_size": vocab_size,
|
| 138 |
+
"dtype": "uint16" if dtype == np.uint16 else "uint32",
|
| 139 |
+
"bos_id": bos_id,
|
| 140 |
+
"eos_id": eos_id,
|
| 141 |
+
"pad_id": pad_id,
|
| 142 |
+
"train_tokens": train_tokens,
|
| 143 |
+
"val_tokens": val_tokens,
|
| 144 |
+
"test_tokens": test_tokens,
|
| 145 |
+
}
|
| 146 |
+
(args.data_dir / "meta.json").write_text(json.dumps(meta, indent=2), encoding="utf-8")
|
| 147 |
+
print(f"Done. Wrote tokenizer and token files to {args.data_dir}")
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
if __name__ == "__main__":
|
| 151 |
+
main()
|
code/tokenizer.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from tokenizers import Tokenizer
|
| 6 |
+
from tokenizers import decoders as _decoders
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TextTokenizer:
|
| 10 |
+
"""
|
| 11 |
+
Wrapper around tokenizers.Tokenizer that guarantees a ByteLevel decoder
|
| 12 |
+
is attached. ByteLevelBPETokenizer saves a JSON without a decoder block,
|
| 13 |
+
so reloading via Tokenizer.from_file() yields a tokenizer whose .decode()
|
| 14 |
+
returns raw byte-level tokens (Ġ, ä) and replacement chars (�, �)
|
| 15 |
+
instead of proper UTF-8 text. We attach the decoder here so decode is
|
| 16 |
+
always correct.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, path: str | Path):
|
| 20 |
+
self.path = Path(path)
|
| 21 |
+
self.tokenizer = Tokenizer.from_file(str(self.path))
|
| 22 |
+
|
| 23 |
+
# Force a ByteLevel decoder if one is not attached.
|
| 24 |
+
try:
|
| 25 |
+
current_decoder = self.tokenizer.decoder
|
| 26 |
+
except Exception:
|
| 27 |
+
current_decoder = None
|
| 28 |
+
if current_decoder is None:
|
| 29 |
+
self.tokenizer.decoder = _decoders.ByteLevel()
|
| 30 |
+
|
| 31 |
+
vocab = self.tokenizer.get_vocab()
|
| 32 |
+
self.pad_id = vocab.get("<pad>", 0)
|
| 33 |
+
self.bos_id = vocab.get("<bos>", 1)
|
| 34 |
+
self.eos_id = vocab.get("<eos>", 2)
|
| 35 |
+
self.unk_id = vocab.get("<unk>", 3)
|
| 36 |
+
self.vocab_size = self.tokenizer.get_vocab_size()
|
| 37 |
+
|
| 38 |
+
def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> list[int]:
|
| 39 |
+
ids = self.tokenizer.encode(text).ids
|
| 40 |
+
if add_bos:
|
| 41 |
+
ids = [self.bos_id] + ids
|
| 42 |
+
if add_eos:
|
| 43 |
+
ids = ids + [self.eos_id]
|
| 44 |
+
return ids
|
| 45 |
+
|
| 46 |
+
def decode(self, ids: list[int], skip_special_tokens: bool = True) -> str:
|
| 47 |
+
if skip_special_tokens:
|
| 48 |
+
specials = {self.pad_id, self.bos_id, self.eos_id, self.unk_id}
|
| 49 |
+
ids = [int(i) for i in ids if int(i) not in specials]
|
| 50 |
+
else:
|
| 51 |
+
ids = [int(i) for i in ids]
|
| 52 |
+
return self.tokenizer.decode(ids, skip_special_tokens=skip_special_tokens)
|
code/train.py
ADDED
|
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import time
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from rich.console import Console
|
| 15 |
+
|
| 16 |
+
from searshorai.model import GPT, GPTConfig
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
console = Console()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
PRESETS = {
|
| 23 |
+
"quick_test": dict(
|
| 24 |
+
n_layer=6,
|
| 25 |
+
n_head=6,
|
| 26 |
+
n_embd=384,
|
| 27 |
+
block_size=256,
|
| 28 |
+
batch_size=8,
|
| 29 |
+
grad_accum=8,
|
| 30 |
+
max_steps=1000,
|
| 31 |
+
),
|
| 32 |
+
"gpu_16gb": dict(
|
| 33 |
+
n_layer=10,
|
| 34 |
+
n_head=10,
|
| 35 |
+
n_embd=640,
|
| 36 |
+
block_size=512,
|
| 37 |
+
batch_size=4,
|
| 38 |
+
grad_accum=16,
|
| 39 |
+
max_steps=20000,
|
| 40 |
+
),
|
| 41 |
+
"rtx3090_8h": dict(
|
| 42 |
+
n_layer=12,
|
| 43 |
+
n_head=12,
|
| 44 |
+
n_embd=768,
|
| 45 |
+
block_size=512,
|
| 46 |
+
batch_size=8,
|
| 47 |
+
grad_accum=16,
|
| 48 |
+
max_steps=20000,
|
| 49 |
+
),
|
| 50 |
+
"rtx3090_quality": dict(
|
| 51 |
+
n_layer=16,
|
| 52 |
+
n_head=16,
|
| 53 |
+
n_embd=1024,
|
| 54 |
+
block_size=512,
|
| 55 |
+
batch_size=4,
|
| 56 |
+
grad_accum=24,
|
| 57 |
+
max_steps=30000,
|
| 58 |
+
),
|
| 59 |
+
"gpu_40gb_quality": dict(
|
| 60 |
+
n_layer=20,
|
| 61 |
+
n_head=16,
|
| 62 |
+
n_embd=1024,
|
| 63 |
+
block_size=768,
|
| 64 |
+
batch_size=4,
|
| 65 |
+
grad_accum=32,
|
| 66 |
+
max_steps=40000,
|
| 67 |
+
),
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def parse_args() -> argparse.Namespace:
|
| 72 |
+
parser = argparse.ArgumentParser(description="Train a GPT-style language model from scratch.")
|
| 73 |
+
|
| 74 |
+
parser.add_argument("--data_dir", type=Path, default=Path("data/wikitext103"))
|
| 75 |
+
parser.add_argument("--out_dir", type=Path, default=Path("runs/wikitext-gpt"))
|
| 76 |
+
|
| 77 |
+
parser.add_argument("--preset", choices=PRESETS.keys(), default="gpu_16gb")
|
| 78 |
+
|
| 79 |
+
parser.add_argument("--resume", type=Path, default=None)
|
| 80 |
+
parser.add_argument("--reset_optimizer", action="store_true")
|
| 81 |
+
parser.add_argument("--reset_step", action="store_true",
|
| 82 |
+
help="When resuming, restart step counter at 0 (useful when restarting a fresh schedule).")
|
| 83 |
+
|
| 84 |
+
parser.add_argument("--n_layer", type=int, default=None)
|
| 85 |
+
parser.add_argument("--n_head", type=int, default=None)
|
| 86 |
+
parser.add_argument("--n_embd", type=int, default=None)
|
| 87 |
+
parser.add_argument("--block_size", type=int, default=None)
|
| 88 |
+
|
| 89 |
+
parser.add_argument("--batch_size", type=int, default=None, help="Micro-batch size.")
|
| 90 |
+
parser.add_argument("--grad_accum", type=int, default=None)
|
| 91 |
+
parser.add_argument("--max_steps", type=int, default=None)
|
| 92 |
+
|
| 93 |
+
parser.add_argument("--learning_rate", type=float, default=2.5e-4)
|
| 94 |
+
parser.add_argument("--min_lr", type=float, default=2.5e-5)
|
| 95 |
+
parser.add_argument("--warmup_steps", type=int, default=1000)
|
| 96 |
+
parser.add_argument("--weight_decay", type=float, default=0.1)
|
| 97 |
+
parser.add_argument("--dropout", type=float, default=0.0)
|
| 98 |
+
parser.add_argument("--grad_clip", type=float, default=1.0)
|
| 99 |
+
|
| 100 |
+
parser.add_argument("--eval_interval", type=int, default=500)
|
| 101 |
+
parser.add_argument("--eval_iters", type=int, default=100)
|
| 102 |
+
parser.add_argument("--save_interval", type=int, default=1000)
|
| 103 |
+
parser.add_argument("--log_interval", type=int, default=20)
|
| 104 |
+
|
| 105 |
+
parser.add_argument("--seed", type=int, default=1337)
|
| 106 |
+
|
| 107 |
+
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"])
|
| 108 |
+
parser.add_argument("--dtype", type=str, default="auto", choices=["auto", "float32", "float16", "bfloat16"])
|
| 109 |
+
|
| 110 |
+
parser.add_argument("--compile", action="store_true")
|
| 111 |
+
|
| 112 |
+
parser.add_argument("--gradient_checkpointing", action="store_true")
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"--no_gradient_checkpointing",
|
| 115 |
+
"--no-gradient-checkpointing",
|
| 116 |
+
action="store_true",
|
| 117 |
+
help="Disable checkpointing when resuming from a checkpoint that was trained with it.",
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
parser.add_argument("--eval_only", action="store_true")
|
| 121 |
+
parser.add_argument("--always_save_checkpoint", action="store_true")
|
| 122 |
+
parser.add_argument("--save_optimizer", action="store_true")
|
| 123 |
+
|
| 124 |
+
return parser.parse_args()
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def apply_preset(args: argparse.Namespace) -> argparse.Namespace:
|
| 128 |
+
preset = PRESETS[args.preset]
|
| 129 |
+
for key, value in preset.items():
|
| 130 |
+
if getattr(args, key) is None:
|
| 131 |
+
setattr(args, key, value)
|
| 132 |
+
return args
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def setup_reproducibility(seed: int) -> None:
|
| 136 |
+
random.seed(seed)
|
| 137 |
+
np.random.seed(seed)
|
| 138 |
+
torch.manual_seed(seed)
|
| 139 |
+
if torch.cuda.is_available():
|
| 140 |
+
torch.cuda.manual_seed_all(seed)
|
| 141 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 142 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 143 |
+
torch.backends.cudnn.benchmark = True
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def choose_device(args: argparse.Namespace) -> str:
|
| 147 |
+
if args.device == "auto":
|
| 148 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 149 |
+
if args.device == "cuda" and not torch.cuda.is_available():
|
| 150 |
+
raise RuntimeError("CUDA was requested, but torch.cuda.is_available() is False.")
|
| 151 |
+
return args.device
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def choose_dtype(args: argparse.Namespace, device: str) -> torch.dtype:
|
| 155 |
+
if device == "cpu":
|
| 156 |
+
return torch.float32
|
| 157 |
+
if args.dtype == "float32":
|
| 158 |
+
return torch.float32
|
| 159 |
+
if args.dtype == "float16":
|
| 160 |
+
return torch.float16
|
| 161 |
+
if args.dtype == "bfloat16":
|
| 162 |
+
if torch.cuda.is_bf16_supported():
|
| 163 |
+
return torch.bfloat16
|
| 164 |
+
console.print("[yellow]bfloat16 requested but not supported. Falling back to float16.[/yellow]")
|
| 165 |
+
return torch.float16
|
| 166 |
+
if torch.cuda.is_bf16_supported():
|
| 167 |
+
return torch.bfloat16
|
| 168 |
+
return torch.float16
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def make_autocast_context(device: str, dtype: torch.dtype):
|
| 172 |
+
enabled = device == "cuda" and dtype in (torch.float16, torch.bfloat16)
|
| 173 |
+
return torch.amp.autocast(device_type=device, dtype=dtype, enabled=enabled)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def make_grad_scaler(device: str, dtype: torch.dtype):
|
| 177 |
+
enabled = device == "cuda" and dtype == torch.float16
|
| 178 |
+
try:
|
| 179 |
+
return torch.amp.GradScaler("cuda", enabled=enabled)
|
| 180 |
+
except TypeError:
|
| 181 |
+
return torch.cuda.amp.GradScaler(enabled=enabled)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def get_lr(step: int, args: argparse.Namespace) -> float:
|
| 185 |
+
if step < args.warmup_steps:
|
| 186 |
+
return args.learning_rate * step / max(1, args.warmup_steps)
|
| 187 |
+
if step > args.max_steps:
|
| 188 |
+
return args.min_lr
|
| 189 |
+
decay_ratio = (step - args.warmup_steps) / max(1, args.max_steps - args.warmup_steps)
|
| 190 |
+
decay_ratio = min(1.0, max(0.0, decay_ratio))
|
| 191 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
| 192 |
+
return args.min_lr + coeff * (args.learning_rate - args.min_lr)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def load_json(path: Path) -> dict[str, Any]:
|
| 196 |
+
if not path.exists():
|
| 197 |
+
raise FileNotFoundError(f"Missing required file: {path}")
|
| 198 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def validate_meta(meta: dict[str, Any]) -> None:
|
| 202 |
+
required_keys = ["vocab_size", "dtype"]
|
| 203 |
+
for key in required_keys:
|
| 204 |
+
if key not in meta:
|
| 205 |
+
raise KeyError(f"meta.json is missing required key: {key}")
|
| 206 |
+
if meta["dtype"] not in ("uint16", "uint32"):
|
| 207 |
+
raise ValueError(f"Unsupported meta dtype: {meta['dtype']}. Expected uint16 or uint32.")
|
| 208 |
+
if int(meta["vocab_size"]) <= 0:
|
| 209 |
+
raise ValueError("meta.json vocab_size must be greater than zero.")
|
| 210 |
+
if meta["dtype"] == "uint16" and int(meta["vocab_size"]) > 65535:
|
| 211 |
+
raise ValueError("meta dtype is uint16 but vocab_size is greater than 65535. Use uint32 data files.")
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def load_memmap(path: Path, dtype: str) -> np.memmap:
|
| 215 |
+
if not path.exists():
|
| 216 |
+
raise FileNotFoundError(f"Missing required file: {path}")
|
| 217 |
+
np_dtype = np.uint16 if dtype == "uint16" else np.uint32
|
| 218 |
+
return np.memmap(path, dtype=np_dtype, mode="r")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def validate_dataset(train_data: np.memmap, val_data: np.memmap, block_size: int, vocab_size: int) -> None:
|
| 222 |
+
min_required = block_size + 2
|
| 223 |
+
if len(train_data) < min_required:
|
| 224 |
+
raise ValueError(
|
| 225 |
+
f"train.bin is too small. Need at least {min_required} tokens for block_size={block_size}, "
|
| 226 |
+
f"but got {len(train_data)}."
|
| 227 |
+
)
|
| 228 |
+
if len(val_data) < min_required:
|
| 229 |
+
raise ValueError(
|
| 230 |
+
f"val.bin is too small. Need at least {min_required} tokens for block_size={block_size}, "
|
| 231 |
+
f"but got {len(val_data)}."
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
sample_count = min(10000, len(train_data))
|
| 235 |
+
sample_positions = np.linspace(0, len(train_data) - 1, sample_count, dtype=np.int64)
|
| 236 |
+
sample = np.asarray(train_data[sample_positions], dtype=np.int64)
|
| 237 |
+
max_token = int(sample.max())
|
| 238 |
+
min_token = int(sample.min())
|
| 239 |
+
if min_token < 0:
|
| 240 |
+
raise ValueError(f"Dataset contains negative token id: {min_token}")
|
| 241 |
+
if max_token >= vocab_size:
|
| 242 |
+
raise ValueError(
|
| 243 |
+
f"Dataset token id {max_token} is >= vocab_size {vocab_size}. "
|
| 244 |
+
"This usually means tokenizer/meta/train.bin mismatch."
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def get_batch(
|
| 249 |
+
data: np.memmap,
|
| 250 |
+
batch_size: int,
|
| 251 |
+
block_size: int,
|
| 252 |
+
device: str,
|
| 253 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 254 |
+
"""
|
| 255 |
+
Fast batch loader: one vectorized gather, then a single host->device transfer.
|
| 256 |
+
The old code did batch_size python-level numpy slices per call, which was a
|
| 257 |
+
major bottleneck.
|
| 258 |
+
"""
|
| 259 |
+
max_start = len(data) - block_size - 1
|
| 260 |
+
if max_start <= 0:
|
| 261 |
+
raise ValueError("Dataset is too small for the configured block_size.")
|
| 262 |
+
|
| 263 |
+
# Random start positions.
|
| 264 |
+
ix = np.random.randint(0, max_start, size=(batch_size,), dtype=np.int64)
|
| 265 |
+
|
| 266 |
+
# Allocate contiguous int64 arrays. memmap reads are cheap for sequential blocks.
|
| 267 |
+
x_np = np.empty((batch_size, block_size), dtype=np.int64)
|
| 268 |
+
y_np = np.empty((batch_size, block_size), dtype=np.int64)
|
| 269 |
+
for row, start in enumerate(ix):
|
| 270 |
+
x_np[row] = data[start : start + block_size]
|
| 271 |
+
y_np[row] = data[start + 1 : start + 1 + block_size]
|
| 272 |
+
|
| 273 |
+
x = torch.from_numpy(x_np)
|
| 274 |
+
y = torch.from_numpy(y_np)
|
| 275 |
+
|
| 276 |
+
if device == "cuda":
|
| 277 |
+
x = x.pin_memory().to(device, non_blocking=True)
|
| 278 |
+
y = y.pin_memory().to(device, non_blocking=True)
|
| 279 |
+
else:
|
| 280 |
+
x = x.to(device)
|
| 281 |
+
y = y.to(device)
|
| 282 |
+
return x, y
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
@torch.no_grad()
|
| 286 |
+
def estimate_loss(
|
| 287 |
+
model: GPT,
|
| 288 |
+
train_data: np.memmap,
|
| 289 |
+
val_data: np.memmap,
|
| 290 |
+
args: argparse.Namespace,
|
| 291 |
+
device: str,
|
| 292 |
+
autocast_ctx,
|
| 293 |
+
) -> dict[str, float]:
|
| 294 |
+
out: dict[str, float] = {}
|
| 295 |
+
model.eval()
|
| 296 |
+
for split, data in [("train", train_data), ("val", val_data)]:
|
| 297 |
+
losses = []
|
| 298 |
+
for _ in range(args.eval_iters):
|
| 299 |
+
x, y = get_batch(data, args.batch_size, args.block_size, device)
|
| 300 |
+
with autocast_ctx:
|
| 301 |
+
_, loss = model(x, y)
|
| 302 |
+
if torch.isfinite(loss):
|
| 303 |
+
losses.append(float(loss.item()))
|
| 304 |
+
out[split] = float(sum(losses) / max(1, len(losses)))
|
| 305 |
+
model.train()
|
| 306 |
+
return out
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def unwrap_model(model: GPT) -> GPT:
|
| 310 |
+
if hasattr(model, "_orig_mod"):
|
| 311 |
+
return model._orig_mod
|
| 312 |
+
return model
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def strip_compile_prefix(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
| 316 |
+
cleaned = {}
|
| 317 |
+
for key, value in state_dict.items():
|
| 318 |
+
if key.startswith("_orig_mod."):
|
| 319 |
+
key = key[len("_orig_mod.") :]
|
| 320 |
+
cleaned[key] = value
|
| 321 |
+
return cleaned
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def optimizer_to_device(optimizer: torch.optim.Optimizer, device: str) -> None:
|
| 325 |
+
for state in optimizer.state.values():
|
| 326 |
+
for key, value in state.items():
|
| 327 |
+
if isinstance(value, torch.Tensor):
|
| 328 |
+
state[key] = value.to(device)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def save_checkpoint(
|
| 332 |
+
path: Path,
|
| 333 |
+
model: GPT,
|
| 334 |
+
optimizer: torch.optim.Optimizer | None,
|
| 335 |
+
args: argparse.Namespace,
|
| 336 |
+
step: int,
|
| 337 |
+
best_val_loss: float,
|
| 338 |
+
meta: dict[str, Any],
|
| 339 |
+
) -> None:
|
| 340 |
+
raw_model = unwrap_model(model)
|
| 341 |
+
checkpoint: dict[str, Any] = {
|
| 342 |
+
"model": raw_model.state_dict(),
|
| 343 |
+
"args": vars(args),
|
| 344 |
+
"config": vars(raw_model.config),
|
| 345 |
+
"step": step,
|
| 346 |
+
"best_val_loss": best_val_loss,
|
| 347 |
+
"meta": meta,
|
| 348 |
+
}
|
| 349 |
+
if args.save_optimizer and optimizer is not None:
|
| 350 |
+
checkpoint["optimizer"] = optimizer.state_dict()
|
| 351 |
+
torch.save(checkpoint, path)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def write_run_config(args: argparse.Namespace, meta: dict[str, Any], device: str, dtype: torch.dtype) -> None:
|
| 355 |
+
config_path = args.out_dir / "run_config.json"
|
| 356 |
+
payload = {
|
| 357 |
+
"args": {k: (str(v) if isinstance(v, Path) else v) for k, v in vars(args).items()},
|
| 358 |
+
"meta": meta,
|
| 359 |
+
"device": device,
|
| 360 |
+
"dtype": str(dtype),
|
| 361 |
+
"torch_version": torch.__version__,
|
| 362 |
+
"cuda_available": torch.cuda.is_available(),
|
| 363 |
+
"cuda_device_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
|
| 364 |
+
}
|
| 365 |
+
config_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def build_model_from_checkpoint(
|
| 369 |
+
ckpt_path: Path,
|
| 370 |
+
device: str,
|
| 371 |
+
args: argparse.Namespace,
|
| 372 |
+
) -> tuple[GPT, int, float, dict[str, Any]]:
|
| 373 |
+
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 374 |
+
config = GPTConfig(**ckpt["config"])
|
| 375 |
+
if hasattr(config, "gradient_checkpointing"):
|
| 376 |
+
if args.no_gradient_checkpointing:
|
| 377 |
+
config.gradient_checkpointing = False
|
| 378 |
+
elif args.gradient_checkpointing:
|
| 379 |
+
config.gradient_checkpointing = True
|
| 380 |
+
model = GPT(config)
|
| 381 |
+
state_dict = strip_compile_prefix(ckpt["model"])
|
| 382 |
+
model.load_state_dict(state_dict, strict=True)
|
| 383 |
+
start_step = int(ckpt.get("step", 0))
|
| 384 |
+
best_val_loss = float(ckpt.get("best_val_loss", float("inf")))
|
| 385 |
+
checkpoint_meta = ckpt.get("meta", {})
|
| 386 |
+
return model, start_step, best_val_loss, checkpoint_meta
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def build_new_model(meta: dict[str, Any], args: argparse.Namespace) -> tuple[GPT, int, float]:
|
| 390 |
+
config = GPTConfig(
|
| 391 |
+
vocab_size=int(meta["vocab_size"]),
|
| 392 |
+
block_size=int(args.block_size),
|
| 393 |
+
n_layer=int(args.n_layer),
|
| 394 |
+
n_head=int(args.n_head),
|
| 395 |
+
n_embd=int(args.n_embd),
|
| 396 |
+
dropout=float(args.dropout),
|
| 397 |
+
gradient_checkpointing=bool(args.gradient_checkpointing),
|
| 398 |
+
)
|
| 399 |
+
model = GPT(config)
|
| 400 |
+
return model, 0, float("inf")
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def print_startup_info(
|
| 404 |
+
model: GPT,
|
| 405 |
+
args: argparse.Namespace,
|
| 406 |
+
device: str,
|
| 407 |
+
dtype: torch.dtype,
|
| 408 |
+
train_data: np.memmap,
|
| 409 |
+
val_data: np.memmap,
|
| 410 |
+
start_step: int,
|
| 411 |
+
) -> None:
|
| 412 |
+
raw_model = unwrap_model(model)
|
| 413 |
+
tokens_per_step = args.batch_size * args.grad_accum * args.block_size
|
| 414 |
+
if hasattr(raw_model, "num_parameters"):
|
| 415 |
+
num_params = raw_model.num_parameters()
|
| 416 |
+
else:
|
| 417 |
+
num_params = sum(p.numel() for p in raw_model.parameters())
|
| 418 |
+
|
| 419 |
+
console.print("")
|
| 420 |
+
console.print("[bold green]Training configuration[/bold green]")
|
| 421 |
+
console.print(f"Device: {device}")
|
| 422 |
+
console.print(f"Dtype: {dtype}")
|
| 423 |
+
console.print(f"Preset: {args.preset}")
|
| 424 |
+
console.print(f"Parameters: {num_params / 1e6:.2f}M")
|
| 425 |
+
console.print(f"Layers: {args.n_layer}")
|
| 426 |
+
console.print(f"Heads: {args.n_head}")
|
| 427 |
+
console.print(f"Embedding size: {args.n_embd}")
|
| 428 |
+
console.print(f"Block size: {args.block_size}")
|
| 429 |
+
console.print(f"Batch size: {args.batch_size}")
|
| 430 |
+
console.print(f"Grad accumulation: {args.grad_accum}")
|
| 431 |
+
console.print(f"Tokens per step: {tokens_per_step:,}")
|
| 432 |
+
console.print(f"Train tokens: {len(train_data):,}")
|
| 433 |
+
console.print(f"Val tokens: {len(val_data):,}")
|
| 434 |
+
console.print(f"Start step: {start_step:,}")
|
| 435 |
+
console.print(f"Max steps: {args.max_steps:,}")
|
| 436 |
+
console.print(f"Learning rate: {args.learning_rate:.2e}")
|
| 437 |
+
console.print(f"Min LR: {args.min_lr:.2e}")
|
| 438 |
+
console.print(f"Warmup steps: {args.warmup_steps:,}")
|
| 439 |
+
console.print(f"Grad clip: {args.grad_clip}")
|
| 440 |
+
console.print("")
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def main() -> None:
|
| 444 |
+
args = apply_preset(parse_args())
|
| 445 |
+
args.out_dir.mkdir(parents=True, exist_ok=True)
|
| 446 |
+
setup_reproducibility(args.seed)
|
| 447 |
+
|
| 448 |
+
device = choose_device(args)
|
| 449 |
+
dtype = choose_dtype(args, device)
|
| 450 |
+
autocast_ctx = make_autocast_context(device, dtype)
|
| 451 |
+
scaler = make_grad_scaler(device, dtype)
|
| 452 |
+
|
| 453 |
+
meta_path = args.data_dir / "meta.json"
|
| 454 |
+
meta = load_json(meta_path)
|
| 455 |
+
validate_meta(meta)
|
| 456 |
+
|
| 457 |
+
train_data = load_memmap(args.data_dir / "train.bin", meta["dtype"])
|
| 458 |
+
val_data = load_memmap(args.data_dir / "val.bin", meta["dtype"])
|
| 459 |
+
validate_dataset(
|
| 460 |
+
train_data=train_data,
|
| 461 |
+
val_data=val_data,
|
| 462 |
+
block_size=int(args.block_size),
|
| 463 |
+
vocab_size=int(meta["vocab_size"]),
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
if args.resume is not None:
|
| 467 |
+
console.print(f"[yellow]Resuming from checkpoint:[/yellow] {args.resume}")
|
| 468 |
+
model, start_step, best_val_loss, checkpoint_meta = build_model_from_checkpoint(args.resume, device, args)
|
| 469 |
+
if checkpoint_meta:
|
| 470 |
+
meta = checkpoint_meta
|
| 471 |
+
else:
|
| 472 |
+
model, start_step, best_val_loss = build_new_model(meta, args)
|
| 473 |
+
|
| 474 |
+
if args.reset_step:
|
| 475 |
+
start_step = 0
|
| 476 |
+
best_val_loss = float("inf")
|
| 477 |
+
console.print("[yellow]reset_step set: step counter restarted at 0.[/yellow]")
|
| 478 |
+
|
| 479 |
+
model.to(device)
|
| 480 |
+
|
| 481 |
+
optimizer = model.configure_optimizers(
|
| 482 |
+
args.weight_decay,
|
| 483 |
+
args.learning_rate,
|
| 484 |
+
(0.9, 0.95),
|
| 485 |
+
"cuda" if device == "cuda" else "cpu",
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
if args.resume is not None and not args.reset_optimizer:
|
| 489 |
+
ckpt = torch.load(args.resume, map_location=device, weights_only=False)
|
| 490 |
+
if "optimizer" in ckpt:
|
| 491 |
+
try:
|
| 492 |
+
optimizer.load_state_dict(ckpt["optimizer"])
|
| 493 |
+
optimizer_to_device(optimizer, device)
|
| 494 |
+
console.print("[green]Loaded optimizer state from checkpoint.[/green]")
|
| 495 |
+
except Exception as exc:
|
| 496 |
+
console.print(f"[yellow]Could not load optimizer state. Continuing with fresh optimizer. Error: {exc}[/yellow]")
|
| 497 |
+
else:
|
| 498 |
+
console.print("[yellow]Checkpoint has no optimizer state. Continuing with fresh optimizer.[/yellow]")
|
| 499 |
+
elif args.resume is not None and args.reset_optimizer:
|
| 500 |
+
console.print("[yellow]reset_optimizer set: starting with fresh Adam moments.[/yellow]")
|
| 501 |
+
|
| 502 |
+
if args.compile:
|
| 503 |
+
console.print("[cyan]Compiling model...[/cyan]")
|
| 504 |
+
model = torch.compile(model)
|
| 505 |
+
|
| 506 |
+
write_run_config(args, meta, device, dtype)
|
| 507 |
+
print_startup_info(model, args, device, dtype, train_data, val_data, start_step)
|
| 508 |
+
|
| 509 |
+
if args.eval_only:
|
| 510 |
+
losses = estimate_loss(model, train_data, val_data, args, device, autocast_ctx)
|
| 511 |
+
console.print(f"eval only: train {losses['train']:.4f}, val {losses['val']:.4f}")
|
| 512 |
+
return
|
| 513 |
+
|
| 514 |
+
model.train()
|
| 515 |
+
tokens_per_step = args.batch_size * args.grad_accum * args.block_size
|
| 516 |
+
|
| 517 |
+
start_time = time.time()
|
| 518 |
+
last_log_time = start_time
|
| 519 |
+
last_log_step = start_step
|
| 520 |
+
|
| 521 |
+
for completed_step in range(start_step, args.max_steps):
|
| 522 |
+
step = completed_step + 1
|
| 523 |
+
|
| 524 |
+
lr = get_lr(step, args)
|
| 525 |
+
for param_group in optimizer.param_groups:
|
| 526 |
+
param_group["lr"] = lr
|
| 527 |
+
|
| 528 |
+
optimizer.zero_grad(set_to_none=True)
|
| 529 |
+
loss_accum = 0.0
|
| 530 |
+
skipped_micro = 0
|
| 531 |
+
|
| 532 |
+
for _ in range(args.grad_accum):
|
| 533 |
+
x, y = get_batch(train_data, args.batch_size, args.block_size, device)
|
| 534 |
+
with autocast_ctx:
|
| 535 |
+
_, loss = model(x, y)
|
| 536 |
+
loss = loss / args.grad_accum
|
| 537 |
+
if not torch.isfinite(loss):
|
| 538 |
+
console.print(f"[yellow]Non-finite loss at step {step}, skipping micro-batch.[/yellow]")
|
| 539 |
+
skipped_micro += 1
|
| 540 |
+
continue
|
| 541 |
+
scaler.scale(loss).backward()
|
| 542 |
+
loss_accum += float(loss.item())
|
| 543 |
+
|
| 544 |
+
if skipped_micro == args.grad_accum:
|
| 545 |
+
# Whole step was bad. Skip the optimizer update.
|
| 546 |
+
scaler.update()
|
| 547 |
+
continue
|
| 548 |
+
|
| 549 |
+
scaler.unscale_(optimizer)
|
| 550 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
| 551 |
+
scaler.step(optimizer)
|
| 552 |
+
scaler.update()
|
| 553 |
+
|
| 554 |
+
if step % args.log_interval == 0 or step == start_step + 1:
|
| 555 |
+
now = time.time()
|
| 556 |
+
elapsed = max(now - last_log_time, 1e-9)
|
| 557 |
+
steps_done = max(1, step - last_log_step)
|
| 558 |
+
toks_per_sec = (tokens_per_step * steps_done) / elapsed
|
| 559 |
+
last_log_time = now
|
| 560 |
+
last_log_step = step
|
| 561 |
+
console.print(
|
| 562 |
+
f"step {step:7d} | "
|
| 563 |
+
f"loss {loss_accum:.4f} | "
|
| 564 |
+
f"lr {lr:.2e} | "
|
| 565 |
+
f"grad {float(grad_norm):.2f} | "
|
| 566 |
+
f"{toks_per_sec:,.0f} tok/s"
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
should_eval = step % args.eval_interval == 0 or step == args.max_steps
|
| 570 |
+
if should_eval:
|
| 571 |
+
losses = estimate_loss(model, train_data, val_data, args, device, autocast_ctx)
|
| 572 |
+
console.print(
|
| 573 |
+
f"[bold]eval step {step}:[/bold] "
|
| 574 |
+
f"train {losses['train']:.4f}, val {losses['val']:.4f}"
|
| 575 |
+
)
|
| 576 |
+
if losses["val"] < best_val_loss:
|
| 577 |
+
best_val_loss = losses["val"]
|
| 578 |
+
save_checkpoint(
|
| 579 |
+
args.out_dir / "best.pt",
|
| 580 |
+
model,
|
| 581 |
+
optimizer,
|
| 582 |
+
args,
|
| 583 |
+
step,
|
| 584 |
+
best_val_loss,
|
| 585 |
+
meta,
|
| 586 |
+
)
|
| 587 |
+
console.print(f"[green]saved best checkpoint: val {best_val_loss:.4f}[/green]")
|
| 588 |
+
if args.always_save_checkpoint:
|
| 589 |
+
save_checkpoint(
|
| 590 |
+
args.out_dir / f"step_{step}.pt",
|
| 591 |
+
model,
|
| 592 |
+
optimizer,
|
| 593 |
+
args,
|
| 594 |
+
step,
|
| 595 |
+
best_val_loss,
|
| 596 |
+
meta,
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
if step % args.save_interval == 0:
|
| 600 |
+
save_checkpoint(
|
| 601 |
+
args.out_dir / "latest.pt",
|
| 602 |
+
model,
|
| 603 |
+
optimizer,
|
| 604 |
+
args,
|
| 605 |
+
step,
|
| 606 |
+
best_val_loss,
|
| 607 |
+
meta,
|
| 608 |
+
)
|
| 609 |
+
console.print(f"[cyan]saved latest checkpoint at step {step}[/cyan]")
|
| 610 |
+
|
| 611 |
+
save_checkpoint(
|
| 612 |
+
args.out_dir / "latest.pt",
|
| 613 |
+
model,
|
| 614 |
+
optimizer,
|
| 615 |
+
args,
|
| 616 |
+
args.max_steps,
|
| 617 |
+
best_val_loss,
|
| 618 |
+
meta,
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
elapsed_hours = (time.time() - start_time) / 3600.0
|
| 622 |
+
console.print("")
|
| 623 |
+
console.print(f"[bold green]Finished in {elapsed_hours:.2f} hours.[/bold green]")
|
| 624 |
+
console.print(f"[bold green]Best validation loss: {best_val_loss:.4f}[/bold green]")
|
| 625 |
+
console.print(f"[bold green]Best checkpoint: {args.out_dir / 'best.pt'}[/bold green]")
|
| 626 |
+
console.print(f"[bold green]Latest checkpoint: {args.out_dir / 'latest.pt'}[/bold green]")
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
if __name__ == "__main__":
|
| 630 |
+
main()
|
config.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 32000,
|
| 3 |
+
"block_size": 512,
|
| 4 |
+
"n_layer": 12,
|
| 5 |
+
"n_head": 12,
|
| 6 |
+
"n_embd": 768,
|
| 7 |
+
"dropout": 0.0,
|
| 8 |
+
"bias": false,
|
| 9 |
+
"gradient_checkpointing": false,
|
| 10 |
+
"model_type": "ron-gpt",
|
| 11 |
+
"architectures": [
|
| 12 |
+
"GPT"
|
| 13 |
+
],
|
| 14 |
+
"torch_dtype": "float32"
|
| 15 |
+
}
|
meta.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"dataset": "Salesforce/wikitext",
|
| 3 |
+
"config": "wikitext-103-raw-v1",
|
| 4 |
+
"vocab_size": 32000,
|
| 5 |
+
"dtype": "uint16",
|
| 6 |
+
"bos_id": 1,
|
| 7 |
+
"eos_id": 2,
|
| 8 |
+
"pad_id": 0,
|
| 9 |
+
"train_tokens": 115671965,
|
| 10 |
+
"val_tokens": 242485,
|
| 11 |
+
"test_tokens": 276246
|
| 12 |
+
}
|
pretrain.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b2b0ce466438490a11b38092ed30993111987b58f6d8e08da64c262db1e0f476
|
| 3 |
+
size 1319159633
|
summarizer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4b665451d9bdb04205cafd88b0ef46a777204584cef3a037c3bd47f0598631e8
|
| 3 |
+
size 1319159633
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|