llm-jax

A small GPT-style language model built from scratch with JAX/Flax.

Model Details

Param Value
Parameters ~256d / 4L / 4H
Vocab Size 50257
Context Length 128
Tokenizer tiktoken (gpt2)

Usage

git clone https://huggingface.co/alexdosy/llm-jax-small
cd llm-jax-small
python -m src.generate --checkpoint ./params --prompt "Hello world"

Training

Trained on roneneldan/TinyStories for 10000 steps with batch size 8.

Downloads last month
36
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Space using alexdosy/llm-jax-small 1