| --- |
| library_name: easydel |
| pipeline_tag: text-generation |
| tags: |
| - easydel |
| - jax |
| - "llama" |
| - "CausalLM" |
| - "vanilla" |
| --- |
| |
| <p align="center"> |
| <img alt="EasyDeL" src="https://raw.githubusercontent.com/erfanzar/easydel/main/images/easydel-logo-with-text.png" height="80"> |
| </p> |
|
|
| <h1 align="center">Llama-3.1-8B</h1> |
|
|
| <div align="center"> |
| EasyDeL checkpoint converted from <a href="https://huggingface.co/meta-llama/Llama-3.1-8B">meta-llama/Llama-3.1-8B</a>. |
| </div> |
|
|
| ## Overview |
|
|
| This checkpoint is intended to be loaded with EasyDeL on JAX (CPU/GPU/TPU). It supports sharded loading with `auto_shard_model=True` and configurable precision via `dtype`, `param_dtype`, and `precision`. |
|
|
| ## Quickstart |
|
|
| ```python |
| import easydel as ed |
| from jax import numpy as jnp, lax |
| |
| repo_id = "EasyDeL/Llama-3.1-8B" |
| |
| dtype = jnp.bfloat16 # try jnp.float16 on many GPUs |
| |
| model = ed.AutoEasyDeLModelForCausalLM.from_pretrained( |
| repo_id, |
| dtype=dtype, |
| param_dtype=dtype, |
| precision=lax.Precision("fastest"), |
| sharding_axis_names=("dp", "fsdp", "ep", "tp", "sp"), |
| sharding_axis_dims=(1, -1, 1, 1, 1), |
| config_kwargs=ed.EasyDeLBaseConfigDict( |
| attn_dtype=dtype, |
| attn_mechanism=ed.AttentionMechanisms.VANILLA, |
| fsdp_is_ep_bound=True, |
| sp_is_ep_bound=True, |
| moe_method=ed.MoEMethods.FUSED_MOE, |
| ), |
| auto_shard_model=True, |
| partition_axis=ed.PartitionAxis(), |
| ) |
| ``` |
|
|
| If the repository only provides PyTorch weights, pass `from_torch=True` to `from_pretrained(...)`. |
|
|
| ## Sharding & Parallelism (Multi-Device) |
|
|
| EasyDeL can scale to multiple devices by creating a logical device mesh. Most EasyDeL loaders use a 5D mesh: |
|
|
| - `dp`: data parallel (replicated parameters, different batch shards) |
| - `fsdp`: parameter sharding (memory saver; often the biggest axis) |
| - `ep`: expert parallel (MoE; keep `1` for non-MoE models) |
| - `tp`: tensor parallel (splits large matmuls) |
| - `sp`: sequence parallel (splits sequence dimension) |
|
|
| Use `sharding_axis_names=("dp","fsdp","ep","tp","sp")` and choose `sharding_axis_dims` so that their product equals your device count. |
| You can use `-1` in `sharding_axis_dims` to let EasyDeL infer the remaining dimension. |
|
|
| <details> |
| <summary>Example sharding configs</summary> |
|
|
| ```python |
| # 8 devices, pure FSDP |
| sharding_axis_dims = (1, 8, 1, 1, 1) |
| |
| # 8 devices, 2-way DP x 4-way FSDP |
| sharding_axis_dims = (2, 4, 1, 1, 1) |
| |
| # 8 devices, 4-way FSDP x 2-way TP |
| sharding_axis_dims = (1, 4, 1, 2, 1) |
| ``` |
| </details> |
|
|
| ## Using via `eLargeModel` (ELM) |
|
|
| `eLargeModel` is a higher-level interface that wires together loading, sharding, training, and eSurge inference from a single config. |
|
|
| ```python |
| from easydel import eLargeModel |
| |
| repo_id = "EasyDeL/Llama-3.1-8B" |
| |
| elm = eLargeModel.from_pretrained(repo_id) # task is auto-detected |
| elm.set_dtype("bf16") |
| elm.set_sharding(axis_names=("dp", "fsdp", "ep", "tp", "sp"), axis_dims=(1, -1, 1, 1, 1)) |
| |
| model = elm.build_model() |
| # Optional: build an inference engine |
| # engine = elm.build_esurge() |
| ``` |
|
|
| <details> |
| <summary>ELM YAML config example</summary> |
|
|
| ```yaml |
| model: |
| name_or_path: "EasyDeL/Llama-3.1-8B" |
| |
| loader: |
| dtype: bf16 |
| param_dtype: bf16 |
| |
| sharding: |
| axis_dims: [1, -1, 1, 1, 1] |
| auto_shard_model: true |
| ``` |
| </details> |
|
|
| ## Features |
|
|
| **EasyDeL:** |
| - JAX native implementation and sharded execution |
| - Configurable attention backends via `AttentionMechanisms.*` |
| - Precision control via `dtype`, `param_dtype`, and `precision` |
|
|
| ## Installation |
|
|
| ```bash |
| pip install easydel |
| ``` |
|
|
| ## Links |
|
|
| - EasyDeL GitHub: https://github.com/erfanzar/EasyDeL |
| - Docs: https://easydel.readthedocs.io/en/latest/ |
|
|
| ## Supported Tasks |
|
|
| - CausalLM |
|
|
| ## Limitations |
|
|
| - Refer to the original model card for training data, evaluation, and intended use. |
|
|
| ## License |
|
|
| EasyDeL is released under the Apache-2.0 license. The license for this model's weights may differ; please consult the original repository. |
|
|
| ## Citation |
|
|
| ```bibtex |
| @misc{Zare Chavoshi_2023, |
| title={EasyDeL: An open-source library for enhancing and streamlining the training process of machine learning models}, |
| url={https://github.com/erfanzar/EasyDeL}, |
| author={Zare Chavoshi, Erfan}, |
| year={2023} |
| } |
| ``` |
|
|