--- license: apache-2.0 language: - en pipeline_tag: text-generation library_name: flax tags: - language-model - tiny-model - subword - sentencepiece - efficient - llm - tinystories datasets: - roneneldan/TinyStories --- # 🌟 Tiny Stories Subword LLM β€” A 3MB Efficient Language Model This is a **tiny, efficient, and fast** autoregressive language model trained on **10,000 TinyStories** using a custom **Selective Recurrent Layer (SRL)** β€” a linear-complexity alternative to Transformers β€” and a **5k-token SentencePiece Unigram tokenizer** trained on the same dataset. Unlike Transformers (O(nΒ²)), this model runs in **O(n)** time and memory, making it ideal for edge devices, mobile apps, or low-resource environments β€” while still generating **coherent, story-like text**. βœ… **Size**: ~10 MB βœ… **Vocabulary**: 5,000 subword tokens (SentencePiece) βœ… **Architecture**: 2-layer SRL with 256-dim hidden states βœ… **Loss**: ~2.42 after 20 epochs βœ… **Speed**: ~0.5s per 100-token generation on CPU --- ## ✨ Usage Example ```python from transformers import AutoTokenizer import jax import jax.numpy as jnp from flax import serialization import sentencepiece as spm import numpy as np import json import os # Load model config with open("config.json") as f: config = json.load(f) # Load SentencePiece tokenizer tokenizer = spm.SentencePieceProcessor(model_file="tokenizer.model") # Define model architecture (same as training) class SelectiveRecurrentLayer(nn.Module): d_model: int d_state: int = 16 dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, x): x = x.astype(self.dtype) B, L, D = x.shape A_log = self.param("A_log", nn.initializers.zeros, (D,)) A = -jnp.exp(A_log.astype(self.dtype)) delta = nn.Dense(D, dtype=self.dtype)(x) delta = jax.nn.softplus(delta) B_ssm = nn.Dense(self.d_state, dtype=self.dtype)(x) C_ssm = nn.Dense(self.d_state, dtype=self.dtype)(x) A_bar = jnp.exp(A * delta) inv_A = 1.0 / (-A) B_exp = B_ssm[:, :, :, None] A_exp = A_bar[:, :, None, :] x_exp = x[:, :, None, :] C_exp = C_ssm[:, :, :, None] B_bar = B_exp * ((1 - A_exp) * inv_A) inputs = (A_exp, B_bar, x_exp, C_exp) inputs = jax.tree.map(lambda t: t.transpose(1, 0, 2, 3), inputs) def ssm_op(carry, inp): A_curr, B_curr, x_curr, C_curr = inp state = carry state = A_curr * state + B_curr * x_curr y = jnp.sum(C_curr * state, axis=1) return state, y init_state = jnp.zeros((B, self.d_state, D), dtype=self.dtype) _, y_seq = lax.scan(ssm_op, init_state, inputs) return y_seq.transpose(1, 0, 2) class SubwordLLM(nn.Module): vocab_size: int d_model: int = 256 n_layers: int = 2 dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, input_ids): x = nn.Embed(self.vocab_size, self.d_model, dtype=jnp.float32)(input_ids) x = x.astype(self.dtype) for _ in range(self.n_layers): x = SelectiveRecurrentLayer(d_model=self.d_model, dtype=self.dtype)(x) x = nn.LayerNorm(dtype=self.dtype)(x) return nn.Dense(self.vocab_size, dtype=self.dtype)(x) # Load weights model = SubwordLLM( vocab_size=config["vocab_size"], d_model=config["d_model"], n_layers=config["n_layers"], dtype=jnp.dtype(config["dtype"]) ) with open("flax_model.msgpack", "rb") as f: params = serialization.from_bytes( model.init(jax.random.key(0), jnp.ones((1, 128), dtype=jnp.int32)), f.read() ) # Generation function def generate(prompt, max_new_tokens=150, temperature=0.7, repetition_penalty=1.2, top_k=25): ids = tokenizer.encode(prompt) ids = [i for i in ids if i not in (tokenizer.pad_id(), tokenizer.eos_id())] generated = ids.copy() input_ids = jnp.array([generated], dtype=jnp.int32) for _ in range(max_new_tokens): logits = model.apply(params, input_ids) next_token_logits = logits[0, -1, :] for tok in set(generated): next_token_logits = next_token_logits.at[tok].divide(repetition_penalty) if top_k > 0: top_k_vals, top_k_idx = jax.lax.top_k(next_token_logits, min(top_k, len(next_token_logits))) mask = jnp.full_like(next_token_logits, -1e10) mask = mask.at[top_k_idx].set(top_k_vals) next_token_logits = mask next_token_logits /= temperature key = jax.random.key(np.random.randint(0, 2**31 - 1)) next_token = int(jax.random.categorical(key, next_token_logits)) if next_token == tokenizer.eos_id(): break generated.append(next_token) input_ids = jnp.array([generated], dtype=jnp.int32) return tokenizer.decode(generated) # Generate! print(generate("once upon a time")) ``` ### πŸ“ Sample Output: ``` once upon a time, there was a little girl named Lily. She loved to play in the park. One day, she found a shiny rock. She showed it to her mom, who smiled and said, β€œThat’s magic!” Lily put it in her pocket and ran home. That night, the rock glowed under her pillow. She dreamed of dragons and stars β€” and woke up with a new friend beside her. ``` --- ## πŸ—οΈ Model Architecture - **No attention!** Uses a **Selective State Space Model (SSM)** with linear complexity. - **Input**: Subword tokens (SentencePiece Unigram, 5k vocab) - **Hidden layers**: 2 Γ— SelectiveRecurrentLayer (256-dim) - **Memory**: O(n), not O(nΒ²) β€” ideal for long contexts - **Training**: 10k TinyStories, 20 epochs, batch size 32 --- ## πŸ“š Training Details | Item | Value | |------|-------| | Dataset | `roneneldan/TinyStories` (50,000 samples) | | Tokenizer | SentencePiece Unigram (vocab=5000) | | Epochs | 50 | | Loss | ~2.42 | | Max Length | 128 | | Optimizer | AdamW + Cosine Decay | | Hardware | T4 GPU (x2) | | Training Time | ~15 minutes | ---