Arko007 commited on
Commit
bebb3b0
·
verified ·
1 Parent(s): 8c23bf8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +104 -175
README.md CHANGED
@@ -1,206 +1,135 @@
1
  ---
 
 
 
 
 
 
 
 
2
  license: apache-2.0
3
- language: en
 
 
 
 
 
 
 
4
  datasets:
5
- - open-web-math/open-web-math
6
- - bigcode/starcoderdata
7
- - HuggingFaceFW/fineweb
8
  - HuggingFaceFW/fineweb-edu
 
9
  - HuggingFaceFW/fineweb-2
 
10
  metrics:
 
11
  - perplexity
12
- - accuracy
13
- - precision
14
  pipeline_tag: text-generation
15
- library_name: transformers
16
- tags:
17
- - text_generaion
18
- - next_token_prediction
19
- - Zenyx
20
- - transformers
21
- - modern
22
- ---
23
-
24
- # finetuned-fineweb-model
25
-
26
- > Short summary:
27
- > - Architecture: Custom GPT-like Transformer (Flax/Linen)
28
- > - Tokenizer: Qwen/Qwen2-0.5B-Instruct
29
- > - Latest saved checkpoint: step 146500
30
- > - License: Apache-2.0
31
-
32
- ---
33
-
34
- ## Model Card
35
-
36
- - Model name: Zenyx-Base-220M
37
- - Repository: Arko007/Zenyx-Base-220M
38
-
39
- Purpose
40
- - Final polishing run of a fineweb pretraining/fine-tuning pipeline using an "Infinite Omni Mix" of web / code / multilingual / math datasets sampled proportionally.
41
- - Intended for research / downstream fine-tuning or evaluation where a compact, well-polished Flax causal LM is desirable.
42
-
43
- Caveats and recommendations
44
- - This model was trained/continued on TPU v5e-8 hardware using Flax/JAX. Loading and inference on CPU/GPU is possible but may require converting weights or using Flax runtime.
45
-
46
  ---
47
 
48
- ## Model architecture (specs)
49
-
50
- These are the properties of the exact model that was trained in the provided training session:
51
-
52
- - Model type: Causal Transformer Language Model (Flax / Linen implementation)
53
- - Tokenizer: Qwen/Qwen2-0.5B-Instruct (uses the tokenizer from HuggingFace; pad token added if missing)
54
- - Vocab size (configured): 151,646 (the script sets VOCAB_SIZE initially; the effective vocab is taken from the tokenizer at runtime)
55
- - Context length (block size / max sequence length): 2048 tokens
56
- - Number of Transformer layers: 12
57
- - Embedding dimension: 768
58
- - Number of attention heads: 12
59
- - Head dimension: 64 (embed_dim / num_heads)
60
- - MLP hidden dim: 3072
61
- - Number of KV heads: 4 (grouped-query attention)
62
- - Rotary embeddings: RoPE caching per-head-dim
63
- - Normalization: RMSNorm
64
- - Activation in FFN: SwiGLU (Silu gating)
65
- - Dropout: 0.1 (training)
66
- - Approximate total parameters: ~140M (order-of-magnitude estimate; exact parameter count computed at init in training logs)
67
-
68
- Notes:
69
- - The implementation uses grouped-query attention with num_kv_heads=4 (k/v maps are shared / repeated to match q heads).
70
- - The output head is tied to the token embedding (logits computed with embedding matrix transpose).
71
 
72
- ---
73
 
74
- ## Training configuration
 
 
 
75
 
76
- This README documents the exact training settings used when the last checkpoint (step 146500) was saved.
77
 
78
- General
79
- - Seed: 42
80
- - Hardware used: TPU v5e-8
81
 
82
- Batching
83
- - Micro-batch size (per step): 16
84
- - Gradient accumulation steps: 32
85
- - Global batch size: 512 (MICRO_BATCH_SIZE * GRADIENT_ACCUM_STEPS)
86
- - Per-core batch (MICRO_BATCH_SIZE // TPU_CORES): derived at runtime from detected TPU_CORES
87
 
88
- Optimizer
89
- - Optimizer: AdamW (Optax)
90
- - Learning rate: 1e-6 (constant schedule — final polish)
91
- - Beta1 / Beta2: 0.9 / 0.95
92
- - Epsilon: 1e-8
93
- - Weight decay: 0.1
94
- - Gradient clipping: clip_by_global_norm(1.0)
95
 
96
- Training schedule & safeguards
97
- - Max training steps: 150,000
98
- - Validation set size (blocks): 2048 blocks (each block = 2048 tokens)
99
- - Safety checks include train/val gap threshold, min perplexity, and gradient norms (warnings; training continues)
 
100
 
101
- Checkpointing
102
- - Checkpoints saved to: `checkpoints/state_step{STEP}.msgpack` and `training_metadata.json` in the model repo
103
- - The training script uploads checkpoints and metadata to the Hugging Face Hub (requires authentication)
 
 
 
 
 
 
104
 
105
- Latest checkpoint
106
- - EVAL Step: 146500
107
- - Validation Loss: 2.3880
108
- - Validation Perplexity (PPL): 10.9
109
 
110
- ---
111
-
112
- ## Datasets used for mixing
113
-
114
- Primary datasets used
115
- - HuggingFaceFW/fineweb-edu (name: sample-350BT) — training split (text)
116
- - HuggingFaceFW/fineweb (name: sample-350BT) — training split (text)
117
- - bigcode/starcoderdata (python) — training split (column: content -> renamed to text)
118
- - open-web-math/open-web-math — training split (text)
119
- - HuggingFaceFW/fineweb-2 (multilingual subsets):
120
- - hin_Deva (Hindi Devanagari)
121
- - cmn_Hani (Chinese Han)
122
- - rus_Cyrl (Russian Cyrillic)
123
- - jpn_Jpan (Japanese)
124
- - fra_Latn (French)
125
- - spa_Latn (Spanish)
126
-
127
- Mix ratios (as used in the script for the main interleave)
128
- - edu: 16.5%
129
- - raw web: 16.5%
130
- - code (StarCoder python): 33%
131
- - multilingual mix: 24% (interleaved across listed languages)
132
- - math: 10%
133
- ---
134
 
135
- ## Evaluation summary
 
 
 
 
 
 
 
136
 
137
- - Validation configuration: 2048 validation blocks (saved to /tmp/finetune_fineweb/val_set_final_v2.npy during run)
138
- - Latest evaluation (EVAL Step 146500):
139
- - Mean validation loss: 2.3880
140
- - Validation perplexity: 10.9
141
- - Model was checkpointed at this step and uploaded to the hub (per training script behavior)
142
-
143
- ---
144
 
145
- ## How to use / inference
146
 
147
- Note: this model is implemented in Flax. The easiest way to run inference is to use the provided tokenizer and a compatible model loader or to use the training script's loading utilities if you keep the same Python package layout.
148
 
149
- 1) Install prerequisites (example):
150
- ```bash
151
- pip install transformers datasets flax jax[tpu] # or jaxlib for CPU/GPU
152
- pip install huggingface_hub
153
- ```
154
-
155
- 2) Authenticate (if the repo is private):
156
- ```bash
157
- huggingface-cli login
158
- # or
159
- export HF_TOKEN="your_token_here"
160
- ```
161
-
162
- 3) Example (recommended approach — load tokenizer and run a simple tokenization + scoring):
163
  ```python
 
 
 
 
 
164
  from transformers import AutoTokenizer
165
- from huggingface_hub import hf_hub_download
166
-
167
- # Load tokenizer from Qwen
168
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True)
169
-
170
- # Download a particular checkpoint file from this model repo (example)
171
- # NOTE: adjust the filename to match what is in the model repo (e.g. checkpoints/state_step146500.msgpack)
172
- # hf_hub_download(repo_id="Arko007/finetuned-fineweb-model", filename="checkpoints/state_step146500.msgpack", repo_type="model")
173
-
174
- # If you provide or write a Flax model loader (matching the training class), you can then:
175
- # - initialize the same model class
176
- # - load checkpoint bytes with flax.serialization.from_bytes on TrainState (see training script)
177
- # - run model.apply to get logits and decode outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  ```
179
 
180
- Because this model uses a custom Flax class defined in the training script, to perform autoregressive generation you should reuse the same model class definition (TransformerLM, attention wrapper, etc.) and load the checkpoint with flax.serialization.from_bytes exactly as the training script does.
181
-
182
- If you prefer PyTorch inference, convert Flax weights to PyTorch using the Hugging Face conversion utilities (if appropriate and if the repository includes consistent naming/structure). Care is required to ensure parameter name mappings match.
183
-
184
- ---
185
-
186
- ## Files of interest in the repo
187
 
188
- - training_metadata.json — metadata uploaded with checkpoints (step, best_val_loss, no_improve, train_losses).
189
- - checkpoints/state_step{STEP}.msgpack — Flax serialized train state (to be loaded with flax.serialization).
190
 
191
- ---
192
-
193
- ## Licensing and citation
194
-
195
- This model and the files in this repository are licensed under the Apache License 2.0.
196
-
197
- If you use this model in academic work or production, please cite this model and include a short description of the training data and settings used in your documentation (see sections above).
198
-
199
- ---
200
-
201
- ## Contact / Maintainers
202
-
203
- - Maintainer: Account that owns the model repo on the Hugging Face Hub (Arko007)
204
- - For questions about training, dataset choices, or how to reuse the training script, open an issue on the model repo or contact the maintainer directly via Hugging Face profile.
205
-
206
- ---
 
1
  ---
2
+ language:
3
+ - en
4
+ - fr
5
+ - es
6
+ - zh
7
+ - hi
8
+ - ja
9
+ - ru
10
  license: apache-2.0
11
+ library_name: flax
12
+ tags:
13
+ - jax
14
+ - flax
15
+ - tpu
16
+ - text-generation
17
+ - base-model
18
+ - custom-architecture
19
  datasets:
 
 
 
20
  - HuggingFaceFW/fineweb-edu
21
+ - bigcode/starcoderdata
22
  - HuggingFaceFW/fineweb-2
23
+ - open-web-math/open-web-math
24
  metrics:
25
+ - loss
26
  - perplexity
 
 
27
  pipeline_tag: text-generation
28
+ inference: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  ---
30
 
31
+ # Zenyx-Base-220M: High-Density Nano Foundation Model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ <div align="center">
34
 
35
+ ![Model Architecture](https://img.shields.io/badge/Model-Zenyx_Base-blue?style=for-the-badge)
36
+ ![Parameter Count](https://img.shields.io/badge/Params-220M-orange?style=for-the-badge)
37
+ ![Training Tokens](https://img.shields.io/badge/Tokens-153B-green?style=for-the-badge)
38
+ ![Format](https://img.shields.io/badge/Weights-Safetensors-yellow?style=for-the-badge)
39
 
40
+ </div>
41
 
42
+ **Zenyx-Base-220M** is a 220 million parameter causal language model built from scratch using JAX/Flax on Kaggle TPU v5e-8.
 
 
43
 
44
+ Unlike typical small models trained on limited data, Zenyx-Base was trained on **~153 Billion tokens**—far exceeding the Chinchilla optimal point for this parameter count. This "over-training" strategy was employed to maximize the information density and logic capabilities of the weights, creating a robust foundation for reasoning tasks.
 
 
 
 
45
 
46
+ ## 🧠 Model Description
 
 
 
 
 
 
47
 
48
+ * **Architecture:** Custom Llama-style Transformer (RoPE, SwiGLU, RMSNorm, Grouped Query Attention).
49
+ * **Tokenizer:** Qwen 2.5 Tokenizer (151,650 Vocab Size) for high compression efficiency.
50
+ * **Context Window:** 2048 Tokens.
51
+ * **Training Hardware:** TPU v5e-8.
52
+ * **Final Validation Loss:** **~2.38** (Exceptional convergence for 220M).
53
 
54
+ ### Technical Specifications
55
+ | Hyperparameter | Value |
56
+ | :--- | :--- |
57
+ | **Layers** | 12 |
58
+ | **Hidden Dim** | 768 |
59
+ | **MLP Dim** | 3072 |
60
+ | **Attention Heads** | 12 |
61
+ | **KV Heads** | 4 (GQA) |
62
+ | **Vocab Size** | 151,646 |
63
 
64
+ ## 📚 Training Curriculum (The "Omni-Mix")
 
 
 
65
 
66
+ The model was trained using a rigorous 4-stage curriculum designed to layer capabilities sequentially:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ 1. **Phase 1: Fundamentals (FineWeb-Edu)**
69
+ * Focus on high-quality educational English text to establish linguistic baselines.
70
+ 2. **Phase 2: Logic & Structure (StarCoder - Python)**
71
+ * Introduction of code data to enforce logical indentation, syntax, and structured thinking.
72
+ 3. **Phase 3: Multilingualism (FineWeb-2)**
73
+ * Exposure to 6 major languages (Hindi, Chinese, Russian, Japanese, French, Spanish) to expand the semantic embedding space.
74
+ 4. **Phase 4: The Infinite Polish (Omni-Mix)**
75
+ * A weighted interleaving of all previous datasets plus **OpenWebMath** to converge the model's logic and language capabilities.
76
 
77
+ ## 💻 Usage
 
 
 
 
 
 
78
 
79
+ This model is a raw **JAX/Flax** checkpoint saved in `.safetensors` format. It uses a custom architecture definition and requires `flax` and `jax` to run.
80
 
81
+ ### Loading with JAX/Flax
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  ```python
84
+ import jax
85
+ import jax.numpy as jnp
86
+ from flax.training import train_state
87
+ from flax import serialization
88
+ from safetensors.flax import load_file
89
  from transformers import AutoTokenizer
90
+ import flax.linen as nn
91
+
92
+ # 1. Define Architecture (Must match training config)
93
+ class TransformerLM(nn.Module):
94
+ vocab_size: int
95
+ embed_dim: int = 768
96
+ num_layers: int = 12
97
+ num_heads: int = 12
98
+ num_kv_heads: int = 4
99
+ mlp_dim: int = 3072
100
+ max_length: int = 2048
101
+ dropout_rate: float = 0.0
102
+
103
+ # ... (Insert full model class definition here from the training script) ...
104
+
105
+ # 2. Load Resources
106
+ repo_id = "Arko007/Zenyx_Base_220M"
107
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", trust_remote_code=True)
108
+
109
+ # 3. Initialize & Load Weights
110
+ model = TransformerLM(vocab_size=len(tokenizer))
111
+ dummy_input = jnp.ones((1, 1), dtype=jnp.int32)
112
+ params = model.init(jax.random.PRNGKey(0), dummy_input)['params']
113
+
114
+ # Load Safetensors
115
+ # Ensure model.safetensors is downloaded locally
116
+ loaded_params = load_file("model.safetensors")
117
+ print("Weights loaded successfully!")
118
  ```
119
 
120
+ ## ⚠️ Limitations
121
+ - Size: At 220M parameters, the model's knowledge retrieval capacity is limited compared to 7B+ models.
122
+ - Base Model: This is a pre-trained base. It has not been fine-tuned for chat or instruction following (see Zenyx-DeepSeek-220M for the instruct version).
123
+ - Hallucinations: While logically consistent, it may generate factually incorrect statements.
 
 
 
124
 
125
+ ## 📜 Citation
 
126
 
127
+ ```python
128
+ @misc{ZenyxBase220M,
129
+ title = {Zenyx-Base-220M: High-Density Foundation Model},
130
+ author = {Arko007},
131
+ year = {2025},
132
+ publisher = {HuggingFace},
133
+ url = {[https://huggingface.co/Arko007/Zenyx_Base_220M](https://huggingface.co/Arko007/Zenyx_Base_220M)}
134
+ }
135
+ ```