Add files using upload-large-folder tool
Browse files- .gitattributes +18 -0
- LICENSE +55 -0
- NOTICE +1 -0
- README.md +52 -0
- _CHECKPOINT_METADATA +1 -0
- _METADATA +1 -0
- commit_success.txt +1 -0
- convert.py +469 -0
- d/6fa366570758d6d92bac0192e972f081 +0 -0
- manifest.ocdbt +0 -0
- ocdbt.process_0/d/0c600442de678f199ab5e8d0152c8bbc +3 -0
- ocdbt.process_0/d/124c8d55117974805f4a148ebd670601 +3 -0
- ocdbt.process_0/d/156c1d2fa9763c506b9a075a5729b7de +3 -0
- ocdbt.process_0/d/2f806db10cde1c0f439f288f34ee8b95 +3 -0
- ocdbt.process_0/d/44d6c5046ef1352f82da17c3daffc66d +3 -0
- ocdbt.process_0/d/4f87c6488b21c3c03780afba253c8af0 +3 -0
- ocdbt.process_0/d/78efb519de44bcebc6eb6fd345ff0af0 +3 -0
- ocdbt.process_0/d/82c6ee1c6166b74a75b93badf224480e +3 -0
- ocdbt.process_0/d/88b7b61e8b4c3f22ec014bd4fcd5cbe5 +3 -0
- ocdbt.process_0/d/8d2385834c5546d282b60ec4c29d4378 +3 -0
- ocdbt.process_0/d/97411cb3209e6a07d86332228942a45c +3 -0
- ocdbt.process_0/d/a4f076e56e9ec995bf74b8f5d19e4101 +3 -0
- ocdbt.process_0/d/ae742711b7e4d0ae7a8eaa239393ec58 +3 -0
- ocdbt.process_0/d/ca081ceb253837d87f9bc28d6dfe2b0a +3 -0
- ocdbt.process_0/d/ce563dd53d0cc8894641b5996dc5453f +3 -0
- ocdbt.process_0/d/df6d4f33d5f964745802e5eedfe1425e +3 -0
- ocdbt.process_0/d/e61206338bedb2c7eecd25dc76511b00 +3 -0
- ocdbt.process_0/d/f43da177788619ad3cd686b1345608ef +3 -0
- ocdbt.process_0/d/fd9008384544af4c40cc6878c2bbb5f0 +0 -0
- ocdbt.process_0/manifest.ocdbt +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,21 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
ocdbt.process_0/d/4f87c6488b21c3c03780afba253c8af0 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
ocdbt.process_0/d/f43da177788619ad3cd686b1345608ef filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
ocdbt.process_0/d/ce563dd53d0cc8894641b5996dc5453f filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
ocdbt.process_0/d/156c1d2fa9763c506b9a075a5729b7de filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
ocdbt.process_0/d/8d2385834c5546d282b60ec4c29d4378 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
ocdbt.process_0/d/ca081ceb253837d87f9bc28d6dfe2b0a filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
ocdbt.process_0/d/44d6c5046ef1352f82da17c3daffc66d filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
ocdbt.process_0/d/ae742711b7e4d0ae7a8eaa239393ec58 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
ocdbt.process_0/d/88b7b61e8b4c3f22ec014bd4fcd5cbe5 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
ocdbt.process_0/d/78efb519de44bcebc6eb6fd345ff0af0 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
ocdbt.process_0/d/df6d4f33d5f964745802e5eedfe1425e filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
ocdbt.process_0/d/124c8d55117974805f4a148ebd670601 filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
ocdbt.process_0/d/82c6ee1c6166b74a75b93badf224480e filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
ocdbt.process_0/d/97411cb3209e6a07d86332228942a45c filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
ocdbt.process_0/d/e61206338bedb2c7eecd25dc76511b00 filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
ocdbt.process_0/d/2f806db10cde1c0f439f288f34ee8b95 filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
ocdbt.process_0/d/0c600442de678f199ab5e8d0152c8bbc filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
ocdbt.process_0/d/a4f076e56e9ec995bf74b8f5d19e4101 filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Gemma Terms of Use
|
| 2 |
+
|
| 3 |
+
Last modified: April 1, 2024
|
| 4 |
+
|
| 5 |
+
Thank you for using Gemma! These terms of use ("Terms") are entered into by and between Google LLC ("Google") and you and govern your use of Gemma. This is an agreement between you and Google. If you are accepting these Terms and using Gemma on behalf of another person or entity, you represent and warrant that you have full authority to bind that person or entity to these Terms.
|
| 6 |
+
|
| 7 |
+
1. THE AGREEMENT
|
| 8 |
+
|
| 9 |
+
1.1 These Terms apply only to Gemma and any use or access you obtain to Gemma. For the purpose of these Terms, "Gemma" means the Gemma model and software and tools made available by Google, and any software, documentation, or other materials provided by Google relating to Gemma. Gemma may be: (a) a standalone product, or (b) a component of other software products. Gemma is not a Google Service or a Google Product as described in Google’s Terms of Service (https://policies.google.com/terms). These Terms do not apply to any Third Party Products (defined in Section 2.3) or any other Google products or services.
|
| 10 |
+
|
| 11 |
+
1.2 “You” means you as an individual or entity, or your authorized representative. If you or your organization have a separate, written agreement with Google governing your use of Gemma, that agreement will supersede these Terms.
|
| 12 |
+
|
| 13 |
+
2. LICENSE AND USE
|
| 14 |
+
|
| 15 |
+
2.1 Subject to your compliance with these Terms, Google grants you a perpetual, non-exclusive, worldwide, non-transferable, non-sublicensable license to use, reproduce, distribute, perform, and display Gemma, to prepare Derivatives of Gemma, and otherwise use Gemma in your Derivative Models.
|
| 16 |
+
|
| 17 |
+
2.2 You may publicly distribute or make available to third parties, without restriction or charge, Gemma, Derivative Models, and Derivative Works (collectively, the “Artifacts”), provided you: (a) comply with these Terms; (b) provide a copy of these Terms; (c) provide attribution to Google as the source of Gemma in a manner that is reasonable and customary within the industry or context; (d) include a prominent notice with any Derivative Works stating that you have modified Gemma; and (e) include the following notice: "Gemma is provided under and subject to the Gemma Terms of Use found at ai.google.dev/gemma/terms".
|
| 18 |
+
|
| 19 |
+
2.3 You may use Third Party Products with Gemma. "Third Party Products" means third-party software, data, or other resources. Third Party Products may be subject to their own terms and Google is not responsible for Third Party Products.
|
| 20 |
+
|
| 21 |
+
3. RESTRICTIONS AND PROHIBITED USE POLICY
|
| 22 |
+
|
| 23 |
+
3.1 You will comply with the Gemma Prohibited Use Policy (provided at ai.google.dev/gemma/terms). Google may update the Gemma Prohibited Use Policy from time to time by publishing a new version at ai.google.dev/gemma/terms.
|
| 24 |
+
|
| 25 |
+
3.2 You will not use Gemma or Derivatives, or allow or facilitate its use, for any of the following:
|
| 26 |
+
|
| 27 |
+
(a) High Risk Activities: any use that could lead to death, personal injury, or environmental damage, such as in: (i) emergency services; (ii) military, law enforcement, or surveillance; (iii) management or operation of critical infrastructure; or (iv) safety applications.
|
| 28 |
+
|
| 29 |
+
(b) Development or Use of Weapons. Prohibited uses include, but are not limited to, research, development, production, testing, maintenance or use of: (i) weapons; (ii) technologies that cause physical injury or death; or (iii) agents intended to cause harm to humans, animals, or property.
|
| 30 |
+
|
| 31 |
+
(c) Malicious, Abusive, or Unlawful Activities. Prohibited uses include, but are not limited to, use of Gemma or Derivatives to: (i) develop malware, including spyware; (ii) engage in phishing, fraud, or other deceptive activities; (iii) engage in cyber abuse, including denial of service or hacking; (iv) exploit vulnerabilities; (v) generate or disseminate spam; or (vi) facilitate or encourage violations of law.
|
| 32 |
+
|
| 33 |
+
(d) Automated Decision-Making. Prohibited uses include, but are not limited to, use of Gemma or Derivatives to: (i) make decisions on behalf of decision makers within government; (ii) make decisions that determine an individual’s eligibility or selection within education, employment, housing, insurance, or credit; or (iii) make decisions, recommendations, or rankings that determine a person’s eligibility or access within healthcare or legal services.
|
| 34 |
+
|
| 35 |
+
(e) Harmful or Scams. Prohibited uses include, but are not limited to, use of Gemma or Derivatives to: (i) facilitate, encourage, or further scams; (ii) generate or disseminate content that incites violence or hatred; (iii) assist in exploitation or abuse of minors; or (iv) promote or facilitate discrimination.
|
| 36 |
+
|
| 37 |
+
4. TERM AND TERMINATION
|
| 38 |
+
|
| 39 |
+
4.1 These Terms are effective upon your first use, access, or download of Gemma and will continue until terminated.
|
| 40 |
+
|
| 41 |
+
4.2 Google may suspend or terminate your use of Gemma at any time if it believes you have breached these Terms or to comply with applicable law. If Google suspends or terminates your use, you must stop using Gemma.
|
| 42 |
+
|
| 43 |
+
4.3 Upon termination, you must cease all use of Gemma and delete Gemma and all Derivatives.
|
| 44 |
+
|
| 45 |
+
5. DISCLAIMERS
|
| 46 |
+
|
| 47 |
+
5.1 GEMMA IS PROVIDED "AS IS" WITHOUT WARRANTIES OF ANY KIND. GOOGLE DISCLAIMS ALL WARRANTIES AND CONDITIONS, EXPRESS OR IMPLIED, INCLUDING ANY WARRANTIES OR CONDITIONS OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, OR NON-INFRINGEMENT.
|
| 48 |
+
|
| 49 |
+
5.2 YOU ARE RESPONSIBLE FOR YOUR USE OF GEMMA AND DERIVATIVES.
|
| 50 |
+
|
| 51 |
+
5.3 IN NO EVENT WILL GOOGLE OR ITS AFFILIATES BE LIABLE FOR ANY INDIRECT, INCIDENTAL, SPECIAL, CONSEQUENTIAL, OR EXEMPLARY DAMAGES OR LOST PROFITS ARISING FROM OR IN CONNECTION WITH THESE TERMS OR USE OF GEMMA.
|
| 52 |
+
|
| 53 |
+
If you have any questions regarding these Terms, you may contact Google at gemma@google.com.
|
| 54 |
+
|
| 55 |
+
Please read these Terms carefully and keep a copy for your reference.
|
NOTICE
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Gemma is provided under and subject to the Gemma Terms of Use found at ai.google.dev/gemma/terms
|
README.md
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: gemma
|
| 3 |
+
base_model: google/gemma-3-27b-pt
|
| 4 |
+
tags:
|
| 5 |
+
- jax
|
| 6 |
+
- orbax
|
| 7 |
+
- tensorstore
|
| 8 |
+
- gemma
|
| 9 |
+
- conversion
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Gemma 3 27B PT – Orbax/TensorStore (OCDBT)
|
| 13 |
+
|
| 14 |
+
This repository contains the **Orbax/TensorStore (OCDBT)** checkpoint converted from
|
| 15 |
+
`google/gemma-3-27b-pt` Hugging Face safetensors. The conversion stacks transformer
|
| 16 |
+
layer weights along the depth axis for efficient JAX/Orbax loading.
|
| 17 |
+
|
| 18 |
+
## What’s included
|
| 19 |
+
|
| 20 |
+
- Orbax/TensorStore checkpoint files (OCDBT)
|
| 21 |
+
- `convert.py` (the conversion script used)
|
| 22 |
+
- `LICENSE` and `NOTICE` per Gemma Terms of Use
|
| 23 |
+
|
| 24 |
+
## Conversion details
|
| 25 |
+
|
| 26 |
+
- Source: `google/gemma-3-27b-pt`
|
| 27 |
+
- Format change only: **no weight changes** beyond layout/format
|
| 28 |
+
- Layer weights are stacked under `layers_stacked/...`
|
| 29 |
+
|
| 30 |
+
## Loading example (JAX/Orbax)
|
| 31 |
+
|
| 32 |
+
```python
|
| 33 |
+
from huggingface_hub import snapshot_download
|
| 34 |
+
import orbax.checkpoint as ocp
|
| 35 |
+
|
| 36 |
+
path = snapshot_download("dffarr/gemma-3-27b-pt-orbax")
|
| 37 |
+
ckpt = ocp.StandardCheckpointer().restore(path)
|
| 38 |
+
|
| 39 |
+
# Example access
|
| 40 |
+
embed_tokens = ckpt["language_model/model/embed_tokens"]
|
| 41 |
+
q_proj = ckpt["layers_stacked/language_model/model/self_attn/q_proj/weight"]
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## License & Use
|
| 45 |
+
|
| 46 |
+
This repository is provided under the **Gemma Terms of Use**. Please read `LICENSE`
|
| 47 |
+
and comply with the **Gemma Prohibited Use Policy**:
|
| 48 |
+
https://ai.google.dev/gemma/terms
|
| 49 |
+
|
| 50 |
+
## Conversion script
|
| 51 |
+
|
| 52 |
+
The conversion script used to generate this checkpoint is included as `convert.py`.
|
_CHECKPOINT_METADATA
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"item_handlers": "orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler", "metrics": {}, "performance_metrics": {}, "init_timestamp_nsecs": 1770325910424047211, "commit_timestamp_nsecs": 1770326068343296057, "custom_metadata": {}}
|
_METADATA
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"tree_metadata": {"('language_model/model/embed_tokens',)": {"key_metadata": [{"key": "language_model/model/embed_tokens", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('language_model/model/norm',)": {"key_metadata": [{"key": "language_model/model/norm", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/input_layernorm/weight',)": {"key_metadata": [{"key": "layers_stacked/input_layernorm/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/language_model/model/self_attn/k_proj/weight',)": {"key_metadata": [{"key": "layers_stacked/language_model/model/self_attn/k_proj/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/language_model/model/self_attn/q_proj/weight',)": {"key_metadata": [{"key": "layers_stacked/language_model/model/self_attn/q_proj/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/language_model/model/self_attn/v_proj/weight',)": {"key_metadata": [{"key": "layers_stacked/language_model/model/self_attn/v_proj/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/layer_norm1/bias',)": {"key_metadata": [{"key": "layers_stacked/layer_norm1/bias", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/layer_norm1/weight',)": {"key_metadata": [{"key": "layers_stacked/layer_norm1/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/layer_norm2/bias',)": {"key_metadata": [{"key": "layers_stacked/layer_norm2/bias", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/layer_norm2/weight',)": {"key_metadata": [{"key": "layers_stacked/layer_norm2/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/mlp/down_proj/weight',)": {"key_metadata": [{"key": "layers_stacked/mlp/down_proj/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/mlp/fc1/bias',)": {"key_metadata": [{"key": "layers_stacked/mlp/fc1/bias", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/mlp/fc1/weight',)": {"key_metadata": [{"key": "layers_stacked/mlp/fc1/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/mlp/fc2/bias',)": {"key_metadata": [{"key": "layers_stacked/mlp/fc2/bias", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/mlp/fc2/weight',)": {"key_metadata": [{"key": "layers_stacked/mlp/fc2/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/mlp/gate_proj/weight',)": {"key_metadata": [{"key": "layers_stacked/mlp/gate_proj/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/mlp/up_proj/weight',)": {"key_metadata": [{"key": "layers_stacked/mlp/up_proj/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/post_attention_layernorm/weight',)": {"key_metadata": [{"key": "layers_stacked/post_attention_layernorm/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/post_feedforward_layernorm/weight',)": {"key_metadata": [{"key": "layers_stacked/post_feedforward_layernorm/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/pre_feedforward_layernorm/weight',)": {"key_metadata": [{"key": "layers_stacked/pre_feedforward_layernorm/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/self_attn/k_norm/weight',)": {"key_metadata": [{"key": "layers_stacked/self_attn/k_norm/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/self_attn/k_proj/bias',)": {"key_metadata": [{"key": "layers_stacked/self_attn/k_proj/bias", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/self_attn/o_proj/weight',)": {"key_metadata": [{"key": "layers_stacked/self_attn/o_proj/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/self_attn/out_proj/bias',)": {"key_metadata": [{"key": "layers_stacked/self_attn/out_proj/bias", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/self_attn/out_proj/weight',)": {"key_metadata": [{"key": "layers_stacked/self_attn/out_proj/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/self_attn/q_norm/weight',)": {"key_metadata": [{"key": "layers_stacked/self_attn/q_norm/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/self_attn/q_proj/bias',)": {"key_metadata": [{"key": "layers_stacked/self_attn/q_proj/bias", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/self_attn/v_proj/bias',)": {"key_metadata": [{"key": "layers_stacked/self_attn/v_proj/bias", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/vision_tower/vision_model/encoder/self_attn/k_proj/weight',)": {"key_metadata": [{"key": "layers_stacked/vision_tower/vision_model/encoder/self_attn/k_proj/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/vision_tower/vision_model/encoder/self_attn/q_proj/weight',)": {"key_metadata": [{"key": "layers_stacked/vision_tower/vision_model/encoder/self_attn/q_proj/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('layers_stacked/vision_tower/vision_model/encoder/self_attn/v_proj/weight',)": {"key_metadata": [{"key": "layers_stacked/vision_tower/vision_model/encoder/self_attn/v_proj/weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('multi_modal_projector/mm_input_projection_weight',)": {"key_metadata": [{"key": "multi_modal_projector/mm_input_projection_weight", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('multi_modal_projector/mm_soft_emb_norm',)": {"key_metadata": [{"key": "multi_modal_projector/mm_soft_emb_norm", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('vision_tower/vision_model/embeddings/patch_embedding',)": {"key_metadata": [{"key": "vision_tower/vision_model/embeddings/patch_embedding", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('vision_tower/vision_model/embeddings/patch_embedding/bias',)": {"key_metadata": [{"key": "vision_tower/vision_model/embeddings/patch_embedding/bias", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('vision_tower/vision_model/embeddings/position_embedding',)": {"key_metadata": [{"key": "vision_tower/vision_model/embeddings/position_embedding", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('vision_tower/vision_model/post_layernorm',)": {"key_metadata": [{"key": "vision_tower/vision_model/post_layernorm", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}, "('vision_tower/vision_model/post_layernorm/bias',)": {"key_metadata": [{"key": "vision_tower/vision_model/post_layernorm/bias", "key_type": 2}], "value_metadata": {"value_type": "np.ndarray", "skip_deserialize": false}}}, "use_ocdbt": true, "use_zarr3": false, "store_array_data_equal_to_fill_value": true, "custom_metadata": null}
|
commit_success.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Checkpoint commit was successful to gs://gemma-3-27b-pt-orbax-b76114af/gemma-3-27b-pt-orbax
|
convert.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Convert Hugging Face safetensors to Orbax (TensorStore) with smart stacking.
|
| 3 |
+
|
| 4 |
+
This program streams a Hugging Face safetensors snapshot, discovers transformer
|
| 5 |
+
layer parameters, stacks them by suffix across depth, and writes a topology-
|
| 6 |
+
agnostic Orbax checkpoint to GCS. It enforces CPU execution, processes one stack
|
| 7 |
+
at a time to limit peak RAM, and keeps non-layer parameters unstacked with clean
|
| 8 |
+
names for downstream JAX/Orbax training loops.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import gc
|
| 14 |
+
import fnmatch
|
| 15 |
+
import json
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
import shutil
|
| 19 |
+
import sys
|
| 20 |
+
import tempfile
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Dict, Iterable, List, Optional, Tuple
|
| 23 |
+
|
| 24 |
+
# Force CPU before importing JAX.
|
| 25 |
+
os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
import typer
|
| 29 |
+
import ml_dtypes # register bfloat16 dtype for numpy # noqa: F401
|
| 30 |
+
from huggingface_hub import hf_hub_download, list_repo_files, snapshot_download
|
| 31 |
+
from safetensors import safe_open
|
| 32 |
+
from tqdm import tqdm
|
| 33 |
+
|
| 34 |
+
import orbax.checkpoint as ocp
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
APP_NAME = "hf-safetensors-to-orbax"
|
| 38 |
+
ALLOWED_LAYER_PREFIXES = {"layers", "layer", "h", "blocks", "block"}
|
| 39 |
+
PREFIXES_TO_STRIP = ("model.", "transformer.", "module.")
|
| 40 |
+
WEIGHT_SUFFIX = ".weight"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def normalize_output_path(path: str, local: bool) -> str:
|
| 44 |
+
"""Validate and normalize the output location.
|
| 45 |
+
|
| 46 |
+
By default the converter writes to GCS and enforces a `gs://` prefix. When
|
| 47 |
+
the local flag is set, the path is treated as a local filesystem destination
|
| 48 |
+
and `gs://` paths are rejected to avoid accidental uploads.
|
| 49 |
+
"""
|
| 50 |
+
if local:
|
| 51 |
+
if path.startswith("gs://"):
|
| 52 |
+
raise typer.BadParameter("--local requires a filesystem path")
|
| 53 |
+
return os.path.abspath(path)
|
| 54 |
+
if not path.startswith("gs://") or len(path) <= 5:
|
| 55 |
+
raise typer.BadParameter(
|
| 56 |
+
"--output must be a valid gs:// path unless --local is set"
|
| 57 |
+
)
|
| 58 |
+
return path.rstrip("/")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def configure_logging() -> None:
|
| 62 |
+
"""Configure structured logging for CLI execution.
|
| 63 |
+
|
| 64 |
+
The goal is to surface high-level progress milestones and critical context
|
| 65 |
+
(like stack shapes and layer ranges) without drowning the user in per-tensor
|
| 66 |
+
noise. The format is timestamped for long-running conversions.
|
| 67 |
+
"""
|
| 68 |
+
logging.basicConfig(
|
| 69 |
+
level=logging.INFO,
|
| 70 |
+
format="%(asctime)s | %(levelname)s | %(message)s",
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def configure_hf_transfer(byte_progress: bool) -> None:
|
| 75 |
+
"""Enable or disable byte-level download progress via hf_transfer.
|
| 76 |
+
|
| 77 |
+
Hugging Face's default snapshot progress bar is per-file, which makes large
|
| 78 |
+
single-file models look stuck at 0/1 for a long time. Enabling hf_transfer
|
| 79 |
+
switches to a byte-level stream so progress reflects actual download volume.
|
| 80 |
+
"""
|
| 81 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" if byte_progress else "0"
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def infer_snapshot_dir(file_path: str) -> str:
|
| 85 |
+
"""Infer the snapshot directory from a cached Hugging Face file path.
|
| 86 |
+
|
| 87 |
+
Cached files live under `.../snapshots/<commit>/...`. This helper finds that
|
| 88 |
+
segment and returns the snapshot directory so subsequent code can operate on
|
| 89 |
+
a stable root without assuming a particular cache location.
|
| 90 |
+
"""
|
| 91 |
+
parts = Path(file_path).parts
|
| 92 |
+
for idx, part in enumerate(parts):
|
| 93 |
+
if part == "snapshots" and idx + 1 < len(parts):
|
| 94 |
+
return str(Path(*parts[: idx + 2]))
|
| 95 |
+
raise ValueError(f"Unable to infer snapshot dir from path: {file_path}")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def download_snapshot(repo_id: str, token: Optional[str], byte_progress: bool) -> str:
|
| 99 |
+
"""Download the safetensors snapshot from Hugging Face.
|
| 100 |
+
|
| 101 |
+
The fast path uses `snapshot_download`, which only exposes file-count
|
| 102 |
+
progress. When byte-level progress is enabled, this function lists matching
|
| 103 |
+
files and downloads each with `hf_hub_download` to get per-byte visibility.
|
| 104 |
+
"""
|
| 105 |
+
allow_patterns = ["*.safetensors", "*.safetensors.index.json"]
|
| 106 |
+
logging.info("Downloading snapshot for %s", repo_id)
|
| 107 |
+
|
| 108 |
+
if not byte_progress:
|
| 109 |
+
return snapshot_download(
|
| 110 |
+
repo_id=repo_id,
|
| 111 |
+
token=token,
|
| 112 |
+
allow_patterns=allow_patterns,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
repo_files = list_repo_files(repo_id=repo_id, token=token)
|
| 116 |
+
filtered = [
|
| 117 |
+
name
|
| 118 |
+
for name in repo_files
|
| 119 |
+
if any(fnmatch.fnmatch(name, pat) for pat in allow_patterns)
|
| 120 |
+
]
|
| 121 |
+
if not filtered:
|
| 122 |
+
raise ValueError("No matching safetensors files found in repository")
|
| 123 |
+
|
| 124 |
+
snapshot_dir: Optional[str] = None
|
| 125 |
+
for name in tqdm(sorted(filtered), desc="Downloading files"):
|
| 126 |
+
path = hf_hub_download(repo_id=repo_id, filename=name, token=token)
|
| 127 |
+
inferred = infer_snapshot_dir(path)
|
| 128 |
+
if snapshot_dir is None:
|
| 129 |
+
snapshot_dir = inferred
|
| 130 |
+
elif snapshot_dir != inferred:
|
| 131 |
+
raise ValueError("Downloaded files belong to different snapshot dirs")
|
| 132 |
+
|
| 133 |
+
if snapshot_dir is None:
|
| 134 |
+
raise ValueError("Unable to determine snapshot directory")
|
| 135 |
+
return snapshot_dir
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def find_index_json(snapshot_dir: str) -> Optional[str]:
|
| 139 |
+
"""Locate a safetensors index JSON, if present.
|
| 140 |
+
|
| 141 |
+
Sharded HF checkpoints publish a single index JSON that maps tensor names
|
| 142 |
+
to shard files. We treat multiple index files as an error because it is not
|
| 143 |
+
a standard layout and likely indicates a mixed or malformed snapshot.
|
| 144 |
+
"""
|
| 145 |
+
matches = list(Path(snapshot_dir).rglob("*.safetensors.index.json"))
|
| 146 |
+
if not matches:
|
| 147 |
+
return None
|
| 148 |
+
if len(matches) > 1:
|
| 149 |
+
raise ValueError("Multiple safetensors index JSON files found")
|
| 150 |
+
return str(matches[0])
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def build_key_to_file(snapshot_dir: str) -> Dict[str, str]:
|
| 154 |
+
"""Build a mapping from tensor key to safetensors file path.
|
| 155 |
+
|
| 156 |
+
Prefer the official shard index when available to avoid redundant scans.
|
| 157 |
+
If no index exists (common for single-file models), scan the safetensors
|
| 158 |
+
files directly. Duplicate keys are treated as fatal to prevent corruption.
|
| 159 |
+
"""
|
| 160 |
+
index_path = find_index_json(snapshot_dir)
|
| 161 |
+
if index_path:
|
| 162 |
+
with open(index_path, "r", encoding="utf-8") as handle:
|
| 163 |
+
data = json.load(handle)
|
| 164 |
+
weight_map = data.get("weight_map")
|
| 165 |
+
if not weight_map:
|
| 166 |
+
raise ValueError(f"Index JSON missing weight_map: {index_path}")
|
| 167 |
+
mapping: Dict[str, str] = {}
|
| 168 |
+
for key, rel_path in weight_map.items():
|
| 169 |
+
abs_path = os.path.join(snapshot_dir, rel_path)
|
| 170 |
+
if key in mapping:
|
| 171 |
+
raise ValueError(f"Duplicate key in index: {key}")
|
| 172 |
+
mapping[key] = abs_path
|
| 173 |
+
return mapping
|
| 174 |
+
|
| 175 |
+
safetensor_files = list(Path(snapshot_dir).rglob("*.safetensors"))
|
| 176 |
+
if not safetensor_files:
|
| 177 |
+
raise ValueError("No .safetensors files found in snapshot")
|
| 178 |
+
|
| 179 |
+
mapping = {}
|
| 180 |
+
for file_path in safetensor_files:
|
| 181 |
+
with safe_open(str(file_path), framework="numpy") as handle:
|
| 182 |
+
for key in handle.keys():
|
| 183 |
+
if key in mapping:
|
| 184 |
+
raise ValueError(f"Duplicate key across files: {key}")
|
| 185 |
+
mapping[key] = str(file_path)
|
| 186 |
+
return mapping
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def parse_layer_key(key: str) -> Optional[Tuple[str, int, str]]:
|
| 190 |
+
"""Parse a parameter key into (prefix, layer_index, suffix) if it looks stackable.
|
| 191 |
+
|
| 192 |
+
The heuristic looks for an integer segment that is preceded by a known layer
|
| 193 |
+
container name (layers, h, blocks, ...). The suffix is everything after the
|
| 194 |
+
index, which becomes the grouping key for stacking. The prefix is everything
|
| 195 |
+
before the layer container; it disambiguates multi-tower models that reuse
|
| 196 |
+
the same suffixes across different component stacks.
|
| 197 |
+
"""
|
| 198 |
+
parts = key.split(".")
|
| 199 |
+
for idx in range(1, len(parts) - 1):
|
| 200 |
+
if parts[idx].isdigit() and parts[idx - 1] in ALLOWED_LAYER_PREFIXES:
|
| 201 |
+
prefix = ".".join(parts[: idx - 1]).strip(".")
|
| 202 |
+
suffix = ".".join(parts[idx + 1 :]).strip(".")
|
| 203 |
+
if suffix:
|
| 204 |
+
return prefix, int(parts[idx]), suffix
|
| 205 |
+
return None
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def group_layer_keys(
|
| 209 |
+
keys: Iterable[str],
|
| 210 |
+
) -> Tuple[Dict[Tuple[str, str], Dict[int, str]], List[str], Dict[str, set[str]]]:
|
| 211 |
+
"""Group layer parameters by (prefix, suffix) and separate non-layer parameters.
|
| 212 |
+
|
| 213 |
+
This pass identifies all stackable keys, buckets them by (prefix, suffix),
|
| 214 |
+
and filters the remaining parameters into a non-layer list. Only groups with
|
| 215 |
+
two or more distinct layer indices qualify for stacking to reduce false positives.
|
| 216 |
+
It also tracks which prefixes appear for a given suffix to drive naming.
|
| 217 |
+
"""
|
| 218 |
+
suffix_groups: Dict[Tuple[str, str], Dict[int, str]] = {}
|
| 219 |
+
suffix_prefixes: Dict[str, set[str]] = {}
|
| 220 |
+
|
| 221 |
+
for key in keys:
|
| 222 |
+
parsed = parse_layer_key(key)
|
| 223 |
+
if not parsed:
|
| 224 |
+
continue
|
| 225 |
+
prefix, layer_idx, suffix = parsed
|
| 226 |
+
group = suffix_groups.setdefault((prefix, suffix), {})
|
| 227 |
+
if layer_idx in group:
|
| 228 |
+
raise ValueError(
|
| 229 |
+
f"Duplicate layer index {layer_idx} for prefix '{prefix}' and suffix '{suffix}'"
|
| 230 |
+
)
|
| 231 |
+
group[layer_idx] = key
|
| 232 |
+
suffix_prefixes.setdefault(suffix, set()).add(prefix)
|
| 233 |
+
|
| 234 |
+
stackable_suffixes = {
|
| 235 |
+
key: group for key, group in suffix_groups.items() if len(group) >= 2
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
stackable_keys = {
|
| 239 |
+
key for group in stackable_suffixes.values() for key in group.values()
|
| 240 |
+
}
|
| 241 |
+
non_layer_keys = [key for key in keys if key not in stackable_keys]
|
| 242 |
+
|
| 243 |
+
return stackable_suffixes, non_layer_keys, suffix_prefixes
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def load_tensor(file_path: str, key: str) -> np.ndarray:
|
| 247 |
+
"""Load a single tensor from a safetensors file.
|
| 248 |
+
|
| 249 |
+
Safetensors reads are isolated to the specific key to keep IO and memory
|
| 250 |
+
tight. Callers handle validation and stacking; this only returns the array.
|
| 251 |
+
"""
|
| 252 |
+
with safe_open(file_path, framework="numpy") as handle:
|
| 253 |
+
return handle.get_tensor(key)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def validate_contiguous(indices: List[int], suffix: str) -> None:
|
| 257 |
+
"""Assert that a suffix group has a contiguous layer index range.
|
| 258 |
+
|
| 259 |
+
Stacking assumes a dense layer axis. Missing indices likely mean a keying
|
| 260 |
+
bug or an unexpected model layout, so we fail fast with a helpful message.
|
| 261 |
+
"""
|
| 262 |
+
expected = list(range(indices[0], indices[-1] + 1))
|
| 263 |
+
if indices != expected:
|
| 264 |
+
raise ValueError(
|
| 265 |
+
f"Missing layer indices for suffix '{suffix}': got {indices}, expected {expected}"
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def make_stacked_name(prefix: str, suffix: str, suffix_prefixes: Dict[str, set[str]]) -> str:
|
| 270 |
+
"""Translate a (prefix, suffix) pair into the output stacked parameter name.
|
| 271 |
+
|
| 272 |
+
The suffix path is preserved while converting dots to slashes, and it is
|
| 273 |
+
rooted under `layers_stacked/`. If the same suffix appears under multiple
|
| 274 |
+
prefixes, the prefix path is included to avoid collisions and make the
|
| 275 |
+
output explicit for multi-tower models.
|
| 276 |
+
"""
|
| 277 |
+
suffix_path = suffix.replace(".", "/")
|
| 278 |
+
if len(suffix_prefixes.get(suffix, set())) > 1:
|
| 279 |
+
prefix_path = prefix.replace(".", "/").strip("/")
|
| 280 |
+
if prefix_path:
|
| 281 |
+
return f"layers_stacked/{prefix_path}/{suffix_path}"
|
| 282 |
+
return f"layers_stacked/{suffix_path}"
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def clean_global_name(key: str) -> str:
|
| 286 |
+
"""Normalize a non-layer parameter key into a clean output name.
|
| 287 |
+
|
| 288 |
+
This strips common framework prefixes and a trailing `.weight`, then
|
| 289 |
+
converts dot segments into path separators for a stable, readable tree.
|
| 290 |
+
"""
|
| 291 |
+
cleaned = key
|
| 292 |
+
for prefix in PREFIXES_TO_STRIP:
|
| 293 |
+
if cleaned.startswith(prefix):
|
| 294 |
+
cleaned = cleaned[len(prefix) :]
|
| 295 |
+
break
|
| 296 |
+
if cleaned.endswith(WEIGHT_SUFFIX):
|
| 297 |
+
cleaned = cleaned[: -len(WEIGHT_SUFFIX)]
|
| 298 |
+
cleaned = cleaned.strip(".")
|
| 299 |
+
if not cleaned:
|
| 300 |
+
raise ValueError(f"Invalid empty name after cleaning: {key}")
|
| 301 |
+
return cleaned.replace(".", "/")
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def save_array_to_temp(temp_dir: str, name: str, array: np.ndarray) -> np.ndarray:
|
| 305 |
+
"""Persist an array to disk and reopen it as a memory-mapped view.
|
| 306 |
+
|
| 307 |
+
Building the full output tree in RAM can be expensive for large models. By
|
| 308 |
+
staging arrays to disk and reopening with mmap, we limit peak memory while
|
| 309 |
+
still handing Orbax a stable array-like object for checkpoint writing. We
|
| 310 |
+
store raw bytes instead of `.npy` so non-standard dtypes like bfloat16 are
|
| 311 |
+
preserved when memory-mapped.
|
| 312 |
+
"""
|
| 313 |
+
file_path = os.path.join(temp_dir, f"{name}.bin")
|
| 314 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
| 315 |
+
if not array.flags["C_CONTIGUOUS"]:
|
| 316 |
+
array = np.ascontiguousarray(array)
|
| 317 |
+
array.tofile(file_path)
|
| 318 |
+
return np.memmap(
|
| 319 |
+
file_path,
|
| 320 |
+
mode="r",
|
| 321 |
+
dtype=array.dtype,
|
| 322 |
+
shape=array.shape,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def convert(
|
| 327 |
+
hf_repo: str,
|
| 328 |
+
hf_token: Optional[str],
|
| 329 |
+
output_path: str,
|
| 330 |
+
stacking_config: str,
|
| 331 |
+
byte_progress: bool,
|
| 332 |
+
local: bool,
|
| 333 |
+
) -> None:
|
| 334 |
+
"""Run the full conversion pipeline from HF snapshot to Orbax checkpoint.
|
| 335 |
+
|
| 336 |
+
The flow is: download snapshot, build key-to-file map, identify stackable
|
| 337 |
+
groups, stack one suffix group at a time with shape validation, write staged
|
| 338 |
+
arrays into an output tree, then save via Orbax. This function enforces all
|
| 339 |
+
invariants and ensures temporary storage is cleaned up on failure.
|
| 340 |
+
"""
|
| 341 |
+
if stacking_config != "auto":
|
| 342 |
+
raise typer.BadParameter("--stacking-config only supports 'auto'")
|
| 343 |
+
|
| 344 |
+
output_path = normalize_output_path(output_path, local)
|
| 345 |
+
configure_hf_transfer(byte_progress)
|
| 346 |
+
snapshot_dir = download_snapshot(hf_repo, hf_token, byte_progress)
|
| 347 |
+
key_to_file = build_key_to_file(snapshot_dir)
|
| 348 |
+
|
| 349 |
+
keys = sorted(key_to_file.keys())
|
| 350 |
+
stackable, non_layer_keys, suffix_prefixes = group_layer_keys(keys)
|
| 351 |
+
|
| 352 |
+
logging.info("Found %d stackable groups", len(stackable))
|
| 353 |
+
logging.info("Found %d non-layer params", len(non_layer_keys))
|
| 354 |
+
|
| 355 |
+
output_tree: Dict[str, np.ndarray] = {}
|
| 356 |
+
used_names = set()
|
| 357 |
+
|
| 358 |
+
temp_dir = tempfile.mkdtemp(prefix=f"{APP_NAME}-")
|
| 359 |
+
logging.info("Using temp dir: %s", temp_dir)
|
| 360 |
+
|
| 361 |
+
try:
|
| 362 |
+
for prefix, suffix in tqdm(sorted(stackable.keys()), desc="Stacking groups"):
|
| 363 |
+
group = stackable[(prefix, suffix)]
|
| 364 |
+
indices = sorted(group.keys())
|
| 365 |
+
validate_contiguous(indices, suffix)
|
| 366 |
+
|
| 367 |
+
first_key = group[indices[0]]
|
| 368 |
+
first_arr = load_tensor(key_to_file[first_key], first_key)
|
| 369 |
+
stacked = np.empty((len(indices),) + first_arr.shape, dtype=first_arr.dtype)
|
| 370 |
+
stacked[0] = first_arr
|
| 371 |
+
base_shape = first_arr.shape
|
| 372 |
+
logging.info(
|
| 373 |
+
"Stacking %s | prefix %s | layers %d-%d | shape %s | dtype %s",
|
| 374 |
+
suffix,
|
| 375 |
+
prefix or "<root>",
|
| 376 |
+
indices[0],
|
| 377 |
+
indices[-1],
|
| 378 |
+
base_shape,
|
| 379 |
+
stacked.dtype,
|
| 380 |
+
)
|
| 381 |
+
del first_arr
|
| 382 |
+
|
| 383 |
+
for pos, layer_idx in enumerate(indices[1:], start=1):
|
| 384 |
+
key = group[layer_idx]
|
| 385 |
+
arr = load_tensor(key_to_file[key], key)
|
| 386 |
+
if arr.shape != base_shape:
|
| 387 |
+
raise ValueError(
|
| 388 |
+
f"Shape mismatch for suffix '{suffix}': {arr.shape} vs {base_shape}"
|
| 389 |
+
)
|
| 390 |
+
stacked[pos] = arr
|
| 391 |
+
|
| 392 |
+
stacked_name = make_stacked_name(prefix, suffix, suffix_prefixes)
|
| 393 |
+
if stacked_name in used_names:
|
| 394 |
+
raise ValueError(f"Output name collision: {stacked_name}")
|
| 395 |
+
used_names.add(stacked_name)
|
| 396 |
+
|
| 397 |
+
output_tree[stacked_name] = save_array_to_temp(
|
| 398 |
+
temp_dir, stacked_name, stacked
|
| 399 |
+
)
|
| 400 |
+
del stacked
|
| 401 |
+
gc.collect()
|
| 402 |
+
|
| 403 |
+
for key in tqdm(non_layer_keys, desc="Saving non-layer params"):
|
| 404 |
+
global_name = clean_global_name(key)
|
| 405 |
+
if global_name in used_names:
|
| 406 |
+
raise ValueError(f"Output name collision: {global_name}")
|
| 407 |
+
used_names.add(global_name)
|
| 408 |
+
|
| 409 |
+
arr = load_tensor(key_to_file[key], key)
|
| 410 |
+
output_tree[global_name] = save_array_to_temp(temp_dir, global_name, arr)
|
| 411 |
+
del arr
|
| 412 |
+
|
| 413 |
+
logging.info("Saving Orbax checkpoint to %s", output_path)
|
| 414 |
+
checkpointer = ocp.StandardCheckpointer()
|
| 415 |
+
checkpointer.save(output_path, output_tree)
|
| 416 |
+
checkpointer.wait_until_finished()
|
| 417 |
+
checkpointer.close()
|
| 418 |
+
logging.info("Save complete")
|
| 419 |
+
finally:
|
| 420 |
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
app = typer.Typer(add_completion=False)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
@app.command()
|
| 427 |
+
def main(
|
| 428 |
+
hf_repo: str = typer.Option(..., help="Hugging Face repo ID"),
|
| 429 |
+
hf_token: Optional[str] = typer.Option(None, help="Hugging Face auth token"),
|
| 430 |
+
output: str = typer.Option(
|
| 431 |
+
...,
|
| 432 |
+
"--output",
|
| 433 |
+
help="Output path (gs://... unless --local)",
|
| 434 |
+
),
|
| 435 |
+
stacking_config: str = typer.Option("auto", help="Stacking strategy (only 'auto')"),
|
| 436 |
+
byte_progress: bool = typer.Option(
|
| 437 |
+
True,
|
| 438 |
+
"--byte-progress/--no-byte-progress",
|
| 439 |
+
help="Use byte-level progress via hf_transfer",
|
| 440 |
+
),
|
| 441 |
+
local: bool = typer.Option(
|
| 442 |
+
False,
|
| 443 |
+
"--local/--no-local",
|
| 444 |
+
help="Write to a local filesystem path instead of GCS",
|
| 445 |
+
),
|
| 446 |
+
) -> None:
|
| 447 |
+
"""CLI entrypoint that wires arguments to the conversion routine.
|
| 448 |
+
|
| 449 |
+
This keeps the CLI surface thin while centralizing the conversion logic in
|
| 450 |
+
`convert`, which makes it easier to test or reuse programmatically.
|
| 451 |
+
"""
|
| 452 |
+
|
| 453 |
+
configure_logging()
|
| 454 |
+
convert(
|
| 455 |
+
hf_repo=hf_repo,
|
| 456 |
+
hf_token=hf_token,
|
| 457 |
+
output_path=output,
|
| 458 |
+
stacking_config=stacking_config,
|
| 459 |
+
byte_progress=byte_progress,
|
| 460 |
+
local=local,
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
if __name__ == "__main__":
|
| 465 |
+
try:
|
| 466 |
+
app()
|
| 467 |
+
except Exception as exc: # pragma: no cover - CLI failure handling
|
| 468 |
+
logging.error("Conversion failed: %s", exc)
|
| 469 |
+
sys.exit(1)
|
d/6fa366570758d6d92bac0192e972f081
ADDED
|
Binary file (2.12 kB). View file
|
|
|
manifest.ocdbt
ADDED
|
Binary file (118 Bytes). View file
|
|
|
ocdbt.process_0/d/0c600442de678f199ab5e8d0152c8bbc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8c76d615abe2ee81f238c013b0c08bcba9a670b533ed4db511cf9db9f46ea401
|
| 3 |
+
size 2754031616
|
ocdbt.process_0/d/124c8d55117974805f4a148ebd670601
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0517cabc606bdc13d666180e4c6b8c87f79561733cc17e9ffcc40eef4cc2f6a1
|
| 3 |
+
size 2780028928
|
ocdbt.process_0/d/156c1d2fa9763c506b9a075a5729b7de
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0e30ef7664e63a6d0f57cd3504ff41b6a4ef6069b885beeb55c3e4a4bd89e666
|
| 3 |
+
size 2825613312
|
ocdbt.process_0/d/2f806db10cde1c0f439f288f34ee8b95
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2b01357e5cc845093e19d68445e120595d5f601655f5eb9be19bb4bd59724332
|
| 3 |
+
size 2905985024
|
ocdbt.process_0/d/44d6c5046ef1352f82da17c3daffc66d
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e4e105c721ad7c6acc26f4649a392a83573424fb9aa2cd0626c35b98afeaf571
|
| 3 |
+
size 2779865088
|
ocdbt.process_0/d/4f87c6488b21c3c03780afba253c8af0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cd0ce378f0c0abbc96571437f4abbedc881082cfb4ab99e1a217d4728f69a9ea
|
| 3 |
+
size 2906173440
|
ocdbt.process_0/d/78efb519de44bcebc6eb6fd345ff0af0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:456831e79cd197fcf3a54bb82d96cc29d5d12c2cd1fa985e2839479ecd5c5d55
|
| 3 |
+
size 2754076672
|
ocdbt.process_0/d/82c6ee1c6166b74a75b93badf224480e
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f59acf6703a68ed463ffdeb6167cc96c5c1a2c929230624e74c8789b916559e
|
| 3 |
+
size 19894272
|
ocdbt.process_0/d/88b7b61e8b4c3f22ec014bd4fcd5cbe5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0338a2136ca337df63f08ebfb302dc5cede3d86ba4575d38bd7c9589ae849c07
|
| 3 |
+
size 477347
|
ocdbt.process_0/d/8d2385834c5546d282b60ec4c29d4378
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2a0d8676748e0392c1cb5f555e19e3f31a7cfd3177e79fe934a3a7594ab84842
|
| 3 |
+
size 2779918336
|
ocdbt.process_0/d/97411cb3209e6a07d86332228942a45c
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6d210bbe1657bfc07fc0fa3d3d113c9c8cf4a3cf772a51daffdc00b7fe069335
|
| 3 |
+
size 2303647744
|
ocdbt.process_0/d/a4f076e56e9ec995bf74b8f5d19e4101
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1b277181db8e3ad9504f4ab24e6039b3d0a553237d7c1b9f887f676c53b550c8
|
| 3 |
+
size 2906382336
|
ocdbt.process_0/d/ae742711b7e4d0ae7a8eaa239393ec58
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2a3620e6e429f02c3ed717690a1baed22654e9fae8a89cfc02323a035590006f
|
| 3 |
+
size 2779967488
|
ocdbt.process_0/d/ca081ceb253837d87f9bc28d6dfe2b0a
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:34504dc0900d79fff51a81e34dc12b56d1e1b4daaee31e70b23e6c3e26bd8d84
|
| 3 |
+
size 2906353664
|
ocdbt.process_0/d/ce563dd53d0cc8894641b5996dc5453f
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:811d04411b96a8c78452d8f8348f03a9aa8bab73aca80953fdfdda74ac990d82
|
| 3 |
+
size 3030417408
|
ocdbt.process_0/d/df6d4f33d5f964745802e5eedfe1425e
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e545021423358467819c42e0dfa71bf7ab7f0f5bd117dac4884a3a092a6921fa
|
| 3 |
+
size 2303664128
|
ocdbt.process_0/d/e61206338bedb2c7eecd25dc76511b00
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b89841d5d259a3d222eda23ee072c1daf9382418fdc3f52eb1d581c87e83ae75
|
| 3 |
+
size 926654464
|
ocdbt.process_0/d/f43da177788619ad3cd686b1345608ef
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b8a56b172509bc843430abfcf32526a185ee59d312f1ded36ee2c317fb740399
|
| 3 |
+
size 2754002944
|
ocdbt.process_0/d/fd9008384544af4c40cc6878c2bbb5f0
ADDED
|
Binary file (265 Bytes). View file
|
|
|
ocdbt.process_0/manifest.ocdbt
ADDED
|
Binary file (327 Bytes). View file
|
|
|