File size: 3,265 Bytes
1548e47 b14b120 1548e47 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | import jax
import jax.numpy as jnp
import flax.nnx as nnx
import grain.python as pygrain
import optax
import tiktoken
from pathlib import Path
class TransformerBlock(nnx.Module):
def __init__(self, embed_dim, num_heads, ff_dim, *, rngs):
self.attention = nnx.MultiHeadAttention(
num_heads=num_heads,
in_features=embed_dim,
qkv_features=embed_dim,
out_features=embed_dim,
decode=False,
rngs=rngs
)
def __call__(self, x, mask=None):
attn_out = self.attention(x, mask=mask)
x = x + attn_out
return x
class TokenAndPositionEmbedding(nnx.Module):
def __init__(self, maxlen, vocab_size, embed_dim, *, rngs):
self.token_emb = nnx.Embed(vocab_size, embed_dim, rngs=rngs)
self.pos_emb = nnx.Embed(maxlen, embed_dim, rngs=rngs)
def __call__(self, x):
seq_len = x.shape[1]
positions = jnp.arange(seq_len)[None, :]
return self.token_emb(x) + self.pos_emb(positions)
class NanoLLM(nnx.Module):
def __init__(self, maxlen=maxlen, vocab_size=vocab_size, embed_dim=embed_dim, num_heads=num_heads,
feed_forward_dim=feed_forward_dim, num_transformer_blocks=num_transformer_blocks, *, rngs=nnx.Rngs(0)):
self.maxlen = maxlen
self.embedding = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim, rngs=rngs)
self.transformer_blocks = [
TransformerBlock(embed_dim, num_heads, feed_forward_dim, rngs=rngs)
for _ in range(num_transformer_blocks)
]
self.output_layer = nnx.Linear(embed_dim, vocab_size, use_bias=False, rngs=rngs)
def causal_attention_mask(self, seq_len):
return jnp.tril(jnp.ones((seq_len, seq_len)))
def __call__(self, token_ids):
seq_len = token_ids.shape[1]
mask = self.causal_attention_mask(seq_len)
x = self.embedding(token_ids)
for block in self.transformer_blocks:
x = block(x, mask=mask)
logits = self.output_layer(x)
return logits
# Pass the tokenizer as an argument to break the global dependency
def generate_text(self, prompt, tokenizer, max_new_tokens=50, temperature=1.0):
# 1. Encode the string prompt into integer token IDs
tokens = tokenizer.encode(prompt)
# Cache the end token ID so we aren't encoding it on every single loop iteration
end_token_id = tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]
for _ in range(max_new_tokens):
context = tokens[-self.maxlen:]
# RIGHT-pad to match training (not left-pad!)
actual_len = len(context)
if actual_len < self.maxlen:
context = context + [0] * (self.maxlen - actual_len)
context_array = jnp.array(context)[None, :]
logits = self(context_array)
next_token_logits = logits[0, actual_len - 1, :] / temperature
next_token = int(jnp.argmax(next_token_logits))
if next_token == end_token_id:
break
tokens.append(next_token)
# 2. Decode the final list of token IDs back into a string
return tokenizer.decode(tokens)
|