File size: 4,093 Bytes
5d8e603
 
dbea3d8
5d8e603
 
 
 
dbea3d8
79107b3
5d8e603
 
7e1bc2b
 
 
5d8e603
dbea3d8
5d8e603
efcada8
dbea3d8
efcada8
5d8e603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbea3d8
5d8e603
 
 
 
a7c9145
 
dbea3d8
 
79107b3
dbea3d8
 
 
 
5d8e603
dbea3d8
5d8e603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbea3d8
5d8e603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
---
library_name: easydel
pipeline_tag: text-generation
tags:
  - easydel
  - jax
  - "glm4v"
  - "CausalLM"
  - "blocksparse"
---

<p align="center">
  <img alt="easydel" src="https://raw.githubusercontent.com/erfanzar/easydel/main/images/easydel-logo-with-text.png">
</p>

<h1 align="center">GLM-4.6V-Flash</h1>

<div align="center">
  A model compatible with the EasyDeL JAX stack.
</div>

## Overview

This checkpoint is intended to be loaded with EasyDeL on JAX (CPU/GPU/TPU). It supports sharded loading with `auto_shard_model=True` and configurable precision via `dtype`, `param_dtype`, and `precision`.

## Quickstart

```python
import easydel as ed
from jax import numpy as jnp, lax

repo_id = "EasyDeL/GLM-4.6V-Flash"

dtype = jnp.bfloat16  # try jnp.float16 on many GPUs

model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
    repo_id,
    dtype=dtype,
    param_dtype=dtype,
    precision=lax.Precision("fastest"),
    sharding_axis_names=("dp", "fsdp", "ep", "tp", "sp"),
    sharding_axis_dims=(1, -1, 1, 1, 1),
    config_kwargs=ed.EasyDeLBaseConfigDict(
        attn_dtype=dtype,
        attn_mechanism=ed.AttentionMechanisms.SPLASH,
        fsdp_is_ep_bound=True,
        sp_is_ep_bound=True,
        moe_method=ed.MoEMethods.FUSED_MOE,
    ),
    auto_shard_model=True,
    partition_axis=ed.PartitionAxis(),
)
```

If the repository only provides PyTorch weights, pass `from_torch=True` to `from_pretrained(...)`.

## Sharding & Parallelism (Multi-Device)

EasyDeL can scale to multiple devices by creating a logical device mesh. Most EasyDeL loaders use a 5D mesh:

- `dp`: data parallel (replicated parameters, different batch shards)
- `fsdp`: parameter sharding (memory saver; often the biggest axis)
- `ep`: expert parallel (MoE; keep `1` for non-MoE models)
- `tp`: tensor parallel (splits large matmuls)
- `sp`: sequence parallel (splits sequence dimension)

Use `sharding_axis_names=("dp","fsdp","ep","tp","sp")` and choose `sharding_axis_dims` so that their product equals your device count.
You can use `-1` in `sharding_axis_dims` to let EasyDeL infer the remaining dimension.

<details>
<summary>Example sharding configs</summary>

```python
# 8 devices, pure FSDP
sharding_axis_dims = (1, 8, 1, 1, 1)

# 8 devices, 2-way DP x 4-way FSDP
sharding_axis_dims = (2, 4, 1, 1, 1)

# 8 devices, 4-way FSDP x 2-way TP
sharding_axis_dims = (1, 4, 1, 2, 1)
```
</details>

## Using via `eLargeModel` (ELM)

`eLargeModel` is a higher-level interface that wires together loading, sharding, training, and eSurge inference from a single config.

```python
from easydel import eLargeModel

repo_id = "EasyDeL/GLM-4.6V-Flash"

elm = eLargeModel.from_pretrained(repo_id)  # task is auto-detected
elm.set_dtype("bf16")
elm.set_sharding(axis_names=("dp", "fsdp", "ep", "tp", "sp"), axis_dims=(1, -1, 1, 1, 1))

model = elm.build_model()
# Optional: build an inference engine
# engine = elm.build_esurge()
```

<details>
<summary>ELM YAML config example</summary>

```yaml
model:
  name_or_path: "EasyDeL/GLM-4.6V-Flash"

loader:
  dtype: bf16
  param_dtype: bf16

sharding:
  axis_dims: [1, -1, 1, 1, 1]
  auto_shard_model: true
```
</details>

## Features

**EasyDeL:**
- JAX native implementation and sharded execution
- Configurable attention backends via `AttentionMechanisms.*`
- Precision control via `dtype`, `param_dtype`, and `precision`

## Installation

```bash
pip install easydel
```

## Links

- EasyDeL GitHub: https://github.com/erfanzar/EasyDeL
- Docs: https://easydel.readthedocs.io/en/latest/

## Supported Tasks

- CausalLM

## Limitations

- Refer to the original model card for training data, evaluation, and intended use.

## License

EasyDeL is released under the Apache-2.0 license. The license for this model's weights may differ; please consult the original repository.

## Citation

```bibtex
@misc{Zare Chavoshi_2023,
    title={EasyDeL: An open-source library for enhancing and streamlining the training process of machine learning models},
    url={https://github.com/erfanzar/EasyDeL},
    author={Zare Chavoshi, Erfan},
    year={2023}
}
```