English
nanoLLM / NanoLLM.py
samairtimer's picture
Rename MiniLLM.py to NanoLLM.py
b14b120 verified
Raw
History Blame Contribute Delete
3.27 kB
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import grain.python as pygrain
import optax
import tiktoken
from pathlib import Path
class TransformerBlock(nnx.Module):
def __init__(self, embed_dim, num_heads, ff_dim, *, rngs):
self.attention = nnx.MultiHeadAttention(
num_heads=num_heads,
in_features=embed_dim,
qkv_features=embed_dim,
out_features=embed_dim,
decode=False,
rngs=rngs
)
def __call__(self, x, mask=None):
attn_out = self.attention(x, mask=mask)
x = x + attn_out
return x
class TokenAndPositionEmbedding(nnx.Module):
def __init__(self, maxlen, vocab_size, embed_dim, *, rngs):
self.token_emb = nnx.Embed(vocab_size, embed_dim, rngs=rngs)
self.pos_emb = nnx.Embed(maxlen, embed_dim, rngs=rngs)
def __call__(self, x):
seq_len = x.shape[1]
positions = jnp.arange(seq_len)[None, :]
return self.token_emb(x) + self.pos_emb(positions)
class NanoLLM(nnx.Module):
def __init__(self, maxlen=maxlen, vocab_size=vocab_size, embed_dim=embed_dim, num_heads=num_heads,
feed_forward_dim=feed_forward_dim, num_transformer_blocks=num_transformer_blocks, *, rngs=nnx.Rngs(0)):
self.maxlen = maxlen
self.embedding = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim, rngs=rngs)
self.transformer_blocks = [
TransformerBlock(embed_dim, num_heads, feed_forward_dim, rngs=rngs)
for _ in range(num_transformer_blocks)
]
self.output_layer = nnx.Linear(embed_dim, vocab_size, use_bias=False, rngs=rngs)
def causal_attention_mask(self, seq_len):
return jnp.tril(jnp.ones((seq_len, seq_len)))
def __call__(self, token_ids):
seq_len = token_ids.shape[1]
mask = self.causal_attention_mask(seq_len)
x = self.embedding(token_ids)
for block in self.transformer_blocks:
x = block(x, mask=mask)
logits = self.output_layer(x)
return logits
# Pass the tokenizer as an argument to break the global dependency
def generate_text(self, prompt, tokenizer, max_new_tokens=50, temperature=1.0):
# 1. Encode the string prompt into integer token IDs
tokens = tokenizer.encode(prompt)
# Cache the end token ID so we aren't encoding it on every single loop iteration
end_token_id = tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]
for _ in range(max_new_tokens):
context = tokens[-self.maxlen:]
# RIGHT-pad to match training (not left-pad!)
actual_len = len(context)
if actual_len < self.maxlen:
context = context + [0] * (self.maxlen - actual_len)
context_array = jnp.array(context)[None, :]
logits = self(context_array)
next_token_logits = logits[0, actual_len - 1, :] / temperature
next_token = int(jnp.argmax(next_token_logits))
if next_token == end_token_id:
break
tokens.append(next_token)
# 2. Decode the final list of token IDs back into a string
return tokenizer.decode(tokens)