erfanzar commited on
Commit
3e7c527
·
verified ·
1 Parent(s): edea10e

Update README.md (EasyDeL auto-generated)

Browse files
Files changed (1) hide show
  1. README.md +122 -72
README.md CHANGED
@@ -1,115 +1,167 @@
1
  ---
 
 
2
  tags:
3
- - EasyDeL
4
- - KimiVLForConditionalGeneration
5
- - TaskType.IMAGE_TEXT_TO_TEXT
6
- - AttentionMechanisms.RAGGED_PAGE_ATTENTION_V3
7
- - safetensors
8
- - TPU
9
- - GPU
10
- - XLA
11
- - Flax
12
  ---
 
13
  <p align="center">
14
  <a href="https://github.com/erfanzar/EasyDeL">
15
- <img src="https://raw.githubusercontent.com/erfanzar/easydel/main/images/easydel-logo-with-text.png" height="80">
16
  </a>
17
  </p>
18
 
 
 
19
  <p align="center">
20
- <a href="https://github.com/erfanzar/EasyDeL">
21
- <img src="https://img.shields.io/badge/🤗_EasyDeL-0.2.0-blue.svg" />
 
 
 
 
22
  </a>
23
  <a href="https://github.com/erfanzar/EasyDeL">
24
- <img src="https://img.shields.io/badge/Model_Type-KimiVLForConditionalGeneration-green.svg" />
25
  </a>
 
 
 
26
  </p>
27
 
28
- # EasyDeL/Kimi-VL-A3B-Instruct
29
-
30
- A model implemented using the EasyDeL framework, designed to deliver optimal performance for large-scale natural language processing tasks.
31
 
32
  ## Overview
33
 
34
- This model is built using [EasyDeL](https://github.com/erfanzar/EasyDeL), an open-source framework designed to enhance and streamline the training and serving process of machine learning models, with a primary focus on Jax/Flax on TPU/GPU at scale.
35
 
36
- EasyDeL provides an efficient, highly-optimized, and customizable machine learning model compatible with both GPU and TPU environments. Built with JAX, this model supports advanced features such as sharded model parallelism, making it suitable for distributed training and inference and customized kernels.
37
 
38
- ## Features Provided by EasyDeL
 
 
39
 
40
- **EasyDeL Framework Features:**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- - **Efficient Implementation**: Built with JAX/Flax for high-performance computation.
43
- - **Modern Architecture**: Built on Flax NNX for better integration, modularity, and performance.
44
- - **Multi-Device Support**: Optimized to run on TPU, GPU, and CPU environments.
45
- - **Sharded Model Parallelism**: Supports model parallelism across multiple devices for scalability (using `auto_shard_model=True`).
46
- - **Customizable Precision**: Allows specification of `dtype`, `param_dtype`, and `precision`.
47
- - **Advanced Serving**: Includes `eSurge` LLM serving engine, `vWhisper` speech endpoints, and OpenAI-compatible APIs.
48
- - **Optimized Kernels**: Integrates multiple attention mechanisms (like `AttentionMechanisms.RAGGED_PAGE_ATTENTION_V3`) and platform-specific optimizations.
49
 
50
- ## Installation
51
 
52
- To use this model via EasyDeL, first install EasyDeL:
53
 
54
- ```bash
55
- pip install easydel
56
- ```
 
 
 
 
 
57
 
58
- ## Usage
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- ### Loading the Pre-trained Model
61
 
62
- To load this pre-trained model with EasyDeL:
63
 
64
  ```python
65
- from easydel import AutoEasyDeLModelForCausalLM, EasyDeLBaseConfigDict, AttentionMechanisms
66
- from jax import numpy as jnp, lax
67
 
68
- # Define max_length if needed for memory optimization
69
- max_length = None
70
-
71
- # Load model and parameters
72
- # Set auto_shard_model=True to automatically distribute across devices
73
- model = AutoEasyDeLModelForCausalLM.from_pretrained(
74
- "EasyDeL/Kimi-VL-A3B-Instruct",
75
- config_kwargs=EasyDeLBaseConfigDict(
76
- # use_scan_mlp=False, # Set to True to potentially reduce memory usage
77
- attn_dtype=jnp.float16, # Or jnp.bfloat16
78
- # freq_max_position_embeddings=max_length, # Set if using RoPE and need truncation
79
- # mask_max_position_embeddings=max_length, # Set if max length is defined
80
- attn_mechanism=AttentionMechanisms.RAGGED_PAGE_ATTENTION_V3 # Matches the mechanism used by this model
81
- ),
82
- dtype=jnp.float16, # Or jnp.bfloat16 - Computation data type
83
- param_dtype=jnp.float16, # Or jnp.bfloat16 - Parameter data type
84
- precision=lax.Precision("fastest"), # Like "default", "fastest", "high", "highest"
85
- auto_shard_model=True, # Auto-shard across available devices
86
- )
87
  ```
88
 
89
- ## Supported Tasks
 
90
 
91
- The primary task for this model is **TaskType.IMAGE_TEXT_TO_TEXT**. Further specific supported tasks are not explicitly listed.
 
 
92
 
93
- ## Limitations
 
 
94
 
95
- **General Limitations:**
 
 
 
 
96
 
97
- - **Hardware Dependency**: Performance can vary significantly based on the hardware (TPU/GPU) used.
98
- - **JAX/Flax Setup Required**: The environment must support JAX/Flax for optimal use.
99
- - **Experimental Features**: Some EasyDeL features (like custom kernels) may require additional configuration.
100
 
101
- ## License 📜
 
 
 
102
 
103
- EasyDeL is released under the Apache v2 license. The license for this specific model might differ; please consult the original model repository or documentation.
104
 
105
- ```code
106
- # Apache License 2.0 (referring to EasyDeL Framework)
107
- # ... (Full license text usually included in the main repo) ...
108
  ```
109
 
110
- ## Citation
111
 
112
- If you use EasyDeL in your research or work, please cite it:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  ```bibtex
115
  @misc{Zare Chavoshi_2023,
@@ -119,5 +171,3 @@ If you use EasyDeL in your research or work, please cite it:
119
  year={2023}
120
  }
121
  ```
122
-
123
- Please also consider citing the original paper or source for the **EasyDeL/Kimi-VL-A3B-Instruct** model architecture if applicable.
 
1
  ---
2
+ library_name: easydel
3
+ pipeline_tag: text-generation
4
  tags:
5
+ - easydel
6
+ - jax
7
+ - "kimi_vl"
8
+ - "CausalLM"
9
+ - "auto"
 
 
 
 
10
  ---
11
+
12
  <p align="center">
13
  <a href="https://github.com/erfanzar/EasyDeL">
14
+ <img src="https://raw.githubusercontent.com/erfanzar/easydel/main/images/easydel-logo-with-text.png" height="80" alt="EasyDeL" />
15
  </a>
16
  </p>
17
 
18
+ <h1 align="center">Kimi-VL-A3B-Instruct</h1>
19
+
20
  <p align="center">
21
+ A model compatible with the EasyDeL JAX stack.
22
+ </p>
23
+
24
+ <p align="center">
25
+ <a href="https://huggingface.co/EasyDeL/Kimi-VL-A3B-Instruct">
26
+ <img src="https://img.shields.io/static/v1?label=HF&message=EasyDeL/Kimi-VL-A3B-Instruct&color=FFD21E&style=flat-square" alt="HuggingFace Repo" />
27
  </a>
28
  <a href="https://github.com/erfanzar/EasyDeL">
29
+ <img src="https://img.shields.io/static/v1?label=EasyDeL&message=v0.2.0&color=0B5FFF&style=flat-square" alt="EasyDeL Version" />
30
  </a>
31
+ <img src="https://img.shields.io/static/v1?label=Model&message=kimi_vl&color=0A66C2&style=flat-square" alt="Model Type" />
32
+ <img src="https://img.shields.io/static/v1?label=Task&message=CausalLM&color=2EAD4D&style=flat-square" alt="Task" />
33
+ <img src="https://img.shields.io/static/v1?label=Attention&message=auto&color=8A2BE2&style=flat-square" alt="Attention Mechanism" />
34
  </p>
35
 
36
+ ---
 
 
37
 
38
  ## Overview
39
 
40
+ This checkpoint is intended to be loaded with EasyDeL on JAX (CPU/GPU/TPU). It supports sharded loading with `auto_shard_model=True` and configurable precision via `dtype`, `param_dtype`, and `precision`.
41
 
42
+ ## Quickstart
43
 
44
+ ```python
45
+ import easydel as ed
46
+ from jax import numpy as jnp, lax
47
 
48
+ repo_id = "EasyDeL/Kimi-VL-A3B-Instruct"
49
+
50
+ dtype = jnp.bfloat16 # try jnp.float16 on many GPUs
51
+
52
+ model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
53
+ repo_id,
54
+ dtype=dtype,
55
+ param_dtype=dtype,
56
+ precision=lax.Precision("fastest"),
57
+ config_kwargs=ed.EasyDeLBaseConfigDict(
58
+ attn_dtype=dtype,
59
+ attn_mechanism=ed.AttentionMechanisms.AUTO,
60
+ fsdp_is_ep_bound=True,
61
+ sp_is_ep_bound=True,
62
+ moe_method=ed.MoEMethods.FUSED_MOE,
63
+ ),
64
+ auto_shard_model=True,
65
+ partition_axis=ed.PartitionAxis(),
66
+ )
67
+ ```
68
 
69
+ If the repository only provides PyTorch weights, pass `from_torch=True` to `from_pretrained(...)`.
 
 
 
 
 
 
70
 
71
+ ## Sharding & Parallelism (Multi-Device)
72
 
73
+ EasyDeL can scale to multiple devices by creating a logical device mesh. Most EasyDeL loaders use a 5D mesh:
74
 
75
+ - `dp`: data parallel (replicated parameters, different batch shards)
76
+ - `fsdp`: parameter sharding (memory saver; often the biggest axis)
77
+ - `ep`: expert parallel (MoE; keep `1` for non-MoE models)
78
+ - `tp`: tensor parallel (splits large matmuls)
79
+ - `sp`: sequence parallel (splits sequence dimension)
80
+
81
+ Use `sharding_axis_names=("dp","fsdp","ep","tp","sp")` and choose `sharding_axis_dims` so that their product equals your device count.
82
+ You can use `-1` in `sharding_axis_dims` to let EasyDeL infer the remaining dimension.
83
 
84
+ <details>
85
+ <summary>Example sharding configs</summary>
86
+
87
+ ```python
88
+ # 8 devices, pure FSDP
89
+ sharding_axis_dims = (1, 8, 1, 1, 1)
90
+
91
+ # 8 devices, 2-way DP x 4-way FSDP
92
+ sharding_axis_dims = (2, 4, 1, 1, 1)
93
+
94
+ # 8 devices, 4-way FSDP x 2-way TP
95
+ sharding_axis_dims = (1, 4, 1, 2, 1)
96
+ ```
97
+ </details>
98
 
99
+ ## Using via `eLargeModel` (ELM)
100
 
101
+ `eLargeModel` is a higher-level interface that wires together loading, sharding, training, and eSurge inference from a single config.
102
 
103
  ```python
104
+ from easydel import eLargeModel
 
105
 
106
+ repo_id = "EasyDeL/Kimi-VL-A3B-Instruct"
107
+
108
+ elm = eLargeModel.from_pretrained(repo_id) # task is auto-detected
109
+ elm.set_dtype("bf16")
110
+ elm.set_sharding(axis_names=("dp", "fsdp", "ep", "tp", "sp"), axis_dims=(1, -1, 1, 1, 1))
111
+
112
+ model = elm.build_model()
113
+ # Optional: build an inference engine
114
+ # engine = elm.build_esurge()
 
 
 
 
 
 
 
 
 
 
115
  ```
116
 
117
+ <details>
118
+ <summary>ELM YAML config example</summary>
119
 
120
+ ```yaml
121
+ model:
122
+ name_or_path: "EasyDeL/Kimi-VL-A3B-Instruct"
123
 
124
+ loader:
125
+ dtype: bf16
126
+ param_dtype: bf16
127
 
128
+ sharding:
129
+ axis_dims: [1, -1, 1, 1, 1]
130
+ auto_shard_model: true
131
+ ```
132
+ </details>
133
 
134
+ ## Features
 
 
135
 
136
+ **EasyDeL:**
137
+ - JAX native implementation and sharded execution
138
+ - Configurable attention backends via `AttentionMechanisms.*`
139
+ - Precision control via `dtype`, `param_dtype`, and `precision`
140
 
141
+ ## Installation
142
 
143
+ ```bash
144
+ pip install easydel
 
145
  ```
146
 
147
+ ## Links
148
 
149
+ - EasyDeL GitHub: https://github.com/erfanzar/EasyDeL
150
+ - Docs: https://easydel.readthedocs.io/en/latest/
151
+
152
+ ## Supported Tasks
153
+
154
+ - CausalLM
155
+
156
+ ## Limitations
157
+
158
+ - Refer to the original model card for training data, evaluation, and intended use.
159
+
160
+ ## License
161
+
162
+ EasyDeL is released under the Apache-2.0 license. The license for this model's weights may differ; please consult the original repository.
163
+
164
+ ## Citation
165
 
166
  ```bibtex
167
  @misc{Zare Chavoshi_2023,
 
171
  year={2023}
172
  }
173
  ```