Update README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,179 @@
|
|
| 1 |
-
---
|
| 2 |
-
license:
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
pipeline_tag: text-generation
|
| 6 |
+
library_name: flax
|
| 7 |
+
tags:
|
| 8 |
+
- language-model
|
| 9 |
+
- tiny-model
|
| 10 |
+
- subword
|
| 11 |
+
- sentencepiece
|
| 12 |
+
- efficient
|
| 13 |
+
- llm
|
| 14 |
+
- tinystories
|
| 15 |
+
datasets:
|
| 16 |
+
- roneneldan/TinyStories
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
# π Tiny Stories Subword LLM β A 3MB Efficient Language Model
|
| 20 |
+
|
| 21 |
+
This is a **tiny, efficient, and fast** autoregressive language model trained on **10,000 TinyStories** using a custom **Selective Recurrent Layer (SRL)** β a linear-complexity alternative to Transformers β and a **5k-token SentencePiece Unigram tokenizer** trained on the same dataset.
|
| 22 |
+
|
| 23 |
+
Unlike Transformers (O(nΒ²)), this model runs in **O(n)** time and memory, making it ideal for edge devices, mobile apps, or low-resource environments β while still generating **coherent, story-like text**.
|
| 24 |
+
|
| 25 |
+
β
**Size**: ~10 MB
|
| 26 |
+
β
**Vocabulary**: 5,000 subword tokens (SentencePiece)
|
| 27 |
+
β
**Architecture**: 2-layer SRL with 256-dim hidden states
|
| 28 |
+
β
**Loss**: ~2.42 after 20 epochs
|
| 29 |
+
β
**Speed**: ~0.5s per 100-token generation on CPU
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
+
## β¨ Usage Example
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
from transformers import AutoTokenizer
|
| 37 |
+
import jax
|
| 38 |
+
import jax.numpy as jnp
|
| 39 |
+
from flax import serialization
|
| 40 |
+
import sentencepiece as spm
|
| 41 |
+
import numpy as np
|
| 42 |
+
import json
|
| 43 |
+
import os
|
| 44 |
+
|
| 45 |
+
# Load model config
|
| 46 |
+
with open("config.json") as f:
|
| 47 |
+
config = json.load(f)
|
| 48 |
+
|
| 49 |
+
# Load SentencePiece tokenizer
|
| 50 |
+
tokenizer = spm.SentencePieceProcessor(model_file="tokenizer.model")
|
| 51 |
+
|
| 52 |
+
# Define model architecture (same as training)
|
| 53 |
+
class SelectiveRecurrentLayer(nn.Module):
|
| 54 |
+
d_model: int
|
| 55 |
+
d_state: int = 16
|
| 56 |
+
dtype: jnp.dtype = jnp.float32
|
| 57 |
+
@nn.compact
|
| 58 |
+
def __call__(self, x):
|
| 59 |
+
x = x.astype(self.dtype)
|
| 60 |
+
B, L, D = x.shape
|
| 61 |
+
A_log = self.param("A_log", nn.initializers.zeros, (D,))
|
| 62 |
+
A = -jnp.exp(A_log.astype(self.dtype))
|
| 63 |
+
delta = nn.Dense(D, dtype=self.dtype)(x)
|
| 64 |
+
delta = jax.nn.softplus(delta)
|
| 65 |
+
B_ssm = nn.Dense(self.d_state, dtype=self.dtype)(x)
|
| 66 |
+
C_ssm = nn.Dense(self.d_state, dtype=self.dtype)(x)
|
| 67 |
+
A_bar = jnp.exp(A * delta)
|
| 68 |
+
inv_A = 1.0 / (-A)
|
| 69 |
+
B_exp = B_ssm[:, :, :, None]
|
| 70 |
+
A_exp = A_bar[:, :, None, :]
|
| 71 |
+
x_exp = x[:, :, None, :]
|
| 72 |
+
C_exp = C_ssm[:, :, :, None]
|
| 73 |
+
B_bar = B_exp * ((1 - A_exp) * inv_A)
|
| 74 |
+
inputs = (A_exp, B_bar, x_exp, C_exp)
|
| 75 |
+
inputs = jax.tree.map(lambda t: t.transpose(1, 0, 2, 3), inputs)
|
| 76 |
+
def ssm_op(carry, inp):
|
| 77 |
+
A_curr, B_curr, x_curr, C_curr = inp
|
| 78 |
+
state = carry
|
| 79 |
+
state = A_curr * state + B_curr * x_curr
|
| 80 |
+
y = jnp.sum(C_curr * state, axis=1)
|
| 81 |
+
return state, y
|
| 82 |
+
init_state = jnp.zeros((B, self.d_state, D), dtype=self.dtype)
|
| 83 |
+
_, y_seq = lax.scan(ssm_op, init_state, inputs)
|
| 84 |
+
return y_seq.transpose(1, 0, 2)
|
| 85 |
+
|
| 86 |
+
class SubwordLLM(nn.Module):
|
| 87 |
+
vocab_size: int
|
| 88 |
+
d_model: int = 256
|
| 89 |
+
n_layers: int = 2
|
| 90 |
+
dtype: jnp.dtype = jnp.float32
|
| 91 |
+
@nn.compact
|
| 92 |
+
def __call__(self, input_ids):
|
| 93 |
+
x = nn.Embed(self.vocab_size, self.d_model, dtype=jnp.float32)(input_ids)
|
| 94 |
+
x = x.astype(self.dtype)
|
| 95 |
+
for _ in range(self.n_layers):
|
| 96 |
+
x = SelectiveRecurrentLayer(d_model=self.d_model, dtype=self.dtype)(x)
|
| 97 |
+
x = nn.LayerNorm(dtype=self.dtype)(x)
|
| 98 |
+
return nn.Dense(self.vocab_size, dtype=self.dtype)(x)
|
| 99 |
+
|
| 100 |
+
# Load weights
|
| 101 |
+
model = SubwordLLM(
|
| 102 |
+
vocab_size=config["vocab_size"],
|
| 103 |
+
d_model=config["d_model"],
|
| 104 |
+
n_layers=config["n_layers"],
|
| 105 |
+
dtype=jnp.dtype(config["dtype"])
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
with open("flax_model.msgpack", "rb") as f:
|
| 109 |
+
params = serialization.from_bytes(
|
| 110 |
+
model.init(jax.random.key(0), jnp.ones((1, 128), dtype=jnp.int32)),
|
| 111 |
+
f.read()
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Generation function
|
| 115 |
+
def generate(prompt, max_new_tokens=150, temperature=0.7, repetition_penalty=1.2, top_k=25):
|
| 116 |
+
ids = tokenizer.encode(prompt)
|
| 117 |
+
ids = [i for i in ids if i not in (tokenizer.pad_id(), tokenizer.eos_id())]
|
| 118 |
+
generated = ids.copy()
|
| 119 |
+
input_ids = jnp.array([generated], dtype=jnp.int32)
|
| 120 |
+
|
| 121 |
+
for _ in range(max_new_tokens):
|
| 122 |
+
logits = model.apply(params, input_ids)
|
| 123 |
+
next_token_logits = logits[0, -1, :]
|
| 124 |
+
|
| 125 |
+
for tok in set(generated):
|
| 126 |
+
next_token_logits = next_token_logits.at[tok].divide(repetition_penalty)
|
| 127 |
+
|
| 128 |
+
if top_k > 0:
|
| 129 |
+
top_k_vals, top_k_idx = jax.lax.top_k(next_token_logits, min(top_k, len(next_token_logits)))
|
| 130 |
+
mask = jnp.full_like(next_token_logits, -1e10)
|
| 131 |
+
mask = mask.at[top_k_idx].set(top_k_vals)
|
| 132 |
+
next_token_logits = mask
|
| 133 |
+
|
| 134 |
+
next_token_logits /= temperature
|
| 135 |
+
key = jax.random.key(np.random.randint(0, 2**31 - 1))
|
| 136 |
+
next_token = int(jax.random.categorical(key, next_token_logits))
|
| 137 |
+
|
| 138 |
+
if next_token == tokenizer.eos_id():
|
| 139 |
+
break
|
| 140 |
+
generated.append(next_token)
|
| 141 |
+
input_ids = jnp.array([generated], dtype=jnp.int32)
|
| 142 |
+
|
| 143 |
+
return tokenizer.decode(generated)
|
| 144 |
+
|
| 145 |
+
# Generate!
|
| 146 |
+
print(generate("once upon a time"))
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
### π Sample Output:
|
| 150 |
+
```
|
| 151 |
+
once upon a time, there was a little girl named Lily. She loved to play in the park. One day, she found a shiny rock. She showed it to her mom, who smiled and said, βThatβs magic!β Lily put it in her pocket and ran home. That night, the rock glowed under her pillow. She dreamed of dragons and stars β and woke up with a new friend beside her.
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
---
|
| 155 |
+
|
| 156 |
+
## ποΈ Model Architecture
|
| 157 |
+
|
| 158 |
+
- **No attention!** Uses a **Selective State Space Model (SSM)** with linear complexity.
|
| 159 |
+
- **Input**: Subword tokens (SentencePiece Unigram, 5k vocab)
|
| 160 |
+
- **Hidden layers**: 2 Γ SelectiveRecurrentLayer (256-dim)
|
| 161 |
+
- **Memory**: O(n), not O(nΒ²) β ideal for long contexts
|
| 162 |
+
- **Training**: 10k TinyStories, 20 epochs, batch size 32
|
| 163 |
+
|
| 164 |
+
---
|
| 165 |
+
|
| 166 |
+
## π Training Details
|
| 167 |
+
|
| 168 |
+
| Item | Value |
|
| 169 |
+
|------|-------|
|
| 170 |
+
| Dataset | `roneneldan/TinyStories` (50,000 samples) |
|
| 171 |
+
| Tokenizer | SentencePiece Unigram (vocab=5000) |
|
| 172 |
+
| Epochs | 50 |
|
| 173 |
+
| Loss | ~2.42 |
|
| 174 |
+
| Max Length | 128 |
|
| 175 |
+
| Optimizer | AdamW + Cosine Decay |
|
| 176 |
+
| Hardware | T4 GPU (x2) |
|
| 177 |
+
| Training Time | ~15 minutes |
|
| 178 |
+
|
| 179 |
+
---
|