| import jax |
| import jax.numpy as jnp |
| import flax.nnx as nnx |
| import grain.python as pygrain |
| import optax |
| import tiktoken |
| from pathlib import Path |
|
|
| class TransformerBlock(nnx.Module): |
|
|
| def __init__(self, embed_dim, num_heads, ff_dim, *, rngs): |
|
|
| self.attention = nnx.MultiHeadAttention( |
| num_heads=num_heads, |
| in_features=embed_dim, |
| qkv_features=embed_dim, |
| out_features=embed_dim, |
| decode=False, |
| rngs=rngs |
| ) |
|
|
| def __call__(self, x, mask=None): |
| attn_out = self.attention(x, mask=mask) |
| x = x + attn_out |
| return x |
|
|
| class TokenAndPositionEmbedding(nnx.Module): |
| def __init__(self, maxlen, vocab_size, embed_dim, *, rngs): |
| self.token_emb = nnx.Embed(vocab_size, embed_dim, rngs=rngs) |
| self.pos_emb = nnx.Embed(maxlen, embed_dim, rngs=rngs) |
|
|
| def __call__(self, x): |
| seq_len = x.shape[1] |
| positions = jnp.arange(seq_len)[None, :] |
| return self.token_emb(x) + self.pos_emb(positions) |
|
|
| class NanoLLM(nnx.Module): |
|
|
| def __init__(self, maxlen=maxlen, vocab_size=vocab_size, embed_dim=embed_dim, num_heads=num_heads, |
| feed_forward_dim=feed_forward_dim, num_transformer_blocks=num_transformer_blocks, *, rngs=nnx.Rngs(0)): |
|
|
| self.maxlen = maxlen |
|
|
| self.embedding = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim, rngs=rngs) |
|
|
| self.transformer_blocks = [ |
| TransformerBlock(embed_dim, num_heads, feed_forward_dim, rngs=rngs) |
| for _ in range(num_transformer_blocks) |
| ] |
|
|
| self.output_layer = nnx.Linear(embed_dim, vocab_size, use_bias=False, rngs=rngs) |
|
|
| def causal_attention_mask(self, seq_len): |
| return jnp.tril(jnp.ones((seq_len, seq_len))) |
|
|
| def __call__(self, token_ids): |
| seq_len = token_ids.shape[1] |
| mask = self.causal_attention_mask(seq_len) |
|
|
| x = self.embedding(token_ids) |
|
|
| for block in self.transformer_blocks: |
| x = block(x, mask=mask) |
|
|
| logits = self.output_layer(x) |
|
|
| return logits |
|
|
| |
| def generate_text(self, prompt, tokenizer, max_new_tokens=50, temperature=1.0): |
| |
| tokens = tokenizer.encode(prompt) |
|
|
| |
| end_token_id = tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0] |
|
|
| for _ in range(max_new_tokens): |
| context = tokens[-self.maxlen:] |
|
|
| |
| actual_len = len(context) |
| if actual_len < self.maxlen: |
| context = context + [0] * (self.maxlen - actual_len) |
|
|
| context_array = jnp.array(context)[None, :] |
| logits = self(context_array) |
|
|
| next_token_logits = logits[0, actual_len - 1, :] / temperature |
| next_token = int(jnp.argmax(next_token_logits)) |
|
|
| if next_token == end_token_id: |
| break |
|
|
| tokens.append(next_token) |
|
|
| |
| return tokenizer.decode(tokens) |
|
|