File size: 4,150 Bytes
65399b9 3e7c527 65399b9 3e7c527 ed936c5 65399b9 3e7c527 fc8c251 65399b9 3e7c527 71b0880 3e7c527 71b0880 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 4b31097 3e7c527 ed936c5 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 3e7c527 65399b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
---
library_name: easydel
pipeline_tag: text-generation
tags:
- easydel
- jax
- "kimi_vl"
- "CausalLM"
- "ragged_page_attention_v3"
---
<p align="center">
<img alt="easydel" src="https://raw.githubusercontent.com/erfanzar/easydel/main/images/easydel-logo-with-text.png">
</p>
<h1 align="center">Kimi-VL-A3B-Instruct</h1>
<div align="center">
A model compatible with the EasyDeL JAX stack.
</div>
## Overview
This checkpoint is intended to be loaded with EasyDeL on JAX (CPU/GPU/TPU). It supports sharded loading with `auto_shard_model=True` and configurable precision via `dtype`, `param_dtype`, and `precision`.
## Quickstart
```python
import easydel as ed
from jax import numpy as jnp, lax
repo_id = "EasyDeL/Kimi-VL-A3B-Instruct"
dtype = jnp.bfloat16 # try jnp.float16 on many GPUs
model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
repo_id,
dtype=dtype,
param_dtype=dtype,
precision=lax.Precision("fastest"),
sharding_axis_names=("dp", "fsdp", "ep", "tp", "sp"),
sharding_axis_dims=(1, -1, 1, 1, 1),
config_kwargs=ed.EasyDeLBaseConfigDict(
attn_dtype=dtype,
attn_mechanism=ed.AttentionMechanisms.RAGGED_PAGE_ATTENTION_V3,
fsdp_is_ep_bound=True,
sp_is_ep_bound=True,
moe_method=ed.MoEMethods.FUSED_MOE,
),
auto_shard_model=True,
partition_axis=ed.PartitionAxis(),
)
```
If the repository only provides PyTorch weights, pass `from_torch=True` to `from_pretrained(...)`.
## Sharding & Parallelism (Multi-Device)
EasyDeL can scale to multiple devices by creating a logical device mesh. Most EasyDeL loaders use a 5D mesh:
- `dp`: data parallel (replicated parameters, different batch shards)
- `fsdp`: parameter sharding (memory saver; often the biggest axis)
- `ep`: expert parallel (MoE; keep `1` for non-MoE models)
- `tp`: tensor parallel (splits large matmuls)
- `sp`: sequence parallel (splits sequence dimension)
Use `sharding_axis_names=("dp","fsdp","ep","tp","sp")` and choose `sharding_axis_dims` so that their product equals your device count.
You can use `-1` in `sharding_axis_dims` to let EasyDeL infer the remaining dimension.
<details>
<summary>Example sharding configs</summary>
```python
# 8 devices, pure FSDP
sharding_axis_dims = (1, 8, 1, 1, 1)
# 8 devices, 2-way DP x 4-way FSDP
sharding_axis_dims = (2, 4, 1, 1, 1)
# 8 devices, 4-way FSDP x 2-way TP
sharding_axis_dims = (1, 4, 1, 2, 1)
```
</details>
## Using via `eLargeModel` (ELM)
`eLargeModel` is a higher-level interface that wires together loading, sharding, training, and eSurge inference from a single config.
```python
from easydel import eLargeModel
repo_id = "EasyDeL/Kimi-VL-A3B-Instruct"
elm = eLargeModel.from_pretrained(repo_id) # task is auto-detected
elm.set_dtype("bf16")
elm.set_sharding(axis_names=("dp", "fsdp", "ep", "tp", "sp"), axis_dims=(1, -1, 1, 1, 1))
model = elm.build_model()
# Optional: build an inference engine
# engine = elm.build_esurge()
```
<details>
<summary>ELM YAML config example</summary>
```yaml
model:
name_or_path: "EasyDeL/Kimi-VL-A3B-Instruct"
loader:
dtype: bf16
param_dtype: bf16
sharding:
axis_dims: [1, -1, 1, 1, 1]
auto_shard_model: true
```
</details>
## Features
**EasyDeL:**
- JAX native implementation and sharded execution
- Configurable attention backends via `AttentionMechanisms.*`
- Precision control via `dtype`, `param_dtype`, and `precision`
## Installation
```bash
pip install easydel
```
## Links
- EasyDeL GitHub: https://github.com/erfanzar/EasyDeL
- Docs: https://easydel.readthedocs.io/en/latest/
## Supported Tasks
- CausalLM
## Limitations
- Refer to the original model card for training data, evaluation, and intended use.
## License
EasyDeL is released under the Apache-2.0 license. The license for this model's weights may differ; please consult the original repository.
## Citation
```bibtex
@misc{Zare Chavoshi_2023,
title={EasyDeL: An open-source library for enhancing and streamlining the training process of machine learning models},
url={https://github.com/erfanzar/EasyDeL},
author={Zare Chavoshi, Erfan},
year={2023}
}
```
|