File size: 4,150 Bytes
65399b9
3e7c527
 
65399b9
3e7c527
 
 
 
ed936c5
65399b9
3e7c527
fc8c251
 
 
65399b9
3e7c527
 
71b0880
3e7c527
71b0880
3e7c527
65399b9
 
3e7c527
65399b9
3e7c527
65399b9
3e7c527
 
 
65399b9
3e7c527
 
 
 
 
 
 
 
 
4b31097
 
3e7c527
 
ed936c5
3e7c527
 
 
 
 
 
 
 
65399b9
3e7c527
65399b9
3e7c527
65399b9
3e7c527
65399b9
3e7c527
 
 
 
 
 
 
 
65399b9
3e7c527
 
 
 
 
 
 
 
 
 
 
 
 
 
65399b9
3e7c527
65399b9
3e7c527
65399b9
 
3e7c527
65399b9
3e7c527
 
 
 
 
 
 
 
 
65399b9
 
3e7c527
 
65399b9
3e7c527
 
 
65399b9
3e7c527
 
 
65399b9
3e7c527
 
 
 
 
65399b9
3e7c527
65399b9
3e7c527
 
 
 
65399b9
3e7c527
65399b9
3e7c527
 
65399b9
 
3e7c527
65399b9
3e7c527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65399b9
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
---
library_name: easydel
pipeline_tag: text-generation
tags:
  - easydel
  - jax
  - "kimi_vl"
  - "CausalLM"
  - "ragged_page_attention_v3"
---

<p align="center">
  <img alt="easydel" src="https://raw.githubusercontent.com/erfanzar/easydel/main/images/easydel-logo-with-text.png">
</p>

<h1 align="center">Kimi-VL-A3B-Instruct</h1>

<div align="center">
  A model compatible with the EasyDeL JAX stack.
</div>

## Overview

This checkpoint is intended to be loaded with EasyDeL on JAX (CPU/GPU/TPU). It supports sharded loading with `auto_shard_model=True` and configurable precision via `dtype`, `param_dtype`, and `precision`.

## Quickstart

```python
import easydel as ed
from jax import numpy as jnp, lax

repo_id = "EasyDeL/Kimi-VL-A3B-Instruct"

dtype = jnp.bfloat16  # try jnp.float16 on many GPUs

model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
    repo_id,
    dtype=dtype,
    param_dtype=dtype,
    precision=lax.Precision("fastest"),
    sharding_axis_names=("dp", "fsdp", "ep", "tp", "sp"),
    sharding_axis_dims=(1, -1, 1, 1, 1),
    config_kwargs=ed.EasyDeLBaseConfigDict(
        attn_dtype=dtype,
        attn_mechanism=ed.AttentionMechanisms.RAGGED_PAGE_ATTENTION_V3,
        fsdp_is_ep_bound=True,
        sp_is_ep_bound=True,
        moe_method=ed.MoEMethods.FUSED_MOE,
    ),
    auto_shard_model=True,
    partition_axis=ed.PartitionAxis(),
)
```

If the repository only provides PyTorch weights, pass `from_torch=True` to `from_pretrained(...)`.

## Sharding & Parallelism (Multi-Device)

EasyDeL can scale to multiple devices by creating a logical device mesh. Most EasyDeL loaders use a 5D mesh:

- `dp`: data parallel (replicated parameters, different batch shards)
- `fsdp`: parameter sharding (memory saver; often the biggest axis)
- `ep`: expert parallel (MoE; keep `1` for non-MoE models)
- `tp`: tensor parallel (splits large matmuls)
- `sp`: sequence parallel (splits sequence dimension)

Use `sharding_axis_names=("dp","fsdp","ep","tp","sp")` and choose `sharding_axis_dims` so that their product equals your device count.
You can use `-1` in `sharding_axis_dims` to let EasyDeL infer the remaining dimension.

<details>
<summary>Example sharding configs</summary>

```python
# 8 devices, pure FSDP
sharding_axis_dims = (1, 8, 1, 1, 1)

# 8 devices, 2-way DP x 4-way FSDP
sharding_axis_dims = (2, 4, 1, 1, 1)

# 8 devices, 4-way FSDP x 2-way TP
sharding_axis_dims = (1, 4, 1, 2, 1)
```
</details>

## Using via `eLargeModel` (ELM)

`eLargeModel` is a higher-level interface that wires together loading, sharding, training, and eSurge inference from a single config.

```python
from easydel import eLargeModel

repo_id = "EasyDeL/Kimi-VL-A3B-Instruct"

elm = eLargeModel.from_pretrained(repo_id)  # task is auto-detected
elm.set_dtype("bf16")
elm.set_sharding(axis_names=("dp", "fsdp", "ep", "tp", "sp"), axis_dims=(1, -1, 1, 1, 1))

model = elm.build_model()
# Optional: build an inference engine
# engine = elm.build_esurge()
```

<details>
<summary>ELM YAML config example</summary>

```yaml
model:
  name_or_path: "EasyDeL/Kimi-VL-A3B-Instruct"

loader:
  dtype: bf16
  param_dtype: bf16

sharding:
  axis_dims: [1, -1, 1, 1, 1]
  auto_shard_model: true
```
</details>

## Features

**EasyDeL:**
- JAX native implementation and sharded execution
- Configurable attention backends via `AttentionMechanisms.*`
- Precision control via `dtype`, `param_dtype`, and `precision`

## Installation

```bash
pip install easydel
```

## Links

- EasyDeL GitHub: https://github.com/erfanzar/EasyDeL
- Docs: https://easydel.readthedocs.io/en/latest/

## Supported Tasks

- CausalLM

## Limitations

- Refer to the original model card for training data, evaluation, and intended use.

## License

EasyDeL is released under the Apache-2.0 license. The license for this model's weights may differ; please consult the original repository.

## Citation

```bibtex
@misc{Zare Chavoshi_2023,
    title={EasyDeL: An open-source library for enhancing and streamlining the training process of machine learning models},
    url={https://github.com/erfanzar/EasyDeL},
    author={Zare Chavoshi, Erfan},
    year={2023}
}
```