| | --- |
| | language: |
| | - en |
| | - fr |
| | - es |
| | - zh |
| | - hi |
| | - ja |
| | - ru |
| | license: apache-2.0 |
| | library_name: flax |
| | tags: |
| | - jax |
| | - flax |
| | - tpu |
| | - text-generation |
| | - base-model |
| | - custom-architecture |
| | datasets: |
| | - HuggingFaceFW/fineweb-edu |
| | - bigcode/starcoderdata |
| | - HuggingFaceFW/fineweb-2 |
| | - open-web-math/open-web-math |
| | metrics: |
| | - loss |
| | - perplexity |
| | pipeline_tag: text-generation |
| | inference: false |
| | --- |
| | |
| | # Zenyx-Base-220M: High-Density Nano Foundation Model |
| |
|
| | <div align="center"> |
| |
|
| |  |
| |  |
| |  |
| |  |
| |
|
| | </div> |
| |
|
| | **Zenyx-Base-220M** is a 220 million parameter causal language model built from scratch using JAX/Flax on Kaggle TPU v5e-8. |
| |
|
| | Unlike typical small models trained on limited data, Zenyx-Base was trained on **~153 Billion tokens**—far exceeding the Chinchilla optimal point for this parameter count. This "over-training" strategy was employed to maximize the information density and logic capabilities of the weights, creating a robust foundation for reasoning tasks. |
| |
|
| | ## 🧠 Model Description |
| |
|
| | * **Architecture:** Custom Llama-style Transformer (RoPE, SwiGLU, RMSNorm, Grouped Query Attention). |
| | * **Tokenizer:** Qwen 2.5 Tokenizer (151,650 Vocab Size) for high compression efficiency. |
| | * **Context Window:** 2048 Tokens. |
| | * **Training Hardware:** TPU v5e-8. |
| | * **Final Validation Loss:** **~2.38** (Exceptional convergence for 220M). |
| |
|
| | ### Technical Specifications |
| | | Hyperparameter | Value | |
| | | :--- | :--- | |
| | | **Layers** | 12 | |
| | | **Hidden Dim** | 768 | |
| | | **MLP Dim** | 3072 | |
| | | **Attention Heads** | 12 | |
| | | **KV Heads** | 4 (GQA) | |
| | | **Vocab Size** | 151,646 | |
| |
|
| | ## 📚 Training Curriculum (The "Omni-Mix") |
| |
|
| | The model was trained using a rigorous 4-stage curriculum designed to layer capabilities sequentially: |
| |
|
| | 1. **Phase 1: Fundamentals (FineWeb-Edu)** |
| | * Focus on high-quality educational English text to establish linguistic baselines. |
| | 2. **Phase 2: Logic & Structure (StarCoder - Python)** |
| | * Introduction of code data to enforce logical indentation, syntax, and structured thinking. |
| | 3. **Phase 3: Multilingualism (FineWeb-2)** |
| | * Exposure to 6 major languages (Hindi, Chinese, Russian, Japanese, French, Spanish) to expand the semantic embedding space. |
| | 4. **Phase 4: The Infinite Polish (Omni-Mix)** |
| | * A weighted interleaving of all previous datasets plus **OpenWebMath** to converge the model's logic and language capabilities. |
| |
|
| | ## 💻 Usage |
| |
|
| | This model is a raw **JAX/Flax** checkpoint saved in `.safetensors` format. It uses a custom architecture definition and requires `flax` and `jax` to run. |
| |
|
| | ### Loading with JAX/Flax |
| |
|
| | ```python |
| | import jax |
| | import jax.numpy as jnp |
| | from flax.training import train_state |
| | from flax import serialization |
| | from safetensors.flax import load_file |
| | from transformers import AutoTokenizer |
| | import flax.linen as nn |
| | |
| | # 1. Define Architecture (Must match training config) |
| | class TransformerLM(nn.Module): |
| | vocab_size: int |
| | embed_dim: int = 768 |
| | num_layers: int = 12 |
| | num_heads: int = 12 |
| | num_kv_heads: int = 4 |
| | mlp_dim: int = 3072 |
| | max_length: int = 2048 |
| | dropout_rate: float = 0.0 |
| | |
| | # ... (Insert full model class definition here from the training script) ... |
| | |
| | # 2. Load Resources |
| | repo_id = "Arko007/Zenyx_Base_220M" |
| | tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", trust_remote_code=True) |
| | |
| | # 3. Initialize & Load Weights |
| | model = TransformerLM(vocab_size=len(tokenizer)) |
| | dummy_input = jnp.ones((1, 1), dtype=jnp.int32) |
| | params = model.init(jax.random.PRNGKey(0), dummy_input)['params'] |
| | |
| | # Load Safetensors |
| | # Ensure model.safetensors is downloaded locally |
| | loaded_params = load_file("model.safetensors") |
| | print("Weights loaded successfully!") |
| | ``` |
| |
|
| | ## ⚠️ Limitations |
| | - Size: At 220M parameters, the model's knowledge retrieval capacity is limited compared to 7B+ models. |
| | - Base Model: This is a pre-trained base. It has not been fine-tuned for chat or instruction following (see Zenyx-DeepSeek-220M for the instruct version). |
| | - Hallucinations: While logically consistent, it may generate factually incorrect statements. |
| |
|
| | ## 📜 Citation |
| |
|
| | ```python |
| | @misc{ZenyxBase220M, |
| | title = {Zenyx-Base-220M: High-Density Foundation Model}, |
| | author = {Arko007}, |
| | year = {2025}, |
| | publisher = {HuggingFace}, |
| | url = {[https://huggingface.co/Arko007/Zenyx_Base_220M](https://huggingface.co/Arko007/Zenyx_Base_220M)} |
| | } |
| | ``` |