| language: | |
| - zh | |
| - en | |
| license: apache-2.0 | |
| tags: | |
| - jax | |
| - flax | |
| - mini-gpt | |
| - text-generation | |
| # handsongpt2 | |
| HandsOnGPT2 model trained on GuoFeng Webnovel Corpus using JAX/Flax on Kaggle TPU. | |
| ## Model Details | |
| - **Architecture**: GPT-2 style transformer | |
| - **Parameters**: 84.6M | |
| - **Vocab Size**: 64,000 (Yi-1.5 tokenizer, TPU-aligned) | |
| - **Max Length**: 256 | |
| - **Layers**: 6 | |
| - **Hidden Size**: 512 | |
| - **Attention Heads**: 8 | |
| ## Training | |
| - **Framework**: JAX/Flax | |
| - **Hardware**: Kaggle TPU v3-8 | |
| - **Batch Size**: 16 | |
| - **Learning Rate**: 0.0003 | |
| - **Final Loss**: 0.0005 | |
| ## Usage | |
| ```python | |
| import orbax.checkpoint as ocp | |
| checkpointer = ocp.PyTreeCheckpointer() | |
| state = checkpointer.restore('/path/to/checkpoint') | |
| ``` | |
| ## License | |
| Apache 2.0 | |