import jax import jax.numpy as jnp import flax.nnx as nnx import grain.python as pygrain import optax import tiktoken from pathlib import Path class TransformerBlock(nnx.Module): def __init__(self, embed_dim, num_heads, ff_dim, *, rngs): self.attention = nnx.MultiHeadAttention( num_heads=num_heads, in_features=embed_dim, qkv_features=embed_dim, out_features=embed_dim, decode=False, rngs=rngs ) def __call__(self, x, mask=None): attn_out = self.attention(x, mask=mask) x = x + attn_out return x class TokenAndPositionEmbedding(nnx.Module): def __init__(self, maxlen, vocab_size, embed_dim, *, rngs): self.token_emb = nnx.Embed(vocab_size, embed_dim, rngs=rngs) self.pos_emb = nnx.Embed(maxlen, embed_dim, rngs=rngs) def __call__(self, x): seq_len = x.shape[1] positions = jnp.arange(seq_len)[None, :] return self.token_emb(x) + self.pos_emb(positions) class NanoLLM(nnx.Module): def __init__(self, maxlen=maxlen, vocab_size=vocab_size, embed_dim=embed_dim, num_heads=num_heads, feed_forward_dim=feed_forward_dim, num_transformer_blocks=num_transformer_blocks, *, rngs=nnx.Rngs(0)): self.maxlen = maxlen self.embedding = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim, rngs=rngs) self.transformer_blocks = [ TransformerBlock(embed_dim, num_heads, feed_forward_dim, rngs=rngs) for _ in range(num_transformer_blocks) ] self.output_layer = nnx.Linear(embed_dim, vocab_size, use_bias=False, rngs=rngs) def causal_attention_mask(self, seq_len): return jnp.tril(jnp.ones((seq_len, seq_len))) def __call__(self, token_ids): seq_len = token_ids.shape[1] mask = self.causal_attention_mask(seq_len) x = self.embedding(token_ids) for block in self.transformer_blocks: x = block(x, mask=mask) logits = self.output_layer(x) return logits # Pass the tokenizer as an argument to break the global dependency def generate_text(self, prompt, tokenizer, max_new_tokens=50, temperature=1.0): # 1. Encode the string prompt into integer token IDs tokens = tokenizer.encode(prompt) # Cache the end token ID so we aren't encoding it on every single loop iteration end_token_id = tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0] for _ in range(max_new_tokens): context = tokens[-self.maxlen:] # RIGHT-pad to match training (not left-pad!) actual_len = len(context) if actual_len < self.maxlen: context = context + [0] * (self.maxlen - actual_len) context_array = jnp.array(context)[None, :] logits = self(context_array) next_token_logits = logits[0, actual_len - 1, :] / temperature next_token = int(jnp.argmax(next_token_logits)) if next_token == end_token_id: break tokens.append(next_token) # 2. Decode the final list of token IDs back into a string return tokenizer.decode(tokens)