Add tokenizer, inference code, model card, and 20-query report
Browse files- 20_query_inference_report.md +90 -0
- 20_query_inference_results.json +113 -0
- README.md +193 -0
- configs/model_75m.yaml +16 -0
- scripts/infer_tinyllm.py +70 -0
- src/__init__.py +0 -0
- src/tinyllm/__init__.py +1 -0
- src/tinyllm/checkpoint.py +19 -0
- src/tinyllm/config.py +31 -0
- src/tinyllm/data.py +20 -0
- src/tinyllm/metrics.py +16 -0
- src/tinyllm/model.py +100 -0
- tokenizer/merges.txt +0 -0
- tokenizer/vocab.json +0 -0
20_query_inference_report.md
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Tiny LLM 75M — 20 Query Inference Report
|
| 2 |
+
|
| 3 |
+
Repo: `razor5050/tinyllm-75m-openwebtext-chat`
|
| 4 |
+
Checkpoint: `final.pt`
|
| 5 |
+
SFT step: `781`
|
| 6 |
+
Parameters: `75,074,112`
|
| 7 |
+
State load: missing `0`, unexpected `0`
|
| 8 |
+
Elapsed: `218.86s`
|
| 9 |
+
|
| 10 |
+
## Questions and answers
|
| 11 |
+
|
| 12 |
+
### 1. What is photosynthesis?
|
| 13 |
+
|
| 14 |
+
Photosynthesis is a chemical process in which a molecule is formed and then into a form called a molecular structure. It is a chemical that is composed of the elements that are
|
| 15 |
+
|
| 16 |
+
### 2. Why do leaves look green?
|
| 17 |
+
|
| 18 |
+
The word "green" is actually a word that is often used in various forms of environmental and environmental sciences. It is often used in various fields, such as agriculture, engineering,
|
| 19 |
+
|
| 20 |
+
### 3. What is gravity?
|
| 21 |
+
|
| 22 |
+
Gravity is a phenomenon known as gravity. It occurs when a substance of gravity is formed in a shape or shape that is not present in a substance, and the substance is not
|
| 23 |
+
|
| 24 |
+
### 4. Name one planet in our solar system.
|
| 25 |
+
|
| 26 |
+
The solar system was named from the first solar system, named by the name "Solar".
|
| 27 |
+
|
| 28 |
+
### 5. What is the largest ocean on Earth?
|
| 29 |
+
|
| 30 |
+
The largest ocean on Earth is the Nusa-Eibagan, the largest ocean on Earth. It is a ocean in the southern Peninsula of the United States, and
|
| 31 |
+
|
| 32 |
+
### 6. What gas do humans breathe in to survive?
|
| 33 |
+
|
| 34 |
+
A gas in water, the gas's core, is a fluid that is released through the water's chemical reactions, which can cause blood death and heart disease. A gas in water
|
| 35 |
+
|
| 36 |
+
### 7. What does the heart do?
|
| 37 |
+
|
| 38 |
+
The heart is a system of organs that are designed to be designed to be designed to perform functions that perform functions that perform functions that perform functions that perform functions that perform functions that perform
|
| 39 |
+
|
| 40 |
+
### 8. What is a volcano?
|
| 41 |
+
|
| 42 |
+
A volcano is a volcano that consists of a range of extreme events, such as volcanic eruptions, earthquakes, and earthquakes. It is a volcanic eruption that occurs within a range of
|
| 43 |
+
|
| 44 |
+
### 9. What is the Moon?
|
| 45 |
+
|
| 46 |
+
The Moon is a vast and beautiful space that is packed with diverse cultures, diverse cultures, and diverse cultures. It is a vast, and well-designed space that can be accessed
|
| 47 |
+
|
| 48 |
+
### 10. Why does rain fall from clouds?
|
| 49 |
+
|
| 50 |
+
The wind is a powerful force that can create storms and create storms. By rain fall, the wind causes the wind to move, which is the way the air is released from the
|
| 51 |
+
|
| 52 |
+
### 11. What is electricity?
|
| 53 |
+
|
| 54 |
+
Electototototototototototototototototototototototototototototototototototot
|
| 55 |
+
|
| 56 |
+
### 12. Name one continent.
|
| 57 |
+
|
| 58 |
+
A continent a continent in the America, United States, and United States.
|
| 59 |
+
|
| 60 |
+
### 13. What is a river?
|
| 61 |
+
|
| 62 |
+
A river is a river that is a river that flows through the river. It is a river that is flowing through the river, making it a river that flows through rivers. It
|
| 63 |
+
|
| 64 |
+
### 14. What animal is known as the king of the jungle?
|
| 65 |
+
|
| 66 |
+
The king of the jungle is the king of the jungle. This species is known as the "King of the jungle." The king of the jungle is the king of the jungle.
|
| 67 |
+
|
| 68 |
+
### 15. What is the purpose of roots in a plant?
|
| 69 |
+
|
| 70 |
+
A plant is a plant that is a plant that is composed of a plant that is composed of a plant that is composed of a plant that is composed of a plant that is composed
|
| 71 |
+
|
| 72 |
+
### 16. What is a computer?
|
| 73 |
+
|
| 74 |
+
A computer is a computer that is designed to monitor and monitor any specific data, including any relevant data, data, or data, or any relevant data, or any relevant data in
|
| 75 |
+
|
| 76 |
+
### 17. What is the internet?
|
| 77 |
+
|
| 78 |
+
The Internet is a digital movement that spans various regions, each having its own unique and complex connections. The internet is a phenomenon that occurs through various methods, including internet-specific communication
|
| 79 |
+
|
| 80 |
+
### 18. What is a vaccine?
|
| 81 |
+
|
| 82 |
+
A vaccine is an vaccine designed to promote the growth of the immune system, which regulates the immune system's activity, such as the immune system, immune cells, and immune cells
|
| 83 |
+
|
| 84 |
+
### 19. What country is Tokyo in?
|
| 85 |
+
|
| 86 |
+
Japan is a world where Tokyo is known for its stunning water quality and its culture. It is a hub for the world's largest expesticive waste waste agency, providing a vital
|
| 87 |
+
|
| 88 |
+
### 20. What is climate change?
|
| 89 |
+
|
| 90 |
+
Climate change is often associated with a significant increase in global temperature, associated with a rapid acceleration in the warming cycle, with a greater amount of heat and heat being added to the global
|
20_query_inference_results.json
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"meta": {
|
| 3 |
+
"repo": "razor5050/tinyllm-75m-openwebtext-chat",
|
| 4 |
+
"checkpoint": "final.pt",
|
| 5 |
+
"step": 781,
|
| 6 |
+
"parameters": 75074112,
|
| 7 |
+
"missing_keys": 0,
|
| 8 |
+
"unexpected_keys": 0
|
| 9 |
+
},
|
| 10 |
+
"elapsed_seconds": 218.86,
|
| 11 |
+
"results": [
|
| 12 |
+
{
|
| 13 |
+
"n": 1,
|
| 14 |
+
"question": "What is photosynthesis?",
|
| 15 |
+
"answer": "Photosynthesis is a chemical process in which a molecule is formed and then into a form called a molecular structure. It is a chemical that is composed of the elements that are"
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"n": 2,
|
| 19 |
+
"question": "Why do leaves look green?",
|
| 20 |
+
"answer": "The word \"green\" is actually a word that is often used in various forms of environmental and environmental sciences. It is often used in various fields, such as agriculture, engineering,"
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"n": 3,
|
| 24 |
+
"question": "What is gravity?",
|
| 25 |
+
"answer": "Gravity is a phenomenon known as gravity. It occurs when a substance of gravity is formed in a shape or shape that is not present in a substance, and the substance is not"
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"n": 4,
|
| 29 |
+
"question": "Name one planet in our solar system.",
|
| 30 |
+
"answer": "The solar system was named from the first solar system, named by the name \"Solar\"."
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"n": 5,
|
| 34 |
+
"question": "What is the largest ocean on Earth?",
|
| 35 |
+
"answer": "The largest ocean on Earth is the Nusa-Eibagan, the largest ocean on Earth. It is a ocean in the southern Peninsula of the United States, and"
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"n": 6,
|
| 39 |
+
"question": "What gas do humans breathe in to survive?",
|
| 40 |
+
"answer": "A gas in water, the gas's core, is a fluid that is released through the water's chemical reactions, which can cause blood death and heart disease. A gas in water"
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"n": 7,
|
| 44 |
+
"question": "What does the heart do?",
|
| 45 |
+
"answer": "The heart is a system of organs that are designed to be designed to be designed to perform functions that perform functions that perform functions that perform functions that perform functions that perform functions that perform"
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"n": 8,
|
| 49 |
+
"question": "What is a volcano?",
|
| 50 |
+
"answer": "A volcano is a volcano that consists of a range of extreme events, such as volcanic eruptions, earthquakes, and earthquakes. It is a volcanic eruption that occurs within a range of"
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"n": 9,
|
| 54 |
+
"question": "What is the Moon?",
|
| 55 |
+
"answer": "The Moon is a vast and beautiful space that is packed with diverse cultures, diverse cultures, and diverse cultures. It is a vast, and well-designed space that can be accessed"
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"n": 10,
|
| 59 |
+
"question": "Why does rain fall from clouds?",
|
| 60 |
+
"answer": "The wind is a powerful force that can create storms and create storms. By rain fall, the wind causes the wind to move, which is the way the air is released from the"
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"n": 11,
|
| 64 |
+
"question": "What is electricity?",
|
| 65 |
+
"answer": "Electototototototototototototototototototototototototototototototototototot"
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"n": 12,
|
| 69 |
+
"question": "Name one continent.",
|
| 70 |
+
"answer": "A continent a continent in the America, United States, and United States."
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"n": 13,
|
| 74 |
+
"question": "What is a river?",
|
| 75 |
+
"answer": "A river is a river that is a river that flows through the river. It is a river that is flowing through the river, making it a river that flows through rivers. It"
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"n": 14,
|
| 79 |
+
"question": "What animal is known as the king of the jungle?",
|
| 80 |
+
"answer": "The king of the jungle is the king of the jungle. This species is known as the \"King of the jungle.\" The king of the jungle is the king of the jungle."
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"n": 15,
|
| 84 |
+
"question": "What is the purpose of roots in a plant?",
|
| 85 |
+
"answer": "A plant is a plant that is a plant that is composed of a plant that is composed of a plant that is composed of a plant that is composed of a plant that is composed"
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"n": 16,
|
| 89 |
+
"question": "What is a computer?",
|
| 90 |
+
"answer": "A computer is a computer that is designed to monitor and monitor any specific data, including any relevant data, data, or data, or any relevant data, or any relevant data in"
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"n": 17,
|
| 94 |
+
"question": "What is the internet?",
|
| 95 |
+
"answer": "The Internet is a digital movement that spans various regions, each having its own unique and complex connections. The internet is a phenomenon that occurs through various methods, including internet-specific communication"
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"n": 18,
|
| 99 |
+
"question": "What is a vaccine?",
|
| 100 |
+
"answer": "A vaccine is an vaccine designed to promote the growth of the immune system, which regulates the immune system's activity, such as the immune system, immune cells, and immune cells"
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"n": 19,
|
| 104 |
+
"question": "What country is Tokyo in?",
|
| 105 |
+
"answer": "Japan is a world where Tokyo is known for its stunning water quality and its culture. It is a hub for the world's largest expesticive waste waste agency, providing a vital"
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"n": 20,
|
| 109 |
+
"question": "What is climate change?",
|
| 110 |
+
"answer": "Climate change is often associated with a significant increase in global temperature, associated with a rapid acceleration in the warming cycle, with a greater amount of heat and heat being added to the global"
|
| 111 |
+
}
|
| 112 |
+
]
|
| 113 |
+
}
|
README.md
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
license: mit
|
| 5 |
+
tags:
|
| 6 |
+
- tiny-llm
|
| 7 |
+
- causal-lm
|
| 8 |
+
- llama-like
|
| 9 |
+
- rope
|
| 10 |
+
- rmsnorm
|
| 11 |
+
- swiglu
|
| 12 |
+
- gqa
|
| 13 |
+
- openwebtext
|
| 14 |
+
- smoltalk
|
| 15 |
+
- pytorch
|
| 16 |
+
pipeline_tag: text-generation
|
| 17 |
+
library_name: pytorch
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
# TinyLLM 75M OpenWebText Chat
|
| 21 |
+
|
| 22 |
+
This repository contains an experimental **75,074,112 parameter decoder-only tiny language model** trained from scratch/near-scratch and then supervised-finetuned for chat.
|
| 23 |
+
|
| 24 |
+
> **Important quality note:** This is a successful end-to-end training pipeline artifact and research toy model, not a production assistant. It can load and generate text, but factual accuracy, instruction following, arithmetic, and repetition control are weak.
|
| 25 |
+
|
| 26 |
+
## Model summary
|
| 27 |
+
|
| 28 |
+
- **Model name:** `razor5050/tinyllm-75m-openwebtext-chat`
|
| 29 |
+
- **Architecture:** LLaMA/SmolLM-style decoder-only causal LM
|
| 30 |
+
- **Parameters:** 75,074,112
|
| 31 |
+
- **Context length:** 1024 tokens
|
| 32 |
+
- **Vocabulary:** 32,000 ByteLevel BPE tokens
|
| 33 |
+
- **Tokenizer:** custom ByteLevel BPE trained for this run
|
| 34 |
+
- **Checkpoint format:** PyTorch `.pt` checkpoints
|
| 35 |
+
- **Primary final checkpoint:** `final.pt`
|
| 36 |
+
- **Best checkpoint:** `best.pt`
|
| 37 |
+
|
| 38 |
+
## Architecture
|
| 39 |
+
|
| 40 |
+
The model uses modern tiny-LM components:
|
| 41 |
+
|
| 42 |
+
- decoder-only causal Transformer
|
| 43 |
+
- RoPE positional embeddings
|
| 44 |
+
- RMSNorm
|
| 45 |
+
- SwiGLU MLP
|
| 46 |
+
- grouped-query/key-value reduction via fewer KV heads
|
| 47 |
+
- tied input/output token embeddings
|
| 48 |
+
- no attention/MLP bias
|
| 49 |
+
- PyTorch SDPA causal attention
|
| 50 |
+
|
| 51 |
+
Approximate config:
|
| 52 |
+
|
| 53 |
+
```yaml
|
| 54 |
+
vocab_size: 32000
|
| 55 |
+
hidden_size: 576
|
| 56 |
+
num_hidden_layers: 16
|
| 57 |
+
num_attention_heads: 9
|
| 58 |
+
num_key_value_heads: 3
|
| 59 |
+
intermediate_size: 1536
|
| 60 |
+
max_position_embeddings: 1024
|
| 61 |
+
rope_theta: 10000.0
|
| 62 |
+
rms_norm_eps: 1e-5
|
| 63 |
+
tie_word_embeddings: true
|
| 64 |
+
attention_bias: false
|
| 65 |
+
mlp_bias: false
|
| 66 |
+
dropout: 0.0
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## Training data
|
| 70 |
+
|
| 71 |
+
### Base pretraining
|
| 72 |
+
|
| 73 |
+
- Dataset: [`Skylion007/openwebtext`](https://huggingface.co/datasets/Skylion007/openwebtext)
|
| 74 |
+
- Rows used: 1,000,000 selected rows
|
| 75 |
+
- Final tokenized train tokens: 1,143,301,833
|
| 76 |
+
- Final tokenized validation tokens: 34,486,473
|
| 77 |
+
- Epochs: 1
|
| 78 |
+
- Optimizer steps: 4,361
|
| 79 |
+
|
| 80 |
+
### Chat/SFT
|
| 81 |
+
|
| 82 |
+
- Dataset: [`HuggingFaceTB/smol-smoltalk`](https://huggingface.co/datasets/HuggingFaceTB/smol-smoltalk)
|
| 83 |
+
- Train examples: 100,000
|
| 84 |
+
- Validation examples: 3,000
|
| 85 |
+
- Epochs: 1
|
| 86 |
+
- Optimizer steps: 781
|
| 87 |
+
- Loss masking: assistant-response tokens only
|
| 88 |
+
|
| 89 |
+
## Training results
|
| 90 |
+
|
| 91 |
+
### Pretraining
|
| 92 |
+
|
| 93 |
+
- Final/latest train loss near end: about `4.997`
|
| 94 |
+
- Latest validation loss: about `5.049` at step 4000
|
| 95 |
+
|
| 96 |
+
### SFT
|
| 97 |
+
|
| 98 |
+
- SFT completed at step `781`
|
| 99 |
+
- Validation trend:
|
| 100 |
+
- step 250: `2.6031`
|
| 101 |
+
- step 500: `2.4505`
|
| 102 |
+
- step 750: `2.3313`
|
| 103 |
+
|
| 104 |
+
SFT improved chat formatting and response style, but the model remains very small and undertrained by modern assistant standards.
|
| 105 |
+
|
| 106 |
+
## Hardware/run
|
| 107 |
+
|
| 108 |
+
- Cloud GPU: Vast.ai RTX 5070 Ti, 16GB VRAM
|
| 109 |
+
- Precision: CUDA/PyTorch mixed precision during training where supported
|
| 110 |
+
- Checkpointing: periodic `latest`, `best`, final, and step checkpoints
|
| 111 |
+
- Training artifacts were preserved separately outside the instance before teardown.
|
| 112 |
+
|
| 113 |
+
## Files in this repo
|
| 114 |
+
|
| 115 |
+
- `final.pt` — final SFT checkpoint
|
| 116 |
+
- `best.pt` — best SFT checkpoint
|
| 117 |
+
- `latest.pt` — latest SFT checkpoint
|
| 118 |
+
- `metrics.jsonl` — SFT metrics
|
| 119 |
+
- `step_609.pt` — intermediate SFT checkpoint
|
| 120 |
+
- `tokenizer/vocab.json` and `tokenizer/merges.txt` — tokenizer files
|
| 121 |
+
- `configs/model_75m.yaml` — architecture config
|
| 122 |
+
- `src/tinyllm/` — minimal PyTorch model implementation
|
| 123 |
+
- `scripts/infer_tinyllm.py` — simple local inference helper
|
| 124 |
+
|
| 125 |
+
## Quick inference
|
| 126 |
+
|
| 127 |
+
Clone/download the repo, install dependencies, then run:
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
pip install torch tokenizers pyyaml huggingface_hub
|
| 131 |
+
python scripts/infer_tinyllm.py \
|
| 132 |
+
--checkpoint final.pt \
|
| 133 |
+
--prompt "What is the capital of France?"
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
The chat prompt format used during SFT is:
|
| 137 |
+
|
| 138 |
+
```text
|
| 139 |
+
<|system|>
|
| 140 |
+
You are a helpful, concise assistant.
|
| 141 |
+
<|end|>
|
| 142 |
+
<|user|>
|
| 143 |
+
USER_QUESTION
|
| 144 |
+
<|end|>
|
| 145 |
+
<|assistant|>
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
## Observed sample behavior
|
| 149 |
+
|
| 150 |
+
In a post-upload local inference test, the model generated text and loaded cleanly, but quality was mixed:
|
| 151 |
+
|
| 152 |
+
- Correct on: “What is the capital of France?” → answered Paris, with repetition.
|
| 153 |
+
- Weak on: simple science/world facts, often rambling or hallucinating.
|
| 154 |
+
- Weak on: arithmetic and short-answer discipline.
|
| 155 |
+
- Repetition and generic phrasing are common.
|
| 156 |
+
|
| 157 |
+
This is expected for a 75M-parameter scratch-trained model with about 1.14B pretraining tokens and one SFT pass.
|
| 158 |
+
|
| 159 |
+
## Limitations
|
| 160 |
+
|
| 161 |
+
- Not suitable for factual QA or production use.
|
| 162 |
+
- Hallucinates frequently.
|
| 163 |
+
- Repetition loops occur.
|
| 164 |
+
- Arithmetic is unreliable.
|
| 165 |
+
- Safety behavior was not evaluated.
|
| 166 |
+
- Model is not aligned beyond basic supervised chat finetuning.
|
| 167 |
+
- The checkpoint is a custom PyTorch model, not a standard `transformers` model class.
|
| 168 |
+
|
| 169 |
+
## Intended use
|
| 170 |
+
|
| 171 |
+
- Educational tiny-LLM experiment
|
| 172 |
+
- Pipeline validation
|
| 173 |
+
- Small-model architecture experimentation
|
| 174 |
+
- Baseline for future 150M+ runs
|
| 175 |
+
|
| 176 |
+
## Recommended next steps
|
| 177 |
+
|
| 178 |
+
To improve quality meaningfully:
|
| 179 |
+
|
| 180 |
+
1. Train a larger ~150M model.
|
| 181 |
+
2. Use more unique pretraining tokens, e.g. ~5B+.
|
| 182 |
+
3. Improve preprocessing/tokenization throughput with multiprocessing/sharding.
|
| 183 |
+
4. Add stronger instruction data and possibly preference tuning.
|
| 184 |
+
5. Export to a standard Hugging Face `transformers` compatible format.
|
| 185 |
+
|
| 186 |
+
## Citation / attribution
|
| 187 |
+
|
| 188 |
+
Training datasets:
|
| 189 |
+
|
| 190 |
+
- `Skylion007/openwebtext`
|
| 191 |
+
- `HuggingFaceTB/smol-smoltalk`
|
| 192 |
+
|
| 193 |
+
This repository is an experimental model artifact from a custom tiny-LLM training pipeline.
|
configs/model_75m.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
vocab_size: 32000
|
| 3 |
+
hidden_size: 576
|
| 4 |
+
num_hidden_layers: 16
|
| 5 |
+
num_attention_heads: 9
|
| 6 |
+
num_key_value_heads: 3
|
| 7 |
+
intermediate_size: 1536
|
| 8 |
+
max_position_embeddings: 1024
|
| 9 |
+
rope_theta: 10000.0
|
| 10 |
+
rms_norm_eps: 1.0e-5
|
| 11 |
+
tie_word_embeddings: true
|
| 12 |
+
attention_bias: false
|
| 13 |
+
mlp_bias: false
|
| 14 |
+
dropout: 0.0
|
| 15 |
+
bos_token_id: 1
|
| 16 |
+
eos_token_id: 2
|
scripts/infer_tinyllm.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse, sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import torch
|
| 5 |
+
from tokenizers import ByteLevelBPETokenizer
|
| 6 |
+
|
| 7 |
+
# If running from the repo root, src/ is available locally.
|
| 8 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
| 9 |
+
from src.tinyllm.config import TinyConfig
|
| 10 |
+
from src.tinyllm.model import TinyLlamaForCausalLM
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def make_prompt(user_prompt: str, system: str) -> str:
|
| 14 |
+
return f"<|system|>\n{system}\n<|end|>\n<|user|>\n{user_prompt}\n<|end|>\n<|assistant|>\n"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def sample_next(logits, temperature: float, top_k: int):
|
| 18 |
+
logits = logits.float()
|
| 19 |
+
if temperature <= 0:
|
| 20 |
+
return int(torch.argmax(logits))
|
| 21 |
+
logits = logits / temperature
|
| 22 |
+
if top_k and top_k > 0:
|
| 23 |
+
vals, idx = torch.topk(logits, min(top_k, logits.numel()))
|
| 24 |
+
probs = torch.softmax(vals, dim=-1)
|
| 25 |
+
return int(idx[torch.multinomial(probs, 1)])
|
| 26 |
+
probs = torch.softmax(logits, dim=-1)
|
| 27 |
+
return int(torch.multinomial(probs, 1))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def main():
|
| 31 |
+
ap = argparse.ArgumentParser()
|
| 32 |
+
ap.add_argument('--checkpoint', default='final.pt')
|
| 33 |
+
ap.add_argument('--config', default='configs/model_75m.yaml')
|
| 34 |
+
ap.add_argument('--tokenizer-dir', default='tokenizer')
|
| 35 |
+
ap.add_argument('--prompt', required=True)
|
| 36 |
+
ap.add_argument('--system', default='You are a helpful, concise assistant.')
|
| 37 |
+
ap.add_argument('--max-new-tokens', type=int, default=80)
|
| 38 |
+
ap.add_argument('--temperature', type=float, default=0.6)
|
| 39 |
+
ap.add_argument('--top-k', type=int, default=40)
|
| 40 |
+
args = ap.parse_args()
|
| 41 |
+
|
| 42 |
+
tok_dir = Path(args.tokenizer_dir)
|
| 43 |
+
tok = ByteLevelBPETokenizer(str(tok_dir / 'vocab.json'), str(tok_dir / 'merges.txt'))
|
| 44 |
+
cfg = TinyConfig.from_yaml(args.config)
|
| 45 |
+
model = TinyLlamaForCausalLM(cfg)
|
| 46 |
+
ckpt = torch.load(args.checkpoint, map_location='cpu')
|
| 47 |
+
state = ckpt['model'] if isinstance(ckpt, dict) and 'model' in ckpt else ckpt
|
| 48 |
+
model.load_state_dict(state, strict=False)
|
| 49 |
+
model.eval()
|
| 50 |
+
|
| 51 |
+
ids = tok.encode(make_prompt(args.prompt, args.system)).ids
|
| 52 |
+
prompt_len = len(ids)
|
| 53 |
+
end_id = tok.token_to_id('<|end|>')
|
| 54 |
+
for _ in range(args.max_new_tokens):
|
| 55 |
+
x = torch.tensor([ids[-cfg.max_position_embeddings:]], dtype=torch.long)
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
logits = model(x)['logits'][0, -1]
|
| 58 |
+
nxt = sample_next(logits, args.temperature, args.top_k)
|
| 59 |
+
ids.append(nxt)
|
| 60 |
+
if end_id is not None and nxt == end_id:
|
| 61 |
+
break
|
| 62 |
+
|
| 63 |
+
text = tok.decode(ids[prompt_len:])
|
| 64 |
+
for marker in ['<|end|>', '<|user|>', '<|assistant|>', '<|system|>']:
|
| 65 |
+
text = text.split(marker)[0]
|
| 66 |
+
print(text.strip())
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
if __name__ == '__main__':
|
| 70 |
+
main()
|
src/__init__.py
ADDED
|
File without changes
|
src/tinyllm/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__version__ = '0.1.0'
|
src/tinyllm/checkpoint.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import torch, time, shutil
|
| 3 |
+
|
| 4 |
+
def save_checkpoint(path, model, optimizer=None, scheduler=None, step=0, epoch=0, metrics=None, config=None):
|
| 5 |
+
path=Path(path); path.parent.mkdir(parents=True, exist_ok=True)
|
| 6 |
+
tmp=path.with_suffix(path.suffix+'.tmp')
|
| 7 |
+
state={'model':model.state_dict(),'step':step,'epoch':epoch,'metrics':metrics or {},'saved_at':time.time(),'config':config}
|
| 8 |
+
if optimizer is not None: state['optimizer']=optimizer.state_dict()
|
| 9 |
+
if scheduler is not None: state['scheduler']=scheduler.state_dict()
|
| 10 |
+
torch.save(state,tmp); tmp.replace(path)
|
| 11 |
+
latest=path.parent/'latest.pt'
|
| 12 |
+
if latest != path: shutil.copy2(path, latest)
|
| 13 |
+
|
| 14 |
+
def load_checkpoint(path, model, optimizer=None, scheduler=None, map_location='cpu'):
|
| 15 |
+
state=torch.load(path,map_location=map_location)
|
| 16 |
+
model.load_state_dict(state['model'])
|
| 17 |
+
if optimizer is not None and 'optimizer' in state: optimizer.load_state_dict(state['optimizer'])
|
| 18 |
+
if scheduler is not None and 'scheduler' in state: scheduler.load_state_dict(state['scheduler'])
|
| 19 |
+
return state
|
src/tinyllm/config.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
import yaml
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class TinyConfig:
|
| 7 |
+
vocab_size: int = 32000
|
| 8 |
+
hidden_size: int = 576
|
| 9 |
+
num_hidden_layers: int = 16
|
| 10 |
+
num_attention_heads: int = 9
|
| 11 |
+
num_key_value_heads: int = 3
|
| 12 |
+
intermediate_size: int = 1536
|
| 13 |
+
max_position_embeddings: int = 1024
|
| 14 |
+
rope_theta: float = 10000.0
|
| 15 |
+
rms_norm_eps: float = 1e-5
|
| 16 |
+
tie_word_embeddings: bool = True
|
| 17 |
+
attention_bias: bool = False
|
| 18 |
+
mlp_bias: bool = False
|
| 19 |
+
dropout: float = 0.0
|
| 20 |
+
bos_token_id: int = 1
|
| 21 |
+
eos_token_id: int = 2
|
| 22 |
+
|
| 23 |
+
@classmethod
|
| 24 |
+
def from_yaml(cls, path):
|
| 25 |
+
data = yaml.safe_load(Path(path).read_text())
|
| 26 |
+
if 'model' in data:
|
| 27 |
+
data = data['model']
|
| 28 |
+
return cls(**data)
|
| 29 |
+
|
| 30 |
+
def load_yaml(path):
|
| 31 |
+
return yaml.safe_load(Path(path).read_text())
|
src/tinyllm/data.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
def save_tokens_bin(tokens, path, dtype=np.uint16):
|
| 5 |
+
path=Path(path).expanduser(); path.parent.mkdir(parents=True,exist_ok=True)
|
| 6 |
+
np.asarray(tokens,dtype=dtype).tofile(path)
|
| 7 |
+
|
| 8 |
+
def load_tokens_bin(path, dtype=np.uint16):
|
| 9 |
+
return np.memmap(Path(path).expanduser(), dtype=dtype, mode='r')
|
| 10 |
+
|
| 11 |
+
class TokenBlockDataset:
|
| 12 |
+
def __init__(self, bin_path, block_size):
|
| 13 |
+
self.data=load_tokens_bin(bin_path); self.block_size=block_size
|
| 14 |
+
def __len__(self): return max(0, len(self.data)-self.block_size-1)
|
| 15 |
+
def get_batch(self, batch_size, device):
|
| 16 |
+
import torch
|
| 17 |
+
ix=torch.randint(len(self),(batch_size,))
|
| 18 |
+
x=torch.stack([torch.from_numpy(np.array(self.data[i:i+self.block_size], dtype=np.int64)) for i in ix])
|
| 19 |
+
y=torch.stack([torch.from_numpy(np.array(self.data[i+1:i+1+self.block_size], dtype=np.int64)) for i in ix])
|
| 20 |
+
return x.to(device), y.to(device)
|
src/tinyllm/metrics.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json, csv, time
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
class MetricsLogger:
|
| 5 |
+
def __init__(self, path):
|
| 6 |
+
self.path=Path(path); self.path.parent.mkdir(parents=True,exist_ok=True)
|
| 7 |
+
def log(self, **kw):
|
| 8 |
+
kw.setdefault('time', time.time())
|
| 9 |
+
with self.path.open('a') as f: f.write(json.dumps(kw)+'\n')
|
| 10 |
+
|
| 11 |
+
def jsonl_to_csv(jsonl_path, csv_path):
|
| 12 |
+
rows=[json.loads(l) for l in Path(jsonl_path).read_text().splitlines() if l.strip()]
|
| 13 |
+
if not rows: return
|
| 14 |
+
keys=sorted(set().union(*(r.keys() for r in rows)))
|
| 15 |
+
with open(csv_path,'w',newline='') as f:
|
| 16 |
+
w=csv.DictWriter(f,fieldnames=keys); w.writeheader(); w.writerows(rows)
|
src/tinyllm/model.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from dataclasses import asdict
|
| 5 |
+
from .config import TinyConfig
|
| 6 |
+
|
| 7 |
+
class RMSNorm(nn.Module):
|
| 8 |
+
def __init__(self, dim, eps=1e-5):
|
| 9 |
+
super().__init__(); self.eps=eps; self.weight=nn.Parameter(torch.ones(dim))
|
| 10 |
+
def forward(self, x):
|
| 11 |
+
return self.weight * x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 12 |
+
|
| 13 |
+
def rotate_half(x):
|
| 14 |
+
x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
|
| 15 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 16 |
+
|
| 17 |
+
class RotaryEmbedding(nn.Module):
|
| 18 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000.0):
|
| 19 |
+
super().__init__()
|
| 20 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 21 |
+
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
| 22 |
+
freqs = torch.einsum('i,j->ij', t, inv_freq)
|
| 23 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 24 |
+
self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
|
| 25 |
+
self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
|
| 26 |
+
def forward(self, q, k, seq_len):
|
| 27 |
+
cos = self.cos_cached[:, :, :seq_len, :].to(q.device, q.dtype)
|
| 28 |
+
sin = self.sin_cached[:, :, :seq_len, :].to(q.device, q.dtype)
|
| 29 |
+
return (q*cos + rotate_half(q)*sin), (k*cos + rotate_half(k)*sin)
|
| 30 |
+
|
| 31 |
+
class CausalSelfAttention(nn.Module):
|
| 32 |
+
def __init__(self, cfg: TinyConfig):
|
| 33 |
+
super().__init__()
|
| 34 |
+
assert cfg.hidden_size % cfg.num_attention_heads == 0
|
| 35 |
+
assert cfg.num_attention_heads % cfg.num_key_value_heads == 0
|
| 36 |
+
self.nh=cfg.num_attention_heads; self.nkv=cfg.num_key_value_heads
|
| 37 |
+
self.hd=cfg.hidden_size//cfg.num_attention_heads
|
| 38 |
+
self.q_proj=nn.Linear(cfg.hidden_size, self.nh*self.hd, bias=cfg.attention_bias)
|
| 39 |
+
self.k_proj=nn.Linear(cfg.hidden_size, self.nkv*self.hd, bias=cfg.attention_bias)
|
| 40 |
+
self.v_proj=nn.Linear(cfg.hidden_size, self.nkv*self.hd, bias=cfg.attention_bias)
|
| 41 |
+
self.o_proj=nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=cfg.attention_bias)
|
| 42 |
+
self.rotary=RotaryEmbedding(self.hd, cfg.max_position_embeddings, cfg.rope_theta)
|
| 43 |
+
self.dropout=cfg.dropout
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
B,T,C=x.shape
|
| 46 |
+
q=self.q_proj(x).view(B,T,self.nh,self.hd).transpose(1,2)
|
| 47 |
+
k=self.k_proj(x).view(B,T,self.nkv,self.hd).transpose(1,2)
|
| 48 |
+
v=self.v_proj(x).view(B,T,self.nkv,self.hd).transpose(1,2)
|
| 49 |
+
q,k=self.rotary(q,k,T)
|
| 50 |
+
if self.nkv != self.nh:
|
| 51 |
+
repeat = self.nh // self.nkv
|
| 52 |
+
k = k.repeat_interleave(repeat, dim=1)
|
| 53 |
+
v = v.repeat_interleave(repeat, dim=1)
|
| 54 |
+
y=F.scaled_dot_product_attention(q,k,v,dropout_p=self.dropout if self.training else 0.0,is_causal=True)
|
| 55 |
+
y=y.transpose(1,2).contiguous().view(B,T,C)
|
| 56 |
+
return self.o_proj(y)
|
| 57 |
+
|
| 58 |
+
class SwiGLU(nn.Module):
|
| 59 |
+
def __init__(self, cfg: TinyConfig):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.gate_proj=nn.Linear(cfg.hidden_size,cfg.intermediate_size,bias=cfg.mlp_bias)
|
| 62 |
+
self.up_proj=nn.Linear(cfg.hidden_size,cfg.intermediate_size,bias=cfg.mlp_bias)
|
| 63 |
+
self.down_proj=nn.Linear(cfg.intermediate_size,cfg.hidden_size,bias=cfg.mlp_bias)
|
| 64 |
+
def forward(self,x):
|
| 65 |
+
return self.down_proj(F.silu(self.gate_proj(x))*self.up_proj(x))
|
| 66 |
+
|
| 67 |
+
class Block(nn.Module):
|
| 68 |
+
def __init__(self,cfg):
|
| 69 |
+
super().__init__(); self.input_norm=RMSNorm(cfg.hidden_size,cfg.rms_norm_eps); self.attn=CausalSelfAttention(cfg); self.post_norm=RMSNorm(cfg.hidden_size,cfg.rms_norm_eps); self.mlp=SwiGLU(cfg)
|
| 70 |
+
def forward(self,x):
|
| 71 |
+
x=x+self.attn(self.input_norm(x)); x=x+self.mlp(self.post_norm(x)); return x
|
| 72 |
+
|
| 73 |
+
class TinyLlamaForCausalLM(nn.Module):
|
| 74 |
+
def __init__(self,cfg:TinyConfig):
|
| 75 |
+
super().__init__(); self.config=cfg
|
| 76 |
+
self.embed_tokens=nn.Embedding(cfg.vocab_size,cfg.hidden_size)
|
| 77 |
+
self.layers=nn.ModuleList([Block(cfg) for _ in range(cfg.num_hidden_layers)])
|
| 78 |
+
self.norm=RMSNorm(cfg.hidden_size,cfg.rms_norm_eps)
|
| 79 |
+
self.lm_head=nn.Linear(cfg.hidden_size,cfg.vocab_size,bias=False)
|
| 80 |
+
if cfg.tie_word_embeddings:
|
| 81 |
+
self.lm_head.weight=self.embed_tokens.weight
|
| 82 |
+
self.apply(self._init_weights)
|
| 83 |
+
def _init_weights(self,m):
|
| 84 |
+
if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
| 85 |
+
if isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
| 86 |
+
def forward(self,input_ids,labels=None,loss_mask=None):
|
| 87 |
+
x=self.embed_tokens(input_ids)
|
| 88 |
+
for layer in self.layers: x=layer(x)
|
| 89 |
+
logits=self.lm_head(self.norm(x))
|
| 90 |
+
loss=None
|
| 91 |
+
if labels is not None:
|
| 92 |
+
shift_logits=logits[:,:-1,:].contiguous(); shift_labels=labels[:,1:].contiguous()
|
| 93 |
+
per=F.cross_entropy(shift_logits.view(-1,shift_logits.size(-1)),shift_labels.view(-1),reduction='none')
|
| 94 |
+
if loss_mask is not None:
|
| 95 |
+
mask=loss_mask[:,1:].contiguous().view(-1).float(); loss=(per*mask).sum()/mask.sum().clamp_min(1.0)
|
| 96 |
+
else: loss=per.mean()
|
| 97 |
+
return {'loss':loss,'logits':logits}
|
| 98 |
+
def num_parameters(self): return sum(p.numel() for p in self.parameters())
|
| 99 |
+
def save_config(self,path):
|
| 100 |
+
import json; open(path,'w').write(json.dumps(asdict(self.config),indent=2))
|
tokenizer/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|