razor5050 commited on
Commit
ca2f8ca
·
verified ·
1 Parent(s): f15cc16

Add tokenizer, inference code, model card, and 20-query report

Browse files
20_query_inference_report.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tiny LLM 75M — 20 Query Inference Report
2
+
3
+ Repo: `razor5050/tinyllm-75m-openwebtext-chat`
4
+ Checkpoint: `final.pt`
5
+ SFT step: `781`
6
+ Parameters: `75,074,112`
7
+ State load: missing `0`, unexpected `0`
8
+ Elapsed: `218.86s`
9
+
10
+ ## Questions and answers
11
+
12
+ ### 1. What is photosynthesis?
13
+
14
+ Photosynthesis is a chemical process in which a molecule is formed and then into a form called a molecular structure. It is a chemical that is composed of the elements that are
15
+
16
+ ### 2. Why do leaves look green?
17
+
18
+ The word "green" is actually a word that is often used in various forms of environmental and environmental sciences. It is often used in various fields, such as agriculture, engineering,
19
+
20
+ ### 3. What is gravity?
21
+
22
+ Gravity is a phenomenon known as gravity. It occurs when a substance of gravity is formed in a shape or shape that is not present in a substance, and the substance is not
23
+
24
+ ### 4. Name one planet in our solar system.
25
+
26
+ The solar system was named from the first solar system, named by the name "Solar".
27
+
28
+ ### 5. What is the largest ocean on Earth?
29
+
30
+ The largest ocean on Earth is the Nusa-Eibagan, the largest ocean on Earth. It is a ocean in the southern Peninsula of the United States, and
31
+
32
+ ### 6. What gas do humans breathe in to survive?
33
+
34
+ A gas in water, the gas's core, is a fluid that is released through the water's chemical reactions, which can cause blood death and heart disease. A gas in water
35
+
36
+ ### 7. What does the heart do?
37
+
38
+ The heart is a system of organs that are designed to be designed to be designed to perform functions that perform functions that perform functions that perform functions that perform functions that perform functions that perform
39
+
40
+ ### 8. What is a volcano?
41
+
42
+ A volcano is a volcano that consists of a range of extreme events, such as volcanic eruptions, earthquakes, and earthquakes. It is a volcanic eruption that occurs within a range of
43
+
44
+ ### 9. What is the Moon?
45
+
46
+ The Moon is a vast and beautiful space that is packed with diverse cultures, diverse cultures, and diverse cultures. It is a vast, and well-designed space that can be accessed
47
+
48
+ ### 10. Why does rain fall from clouds?
49
+
50
+ The wind is a powerful force that can create storms and create storms. By rain fall, the wind causes the wind to move, which is the way the air is released from the
51
+
52
+ ### 11. What is electricity?
53
+
54
+ Electototototototototototototototototototototototototototototototototototot
55
+
56
+ ### 12. Name one continent.
57
+
58
+ A continent a continent in the America, United States, and United States.
59
+
60
+ ### 13. What is a river?
61
+
62
+ A river is a river that is a river that flows through the river. It is a river that is flowing through the river, making it a river that flows through rivers. It
63
+
64
+ ### 14. What animal is known as the king of the jungle?
65
+
66
+ The king of the jungle is the king of the jungle. This species is known as the "King of the jungle." The king of the jungle is the king of the jungle.
67
+
68
+ ### 15. What is the purpose of roots in a plant?
69
+
70
+ A plant is a plant that is a plant that is composed of a plant that is composed of a plant that is composed of a plant that is composed of a plant that is composed
71
+
72
+ ### 16. What is a computer?
73
+
74
+ A computer is a computer that is designed to monitor and monitor any specific data, including any relevant data, data, or data, or any relevant data, or any relevant data in
75
+
76
+ ### 17. What is the internet?
77
+
78
+ The Internet is a digital movement that spans various regions, each having its own unique and complex connections. The internet is a phenomenon that occurs through various methods, including internet-specific communication
79
+
80
+ ### 18. What is a vaccine?
81
+
82
+ A vaccine is an vaccine designed to promote the growth of the immune system, which regulates the immune system's activity, such as the immune system, immune cells, and immune cells
83
+
84
+ ### 19. What country is Tokyo in?
85
+
86
+ Japan is a world where Tokyo is known for its stunning water quality and its culture. It is a hub for the world's largest expesticive waste waste agency, providing a vital
87
+
88
+ ### 20. What is climate change?
89
+
90
+ Climate change is often associated with a significant increase in global temperature, associated with a rapid acceleration in the warming cycle, with a greater amount of heat and heat being added to the global
20_query_inference_results.json ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "meta": {
3
+ "repo": "razor5050/tinyllm-75m-openwebtext-chat",
4
+ "checkpoint": "final.pt",
5
+ "step": 781,
6
+ "parameters": 75074112,
7
+ "missing_keys": 0,
8
+ "unexpected_keys": 0
9
+ },
10
+ "elapsed_seconds": 218.86,
11
+ "results": [
12
+ {
13
+ "n": 1,
14
+ "question": "What is photosynthesis?",
15
+ "answer": "Photosynthesis is a chemical process in which a molecule is formed and then into a form called a molecular structure. It is a chemical that is composed of the elements that are"
16
+ },
17
+ {
18
+ "n": 2,
19
+ "question": "Why do leaves look green?",
20
+ "answer": "The word \"green\" is actually a word that is often used in various forms of environmental and environmental sciences. It is often used in various fields, such as agriculture, engineering,"
21
+ },
22
+ {
23
+ "n": 3,
24
+ "question": "What is gravity?",
25
+ "answer": "Gravity is a phenomenon known as gravity. It occurs when a substance of gravity is formed in a shape or shape that is not present in a substance, and the substance is not"
26
+ },
27
+ {
28
+ "n": 4,
29
+ "question": "Name one planet in our solar system.",
30
+ "answer": "The solar system was named from the first solar system, named by the name \"Solar\"."
31
+ },
32
+ {
33
+ "n": 5,
34
+ "question": "What is the largest ocean on Earth?",
35
+ "answer": "The largest ocean on Earth is the Nusa-Eibagan, the largest ocean on Earth. It is a ocean in the southern Peninsula of the United States, and"
36
+ },
37
+ {
38
+ "n": 6,
39
+ "question": "What gas do humans breathe in to survive?",
40
+ "answer": "A gas in water, the gas's core, is a fluid that is released through the water's chemical reactions, which can cause blood death and heart disease. A gas in water"
41
+ },
42
+ {
43
+ "n": 7,
44
+ "question": "What does the heart do?",
45
+ "answer": "The heart is a system of organs that are designed to be designed to be designed to perform functions that perform functions that perform functions that perform functions that perform functions that perform functions that perform"
46
+ },
47
+ {
48
+ "n": 8,
49
+ "question": "What is a volcano?",
50
+ "answer": "A volcano is a volcano that consists of a range of extreme events, such as volcanic eruptions, earthquakes, and earthquakes. It is a volcanic eruption that occurs within a range of"
51
+ },
52
+ {
53
+ "n": 9,
54
+ "question": "What is the Moon?",
55
+ "answer": "The Moon is a vast and beautiful space that is packed with diverse cultures, diverse cultures, and diverse cultures. It is a vast, and well-designed space that can be accessed"
56
+ },
57
+ {
58
+ "n": 10,
59
+ "question": "Why does rain fall from clouds?",
60
+ "answer": "The wind is a powerful force that can create storms and create storms. By rain fall, the wind causes the wind to move, which is the way the air is released from the"
61
+ },
62
+ {
63
+ "n": 11,
64
+ "question": "What is electricity?",
65
+ "answer": "Electototototototototototototototototototototototototototototototototototot"
66
+ },
67
+ {
68
+ "n": 12,
69
+ "question": "Name one continent.",
70
+ "answer": "A continent a continent in the America, United States, and United States."
71
+ },
72
+ {
73
+ "n": 13,
74
+ "question": "What is a river?",
75
+ "answer": "A river is a river that is a river that flows through the river. It is a river that is flowing through the river, making it a river that flows through rivers. It"
76
+ },
77
+ {
78
+ "n": 14,
79
+ "question": "What animal is known as the king of the jungle?",
80
+ "answer": "The king of the jungle is the king of the jungle. This species is known as the \"King of the jungle.\" The king of the jungle is the king of the jungle."
81
+ },
82
+ {
83
+ "n": 15,
84
+ "question": "What is the purpose of roots in a plant?",
85
+ "answer": "A plant is a plant that is a plant that is composed of a plant that is composed of a plant that is composed of a plant that is composed of a plant that is composed"
86
+ },
87
+ {
88
+ "n": 16,
89
+ "question": "What is a computer?",
90
+ "answer": "A computer is a computer that is designed to monitor and monitor any specific data, including any relevant data, data, or data, or any relevant data, or any relevant data in"
91
+ },
92
+ {
93
+ "n": 17,
94
+ "question": "What is the internet?",
95
+ "answer": "The Internet is a digital movement that spans various regions, each having its own unique and complex connections. The internet is a phenomenon that occurs through various methods, including internet-specific communication"
96
+ },
97
+ {
98
+ "n": 18,
99
+ "question": "What is a vaccine?",
100
+ "answer": "A vaccine is an vaccine designed to promote the growth of the immune system, which regulates the immune system's activity, such as the immune system, immune cells, and immune cells"
101
+ },
102
+ {
103
+ "n": 19,
104
+ "question": "What country is Tokyo in?",
105
+ "answer": "Japan is a world where Tokyo is known for its stunning water quality and its culture. It is a hub for the world's largest expesticive waste waste agency, providing a vital"
106
+ },
107
+ {
108
+ "n": 20,
109
+ "question": "What is climate change?",
110
+ "answer": "Climate change is often associated with a significant increase in global temperature, associated with a rapid acceleration in the warming cycle, with a greater amount of heat and heat being added to the global"
111
+ }
112
+ ]
113
+ }
README.md ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: mit
5
+ tags:
6
+ - tiny-llm
7
+ - causal-lm
8
+ - llama-like
9
+ - rope
10
+ - rmsnorm
11
+ - swiglu
12
+ - gqa
13
+ - openwebtext
14
+ - smoltalk
15
+ - pytorch
16
+ pipeline_tag: text-generation
17
+ library_name: pytorch
18
+ ---
19
+
20
+ # TinyLLM 75M OpenWebText Chat
21
+
22
+ This repository contains an experimental **75,074,112 parameter decoder-only tiny language model** trained from scratch/near-scratch and then supervised-finetuned for chat.
23
+
24
+ > **Important quality note:** This is a successful end-to-end training pipeline artifact and research toy model, not a production assistant. It can load and generate text, but factual accuracy, instruction following, arithmetic, and repetition control are weak.
25
+
26
+ ## Model summary
27
+
28
+ - **Model name:** `razor5050/tinyllm-75m-openwebtext-chat`
29
+ - **Architecture:** LLaMA/SmolLM-style decoder-only causal LM
30
+ - **Parameters:** 75,074,112
31
+ - **Context length:** 1024 tokens
32
+ - **Vocabulary:** 32,000 ByteLevel BPE tokens
33
+ - **Tokenizer:** custom ByteLevel BPE trained for this run
34
+ - **Checkpoint format:** PyTorch `.pt` checkpoints
35
+ - **Primary final checkpoint:** `final.pt`
36
+ - **Best checkpoint:** `best.pt`
37
+
38
+ ## Architecture
39
+
40
+ The model uses modern tiny-LM components:
41
+
42
+ - decoder-only causal Transformer
43
+ - RoPE positional embeddings
44
+ - RMSNorm
45
+ - SwiGLU MLP
46
+ - grouped-query/key-value reduction via fewer KV heads
47
+ - tied input/output token embeddings
48
+ - no attention/MLP bias
49
+ - PyTorch SDPA causal attention
50
+
51
+ Approximate config:
52
+
53
+ ```yaml
54
+ vocab_size: 32000
55
+ hidden_size: 576
56
+ num_hidden_layers: 16
57
+ num_attention_heads: 9
58
+ num_key_value_heads: 3
59
+ intermediate_size: 1536
60
+ max_position_embeddings: 1024
61
+ rope_theta: 10000.0
62
+ rms_norm_eps: 1e-5
63
+ tie_word_embeddings: true
64
+ attention_bias: false
65
+ mlp_bias: false
66
+ dropout: 0.0
67
+ ```
68
+
69
+ ## Training data
70
+
71
+ ### Base pretraining
72
+
73
+ - Dataset: [`Skylion007/openwebtext`](https://huggingface.co/datasets/Skylion007/openwebtext)
74
+ - Rows used: 1,000,000 selected rows
75
+ - Final tokenized train tokens: 1,143,301,833
76
+ - Final tokenized validation tokens: 34,486,473
77
+ - Epochs: 1
78
+ - Optimizer steps: 4,361
79
+
80
+ ### Chat/SFT
81
+
82
+ - Dataset: [`HuggingFaceTB/smol-smoltalk`](https://huggingface.co/datasets/HuggingFaceTB/smol-smoltalk)
83
+ - Train examples: 100,000
84
+ - Validation examples: 3,000
85
+ - Epochs: 1
86
+ - Optimizer steps: 781
87
+ - Loss masking: assistant-response tokens only
88
+
89
+ ## Training results
90
+
91
+ ### Pretraining
92
+
93
+ - Final/latest train loss near end: about `4.997`
94
+ - Latest validation loss: about `5.049` at step 4000
95
+
96
+ ### SFT
97
+
98
+ - SFT completed at step `781`
99
+ - Validation trend:
100
+ - step 250: `2.6031`
101
+ - step 500: `2.4505`
102
+ - step 750: `2.3313`
103
+
104
+ SFT improved chat formatting and response style, but the model remains very small and undertrained by modern assistant standards.
105
+
106
+ ## Hardware/run
107
+
108
+ - Cloud GPU: Vast.ai RTX 5070 Ti, 16GB VRAM
109
+ - Precision: CUDA/PyTorch mixed precision during training where supported
110
+ - Checkpointing: periodic `latest`, `best`, final, and step checkpoints
111
+ - Training artifacts were preserved separately outside the instance before teardown.
112
+
113
+ ## Files in this repo
114
+
115
+ - `final.pt` — final SFT checkpoint
116
+ - `best.pt` — best SFT checkpoint
117
+ - `latest.pt` — latest SFT checkpoint
118
+ - `metrics.jsonl` — SFT metrics
119
+ - `step_609.pt` — intermediate SFT checkpoint
120
+ - `tokenizer/vocab.json` and `tokenizer/merges.txt` — tokenizer files
121
+ - `configs/model_75m.yaml` — architecture config
122
+ - `src/tinyllm/` — minimal PyTorch model implementation
123
+ - `scripts/infer_tinyllm.py` — simple local inference helper
124
+
125
+ ## Quick inference
126
+
127
+ Clone/download the repo, install dependencies, then run:
128
+
129
+ ```bash
130
+ pip install torch tokenizers pyyaml huggingface_hub
131
+ python scripts/infer_tinyllm.py \
132
+ --checkpoint final.pt \
133
+ --prompt "What is the capital of France?"
134
+ ```
135
+
136
+ The chat prompt format used during SFT is:
137
+
138
+ ```text
139
+ <|system|>
140
+ You are a helpful, concise assistant.
141
+ <|end|>
142
+ <|user|>
143
+ USER_QUESTION
144
+ <|end|>
145
+ <|assistant|>
146
+ ```
147
+
148
+ ## Observed sample behavior
149
+
150
+ In a post-upload local inference test, the model generated text and loaded cleanly, but quality was mixed:
151
+
152
+ - Correct on: “What is the capital of France?” → answered Paris, with repetition.
153
+ - Weak on: simple science/world facts, often rambling or hallucinating.
154
+ - Weak on: arithmetic and short-answer discipline.
155
+ - Repetition and generic phrasing are common.
156
+
157
+ This is expected for a 75M-parameter scratch-trained model with about 1.14B pretraining tokens and one SFT pass.
158
+
159
+ ## Limitations
160
+
161
+ - Not suitable for factual QA or production use.
162
+ - Hallucinates frequently.
163
+ - Repetition loops occur.
164
+ - Arithmetic is unreliable.
165
+ - Safety behavior was not evaluated.
166
+ - Model is not aligned beyond basic supervised chat finetuning.
167
+ - The checkpoint is a custom PyTorch model, not a standard `transformers` model class.
168
+
169
+ ## Intended use
170
+
171
+ - Educational tiny-LLM experiment
172
+ - Pipeline validation
173
+ - Small-model architecture experimentation
174
+ - Baseline for future 150M+ runs
175
+
176
+ ## Recommended next steps
177
+
178
+ To improve quality meaningfully:
179
+
180
+ 1. Train a larger ~150M model.
181
+ 2. Use more unique pretraining tokens, e.g. ~5B+.
182
+ 3. Improve preprocessing/tokenization throughput with multiprocessing/sharding.
183
+ 4. Add stronger instruction data and possibly preference tuning.
184
+ 5. Export to a standard Hugging Face `transformers` compatible format.
185
+
186
+ ## Citation / attribution
187
+
188
+ Training datasets:
189
+
190
+ - `Skylion007/openwebtext`
191
+ - `HuggingFaceTB/smol-smoltalk`
192
+
193
+ This repository is an experimental model artifact from a custom tiny-LLM training pipeline.
configs/model_75m.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ vocab_size: 32000
3
+ hidden_size: 576
4
+ num_hidden_layers: 16
5
+ num_attention_heads: 9
6
+ num_key_value_heads: 3
7
+ intermediate_size: 1536
8
+ max_position_embeddings: 1024
9
+ rope_theta: 10000.0
10
+ rms_norm_eps: 1.0e-5
11
+ tie_word_embeddings: true
12
+ attention_bias: false
13
+ mlp_bias: false
14
+ dropout: 0.0
15
+ bos_token_id: 1
16
+ eos_token_id: 2
scripts/infer_tinyllm.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse, sys
3
+ from pathlib import Path
4
+ import torch
5
+ from tokenizers import ByteLevelBPETokenizer
6
+
7
+ # If running from the repo root, src/ is available locally.
8
+ sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
9
+ from src.tinyllm.config import TinyConfig
10
+ from src.tinyllm.model import TinyLlamaForCausalLM
11
+
12
+
13
+ def make_prompt(user_prompt: str, system: str) -> str:
14
+ return f"<|system|>\n{system}\n<|end|>\n<|user|>\n{user_prompt}\n<|end|>\n<|assistant|>\n"
15
+
16
+
17
+ def sample_next(logits, temperature: float, top_k: int):
18
+ logits = logits.float()
19
+ if temperature <= 0:
20
+ return int(torch.argmax(logits))
21
+ logits = logits / temperature
22
+ if top_k and top_k > 0:
23
+ vals, idx = torch.topk(logits, min(top_k, logits.numel()))
24
+ probs = torch.softmax(vals, dim=-1)
25
+ return int(idx[torch.multinomial(probs, 1)])
26
+ probs = torch.softmax(logits, dim=-1)
27
+ return int(torch.multinomial(probs, 1))
28
+
29
+
30
+ def main():
31
+ ap = argparse.ArgumentParser()
32
+ ap.add_argument('--checkpoint', default='final.pt')
33
+ ap.add_argument('--config', default='configs/model_75m.yaml')
34
+ ap.add_argument('--tokenizer-dir', default='tokenizer')
35
+ ap.add_argument('--prompt', required=True)
36
+ ap.add_argument('--system', default='You are a helpful, concise assistant.')
37
+ ap.add_argument('--max-new-tokens', type=int, default=80)
38
+ ap.add_argument('--temperature', type=float, default=0.6)
39
+ ap.add_argument('--top-k', type=int, default=40)
40
+ args = ap.parse_args()
41
+
42
+ tok_dir = Path(args.tokenizer_dir)
43
+ tok = ByteLevelBPETokenizer(str(tok_dir / 'vocab.json'), str(tok_dir / 'merges.txt'))
44
+ cfg = TinyConfig.from_yaml(args.config)
45
+ model = TinyLlamaForCausalLM(cfg)
46
+ ckpt = torch.load(args.checkpoint, map_location='cpu')
47
+ state = ckpt['model'] if isinstance(ckpt, dict) and 'model' in ckpt else ckpt
48
+ model.load_state_dict(state, strict=False)
49
+ model.eval()
50
+
51
+ ids = tok.encode(make_prompt(args.prompt, args.system)).ids
52
+ prompt_len = len(ids)
53
+ end_id = tok.token_to_id('<|end|>')
54
+ for _ in range(args.max_new_tokens):
55
+ x = torch.tensor([ids[-cfg.max_position_embeddings:]], dtype=torch.long)
56
+ with torch.no_grad():
57
+ logits = model(x)['logits'][0, -1]
58
+ nxt = sample_next(logits, args.temperature, args.top_k)
59
+ ids.append(nxt)
60
+ if end_id is not None and nxt == end_id:
61
+ break
62
+
63
+ text = tok.decode(ids[prompt_len:])
64
+ for marker in ['<|end|>', '<|user|>', '<|assistant|>', '<|system|>']:
65
+ text = text.split(marker)[0]
66
+ print(text.strip())
67
+
68
+
69
+ if __name__ == '__main__':
70
+ main()
src/__init__.py ADDED
File without changes
src/tinyllm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = '0.1.0'
src/tinyllm/checkpoint.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import torch, time, shutil
3
+
4
+ def save_checkpoint(path, model, optimizer=None, scheduler=None, step=0, epoch=0, metrics=None, config=None):
5
+ path=Path(path); path.parent.mkdir(parents=True, exist_ok=True)
6
+ tmp=path.with_suffix(path.suffix+'.tmp')
7
+ state={'model':model.state_dict(),'step':step,'epoch':epoch,'metrics':metrics or {},'saved_at':time.time(),'config':config}
8
+ if optimizer is not None: state['optimizer']=optimizer.state_dict()
9
+ if scheduler is not None: state['scheduler']=scheduler.state_dict()
10
+ torch.save(state,tmp); tmp.replace(path)
11
+ latest=path.parent/'latest.pt'
12
+ if latest != path: shutil.copy2(path, latest)
13
+
14
+ def load_checkpoint(path, model, optimizer=None, scheduler=None, map_location='cpu'):
15
+ state=torch.load(path,map_location=map_location)
16
+ model.load_state_dict(state['model'])
17
+ if optimizer is not None and 'optimizer' in state: optimizer.load_state_dict(state['optimizer'])
18
+ if scheduler is not None and 'scheduler' in state: scheduler.load_state_dict(state['scheduler'])
19
+ return state
src/tinyllm/config.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import yaml
3
+ from pathlib import Path
4
+
5
+ @dataclass
6
+ class TinyConfig:
7
+ vocab_size: int = 32000
8
+ hidden_size: int = 576
9
+ num_hidden_layers: int = 16
10
+ num_attention_heads: int = 9
11
+ num_key_value_heads: int = 3
12
+ intermediate_size: int = 1536
13
+ max_position_embeddings: int = 1024
14
+ rope_theta: float = 10000.0
15
+ rms_norm_eps: float = 1e-5
16
+ tie_word_embeddings: bool = True
17
+ attention_bias: bool = False
18
+ mlp_bias: bool = False
19
+ dropout: float = 0.0
20
+ bos_token_id: int = 1
21
+ eos_token_id: int = 2
22
+
23
+ @classmethod
24
+ def from_yaml(cls, path):
25
+ data = yaml.safe_load(Path(path).read_text())
26
+ if 'model' in data:
27
+ data = data['model']
28
+ return cls(**data)
29
+
30
+ def load_yaml(path):
31
+ return yaml.safe_load(Path(path).read_text())
src/tinyllm/data.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import numpy as np
3
+
4
+ def save_tokens_bin(tokens, path, dtype=np.uint16):
5
+ path=Path(path).expanduser(); path.parent.mkdir(parents=True,exist_ok=True)
6
+ np.asarray(tokens,dtype=dtype).tofile(path)
7
+
8
+ def load_tokens_bin(path, dtype=np.uint16):
9
+ return np.memmap(Path(path).expanduser(), dtype=dtype, mode='r')
10
+
11
+ class TokenBlockDataset:
12
+ def __init__(self, bin_path, block_size):
13
+ self.data=load_tokens_bin(bin_path); self.block_size=block_size
14
+ def __len__(self): return max(0, len(self.data)-self.block_size-1)
15
+ def get_batch(self, batch_size, device):
16
+ import torch
17
+ ix=torch.randint(len(self),(batch_size,))
18
+ x=torch.stack([torch.from_numpy(np.array(self.data[i:i+self.block_size], dtype=np.int64)) for i in ix])
19
+ y=torch.stack([torch.from_numpy(np.array(self.data[i+1:i+1+self.block_size], dtype=np.int64)) for i in ix])
20
+ return x.to(device), y.to(device)
src/tinyllm/metrics.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, csv, time
2
+ from pathlib import Path
3
+
4
+ class MetricsLogger:
5
+ def __init__(self, path):
6
+ self.path=Path(path); self.path.parent.mkdir(parents=True,exist_ok=True)
7
+ def log(self, **kw):
8
+ kw.setdefault('time', time.time())
9
+ with self.path.open('a') as f: f.write(json.dumps(kw)+'\n')
10
+
11
+ def jsonl_to_csv(jsonl_path, csv_path):
12
+ rows=[json.loads(l) for l in Path(jsonl_path).read_text().splitlines() if l.strip()]
13
+ if not rows: return
14
+ keys=sorted(set().union(*(r.keys() for r in rows)))
15
+ with open(csv_path,'w',newline='') as f:
16
+ w=csv.DictWriter(f,fieldnames=keys); w.writeheader(); w.writerows(rows)
src/tinyllm/model.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from dataclasses import asdict
5
+ from .config import TinyConfig
6
+
7
+ class RMSNorm(nn.Module):
8
+ def __init__(self, dim, eps=1e-5):
9
+ super().__init__(); self.eps=eps; self.weight=nn.Parameter(torch.ones(dim))
10
+ def forward(self, x):
11
+ return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
12
+
13
+ def rotate_half(x):
14
+ x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
15
+ return torch.cat((-x2, x1), dim=-1)
16
+
17
+ class RotaryEmbedding(nn.Module):
18
+ def __init__(self, dim, max_position_embeddings=2048, base=10000.0):
19
+ super().__init__()
20
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
21
+ t = torch.arange(max_position_embeddings, dtype=torch.float)
22
+ freqs = torch.einsum('i,j->ij', t, inv_freq)
23
+ emb = torch.cat((freqs, freqs), dim=-1)
24
+ self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
25
+ self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
26
+ def forward(self, q, k, seq_len):
27
+ cos = self.cos_cached[:, :, :seq_len, :].to(q.device, q.dtype)
28
+ sin = self.sin_cached[:, :, :seq_len, :].to(q.device, q.dtype)
29
+ return (q*cos + rotate_half(q)*sin), (k*cos + rotate_half(k)*sin)
30
+
31
+ class CausalSelfAttention(nn.Module):
32
+ def __init__(self, cfg: TinyConfig):
33
+ super().__init__()
34
+ assert cfg.hidden_size % cfg.num_attention_heads == 0
35
+ assert cfg.num_attention_heads % cfg.num_key_value_heads == 0
36
+ self.nh=cfg.num_attention_heads; self.nkv=cfg.num_key_value_heads
37
+ self.hd=cfg.hidden_size//cfg.num_attention_heads
38
+ self.q_proj=nn.Linear(cfg.hidden_size, self.nh*self.hd, bias=cfg.attention_bias)
39
+ self.k_proj=nn.Linear(cfg.hidden_size, self.nkv*self.hd, bias=cfg.attention_bias)
40
+ self.v_proj=nn.Linear(cfg.hidden_size, self.nkv*self.hd, bias=cfg.attention_bias)
41
+ self.o_proj=nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=cfg.attention_bias)
42
+ self.rotary=RotaryEmbedding(self.hd, cfg.max_position_embeddings, cfg.rope_theta)
43
+ self.dropout=cfg.dropout
44
+ def forward(self, x):
45
+ B,T,C=x.shape
46
+ q=self.q_proj(x).view(B,T,self.nh,self.hd).transpose(1,2)
47
+ k=self.k_proj(x).view(B,T,self.nkv,self.hd).transpose(1,2)
48
+ v=self.v_proj(x).view(B,T,self.nkv,self.hd).transpose(1,2)
49
+ q,k=self.rotary(q,k,T)
50
+ if self.nkv != self.nh:
51
+ repeat = self.nh // self.nkv
52
+ k = k.repeat_interleave(repeat, dim=1)
53
+ v = v.repeat_interleave(repeat, dim=1)
54
+ y=F.scaled_dot_product_attention(q,k,v,dropout_p=self.dropout if self.training else 0.0,is_causal=True)
55
+ y=y.transpose(1,2).contiguous().view(B,T,C)
56
+ return self.o_proj(y)
57
+
58
+ class SwiGLU(nn.Module):
59
+ def __init__(self, cfg: TinyConfig):
60
+ super().__init__()
61
+ self.gate_proj=nn.Linear(cfg.hidden_size,cfg.intermediate_size,bias=cfg.mlp_bias)
62
+ self.up_proj=nn.Linear(cfg.hidden_size,cfg.intermediate_size,bias=cfg.mlp_bias)
63
+ self.down_proj=nn.Linear(cfg.intermediate_size,cfg.hidden_size,bias=cfg.mlp_bias)
64
+ def forward(self,x):
65
+ return self.down_proj(F.silu(self.gate_proj(x))*self.up_proj(x))
66
+
67
+ class Block(nn.Module):
68
+ def __init__(self,cfg):
69
+ super().__init__(); self.input_norm=RMSNorm(cfg.hidden_size,cfg.rms_norm_eps); self.attn=CausalSelfAttention(cfg); self.post_norm=RMSNorm(cfg.hidden_size,cfg.rms_norm_eps); self.mlp=SwiGLU(cfg)
70
+ def forward(self,x):
71
+ x=x+self.attn(self.input_norm(x)); x=x+self.mlp(self.post_norm(x)); return x
72
+
73
+ class TinyLlamaForCausalLM(nn.Module):
74
+ def __init__(self,cfg:TinyConfig):
75
+ super().__init__(); self.config=cfg
76
+ self.embed_tokens=nn.Embedding(cfg.vocab_size,cfg.hidden_size)
77
+ self.layers=nn.ModuleList([Block(cfg) for _ in range(cfg.num_hidden_layers)])
78
+ self.norm=RMSNorm(cfg.hidden_size,cfg.rms_norm_eps)
79
+ self.lm_head=nn.Linear(cfg.hidden_size,cfg.vocab_size,bias=False)
80
+ if cfg.tie_word_embeddings:
81
+ self.lm_head.weight=self.embed_tokens.weight
82
+ self.apply(self._init_weights)
83
+ def _init_weights(self,m):
84
+ if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=0.0, std=0.02)
85
+ if isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02)
86
+ def forward(self,input_ids,labels=None,loss_mask=None):
87
+ x=self.embed_tokens(input_ids)
88
+ for layer in self.layers: x=layer(x)
89
+ logits=self.lm_head(self.norm(x))
90
+ loss=None
91
+ if labels is not None:
92
+ shift_logits=logits[:,:-1,:].contiguous(); shift_labels=labels[:,1:].contiguous()
93
+ per=F.cross_entropy(shift_logits.view(-1,shift_logits.size(-1)),shift_labels.view(-1),reduction='none')
94
+ if loss_mask is not None:
95
+ mask=loss_mask[:,1:].contiguous().view(-1).float(); loss=(per*mask).sum()/mask.sum().clamp_min(1.0)
96
+ else: loss=per.mean()
97
+ return {'loss':loss,'logits':logits}
98
+ def num_parameters(self): return sum(p.numel() for p in self.parameters())
99
+ def save_config(self,path):
100
+ import json; open(path,'w').write(json.dumps(asdict(self.config),indent=2))
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff