Update README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,36 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
---
|
| 4 |
+
### huihui-ai/grok-2
|
| 5 |
+
This Python script is designed to process and merge sharded weight files (in safetensors format) for a machine learning model, specifically targeting the `xai-org/grok-2` model. The main functionalities include:
|
| 6 |
+
|
| 7 |
+
1. **Collecting safetensors files**: Locates all `pytorch_model-*.safetensors` files in the specified model directory.
|
| 8 |
+
2. **Loading files into cache**: Loads all safetensors files into memory and builds a key-to-file mapping.
|
| 9 |
+
3. **Merging Tensor Parallel (TP) shards**: Merges shards for tensor parallelism (TP=8) along specific dimensions and verifies the merged tensor shapes.
|
| 10 |
+
4. **Grouping weights by layer**: Organizes weights by model layer, with special weights (e.g., `lm_head.weight`, `model.embed_tokens.weight`, and `model.norm.weight`) handled separately.
|
| 11 |
+
5. **Saving merged weights**: Saves the grouped weights as new safetensors files and generates a new index file `pytorch_model.bin.index.json`.
|
| 12 |
+
|
| 13 |
+
### Features
|
| 14 |
+
- **Input**: Safetensors files in the `xai-org/grok-2` model directory.
|
| 15 |
+
- **Output**: Layer-organized safetensors files and an index file in the `huihui-ai/grok-2` directory.
|
| 16 |
+
- **Tensor Parallelism Support**: Handles TP=8 shards, merging tensors along specific dimensions (`w1.weight` and `w3.weight` along dim=0, `w2.weight` along dim=1).
|
| 17 |
+
- **Error Handling**: Includes warnings and handling for missing files, shape mismatches, and other exceptions.
|
| 18 |
+
- **Shape Validation**: Verifies shapes for specific weights (e.g., MoE layer weights), ensuring merged tensors match expected shapes (e.g., `(16384, 8192)` or `(8192, 16384)`).
|
| 19 |
+
|
| 20 |
+
### Usage
|
| 21 |
+
1. Install the required Python libraries:
|
| 22 |
+
```bash
|
| 23 |
+
pip install torch safetensors
|
| 24 |
+
```
|
| 25 |
+
2. Place the script in an environment with the `xai-org/grok-2` model directory.
|
| 26 |
+
3. Run the script:
|
| 27 |
+
```bash
|
| 28 |
+
python script.py
|
| 29 |
+
```
|
| 30 |
+
4. Output files will be saved in the `huihui-ai/grok-2` directory, including layer-organized safetensors files and an index file.
|
| 31 |
+
|
| 32 |
+
### Notes
|
| 33 |
+
- Ensure the input directory `xai-org/grok-2` contains valid `pytorch_model-*.safetensors` files.
|
| 34 |
+
- The script assumes a tensor parallelism degree of 8 (`tp_count = 8`). Modify the `tp_count` value in the script if needed.
|
| 35 |
+
- Memory requirements may be high; run on a machine with sufficient memory.
|
| 36 |
+
- If shards are missing or shapes mismatch, the script will print warnings and attempt to proceed.
|